論文筆記《Self-distillation with Batch Knowledge Ensembling Improves ImageNet Classification》


Self-distillation with Batch Knowledge Ensembling Improves ImageNet Classification

Introduction

  • 主要目標在於給batch內的每一個作為anchor的圖片通過傳播同一個batch內其他樣本的知識生成細化的軟標簽refined soft label,所傳播的知識為batch內樣本之間的相似度
  • 該方法基於這樣的假設:外觀相似的樣本應有更一致的類別預測
  • 實際中對於每個樣本,將batch內別的樣本的預測結果或通過加權傳播的方式形成軟目標
  • 知識的傳播是經過迭代處理的,直到其收斂
  • 折是第一個沒有用多個網絡或額外的分支來生成 ensembled soft target的自蒸餾方法
  • 所提出的batch knowledge ensembling使用樣本之間的知識生成驚喜的蒸餾目標

Method

  • 具體的實現方法其實通過偽代碼可以非常清晰的理解,這邊還是介紹一下具體過程,整體框架圖如下

  • 對於我們要傳播的label知識,首先是要對batch內所有的樣本進行相似性計算,生成相似性矩陣\(A\in \mathbb{R}^{N\times N}\),相似性計算后,去除對角線刪去自己與自己的相似性,然后進行一個歸一化,對於每一個樣本的相似性向量和各元素和為1,記作\(\hat{A}\)

  • 將原先所預測出每個樣本所對應的logit記作\(P^{\tau}\),然后將上一步計算出的相似性矩陣與之相乘,相當於利用相似性作一個加權\(\hat{P^{\tau}}=\hat{A}P^{\tau}\)

  • 對於相似性傳播過來的label,我們也要進行一個加權,相當於是得到了我們想要的細化的logit:

  • 這樣的知識傳播需要進行數次,直到收斂,這時候公式中t表示第t次傳播與迭代

  • 當我們t趨於無窮大時,我們對此求極限,相當於是一個等比級數的極限,證明也很簡單,值得注意的是上式中第一項極限為0
  • 所以最后我們的知識加權傳播模塊最終可以表示成如下,該公式在下面的代碼中也有所體現,算是真正得到了我們想要的吸取了batch內所有其他樣本后的logit,值得注意的是對於每一個樣本的refined logit和剛好為1,所以可以直接用:
  • 最后就對於原本的logit和我們細化的logit之間做一次KL loss,加上一定的權重后和原本的CE loss成為本自蒸餾項目全部的loss
  • 值得注意的是,在這個工作中有非常重要的一點,因為logit的細化蒸餾主要依賴於相似性,在一個batch內如果沒有相似的樣本其實是本方法是無效的,所以我們引入了一個對每類都采樣的機制,對於batch大小為\(N\)內有一張圖片后隨機選取同類的\(M\)張圖放入同一個batch中,組成新的batch,這時新的batch大小為\(N\times (M+1)\)
# w: ensembling weight
# t: temperature
# r: loss weight
for (x, gt_labels) in loader:
    # features: N×D, logits: N×K 分別是embedding特征和logit
    f, logits = net.forward(x)
    # classification loss with ground-truth labels
    loss = CrossEntropyLoss(logits, gt labels)
    
    # produce soft targets
    f = normalize(f)
    # 計算batch內各樣本之間的相似度並去除中間的自己與自己,進行一個softmax變成0-1之間,得到公式中的A
    A = softmax(mm(f, f.t())-eye(N)*1e-9) # row-wise normalization of affinity matrix with zero diagonal
    # 最后求過極限之后得到的公式 得到soft_target 
    soft_targets = mm((1-w)·inv(eye(N)-w·A),softmax(logits/t)) # approximate inference for propagation and ensembling
    soft_targets = soft_targets.detach() # no gradient
    
    # distillation loss with soft targets 兩個target之間的KL loss
    loss += KLDivLoss(log_softmax(logits/t), soft_targets)*t^2*r
    # SGD update
    loss.backward()
    update(net.params)

Experiments & Result

  • 做了很多實驗來證明其有效性,首先給出了訓練的細節,如\(N=256,M=1,lr = base\_lr×batch\_size/256\)

  • 首先是不同架構下與原baseline之間的差距和別的label regularzation方法和別的self-distillation之間的區別,常規實驗對比

  • 和別的ensembel distillation方法之間的對比

  • 和別的label refinery方法之間的對比

  • Transfer learning下游任務上目標檢測結果

  • 魯棒性測量實驗結果

  • 每類數據采樣方法的實驗,這個實驗很重要,因為它證明了BAKE方法效果好的原因還是在於knowledge ensemble而不是采樣方法,因為可以看到在正常情況下采用這種采樣方法反而會使效果下降,可能是因為這導致了同一個batch內多樣性下降,而且也並不是同一batch內相同樣本越多越好

  • 小數據集上的實驗結果

Conclusion

  • 一種全新的batch knowledge ensemble方法,為自蒸餾生成了refined soft target,不過這也是建立在一定的采樣方法基礎之上的,雖然該方法還挺有意思的,但受限於這個條件顯得就沒有那么厲害了,因為蒸餾中利用batch之內樣本的相似性來作文章真的挺多了,但這個工作是用來生成新的logit,所以我個人感覺還是挺有意思的,而且這篇文章的算法過程描述的非常清楚了很容易就懂。但目前還並不知道這篇文章中了沒有,其實其對比的自蒸餾方法還是相對來說比較少的,不知道最后結果如何,感謝作者的工作給我帶來的啟發。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM