Self-distillation with Batch Knowledge Ensembling Improves ImageNet Classification
- 2021.5.13
- Project Page: https://geyixiao.com/projects/bake
- https://arxiv.org/abs/2104.13298
Introduction
- 主要目標在於給batch內的每一個作為anchor的圖片通過傳播同一個batch內其他樣本的知識生成細化的軟標簽refined soft label,所傳播的知識為batch內樣本之間的相似度
- 該方法基於這樣的假設:外觀相似的樣本應有更一致的類別預測
- 實際中對於每個樣本,將batch內別的樣本的預測結果或通過加權傳播的方式形成軟目標
- 知識的傳播是經過迭代處理的,直到其收斂
- 折是第一個沒有用多個網絡或額外的分支來生成 ensembled soft target的自蒸餾方法
- 所提出的batch knowledge ensembling使用樣本之間的知識生成驚喜的蒸餾目標
Method
-
具體的實現方法其實通過偽代碼可以非常清晰的理解,這邊還是介紹一下具體過程,整體框架圖如下
-
對於我們要傳播的label知識,首先是要對batch內所有的樣本進行相似性計算,生成相似性矩陣\(A\in \mathbb{R}^{N\times N}\),相似性計算后,去除對角線刪去自己與自己的相似性,然后進行一個歸一化,對於每一個樣本的相似性向量和各元素和為1,記作\(\hat{A}\)
-
將原先所預測出每個樣本所對應的logit記作\(P^{\tau}\),然后將上一步計算出的相似性矩陣與之相乘,相當於利用相似性作一個加權\(\hat{P^{\tau}}=\hat{A}P^{\tau}\)
-
對於相似性傳播過來的label,我們也要進行一個加權,相當於是得到了我們想要的細化的logit:
-
這樣的知識傳播需要進行數次,直到收斂,這時候公式中t表示第t次傳播與迭代

- 當我們t趨於無窮大時,我們對此求極限,相當於是一個等比級數的極限,證明也很簡單,值得注意的是上式中第一項極限為0

- 所以最后我們的知識加權傳播模塊最終可以表示成如下,該公式在下面的代碼中也有所體現,算是真正得到了我們想要的吸取了batch內所有其他樣本后的logit,值得注意的是對於每一個樣本的refined logit和剛好為1,所以可以直接用:

- 最后就對於原本的logit和我們細化的logit之間做一次KL loss,加上一定的權重后和原本的CE loss成為本自蒸餾項目全部的loss
- 值得注意的是,在這個工作中有非常重要的一點,因為logit的細化蒸餾主要依賴於相似性,在一個batch內如果沒有相似的樣本其實是本方法是無效的,所以我們引入了一個對每類都采樣的機制,對於batch大小為\(N\)內有一張圖片后隨機選取同類的\(M\)張圖放入同一個batch中,組成新的batch,這時新的batch大小為\(N\times (M+1)\)
# w: ensembling weight
# t: temperature
# r: loss weight
for (x, gt_labels) in loader:
# features: N×D, logits: N×K 分別是embedding特征和logit
f, logits = net.forward(x)
# classification loss with ground-truth labels
loss = CrossEntropyLoss(logits, gt labels)
# produce soft targets
f = normalize(f)
# 計算batch內各樣本之間的相似度並去除中間的自己與自己,進行一個softmax變成0-1之間,得到公式中的A
A = softmax(mm(f, f.t())-eye(N)*1e-9) # row-wise normalization of affinity matrix with zero diagonal
# 最后求過極限之后得到的公式 得到soft_target
soft_targets = mm((1-w)·inv(eye(N)-w·A),softmax(logits/t)) # approximate inference for propagation and ensembling
soft_targets = soft_targets.detach() # no gradient
# distillation loss with soft targets 兩個target之間的KL loss
loss += KLDivLoss(log_softmax(logits/t), soft_targets)*t^2*r
# SGD update
loss.backward()
update(net.params)
Experiments & Result
-
做了很多實驗來證明其有效性,首先給出了訓練的細節,如\(N=256,M=1,lr = base\_lr×batch\_size/256\)
-
首先是不同架構下與原baseline之間的差距和別的label regularzation方法和別的self-distillation之間的區別,常規實驗對比
-
和別的ensembel distillation方法之間的對比
-
和別的label refinery方法之間的對比
-
Transfer learning下游任務上目標檢測結果
-
魯棒性測量實驗結果
-
每類數據采樣方法的實驗,這個實驗很重要,因為它證明了BAKE方法效果好的原因還是在於knowledge ensemble而不是采樣方法,因為可以看到在正常情況下采用這種采樣方法反而會使效果下降,可能是因為這導致了同一個batch內多樣性下降,而且也並不是同一batch內相同樣本越多越好
-
小數據集上的實驗結果
Conclusion
- 一種全新的batch knowledge ensemble方法,為自蒸餾生成了refined soft target,不過這也是建立在一定的采樣方法基礎之上的,雖然該方法還挺有意思的,但受限於這個條件顯得就沒有那么厲害了,因為蒸餾中利用batch之內樣本的相似性來作文章真的挺多了,但這個工作是用來生成新的logit,所以我個人感覺還是挺有意思的,而且這篇文章的算法過程描述的非常清楚了很容易就懂。但目前還並不知道這篇文章中了沒有,其實其對比的自蒸餾方法還是相對來說比較少的,不知道最后結果如何,感謝作者的工作給我帶來的啟發。