三分鍾理解知識蒸餾


知識蒸餾的意義

能夠壓縮模型,提升模型性能

為什么能夠壓縮模型?

!!!誰知道了告訴我一下!!!

為什么能提升模型精度?

栗子:分類問題有三個分類:貓,狗,烏龜,實際訓練過程中,比如當前的數據真實標簽是:貓,模型預測出貓,狗,烏龜的概率分別是0.6, 0.3, 0.1,

傳統思路:不錯,識別對了,貓的概率最高,給模型一定的獎勵;

知識蒸餾:不錯,識別對了,貓的概率最高,並且狗比烏龜更像貓,給模型一定的獎勵;

總結:即便是負樣本,也包含大量知識,知識蒸餾能把這部分知識也學習起來。

大致步驟:

1. 基於一個已經訓練好的NET-T模型,該模型經過大量數據的訓練准確度很高,但是模型笨重,將NET-T模型最終softmax結果進行軟化,生成soft-target,繼而生成loss1;

2. 創造一個輕量模型NET-S正常前像傳播,實際標簽用one-hot向量表示即hard-target,生成loss2;

3. 將loss1與loss2加權求和生成loss3;

4. loss3用於更新NET-S網絡;

即將NET-T模型的知識遷移到NET-S上並優化性能


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM