CenterNet算法介紹


CenterNet算法介紹(學習自objects as points)

論文依據:objects as points

博客參考

img

CenterNet不僅可以用於目標檢測,還可以用於其他的一些任務,如 肢體識別或者3D目標檢測 等等,我們團隊當下在實現的主要是目標檢測的部分。

img

與傳統的one-stage和two-stage的區別:

  • CenterNet的“anchor”(錨)僅僅會出現在當前目標的位置處而不是整張圖上撒,所以也沒有所謂的box overlap大於多少多少的算positive anchor這一說,也不需要區分這個anchor是物體還是背景 - 因為每個目標只對應一個“anchor”,這個anchor是從heatmap中提取出來的,所以不需要NMS再進行來篩選
  • CenterNet的輸出分辨率的下采樣因子是4,比起其他的目標檢測框架算是比較小的(Mask-Rcnn最小為16、SSD為最小為16)。

網絡結構與前提條件

  • 網絡結構

論文中CenterNet提到了三種用於目標檢測的網絡,這三種網絡都是編碼解碼(encoder-decoder)的結構:

  1. Resnet-18 with up-convolutional layers : 28.1% coco and 142 FPS
  2. DLA-34 : 37.4% COCOAP and 52 FPS
  3. Hourglass-104 : 45.1% COCOAP and 1.4 FPS

每個網絡內部的結構不同,但是在模型的最后都是加了三個網絡構造來輸出預測值,默認是80個類、2個預測的中心點坐標、2個中心點的偏置。

用官方的源碼(使用Pytorch)來表示一下最后三層,其中hm為heatmap、wh為對應中心點的width和height、reg為偏置量,這些值在后文中會有講述。

(hm): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1))
)
(wh): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
)
(reg): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
)
  • 檢測方法

img

  	首先假設輸入圖像為 ![[公式]](https://www.zhihu.com/equation?tex=I+%5Cin+R%5E%7BW+%5Ctimes+H+%5Ctimes+3%7D) ,其中 ![[公式]](https://www.zhihu.com/equation?tex=W) 和 ![[公式]](https://www.zhihu.com/equation?tex=H) 分別為圖像的寬和高,然后在預測的時候,我們要產生出關鍵點的熱點圖(keypoint heatmap): ![[公式]](https://www.zhihu.com/equation?tex=%5Chat%7BY%7D+%5Cin+%5B0%2C1%5D%5E+%7B%5Cfrac%7BW%7D%7BR%7D+%5Ctimes+%5Cfrac%7BH%7D%7BR%7D+%5Ctimes+C%7D) ,其中 ![[公式]](https://www.zhihu.com/equation?tex=R) 為輸出對應原圖的步長,而 ![[公式]](https://www.zhihu.com/equation?tex=C) 是在目標檢測中對應着檢測點的數量,如在COCO目標檢測任務中,這個 ![[公式]](https://www.zhihu.com/equation?tex=C+) 的值為80,代表當前有80個類別。

插一段官方代碼,其中 [公式] 就是self.opt.down_ratio也就是4,代表下采樣的因子。

# 其中input_h和input_w為512,而self.opt.down_ratio為4,最終的output_h為128
# self.opt.down_ratio就是上述的R即輸出對應原圖的步長
output_h = input_h // self.opt.down_ratio
output_w = input_w // self.opt.down_ratio

​ 這樣, [公式] 就是一個檢測到物體的預測值,對於 [公式] ,表示對於類別 [公式] ,在當前 [公式] 坐標中檢測到了這種類別的物體,而 [公式] 則表示當前當前這個坐標點不存在類別為 [公式] 的物體。

​ 在整個訓練的流程中,CenterNet學習了CornerNet的方法。對於每個標簽圖(ground truth)中的某一 [公式] 類,我們要將真實關鍵點(true keypoint) [公式] 計算出來用於訓練,中心點的計算方式為 [公式] ,對於下采樣后的坐標,我們設為 [公式] ,其中 [公式] 是上文中提到的下采樣因子4。所以我們最終計算出來的中心點是對應低分辨率的中心點。

​ 然后我們利用 [公式] 來對圖像進行標記,在下采樣的[128,128]圖像中將ground truth point[公式] 的形式,用一個高斯核 [公式] 來將關鍵點分布到特征圖上,其中 [公式] 是一個與目標大小(也就是w和h)相關的標准差。如果某一個類的兩個高斯分布發生了重疊,直接取元素間最大的就可以。

​ 這么說可能不是很好理解,那么直接看一個官方源碼中生成的一個高斯分布[9,9]:

preview

  • 損失函數(中心點預測)

img

其中 [公式][公式] 是Focal Loss的超參數, [公式] 是圖像 [公式] 的的關鍵點數量,用於將所有的positive focal loss標准化為1。在這篇論文中 [公式][公式] 分別是2和4。這個損失函數是Focal Loss的修改版,適用於CenterNet。

這個損失也比較關鍵,需要重點說一下。和Focal Loss類似,對於easy example的中心點,適當減少其訓練比重也就是loss值。

[公式] 的時候, [公式] 就充當了矯正的作用,假如 [公式] 接近1的話,說明這個是一個比較容易檢測出來的點,那么 [公式] 就相應比較低了。而當 [公式] 接近0的時候,說明這個中心點還沒有學習到,所以要加大其訓練的比重,因此 [公式] 就會很大, [公式] 是超參數,這里取2。

img

高斯生成的中心點

[公式] 的時候,這里對實際中心點的其他近鄰點的訓練比重(loss)也進行了調整,首先可以看到 [公式] ,因為當 [公式] 的時候 [公式] 的預測值理應是0,如果不為0的且越來越接近1的話, [公式] 的值就會變大從而使這個損失的訓練比重也加大;而 [公式] 則對中心點周圍的,和中心點靠得越近的點也做出了調整(因為與實際中心點靠的越近的點可能會影響干擾到實際中心點,造成誤檢測),因為 [公式] 在上文中已經提到,是一個高斯核生成的中心點,在中心點 [公式]但是在中心點周圍擴散 [公式] 會由1慢慢變小但是並不是直接為0,類似於上圖,因此 [公式] ,與中心點距離越近, [公式] 越接近1,這個值越小,相反則越大。那么 [公式] [公式] 是怎么協同工作的呢?

目標中心的偏置損失

因為上文中對圖像進行了 [公式] 的下采樣,這樣的特征圖重新映射到原始圖像上的時候會帶來精度誤差,因此對於每一個中心點,額外采用了一個local offset[公式] 去補償它。所有類 [公式] 的中心點共享同一個offset prediction,這個偏置值(offset)用L1 loss來訓練:

[公式]

上述公式直接看可能不是特別容易懂,其實 [公式] 是原始圖像經過下采樣得到的,對於[512,512]的圖像如果 [公式] 的話那么下采樣后就是[128,128]的圖像,下采樣之后對標簽圖像用高斯分布來在圖像上撒熱點,怎么撒呢?首先將box坐標也轉化為與[128,128]大小圖像匹配的形式,但是因為我們原始的annotation是浮點數的形式(COCO數據集),使用轉化后的box計算出來的中心點也是浮點型的,假設計算出來的中心點是[98.97667,2.3566666]。

推斷階段

在預測階段,首先針對一張圖像進行下采樣,隨后對下采樣后的圖像進行預測,對於每個類在下采樣的特征圖中預測中心點,然后將輸出圖中的每個類的熱點單獨地提取出來。具體怎么提取呢?就是檢測當前熱點的值是否比周圍的八個近鄰點(八方位)都大(或者等於),然后取100個這樣的點,采用的方式是一個3x3的MaxPool,類似於anchor-based檢測中nms的效果。

這里假設 [公式] 為檢測到的點,

img

代表 [公式] 類中檢測到的一個點。每個關鍵點的位置用整型坐標表示 [公式] ,然后使用 [公式] 表示當前點的confidence,隨后使用坐標來產生標定框:

[公式] 其中 [公式] 是當前點對應原始圖像的偏置點, [公式] 代表預測出來當前點對應目標的長寬。

下圖展示網絡模型預測出來的中心點、中心點偏置以及該點對應目標的長寬:

img

后記

這篇論文(objects as points)厲害的地方在於:

  1. 設計模型的結構比較簡單,像我這么頭腦愚笨的人也可以輕松看明白,不僅對於two-stage( faster-rcnn ),對於one-stage( yolo )的目標檢測算法來說該網絡的模型設計也是優雅簡單的。
  2. 該模型的思想不僅可以用於目標檢測,還可以用於3D檢測和人體姿態識別,雖然論文中沒有是深入探討這個,但是可以說明這個網絡的設計還是很好的,我們可以借助這個框架去做一些其他的任務。
  3. 雖然目前尚未嘗試輕量級的模型(這是我接下來要做的!),但是可以猜到這個模型對於嵌入式端這種算力比較小的平台還是很有優勢的,希望大家多多嘗試一些新的backbone(不知道mobilenetv3+CenterNet會是什么樣的效果),測試一下,歡迎和我交流呀~

當然說了一堆優點,CenterNet的缺點也是有的,那就是:

  • 在實際訓練中,如果在圖像中,同一個類別中的某些物體的GT中心點,在下采樣時會擠到一塊,也就是兩個物體在GT中的中心點重疊了,CenterNet對於這種情況也是無能為力的,也就是將這兩個物體的當成一個物體來訓練(因為只有一個中心點)。同理,在預測過程中,如果兩個同類的物體在下采樣后的中心點也重疊了,那么CenterNet也是只能檢測出一個中心點,不過CenterNet對於這種情況的處理要比faster-rcnn強一些的,具體指標可以查看論文相關部分。
  • 有一個需要注意的點,CenterNet在訓練過程中,如果同一個類的不同物體的高斯分布點互相有重疊,那么則在重疊的范圍內選取較大的高斯點。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM