打通多個視覺任務的全能Backbone:HRNet


HRNet是微軟亞洲研究院的王井東老師領導的團隊完成的,打通圖像分類、圖像分割、目標檢測、人臉對齊、姿態識別、風格遷移、Image Inpainting、超分、optical flow、Depth estimation、邊緣檢測等網絡結構。

王老師在ValseWebinar《物體和關鍵點檢測》中親自講解了HRNet,講解地非常透徹。以下文章主要參考了王老師在演講中的解讀,配合論文+代碼部分,來為各位讀者介紹這個全能的Backbone-HRNet。

1. 引入

網絡結構設計思路

在人體姿態識別這類的任務中,需要生成一個高分辨率的heatmap來進行關鍵點檢測。這就與一般的網絡結構比如VGGNet的要求不同,因為VGGNet最終得到的feature map分辨率很低,損失了空間結構。

傳統的解決思路

獲取高分辨率的方式大部分都是如上圖所示,采用的是先降分辨率,然后再升分辨率的方法。U-Net、SegNet、DeconvNet、Hourglass本質上都是這種結構。

雖然看上去不同,但是本質是一致的

2. 核心

普通網絡都是這種結構,不同分辨率之間是進行了串聯

不斷降分辨率

王井東老師則是將不同分辨率的feature map進行並聯:

並聯不同分辨率feature map

在並聯的基礎上,添加不同分辨率feature map之間的交互(fusion)。

具體fusion的方法如下圖所示:

  • 同分辨率的層直接復制。
  • 需要升分辨率的使用bilinear upsample + 1x1卷積將channel數統一。
  • 需要降分辨率的使用strided 3x3 卷積。
  • 三個feature map融合的方式是相加。

至於為何要用strided 3x3卷積,這是因為卷積在降維的時候會出現信息損失,使用strided 3x3卷積是為了通過學習的方式,降低信息的損耗。所以這里沒有用maxpool或者組合池化。

HR示意圖

另外在讀HRNet的時候會有一個問題,有四個分支的到底如何使用這幾個分支呢?論文中也給出了幾種方式作為最終的特征選擇。

三種特征融合方法

(a)圖展示的是HRNetV1的特征選擇,只使用分辨率最高的特征圖。

(b)圖展示的是HRNetV2的特征選擇,將所有分辨率的特征圖(小的特征圖進行upsample)進行concate,主要用於語義分割和面部關鍵點檢測。

(c)圖展示的是HRNetV2p的特征選擇,在HRNetV2的基礎上,使用了一個特征金字塔,主要用於目標檢測網絡。

再補充一個(d)圖

HRNetV2分類網絡后的特征選擇

(d)圖展示的也是HRNetV2,采用上圖的融合方式,主要用於訓練分類網絡。

總結一下HRNet創新點

  • 將高低分辨率之間的鏈接由串聯改為並聯。
  • 在整個網絡結構中都保持了高分辨率的表征(最上邊那個通路)。
  • 在高低分辨率中引入了交互來提高模型性能。

3. 效果

3.1 消融實驗

  1. 對交互方法進行消融實驗,證明了當前跨分辨率的融合的有效性。

交互方法的消融實現

  1. 證明高分辨率feature map的表征能力

1x代表不進行降維,2x代表分辨率變為原來一半,4x代表分辨率變為原來四分之一。W32、W48中的32、48代表卷積的寬度或者通道數。

3.2 姿態識別任務上的表現

以上的姿態識別采用的是top-down的方法。

COCO驗證集的結果

在參數和計算量不增加的情況下,要比其他同類網絡效果好很多。

COCO測試集上的結果

在19年2月28日時的PoseTrack Leaderboard,HRNet占領兩個項目的第一名。

PoseTrack Leaderboard

3.3 語義分割任務中的表現

CityScape驗證集上的結果對比

Cityscapes測試集上的對比

3.4 目標檢測任務中的表現

單模型單尺度模型對比

Mask R-CNN上結果

3.5 分類任務上的表現

ps: 王井東老師在這部分提到,分割的網絡也需要使用分類的預訓練模型,否則結果會差幾個點。

圖像分類任務中和ResNet進行對比

以上是HRNet和ResNet結果對比,同一個顏色的都是參數量大體一致的模型進行的對比,在參數兩差不多甚至更少的情況下,HRNet能夠比ResNet達到更好的效果。

4. 代碼

HRNet( https://github.com/HRNet )工作量非常大,構建了六個庫涉及語義分割、人體姿態檢測、目標檢測、圖片分類、面部關鍵點檢測、Mask R-CNN等庫。全部內容如下圖所示:

筆者對HRNet代碼構建非常感興趣,所以以HRNet-Image-Classification庫為例,來解析一下這部分代碼。

先從簡單的入手,BasicBlock

BasicBlock結構

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

Bottleneck:

Bottleneck結構圖

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
                               bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion,
                                  momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

HighResolutionModule,這是核心模塊, 主要分為兩個組件:branches和fuse layer。

class HighResolutionModule(nn.Module):
    def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
                 num_channels, fuse_method, multi_scale_output=True):
        '''
        調用:
        # 調用高低分辨率交互模塊, stage2 為例
        HighResolutionModule(num_branches, # 2
                             block, # 'BASIC'
                             num_blocks, # [4, 4]
                             num_inchannels, # 上個stage的out channel
                             num_channels, # [32, 64]
                             fuse_method, # SUM
                             reset_multi_scale_output)
        '''
        super(HighResolutionModule, self).__init__()
        self._check_branches(
            # 檢查分支數目是否合理
            num_branches, blocks, num_blocks, num_inchannels, num_channels)

        self.num_inchannels = num_inchannels
        # 融合選用相加的方式
        self.fuse_method = fuse_method
        self.num_branches = num_branches

        self.multi_scale_output = multi_scale_output

        # 兩個核心部分,一個是branches構建,一個是融合layers構建
        self.branches = self._make_branches(
            num_branches, blocks, num_blocks, num_channels)
        self.fuse_layers = self._make_fuse_layers()

        self.relu = nn.ReLU(False)

    def _check_branches(self, num_branches, blocks, num_blocks,
                        num_inchannels, num_channels):
        # 分別檢查參數是否符合要求,看models.py中的參數,blocks參數冗余了
        if num_branches != len(num_blocks):
            error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
                num_branches, len(num_blocks))
            logger.error(error_msg)
            raise ValueError(error_msg)

        if num_branches != len(num_channels):
            error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
                num_branches, len(num_channels))
            logger.error(error_msg)
            raise ValueError(error_msg)

        if num_branches != len(num_inchannels):
            error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
                num_branches, len(num_inchannels))
            logger.error(error_msg)
            raise ValueError(error_msg)

    def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
                         stride=1):
        # 構建一個分支,一個分支重復num_blocks個block
        downsample = None

        # 這里判斷,如果通道變大(分辨率變小),則使用下采樣
        if stride != 1 or \
           self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.num_inchannels[branch_index],
                          num_channels[branch_index] * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(num_channels[branch_index] * block.expansion,
                               momentum=BN_MOMENTUM),
            )

        layers = []
        layers.append(block(self.num_inchannels[branch_index],
                            num_channels[branch_index], stride, downsample))

        self.num_inchannels[branch_index] = \
            num_channels[branch_index] * block.expansion

        for i in range(1, num_blocks[branch_index]):
            layers.append(block(self.num_inchannels[branch_index],
                                num_channels[branch_index]))

        return nn.Sequential(*layers)

    def _make_branches(self, num_branches, block, num_blocks, num_channels):
        branches = []
        
        # 通過循環構建多分支,每個分支屬於不同的分辨率
        for i in range(num_branches):
            branches.append(
                self._make_one_branch(i, block, num_blocks, num_channels))

        return nn.ModuleList(branches)

    def _make_fuse_layers(self):
        if self.num_branches == 1:
            return None

        num_branches = self.num_branches # 2
        num_inchannels = self.num_inchannels
        fuse_layers = []
        for i in range(num_branches if self.multi_scale_output else 1):
            # i代表枚舉所有分支
            fuse_layer = []
            for j in range(num_branches):
                # j代表處理的當前分支
                if j > i: # 進行上采樣,使用最近鄰插值
                    fuse_layer.append(nn.Sequential(
                        nn.Conv2d(num_inchannels[j],
                                  num_inchannels[i],
                                  1,
                                  1,
                                  0,
                                  bias=False),
                        nn.BatchNorm2d(num_inchannels[i],
                                       momentum=BN_MOMENTUM),
                        nn.Upsample(scale_factor=2**(j-i), mode='nearest')))
                elif j == i:
                    # 本層不做處理
                    fuse_layer.append(None)
                else:
                    conv3x3s = []
                    # 進行strided 3x3 conv下采樣,如果跨兩層,就使用兩次strided 3x3 conv
                    for k in range(i-j):
                        if k == i - j - 1:
                            num_outchannels_conv3x3 = num_inchannels[i]
                            conv3x3s.append(nn.Sequential(
                                nn.Conv2d(num_inchannels[j],
                                          num_outchannels_conv3x3,
                                          3, 2, 1, bias=False),
                                nn.BatchNorm2d(num_outchannels_conv3x3,
                                               momentum=BN_MOMENTUM)))
                        else:
                            num_outchannels_conv3x3 = num_inchannels[j]
                            conv3x3s.append(nn.Sequential(
                                nn.Conv2d(num_inchannels[j],
                                          num_outchannels_conv3x3,
                                          3, 2, 1, bias=False),
                                nn.BatchNorm2d(num_outchannels_conv3x3,
                                nn.ReLU(False)))
                    fuse_layer.append(nn.Sequential(*conv3x3s))
            fuse_layers.append(nn.ModuleList(fuse_layer))

        return nn.ModuleList(fuse_layers)

    def get_num_inchannels(self):
        return self.num_inchannels

    def forward(self, x):
        if self.num_branches == 1:
            return [self.branches[0](x[0])]

        for i in range(self.num_branches):
            x[i]=self.branches[i](x[i])

        x_fuse=[]
        for i in range(len(self.fuse_layers)):
            y=x[0] if i == 0 else self.fuse_layers[i][0](x[0])
            for j in range(1, self.num_branches):
                if i == j:
                    y=y + x[j]
                else:
                    y=y + self.fuse_layers[i][j](x[j])
            x_fuse.append(self.relu(y))

        # 將fuse以后的多個分支結果保存到list中
        return x_fuse

models.py中保存的參數, 可以通過這些配置來改變模型的容量、分支個數、特征融合方法:

# high_resoluton_net related params for classification
POSE_HIGH_RESOLUTION_NET = CN()
POSE_HIGH_RESOLUTION_NET.PRETRAINED_LAYERS = ['*']
POSE_HIGH_RESOLUTION_NET.STEM_INPLANES = 64
POSE_HIGH_RESOLUTION_NET.FINAL_CONV_KERNEL = 1
POSE_HIGH_RESOLUTION_NET.WITH_HEAD = True

POSE_HIGH_RESOLUTION_NET.STAGE2 = CN()
POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_MODULES = 1
POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BRANCHES = 2
POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BLOCKS = [4, 4]
POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_CHANNELS = [32, 64]
POSE_HIGH_RESOLUTION_NET.STAGE2.BLOCK = 'BASIC'
POSE_HIGH_RESOLUTION_NET.STAGE2.FUSE_METHOD = 'SUM'

POSE_HIGH_RESOLUTION_NET.STAGE3 = CN()
POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_MODULES = 1
POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BRANCHES = 3
POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BLOCKS = [4, 4, 4]
POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_CHANNELS = [32, 64, 128]
POSE_HIGH_RESOLUTION_NET.STAGE3.BLOCK = 'BASIC'
POSE_HIGH_RESOLUTION_NET.STAGE3.FUSE_METHOD = 'SUM'

POSE_HIGH_RESOLUTION_NET.STAGE4 = CN()
POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_MODULES = 1
POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BRANCHES = 4
POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
POSE_HIGH_RESOLUTION_NET.STAGE4.BLOCK = 'BASIC'
POSE_HIGH_RESOLUTION_NET.STAGE4.FUSE_METHOD = 'SUM'

然后來看整個HRNet模型的構建, 由於整體代碼量太大,這里僅僅來看forward函數。

def forward(self, x):

    # 使用兩個strided 3x3conv進行快速降維
    x=self.relu(self.bn1(self.conv1(x)))
    x=self.relu(self.bn2(self.conv2(x)))

    # 構建了一串BasicBlock構成的模塊
    x=self.layer1(x)

    # 然后是多個stage,每個stage核心是調用HighResolutionModule模塊
    x_list=[]
    for i in range(self.stage2_cfg['NUM_BRANCHES']):
        if self.transition1[i] is not None:
            x_list.append(self.transition1[i](x))
        else:
            x_list.append(x)
    y_list=self.stage2(x_list)

    x_list=[]
    for i in range(self.stage3_cfg['NUM_BRANCHES']):
        if self.transition2[i] is not None:
            x_list.append(self.transition2[i](y_list[-1]))
        else:
            x_list.append(y_list[i])
    y_list=self.stage3(x_list)

    x_list=[]
    for i in range(self.stage4_cfg['NUM_BRANCHES']):
        if self.transition3[i] is not None:
            x_list.append(self.transition3[i](y_list[-1]))
        else:
            x_list.append(y_list[i])
    y_list=self.stage4(x_list)

    # 添加分類頭,上文中有顯示,在分類問題中添加這種頭
    # 在其他問題中換用不同的頭
    y=self.incre_modules[0](y_list[0])
    for i in range(len(self.downsamp_modules)):
        y=self.incre_modules[i+1](y_list[i+1]) + \
            self.downsamp_modules[i](y)
    y=self.final_layer(y)

    if torch._C._get_tracing_state():
        # 在不寫C代碼的情況下執行forward,直接用python版本
        y=y.flatten(start_dim=2).mean(dim=2)
    else:
        y=F.avg_pool2d(y, kernel_size=y.size()
                            [2:]).view(y.size(0), -1)
    y=self.classifier(y)

    return y

5. 總結

HRNet核心方法是:在模型的整個過程中,保存高分辨率表征的同時使用讓不同分辨率的feature map進行特征交互。

HRNet在非常多的CV領域有廣泛的應用,比如ICCV2019的東北虎關鍵點識別比賽中,HRNet就起到了一定的作用。並且在分類部分的實驗證明了在同等參數量的情況下,可以取代ResNet進行分類。

之前看鄭安坤大佬的一篇文章CNN結構設計技巧-兼顧速度精度與工程實現中提到了一點:

senet是hrnet的一個特例,hrnet不僅有通道注意力,同時也有空間注意力

-- akkaze-鄭安坤

SELayer核心實現

SELayer首先通過一個全局平均池化得到一個一維向量,然后通過兩個全連接層,將信息進行壓縮和擴展,通過sigmoid以后得到每個通道的權值,然后用這個權值與原來的feature map相乘,進行信息上的優化。

HRNet一個結構

可以看到上圖用紅色箭頭串起來的是不是和SELayer很相似。為什么說SENet是HRNet的一個特例,但從這個結構來講,可以這么看:

  • SENet沒有像HRNet這樣分辨率變為原來的一半,分辨率直接變為1x1,比較極端。變為1x1向量以后,SENet中使用了兩個全連接網絡來學習通道的特征分布;但是在HRNet中,使用了幾個卷積(Residual block)來學習特征。
  • SENet在主干部分(高分辨率分支)沒有安排卷積進行特征的學習;HRNet中主干部分(高分辨率分支)安排了幾個卷積(Residual block)來學習特征。
  • 特征融合部分SENet和HRNet區分比較大,SENet使用的對應通道相乘的方法,HRNet則使用的是相加。之所以說SENet是通道注意力機制是因為通過全局平均池化后沒有了空間特征,只剩通道的特征;HRNet則可以看作同時保留了空間特征和通道特征,所以說HRNet不僅有通道注意力,同時也有空間注意力。

HRNet團隊有10人之多,構建了分類、分割、檢測、關鍵點檢測等庫,工作量非常大,而且做了很多扎實的實驗證明了這種思路的有效性。所以是否可以認為HRNet屬於SENet之后又一個更優的backbone呢?還需要自己實踐中使用這種想法和思路來驗證。

6. 參考

https://arxiv.org/pdf/1908.07919

https://www.bilibili.com/video/BV1WJ41197dh?t=508

https://github.com/HRNet


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM