【語義分割】Stacked Hourglass Networks 以及 PyTorch 實現


Stacked Hourglass Networks(級聯漏斗網絡)

姿態估計(Pose Estimation)是 CV 領域一個非常重要的方向,而級聯漏斗網絡的提出就是為了提升姿態估計的效果,但是其中的經典思想可以擴展到其他方向,比如目標識別方向,代表網絡是 CornerNet(預測目標的左上角和右下角點,再進行組合畫框)。

CNN 之所以有效,是因為它能自動提取出對分類、檢測和識別等任務有幫助的特征,並且隨着網絡層數的增加,所提取的特征逐漸變得抽象。以人臉識別為例,低層卷積網絡能夠提取出一些簡單的特征,如輪廓;中間卷積網絡能夠提取出抽象一些的特征,如眼睛鼻子;較高層的卷積網絡則能提取出更加抽象的特征,比如完整的人臉。這些將有助於我們理解級聯漏斗模型(Stacked Hourglass Model,簡稱SHM)為什么有效。

做姿態估計,需要預測身體不同的關節點,手臂這種線條簡單的結構,可能在中間卷積網絡更容易被識別;而面部這種線條復雜的結構,可能在高層卷積網絡才更容易被識別。因此,如果我們只使用最后一層的 feature map,就會造成一些信息的丟失。SHN 的主要貢獻——利用多尺度特征來識別姿態。

Single Hourglass Network

上圖是單個漏斗網絡的結構。該結構與全卷積網絡和其它設計(以多尺度方式處理空間信息,並進行密集預測)緊密相連。然而漏斗網絡與其它設計有什么不同呢?由圖可以看出,其自底向上(從高分辨率到低分辨率)處理和自頂向下(從低分辨率到高分辨率)處理之間的容量分布(這里實在不知道怎么翻譯。。。)更加對稱。另外還有一點需要注意,在自頂向下處理過程中,使用的不是 unpooing(一種常見的上采樣操作)或者 deconv layers(可稱為去卷積層),而是采用nearest neighbor upsampling(最近鄰上采樣)和 skip connections。這些操作需要在源碼中理解。

Stacked Hourglass Networks

StackedHourglass_1

上圖是單個漏斗網絡后面的一些設計以及兩個漏斗網絡的連接細節

塊1 是上面介紹的單個沙漏網絡,在它后面是一個 1$\times\(1 的全卷積網絡,即塊2;塊2 后面分離出上下兩個分支(塊3 和塊4):上分支(塊3)依然是一個 1\)\times$1 的全卷積網絡,下分支(塊4)為 Heat map(下面重點介紹)。塊5 是對塊4 進行 channal 上的擴增,以方便塊3、塊5 和 上個漏斗網絡的輸出進行合並,一起作為當前漏斗網絡的輸出,同時是下一個漏斗網絡的輸入。

這里對 Heat map 進行解釋:大部分姿態檢測的最后一步是對 feature map 上的每個像素做概率預測,計算該像素是某個關節點的概率,而這里的 feature map 就是上面輸出的 Heat map。使用它與真值進行誤差計算。應用中,如果多個 Hourglass Module 組合在一起進行梯度下降,輸出層的誤差經過多層反向傳播會大幅減小,也就是發生了梯度消失。因此,在整個網絡中每個Hourglass Module 后面都會輸出 Heat map 來計算損失。這種方法稱為 中間監督(Intermediate Supervision),可以保證底層參數正常更新。

之所以使用多個 Stack Hourglass,是為了重復自下而上和自上而下的推理機制,允許重新評估整個圖像的初始估計和特征,實現這一過程的核心就是預測中間的 Heat map,並讓中間 Heat map 參與 loss 計算。


PyTorch 實現 Model

  1. 首先定義殘差網絡的基本模塊:

    HgResBlock
    import torch.nn as nn
    
    
    class HgResBlock(nn.Module):
    
        def __init__(self, inplanes, outplanes, stride=1):
            super(HgResBlock, self).__init__()
    
            self.inplanes = inplanes
            self.outplanes = outplanes
            midplanes = outplanes // 2
    
            self.bn_1 = nn.BatchNorm2d(inplanes)
            self.conv_1 = nn.Conv2d(inplanes, midplanes, kernel_size=1, stride=stride)
            self.bn_2 = nn.BatchNorm2d(midplanes)
            self.conv_2 = nn.Conv2d(midplanes, midplanes, kernel_size=3, stride=1, padding=1)
            self.bn_3 = nn.BatchNorm2d(midplanes)
            self.conv_3 = nn.Conv2d(midplanes, outplanes, kernel_size=1, stride=1)
            self.relu = nn.ReLU(inplanes=True)
            if inplanes != outplanes:
                self.conv_skip = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1)
    
        # Bottle neck
        def forward(self, x):
            residual = x
    
            out = self.bn_1(x)
            out = self.conv_1(out)
            out = self.relu(out)
    
            out = self.bn_2(out)
            out = self.conv_2(out)
            out = self.relu(out)
    
            out = self.bn_3(out)
            out = self.conv_3(out)
            out = self.relu(out)
    
            if self.inplanes != self.outplanes:
                residual = self.conv_skip(residual)
            out += residual
    
            return out
    
  2. 定義單個的 Hourglass Module(注意這里用到了遞歸):

    HourglassNetwork
    import torch.nn as nn
    
    
    class Hourglass(nn.Module):
    
        def __init__(self, depth, nFeat, nModules, resBlocks):
            super(Hourglass, self).__init__()
    
            self.depth = depth
            self.nFeat = nFeat
            self.nModules = nModules
            self.resBlocks = resBlocks
    
            self.hg = self._make_hourglass()
            self.downsample = nn.MaxPool2d(kernel_size=2, stride=2)
            self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
    
        def _make_residual(self, n):
            return nn.Sequential(*[self.resBlocks(self.nFeat, self.nFeat) for _ in range(n)])
    
        def _make_hourglass(self):
            hg = []
    
            for i in range(self.depth):
                res = [self._make_residual(self.nModules) for _ in range(3)]
                if i == (self.depth - 1):
                    res.append(self._make_residual(self.nModules))      # extra one for the middle
                hg.append(nn.ModuleList(res))
    
            return nn.ModuleList(hg)
    
        def _hourglass_forward(self, depth_id, x):
            up_1 = self.hg[depth_id][0](x)
            low_1 = self.downsample(x)
            low_1 = self.hg[depth_id][1](low_1)
    
            if depth_id == (self.depth - 1):
                low_2 = self.hg[depth_id][3](low_1)
            else:
                low_2 = self._hourglass_forward(depth_id+1, low_1)
    
            low_3 = self.hg[depth_id][2](low_2)
            up_2 = self.upsample(low_3)
    
            return up_1 + up_2
    
        def forward(self, x):
            return self._hourglass_forward(0, x)
    
  3. 定義 Stacked Hourglass Network:

    StackedHourglass_2
    import torch.nn as nn
    
    from Model.HgResBlock import HgResBlock
    from Model.SingleHourglass import Hourglass
    
    class HourglassNet(nn.Module):
    
        def __init__(self, nStacks, nModules, nFeat, nClasses, resBlock=HgResBlock, inplanes=3):
            super(HourglassNet, self).__init__()
    
            self.nStacks = nStacks
            self.nModules = nModules
            self.nFeat = nFeat
            self.nClasses = nClasses
            self.resBlock = resBlock
            self.inplanes = inplanes
    
            hg, res, fc, score, fc_, score_ = [], [], [], [], [], []
    
            for i in range(nStacks):
                hg.append(Hourglass(depth=4, nFeat=nFeat, nModules=nModules, resBlocks=resBlock))
                res.append(self._make_residual(nModules))
                fc.append(self._make_fc(nFeat, nFeat))
                score.append(nn.Conv2d(nFeat, nClasses, kernel_size=1))
                if i < (nStacks - 1):
                    fc_.append(nn.Conv2d(nFeat, nFeat, kernel_size=1))
                    score_.append(nn.Conv2d(nClasses, nFeat, kernel_size=1))
    
            self.hg = nn.ModuleList(hg)
            self.res = nn.ModuleList(res)
            self.fc = nn.ModuleList(fc)
            self.score = nn.ModuleList(score)
            self.fc_ = nn.ModuleList(fc_)
            self.score_ = nn.ModuleList(score_)
    
        def _make_head(self):
            self.conv_1 = nn.Conv2d(self.inplanes, 64, kernel_size=7, stride=2, padding=3)
            self.bn_1 = nn.BatchNorm2d(64)
            self.relu = nn.ReLU(inplace=True)
    
            self.res_1 = self.resBlock(64, 128)
            self.pool = nn.MaxPool2d(2, 2)
            self.res_2 = self.resBlock(128, 128)
            self.res_3 = self.resBlock(128, self.nFeat)
    
        def _make_residual(self, n):
            return nn.Sequential(*[self.resBlock(self.nFeat, self.nFeat) for _ in range(n)])
    
        def _make_fc(self, inplanes, outplanes):
            return nn.Sequential(
                nn.Conv2d(inplanes, outplanes, kernel_size=1),
                nn.BatchNorm2d(outplanes),
                nn.ReLU(True))
    
        def forward(self, x):
            # head
            x = self.conv_1(x)
            x = self.bn_1(x)
            x = self.relu(x)
    
            x = self.res_1(x)
            x = self.pool(x)
            x = self.res_2(x)
            x = self.res_3(x)
    
            out = []
    
            for i in range(self.nStacks):
                y = self.hg[i](x)
                y = self.res[i](y)
                y = self.fc[i](y)
                score = self.score[i](y)
                out.append(score)
                if i < (self.nStacks - 1):
                    fc_ = self.fc_[i](y)
                    score_ = self.score_[i](score)
                    x = x + fc_ + score_
    
            return out
    

References:

​ [1] Stacked Hourglass Networks for Human Pose Estimation

​ [2] [hourglass pytorch 實現]
(https://blog.csdn.net/github_36923418/article/details/81030883)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM