TransCoder代碼詳解(二):MLM的訓練過程


前言

在上一篇blog里,ATP分析了TransCoder模型最頂層的main函數,理清了它的訓練過程是怎么循環的。

這次ATP本來想要看一下它的模型具體是什么樣子的。但ATP發現,pretrain過程(只有encoder)和后續的過程(同時有encoder和decoder)它模型的結構與訓練過程還是差別很大的。

為了避免ATP的blog寫得太亂七八糟,ATP決定這次先有針對性地去看一下MLM的訓練過程,也就是只有encoder的時候它是怎么操作的。

建立模型build_model

只考慮MLM的過程的話,build_model這塊內容非常簡單,就是建立了一個Transformer的encoder。基本結構整理出來就像下面這樣:

def build_model(params, dico):
    """
    Build model.
    """
    if params.encoder_only:
        # build
        model = TransformerModel(
            params, dico, is_encoder=True, with_output=True)

        # reload pretrained word embeddings
        if params.reload_emb != '':
	      ......

        # reload a pretrained model
        if params.reload_model != '':
	      ......

        ......
		
        return [model.cuda()]

在用MLM進行pretrain的時候,參數里面的“reload_emb”和“reload_model”都是空串,意思是既不需要載入已有的embedding,也不需要載入已有的model(因為MLM過程是訓練的第一個過程,不需要從別的地方載入什么東西)。

而通過對比可以發現,在進行DAE/BT的訓練時,reload_model這個參數有值,指向的是用MLM訓練好的model。這也進一步印證了該模型的訓練過程是先MLM,再DAE/BT。

Transformer內部的細節ATP沒有仔細看。ATP傾向於認為它就是一個普通的transformer。

訓練過程:trainer和mlm_step

在main函數中,模型建立完成后,又定義了一個trainer。這個類的定義位於XLM/src/trainer.py中,作用是執行訓練的步驟。

例如在主循環中,mlm_step這個函數就是trainer類的一個成員函數,作用是執行一次MLM的訓練。

# generate batch / select words to predict
x, lengths, positions, langs, _ = self.generate_batch(lang1, lang2, 'pred')
x, lengths, positions, langs, _ = self.round_batch(x, lengths, positions, langs)
x, y, pred_mask = self.mask_out(x, lengths)

mlm_step函數首先通過generate_batch這個函數生成一批數據。雖然這個函數返回很多個值,但在MLM過程中我們只需要關注x(返回的數據)和lengths(數據的長度)。

round_batch是與fp16有關的。mask_out是給數據打mask的,返回的x,y,pred_mask三個參數分別是打過mask的數據、原始數據,以及一個布爾數組表示哪里打了mask。

接下來,將得到的數據推送到顯存上后,就可以開始訓練了。mlm_step的核心語句是這幾句:

# forward / loss
tensor = model('fwd', x=x, lengths=lengths, positions=positions, langs=langs, causal=False)
_, loss = model('predict', tensor=tensor, pred_mask=pred_mask, y=y, get_scores=False)
self.stats[('MLM-%s' % lang1) if lang2 is None else ('MLM-%s-%s' % (lang1, lang2))].append(loss.item())
loss = lambda_coeff * loss

這段語句的前兩行是在調用transformer類的成員函數。它們的作用光看字面意思就能猜個大概,就是把數據送入transformer,過了encoder以后再預測mask的內容,然后與真實的數據(y)算出loss進行優化。

其中,fwd函數返回的是輸入數據過了encoder與一個額外的全連接層(FFN)后的輸出,而predict函數利用這個輸出來進行預測並計算loss。

原理和這個圖是一樣的:

這個圖是從李宏毅的講BERT的課程視頻里截出來的。關於這個訓練過程他的解釋是,因為線性分類器是相對比較弱的一種分類器,所以分類的效果更多地取決於encoder所作出的embedding是不是准確。所以這個MLM的訓練過程能有效地訓練模型的embedding能力。

另外,TransCoder的原論文中提到,模型能work的關鍵是它找到了不同語言之間的anchor point,也就是具有相同表示的token。ATP其實對這個地方的理解一直比較模糊。它現在認為這個anchor point應該指的是在embedding之后位置相近(或相同)的token,也就是說不同語言中上下文語境相似的token。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM