【CV中的Attention機制】BiSeNet中的FFM模塊與ARM模塊


前言:之前介紹過一個語義分割中的注意力機制模塊-scSE模塊,效果很不錯。今天講的也是語義分割中使用到注意力機制的網絡BiSeNet,這個網絡有兩個模塊,分別是FFM模塊和ARM模塊。其實現也很簡單,不過作者對注意力機制模塊理解比較深入,提出的FFM模塊進行的特征融合方式也很新穎。

1. 簡介

語義分割需要豐富的空間信息和相關大的感受野,目前很多語義分割方法為了達到實時推理的速度選擇犧牲空間分辨率,這可能會導致比較差的模型表現。

BiSeNet(Bilateral Segmentation Network)中提出了空間路徑和上下文路徑:

  • 空間路徑用於保留語義信息生成較高分辨率的feature map(減少下采樣的次數)
  • 上下文路徑使用了快速下采樣的策略,用於獲取充足的感受野。
  • 提出了一個FFM模塊,結合了注意力機制進行特征融合。

本文主要關注的是速度和精度的權衡,對於分辨率為2048×1024的輸入,BiSeNet能夠在NVIDIA Titan XP顯卡上達到105FPS的速度,做到了實時語義分割。

2. 分析

提升語義分割速度主要有三種方法,如下圖所示:

  1. 通過resize的方式限定輸入大小,降低計算復雜度。缺點是空間細節有損失,尤其是邊界部分。
  2. 通過減少網絡通道的個數來加快處理速度。缺點是會弱化空間信息。
  3. 放棄最后階段的下采樣(如ENet)。缺點是模型感受野不足以覆蓋大物體,判別能力差。

語義分割中,U型結構也被廣泛使用,如下圖所示:

這種U型網絡通過融合backbone不同層次的特征,在U型結構中逐漸增加空間分辨率,保留更多的細節特征。不過有兩個缺點:

  1. 高分辨率特征圖計算量非常大,影響計算速度。
  2. 由於resize或者減少網絡通道而丟失的空間信息無法通過引入淺層而輕易復原。

3. 細節

下圖是BiSeNet的架構圖,從圖中可看到主要包括兩個部分:空間路徑和上下文路徑。

代碼實現來自:https://github.com/ooooverflow/BiSeNet,其CP部分沒有使用Xception39而使用的ResNet18。

空間路徑SP

減少下采樣次數,只使用三個卷積層(stride=2)獲得1/8的特征圖,由於它利用了較大尺度的特征圖,所以可以編碼比較豐富的空間信息。

class ConvBlock(torch.nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=3,
                 stride=2,
                 padding=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels,
                               out_channels,
                               kernel_size=kernel_size,
                               stride=stride,
                               padding=padding,
                               bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, input):
        x = self.conv1(input)
        return self.relu(self.bn(x))


class Spatial_path(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.convblock1 = ConvBlock(in_channels=3, out_channels=64)
        self.convblock2 = ConvBlock(in_channels=64, out_channels=128)
        self.convblock3 = ConvBlock(in_channels=128, out_channels=256)

    def forward(self, input):
        x = self.convblock1(input)
        x = self.convblock2(x)
        x = self.convblock3(x)
        return x

上下文路徑CP

為了增大感受野,論文提出上下文路徑,在Xception尾部添加全局平均池化層,從而提供更大的感受野。可以看出CP中進行了32倍的下采樣。(示例中CP部分使用的是ResNet18,不是論文中的xception39)

class resnet18(torch.nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.features = models.resnet18(pretrained=pretrained)
        self.conv1 = self.features.conv1
        self.bn1 = self.features.bn1
        self.relu = self.features.relu
        self.maxpool1 = self.features.maxpool
        self.layer1 = self.features.layer1
        self.layer2 = self.features.layer2
        self.layer3 = self.features.layer3
        self.layer4 = self.features.layer4

    def forward(self, input):
        x = self.conv1(input)
        x = self.relu(self.bn1(x))
        x = self.maxpool1(x)
        feature1 = self.layer1(x)  # 1 / 4
        feature2 = self.layer2(feature1)  # 1 / 8
        feature3 = self.layer3(feature2)  # 1 / 16
        feature4 = self.layer4(feature3)  # 1 / 32
        # global average pooling to build tail
        tail = torch.mean(feature4, 3, keepdim=True)
        tail = torch.mean(tail, 2, keepdim=True)
        return feature3, feature4, tail

組件融合

為了SP和CP更好的融合,提出了特征融合模塊FFM還有注意力優化模塊ARM。

ARM:

ARM使用在上下文路徑中,用於優化每一階段的特征,使用全局平均池化指導特征學習,計算成本可以忽略。其具體實現方式與SE模塊很類似,屬於通道注意力機制。

class AttentionRefinementModule(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.sigmoid = nn.Sigmoid()
        self.in_channels = in_channels
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))

    def forward(self, input):
        # global average pooling
        x = self.avgpool(input)
        assert self.in_channels == x.size(
            1), 'in_channels and out_channels should all be {}'.format(
                x.size(1))
        x = self.conv(x)
        # x = self.sigmoid(self.bn(x))
        x = self.sigmoid(x)
        # channels of input and x should be same
        x = torch.mul(input, x)
        return x

FFM:

特征融合模塊用於融合CP和SP提供的輸出特征,由於兩路特征並不相同,所以不能對這兩部分特征進行簡單的加權。SP提供的特征是低層次的(8×down),CP提供的特征是高層語義的(32×down)。

將兩個部分特征圖通過concate方式疊加,然后使用類似SE模塊的方式計算加權特征,起到特征選擇和結合的作用。(這種特征融合方式值得學習)

class FeatureFusionModule(torch.nn.Module):
    def __init__(self, num_classes, in_channels):
        super().__init__()
        self.in_channels = in_channels
        self.convblock = ConvBlock(in_channels=self.in_channels,
                                   out_channels=num_classes,
                                   stride=1)
        self.conv1 = nn.Conv2d(num_classes, num_classes, kernel_size=1)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(num_classes, num_classes, kernel_size=1)
        self.sigmoid = nn.Sigmoid()
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))

    def forward(self, input_1, input_2):
        x = torch.cat((input_1, input_2), dim=1)
        assert self.in_channels == x.size(
            1), 'in_channels of ConvBlock should be {}'.format(x.size(1))
        feature = self.convblock(x)
        x = self.avgpool(feature)

        x = self.relu(self.conv1(x))
        x = self.sigmoid(self.conv2(x))
        x = torch.mul(feature, x)
        x = torch.add(x, feature)
        return x

BiSeNet網絡整個模型:

class BiSeNet(torch.nn.Module):
    def __init__(self, num_classes, context_path):
        super().__init__()
        self.spatial_path = Spatial_path()
        self.context_path = build_contextpath(name=context_path)
        if context_path == 'resnet101':
            self.attention_refinement_module1 = AttentionRefinementModule(
                1024, 1024)
            self.attention_refinement_module2 = AttentionRefinementModule(
                2048, 2048)
            self.supervision1 = nn.Conv2d(in_channels=1024,
                                          out_channels=num_classes,
                                          kernel_size=1)
            self.supervision2 = nn.Conv2d(in_channels=2048,
                                          out_channels=num_classes,
                                          kernel_size=1)
            self.feature_fusion_module = FeatureFusionModule(num_classes, 3328)

        elif context_path == 'resnet18':
            self.attention_refinement_module1 = AttentionRefinementModule(
                256, 256)
            self.attention_refinement_module2 = AttentionRefinementModule(
                512, 512)
            self.supervision1 = nn.Conv2d(in_channels=256,
                                          out_channels=num_classes,
                                          kernel_size=1)
            self.supervision2 = nn.Conv2d(in_channels=512,
                                          out_channels=num_classes,
                                          kernel_size=1)
            self.feature_fusion_module = FeatureFusionModule(num_classes, 1024)
        else:
            print('Error: unspport context_path network \n')
        self.conv = nn.Conv2d(in_channels=num_classes,
                              out_channels=num_classes,
                              kernel_size=1)

    def forward(self, input):
        sx = self.spatial_path(input)
        cx1, cx2, tail = self.context_path(input)
        cx1 = self.attention_refinement_module1(cx1)
        cx2 = self.attention_refinement_module2(cx2)
        cx2 = torch.mul(cx2, tail)
        cx1 = torch.nn.functional.interpolate(cx1,
                                              size=sx.size()[-2:],
                                              mode='bilinear')
        cx2 = torch.nn.functional.interpolate(cx2,
                                              size=sx.size()[-2:],
                                              mode='bilinear')
        cx = torch.cat((cx1, cx2), dim=1)
        if self.training == True:
            cx1_sup = self.supervision1(cx1)
            cx2_sup = self.supervision2(cx2)
            cx1_sup = torch.nn.functional.interpolate(cx1_sup,
                                                      size=input.size()[-2:],
                                                      mode='bilinear')
            cx2_sup = torch.nn.functional.interpolate(cx2_sup,
                                                      size=input.size()[-2:],
                                                      mode='bilinear')
        result = self.feature_fusion_module(sx, cx)
        result = torch.nn.functional.interpolate(result,
                                                 scale_factor=8,
                                                 mode='bilinear')
        result = self.conv(result)
        if self.training == True:
            return result, cx1_sup, cx2_sup
        return result

4. 實驗

使用了Xception39處理實時語義分割任務,在CityScapes, CamVid和COCO stuff三個數據集上進行評估。

消融實驗:

測試了basemodel xception39,參數量要比ResNet18小得多,同時MIOU只略低於與ResNet18。

以上是BiSeNet各個模塊的消融實驗,可以看出,每個模塊都是有效的。

統一使用了640×360分辨率的圖片進行對比參數量和FLOPS狀態。

上表對BiSeNet網絡和其他網絡就MIOU和FPS上進行比較,可以看出該方法相比於其他方法在速度和精度方面有很大的優越性。

在使用ResNet101等比較深的網絡作為backbone的情況下,效果也是超過了其他常見的網絡,這證明了這個模型的有效性。

5. 結論

BiSeNet 旨在同時提升實時語義分割的速度與精度,它包含兩路網絡:Spatial Path 和 Context Path。Spatial Path 被設計用來保留原圖像的空間信息,Context Path 利用輕量級模型和全局平均池化快速獲取大感受野。由此,在 105 fps 的速度下,該方法在 Cityscapes 測試集上取得了 68.4% mIoU 的結果。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM