論文筆記之:SeqGAN: Sequence generative adversarial nets with policy gradient


 SeqGAN: Sequence generative adversarial nets with policy gradient 

AAAI-2017 

 

Paperhttps://arxiv.org/abs/1609.05473 

Offical Tensorflow Codehttps://github.com/LantaoYu/SeqGAN 

PyTorch Codehttps://github.com/suragnair/seqGAN 

 

Introduction : 

產生序列模擬數據來模仿 real data 是無監督學習中非常重要的課題之一。最近, RNN/LSTM 框架在文本生成上取得了非常好的效果,最常見的訓練方法是:給定上一個 token,推測當前 token 的最大化似然概率。但是最大似然方法容易受到 “exposure bias” 的干擾:the model generates a sequence iteratively and predicts next token conditioned on its previously predicted ones that may be never observed in the training data。這種 training 和 inference 之間的差異可以招致 accumulatively,隨着 sequence 的累計,將會隨着 sequence 的增長,變得 prominent。為了解決這個問題,Bengio 在 2015 年提出了 schedule sampling (SS) 的方法,但是又有人說這種方法在某些情況下也會失效。另一個可能的解決方案(the training/inference discrepancy problem)是:在整個產生的序列上構建損失函數,而不是每一個翻譯(to build the loss function on the entire generated sequence instead of each trainsition)。但是,在許多其他的應用上,如:poem generation 和 chatbot,一個 task specific loss 無法直接准確的用來評價產生的序列。

 

GAN 是最近比較熱門的研究課題,已經廣泛的應用於 CV 的許多課題上,但是,不幸的是,直接用 GAN 來產生 sequence 有兩個問題:

(1),GAN 被設計用來產生 real-valued, continuous data,但是在直接產生 離散的 tokens 的序列,是有問題的,如:text。The reason is that in GANs, the generator starts with random sampling first and then a determistic transform, govermented by the model parameters. As such, the gradient of the loss from D w.r.t. the outputs by G is used to guide the generative model G (paramters) to slightly change the generated value to make it more realistic. 但是,如果基於離散的 tokens 產生的數據,從 D 的 loss 得到的 “slight change” 卻不是很有道理,因為可能根本不存在這樣的 token 使得這一改變有意義(因為 字典空間是有效的)。

(2),GAN 僅僅可以提供 score/loss 給整個的 sequence,而對於部分產生的序列,卻無法判斷目前已經有多好了。(GAN can only give the score/loss for an entire sequence when it has been generated; for a partially generated sequence, it is non-trivial to balance how good as it is now and the future score as the entire sequence. )

  

本文提出一種思路來解決上述問題,將 序列產生問題 看做是 序列決策問題(consider the sequence generation procedure as a sequential decision making problem)。產生器 被認為是 RL 當中的 agent;狀態是 目前已經產生的 tokens,動作是 下一步需要產生的 token。不像 Bahdanau et al. 2016 提出的方法那樣需要 a task specific sequence score, such as BLEU in machine translation,為了給出獎勵,我們用 discriminator 來評價 sequence,並且反饋評價來引導 generative model 的學習。為了解決 當輸出是離散的,梯度無法回傳給 generative model 的情況,我們將 generative model 看做是 stochastic parameterized policy。在我們的策略梯度,我們采用 MC 搜索來近似 the state-action value。我們直接用 policy gradient 來訓練 policy,很自然的就避免了傳統 GAN 中,離散數據的微分困難問題(the differentiation difficulty for discrete data in a conventional GAN)。

 

Sequence Generative Adversarial Nets :   

  

As illustrated in Figure 1, the discriminative model Dφ is trained by providing positive examples from the real sequence data and negative examples from the synthetic sequences generated from the generative model Gθ. At the same time, the generative model Gθ is updated by employing a policy gradient and MC search on the basis of the expected end reward received from the discriminative model Dφ. The reward is estimated by the likelihood that it would fool the discriminative model Dφ. The specific formulation is given in the next subsection. 

 

SeqGAN via PolicyGradient

Following (Sutton et al. 1999), when there is no intermediate reward, the objective of the generator model (policy) Gθ(yt|Y1:t−1) is to generate a sequence from the start state s0 to maximize its expected end reward: 

  

其中,RT 是整個序列的獎勵,獎勵來自於 判別器 Dφ。QGθ Dφ(s,a) is the action-value function of a sequence, i.e. the expected accumulative reward starting from state s, taking action a, and then following policy Gθ. 目標函數的合理性應該是: 從給定的初始狀態,產生器的目標是產生一個序列,使得 discriminator 認為是真的。

  

下一個問題就是:如何如何預測 the action-value function。本文當中,我們采用 REINFORCE algorithm,consider the estimated probability of being real by the discriminator D as the reward。意思是說,如果 判別器 D 認為給定的 fake sequence 是真的,其概率記為 reward,此時:概率越高,reward 越大,這兩者是成正比例關系的。正式的來說,我們有:

  

然而,這個 discriminator 僅僅提供了一個 reward 給一個已經結束的 sequence。因為我們實際上關心的是長期的匯報,在每一個時間步驟,我們不但應該考慮到 previous tokens 的擬合程度,也考慮到 the resulted future outcome。就像是下棋的游戲,玩家有時會放棄即可的獎賞,而為了得到更加長遠的獎勵。所以,為了評價 the action-value for an intermediate state,我們采用 MC search with a roll-out policy to sample the unkown last T-1 tokens。我們表示一個 N-time 的 MC search 為:

  

其中,Y^n_{1:t} ={y1, ... , yt} and Y^n_{t+1:T} is sampled based on the roll-out policy and the current state。在我們的實驗當中,$G\beta$ 也設置為 the generator。為了降低 variance,並且得到更加精確地  action value 的估計值,我們運行 the roll-out policy starting from current state 直到 序列的結束,N times,以得到一批輸出樣本。所以,我們有:

  

其中,我們看到 當沒有即可獎賞的時候,該函數被迭代的定義為:the next-state value starting from state s' = Y1:t and rolling out to the end。

  

利用 判別器 D 作為獎賞函數的一個函數是:it can be dynamically updated to further improve the generative model interatively(為了進一步的提升產生式模型,它可以被動態的更新)。一旦我們有了一筆新的 更加 realistic 的產生的序列,我們應該重新訓練 the discriminator model as follows:

  

每次當一個新的判別式模型已經被訓練完畢的時候,我們已經准備好來更新 generator。所提出的 基於策略的方法依賴於優化一個參數化的策略,來直接最大化 the long-term reward。目標函數 J 的梯度可以寫為:

  

上述形式是由於 the deterministic state transition and zero intermediate rewards。利用 likelihood ratio,我們構建一種 unbiased estimation of Eq.(6) : 

  

其中,$Y_{1:t}$ 是觀察到的 intermediate state sampled from $G\theta$。因為期望 E[*] 可以通過采樣的方法進行估計,我們然后更新產生器的參數:

  

其中,$\alpha$ 代表了對應的時刻 h-th step 的學習率。

  

整體的算法流程如下圖所示:

  

本文首先用 最大似然估計的方法進行預訓練 產生器 G,然后用 迭代的進行 G, D 的訓練。

  

然后就是對 G 和 D 的具體結構進行了解釋:

The Generative Model for Sequence:

用 LSTM 來編碼 sentences,然后將其映射到 下一個時刻 token 的概率分布。

The Discriminative Model for Sequence

此處的判別器,作者利用 CNN 的方法來進行判別。作者首先將 Word 轉為 vector,然后一句話弄成了一個 matrix,然后用多個卷積核,進行特征提取。為了提升精度,作者也加了 highway architecture based on the pooled feature maps. 最后,添加了 fc layer 以及 sigmoid activation 來輸出 給定的序列為真的概率(to output the probabiltiy that the input sequence is real)。優化的目標是:最小化 the groundtruth labels 和 the predicted probability 之間的 cross entropy loss。

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM