GAN在seq2seq中的應用 Application to Sequence Generation


Improving Supervised Seq-to-seq Model
有監督的 seq2seq ,比如機器翻譯、聊天機器人、語音辨識之類的 。

 

而 generator 其實就是典型的 seq2seq model ,可以把 GAN 應用到這個任務中。

 

RL(human feedback)
訓練目標是,最大化 expected reward。很大的不同是,並沒有事先給定的 label,而是人類來判斷,生成的 x 好還是不好。
 
簡單介紹一下 policy gradient。更新 encoder 和 generator 的參數來最大化 human 函數的輸出。最外層對所有可能的輸入 h 求和(weighted sum,因為不同的 h 有不同的采樣概率);對一個給定的 h,對所有的可能的 x 求和(因為同樣的 seq 輸入可能會產生不一樣的 seq 輸出);求和項為 R(h, x)*P_θ (x | h) ,表示給定一個 h 產生 x 的概率以及對應得到的 reward(整項合起來看,就是 reward 的期望)

 

用 sampling 后求平均來近似求期望:

 

 

但是 R_θ 近似后並沒有體現 θ(隱藏到 sampling 過程中去了),怎么算梯度?先對 P_θ (x | h) 求梯度,然后分子分母同乘 P_θ (x | h) ,而 grad(P_θ (x | h)) / P_θ (x | h) 就等於 grad(log P_θ (x | h)),所以就在 R_θ 原本的近似項上乘一個 grad(log P_θ (x | h))

 

如果是 positive 的 reward(R(hi, xi) > 0), 更新 θ 后  P_θ (xi | hi) 會增加;反之會減小(所以最好人類給的 reward 是有正有負的)

 

整個 implement 的過程就如下圖所示,注意每次更新 θ 后,都要重新 sampling

 

RL 的方法和之前所說的 seq2seq model (based on maximum likelihood)的區別

 

GAN(discriminator feedback)
不再是人給 feedback,而是 discriminator 給 feedback。

 

 

訓練流程。訓練 D 來分辨 <c, x> pair 到底是來自於 chatbot 還是人類的對話;訓練 G 來使得固定的 D 給來自 chatbot 的 (c', x~) 高分。

 

 

仔細想一下,訓練 G 的過程中是存在問題的,因為決定 LSTM 在每一個 time step 的 token 的時候實際上做了 sampling (或者取argmax),所以最后的 discriminator 的輸出的梯度傳不到 generator(不可微)。

 

怎么解決?

  1. Gumbel-softmax https://casmls.github.io/general/2017/02/01/GumbelSoftmax.html

  首先需要可以采樣,使得離散的概率分布有意義而不是只能取 argmax。對於 n 維概率向量 π,其對應的離散隨機變量 x π 添加 Gumbel 噪聲再采樣。
  x π  = argmax(log(π i) + G i)
  其中,G 是獨立同分布的標准 Gumbel 分布的隨機變量,cdf 為 F(x) = exp(-exp(-x))。為了要可微,用 softmax 代替 argmax(因為 argmax 不可微,所以光滑地逼近),G 可以通過 Gumbel 分布求逆,從均勻分布中生成 G i = -log(-log(U i)),U i ~ U(0, 1) 
  

 

 

  2. Continuous Input for Discriminator 

  避免 sampling 過程,直接把每一個 time step 的 word distribution 當作 discriminator 的輸入。

   

  這樣做有問題嗎?明顯有,real sentence 的 word distribution 就是每個詞 one-hot 的,而 generated sentence 的 word distribution 本質上就不會是 1-of-N,這樣 discriminator 很容易就能分辨了,而且判斷准則沒有在考慮語義了(直接看是不是 one-hot 就行了)。

  

 

  3. Reinforcement Learning

   

  把 discriminator 的 output 看作是 reward:

    • Update generator to increase discriminator = to get maximum reward    
    • Using the formulation of policy gradient, replace reward  R(c, x) with discriminator output D(c, x)
  
  和典型的 RL 不同的是,discriminator 參數是要 update 的,還是要輸入給 discriminator 現在 chatbot 產生的對話和人類的對話,訓練 discriminator 來分辨。
  

 

 

 
Unsupervised Seq-to-seq Model
 
Text Style Transfer
用 cycle GAN 來實現,訓練兩個 GAN,實現兩個 domain 的互相轉。仍舊要面對 generator 的輸出要 sampling 的情況,選擇上述第二種解決方案,就是連續化。直接用 word embedding 的向量。

 

 

也可以用映射到 common space 的方法,sampling 后離散化的問題,可以用一個新的技巧解決:把 decoder LSTM 的 hidden layer 當作 discriminator 的輸入,就是連續的了。

 
 
Unsupervised Abstractive Summarization
 
Unsupervised Translation

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM