[GiantPandaCV導語] 本文主要講解CenterNet的loss,由偏置部分(reg loss)、熱圖部分(heatmap loss)、寬高(wh loss)部分三部分loss組成,附代碼實現。
1. 網絡輸出
論文中提供了三個用於目標檢測的網絡,都是基於編碼解碼的結構構建的。
- ResNet18 + upsample + deformable convolution : COCO AP 28%/142FPS
- DLA34 + upsample + deformable convolution : COCO AP 37.4%/52FPS
- Hourglass104: COCO AP 45.1%/1.4FPS
這三個網絡中輸出內容都是一樣的,80個類別,2個預測中心對應的長和寬,2個中心點的偏差。
# heatmap 輸出的tensor的通道個數是80,每個通道代表對應類別的heatmap
(hm): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1))
)
# wh 輸出是中心對應的長和寬,通道數為2
(wh): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
)
# reg 輸出的tensor通道個數為2,分別是w,h方向上的偏移量
(reg): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
)
2. 損失函數
2.1 heatmap loss
輸入圖像\(I\in R^{W\times H\times 3}\), W為圖像寬度,H為圖像高度。網絡輸出的關鍵點熱圖heatmap為\(\hat{Y}\in [0,1]^{\frac{W}{R}\times \frac{H}{R}\times C}\)其中,R代表得到輸出相對於原圖的步長stride。C代表類別個數。
下面是CenterNet中核心loss公式:
這個和Focal loss形式很相似,\(\alpha\)和\(\beta\)是超參數,N代表的是圖像關鍵點個數。
- 在\(Y_{xyc}=1\)的時候,
對於易分樣本來說,預測值\(\hat{Y}_{xyc}\)接近於1,\((1-\hat{Y}_{xyc})^\alpha\)就是一個很小的值,這樣loss就很小,起到了矯正作用。
對於難分樣本來說,預測值\(\hat{Y}_{xyc}\)接近於0,$ (1-\hat{Y}_{xyc})^\alpha$就比較大,相當於加大了其訓練的比重。
- otherwise的情況下:
上圖是一個簡單的示意,縱坐標是\({Y}_{xyc}\),分為A區(距離中心點較近,但是值在0-1之間)和B區(距離中心點很遠接近於0)。
對於A區來說,由於其周圍是一個高斯核生成的中心,\(Y_{xyc}\)的值是從1慢慢變到0。
舉個例子(CenterNet中默認\(\alpha=2,\beta=4\)):
\(Y_{xyc}=0.8\)的情況下,
-
如果\(\hat{Y}_{xyc}=0.99\),那么loss=\((1-0.8)^4(0.99)^2log(1-0.99)\),這就是一個很大的loss值。
-
如果\(\hat{Y}_{xyc}=0.8\), 那么loss=\((1-0.8)^4(0.8)^2log(1-0.8)\), 這個loss就比較小。
-
如果\(\hat{Y}_{xyc}=0.5\), 那么loss=\((1-0.8)^4(0.5)^2log(1-0.5)\),
-
如果\(\hat{Y}_{xyc}=0.99\),那么loss=\((1-0.5)^4(0.99)^2log(1-0.99)\),這就是一個很大的loss值。
-
如果\(\hat{Y}_{xyc}=0.8\), 那么loss=\((1-0.5)^4(0.8)^2log(1-0.8)\), 這個loss就比較小。
-
如果\(\hat{Y}_{xyc}=0.5\), 那么loss=\((1-0.5)^4(0.5)^2log(1-0.5)\),
總結一下:為了防止預測值\(\hat{Y}_{xyc}\)過高接近於1,所以用\((\hat{Y}_{xyc})^\alpha\)來懲罰Loss。而\((1-Y_{xyc})^\beta\)這個參數距離中心越近,其值越小,這個權重是用來減輕懲罰力度。
對於B區來說,\(\hat{Y}_{xyc}\)的預測值理應是0,如果該值比較大比如為1,那么\((\hat{Y}_{xyc})^\alpha\)作為權重會變大,懲罰力度也加大了。如果預測值接近於0,那么\((\hat{Y}_{xyc})^\alpha\)會很小,讓其損失比重減小。對於\((1-Y_{xyc})^\beta\)來說,B區的值比較大,弱化了中心點周圍其他負樣本的損失比重。
2.2 offset loss
由於三個骨干網絡輸出的feature map的空間分辨率變為原來輸入圖像的四分之一。相當於輸出feature map上一個像素點對應原始圖像的4x4的區域,這會帶來較大的誤差,因此引入了偏置值和偏置的損失值。設骨干網絡輸出的偏置值為\(\hat{O}\in R^{\frac{W}{R}\times \frac{H}{R}\times 2}\), 這個偏置值用L1 loss來訓練:
p代表目標框中心點,R代表下采樣倍數4,\(\tilde{p}=\lfloor \frac{p}{R} \rfloor\), \(\frac{p}{R}-\tilde{p}\)代表偏差值。
2.3 size loss/wh loss
假設第k個目標,類別為\(c_k\)的目標框的表示為\((x_1^{(k)},y_1^{(k)},x_2^{(k)},y_2^{(k)})\),那么其中心點坐標位置為\((\frac{x_1^{(k)}+x_2^{(k)}}{2}, \frac{y_1^{(k)}+y_2^{(k)}}{2})\), 目標的長和寬大小為\(s_k=(x_2^{(k)}-x_1^{(k)},y_2^{(k)}-y_1^{(k)})\)。對長和寬進行訓練的是L1 Loss函數:
其中\(\hat{S}\in R^{\frac{W}{R}\times \frac{H}{R}\times 2}\)是網絡輸出的結果。
2.4 CenterNet Loss
整體的損失函數是以上三者的綜合,並且分配了不同的權重。
其中\(\lambda_{size}=0.1, \lambda_{offsize}=1\)
3. 代碼解析
來自train.py中第173行開始進行loss計算:
# 得到heat map, reg, wh 三個變量
hmap, regs, w_h_ = zip(*outputs)
regs = [
_tranpose_and_gather_feature(r, batch['inds']) for r in regs
]
w_h_ = [
_tranpose_and_gather_feature(r, batch['inds']) for r in w_h_
]
# 分別計算loss
hmap_loss = _neg_loss(hmap, batch['hmap'])
reg_loss = _reg_loss(regs, batch['regs'], batch['ind_masks'])
w_h_loss = _reg_loss(w_h_, batch['w_h_'], batch['ind_masks'])
# 進行loss加權,得到最終loss
loss = hmap_loss + 1 * reg_loss + 0.1 * w_h_loss
上述transpose_and_gather_feature
函數具體實現如下,主要功能是將ground truth中計算得到的對應中心點的值獲取。
def _tranpose_and_gather_feature(feat, ind):
# ind代表的是ground truth中設置的存在目標點的下角標
feat = feat.permute(0, 2, 3, 1).contiguous()# from [bs c h w] to [bs, h, w, c]
feat = feat.view(feat.size(0), -1, feat.size(3)) # to [bs, wxh, c]
feat = _gather_feature(feat, ind)
return feat
def _gather_feature(feat, ind, mask=None):
# feat : [bs, wxh, c]
dim = feat.size(2)
# ind : [bs, index, c]
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
feat = feat.gather(1, ind) # 按照dim=1獲取ind
if mask is not None:
mask = mask.unsqueeze(2).expand_as(feat)
feat = feat[mask]
feat = feat.view(-1, dim)
return feat
3.1 hmap loss代碼
調用:hmap_loss = _neg_loss(hmap, batch['hmap'])
def _neg_loss(preds, targets):
''' Modified focal loss. Exactly the same as CornerNet.
Runs faster and costs a little bit more memory
Arguments:
preds (B x c x h x w)
gt_regr (B x c x h x w)
'''
pos_inds = targets.eq(1).float()# heatmap為1的部分是正樣本
neg_inds = targets.lt(1).float()# 其他部分為負樣本
neg_weights = torch.pow(1 - targets, 4)# 對應(1-Yxyc)^4
loss = 0
for pred in preds: # 預測值
# 約束在0-1之間
pred = torch.clamp(torch.sigmoid(pred), min=1e-4, max=1 - 1e-4)
pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
neg_loss = torch.log(1 - pred) * torch.pow(pred,
2) * neg_weights * neg_inds
num_pos = pos_inds.float().sum()
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
if num_pos == 0:
loss = loss - neg_loss # 只有負樣本
else:
loss = loss - (pos_loss + neg_loss) / num_pos
return loss / len(preds)
代碼和以上公式一一對應,pos代表正樣本,neg代表負樣本。
3.2 reg & wh loss代碼
調用:reg_loss = _reg_loss(regs, batch['regs'], batch['ind_masks'])
調用:w_h_loss = _reg_loss(w_h_, batch['w_h_'], batch['ind_masks'])
def _reg_loss(regs, gt_regs, mask):
mask = mask[:, :, None].expand_as(gt_regs).float()
loss = sum(F.l1_loss(r * mask, gt_regs * mask, reduction='sum') /
(mask.sum() + 1e-4) for r in regs)
return loss / len(regs)