WGAN的改進點和實操


包含三部分:1、WGAN改進點  2、代碼修改  3、訓練心得

一、WGAN的改進部分:

  • 判別器最后一層去掉sigmoid    (相當於最后一層做了一個y = x的激活)
  • 生成器和判別器的loss不取log
  • 每次更新判別器的參數之后把它們的絕對值截斷到不超過一個固定常數c
  • 不要用基於動量的優化算法(包括momentum和Adam),推薦RMSProp,SGD也行        (這部分很玄學)

去掉sigmoid會出現什么問題?

優點: 去掉sigmoid 只要二者存在差值就會學習讓他們盡量小

缺點:去掉sigmoid 判別器的輸出會到無窮大 生成器也會到無窮大(只要二者的差值很小就滿足條件)無法優化。

                         (公式1)

 

如何解決(上述)無法優化問題(loss可能一直上升)?

這就是WGAN的第三個改進點。(每次更新判別器的參數之后把它們的絕對值截斷到不超過一個固定常數c

                   (公式2)(作者用這個公式來表達,證明過程再論文附錄中)

 

詳細解讀(這部分參看:https://blog.csdn.net/omnispace/article/details/54942668

分析

首先需要介紹一個概念——Lipschitz連續。它其實就是在一個連續函數f上面額外施加了一個限制,要求存在一個常數K\geq 0使得定義域內的任意兩個元素x_1x_2都滿足

|f(x_1) - f(x_2)| \leq K |x_1 - x_2|

此時稱函數f的Lipschitz常數為K

簡單理解,比如說f的定義域是實數集合,那上面的要求就等價於f的導函數絕對值不超過K(這里是導數概念(f(x1) - f(x2))/(x1-x2) 為導數)。再比如說\log (x)就不是Lipschitz連續,因為它的導函數沒有上界。Lipschitz連續條件限制了一個連續函數的最大局部變動幅度。

公式2的意思就是在要求函數f的Lipschitz常數||f||_L不超過K的條件下,對所有可能滿足條件的f取到\mathbb{E}_{x \sim P_r} [f(x)] - \mathbb{E}_{x \sim P_g} [f(x)]的上界,然后再除以K。特別地,我們可以用一組參數w來定義一系列可能的函數f_w,此時求解公式2可以近似變成求解如下形式

K \cdot W(P_r, P_g) \approx \max_{w: |f_w|_L \leq K} \mathbb{E}_{x \sim P_r} [f_w(x)] - \mathbb{E}_{x \sim P_g} [f_w(x)]                      (公式3)

再用上我們搞深度學習的人最熟悉的那一套,不就可以把f用一個帶參數w的神經網絡來表示嘛!由於神經網絡的擬合能力足夠強大,我們有理由相信,這樣定義出來的一系列f_w雖然無法囊括所有可能,但是也足以高度近似公式2要求的那個sup_{||f||_L \leq K}了。

最后,還不能忘了滿足公式3中||f_w||_L \leq K這個限制。我們其實不關心具體的K是多少,只要它不是正無窮就行,因為它只是會使得梯度變大K倍,並不會影響梯度的方向。所以作者采取了一個非常簡單的做法,就是限制神經網絡f_\theta的所有參數w_i的不超過某個范圍[-c, c],比如w_i \in [- 0.01, 0.01],此時關於輸入樣本x的導數\frac{\partial f_w}{\partial x}也不會超過某個范圍,所以一定存在某個不知道的常數K使得f_w的局部變動幅度不會超過它,Lipschitz連續條件得以滿足。具體在算法實現中,只需要每次更新完w后把它clip回這個范圍就可以了。

到此為止,我們可以構造一個含參數w、最后一層不是非線性激活層的判別器網絡f_w,在限制w不超過某個范圍的條件下,使得

L = \mathbb{E}_{x \sim P_r} [f_w(x)] - \mathbb{E}_{x \sim P_g} [f_w(x)]                       (公式4)

盡可能取到最大,此時L就會近似真實分布與生成分布之間的Wasserstein距離(忽略常數倍數K)。注意原始GAN的判別器做的是真假二分類任務,所以最后一層是sigmoid,但是現在WGAN中的判別器f_w做的是近似擬合Wasserstein距離,屬於回歸任務,所以要把最后一層的sigmoid拿掉。

接下來生成器要近似地最小化Wasserstein距離,可以最小化L,由於Wasserstein距離的優良性質,我們不需要擔心生成器梯度消失的問題。再考慮到L的第一項與生成器無關,就得到了WGAN的兩個loss。

 

二、代碼修改:

根據改進的四個部分來修改代碼(TF下):

加變量:

 

1 CLIP = [-0.01, 0.01]  #用來截斷w(第三個改進點)
2 CRITIC_NUM = 5     #權衡訓練次數  Discrimnator要訓練的比Genenrator多(5 次Discrimnator 一次 G

① 判別器最后一層去掉sigmoid

1 return tf.nn.sigmoid(h4), h4
2 替換后:
3 return h4, h4

② 生成器和判別器的loss不取log

 

原始的GAN loss為:

                       min GmaxD Exq(x)[logD(x)]+Ezp(z)[log(1D(G(z)))

去掉log為        min GmaxD    D(x) + 1D(G(z))

由於最大化D 我們在代碼中應該加 “-”     D loss:  minD   -(D(x) + 1D(G(z)))  

                                                               G loss   minG  D(G(z))

 

1 self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits, labels=tf.ones_like(self.D)))
2 self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.zeros_like(self.D_)))
3 self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.ones_like(self.D_))) \
4 self.d_loss = self.d_loss_real + self.d_loss_fake

 

修改D loss為:

1 self.d_loss_real = tf.reduce_mean(self.D_logits)
2 self.d_loss_fake = -tf.reduce_mean(self.D_logits_)
3 self.d_loss = -(self.d_loss_real + self.d_loss_fake)

修改G loss為:

1 self.g_loss = -tf.reduce_mean(self.D_logits_)

 

③ ④  每次更新判別器的參數之后把它們的絕對值截斷到不超過一個固定常數c(放到參數更新后)   修改優化器

原始:

1 d_optim = tf.train.AdamOptimizer(args.lr, beta1=args.beta1) \
2                           .minimize(self.d_loss, var_list=self.d_vars)
3 g_optim = tf.train.AdamOptimizer(args.lr, beta1=args.beta1) \
4                           .minimize(self.g_loss, var_list=self.g_vars)

修改為:

1 d_optim = tf.train.RMSPropOptimizer(args.lr, beta1=args.beta1) \
2             .minimize(self.d_loss, var_list=self.d_vars)
3 g_optim = tf.train.RMSPropOptimizer(args.lr, beta1=args.beta1) \
4             .minimize(self.g_loss, var_list=self.g_vars)
5 clip_d_op = [var.assign(tf.clip_by_value(var, CILP[0], CILP[1])) for var in self.d_vars]     #進行截斷

 

三、訓練心得:

一、權重

a. 調節Generator loss中GAN loss的權重
G loss和Gan loss在一個尺度上或者G loss比Gan loss大一個尺度。但是千萬不能讓Gan loss占主導地位, 這樣整個網絡權重會被帶偏。

二、訓練次數
b. 調節Generator和Discrimnator的訓練次數比
一般來說,Discrimnator要訓練的比Genenrator多。比如訓練五次Discrimnator,再訓練一次Genenrator(WGAN論文 是這么干的)。

三、學習率
c. 調節learning rate
這個學習速率不能過大。一般要比Genenrator的速率小一點。

四、優化器
d. Optimizer的選擇不能用基於動量法的
如Adam和momentum。可使用RMSProp或者SGD。

五、結構
e. Discrimnator的結構可以改變
如果用WGAN,判別器的最后一層需要去掉sigmoid。但是用原始的GAN,需要用sigmoid,因為其loss function里面需要取log,所以值必須在[0,1]。這里用的是鄧煒的critic模型當作判別器。之前twitter的論文里面的判別器即使去掉了sigmoid也不好訓練。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM