通常我們訓練出的模型都比較大,將這些模型部署到例如手機、機器人等移動設備上時比較困難。模型壓縮(model compression)可以將大模型壓縮成小模型,壓縮后的小模型也能得到和大模型接近甚至更好的性能。這篇文章總結了幾種常用的模型壓縮方法:網絡裁剪(network pruning)、知識蒸餾(knowledge distillation)、參數量化(parameter quantization)以及模型結構設計(architecture design)。
網絡裁剪
一些研究表明,我們訓練出的很多大模型的參數都是過多的(over-paramerterized),有很多冗余參數或者神經元。我們可以將這些冗余的參數和神經元給裁減掉,從而減小模型的體積。網絡裁剪的方法有參數裁剪和神經元裁剪。
參數裁剪
如果一些參數接近於 0,那么我們就可以把這些參數給裁減掉。也就是說,可以通過判斷某參數的 l1 或者 l2 范數是否接近於 0 來決定是否要裁剪該參數。
參數裁剪之后,模型變得不對稱了。不對稱的模型難以用代碼實現(通常使用 0 來代替被裁剪的權重),也難以利用 GPU 進行加速。
神經元裁剪
如果我們在訓練的過程中,某一神經元的輸出在大多數情況下都為 0 或者接近 0,那么我們就可以把這個神經元給裁減掉。
神經元裁剪之后網絡還是對稱的,所以神經元裁剪比參數裁剪更容易實現,也更容易利用 GPU 加速。
裁剪后微調
我們在模型裁剪之后會得到一個小模型,小模型通常會損失一些性能。為了提高小模型的性能,我們可以將小模型在原來的訓練集上進行微調,這樣可以將丟失的性能填補回來。
上圖是網絡裁剪的流程圖。需要的注意的是,不要一次將模型裁剪太多,這樣的話模型的性能很難恢復。所以,每次對模型裁剪一點,循環裁剪多次,直到模型達到自己的要求。
知識蒸餾
知識蒸餾(knowledge distillation)的思想是先訓練一個大模型(teacher),然后再訓練一個小模型(student)來擬合大模型的輸出。這樣,最終得到的小模型體積比大模型小,但能獲得和大模型接近的性能。
在上圖中,我們要訓練一個模型對手寫數字進行分類。首先將“1”輸入到大模型中,大模型的輸入為:"1":0.7, "7":0.2, "9":0.1
,也就是模型認為這張圖片是 1 的可能性為 0.7, 是 7 的可能性為 0.2, 是 9 的可能性為 0.1,然后我們再訓練一個小模型來同時擬合原來的標簽 1 (hard target)和大模型的輸出(soft target)。大模型的輸出能帶來更多的信息,例如,通過大模型的輸出我們可以得到輸入的標簽為 1 ,而且我們還可以知道 1、7、9 這 3 個數字是相似的,這是原始數據集中沒有的。假設輸入為 x,正確標簽為 y,大模型(teacher)的輸出概率為 p,小模型(student)的輸出概率為 q,則小模型的訓練目標為:
其中,\(CE\) 為交叉熵(Cross Entropy),\(\alpha\) 為權重。
知識蒸餾也可以應用於模型融合(ensemble)。例如,在 kaggle 比賽中,為了獲得更好的成績,我們通常會融合多個模型,在分類問題中,一種融合模型的方法是將各個模型的輸出概率求平均作為最終的輸出概率。這種方法在比賽中可以使用,但在實際使用中意義不大,因為融合多個模型會增加推斷時間,而且總的模型體積也會成倍增加。可以使用知識蒸餾來解決這個問題。拿分類問題舉例,我們融合后的結果是多個模型輸出概率的平均,所以我們可以再訓練一個小模型來擬合大模型融合后的結果。這樣就相當於用一個小模型來替代了參與融合的所有模型。
在分類問題中,網絡的最后一層通常是一個 softmax 層,softmax 的計算方法如下
這樣會有一個問題。例如,我們在訓練大模型(teacher)的時候,假設是 3 分類問題,softmax 的輸入為:x1:100, x2:10, x3:1
,這樣 softmax 的輸出會接近 1, 0, 0
,因為這 3 個輸入差的太多了,所以輸出會接近 one hot vector。而1, 0, 0
這樣的 one hot 輸出和原始數據集中的輸出的一樣的,小模型(student)不會從這樣的輸出中學到其他的信息,蒸餾學習也就失去了意義。
為了避免這個問題,我們需要在 softmax 中增加溫度參數 T,此時,softmax 的計算方法如下
也就是先對 softmax 的輸入除以 T。上圖中 T=100,所以輸入由x1:100, x2:10, x3:1
變成了x1:1, x2:0.1, x3:0.01
,此時 softmax 的輸出為 0.56, 023, 0.21
,這樣小模型(student)就可以從大模型(teacher)的輸出中學到其他的信息。
參數量化
參數量化(parameter quantization)通過對模型的參數做一些限制來減小模型的體積。
使用 16 bit
模型參數模型是 32 位的,我們知道如果在訓練時使用 16 位來進行訓練,會大大降低顯存占用並加快訓練速度。所以,我們可以用 16 位的參數來代替 32 位的參數來實現模型壓縮。
參數聚類
參數聚類(parameter clustering)使用聚類算法(例如 k-means)將相似的參數轉為同一個值,然后記錄某個參數屬於哪個簇以及該簇對應的參數值是多少即可。例如,聚類前的參數如下
我們通過聚類得到了 4 個簇,每個簇用一種顏色表示
因為有 4 個簇,所以我們可以用 2 個二進制位來表示每個參數,例如 00 表示藍色簇。然后再使用一個表來記錄顏色和參數的對應關系。這樣的話可以大大減少模型的體積。還可以使用哈夫曼編碼來對簇的編碼進行進一步的優化。
Binary Weights
我們還可以使用 binary weights,也就是每個參數的取值只有兩種 0 或者 1。直覺上這樣的網絡的性能可能會很差,實際上,使用 binary weights 的網絡和使用普通 weights 的網絡在一些分類問題上的性能差距不是很大
表格中的值為分類錯誤率。Binary weights 能達到這樣性能的原因可以從正則化方面進行解釋。Binary weights 對參數進行了限制。
模型結構設計
還可以從模型結構設計(architecture design)方面入手來減少模型的參數。
考慮如下圖所示的神經網絡
第一層有 N 個神經元,第二層有 M 個神經元,則參數大小為 M × N。我們在兩層之間插入一個包含 K 個神經元的新層且該層不使用激活函數,其中 K<M,N,如下
則參數矩陣變成了兩個矩陣 M × K 和 K × N 之和。當 K<M,N 時,參數數量會降低(ALBERT就使用了這個技巧)。
這一技巧會降低參數矩陣的秩(rank),所以也叫做低秩近似(low rank approximation)。
我們再把這一技巧推廣到卷積神經網絡(CNN)中。假設我們的輸入形狀為 6 × 6 × 2,也就是 2 個信道,每個信道的形狀為 6 × 6,如下
因為輸入的信道為 2,則卷積核的信道也為 2,假設我們使用 4 個 3 × 3 × 2 的卷積核對輸入進行卷積,則最終結果的形狀為 4 × 4 × 4,如下
這樣共需要 72 個參數。
我們對卷積的過程分為兩個步驟:depthwise convolution 和 pointwise convolution,這叫做 Depthwise Separable Convolution。首先,我們使用 2 個 3 × 3 的卷積核分別對輸入的兩個信道進行卷積
在上圖中,深藍色卷積核對深灰色輸入信道進行卷積,淺藍色卷積核對淺藍色輸入信道進行卷積最終得到 2 個 4 × 4 的輸出,這一步驟叫做 depthwise convolution。
然后,我們對這 2 個 4 × 4 的輸出進行 pointwise convolution,具體是使用 4 個 1 × 1 × 2 的卷積核對其進行卷積
最終會得到的輸出和常規卷積的輸出具有相同的尺寸,但是我們只用了 18 + 8 = 26 個參數。
其實,我們將卷積分為兩個步驟就相當於在兩個線性層中間插入一個更小的層
下面估計一下這種卷積參數量的差距。使用 I 表示輸入信道的個數,在上面的例子中,I=2,;使用 O 表示輸出信道的個數,上面的例子中 O=4;使用 k 表示卷積核的大小。則常規卷積需要的參數數量為 (k × k × I)× O,而 Depthwise Separable Convolution 需要的參數數量為 k × k × I + I × O。
兩者的比值為
通常,O 比較大,例如 256,所以最終的結果接近 \(\frac{1}{k*k}\),含義是假設卷積核的尺寸為 3,那么 Depthwise Separable Convolution 需要的參數量為常規卷積參數量的九分之一。這一技巧在 MobileNet 中被使用。
總結
這篇文章介紹了網絡裁剪(network pruning)、知識蒸餾(knowledge distillation)、參數量化(parameter quantization)以及模型結構設計(architecture design)這 4 種模型壓縮的方法,主要參考了李宏毅《深度學習2019》課程。這 4 種方法基本上都是從減少模型參數的角度來進行模型壓縮,只有使用 16 位模型這種方法是從存儲角度來壓縮模型。還有一個問題,就是既然我們需要小模型,那么為什么不直接訓練小模型呢?原因是大模型通常更加容易訓練而且性能也會更好。
參考
1、李宏毅《深度學習2019》
2、https://posts.careerengine.us/p/5e040074089a4c71be7da859