總結一些常用的訓練 GANs 的方法


眾所周知,GANs 的訓練尤其困難,筆者自從跳入了 GANs 這個領域(坑),就一直在跟如何訓練 GANs 做「對抗訓練」,受啟發於 ganhacks,並結合自己的經驗記錄總結了一些常用的訓練 GANs 的方法,以備后用。

什么是 GANs?

GANs(Generative Adversarial Networks)可以說是一種強大的「萬能」數據分布擬合器,主要由一個生成器(generator)和判別器(discriminator)組成。生成器主要從一個低維度的數據分布中不斷擬合真實的高維數據分布,而判別器主要是為了區分數據是來源於真實數據還是生成器生成的數據,他們之間相互對抗,不斷學習,最終達到Nash均衡,即任何一方的改進都不會導致總體的收益增加,這個時候判別器再也無法區分是生成器生成的數據還是真實數據。

GANs 最初由 Ian Goodfellow [1] 於 2014 年提出,目前已經在圖像、語音、文字等方面得到廣泛研究和應用,特別是在圖像生成方面,可謂是遍地開花,例如圖像風格遷移(style transfer)、圖像修復(image inpainting)、超分辨率(super resolution)等。

GANs 出了什么問題?

GANs 通常被定義為一個 minimax 的過程:

其中 P_r 是真實數據分布,P_z 是隨機噪聲分布。乍一看這個目標函數,感覺有點相互矛盾,其實這就是 GANs 的精髓所在—— 對抗訓練。

在原始的 GANs 中,判別器要不斷的提高判別是非的能力,即盡可能的將真實樣本分類為正例,將生成樣本分類為負例,所以判別器需要優化如下損失函數:

作為對抗訓練,生成器需要不斷將生成數據分布拉到真實數據分布,Ian Goodfellow 首先提出了如下式的生成器損失函數:

由於在訓練初期階段,生成器的能力比較弱,判別器這時候也比較弱,但仍然可以足夠精准的區分生成樣本和真實樣本,這樣 D(x) 就非常接近1,導致 log(1-D(x)) 達到飽和,后續網絡就很難再調整過來。為了解決訓練初期階段飽和問題,作者提出了另外一個損失函數,即:

以上面這個兩個生成器目標函數為例,簡單地分析一下GAN模型存在的幾個問題:

Ian Goodfellow 論文里面已經給出,固定 G 的參數,我們得到最優的 D^*:

也就是說,只有當 P_r=P_g 時候,不管是真實樣本和生成樣本,判別器給出的概率都是 0.5,這個時候就無法區分樣本到底是來自於真實樣本還是來自於生成樣本,這是最理想的情況。

1. 對於第一種目標函數

在最優判別器下 D^* 下,我們給損失函數加上一個與 G 無關的項,(3) 式變成:

注意,該式子其實就是判別器的損失函數的相反數。

把最優判別器 D^* 帶入,可以得到:

到這里,我們就可以看清楚我們到底在優化什么東西了,在最優判別器的情況下,其實我們在優化兩個分布的 JS 散度。當然在訓練過程中,判別器一開始不是最優的,但是隨着訓練的進行,我們優化的目標也逐漸接近JS散度,而問題恰恰就出現在這個 JS 散度上面。一個直觀的解釋就是只要兩個分布之間的沒有重疊或者重疊部分可以忽略不計,那么大概率上我們優化的目標就變成了一個常數 -2log2,這種情況通過判別器傳遞給生成器的梯度就是零,也就是說,生成器不可能從判別器那里學到任何有用的東西,這也就導致了無法繼續學習。

Arjovsky [2] 以其精湛的數學技巧提供一個更嚴謹的一個數學推導(手動截圖原論文了)。

在 Theorm2.4 成立的情況下:

拋開上面這些文縐縐的數學表述,其實上面講的核心內容就是當兩個分布的支撐集是沒有交集的或者說是支撐集是低維的流形空間,隨着訓練的進行,判別器不斷接近最優判別器,會導致生成器的梯度處處都是為0。

2. 對於第二種目標函數

同樣在最優判別器下,優化 (4) 式等價優化如下

仔細盯着上面式子幾秒鍾,不難發現我們優化的目標是相互悖論的,因為 KL 散度和 JS 散度的符號相反,優化 KL 是把兩個分布拉近,但是優化 -JS 是把兩個分布推遠,這「一推一拉」就會導致梯度更新非常不穩定。此外,我們知道 KL 不是對稱的,對於生成器無法生成真實樣本的情況,KL 對 loss 的貢獻非常大,而對於生成器生成的樣本多樣性不足的時候,KL 對 loss 的貢獻非常小。

而 JS 是對稱的,不會改變 KL 的這種不公平的行為。這就解釋了我們經常在訓練階段經常看見兩種情況,一個是訓練 loss 抖動非常大,訓練不穩定;另外一個是即使達到了穩定訓練,生成器也大概率上只生成一些安全保險的樣本,這樣就會導致模型缺乏多樣性。

此外,在有監督的機器學習里面,經常會出現一些過擬合的情況,然而 GANs 也不例外。當生成器訓練得越來越好時候,生成的數據越接近於有限樣本集合里面的數據。特別是當訓練集里面包含有錯誤數據時候,判別器會過擬合到這些錯誤的數據,對於那些未見的數據,判別器就不能很好的指導生成器去生成可信的數據。這樣就會導致 GANs 的泛化能力比較差。

綜上所述,原始的 GANs 在訓練穩定性、模式多樣性以及模型泛化性能方面存在着或多或少的問題,后續學術上的工作大多也是基於此進行改進(填坑)。

訓練 GAN 的常用策略

上一節都是基於一些簡單的數學或者經驗的分析,但是根本原因目前沒有一個很好的理論來解釋;盡管理論上的缺陷,我們仍然可以從一些經驗中發現一些實用的 tricks,讓你的 GANs 不再難訓。這里列舉的一些 tricks 可能跟 ganhacks 里面的有些重復,更多的是補充,但是為了完整起見,部分也添加在這里。

1. model choice

如果你不知道選擇什么樣的模型,那就選擇 DCGAN[3] 或者 ResNet[4] 作為 base model。

2. input layer

假如你的輸入是一張圖片,將圖片數值歸一化到 [-1, 1];假如你的輸入是一個隨機噪聲的向量,最好是從 N(0, 1) 的正態分布里面采樣,不要從 U(0,1) 的均勻分布里采樣。

3. output layer

使用輸出通道為 3 的卷積作為最后一層,可以采用 1x1 或者 3x3 的 filters,有的論文也使用 9x9 的 filters。(注:ganhacks 推薦使用 tanh)

4. transposed convolution layer

在做 decode 的時候,盡量使用 upsample+conv2d 組合代替 transposed_conv2d,可以減少 checkerboard 的產生 [5];

在做超分辨率等任務上,可以采用 pixelshuffle [6]。在 tensorflow 里,可以用 tf.depth_to_sapce 來實現 pixelshuffle 操作。

5. convolution layer

由於筆者經常做圖像修復方向相關的工作,推薦使用 gated-conv2d [7]。

6. normalization

雖然在 resnet 里的標配是 BN,在分類任務上表現很好,但是圖像生成方面,推薦使用其他 normlization 方法,例如 parameterized 方法有 instance normalization [8]、layer normalization [9] 等,non-parameterized 方法推薦使用 pixel normalization [10]。假如你有選擇困難症,那就選擇大雜燴的 normalization 方法——switchable normalization [11]。

7. discriminator

想要生成更高清的圖像,推薦 multi-stage discriminator [10]。簡單的做法就是對於輸入圖片,把它下采樣(maxpooling)到不同 scale 的大小,輸入三個不同參數但結構相同的 discriminator。

8. minibatch discriminator

由於判別器是單獨處理每張圖片,沒有一個機制能告訴 discriminator 每張圖片之間要盡可能的不相似,這樣就會導致判別器會將所有圖片都 push 到一個看起來真實的點,缺乏多樣性。minibatch discriminator [22] 就是這樣這個機制,顯式地告訴 discriminator 每張圖片應該要不相似。在 tensorflow 中,一種實現 minibatch discriminator 方式如下:

上面是通過一個可學習的網絡來顯示度量每個樣本之間的相似度,PGGAN 里提出了一個更廉價的不需要學習的版本,即通過統計每個樣本特征每個像素點的標准差,然后取他們的平均,把這個平均值復制到與當前 feature map 一樣空間大小單通道,作為一個額外的 feature maps 拼接到原來的 feature maps 里,一個簡單的 tensorflow 實現如下:

9. GAN loss

除了第二節提到的原始 GANs 中提出的兩種 loss,還可以選擇 wgan loss [12]、hinge loss、lsgan loss [13]等。wgan loss 使用 Wasserstein 距離(推土機距離)來度量兩個分布之間的差異,lsgan 采用類似最小二乘法的思路設計損失函數,最后演變成用皮爾森卡方散度代替了原始 GAN 中的 JS 散度,hinge loss 是遷移了 SVM 里面的思想,在 SAGAN [14] 和 BigGAN [15] 等都是采用該損失函數。

ps: 我自己經常使用沒有 relu 的 hinge loss 版本。

10. other loss

  • perceptual loss [17]
  • style loss [18]
  • total variation loss [17]
  • l1 reconstruction loss

通常情況下,GAN loss 配合上面幾種 loss,效果會更好。

11. gradient penalty

Gradient penalty 首次在 wgan-gp 里面提出來的,記為 1-gp,目的是為了讓 discriminator 滿足 1-lipchitchz 連續,后續 Mescheder, Lars M. et al [19] 又提出了只針對正樣本或者負樣本進行梯度懲罰,記為 0-gp-sample。Thanh-Tung, Hoang et al [20] 提出了 0-gp,具有更好的訓練穩定性。三者的對比如下:

12. Spectral normalization [21]

譜歸一化是另外一個讓判別器滿足 1-lipchitchz 連續的利器,建議在判別器和生成器里同時使用。

ps: 在個人實踐中,它比梯度懲罰更有效。

13. one-size label smoothing [22]

平滑正樣本的 label,例如 label 1 變成 0.9-1.1 之間的隨機數,保持負樣本 label 仍然為 0。個人經驗表明這個 trick 能夠有效緩解訓練不穩定的現象,但是不能根本解決問題,假如模型不夠好的話,隨着訓練的進行,后期 loss 會飛。

14. add supervised labels

  • add labels
  • conditional batch normalization

15. instance noise (decay over time)

在原始 GAN 中,我們其實在優化兩個分布的 JS 散度,前面的推理表明在兩個分布的支撐集沒有交集或者支撐集是低維的流形空間,他們之間的 JS 散度大概率上是 0;而加入 instance noise 就是強行讓兩個分布的支撐集之間產生交集,這樣 JS 散度就不會為 0。新的 JS 散度變為:

16. TTUR [23]

在優化 G 的時候,我們默認是假定我們的 D 的判別能力是比當前的 G 的生成能力要好的,這樣 D 才能指導 G 朝更好的方向學習。通常的做法是先更新 D 的參數一次或者多次,然后再更新 G 的參數,TTUR 提出了一個更簡單的更新策略,即分別為 D 和 G 設置不同的學習率,讓 D 收斂速度更快。

17. training strategy

  • PGGAN [10]

PGGAN 是一個漸進式的訓練技巧,因為要生成高清(eg, 1024x1024)的圖片,直接從一個隨機噪聲生成這么高維度的數據是比較難的;既然沒法一蹴而就,那就循序漸進,首先從簡單的低緯度的開始生成,例如 4x4,然后 16x16,直至我們所需要的圖片大小。在 PGGAN 里,首次實現了高清圖片的生成,並且可以做到以假亂真,可見其威力。此外,由於我們大部分的操作都是在比較低的維度上進行的,訓練速度也不比其他模型遜色多少。

  • coarse-to-refine

coarse-to-refine 可以說是 PGGAN 的一個特例,它的做法就是先用一個簡單的模型,加上一個 l1 loss,訓練一個模糊的效果,然后再把這個模糊的照片送到后面的 refine 模型里,輔助對抗 loss 等其他 loss,訓練一個更加清晰的效果。這個在圖片生成里面廣泛應用。

18. Exponential Moving Average [24]

EMA主要是對歷史的參數進行一個指數平滑,可以有效減少訓練的抖動。強烈推薦!!!

總結

訓練 GAN 是一個精(折)細(磨)的活,一不小心你的 GAN 可能就是一部驚悚大片。筆者結合自己的經驗以及看過的一些文獻資料,列出了常用的 tricks,在此拋磚引玉,由於筆者能力和視野有限,有些不正確之處或者沒補全的 tricks,還望斧正。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM