Teacher forcing是什么?
RNN 存在兩種訓練模式(mode):
- free-running mode: 上一個state的輸出作為下一個state的輸入。
- teacher-forcing mode: 使用來自先驗時間步長的輸出作為輸入。
teacher forcing要解決什么問題?
常見的訓練RNN網絡的方式是free-running mode,即將上一個時間步的輸出作為下一個時間步的輸入。可能導致的問題:
- Slow convergence.
- Model instability.
- Poor skill.
訓練迭代過程早期的RNN預測能力非常弱,幾乎不能給出好的生成結果。如果某一個unit產生了垃圾結果,必然會影響后面一片unit的學習。錯誤結果會導致后續的學習都受到不好的影響,導致學習速度變慢,難以收斂。teacher forcing最初的motivation就是解決這個問題的。
使用teacher-forcing,在訓練過程中,模型會有較好的效果,但是在測試的時候因為不能得到ground truth的支持,存在訓練測試偏差,模型會變得脆弱。
什么是teacher forcing?
teacher-forcing 在訓練網絡過程中,每次不使用上一個state的輸出作為下一個state的輸入,而是直接使用訓練數據的標准答案(ground truth)的對應上一項作為下一個state的輸入。
Teacher Forcing工作原理: 在訓練過程的\(t\)時刻,使用訓練數據集的期望輸出或實際輸出: \(y(t)\), 作為下一時間步驟的輸入: \(x(t+1)\),而不是使用模型生成的輸出\(h(t)\)。
一個例子:訓練這樣一個模型,在給定序列中前一個單詞的情況下生成序列中的下一個單詞。
給定如下輸入序列:
Mary had a little lamb whose fleece was white as snow
首先,我們得給這個序列的首尾加上起止符號:
[START] Mary had a little lamb whose fleece was white as snow [END]
對比兩個訓練過程:
No. | Free-running: X | Free-running: \(\hat{y}\) | teacher-forcing: X | teacher-forcing: \(\hat{y}\) | teacher-forcing: Ground truth |
---|---|---|---|---|---|
1 | "[START]" | "a" | "[START]" | "a" | "Marry" |
2 | "[START]", "a" | ? | "[START]", "Marry" | ? | "had" |
3 | ... | ... | "[START]", "Marry", "had" | ? | "a" |
4 | "[START]", "Marry", "had", "a" | ? | "little" | ||
5 | ... | ... | ... | ||
free-running 下如果一開始生成"a",之后作為輸入來生成下一個單詞,模型就偏離正軌。因為生成的錯誤結果,會導致后續的學習都受到不好的影響,導致學習速度變慢,模型也變得不穩定。
而使用teacher-forcing,模型生成一個"a",可以在計算了error之后,丟棄這個輸出,把"Marry"作為后續的輸入。該模型將更正模型訓練過程中的統計屬性,更快地學會生成正確的序列。
teacher-forcing 有什么缺點?
teacher-forcing過於依賴ground truth數據,在訓練過程中,模型會有較好的效果,但是在測試的時候因為不能得到ground truth的支持,所以如果目前生成的序列在訓練過程中有很大不同,模型就會變得脆弱。
換言之,這種模型的cross-domain能力會更差,即如果測試數據集與訓練數據集來自不同的領域,模型的performance就會變差。
那有沒有解決這個限制的辦法呢?
teacher-forcing缺點的解決方法
beam search
在預測單詞這種離散值的輸出時,一種常用方法是:對詞表中每一個單詞的預測概率執行搜索,生成多個候選的輸出序列。
這個方法常用於機器翻譯(MT)等問題,以優化翻譯的輸出序列。
beam search是完成此任務應用最廣的方法,通過這種啟發式搜索(heuristic search),可減小模型學習階段performance與測試階段performance的差異。
curriculum learning
Curriculum Learning是Teacher Forcing的一個變種:一開始老師帶着學,后面慢慢放手讓學生自主學。
Curriculum Learning即有計划地學習:
- 使用一個概率\(p\)去選擇使用ground truth的輸出\(y(t)\)還是前一個時間步驟模型生成的輸出\(h(t)\)作為當前時間步驟的輸入\(x(t+1)\)。
- 這個概率\(p\)會隨着時間的推移而改變,稱為計划抽樣(scheduled sampling)。
- 訓練過程會從force learning開始,慢慢地降低在訓練階段輸入ground truth的頻率。
Further Reading
Papers
- A Learning Algorithm for Continually Running Fully Recurrent Neural Networks, 1989.
- Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks, 2015.
- Professor Forcing: A New Algorithm for Training Recurrent Networks, 2016.
Book
- Section 10.2.1, Teacher Forcing and Networks with Output Recurrence, Deep Learning, Ian Goodfellow, Yoshua Bengio, Aaron Courville, 2016.
問:在訓練中,將teacher forcing替換為使用解碼器在上一時間步的輸出作為解碼器在當前時間步的輸入,結果有什么變化嗎?