最近學習了生成對抗網絡(GAN),基於幾個經典GAN網絡結構做了些小實驗,包括dcgan,wgan,wgan-gp。坦率的說,wgan,wgan-gp論文的原理還是有點小復雜,我也沒有完全看明白,因此在此就不詳細介紹了,如果感興趣可以閱讀參考部分的論文,本篇博客主要着重於記錄如何利用tensorflow實現這幾種網絡的訓練、預測。下面先簡單介紹下GAN的原理和個人理解,以及dcgan,wgan,wgan-gp的改進,最后給出代碼。
1.GAN原理和個人理解
Generative Adversarial Nets(GAN),生成對抗網絡,是2014年由Goodfellow等人提出的一個深度學習框架,這個框架的目的是生成和真實圖片概率分布一致的輸出,也就是所謂的生成模型。該框架包含兩部分:生成網絡(Generator,簡稱G)、判別網絡(Discriminator,簡稱D)。假設G的輸入是z,輸出是x’,真實數據是x,則G要盡量將輸入映射到輸出數據的分布上,即利用z盡量生成和x接近的x’。而D要盡量將x’和x區分開。具體公式如下:
上面公式中,G(z)是G的輸出即x’。D要最大化這個公式,其實就是區分生成數據和真實數據之間的真假,是個二分類問題,也就是盡量將x分為1類、x’分為0類。另一方面,G又想盡量生成和真實數據近似的x’,這個就體現在欺騙D上,如果D無法分辨x’和x,就說明D成功生成了符合x分布的x’,因此G的目的是要讓x'都被D分為1類(而不是0類)。又因第一部分logD(x)和G沒關系,所以G的目標是最小化下面這個式子,
根據論文的說法,上面這個式子容易saturates,導致G訓練出現困難,所以轉化成最大化log(D(G(z)))。因為一般做優化都是求目標函數的最小值(當然tf也支持最大值),因此實際編程中是這么寫的:
1. min[-log(D(x))-log(1-D(G(z)))]
2. min[-log(1-D(G(z)))]
實際上就是加了個負號。另外,在做前向后向的時候,一般是分別對D、G進行更新,也就是說在后向誤差傳播更新梯度的時候,會先固定G的參數(權重、偏置等),然后前向計算式1並做后向誤差傳播,更新G的參數。接着再固定G的參數,前向計算式2並做后向誤差傳播,更新D的參數。
最后再說下輸入z,在沒有label的經典gan中,z是一個符合正態分布的隨機高維向量,至於在其他經典gan的變體中,z可以是正態分布的隨機高維向量加上label,或者是純label。
2.dcgan、wgan、wgan-gp
Dcgan於2015年提出,相比於最早的GAN相比,主要做了網絡結構方面的改進,例如:引入bm層,去掉pooling以strided conv層代之,去掉fc層,引入relu和leakyRelu等等。這些改進效果帶來了更好的生成圖片效果,不過和早期GAN相比,依然很難訓練,所以這個階段也引入了訓練時的一些小竅門,比如每個stepG迭代兩次,D迭代一次。在一開始的時候D不能訓練的太好,防止G難以收斂等等。
wgan在2017年橫空出世,通過理論分析,wgan指出了傳統gan的為什么訓練時難以收斂,並進行了改進,使其訓練難度大幅降低、收斂速度加快。主要有兩點改進:一是把目標函數中的log去掉,二是每次迭代更新權重后做weight clipping,把權重限制到一個范圍內(例如限定范圍[-0.1,+0.1],則超出這個范圍的權重都會被修剪到-0.1或+0.1)。
同年,wgan-gp又基於wgan提出了改進方案,因為wgan雖然降低了gan的訓練難度,但在一些設定下仍然難以收斂,並且生成圖片效果相比dcgan還要差,wgan-gp將weight clipping改為penalize the norm of the gradient of the critic with respect to its input(根據D的輸入后向計算出權重梯度,並針對梯度的范數進行懲罰),解決了上述問題。
3.基於tf的代碼實現
https://github.com/handspeaker/gan_practice
每種結構的代碼都比較簡潔,屬於實驗性質,沒有加太多東西,看起來比一些github上很大的工程容易懂些。
4.參考
Generative Adversarial Nets
Wasserstein GAN
Improved Training of Wasserstein GANs
https://github.com/shekkizh/WassersteinGAN.tensorflow
https://github.com/igul222/improved_wgan_training