轉載請注明出處:
https://www.cnblogs.com/darkknightzh/p/12424767.html
論文:
Learning Feature Pyramids for Human Pose Estimation
https://arxiv.org/abs/1708.01101
第三方pytorch代碼:
https://github.com/Naman-ntc/Pytorch-Human-Pose-Estimation
1. 整體結構
將hourglass的殘差模塊改為金字塔殘差模塊(白框),用於學習輸入圖像不同尺度的特征。
hourglass見https://www.cnblogs.com/darkknightzh/p/11486185.html。參考代碼中的Hourglass內部也使用了PRM模塊,而不是原始的Hourglass。
該算法在stacked hourglass的基礎上更容易理解。
2. 金字塔殘差模塊PRM
論文給出了4中PRM(金字塔殘差模塊)的結構,最終發現PRM-B的效果最好,如下圖所示。其中虛線代表同等映射,白色虛框代表該處無上采樣或下采樣。
3. 下采樣
由於pooling下采樣速度太快,下采樣倍數最低為2,因而論文未使用pool。而是使用了fractional max-pooling的下采樣方式,第c層的下采樣率(論文中M=1,C=4):
${{s}_{c}}={{2}^{-M\frac{c}{C}}},c=0,\cdots ,C,M\ge 1$
4. 訓練及測試
訓練階段和其他姿態估計算法相似,都是估計熱圖,然后計算真值熱圖和估計熱圖的均方誤差,如下
$L=\frac{1}{2}\sum\limits_{n=1}^{N}{\sum\limits_{k=1}^{K}{{{\left\| {{\mathbf{S}}_{k}}-{{{\mathbf{\hat{S}}}}_{k}} \right\|}^{2}}}}$
其中N為樣本數量,K為關鍵點的數量(也即熱圖數量)
測試階段,使用最后一個hourglass熱圖最大的score的位置作為關鍵點。由於該算法為自頂向下的姿態估計算法,輸入網絡的圖像僅有一個人,因而最大score的位置即為對應的關鍵點。
${{\mathbf{\hat{z}}}_{k}}=\underset{\mathbf{p}}{\mathop{\arg \max }}\,{{\mathbf{\hat{S}}}_{k}}(\mathbf{p}),k=1,L,K$
5. 代碼
PyraNet定義如下:

1 class PyraNet(nn.Module): 2 """docstring for PyraNet""" 3 def __init__(self, nChannels=256, nStack=4, nModules=2, numReductions=4, baseWidth=6, cardinality=30, nJoints=16, inputRes=256): 4 super(PyraNet, self).__init__() 5 self.nChannels = nChannels 6 self.nStack = nStack 7 self.nModules = nModules 8 self.numReductions = numReductions 9 self.baseWidth = baseWidth 10 self.cardinality = cardinality 11 self.inputRes = inputRes 12 self.nJoints = nJoints 13 14 self.start = M.BnReluConv(3, 64, kernelSize = 7, stride = 2, padding = 3) # BN+ReLU+conv 15 16 # 先通過兩分支(1*1 conv+3*3 conv,1*1 conv+不同尺度特征之和+3*3 conv,這兩分支求和,並使用1*1 conv升維),並在輸入輸出通道相等時,直接返回,否則使用1*1 conv相加 17 self.res1 = M.ResidualPyramid(64, 128, self.inputRes//2, self.baseWidth, self.cardinality, 0) 18 self.mp = nn.MaxPool2d(2, 2) 19 self.res2 = M.ResidualPyramid(128, 128, self.inputRes//4, self.baseWidth, self.cardinality,) # 先通過兩分支,並在輸入輸出通道相等時,直接返回,否則使用1*1 conv相加 20 self.res3 = M.ResidualPyramid(128, self.nChannels, self.inputRes//4, self.baseWidth, self.cardinality) # 先通過兩分支,並在輸入輸出通道相等時,直接返回,否則使用1*1 conv相加 21 22 _hourglass, _Residual, _lin1, _chantojoints, _lin2, _jointstochan = [],[],[],[],[],[] 23 24 for _ in range(self.nStack): # 堆疊個數 25 _hourglass.append(PyraNetHourGlass(self.nChannels, self.numReductions, self.nModules, self.inputRes//4, self.baseWidth, self.cardinality)) 26 _ResidualModules = [] 27 for _ in range(self.nModules): 28 _ResidualModules.append(M.Residual(self.nChannels, self.nChannels)) # 輸入和輸出相等,只有3*(BN+ReLU+conv) 29 _ResidualModules = nn.Sequential(*_ResidualModules) 30 _Residual.append(_ResidualModules) 31 _lin1.append(M.BnReluConv(self.nChannels, self.nChannels)) # BN+ReLU+conv 32 _chantojoints.append(nn.Conv2d(self.nChannels, self.nJoints,1)) # 1*1 conv,維度變換 33 _lin2.append(nn.Conv2d(self.nChannels, self.nChannels,1)) # 1*1 conv,維度變換 34 _jointstochan.append(nn.Conv2d(self.nJoints,self.nChannels,1)) # 1*1 conv,維度變換 35 36 self.hourglass = nn.ModuleList(_hourglass) 37 self.Residual = nn.ModuleList(_Residual) 38 self.lin1 = nn.ModuleList(_lin1) 39 self.chantojoints = nn.ModuleList(_chantojoints) 40 self.lin2 = nn.ModuleList(_lin2) 41 self.jointstochan = nn.ModuleList(_jointstochan) 42 43 def forward(self, x): 44 x = self.start(x) 45 x = self.res1(x) 46 x = self.mp(x) 47 x = self.res2(x) 48 x = self.res3(x) 49 out = [] 50 51 for i in range(self.nStack): 52 x1 = self.hourglass[i](x) 53 x1 = self.Residual[i](x1) 54 x1 = self.lin1[i](x1) 55 out.append(self.chantojoints[i](x1)) 56 x1 = self.lin2[i](x1) 57 x = x + x1 + self.jointstochan[i](out[i]) # 特征求和 58 59 return (out)
ResidualPyramid定義如下:

1 class ResidualPyramid(nn.Module): 2 """docstring for ResidualPyramid""" 3 # 先通過兩分支(1*1 conv+3*3 conv,1*1 conv+不同尺度特征之和+3*3 conv,這兩分支求和,並使用1*1 conv升維),並在輸入輸出通道相等時,直接返回,否則使用1*1 conv相加 4 def __init__(self, inChannels, outChannels, inputRes, baseWidth, cardinality, type = 1): 5 super(ResidualPyramid, self).__init__() 6 self.inChannels = inChannels 7 self.outChannels = outChannels 8 self.inputRes = inputRes 9 self.baseWidth = baseWidth 10 self.cardinality = cardinality 11 self.type = type 12 # PyraConvBlock:兩分支,一個是1*1 conv+3*3 conv,一個是1*1 conv+不同尺度特征之和+3*3 conv,這兩分支求和,並使用1*1 conv升維 13 self.cb = PyraConvBlock(self.inChannels, self.outChannels, self.inputRes, self.baseWidth, self.cardinality, self.type) 14 self.skip = SkipLayer(self.inChannels, self.outChannels) # 輸入和輸出通道相等,則為None,否則為1*1 conv 15 16 def forward(self, x): 17 out = 0 18 out = out + self.cb(x) 19 out = out + self.skip(x) 20 return out
PyraConvBlock如下:

1 class PyraConvBlock(nn.Module): 2 """docstring for PyraConvBlock""" # 兩分支,一個是1*1 conv+3*3 conv,一個是1*1 conv+不同尺度特征之和+3*3 conv,這兩分支求和,並使用1*1 conv升維 3 def __init__(self, inChannels, outChannels, inputRes, baseWidth, cardinality, type = 1): 4 super(PyraConvBlock, self).__init__() 5 self.inChannels = inChannels 6 self.outChannels = outChannels 7 self.inputRes = inputRes 8 self.baseWidth = baseWidth 9 self.cardinality = cardinality 10 self.outChannelsby2 = outChannels//2 11 self.D = self.outChannels // self.baseWidth 12 self.branch1 = nn.Sequential( # 第一個分支,1*1 conv + 3*3 conv 13 BnReluConv(self.inChannels, self.outChannelsby2, 1, 1, 0), # BN+ReLU+conv 14 BnReluConv(self.outChannelsby2, self.outChannelsby2, 3, 1, 1) # BN+ReLU+conv 15 ) 16 self.branch2 = nn.Sequential( # 第二個分支,1*1 conv + 3*3 conv 17 BnReluConv(self.inChannels, self.D, 1, 1, 0), # BN+ReLU+conv 18 BnReluPyra(self.D, self.cardinality, self.inputRes), # BN+ReLU+不同尺度的特征之和 19 BnReluConv(self.D, self.outChannelsby2, 1, 1, 0) # BN+ReLU+conv 20 ) 21 self.afteradd = BnReluConv(self.outChannelsby2, self.outChannels, 1, 1, 0) # BN+ReLU+conv 22 23 def forward(self, x): 24 x = self.branch2(x) + self.branch1(x) # 兩個分支特征之和 25 x = self.afteradd(x) # 1*1 conv進行升維 26 return x
BnReluPyra如下

1 class BnReluPyra(nn.Module): 2 """docstring for BnReluPyra""" # BN + ReLU + 不同尺度的特征之和 3 def __init__(self, D, cardinality, inputRes): 4 super(BnReluPyra, self).__init__() 5 self.D = D 6 self.cardinality = cardinality 7 self.inputRes = inputRes 8 self.bn = nn.BatchNorm2d(self.D) 9 self.relu = nn.ReLU() 10 self.pyra = Pyramid(self.D, self.cardinality, self.inputRes) # 將不同尺度的特征求和 11 12 def forward(self, x): 13 x = self.bn(x) 14 x = self.relu(x) 15 x = self.pyra(x) 16 return x
Pyramid如下:

1 class Pyramid(nn.Module): 2 """docstring for Pyramid""" # 將不同尺度的特征求和 3 def __init__(self, D, cardinality, inputRes): 4 super(Pyramid, self).__init__() 5 self.D = D 6 self.cardinality = cardinality # 論文中公式3的C,金字塔層數 7 self.inputRes = inputRes 8 self.scale = 2**(-1/self.cardinality) # 金字塔第1層的下采樣率,后面層在此基礎上+1 9 _scales = [] 10 for card in range(self.cardinality): 11 temp = nn.Sequential( # 下采樣 + 3*3 conv + 上采樣 12 nn.FractionalMaxPool2d(2, output_ratio = self.scale**(card + 1)), # 每一層在第1層基礎上+1的下采樣率 13 nn.Conv2d(self.D, self.D, 3, 1, 1), 14 nn.Upsample(size = self.inputRes)#, mode='bilinear') # 上采樣到輸入分辨率 15 ) 16 _scales.append(temp) 17 self.scales = nn.ModuleList(_scales) 18 19 def forward(self, x): 20 #print(x.shape, self.inputRes) 21 out = torch.zeros_like(x) # 初始化和輸入大小一樣的0矩陣 22 for card in range(self.cardinality): 23 out += self.scales[card](x) # 將所有尺度的特征求和 24 return out
PyraNetHourGlass如下:

1 class PyraNetHourGlass(nn.Module): 2 """docstring for PyraNetHourGlass""" 3 def __init__(self, nChannels=256, numReductions=4, nModules=2, inputRes=256, baseWidth=6, cardinality=30, poolKernel=(2,2), poolStride=(2,2), upSampleKernel=2): 4 super(PyraNetHourGlass, self).__init__() 5 self.numReductions = numReductions 6 self.nModules = nModules 7 self.nChannels = nChannels 8 self.poolKernel = poolKernel 9 self.poolStride = poolStride 10 self.upSampleKernel = upSampleKernel 11 12 self.inputRes = inputRes 13 self.baseWidth = baseWidth 14 self.cardinality = cardinality 15 16 """ For the skip connection, a residual module (or sequence of residuaql modules) """ 17 # ResidualPyramid:先通過兩分支,並在輸入輸出通道相等時,直接返回,否則使用1*1 conv相加 18 # Residual:輸入和輸出相等,只有3*(BN+ReLU+conv) 19 Residualskip = M.ResidualPyramid if numReductions > 1 else M.Residual 20 Residualmain = M.ResidualPyramid if numReductions > 2 else M.Residual 21 _skip = [] 22 for _ in range(self.nModules): # 根據numReductions確定使用金字塔還是3*(BN+ReLU+conv) 23 _skip.append(Residualskip(self.nChannels, self.nChannels, self.inputRes, self.baseWidth, self.cardinality)) 24 self.skip = nn.Sequential(*_skip) 25 26 """ First pooling to go to smaller dimension then pass input through 27 Residual Module or sequence of Modules then and subsequent cases: 28 either pass through Hourglass of numReductions-1 or pass through Residual Module or sequence of Modules """ 29 self.mp = nn.MaxPool2d(self.poolKernel, self.poolStride) 30 31 _afterpool = [] 32 for _ in range(self.nModules): # 根據numReductions確定使用金字塔還是3*(BN+ReLU+conv) 33 _afterpool.append(Residualmain(self.nChannels, self.nChannels, self.inputRes//2, self.baseWidth, self.cardinality)) 34 self.afterpool = nn.Sequential(*_afterpool) 35 36 if (numReductions > 1): # 嵌套調用本身 37 self.hg = PyraNetHourGlass(self.nChannels, self.numReductions-1, self.nModules, self.inputRes//2, self.baseWidth, 38 self.cardinality, self.poolKernel, self.poolStride, self.upSampleKernel) 39 else: 40 _num1res = [] 41 for _ in range(self.nModules): # 根據numReductions確定使用金字塔還是3*(BN+ReLU+conv) 42 _num1res.append(Residualmain(self.nChannels,self.nChannels, self.inputRes//2, self.baseWidth, self.cardinality)) 43 self.num1res = nn.Sequential(*_num1res) # doesnt seem that important ? 44 45 """ Now another Residual Module or sequence of Residual Modules """ 46 _lowres = [] 47 for _ in range(self.nModules): # 根據numReductions確定使用金字塔還是3*(BN+ReLU+conv) 48 _lowres.append(Residualmain(self.nChannels,self.nChannels, self.inputRes//2, self.baseWidth, self.cardinality)) 49 self.lowres = nn.Sequential(*_lowres) 50 51 """ Upsampling Layer (Can we change this??????) As per Newell's paper upsamping recommended """ 52 self.up = nn.Upsample(scale_factor = self.upSampleKernel) # 將高和寬擴充,實現上采樣 53 54 def forward(self, x): 55 out1 = x 56 out1 = self.skip(out1) # 根據numReductions確定使用金字塔還是3*(BN+ReLU+conv) 57 out2 = x 58 out2 = self.mp(out2) # 根據numReductions確定使用金字塔還是3*(BN+ReLU+conv) 59 out2 = self.afterpool(out2) 60 if self.numReductions>1: 61 out2 = self.hg(out2) # 嵌套調用本身 62 else: 63 out2 = self.num1res(out2) # 根據numReductions確定使用金字塔還是3*(BN+ReLU+conv) 64 out2 = self.lowres(out2) # 根據numReductions確定使用金字塔還是3*(BN+ReLU+conv) 65 out2 = self.up(out2) # 升維 66 67 return out2 + out1 # 求和
Residual如下:

1 class Residual(nn.Module): 2 """docstring for Residual""" # 輸入和輸出相等,只有3*(BN+ReLU+conv);否則輸入通過1*1conv結果和3*(BN+ReLU+conv)求和 3 def __init__(self, inChannels, outChannels, inputRes=None, baseWidth=None, cardinality=None, type=None): 4 super(Residual, self).__init__() 5 self.inChannels = inChannels 6 self.outChannels = outChannels 7 self.cb = ConvBlock(self.inChannels, self.outChannels) # 3 * (BN+ReLU+conv) 其中第一組降維,第二組不變,第三組升維 8 self.skip = SkipLayer(self.inChannels, self.outChannels) # 輸入和輸出通道相等,則為None,否則為1*1 conv 9 10 def forward(self, x): 11 out = 0 12 out = out + self.cb(x) 13 out = out + self.skip(x) 14 return out