(原)人體姿態識別PyraNet


轉載請注明出處:

https://www.cnblogs.com/darkknightzh/p/12424767.html

論文:

Learning Feature Pyramids for Human Pose Estimation

https://arxiv.org/abs/1708.01101

第三方pytorch代碼:

https://github.com/Naman-ntc/Pytorch-Human-Pose-Estimation

1. 整體結構

將hourglass的殘差模塊改為金字塔殘差模塊(白框),用於學習輸入圖像不同尺度的特征。

hourglass見https://www.cnblogs.com/darkknightzh/p/11486185.html。參考代碼中的Hourglass內部也使用了PRM模塊,而不是原始的Hourglass。

該算法在stacked hourglass的基礎上更容易理解。

2. 金字塔殘差模塊PRM

論文給出了4中PRM(金字塔殘差模塊)的結構,最終發現PRM-B的效果最好,如下圖所示。其中虛線代表同等映射,白色虛框代表該處無上采樣或下采樣。

3. 下采樣

由於pooling下采樣速度太快,下采樣倍數最低為2,因而論文未使用pool。而是使用了fractional max-pooling的下采樣方式,第c層的下采樣率(論文中M=1,C=4):

${{s}_{c}}={{2}^{-M\frac{c}{C}}},c=0,\cdots ,C,M\ge 1$

4. 訓練及測試

訓練階段和其他姿態估計算法相似,都是估計熱圖,然后計算真值熱圖和估計熱圖的均方誤差,如下

$L=\frac{1}{2}\sum\limits_{n=1}^{N}{\sum\limits_{k=1}^{K}{{{\left\| {{\mathbf{S}}_{k}}-{{{\mathbf{\hat{S}}}}_{k}} \right\|}^{2}}}}$

其中N為樣本數量,K為關鍵點的數量(也即熱圖數量)

測試階段,使用最后一個hourglass熱圖最大的score的位置作為關鍵點。由於該算法為自頂向下的姿態估計算法,輸入網絡的圖像僅有一個人,因而最大score的位置即為對應的關鍵點。

${{\mathbf{\hat{z}}}_{k}}=\underset{\mathbf{p}}{\mathop{\arg \max }}\,{{\mathbf{\hat{S}}}_{k}}(\mathbf{p}),k=1,L,K$

5. 代碼

PyraNet定義如下:

 1 class PyraNet(nn.Module):
 2     """docstring for PyraNet"""
 3     def __init__(self, nChannels=256, nStack=4, nModules=2, numReductions=4, baseWidth=6, cardinality=30, nJoints=16, inputRes=256):
 4         super(PyraNet, self).__init__()
 5         self.nChannels = nChannels
 6         self.nStack = nStack
 7         self.nModules = nModules
 8         self.numReductions = numReductions
 9         self.baseWidth = baseWidth
10         self.cardinality = cardinality
11         self.inputRes = inputRes
12         self.nJoints = nJoints
13 
14         self.start = M.BnReluConv(3, 64, kernelSize = 7, stride = 2, padding = 3)   # BN+ReLU+conv
15 
16         # 先通過兩分支(1*1 conv+3*3 conv,1*1 conv+不同尺度特征之和+3*3 conv,這兩分支求和,並使用1*1 conv升維),並在輸入輸出通道相等時,直接返回,否則使用1*1 conv相加
17         self.res1 = M.ResidualPyramid(64, 128, self.inputRes//2, self.baseWidth, self.cardinality, 0)
18         self.mp = nn.MaxPool2d(2, 2)
19         self.res2 = M.ResidualPyramid(128, 128, self.inputRes//4, self.baseWidth, self.cardinality,)  # 先通過兩分支,並在輸入輸出通道相等時,直接返回,否則使用1*1 conv相加
20         self.res3 = M.ResidualPyramid(128, self.nChannels, self.inputRes//4, self.baseWidth, self.cardinality)  # 先通過兩分支,並在輸入輸出通道相等時,直接返回,否則使用1*1 conv相加
21 
22         _hourglass, _Residual, _lin1, _chantojoints, _lin2, _jointstochan = [],[],[],[],[],[]
23 
24         for _ in range(self.nStack):   # 堆疊個數
25             _hourglass.append(PyraNetHourGlass(self.nChannels, self.numReductions, self.nModules, self.inputRes//4, self.baseWidth, self.cardinality))
26             _ResidualModules = []
27             for _ in range(self.nModules):
28                 _ResidualModules.append(M.Residual(self.nChannels, self.nChannels))     # 輸入和輸出相等,只有3*(BN+ReLU+conv)
29             _ResidualModules = nn.Sequential(*_ResidualModules)
30             _Residual.append(_ResidualModules)
31             _lin1.append(M.BnReluConv(self.nChannels, self.nChannels))        # BN+ReLU+conv
32             _chantojoints.append(nn.Conv2d(self.nChannels, self.nJoints,1))   # 1*1 conv,維度變換
33             _lin2.append(nn.Conv2d(self.nChannels, self.nChannels,1))         # 1*1 conv,維度變換
34             _jointstochan.append(nn.Conv2d(self.nJoints,self.nChannels,1))    # 1*1 conv,維度變換
35 
36         self.hourglass = nn.ModuleList(_hourglass)
37         self.Residual = nn.ModuleList(_Residual)
38         self.lin1 = nn.ModuleList(_lin1)
39         self.chantojoints = nn.ModuleList(_chantojoints)
40         self.lin2 = nn.ModuleList(_lin2)
41         self.jointstochan = nn.ModuleList(_jointstochan)
42 
43     def forward(self, x):
44         x = self.start(x)
45         x = self.res1(x)
46         x = self.mp(x)
47         x = self.res2(x)
48         x = self.res3(x)
49         out = []
50 
51         for i in range(self.nStack):
52             x1 = self.hourglass[i](x)
53             x1 = self.Residual[i](x1)
54             x1 = self.lin1[i](x1)
55             out.append(self.chantojoints[i](x1))
56             x1 = self.lin2[i](x1)
57             x = x + x1 + self.jointstochan[i](out[i])     # 特征求和
58 
59         return (out)
View Code

ResidualPyramid定義如下:

 1 class ResidualPyramid(nn.Module):
 2     """docstring for ResidualPyramid"""
 3     # 先通過兩分支(1*1 conv+3*3 conv,1*1 conv+不同尺度特征之和+3*3 conv,這兩分支求和,並使用1*1 conv升維),並在輸入輸出通道相等時,直接返回,否則使用1*1 conv相加
 4     def __init__(self, inChannels, outChannels, inputRes, baseWidth, cardinality, type = 1):
 5         super(ResidualPyramid, self).__init__()
 6         self.inChannels = inChannels
 7         self.outChannels = outChannels
 8         self.inputRes = inputRes
 9         self.baseWidth = baseWidth
10         self.cardinality = cardinality
11         self.type = type
12         # PyraConvBlock:兩分支,一個是1*1 conv+3*3 conv,一個是1*1 conv+不同尺度特征之和+3*3 conv,這兩分支求和,並使用1*1 conv升維
13         self.cb = PyraConvBlock(self.inChannels, self.outChannels, self.inputRes, self.baseWidth, self.cardinality, self.type)
14         self.skip = SkipLayer(self.inChannels, self.outChannels)         # 輸入和輸出通道相等,則為None,否則為1*1 conv
15 
16     def forward(self, x):
17         out = 0
18         out = out + self.cb(x)
19         out = out + self.skip(x)
20         return out
View Code

PyraConvBlock如下:

 1 class PyraConvBlock(nn.Module):
 2     """docstring for PyraConvBlock"""     # 兩分支,一個是1*1 conv+3*3 conv,一個是1*1 conv+不同尺度特征之和+3*3 conv,這兩分支求和,並使用1*1 conv升維
 3     def __init__(self, inChannels, outChannels, inputRes, baseWidth, cardinality, type = 1):
 4         super(PyraConvBlock, self).__init__()
 5         self.inChannels = inChannels
 6         self.outChannels = outChannels
 7         self.inputRes = inputRes
 8         self.baseWidth = baseWidth
 9         self.cardinality = cardinality
10         self.outChannelsby2 = outChannels//2
11         self.D = self.outChannels // self.baseWidth
12         self.branch1 = nn.Sequential(   # 第一個分支,1*1 conv + 3*3 conv
13                 BnReluConv(self.inChannels, self.outChannelsby2, 1, 1, 0),           # BN+ReLU+conv
14                 BnReluConv(self.outChannelsby2, self.outChannelsby2, 3, 1, 1)        # BN+ReLU+conv
15             )
16         self.branch2 = nn.Sequential(   # 第二個分支,1*1 conv + 3*3 conv
17                 BnReluConv(self.inChannels, self.D, 1, 1, 0),                        # BN+ReLU+conv
18                 BnReluPyra(self.D, self.cardinality, self.inputRes),                 # BN+ReLU+不同尺度的特征之和
19                 BnReluConv(self.D, self.outChannelsby2, 1, 1, 0)                     # BN+ReLU+conv
20             )
21         self.afteradd = BnReluConv(self.outChannelsby2, self.outChannels, 1, 1, 0)   # BN+ReLU+conv
22 
23     def forward(self, x):
24         x = self.branch2(x) + self.branch1(x)                                        # 兩個分支特征之和
25         x = self.afteradd(x)                                                         # 1*1 conv進行升維
26         return x
View Code

BnReluPyra如下

 1 class BnReluPyra(nn.Module):
 2     """docstring for BnReluPyra"""     # BN + ReLU + 不同尺度的特征之和
 3     def __init__(self, D, cardinality, inputRes):
 4         super(BnReluPyra, self).__init__()
 5         self.D = D
 6         self.cardinality = cardinality
 7         self.inputRes = inputRes
 8         self.bn = nn.BatchNorm2d(self.D)
 9         self.relu = nn.ReLU()
10         self.pyra = Pyramid(self.D, self.cardinality, self.inputRes)     # 將不同尺度的特征求和
11 
12     def forward(self, x):
13         x = self.bn(x)
14         x = self.relu(x)
15         x = self.pyra(x)
16         return x
View Code

Pyramid如下:

 1 class Pyramid(nn.Module):
 2     """docstring for Pyramid"""     # 將不同尺度的特征求和
 3     def __init__(self, D, cardinality, inputRes):
 4         super(Pyramid, self).__init__()
 5         self.D = D
 6         self.cardinality = cardinality     # 論文中公式3的C,金字塔層數
 7         self.inputRes = inputRes
 8         self.scale = 2**(-1/self.cardinality)   # 金字塔第1層的下采樣率,后面層在此基礎上+1
 9         _scales = []
10         for card in range(self.cardinality):
11             temp = nn.Sequential(    # 下采樣 + 3*3 conv + 上采樣
12                     nn.FractionalMaxPool2d(2, output_ratio = self.scale**(card + 1)),  # 每一層在第1層基礎上+1的下采樣率
13                     nn.Conv2d(self.D, self.D, 3, 1, 1),
14                     nn.Upsample(size = self.inputRes)#, mode='bilinear')   # 上采樣到輸入分辨率
15                 )
16             _scales.append(temp)
17         self.scales = nn.ModuleList(_scales)
18 
19     def forward(self, x):
20         #print(x.shape, self.inputRes)
21         out = torch.zeros_like(x)         # 初始化和輸入大小一樣的0矩陣
22         for card in range(self.cardinality):
23             out += self.scales[card](x)    # 將所有尺度的特征求和
24         return out
View Code

PyraNetHourGlass如下:

 1 class PyraNetHourGlass(nn.Module):
 2     """docstring for PyraNetHourGlass"""
 3     def __init__(self, nChannels=256, numReductions=4, nModules=2, inputRes=256, baseWidth=6, cardinality=30, poolKernel=(2,2), poolStride=(2,2), upSampleKernel=2):
 4         super(PyraNetHourGlass, self).__init__()
 5         self.numReductions = numReductions
 6         self.nModules = nModules
 7         self.nChannels = nChannels
 8         self.poolKernel = poolKernel
 9         self.poolStride = poolStride
10         self.upSampleKernel = upSampleKernel
11 
12         self.inputRes = inputRes
13         self.baseWidth = baseWidth
14         self.cardinality = cardinality
15 
16         """ For the skip connection, a residual module (or sequence of residuaql modules)  """
17         # ResidualPyramid:先通過兩分支,並在輸入輸出通道相等時,直接返回,否則使用1*1 conv相加
18         # Residual:輸入和輸出相等,只有3*(BN+ReLU+conv)
19         Residualskip = M.ResidualPyramid if numReductions > 1 else M.Residual
20         Residualmain = M.ResidualPyramid if numReductions > 2 else M.Residual
21         _skip = []
22         for _ in range(self.nModules):  # 根據numReductions確定使用金字塔還是3*(BN+ReLU+conv)
23             _skip.append(Residualskip(self.nChannels, self.nChannels, self.inputRes, self.baseWidth, self.cardinality))
24         self.skip = nn.Sequential(*_skip)
25 
26         """ First pooling to go to smaller dimension then pass input through
27         Residual Module or sequence of Modules then  and subsequent cases:
28             either pass through Hourglass of numReductions-1 or pass through Residual Module or sequence of Modules """
29         self.mp = nn.MaxPool2d(self.poolKernel, self.poolStride)
30 
31         _afterpool = []
32         for _ in range(self.nModules):   # 根據numReductions確定使用金字塔還是3*(BN+ReLU+conv)
33             _afterpool.append(Residualmain(self.nChannels, self.nChannels, self.inputRes//2, self.baseWidth, self.cardinality))
34         self.afterpool = nn.Sequential(*_afterpool)
35 
36         if (numReductions > 1):     # 嵌套調用本身
37             self.hg = PyraNetHourGlass(self.nChannels, self.numReductions-1, self.nModules, self.inputRes//2, self.baseWidth,
38                                        self.cardinality, self.poolKernel, self.poolStride, self.upSampleKernel)
39         else:
40             _num1res = []
41             for _ in range(self.nModules):    # 根據numReductions確定使用金字塔還是3*(BN+ReLU+conv)
42                 _num1res.append(Residualmain(self.nChannels,self.nChannels, self.inputRes//2, self.baseWidth, self.cardinality))
43             self.num1res = nn.Sequential(*_num1res)  # doesnt seem that important ?
44 
45         """ Now another Residual Module or sequence of Residual Modules """
46         _lowres = []
47         for _ in range(self.nModules):    # 根據numReductions確定使用金字塔還是3*(BN+ReLU+conv)
48             _lowres.append(Residualmain(self.nChannels,self.nChannels, self.inputRes//2, self.baseWidth, self.cardinality))
49         self.lowres = nn.Sequential(*_lowres)
50 
51         """ Upsampling Layer (Can we change this??????) As per Newell's paper upsamping recommended  """
52         self.up = nn.Upsample(scale_factor = self.upSampleKernel)     # 將高和寬擴充,實現上采樣
53 
54     def forward(self, x):
55         out1 = x
56         out1 = self.skip(out1)             # 根據numReductions確定使用金字塔還是3*(BN+ReLU+conv)
57         out2 = x
58         out2 = self.mp(out2)               # 根據numReductions確定使用金字塔還是3*(BN+ReLU+conv)
59         out2 = self.afterpool(out2)
60         if self.numReductions>1:
61             out2 = self.hg(out2)           # 嵌套調用本身
62         else:
63             out2 = self.num1res(out2)      # 根據numReductions確定使用金字塔還是3*(BN+ReLU+conv)
64         out2 = self.lowres(out2)           # 根據numReductions確定使用金字塔還是3*(BN+ReLU+conv)
65         out2 = self.up(out2)               # 升維
66 
67         return out2 + out1                 # 求和
View Code

Residual如下:

 1 class Residual(nn.Module):
 2     """docstring for Residual"""     # 輸入和輸出相等,只有3*(BN+ReLU+conv);否則輸入通過1*1conv結果和3*(BN+ReLU+conv)求和
 3     def __init__(self, inChannels, outChannels, inputRes=None, baseWidth=None, cardinality=None, type=None):
 4         super(Residual, self).__init__()
 5         self.inChannels = inChannels
 6         self.outChannels = outChannels
 7         self.cb = ConvBlock(self.inChannels, self.outChannels)      # 3 * (BN+ReLU+conv) 其中第一組降維,第二組不變,第三組升維
 8         self.skip = SkipLayer(self.inChannels, self.outChannels)    # 輸入和輸出通道相等,則為None,否則為1*1 conv
 9 
10     def forward(self, x):
11         out = 0
12         out = out + self.cb(x)
13         out = out + self.skip(x)
14         return out
View Code

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM