自然語言處理中的負樣本挖掘


自然語言處理中的負樣本挖掘 (分類與排序任務中如何選擇負樣本)

1 簡介

首先, 介紹下自然與處理中的分類任務和排序任務的基本定義和常見做法, 然后介紹負樣本在這兩個任務中的意義.

1.1 分類任務

輸入為一段文本, 輸出為這段文本的分類, 是自然語言處理最為常見,應用最為廣泛的任務. 意圖識別, 語義蘊含和情感分析都屬於該類任務.
深度學習沒有大火之前, 主要做法是手工特征+XGBoost(也可以是邏輯斯蒂), 效果的好壞主要看工程師對業務理解的程度和手工特征的質量.
有了深度學習后, 多是將文本映射為字詞向量, 然后通過CNN或LSTM做語義理解, 最后經過全連接層獲取分類結果.
BERT問世后, 分類變得更為無腦, 拿到CLS后接全連接層就行了. 當然這是基本做法, 實際使用必須根據業務需求對算法和訓練方式做相應修改.

1.2 排序任務

輸入為一組樣本, 輸出每個樣本的排名. 搜索引擎就主要使用了這類技術. 排序任務主要用到的技術是Learn To Rank(LTR), 隨着深度學習在圖像領域的光速落地, 出現了各種類似的研究方向, 比如 Metric Learning, Contrastive Learning, Representation Learning等. 其實這些研究方向本質都是比較類似的, 即讓模型學會度量樣本之間的距離. 大名鼎鼎的Triplet Loss本質上就是LTR中的Pair-Wise.
此類任務做法也是五花八門, 各種奇淫巧計, 有興趣的朋友可以專門去看資料. 我只說一下我在LTR中的若干經驗:

  1. 別把排序任務當分類, 這兩個在本質上是不一樣的
  2. 排序任務不管怎么折騰, 最后訓練的還是一個打分模型, 而分值只需用來比較大小, 不需要有實際意義(這個還要看具體使用的算法). 所以本人比較喜愛Triplet Loss, 在我看來它直接表現了排序任務的本質.
  3. 兩個樣本先表征再計算分值, 速度快效果差; 兩個樣本直接送入模型算分值,效果好, 速度慢.

1.3 負樣本的重要性

在上述兩個任務中, 都不可避免的使用到負樣本. 比如分類任務中的短文本相似度匹配, 語義不相關的文本對就是負樣本, 排序任務中, 排名靠后的樣本對就可以看作是負樣本. 正負樣本又有如下兩個類別:
Easy Example (簡單樣本): 即模型非常容易做出正確的判斷.
Hard Example (困難樣本): 即模型難以做出正確的判斷, 訓練時常會給模型帶來較大損失.
相較於簡單樣本, 困難樣本更有價值, 它有助於模型的快速學習到邊界, 也加快了模型收斂速度.
在實際情況中, 正樣本基本上是被確定好的, 也不太好再去擴充和修改. 但是負樣本有非常大的選擇空間, 比如在搜索任務中, 用戶點擊的Document理解為正樣本, 那么該頁面的其余文檔就全是負樣本了. 訓練模型時, 顯然不能全部采用這些負樣本, 因此Hard Negative Example Mining (選取困難負樣本) 就變得非常重要!

2 負樣本選擇方法

2.1 基於統計度量的負樣本選擇方法

計算候選負樣本的一些統計度量值並以此為標准選取負樣本. 比如在短文本匹配任務中, 選取和目標樣本集合相似度值較高的樣本作為負例, 這種類型的負例可以讓模型盡可能學習文本所表示的語義信息, 而不是簡單學習字面意思. 試想在一個語義匹配任務中, 所有負樣本都是隨機生成的, 毫無章法的漢字組合, 那么模型定能快速收斂, 然而在實際生產中毫無用處.
在搜索任務中, 可以使用TF-IDF, BM25等方法檢索出top-k作為負例, 注意要保證訓練數據和測試數據分布一致, 如果你的模型在整個搜索框架中需要為全量文檔打分排序, 那么除了top-k作為負例, 還需隨機采樣一些作為負例, 畢竟見多才能識廣.

2.2 基於模型的負樣本選擇方法

2.1小節的方法太過朴素, 該方法選出的負樣本未必就是最能主導模型梯度更新方向和大小的樣本. 因此一個簡單的做法就是用訓練好的模型預測所有的負樣本, 找出預測錯的或者產生較大loss的樣本作為優質負樣本, 然后再去訓練模型, 不斷迭代優化模型. 整體流程如下圖:

該方法邏輯上沒有問題, 使用效果也不錯, 但是最大的問題時時間消耗太過嚴重, 每訓練若干輪就要在所有負樣本上預測一次找到最有價值的負樣本, 這對於深度學習而言太耗時了. 所以本人有兩個解決方案, 一是不要用模型預測所有負樣本, 預測一部分負樣本就行, 這算是一個折中方案. 第二個方法是與此有些類似, 名字叫做OHEM(Online Hard Example Mining, 在線困難樣本挖掘), 在每次梯度更新前, 選取loss值較大的樣本進行梯度更新. 該方法選取負樣本是從一個Batch中選擇的, 自然節省了時間. 該方法是為目標檢測提出的, 在NLP領域能否適用還需要看實戰效果.

3 基於loss的改進

除了選擇優質負樣本, 還可以考慮在損失函數上做改進, 讓模型自動提高困難樣本的權重.

3.1 Focal Loss

Focal Loss對交叉熵損失函數進行了改進。該損失函數可以通過減少易分類樣本的權重,使得模型在訓練時更專注於難分類的樣本。
首先來看一下二分類上的交叉熵損失函數:

\[CELoss = -y_{true}log(y_{pred}) - (1-y_{true})log(1-y_{pred}) \]

簡單化簡后:

\[CELoss = \begin{cases} -log(y_{pred}), y_{true}=1 \\ -log(1-y_{pred}), y_{true}=0\end{cases} \]

這個公式相比大家非常熟悉了, 我就不再贅述.

Focal Loss對該損失函數做了簡單修改, 具體公式如下:

\[ FocalLoss = \begin{cases} -a(1-y_{pred})^\gamma log(y_{pred}), y_{true}=1 \\ -(1-a)(y_{pred})^\gamma log(1-y_{pred}), y_{true}=0\end{cases} \]

Focal Loss的作者建議alpha取0.25, gamma取2, 其實通過和CELoss的對比就可以發現, Focal Loss主要對損失值做了進一步的縮放, 使得難以區分的樣本會產生更大的損失值, 最終是模型的梯度大小和方向主要由難分樣本來決定.
下圖是Focal Loss的測試結果, 效果還是令人滿意的.

3.2 Gradient Harmonizing Mechanism(GHM)

Focal Loss 會過度關注難分類樣本, 真實數據集往往會有很多噪音, 這容易導致模型過擬合.
與Focal Loss類似, GHM同樣會對損失值做一個抑制,只不過這個抑制是根據樣本數量來的, 梯度小的樣本數比較大,那就給他們乘上一個小系數,梯度大的樣本少乘以一個大的系數。不過這個系數不是靠自己調的,而是根據樣本的梯度分布來確定的. 具體公式可以參考原論文, 這里就不再貼公式了, 到目前未知還沒有聽說有用在NLP上的, 效果如何也要看實戰了.

4 訓練方式的改進

第二節和第三節都是針對優質樣本的探索, 其實可以換個角度思考, 就讓模型見到足夠多的負例, 只要硬件足夠強大, 你可以模型學習所有的負例, 頗有一種大力出奇跡的感覺. 因此可以嘗試改變訓練方式讓模型見到更多的負樣本, 此類方法很想LTR中的List-Wise.

4.1 一個Batch的樣本都作為負樣本

以短文本匹配任務為例, 假設輸入的一個batch是 (a1,b1), (a2,b2), (a3,b3), (a4,b4), 每一對樣本都是相似的, 可以把剩余其他樣本都作為負樣本, 比如對於a1而言b2, b3, b4都是負樣本, 這樣可以在沒有增大batch_size大小的情況下讓模型學到更多負樣本. 具體損失函數既可以是二分類的交叉熵損失函數, 也可以當成多分類損失函數來優化, 即把(a1,b1), (a1,b2), (a1,b3), (a1,b4)的相似度值作為logit送入多分類損失函數, 在這個例子中, 是四分類任務, 標簽是0(a1和b1的相似度值應該最大).
如果要使用該方法記得確保每個batch中其他樣本可以作為負樣本! 如果a1和b4也是相似, 那么訓練數據就存在噪聲了.
此種類型的方法在NLP領域已有使用, 效果還算可以, 有興趣的朋友可以在自己的任務上試一試. SimBERT(一個通用句向量編碼器)就是基於此類方法訓練的, 美團在微軟閱讀理解比賽取得第一名的算法也和此類似. "沒有增大batch_size大小的情況下讓模型學到更多負樣本", 這句話算是第四節的核心了

4.2 MOCO

MOCO方法算是一種經典的圖像預訓練方法, 它是自監督的, 基於表征學習, 類似於NLP里的Bert. MOCO的一大創新就是讓模型一次見到海量負例(以萬為單位), 那么存在的問題就是計算量會爆炸, 每一個step 都要多計算上萬次, 就像是把4.1中的batch_size變成一萬一樣, 時間是吃不消的. 為了解決這個問題, MOCO創建了負例隊列, 不會一次計算所有負例的表征, 而是緩慢更新這個隊列中的一部分負例, 即先進隊列的先被更新, 這樣是減少了計算量, 但是在計算損失函數時使用到的大部分負例表征是過時的, 因此MOCO使用動量方法來更新模型參數, 具體過程可以看下面的偽代碼:

# f_q, f_k: encoder networks for query and key
# queue: dictionary as a queue of K keys (CxK)
# m: momentum
# t: temperature

f_k.params = f_q.params # initialize
for x in loader: # load a minibatch x with N samples
	x_q = aug(x) # a randomly augmented version
	x_k = aug(x) # another randomly augmented version
	q = f_q.forward(x_q) # queries: NxC
	k = f_k.forward(x_k) # keys: NxC
	k = k.detach() # no gradient to keys
	# positive logits: Nx1
	l_pos = bmm(q.view(N,1,C), k.view(N,C,1))
	# negative logits: NxK
	l_neg = mm(q.view(N,C), queue.view(C,K))
	# logits: Nx(1+K)
	logits = cat([l_pos, l_neg], dim=1)
	# contrastive loss, Eqn.(1)
	labels = zeros(N) # positives are the 0-th
	loss = CrossEntropyLoss(logits/t, labels)
	# SGD update: query network
	loss.backward()
	update(f_q.params)
	# momentum update: key network
	f_k.params = m*f_k.params+(1-m)*f_q.params
	# update dictionary
	enqueue(queue, k) # enqueue the current minibatch
	dequeue(queue) # dequeue the earliest minibatch

logits = cat([l_pos, l_neg], dim=1)這一行代碼就是把正樣本的點積值和負樣本的點積值在最后一維進行拼接. f_k.params = m*f_k.params+(1-m)*f_q.params該行代碼是對模型參數進行動量更新. 如果沒有一定的Pytorch基礎還是比較難看懂的, 建議大家去讀一讀論文以便加深理解.
下圖是MOCO的性能評估, 可以看出MOCO性能優於其他同類方法, 且負例隊列數量越大效果越好. 該方法本人已經在NLP領域相關任務上做過嘗試, 效果不錯, 有興趣的可以在自己任務上試一試.

5 總結

天下沒有免費的午餐, 很多方法只有親自試了才知道是否有效. 建議大家多去看測試結果, 基於真實數據分析思考算法的優化方向. 本文所講的方法也是拋磚引玉, 希望各位大佬可以貢獻更多的方法.
NLP被稱作是人工智能皇冠上的明珠, 但是截至目前未知還未看到這顆明珠大放異彩. 本文介紹的方法幾乎全部來源於圖像領域, 想一想還是挺失望的. 學科之間思想方法都是相通的, 希望在以后能看到更多在NLP研究上的創新.

最后感謝各位閱讀, 希望能幫到你們.

文章可以轉載, 但請注明出處:

6 參考文獻

  1. OHEM: Training Region-based Object Detectors with Online Hard
    Example Mining
  2. S-OHEM: Stratified Online Hard Example Mining for Object Detection
    S-OHEM
  3. A-Fast-RCNN: Hard positive generation via adversary for object
    detection
  4. Focal Loss: Focal Loss for Dense Object Detection
  5. GHM: Gradient Harmonized Single-stage Detector
  6. MOCO: Momentum Contrast for Unsupervised Visual Representation
    Learning
  7. https://zhuanlan.zhihu.com/p/60612064
  8. https://www.cnblogs.com/rookiechenv587/p/11973078.html


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM