Feature Fusion for Online Mutual Knowledge Distillation (CVPR 2019)


一、解決問題

  • 如何將特征融合與知識蒸餾結合起來,提高模型性能

二、創新點

  • 支持多子網絡分支的在線互學習
  • 子網絡可以是相同結構也可以是不同結構
  • 應用特征拼接、depthwise+pointwise,將特征融合和知識蒸餾結合起來

三、實驗方法和理論

1.Motivation

DML (Deep Mutual Learning)

  • 算法思想:

​ 用兩個子網絡(可以是不同的網絡結構)進行在線互學習,得到比單獨訓練性能更好的網絡

  • 損失函數:

​ 傳統監督損失函數:

​ 模仿性的損失函數:

​ 單個網絡的損失函數:

ONE (On-the-FlyNative Ensemble)

  • 算法思想:

​ 通過在網絡深層次構造多分支結構,每個分支作為學生網絡,融合logit分布,生成更強的教師網絡,進而通過學生/教師網絡的共同在線學習,互相蒸餾,訓練得性能優越的單分支或多分支融合模型。

  • logit融合 (Gate Module:FC、BN、ReLU、Softmax):

  • 損失函數:

DualNet

  • 算法思想:

​ 通過融合兩個互補parallel networks生成的特征,使得融合后的性能比單獨訓練的性能更好

  • 損失函數:

  • 啟發:結合DML、ONE和DualNet的思想,構造一個支持(相同或者不同的)多個子網絡分支進行特征融合的網絡結構,進而讓融合分類器和分類器進行在線互學習,互蒸餾的方式,從而提高網絡的性能。

    2.Network Architecture

  • Fusion Module

  • Fusion Module 將Net1 和Net2 的到的特征張量進行拼接,然后通過Depthwise conv 得到一個通道數為M的特征張量,經過 Pointwise conv 后生成一個通道數為N的特征張量,即為融合后的特征。
  • 子網絡和融合網絡同時訓練,將子網絡最后一層得到的特征,通過一個Fusion Module進行特征融合,得到融合分類器的概率分布。

3.訓練過程

  • 軟分布概率:

其中,

  • 集成logit概率分布計算:

  • 交叉熵損失函數:

  • KL散度損失函數:

​ 這里有兩個KL散度損失函數,分別對應從 Ensemble Classifier 到 Fused Classifier 的知識蒸餾和從 Fused Classifer 到 Sub-network Classifier 的知識蒸餾的損失函數。

  • 總的損失函數:

四、實驗結果

數據集

  • CIFAR-10
    • 50k 訓練集,10k 測試集
    • 10種圖像類別,每類 6k 張圖片
  • CIFAR-100
    • 50k 訓練集,10k 測試集
    • 100種圖像類別,每類600張圖片
  • ImageNet LSVRC2015
    • 1.2M 訓練集,50k 驗證集
    • 1000種圖像類別

特征融合對比(FFL vs DualNet):

  • FFL融合后的性能略比DualNet好
  • FFL得到的子網絡性能明顯比DualNet好

消融實驗

  • 缺少任何一個模塊都會導致融合分類器和子分類的效果下降,尤其當缺少FKD時,子網絡性能下降很多。

在線蒸餾對比(FFL vs ONE):

由於FFL比ONE多了一個Fusion Module為了參數大小公平起見,ONE在Gate模塊前多疊加幾個殘差模塊

  • vanilla 表示單獨訓練的結果,ONE表示兩個子網絡的平均結果,ONE-E表示融合后的結果,ONE-E+表示參數與FFL大小一樣融合后的結果,FFL-S表示子網絡的平均結果,FFL表示融合后的結果
  • 即便增加ONE的殘差模塊,從ONE-E和ONE-E+的對比來看,性能並沒有多大提升,甚至有所下降(例如CIFAR-100)
  • 從表格發現,FFL比ONE的效果略有提升

分支拓展:

  • 隨着分支數增多,性能也略有提升。

ImageNet:

  • ONE 和 FFL性能相似,FFL效果略好一些。
  • 這說明了本文方法一樣適用於大規模的數據集

互學習性能對比(FFL vs DML):

  • 雖然參數量FFL比DML多4%,但性能優於DML,也說明了FFL適用於不同子網絡結構。

定性分析

  • 1-2列,分類都是正確,但FFL關注的特征區域比單獨訓練的ResNet-34好,且置信度更高
  • 3-6列,FFL分類正確,而單獨訓練的ResNet-34分類錯誤
  • 7-9列,兩者分類都是錯誤的,但是FFL關注的特征區域屬於正確類別的關注區域。
  • 同時我們發現Subnet的特征熱區一直在擬合Fusion的結果,這也驗證了互蒸餾的有效性,即的確學習到軟概率分布中含有的豐富的錯誤類別的相關概率信息。

五、 總結

  • 結合預訓練模型,該方法可以適用於圖像檢測(RPN特征),圖像分割(dense feature),風格遷移等任務。

  • 同時兼顧子網絡和融合網絡的性能,根據實際需要,選擇子網絡或者融合網絡

  • Fusion Module 可以得到更為豐富的圖像特征,從而提高整體性能。

  • 子網絡的選擇限制低,可以選擇多個相同或者不同的網絡構成

  • 能夠將多個方法的優點結合起來得到更好的方法,實驗充分

  • 不足:參數量略多一些,以及子網絡結構選取的不確定性


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM