原來你是這樣的BERT,i了i了! —— 超詳細BERT介紹(二)BERT預訓練
BERT(Bidirectional Encoder Representations from Transformers)是谷歌在2018年10月推出的深度語言表示模型。
一經推出便席卷整個NLP領域,帶來了革命性的進步。
從此,無數英雄好漢競相投身於這場追劇(芝麻街)運動。
只聽得這邊G家110億,那邊M家又1750億,真是好不熱鬧!
然而大家真的了解BERT的具體構造,以及使用細節嗎?
本文就帶大家來細品一下。
前言
本系列文章分成三篇介紹BERT,上一篇介紹了BERT主模型的結構及其組件相關,本篇則主要介紹BERT預訓練相關知識,其后還會有一篇介紹如何將BERT應用到不同的下游任務。
文章中的一些縮寫:NLP(natural language processing)自然語言處理;CV(computer vision)計算機視覺;DL(deep learning)深度學習;NLP&DL 自然語言處理和深度學習的交叉領域;CV&DL 計算機視覺和深度學習的交叉領域。
文章公式中的向量均為行向量,矩陣或張量的形狀均按照PyTorch的方式描述。
向量、矩陣或張量后的括號表示其形狀。
本系列文章的代碼均是基於transformers庫(v2.11.0)的代碼(基於Python語言、PyTorch框架)。
為便於理解,簡化了原代碼中不必要的部分,並保持主要功能等價。
閱讀本系列文章需要一些背景知識,包括Word2Vec、LSTM、Transformer-Base、ELMo、GPT等,由於本文不想過於冗長(其實是懶),以及相信來看本文的讀者們也都是沖着BERT來的,所以這部分內容還請讀者們自行學習。
本文假設讀者們均已有相關背景知識。
目錄
2、預訓練
BERT的預訓練是一大特色,BERT經過預訓練后,只需要在下游任務的數據集上進行少數幾輪(epoch)的監督學習(supervised learning),就可以大幅度提升下游任務的精度。
另外,BERT的預訓練是通過無監督學習(unsupervised learning)實現的。
預訓練所使用的無監督數據集往往非常大,而下游任務的監督數據集則可以很小。
由於網絡上文本數據非常多,所以獲取大規模無監督的文本數據集是相對容易的。
BERT在預訓練時學習兩種任務:遮蓋的語言模型(masked language model, MLM)、下一句預測(next sentence prediction,NSP)。
- 遮蓋的語言模型:在輸入的序列中隨機把原標記替換成
[MASK]
標記,然后用主模型輸出的標記表示來預測所有原標記,即學習標記的概率分布。 - 下一句預測:訓練數據隨機取同一篇文章中連續兩句話,或分別來自不同文章的兩句話,用序列表示來預測是否是連續的兩句話(二分類)。
下面先來講解一下BERT用到的損失函數,然后再講解以上兩個學習任務。
2.1、損失函數
BERT中主要使用到了回歸(regression)和分類(classification)損失函數。
回歸任務常用均方誤差(mean square error,MSE)作為損失函數,而分類任務一般用交叉熵(cross entropy,CE)作為損失函數。
2.1.1、均方誤差損失函數
均方誤差可以度量連續的預測值和真實值之間的差異。
假設\(\hat{y_i}\)是預測值,\(y_i\)是真實值,\(i = 0, 1, ..., (N-1)\)是樣本的編號,總共有\(N\)個樣本,那么損失\(loss\)為:
有時為了方便求導,會對這個\(loss\)除以\(2\)。
由於學習任務是讓\(loss\)最小化,給\(loss\)乘以一個常量對學習任務是沒有影響的。
2.1.2、交叉熵損失函數
交叉熵理解起來略有些復雜。
2.1.2.1、熵
首先來看看什么是熵(entropy)。
熵是用來衡量隨機變量的不確定性的,熵越大,不確定性越大,就需要越多的信息來消除不確定性。
比如小明考試打小抄,對於答案簡短的題目,只需要簡單做個標記,而像背古詩這種的,就需要更多字來記錄了。
假設離散隨機變量\(X \sim P\),則
稱為\(x\)(\(X\)的某個取值)的信息量,單位是奈特(nat)或比特(bit),取決於對數是以\(e\)還是以\(2\)為底的。
\(H(X)\)為\(X\)(\(X\)的所有取值)的信息量的期望,即熵,單位同上。
另外,由於熵和概率分布有關,所以很多時候寫作某個概率分布的熵,而不是某個隨機變量的熵。
2.1.2.2、KL散度
KL散度(Kullback-Leibler divergence),也叫相對熵(relative entropy)。
假設有概率分布\(P\)、\(Q\),一般分別表示真實分布和預測分布,KL散度可以用來衡量兩個分布的差異。
KL散度定義為:
將公式稍加修改:
可以看出,KL散度實際上是用概率分布\(P\)來計算\(Q\)對\(P\)的信息量差的期望。
如果\(P = Q\),那么\(D_{KL}(P||Q) = 0\)。
另外可以證明,\(D_{KL}(P||Q) \ge 0\)總是成立,本文證明略。
如果把上式拆開,就是
其中,\(H(P, Q)\)稱為\(Q\)對\(P\)的交叉熵,\(H(P)\)是\(P\)的熵。
2.1.2.3、交叉熵
由於KL散度描述了兩個分布信息量差的期望,所以可以通過最小化KL散度來使得兩個分布接近。
而在一些學習任務中,比如分類任務,真實分布是固定的,即\(H(P)\)是固定的,所以最小化KL散度等價於最小化交叉熵。
交叉熵公式為:
單標簽分類任務中,假設\(\hat{y}_{ik} \in [0, 1]\)和\(y_{ik} \in \{0, 1\}\)分別是第\(i = 0, 1, ..., (N-1)\)個樣本第\(k = 0, 1, ..., (K-1)\)類的預測值和真實值,其中\(\sum_{k=0}^{K-1} \hat{y}_{ik} = \sum_{k=0}^{K-1} y_{ik} = 1\),\(N\)是樣本數量,\(K\)是類別數量,則損失\(loss\)為:
損失函數代碼如下:
代碼
# 損失函數之回歸
class LossRgrs(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
# 均方誤差損失函數
self.loss_fct = nn.MSELoss(*args, **kwargs)
def forward(self, logits, labels):
return self.loss_fct(logits.view(-1), labels.view(-1))
# 損失函數之分類
class LossCls(nn.Module):
def __init__(self, num_classes, *args, **kwargs):
super().__init__()
# 標簽的類別數量
self.num_classes = num_classes
# 交叉熵損失函數
self.loss_fct = nn.CrossEntropyLoss(*args, **kwargs)
def forward(self, logits, labels):
return self.loss_fct(logits.view(-1, self.num_classes), labels.view(-1))
# 損失函數之回歸或分類
class LossRgrsCls(nn.Module):
def __init__(self, num_classes, *args, **kwargs):
super().__init__()
self.loss_fct = (LossRgrs(*args, **kwargs) if not num_classes<=1
else LossCls(num_classes, *args, **kwargs))
def forward(self, logits, labels):
return self.loss_fct(logits, labels)
其中,
num_classes
是標簽的類別數量,BERT應用到序列分類任務時,如果類別數量=1則為回歸任務,否則為分類任務。
注意:無論是回歸還是分類任務,模型輸出表示后,都需要將表示轉化成預測值,一般是在模型最后通過一個線性回歸或分類器(其實也是一個線性變換)來實現,回歸任務得到的就是預測值,然后直接輸入MSE損失函數計算損失就可以了;而分類任務得到的是對數幾率(logit),還要用softmax函數轉化成概率,再通過CE損失函數計算損失,而PyTorch中softmax和CE封裝在一起了,所以直接輸入對數幾率就可以了。
2.2、遮蓋的語言模型
語言模型(language model,LM)是對語言(字符串)進行數學建模(表示)。
傳統的語言模型包括離散的和連續的(分布式的),離散的最經典的是詞袋(bag of words,BOW)模型和N元文法(N-gram)模型,連續的包括Word2Vec等,這些本文就不細說了。
然而到了DL時代,語言模型就是想個辦法讓神經網絡學習序列的概率分布。
MLM采用了降噪自編碼器(denoising autoencoder,DAE)的思想,簡單來說,就是在輸入數據中加噪聲,輸入神經網絡后再讓神經網絡恢復出原本無噪聲的數據,從而讓模型學習到了聯想能力,即輸入數據的概率分布。
具體來說,MLM將序列中的標記隨機替換成[MASK]
標記,例如
I ' m repair ##ing immortal ##s .
這句話,修改成
I ' m repair [MASK] immortal ##s .
如果模型可以成功預測出[MASK]
對應的原標記是##ing
,那么就可以認為模型學到了現在進行時要加ing的知識。
另外在計算損失的時候,是所有標記都要參與計算的。
2.3、下一句預測
NSP是為了讓模型學會表示句子連貫性等較為深層次的語言特征而設計的。
具體來說,首先看如下例子(來自transformers庫的示例):
This text is included to make sure ...
Text should be one-sentence-per-line ...
This sample text is public domain ...
The rain had only ceased with the gray streaks ...
Indeed, it was recorded in Blazing Star that ...
Possibly this may have been the reason ...
"Cass" Beard had risen early that morning ...
A leak in his cabin roof ...
The fountain of classic wisdom ...
As the ancient sage ...
From my youth I felt in me a soul ...
She revealed to me the glorious fact ...
A fallen star, I am, sir ...
其中,每一行都是一個句子(原句子太長,所以省略了一部分),不同的文章用空行來隔開。
如果選擇來自同一篇文章連續的兩個句子:
This text is included to make sure ... ||| Text should be one-sentence-per-line ...
則NSP的標簽為1;如果選擇來自不同文章的兩個句子:
This text is included to make sure ... ||| The fountain of classic wisdom ...
則NSP的標簽為0。
另外無論是MLM還是NSP,BERT預訓練的數據是在訓練之前靜態生成好的。
預訓練代碼如下:
代碼
# BERT之預訓練
class BertForPreTrain(BertPreTrainedModel):
# noinspection PyUnresolvedReferences
def __init__(self, config):
super().__init__(config)
self.config = config
# 主模型
self.bert = BertModel(config)
self.linear = nn.Linear(config.hidden_size, config.hidden_size)
self.act_fct = F.gelu
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# 標記線性分類器
self.cls = nn.Linear(config.hidden_size, config.vocab_size)
# 句子關系線性分類器
self.nsp_cls = nn.Linear(config.hidden_size, 2)
# 標記分類損失函數
self.loss_fct = LossCls(config.vocab_size)
# 句子關系分類損失函數
self.nsp_loss_fct = LossCls(2)
self.init_weights()
def get_output_embeddings(self):
return self.cls
def forward(self,
tok_ids, # 標記編碼(batch_size * seq_length)
pos_ids=None, # 位置編碼(batch_size * seq_length)
sent_pos_ids=None, # 句子位置編碼(batch_size * seq_length)
att_masks=None, # 注意力掩碼(batch_size * seq_length)
mlm_labels=None, # MLM標記標簽(batch_size * seq_length)
nsp_labels=None, # NSP句子關系標簽(batch_size)
):
outputs, pooled_outputs = self.bert(
tok_ids,
pos_ids=pos_ids,
sent_pos_ids=sent_pos_ids,
att_masks=att_masks,
)
outputs = self.linear(outputs)
outputs = self.act_fct(outputs)
outputs = self.layer_norm(outputs)
logits = self.cls(outputs)
nsp_logits = self.nsp_cls(pooled_outputs)
if mlm_labels is None and nsp_labels is None:
return (
logits, # 標記對數幾率(batch_size * seq_length * vocab_size)
nsp_logits, # 句子關系對數幾率(batch_size * 2)
)
loss = 0
if mlm_labels is not None:
loss = loss + self.loss_fct(logits, mlm_labels)
if nsp_labels is not None:
loss = loss + self.nsp_loss_fct(nsp_logits, nsp_labels)
return loss
后記
本文詳細地介紹了BERT預訓練,BERT預訓練是BERT有出色性能的關鍵,其中所使用的學習任務也是BERT的一大亮點。
后續一篇文章會介紹BERT下游任務相關。
從BERT預訓練的實現中可以發現,BERT巧妙地充分利用了主模型輸出的標記表示和序列表示,並分別學習標記分布概率和句子連貫性,並且運用了DAE的思想,以及兩種學習任務都可以通過無監督的方式實現。
然而后續的一些研究也對BERT提出了批評,例如采用MLM學習,預訓練時訓練數據中有[MASK]
標記,而微調時沒有這個標記,這就導致預訓練和微調的數據分布不一致;NSP並不能使模型學習到句子連貫性特征,因為來自不同文章的句子可能主題(topic)不一樣,NSP最終可能只學習了主題特征,而主題特征是文本中的淺層次特征,應該改為來自同一篇文章連續的或不連續的兩句話作為改良版的NSP訓練數據;MLM是一種生成式任務,生成式任務也可以看成分類任務,只不過一個類是詞匯表里的一個標記,詞匯表往往比較大,所以類別往往很多,計算損失之前要將表示轉化成長度為詞匯表長度的對數幾率向量,這個計算量是比較大的,如果模型又很大,那么整個學習任務對算力要求就會很高;BERT在計算每個標記的標簽時,是獨立計算的,即認為標記之間的標簽是相互獨立的,這往往不符合實際,所以其實BERT對標記分類(序列標注)任務的效果不是非常好。