DN-DETR
CVPR 2022 的一篇文章
一、Introduction
之前許多工作對 detr 的encoder或是decoder結構進行了改進,以期改善收斂慢的現象。本文作者從另一個角度(訓練方法的角度)分析和解決了detr收斂慢的問題。
第一次提出了全新的去噪訓練(DeNoising training)解決了DETR decoder在訓練過程中二分圖匹配 (bipartite graph matching)不穩定的問題,可以讓模型收斂速度翻倍,並對檢測結果帶來顯著提升(+1.9AP)。該方法簡易實用,可以廣泛運用到各種DETR模型當中,以微小的訓練代價帶來顯著提升。
二、Model
(一)二分圖匹配的不穩定性導致訓練速度慢
我的理解是:匈牙利匹配算法會根據cost metric將兩個匹配程度最高的框,作為一對匹配,在此基礎上計算損失並更新模型,而在訓練過程中模型的更新會使得產生的預測框發生變化,而這種變化會導致cost metric的變化,進而很容易導致匹配結果與之前的匹配結果不同(例如之前是預測框a匹配gtbox 6,模型訓練更新會向着使預測框a與gtbox6匹配程度更高的方向去調整,但這種調整不僅會影響a與gtbox6的匹配程度,還會無意中影響到a與其他gtbox的匹配程度,所以有可能會產生更新后預測框a與gtbox 10的匹配程度高於與gtbox6的匹配程度的情況,這種情況下a又變為與gtbox10匹配了),即二分圖匹配的不穩定性。而這種不穩定性會使得loss值發生波動,使得優化目標具有不連續性,阻礙模型的收斂
針對這種不穩定性,作者設計了評判標准進行量化實驗驗證:
(二)DN-DETR
為了解決二分圖匹配的不穩定問題,作者提出了一種新的訓練方式,就是在原有基礎上增加一個訓練任務,來提高訓練過程的穩定性。該工作在DAB-DETR基礎上進行展開。
在DAB-DETR中,cross-attention的輸入query由兩部分:learnable anchors(anchor box參數,包括x y w h),decoder embeddings(學習目標的內容信息)。
在DN-DETR中,為了更好的發揮新增加的訓練任務denoising task的作用,將decoder embedding替換為了帶有目標標簽信息的class label embedding,並且附加了一個指示器indicator,用來區分是denoising task還是matching task。
執行matching task時,除了輸入的class label embedding是unknown class之外,其他的部分都與之前的DAB-DETR相同;
執行denoising task時,輸入的learnable anchors是將gtbox信息進行中心點偏移或者邊框縮放得到的,class label embedding是將真是標簽按照一定比例進行隨機翻轉得到的。在denoising task中,由於事先知道輸入的信息對應於哪一個gtbox,所以在計算損失時不需要進行二分圖匹配,就不存在匹配不穩定的問題。
在denoising task的干預下,訓練的不穩定性降低。
原文描述如下:
To address this problem, we propose a novel training method by introducing a query denoising task to help stabilize bipartite graph matching in the training process.
Our solution is to feed noised ground truth bounding boxes as noised queries together with learnable anchor queries into Transformer decoders. Both kinds of queries have the same input format of (x, y, w, h) and can be fed into Transformer decoders simultaneously.
For noised queries, we perform a denoising task to reconstruct their corresponding ground truth boxes. For other learnable anchor queries, we use the same training loss including bipartite matching as in the vanilla DETR.
另外,由於denoising task和matching task時同時進行的,所以在內部計算時可能會出現一些信息的交互,使得matching task部分獲知了denoising task部分輸入的信息(由於是從gtbox加噪來的,所以帶有gtbox信息),也就是提前知道了“答案”,這會損害matching部分的學習(最終預測時只保留matching部分,所以它學習到的能力才是最關鍵的)。因此作者設計了一個attention mask來阻止這種信息的交互。
下面是總體圖: