知識蒸餾(Distillation)


蒸餾神經網絡取名為蒸餾(Distill),其實是一個非常形象的過程。

我們把數據結構信息和數據本身當作一個混合物,分布信息通過概率分布被分離出來。首先,T值很大,相當於用很高的溫度將關鍵的分布信息從原有的數據中分離,之后在同樣的溫度下用新模型融合蒸餾出來的數據分布,最后恢復溫度,讓兩者充分融合。這也可以看成Prof. Hinton將這一個遷移學習過程命名為蒸餾的原因。

 

蒸餾神經網絡想做的事情,本質上更接近於遷移學習(Transfer Learning),當然也可從模型壓縮(Model Compression)的角度取理解蒸餾神經網絡。

Hinton的這篇論文嚴謹的數學思想推導並不復雜,但是主要是通過巧妙的實驗設計來驗證了蒸餾神經網絡的可行性,所以本專題主要從蒸餾的思想以及實驗的設計來介紹蒸餾神經網絡。

 

 

 

Distillation:

修改后的softmax公式為:

T就是一個調節參數,通常為1;T的數值越大則所有類的分布越‘軟’(平緩)。

公式中,T參數是一個溫度超參數,按照softmax的分布來看,隨着T參數的增大,這個軟目標的分布更加均勻。

一個簡單的知識蒸餾的形式是:用復雜模型得到的“軟目標”為目標(在softmax中T較大),用“轉化”訓練集訓練小模型。訓練小模型時T不變仍然較大,訓練完之后T改為1。 

當正確的標簽是所有的或部分的傳輸集時,這個方法可以通過訓練被蒸餾的模型產生正確的標簽。一種方法是使用正確的標簽來修改軟目標,但是我們發現更好的方法是簡單地使用兩個不同目標函數的加權平均值。第一個目標函數是帶有軟目標的交叉熵,這種交叉熵是在蒸餾模型的softmax中使用相同的T計算的,用於從繁瑣的模型中生成軟目標。第二個目標函數是帶有正確標簽的交叉熵。這是在蒸餾模型的softmax中使用完全相同的邏輯,但在T=1下計算。我們發現,在第二個目標函數中,使用一個較低權重的條件,得到了最好的結果。由於軟目標尺度所產生的梯度的大小為1/T^2,所以在使用硬的和軟的目標時將它們乘以T^2是很重要的。這確保了在使用T時,硬和軟目標的相對貢獻基本保持不變。

   

  1. T參數是什么?有什么作用?

        T參數為了對應蒸餾的概念,在論文中叫的是Temperature,也就是蒸餾的溫度。T越高對應的分布概率越平緩,為什么要使得分布概率變平緩?舉一個例子,假設你是每次都是進行負重登山,雖然過程很辛苦,但是當有一天你取下負重,正常的登山的時候,你就會變得非常輕松,可以比別人登得高登得遠。

        同樣的,在這篇文章里面的T就是這個負重包,我們知道對於一個復雜網絡來說往往能夠得到很好的分類效果,錯誤的概率比正確的概率會小很多很多,但是對於一個小網絡來說它是無法學成這個效果的。我們為了去幫助小網絡進行學習,就在小網絡的softmax加一個T參數,加上這個T參數以后錯誤分類再經過softmax以后輸出會變大(softmax中指數函數的單增特性,這里不做具體解釋),同樣的正確分類會變小。這就人為的加大了訓練的難度,一旦將T重新設置為1,分類結果會非常的接近於大網絡的分類效果

  2. soft target(“軟目標”)是什么?

        soft就是對應的帶有T的目標,是要盡量的接近於大網絡加入T后的分布概率。

  3. hard target(“硬目標”)是什么?

         hard就是正常網絡訓練的目標,是要盡量的完成正確的分類。

  4. 兩個目標函數究竟是什么?

        兩個目標函數也就是對應的上面的soft target和hard target。這個體現在Student Network會有兩個loss,分別對應上面兩個問題求得的交叉熵,作為小網絡訓練的loss function。

  5. 具體蒸餾是如何訓練的?

  Teacher:  對softmax(T=20)的輸出與原始label求loss。

  Student: (1)對softmax(T=20)的輸出與Teacher的softmax(T=20)的輸出求loss1。

       (2)對softmax(T=1)的輸出與原始label求loss2。

       (3)loss = loss1+loss2








    https://blog.csdn.net/paper_reader/article/details/81080857




免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM