知識蒸餾的意義
能夠壓縮模型,提升模型性能
為什么能夠壓縮模型?
!!!誰知道了告訴我一下!!!
為什么能提升模型精度?
栗子:分類問題有三個分類:貓,狗,烏龜,實際訓練過程中,比如當前的數據真實標簽是:貓,模型預測出貓,狗,烏龜的概率分別是0.6, 0.3, 0.1,
傳統思路:不錯,識別對了,貓的概率最高,給模型一定的獎勵;
知識蒸餾:不錯,識別對了,貓的概率最高,並且狗比烏龜更像貓,給模型一定的獎勵;
總結:即便是負樣本,也包含大量知識,知識蒸餾能把這部分知識也學習起來。
大致步驟:
1. 基於一個已經訓練好的NET-T模型,該模型經過大量數據的訓練准確度很高,但是模型笨重,將NET-T模型最終softmax結果進行軟化,生成soft-target,繼而生成loss1;
2. 創造一個輕量模型NET-S正常前像傳播,實際標簽用one-hot向量表示即hard-target,生成loss2;
3. 將loss1與loss2加權求和生成loss3;
4. loss3用於更新NET-S網絡;
即將NET-T模型的知識遷移到NET-S上並優化性能