論文:CenterMask: Single Shot Instance Segmentation with Point Representation
0.簡介
慣例,有請作者自己介紹——摘要:
- 將實例分割分為兩個子任務:局部形狀預測——區分實例即使重疊在一起;全局顯著性生成——對整個圖像進行pixel-to-pixel分割。
- 局部信息取自物體中心點(center points)的表示。(請見我在第一節中補充的 CenterNet部分)
- 在COCO上34.5maskAP,12fps。
作者認為單階段實體分割現在有兩個主要挑戰:
- 如何區分實體,尤其是同類別,重疊的情況。
- 如何保留像素級的局部位置信息。雙階段和單階段方法都面臨這一像素對齊問題,即特征結果轉換為原始物體大小,或利用固定個數的點對輪廓進行描述(如PolarMask?),無法保留原始圖像的空間信息。
作者也將目前單階段實體分割模型分為兩類:
- 基於全局圖像的方法。一般先生成全局特征圖,然后特征組合生成最終Mask。優點是能較好保留位置信息,實現像素級的特征對齊(pixel-to-pixel alignment);缺點是難以解決重疊遮擋問題。如YOLACT。
- 基於局部圖像的方法。優點是利於解決重疊遮擋問題;但分割精度不足,即不能較好保留位置信息。
而CenterMask同時包含一個全局顯著圖生成分支和一個局部形狀預測分支。分割既精細准確又能區分不同實例。
1.CenterMask結構
CenterMask模型的主要思想就是,把實例分割分為兩個子任務:一個負責預測粗糙但是實例敏感的局部形狀,另一個負責預測精確但是實例不敏感的全局顯著圖。
另外一個顯著特點是,CenterMask基於CenterNet的工作,用center point建模物體,是anchor-free的。這就是上圖所示的point representation。沒有看過CenterNet的話學習CenterMask可能從始至終都會有很多疑惑,所以我們先來看一下另一篇工作:Objects as Points (CenterNet)
CenterNet:Objects as Points
We model an object as a single point — the center point of its bounding box.
正如論文標題及原文摘要的這句話,CenterNet的作者將每個物體建模為一個點——bounding box的中心點(center point),物體的其他屬性通過回歸來預測,預測哪些屬性由特定任務決定,如尺寸、3D位置、方向甚至姿態等。
具體算法流程為:
-
輸入圖像,主干網絡提取特征,得到特征圖。
-
后接三個平行Head網絡,分別生成生成:
-
HeatMap。即用來center points的keypoint heatmap;這一分支也負責了分類,輸出中一個類別對應一個channel。
-
Offset。調整center points的位置,以恢復特征圖與原圖之間不對齊問題造成的偏差( to recover the discretization error caused by the output stride. )
-
Size。即預測center point代表物體的大小。
-
-
通過HeatMap找peak,其反映的就是center points(這里的peak定義為局部最大,找的方法就是某點比8鄰域都大的,通過3×3的maxpooling實現,類似於NMS。這就是作者提到的一個優點——CenterNet無需NMS)。Offset對center point位置進行調整,然后size在此point location上顯示物體大小。
對於其他的任務,可以設計添加不同的分支:
CenterMask
如上圖所示,下面的兩個分支,我們在CenterNet中已經簡單說明,詳情可學習CenterNet的原文了解,也是一項很好的工作,值得一學。下面我們還是將鏡頭拉回這次的主角,負責預測實例mask的上面的分支(其實size分支和CenterNet中的size基本一樣的,只不過這里也同時和shape分支合作預測local shape)。
先說CenterMask的整體流程,圖像輸入到Hourglass network中提取特征,得到center points之后(每一個center point就是一個實例對象),shape和size分支預測出Local shapes,同時saliency分支生成全局顯著圖,然后這兩個結果做Hadamard矩陣乘,得到最后結果。下面詳細展開:
2.Local Shape Prediction
P為主干網絡提取出的特征圖,H×W為shape和size分支生成特征圖的大小,shape分支通道為\(S^2\),\(F_{shape}\in \mathbb{R}^{H×W×S^2}\),size分支通道為2(即長寬h&w)。對於一個在feature map的center ponit \((x,y)\),它的形狀被提取在\(F_{shape}(x,y)\),向量尺寸為\(1×1×S^2\),然后reshape為S×S,再結合size的輸出h&w,resize為h×w,至此就得到了代表着一個實例對象的center ponit對應的局部形狀結果。這一結果因為是由每個center point來的,所以能夠很好區分即使同類重疊的實例對象,但由於是尺寸固定的形狀向量\((s×s)\)只能預測出比較粗糙的mask,所以還需要下面的Global Saliency Generation來提升分割的精度。
3.Global Saliency Generation
這一分支作為一種校准機制來調整local shape的Mask結果。將生成一張全局顯著圖,這張圖的目的是表示每一像素的顯著性,也就是說這個像素是否屬於一個物體的區域。
是一個FCN網絡來對每個像素分類,但與標准語義分割的FCN不同,最后不適用softmax,而是對對每一類使用sigmoid進行二分類,避免類間競爭。(與MaskRCNN一樣)
另外,有兩個模式:
- 類別不可知(class-agnostic),生成H×W×1的二值map。
- 類別明確(class-specific),生成H×W×C的二值map,每一通道都是一個類別對應的map。
4.Mask Assembly
最后對local shape和global saliency進行整合得到最后結果,local shape \(L_{k}\in \mathbb{R}^{h×w}\),把global saliecny也切割為一個個的\(G_{k}\in \mathbb{R}^{h×w}\),然后分別用sigmoid整合到(0,1)的區間,最后兩者做Hadamard乘積(就是兩矩陣對應元素分別相乘):
我們看下圖來直觀理解,only local shape、only global saliency 以及 兩者結合的CenterMask效果,兩者的特點還是很明顯、很符合本文理論論述的:
5.Loss Fuction
整體LossFunction由四部分構成,其中最后一部分\(L_{mask}\)是實例分割mask的,我們先看這部分。
需要注意的是,Mask由local shape和global saliency兩個分支組裝得來,但實際訓練時並沒有為這兩個分支網絡設置各自單獨的loss,而是直接對整個mask結果做懲罰。N是實例的個數,Bce表示對於像素的二分類交叉熵。
其余三項損失並不是本文提出的,而是直接沿用了CenterNet的,在那篇文章和文本都有介紹,下面我們來一項一項過一下:
\(L_p\)為預測center points(heat map)的一個focal loss,\(Y\)是heatmap的ground truth,\(\hat{Y}_{ijc}\)是在預測出的heatmap中位置(i,j)對於類別c的得分。\(\alpha\) 和 $\beta \(的是focal loss的超參,應是延用CenterNet的數值(\)\alpha=2, \beta=4$)。
offset的loss采用L1的形式。\(\hat{O}\)是預測出的偏移offset;p表示ground truth的center point;R是output stride,比如原圖像是400×400,而分割是在50×50的特征圖上做的,則output stride就是8,低分辨率的\(\tilde{p}=\lfloor \frac{P}{R} \rfloor\),這其中的還原取整操作會帶來誤差,所以需要offset分支產生偏移來矯正。
這一項就非常簡單了,分別表示預測出的size和ground truth的size(h,w)。
最后,這四項參數線性組合時的權重分別為1,1,0.1,1。使用Adam迭代更新。
參考及引圖:
https://arxiv.org/pdf/1904.07850.pdf
https://tech.meituan.com/2020/05/21/cvpr2020-centermask.html