bert剪枝系列——Are Sixteen Heads Really Better than One?


1,概述

  剪枝可以分為兩種:一種是無序的剪枝,比如將權重中一些值置為0,這種也稱為稀疏化,在實際的應用上這種剪枝基本沒有意義,因為它只能壓縮模型的大小,但很多時候做不到模型推斷加速,而在當今的移動設備上更多的關注的是系統的實時相應,也就是模型的推斷速度。另一種是結構化的剪枝,比如卷積中對channel的剪枝,這種不僅可以降低模型的大小,還可以提升模型的推斷速度。剪枝之前在卷積上應用較多,而隨着bert之類的預訓練模型的出現,這一類模型通常比較大,且推斷速度較慢。例如bert在文本分類的任務上,128的序列長度,其推斷速度都只有80ms左右,這還只是單個模型,而一個大的系統,往往是有多個模型組成的。因此bert要想在工業界,尤其是移動端落地,是極度需要模型壓縮的。

2,具體方法

  看完這篇論文之后,更多的感覺是這篇論文並沒有在剪枝上有太多的貢獻,更像是對multi head中head的數量做了一個實驗性的工作,探索了在multi head中並不是所有的head都需要,有很多head提取的信息對最終的結果並沒有什么影響,是冗余存在的。

  本論文在探討在test階段,去掉一部分head是否會影響模型的性能,得到的結論是大多數都不會,而且部分還會提升性能,作者給出了三種實驗方法來證明這一點:

  1,每次去掉一層中一個head,測試模型的性能

  2,每次去掉一層中剩余的層,只保存一個head,測試模型的性能

  3,通過梯度來判斷每個head的重要性,然后去掉一部分不重要的head,測試模型的性能

  為了實現上述的實驗,作者對multi head的計算做了一些修改,修改后的公式如下:

    

  在這里引入了一個系數$\zeta_h$,該值的取值為0或1,它的作用是用來mask不重要的head。在訓練時保持為1,到test的時候對部分head mask掉。

  作者在基於transformer的機器翻譯模型上和基於bert的NLI任務上做了實驗,我們來看看上面三個實驗的結果

  Ablating One Head

  去掉一個head,作者給出了實驗結果如下:

    

   從上面的圖中可以看到大多數head去掉之后的模型分數還基本分布在baseline附近,從作者給的表格數據看會更加的清晰:

    

   上面給出的是機器翻譯的表格數據,藍色的值表示性能增加,紅色的值表示性能下降,大多數情況下性能是增加的,只有少部分性能會有所下降,只有極少部分性能會下降的比較多。

  Ablating All Heads but One

  當去掉一層中的其余head只保留一個head時,我們來看下模型的結果,這回作者給出了一個離散圖:

    

   同樣的,大多數情況下的性能都分布在baseline附近,同樣看看表格會更清晰:

    

   從上面來看除了機器翻譯中的encoder-decoder之間的attention的最后一層會出現性能明顯的下降,其他大多數情況都還好,甚至有的情況下性能反而上升。

  上面兩種實驗都有一個共同的弊端,就是每次實驗只能對一層做head的mask,但實際過程中所有層的head都有可能會被去除,且至於去除哪些還和層與層之間的依賴性有關,因此第三種方法可以來改善這個問題。

  Head Importance Score for Pruning

  在這里作者引入了梯度來衡量head的重要性,首先給出一個公式如下:

    

   上面公式是對mask系數的偏導,我們知道偏導的值的大小可以衡量這個維度上對損失的影響大小,在這里作者對偏導取了個絕對值,避免在求期望的時候正負抵消,因為無論是正值還是負值,只要絕對值比較大,就可以衡量偏導對損失的影響是比較大的,這里的期望是對所有樣本X的,因為單個batch是存在誤差的,因此對全量樣本計算的偏導求均值。

   對上面的公式做一個鏈式轉換,可以得到:

    

   這樣我們就可以用這個對head的期望梯度值來衡量其重要性,然后按百分比去除head,得到的結果如下:

    

   上面圖中的實驗是通過梯度來進行剪枝的,虛線是通過第一種方法中的分數來衡量head的重要性進行剪枝的,可以看到基於梯度的效果還是很明顯的,但是剪枝范圍也是有限的,超過這個范圍之后,性能會急劇下降。

  作者還測了下剪枝后模型的推斷速度,個人感覺這個推斷速度的減小真的是毫無意義:

    

   如上圖所示,只有在batch達到16的時候才有比較明顯的速度提升,但是大多數線上運行的時候都是batch為1的。不過也不能就此下定論說減少head的數量是起不到加速效果的,個人感覺作者在這里測推斷速度的時候是存在一些問題的:作者是先訓練,后剪枝,但剪枝之后沒有再訓練,這也就意味着這些head仍然存在,只是將不需要的head前面的mask系數置為0而已。為什么做出這樣的認定呢?因為在實際的multi head設計中,我們是要保證每個head得到的詞向量拼接在一起等於原始的詞向量,因為后面要進入到前向層,必須保持維度一致,我猜這里作者可能是將mask掉的head得到的向量置為0,這樣這些值在下一層計算self-attention就沒有意義了,至於為什么還是有加速,原因不明。以上個人猜測。

  此外單純得減少head的數量好像對加速意義不大,只有配合減小embedding size才有意義,否則計算復雜度基本一致,因為我們在做multi-attention時映射到不同子空間時,實際上是一個大的矩陣映射,這個大的矩陣的維度取決於embedding size,映射完之后再分割成多個而已。從計算上來看self-attention是耗時的,因為減少embedding size,減小序列長度都可以極大的提速(減小序列長度還會影響到前向層的速度)。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM