保姆級帶你深入閱讀NAS-BERT


摘要:本文用權重共享的one-shot的NAS方式對BERT做NAS搜索。

本文分享自華為雲社區《[NAS論文][Transformer][預訓練模型]精讀NAS-BERT》,作者:蘇道 。

NAS-BERT: Task-Agnostic and Adaptive-Size BERT Compression with Neural Architecture Search

簡介:

論文代碼沒有開源,但是論文寫得挺清晰,應該可以手工實現。BERT參數量太多推理太慢(雖然已經支持用tensorRT8.X取得不錯的推理效果,BERT-Large推理僅需1.2毫秒),但是精益求精一直是科研人員的追求,所以本文用權重共享的one-shot的NAS方式對BERT做NAS搜索。

涉及到的方法包括 block-wise search, progressive shrinking,and performance approximation

講解:

1、搜索空間定義

搜索空間的ops包括深度可分離卷積的卷積核大小[3/5/7],Hidden size大小【128/192/256/384/512】MHA的head數[2/3/4/6/8],FNN[512/768/1021/1536/2048]、和identity 連接,也就是跳層了,一共26個op,具體可見下圖:

注意這里的MHA和FNN是二選一的關系,但是可以比如說第一層選MHA第二層選FNN,這樣構成一個基本的Transformer塊,可以說這個方法打破的定式的Transformer塊的搜索又包含了Transformer和BERT的結構,不同層間也是鏈式鏈接,每層只選擇一個op,如下圖

2、超網絡的訓練方式

【 Block-Wise Training + Knowledge Distillation、分塊訓練+KD蒸餾】

(1)首先把超網絡等分成N個Blocks

(2)以原始的BERT作為Teacher模型,BERT也同樣分為N個Blocks

(3)超網絡(Student)中第n個塊的輸入是teacher模型第n-1個塊的輸出,來和teacher模型的第n個塊的輸出做均方差來作為loss,來預測teacher模型中這第n個block的輸出

(4)超網絡的訓練是單架構隨機采樣訓練

(5)由於student 塊的隱藏大小可能與teacher塊中的hidden size不同,能直接利用教師塊隱藏的輸入,和輸出作為學生塊的訓練數據。為了解決這個問題,需要在學生塊的輸入和輸出處使用一個可學習的線性變換層來轉換每個hidden size,以匹配教師塊的大小,如下圖所示

【 Progressive Shrinking】

搜索空間太大,超網絡需要有效的訓練,可以借助Progressive Shrinking的方式來加速訓練和提高搜索效率,以下簡稱為PS。但是不能簡單粗暴的剔除架構,因為大架構再訓練初期難收斂,效果不好,但是並不能代表其表征能力差,所以本文設置了一個PS規則:

其含義,a^t表示超網絡中最大的架構,p(▪)表示參數量大小,l(▪)表示latency大小,B表示設置B個區間桶,b表示當前為第幾個區間。如果一個架構a不滿足p_b>p(a)>pb_1並且l_b>l(a)>l_b-1這個區間,就剔除。

PS的過程就是從每個B桶中抽E個架構,過驗證集,剔除R個最大loss的架構,重復這個過程直到只有m個架構在每個桶中

3、Model Selection

建一個表,包括 latency 、loss、 參數量 和結構編碼,其中loss和latency是預測評估的方法,評估方法具體可以看論文,對於給定的模型大小和推理延遲約束條件,從滿足參數和延遲約束的表中選擇最低loss的T個架構,然后把這個T個架構過驗證集,選取最好的那個。

實驗結果

1、和原始BERT相比在 GLUE Datasets上都有一定的提升:

2、和其他變種BERT相比效果也不錯:

消融實驗

1、PS是否有效?

如果不用PS方法,需要巨大的驗證上的時間(5min vs 50hours),並且超網絡訓練更難收斂,影響架構排序:

2、是PS架構還是PS掉node

結論是PS掉node太過粗暴,效果不好:

3、二階段蒸餾是否有必有?

本文蒸餾探究了預訓練階段和finetune階段,也就是pre-train KD 和 finetune KD,結論是:

1、預訓練蒸餾效果比finetune時候蒸餾好

2、兩階段一起蒸餾效果最好

 

點擊關注,第一時間了解華為雲新鮮技術~


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM