【從零開始學CenterNet】6. CenterNet之loss計算代碼解析


[GiantPandaCV導語] 本文主要講解CenterNet的loss,由偏置部分(reg loss)、熱圖部分(heatmap loss)、寬高(wh loss)部分三部分loss組成,附代碼實現。

1. 網絡輸出

論文中提供了三個用於目標檢測的網絡,都是基於編碼解碼的結構構建的。

  1. ResNet18 + upsample + deformable convolution : COCO AP 28%/142FPS
  2. DLA34 + upsample + deformable convolution : COCO AP 37.4%/52FPS
  3. Hourglass104: COCO AP 45.1%/1.4FPS

這三個網絡中輸出內容都是一樣的,80個類別,2個預測中心對應的長和寬,2個中心點的偏差。

# heatmap 輸出的tensor的通道個數是80,每個通道代表對應類別的heatmap
(hm): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1))
)
# wh 輸出是中心對應的長和寬,通道數為2
(wh): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
)
# reg 輸出的tensor通道個數為2,分別是w,h方向上的偏移量
(reg): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
)

2. 損失函數

2.1 heatmap loss

輸入圖像\(I\in R^{W\times H\times 3}\), W為圖像寬度,H為圖像高度。網絡輸出的關鍵點熱圖heatmap為\(\hat{Y}\in [0,1]^{\frac{W}{R}\times \frac{H}{R}\times C}\)其中,R代表得到輸出相對於原圖的步長stride。C代表類別個數。

下面是CenterNet中核心loss公式:

\[L_k=\frac{-1}{N}\sum_{xyc}\begin{cases} (1-\hat{Y}_{xyc})^\alpha log(\hat{Y}_{xyc})& Y_{xyc}=1\\ (1-Y_{xyc})^\beta(\hat{Y}_{xyc})^\alpha log(1-\hat{Y}_{xyc})& otherwise \end{cases} \]

這個和Focal loss形式很相似,\(\alpha\)\(\beta\)是超參數,N代表的是圖像關鍵點個數。

  • \(Y_{xyc}=1\)的時候,

對於易分樣本來說,預測值\(\hat{Y}_{xyc}\)接近於1,\((1-\hat{Y}_{xyc})^\alpha\)就是一個很小的值,這樣loss就很小,起到了矯正作用。

對於難分樣本來說,預測值\(\hat{Y}_{xyc}\)接近於0,$ (1-\hat{Y}_{xyc})^\alpha$就比較大,相當於加大了其訓練的比重。

  • otherwise的情況下:

otherwise分為兩個情況A和B

上圖是一個簡單的示意,縱坐標是\({Y}_{xyc}\),分為A區(距離中心點較近,但是值在0-1之間)和B區(距離中心點很遠接近於0)。

對於A區來說,由於其周圍是一個高斯核生成的中心,\(Y_{xyc}\)的值是從1慢慢變到0。

舉個例子(CenterNet中默認\(\alpha=2,\beta=4\)):

\(Y_{xyc}=0.8\)的情況下,

  • 如果\(\hat{Y}_{xyc}=0.99\),那么loss=\((1-0.8)^4(0.99)^2log(1-0.99)\),這就是一個很大的loss值。

  • 如果\(\hat{Y}_{xyc}=0.8\), 那么loss=\((1-0.8)^4(0.8)^2log(1-0.8)\), 這個loss就比較小。

  • 如果\(\hat{Y}_{xyc}=0.5\), 那么loss=\((1-0.8)^4(0.5)^2log(1-0.5)\),

  • 如果\(\hat{Y}_{xyc}=0.99\),那么loss=\((1-0.5)^4(0.99)^2log(1-0.99)\),這就是一個很大的loss值。

  • 如果\(\hat{Y}_{xyc}=0.8\), 那么loss=\((1-0.5)^4(0.8)^2log(1-0.8)\), 這個loss就比較小。

  • 如果\(\hat{Y}_{xyc}=0.5\), 那么loss=\((1-0.5)^4(0.5)^2log(1-0.5)\),

總結一下:為了防止預測值\(\hat{Y}_{xyc}\)過高接近於1,所以用\((\hat{Y}_{xyc})^\alpha\)來懲罰Loss。而\((1-Y_{xyc})^\beta\)這個參數距離中心越近,其值越小,這個權重是用來減輕懲罰力度。

對於B區來說\(\hat{Y}_{xyc}\)的預測值理應是0,如果該值比較大比如為1,那么\((\hat{Y}_{xyc})^\alpha\)作為權重會變大,懲罰力度也加大了。如果預測值接近於0,那么\((\hat{Y}_{xyc})^\alpha\)會很小,讓其損失比重減小。對於\((1-Y_{xyc})^\beta\)來說,B區的值比較大,弱化了中心點周圍其他負樣本的損失比重。

2.2 offset loss

由於三個骨干網絡輸出的feature map的空間分辨率變為原來輸入圖像的四分之一。相當於輸出feature map上一個像素點對應原始圖像的4x4的區域,這會帶來較大的誤差,因此引入了偏置值和偏置的損失值。設骨干網絡輸出的偏置值為\(\hat{O}\in R^{\frac{W}{R}\times \frac{H}{R}\times 2}\), 這個偏置值用L1 loss來訓練:

\[L_{offset}=\frac{1}{N}\sum_{p}|\hat{O}_{\tilde{p}}-(\frac{p}{R}-\tilde{p})| \]

p代表目標框中心點,R代表下采樣倍數4,\(\tilde{p}=\lfloor \frac{p}{R} \rfloor\), \(\frac{p}{R}-\tilde{p}\)代表偏差值。

2.3 size loss/wh loss

假設第k個目標,類別為\(c_k\)的目標框的表示為\((x_1^{(k)},y_1^{(k)},x_2^{(k)},y_2^{(k)})\),那么其中心點坐標位置為\((\frac{x_1^{(k)}+x_2^{(k)}}{2}, \frac{y_1^{(k)}+y_2^{(k)}}{2})\), 目標的長和寬大小為\(s_k=(x_2^{(k)}-x_1^{(k)},y_2^{(k)}-y_1^{(k)})\)。對長和寬進行訓練的是L1 Loss函數:

\[L_{size}=\frac{1}{N}\sum^{N}_{k=1}|\hat{S}_{pk}-s_k| \]

其中\(\hat{S}\in R^{\frac{W}{R}\times \frac{H}{R}\times 2}\)是網絡輸出的結果。

2.4 CenterNet Loss

整體的損失函數是以上三者的綜合,並且分配了不同的權重。

\[L_{det}=L_k+\lambda_{size}L_{size}+\lambda_{offset}L_{offset} \]

其中\(\lambda_{size}=0.1, \lambda_{offsize}=1\)

3. 代碼解析

來自train.py中第173行開始進行loss計算:

# 得到heat map, reg, wh 三個變量
hmap, regs, w_h_ = zip(*outputs)

regs = [
_tranpose_and_gather_feature(r, batch['inds']) for r in regs
]
w_h_ = [
_tranpose_and_gather_feature(r, batch['inds']) for r in w_h_
]

# 分別計算loss
hmap_loss = _neg_loss(hmap, batch['hmap'])
reg_loss = _reg_loss(regs, batch['regs'], batch['ind_masks'])
w_h_loss = _reg_loss(w_h_, batch['w_h_'], batch['ind_masks'])

# 進行loss加權,得到最終loss
loss = hmap_loss + 1 * reg_loss + 0.1 * w_h_loss

上述transpose_and_gather_feature函數具體實現如下,主要功能是將ground truth中計算得到的對應中心點的值獲取。

def _tranpose_and_gather_feature(feat, ind):
  # ind代表的是ground truth中設置的存在目標點的下角標
  feat = feat.permute(0, 2, 3, 1).contiguous()# from [bs c h w] to [bs, h, w, c] 
  feat = feat.view(feat.size(0), -1, feat.size(3)) # to [bs, wxh, c]
  feat = _gather_feature(feat, ind)
  return feat

def _gather_feature(feat, ind, mask=None):
  # feat : [bs, wxh, c]
  dim = feat.size(2)
  # ind : [bs, index, c]
  ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
  feat = feat.gather(1, ind) # 按照dim=1獲取ind
  if mask is not None:
    mask = mask.unsqueeze(2).expand_as(feat)
    feat = feat[mask]
    feat = feat.view(-1, dim)
  return feat

3.1 hmap loss代碼

調用:hmap_loss = _neg_loss(hmap, batch['hmap'])

def _neg_loss(preds, targets):
    ''' Modified focal loss. Exactly the same as CornerNet.
        Runs faster and costs a little bit more memory
        Arguments:
        preds (B x c x h x w)
        gt_regr (B x c x h x w)
    '''
    pos_inds = targets.eq(1).float()# heatmap為1的部分是正樣本
    neg_inds = targets.lt(1).float()# 其他部分為負樣本

    neg_weights = torch.pow(1 - targets, 4)# 對應(1-Yxyc)^4

    loss = 0
    for pred in preds: # 預測值
        # 約束在0-1之間
        pred = torch.clamp(torch.sigmoid(pred), min=1e-4, max=1 - 1e-4)
        pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
        neg_loss = torch.log(1 - pred) * torch.pow(pred,
                                                   2) * neg_weights * neg_inds
        num_pos = pos_inds.float().sum()
        pos_loss = pos_loss.sum()
        neg_loss = neg_loss.sum()

        if num_pos == 0:
            loss = loss - neg_loss # 只有負樣本
        else:
            loss = loss - (pos_loss + neg_loss) / num_pos
    return loss / len(preds)

\[L_k=\frac{-1}{N}\sum_{xyc}\begin{cases} (1-\hat{Y}_{xyc})^\alpha log(\hat{Y}_{xyc})& Y_{xyc}=1\\ (1-Y_{xyc})^\beta(\hat{Y}_{xyc})^\alpha log(1-\hat{Y}_{xyc})& otherwise \end{cases} \]

代碼和以上公式一一對應,pos代表正樣本,neg代表負樣本。

3.2 reg & wh loss代碼

調用:reg_loss = _reg_loss(regs, batch['regs'], batch['ind_masks'])

調用:w_h_loss = _reg_loss(w_h_, batch['w_h_'], batch['ind_masks'])

def _reg_loss(regs, gt_regs, mask):
    mask = mask[:, :, None].expand_as(gt_regs).float()
    loss = sum(F.l1_loss(r * mask, gt_regs * mask, reduction='sum') /
               (mask.sum() + 1e-4) for r in regs)
    return loss / len(regs)

4. 參考

https://zhuanlan.zhihu.com/p/66048276

http://xxx.itp.ac.cn/pdf/1904.07850


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM