NLP中的預訓練語言模型(五)—— ELECTRA


  這是一篇還在雙盲審的論文,不過看了之后感覺作者真的是很有創新能力,ELECTRA可以看作是開辟了一條新的預訓練的道路,模型不但提高了計算效率,加快模型的收斂速度,而且在參數很小也表現的非常好。

論文:ELECTRA: PRE-TRAINING TEXT ENCODERS AS DISCRIMINATORS RATHER THAN GENERATORS

  ELECTRA全稱為Efficiently Learning an Encoder that Classifies Token Replacements Accurately 。論文中提出了一個新的任務—replaced token detection,簡單來說該任務就是預測預訓練語言模型生成的句子中哪些token是原本句子中的,哪些是由語言模型生成的。

  模型的整個結構如下:

    

  整個訓練模式有點類似於GAN,模型由一個生成器和一個判別器組成的,這個判別器就是我們最終使用的預訓練模型,生成器可以采用任何形式的生成模型,在這里作用采用了MLM語言模型(bert之類的)來作為生成器,具體流程如下:

  1)首先對一個距離mask一些詞,將這個mask后的句子作為生成器的輸入。

  2)生成器將這些mask的詞預測成vocab中的token,如上面將painting mask后輸入到生成器中,然后生成器重構輸入,將mask預測成car。

  3)將生成器的輸出作為判別器的輸入,判別器去預測這個句子中的每個token是真實的token,還是由生成器生成的虛假的token,注意:如果生成器生成的詞和真實詞一致,則當作真實的token,例如上面講the mask后生成器仍預測為the,則the在判別器中也是真實值,標簽為正。

  模型的整個流程確定了,剩下的就是該怎么訓練了,在這里訓練方式和GAN並不相同,在GAN中會將判別器的結果作為訓練生成器的損失,但由於NLP中句子是離散的,因此無法通過梯度下降的方式來將判別器的結果反向傳播來訓練生成器,因此在這里作者將MLM損失作為生成器的損失,而將replaced token detection的損失作為判別器的損失,具體損失函數如下:

    生成器的損失:

      

    生成器的損失就是MLM語言模型中預測mask詞的損失。

    判別器的損失:

      

    判別器的損失就是token detection的損失,每個token都有兩個可能性——真實和虛假,因此每個token是一個二分類,然后在這里作者考慮了所有的token。

    最終整個模型的損失為:

      

    $\lambda$ 是一個權重系數,作者認為生成器的任務比較難,因此損失比較大,但是判別器任務相對簡單,因此損失會比較小,因此將判別器的權重設大一點,作者訓練時使用了50。以上就是整個訓練過程。

  權值共享

  作者在訓練的時候采用了一些策略,在這里作者共享生成器和判別器的權值,作者對比了不共享,共享embdding層,共享所有層(共享所有層時需要保證生成器和判別器的架構一樣),作者得出不共享時性能為83.6,共享embedding層為84.3,共享所有層為84.4,因為共享所有層提升不明顯,且還需要保證生成器和判別器結構一致,因此作者只共享了embedding層。

  更小的生成器

  作者在這里對比了不同尺寸生成器的性能,結果發現當生成器大小為判別器的1/4-1/2時,模型性能最好,具體如下圖:

    

   不同的訓練模式

  作者對比了不同的訓練模式,除當前這種模式之外,又給出了兩種新的模式:

  1)采用GAN的方式,用reinforce算法將生成器的最小化MLM損失改成了最大化replaced token detection 損失。

  2)采用2-stage的方式,首先訓練生成器,然后用生成器的參數初始化判別器,在這里要保證生成器和判別器結構相同。作者也說了如果從0開始訓練判別器,效果很差。

  具體對比結果如下:

    

  小模型性能對比:

    

   在這里作者給出了一個14M的小模型ELECTRA-small,效果超過了distillbert,gpt,同等大小的bert-small。這個模型在V100單卡上訓練即可。

  大模型性能對比:

    

   在同等大小的大模型上,性能也和當前最佳roberta性能相當,但是訓練計算量只有roberta的1/4。

  另外作者還分析了計算效率,作者認為判別器中的全詞預測能充分的利用計算效率,而只預測15%的mask的token是很浪費計算資源的,因此全詞預測可以加快模型的收斂速度。另外同等大小下,ELECTRA的性能優於bert,且模型越小,優勢越明顯,如下圖所示:

    

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM