Distilling the Knowledge in Neural Network
Geoffrey Hinton, Oriol Vinyals, Jeff Dean
preprint arXiv:1503.02531, 2015
NIPS 2014 Deep Learning Workshop
簡單總結
主要工作(What)
- “蒸餾”(distillation):把大網絡的知識壓縮成小網絡的一種方法
- “專用模型”(specialist models):對於一個大網絡,可以訓練多個專用網絡來提升大網絡的模型表現
具體做法(How)
- 蒸餾:先訓練好一個大網絡,在最后的softmax層使用合適的溫度參數T,最后訓練得到的概率稱為“軟目標”。以這個軟目標和真實標簽作為目標,去訓練一個比較小的網絡,訓練的時候也使用在大模型中確定的溫度參數T
- 專用模型:對於一個已經訓練好的大網絡,可以訓練一系列的專用模型,每個專用模型只訓練一部分專用的類以及一個“不屬於這些專用類的其它類”,比如專用模型1訓練的類包括“顯示器”,“鼠標”,“鍵盤”,...,“其它”;專用模型2訓練的類包括“玻璃杯”,“保溫杯”,“塑料杯”,“其它“。最后以專用模型和大網絡的預測輸出作為目標,訓練一個最終的網絡來擬合這個目標。
意義(Why)
- 蒸餾把大網絡壓成小網絡,這樣就可以先在訓練階段花費大精力訓練一個大網絡,然后在部署階段以較小的計算代價來產生一個較小的網絡,同時保持一定的網絡預測表現。
- 對於一個已經訓練好的大網絡,如果要去做集成的話計算開銷是很大的,可以在這個基礎上訓練一系列專用模型,因為這些模型通常比較小,所以訓練會快很多,而且有了這些專用模型的輸出可以得到一個軟目標,實驗證明使用軟目標訓練可以減小過擬合。最后根據這個大網絡和一系列專用模型的輸出作為目標,訓練一個最終的網絡,可以得到不錯的表現,而且不需要對大網絡做大量的集成計算。
Abstract
提高機器學習算法表現的一個簡單方法就是,訓練不同模型然后對預測結果取平均。
但是要訓練多個模型會帶來過高的計算復雜度和部署難度。
可以將集成的知識壓縮在單一的模型中。
論文使用這種方法在MNIST上做實驗,發現取得了不錯的效果。
論文還介紹了一種新型的集成,包括一個或多個完整模型和專用模型,能夠學習區分完整模型容易混淆的細粒度的類別。
1 Introduction
昆蟲有幼蟲期和成蟲期,幼蟲期主要行為是吸收養分,成蟲期主要行為是生長繁殖。
類似地,大規模機器學習應用可以分為訓練階段和部署階段,訓練階段不要求實時操作,允許訓練一個復雜緩慢的模型,這個模型可以是分別訓練多個模型的集成,也可以是單獨的一個很大的帶有強正則比如dropout的模型。
一旦模型訓練好,可以用不同的訓練,這里稱為“蒸餾”,去把知識轉移到更適合部署的小模型上。
復雜模型學習區分大量的類,通常的訓練目標是最大化正確答案的平均log概率,這么做有一個副作用就是訓練模型同時也會給所有的錯誤答案分配概率,即使這個概率很小,而有一些概率會比其它的大很多。錯誤答案的相對概率體現了復雜模型的泛化能力。舉個例子,寶馬的圖像被錯認為垃圾箱的概率很低,但是這被個錯認為垃圾桶的概率相比於被錯認為胡蘿卜的概率來說,是很大的。(可以認為模型不止學到了訓練集中的寶馬圖像特征,還學到了一些別的特征,比如和垃圾桶共有的一些特征,這樣就可能捕捉到在新的測試集上的寶馬出現這些的特征,這就是泛化能力的體現)
將復雜模型轉為小模型需要保留模型的泛化能力,一個方法就是用復雜模型產生的分類概率作為“軟目標”來訓練小模型。
當軟目標的熵值較高時,相對於硬目標,每個訓練樣本提供更多的信息,訓練樣本之間會有更小的梯度方差。
所以小模型經常可以被訓練在小數據集上,而且可以使用更高的學習率。
像MNIST這種分類任務,復雜模型可以產生很好的表現,大部分信息分布在小概率的軟目標中。
為了規避這個問題,Caruana和他的合作者們使用softmax輸出前的units值,而不是softmax后的概率,最小化復雜模型和簡單模型的units的平方誤差來訓練小模型。
而更通用的方法,蒸餾法,先提高softmax的溫度參數直到模型能產生合適的軟目標。然后在訓練小模型匹配軟目標的時候使用相同的溫度T。
被用於訓練小模型的轉移訓練集可以包括未打標簽的數據(可以沒有原始的實際標簽,因為可以通過復雜模型獲取一個軟目標作為標簽),或者使用原始的數據集,使用原始數據集可以得到更好的表現。
2 Distillation
softmax公式: $ q_{i} = \frac{exp(z_{i}/T)}{\sum_{j}^{ }exp(z_{j}/T)} $
其中溫度參數T通常設置為1,T越大可以得到更“軟”的概率分布。
(T越大,不同激活值的概率差異越小,所有激活值的概率趨於相同;T越小,不同激活值的概率差異越大)
(在蒸餾訓練的時候使用較大的T的原因是,較小的T對於那些遠小於平均激活值的單元會給予更少的關注,而這些單元是有用的,使用較高的T能夠捕捉這些信息)
最簡單的蒸餾形式就是,訓練小模型的時候,以復雜模型得到的“軟目標”為目標,采用復雜模型中的較高的T,訓練完之后把T改為1。
當部分或全部轉移訓練集的正確標簽已知時,蒸餾得到的模型會更優。一個方法就是使用正確標簽來修改軟目標。
但是我們發現一個更好的方法,簡單對兩個不同的目標函數進行權重平均,第一個目標函數是和復雜模型的軟目標做一個交叉熵,使用的復雜模型的溫度T;第二個目標函數是和正確標簽的交叉熵,溫度設置為1。我們發現第二個目標函數被分配一個低權重時通常會取得最好的結果。
3 Preliminary experiments on MNIST
| net | layers | units of each layer | activation | regularization | test errors |
|---|---|---|---|---|---|
| single net1 | 2 | 1600 | relu | dropout | 67 |
| single net2 | 2 | 800 | relu | no | 146 |
(防止表格黏在一起)
| net | large net | small net | temperature | test errors |
|---|---|---|---|---|
| distilled net | single net1 | single net2 | 20 | 74 |
(第一個表格中是兩個單獨的網絡,一個大網絡和一個小網絡。)
(第二個表格是使用了蒸餾的方法,先訓練大網絡,然后根據大網絡的“軟目標”結果和溫度T來訓練小網絡。)
(可以看到,通過蒸餾的方法將大網絡中的知識壓縮到小網絡中,取得了不錯的效果。)
4 Experiments on speech recognition
| system | Test Frame Accuracy | Word Error Rate on dev set |
|---|---|---|
| baseline | 58.9% | 10.9% |
| 10XEnsemble | 61.1% | 10.7% |
| Distilled model | 60.8% | 10.7% |
其中basline的配置為
- 8 層,每層2560個relu單元
- softmax層的單元數為14000
- 訓練樣本大小約為 700M,2000個小時的語音文本數據
10XEnsemble是對baseline訓練10次(隨機初始化為不同參數)然后取平均
蒸餾模型的配置為
- 使用的候選溫度為{1, 2, 5, 10}, 其中T為2時表現最好
- hard target 的目標函數給予0.5的相對權重
可以看到,相對於10次集成后的模型表現提升,蒸餾保留了超過80%的效果提升
5 Training ensembles of specialists on very big datasets
訓練一個大的集成模型可以利用並行計算來訓練,訓練完成后把大模型蒸餾成小模型,但是另一個問題就是,訓練本身就要花費大量的時間,這一節介紹的是如何學習專用模型集合,集合中的每個模型集中於不同的容易混淆的子類集合,這樣可以減小計算需求。專用模型的主要問題是容易集中於區分細粒度特征而導致過擬合,可以使用軟目標來防止過擬合。
5.1 JFT數據集
JFT是一個谷歌的內部數據集,有1億的圖像,15000個標簽。google用一個深度卷積神經網絡,訓練了將近6個月。
我們需要更快的方法來提升baseline模型。
5.2 專用模型
將一個復雜模型分為兩部分,一部分是一個用於訓練所有數據的通用模型,另一部分是很多個專用模型,每個專用模型訓練的數據集是一個容易混淆的子類集合。這些專用模型的softmax結合所有不關心的類為一類來使模型更小。
為了減少過擬合,共享學習到的低水平特征,每個專用模型用通用模型的權重進行初始化。另外,專用模型的訓練樣本一半來自專用子類集合,另一半從剩余訓練集中隨機抽取。
5.3 將子類分配到專用模型
專用模型的子類分組集中於容易混淆的那些類別,雖然計算出了混淆矩陣來尋找聚類,但是可以使用一種更簡單的辦法,不需要使用真實標簽來構建聚類。對通用模型的預測結果計算協方差,根據協方差把經常一起預測的類作為其中一個專用模型的要預測的類別。幾個簡單的例子如下。
JFT 1: Tea party; Easter; Bridal shower; Baby shower; Easter Bunny; ...
JFT 2: Bridge; Cable-stayed bridge; Suspension bridge; Viaduct; Chimney; ...
JFT 3: Toyota Corolla E100; Opel Signum; Opel Astra; Mazda Familia; ...
5.4 實驗表現
| system | Conditional Test Accuracy | Test Accuracy |
|---|---|---|
| baseline | 43.1% | 25.0% |
| 61 specialist models | 45.9% | 26.1% |
6 Soft Targets as Regularizers
對於前面提到過的,對於大量數據訓練好的語音baseline模型,用更少的數據去擬合這個模型的時候,使用軟目標可以達到更好的效果,減小過擬合。實驗結果如下。
| system & training set | Train Frame Accuracy | Test Frame Accuracy |
|---|---|---|
| baseline(100% training set) | 63.4% | 58.9% |
| baseline(3% training set) | 67.3% | 44.5% |
| soft targets(3% training set) | 65.4% | 57.0% |
