基於seq2seq文本生成的解碼/采樣策略
基於Seq2Seq模型的文本生成有各種不同的decoding strategy。文本生成中的decoding strategy主要可以分為兩大類:
- Argmax Decoding: 主要包括beam search, class-factored softmax等
- Stochastic Decoding: 主要包括temperature sampling, top-k sampling等。
在Seq2Seq模型中,RNN Encoder對輸入句子進行編碼,生成一個大小固定的hidden state \(h_c\);基於輸入句子的hidden state \(h_c\) 和先前生成的第1到t-1個詞\(x_{1:t-1}\),RNN Decoder會生成當前第t個詞的hidden state \(h_t\) ,最后通過softmax函數得到第t個詞 \(x_t\) 的vocabulary probability distribution \(P(x|x_{1:t-1})\)。
兩類decoding strategy的主要區別就在於,如何從vocabulary probability distribution \(P(x|x_{1:t-1})\)中選取一個詞 \(x_t\) :
- Argmax Decoding的做法是選擇詞表中probability最大的詞,即\(x_t=argmax\quad P(x|x_{1:t-1})\) ;
- Stochastic Decoding則是基於概率分布\(P(x|x_{1:t-1})\) 隨機sample一個詞 \(x_t\),即 \(x_t \sim P(x|x_{1:t-1})\) 。
在做seq predcition時,需要根據假設模型每個時刻softmax的輸出概率來sample單詞,合適的sample方法可能會獲得更有效的結果。
1. 貪婪采樣
-
Greedy Search
核心思想:每一步取當前最大可能性的結果,作為最終結果。
具體方法:獲得新生成的詞是vocab中各個詞的概率,取argmax作為需要生成的詞向量索引,繼而生成后一個詞。
-
Beam Search
核心思想: beam search嘗試在廣度優先基礎上進行進行搜索空間的優化(類似於剪枝)達到減少內存消耗的目的。
具體方法:在decoding的每個步驟,我們都保留着 top K 個可能的候選單詞,然后到了下一個步驟的時候,我們對這 K 個單詞都做下一步 decoding,分別選出 top K,然后對這 K^2 個候選句子再挑選出 top K 個句子。以此類推一直到 decoding 結束為止。當然 Beam Search 本質上也是一個 greedy decoding 的方法,所以我們無法保證自己一定可以得到最好的 decoding 結果。
Greedy Search和Beam Search存在的問題:
- 容易出現重復的、可預測的詞;
- 句子/語言的連貫性差。
2. 隨機采樣
核心思想: 根據單詞的概率分布隨機采樣。
-
Temperature Sampling:
具體方法:在softmax中引入一個temperature來改變vocabulary probability distribution,使其更偏向high probability words:
\[P(x|x_{1:t-1})=\frac{exp(u_t/temperature)}{\sum_{t'}exp(u_{t'}/temperature)},temperature\in[0,1) \]另一種表示:假設\(p(x)\)為模型輸出的原始分布,給定一個 temperature 值,將按照下列方法對原始概率分布(即模型的 softmax 輸出) 進行重新加權,計算得到一個新的概率分布。
\[\pi(x_{k})=\frac{e^{log(p(x_k))/temperature}} {\sum_{i=1}^{n}e^{log(p(x_i))/temperature}},temperature\in[0,1) \]當\(temperature \to 0\),就變成greedy search;當\(temperature \to \infty\),就變成均勻采樣(uniform sampling)。詳見論文:The Curious Case of Neural Text Degeneration
-
Top-k Sampling:
可以緩解生成罕見單詞的問題。比如說,我們可以每次只在概率最高的50個單詞中按照概率分布做采樣。我只保留top-k個probability的單詞,然后在這些單詞中根據概率做sampling。
核心思想:對概率進行降序排序,然后對第k個位置之后的概率轉換為0。
具體方法:在decoding過程中,從 \(P(x|x_{1:t-1})\) 中選取probability最高的前k個tokens,把它們的probability加總得到 \(p'=\sum P(x|x_{1:t-1})\) ,然后將 \(P(x|x_{1:t-1})\) 調整為 \(P'(x|x_{1:t-1})=P(x|x_{1:t-1})/p'\) ,其中 \(x\in V^{(k)}\)! ,最后從 \(P'(x|x_{1:t-1})\) 中sample一個token作為output token。詳見論文:Hierarchical Neural Story Generation
但Top-k Sampling存在的問題是,常數k是提前給定的值,對於長短大小不一,語境不同的句子,我們可能有時需要比k更多的tokens。
-
Top-p Sampling (Nucleus Sampling ):
核心思想:通過對概率分布進行累加,然后當累加的值超過設定的閾值p,則對之后的概率進行置0。
具體方法:提出了Top-p Sampling來解決Top-k Sampling的問題,基於Top-k Sampling,它將 \(p'=\sum P(x|x_{1:t-1})\) 設為一個提前定義好的常數\(p'\in(0,1)\) ,而selected tokens根據句子history distribution的變化而有所不同。詳見論文:The Curious Case of Neural Text Degeneration
本質上Top-p Sampling和Top-k Sampling都是從truncated vocabulary distribution中sample token,區別在於置信區間的選擇。
隨機采樣存在的問題:
- 生成的句子容易不連貫,上下文比較矛盾。
- 容易生成奇怪的句子,出現罕見詞。
3. 參考
LSTM文本生成:《Python深度學習》第8章第1節:8.1 使用LSTM生成文本P228-P234。