基於seq2seq文本生成的解碼/采樣策略


基於seq2seq文本生成的解碼/采樣策略


基於Seq2Seq模型的文本生成有各種不同的decoding strategy。文本生成中的decoding strategy主要可以分為兩大類:

  • Argmax Decoding: 主要包括beam search, class-factored softmax等
  • Stochastic Decoding: 主要包括temperature sampling, top-k sampling等。

在Seq2Seq模型中,RNN Encoder對輸入句子進行編碼,生成一個大小固定的hidden state \(h_c\);基於輸入句子的hidden state \(h_c\) 和先前生成的第1到t-1個詞\(x_{1:t-1}\),RNN Decoder會生成當前第t個詞的hidden state \(h_t\) ,最后通過softmax函數得到第t個詞 \(x_t\) 的vocabulary probability distribution \(P(x|x_{1:t-1})\)

兩類decoding strategy的主要區別就在於,如何從vocabulary probability distribution \(P(x|x_{1:t-1})\)中選取一個詞 \(x_t\)

  • Argmax Decoding的做法是選擇詞表中probability最大的詞,即\(x_t=argmax\quad P(x|x_{1:t-1})\) ;
  • Stochastic Decoding則是基於概率分布\(P(x|x_{1:t-1})\) 隨機sample一個詞 \(x_t\),即 \(x_t \sim P(x|x_{1:t-1})\)

在做seq predcition時,需要根據假設模型每個時刻softmax的輸出概率來sample單詞,合適的sample方法可能會獲得更有效的結果。

1. 貪婪采樣

  1. Greedy Search

    核心思想:每一步取當前最大可能性的結果,作為最終結果。

    具體方法:獲得新生成的詞是vocab中各個詞的概率,取argmax作為需要生成的詞向量索引,繼而生成后一個詞。

  2. Beam Search

    核心思想: beam search嘗試在廣度優先基礎上進行進行搜索空間的優化(類似於剪枝)達到減少內存消耗的目的。

    具體方法:在decoding的每個步驟,我們都保留着 top K 個可能的候選單詞,然后到了下一個步驟的時候,我們對這 K 個單詞都做下一步 decoding,分別選出 top K,然后對這 K^2 個候選句子再挑選出 top K 個句子。以此類推一直到 decoding 結束為止。當然 Beam Search 本質上也是一個 greedy decoding 的方法,所以我們無法保證自己一定可以得到最好的 decoding 結果。

Greedy Search和Beam Search存在的問題:

  1. 容易出現重復的、可預測的詞;
  2. 句子/語言的連貫性差。

2. 隨機采樣

核心思想: 根據單詞的概率分布隨機采樣。

  1. Temperature Sampling:

    具體方法:在softmax中引入一個temperature來改變vocabulary probability distribution,使其更偏向high probability words:

    \[P(x|x_{1:t-1})=\frac{exp(u_t/temperature)}{\sum_{t'}exp(u_{t'}/temperature)},temperature\in[0,1) \]

    另一種表示:假設\(p(x)\)為模型輸出的原始分布,給定一個 temperature 值,將按照下列方法對原始概率分布(即模型的 softmax 輸出) 進行重新加權,計算得到一個新的概率分布。

    \[\pi(x_{k})=\frac{e^{log(p(x_k))/temperature}} {\sum_{i=1}^{n}e^{log(p(x_i))/temperature}},temperature\in[0,1) \]

    \(temperature \to 0\),就變成greedy search;當\(temperature \to \infty\),就變成均勻采樣(uniform sampling)。詳見論文:The Curious Case of Neural Text Degeneration

  2. Top-k Sampling:

    可以緩解生成罕見單詞的問題。比如說,我們可以每次只在概率最高的50個單詞中按照概率分布做采樣。我只保留top-k個probability的單詞,然后在這些單詞中根據概率做sampling。

    核心思想:對概率進行降序排序,然后對第k個位置之后的概率轉換為0。

    具體方法:在decoding過程中,從 \(P(x|x_{1:t-1})\) 中選取probability最高的前k個tokens,把它們的probability加總得到 \(p'=\sum P(x|x_{1:t-1})\) ,然后將 \(P(x|x_{1:t-1})\) 調整為 \(P'(x|x_{1:t-1})=P(x|x_{1:t-1})/p'\) ,其中 \(x\in V^{(k)}\)! ,最后從 \(P'(x|x_{1:t-1})\) 中sample一個token作為output token。詳見論文:Hierarchical Neural Story Generation

    但Top-k Sampling存在的問題是,常數k是提前給定的值,對於長短大小不一,語境不同的句子,我們可能有時需要比k更多的tokens。

  3. Top-p Sampling (Nucleus Sampling ):

    核心思想:通過對概率分布進行累加,然后當累加的值超過設定的閾值p,則對之后的概率進行置0。

    具體方法:提出了Top-p Sampling來解決Top-k Sampling的問題,基於Top-k Sampling,它將 \(p'=\sum P(x|x_{1:t-1})\) 設為一個提前定義好的常數\(p'\in(0,1)\) ,而selected tokens根據句子history distribution的變化而有所不同。詳見論文:The Curious Case of Neural Text Degeneration

    本質上Top-p Sampling和Top-k Sampling都是從truncated vocabulary distribution中sample token,區別在於置信區間的選擇。

隨機采樣存在的問題:

  1. 生成的句子容易不連貫,上下文比較矛盾。
  2. 容易生成奇怪的句子,出現罕見詞。

3. 參考

LSTM文本生成:《Python深度學習》第8章第1節:8.1 使用LSTM生成文本P228-P234。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM