由於筆者水平有限,如有錯,歡迎指正。
0 GAN的思想
GAN,全稱為 Generative Adversarial Nets,直譯為生成式對抗網絡,是一種非監督式模型。
GAN的主要靈感來源於博弈論中零和博弈的思想,應用到深度學習神經網絡上來說,就是通過生成網絡G(Generator)和判別網絡D(Discriminator)不斷博弈,進而使G學習到數據的分布,
GAN網絡最強大的地方就是可以幫助我們建立模型,而不像傳統的網絡那樣是在已有模型上幫我們更新參數而已。同時,因為GAN網絡是一種無監督的學習方式,它的泛化性非常好。
1 GAN模型
1.1網絡結構

上圖都描述了GAN的核心網絡,在生成網絡中,得到假的數據,然后和真的數據一起喂入判別模型,判別模型判斷輸入的樣本是真是假,先訓練識別網絡,再訓練生成網絡,再訓練識別網絡,如此反復,直到平衡。
1.2具體過程
-
生成模型:比作是一個樣本生成器,輸入一個噪聲/樣本,然后把它包裝成一個逼真的樣子,也就是輸出。
-
生成網絡是造樣本,它的目的就是使得自己造樣本的能力盡可能強,強到什么程度呢,判別網絡沒法判斷我是真樣本還是假樣本。
-
通常這個網絡選用最普通的多層隨機網絡即可,網絡太深容易引起梯度消失或者梯度爆炸。
-
-
判別模型:比作一個二分類器(如同0-1分類器),來判斷輸入的樣本是真是假。(就是輸出值大於0.5還是小於0.5)
- 判別出來屬於的一張圖它是來自真實樣本集還是假樣本集。若輸入的是真樣本,輸出就接近1,輸出的是假樣本,輸出接近0。
訓練過程中,生成網絡G的目標就是盡量生成真實的圖片去欺騙判別網絡D。而D的目標就是盡量辨別出G生成的假圖像和真實的圖像。這樣,G和D構成了一個動態的“博弈過程”,最終的平衡點即納什均衡點.。
納什均衡是指博弈中這樣的局面,對於每個參與者來說,只要其他人不改變策略,他就無法改善自己的狀況。

上圖是是論文中的一張過程圖,判別分布(藍色,虛線) ,生成數據的實際分布(黑色,虛線),數據的生成分布(綠色,實線)
(a) 對於D(判別網絡)剛開始訓練,有波動,但基本可以區分實際數據和生成數據;
(b) 隨着訓練的進行,D可以明顯的區分實際數據和生成數據;
(c) 隨着G的更新,綠色的線能夠趨近於黑色的線;
(d) 經過幾步訓練,如果G和D有足夠的能力,他們將達到平衡,辨別器無法區分兩個分布,即D(x)= 1;
1.3訓練結果
最終,訓練結束后,生成模型 G 恢復了訓練數據的分布(造出了和真實數據一模一樣的樣本),判別模型再也判別不出來結果,准確率為 50%,約等於亂猜。這是雙方網路都得到利益最大化,不再改變自己的策略,也就是不再更新自己的權重。如果loss值很低,則生成器成功欺騙了識別器(把假數據當成和label一樣也是1了),如果loss很大(label盡管是1,但是識別器還是預測為0,識別器判斷出了真假),說明生成器還需提升)。
3 代碼實現
- 避免使用RELU和pooling層,減少稀疏梯度的可能性,使用leakrelu激活函數;
- 最后一層的激活函數使用tanh;
- 在鑒別器中使用dropout;
3.1 Generative model:
model = Sequential()
model.add(Dense(256, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
model.add(Reshape(self.img_shape))
model.summary()
noise = Input(shape=(self.latent_dim,))
img = model(noise)
3.2 Discriminator model:
model = Sequential()
model.add(Flatten(input_shape=self.img_shape))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation='sigmoid'))
model.summary()
img = Input(shape=self.img_shape)
validity = model(img)
3.3 GAN
discriminator.trainable = False
gan_input = keras.Input(shape=(latent_dim,))
gan_output = discriminator(generator(gan_input))
gan = keras.models.Model(gan_input, gan_output)
gan_optimizer = keras.optimizers.RMSprop(lr=4e-4, clipvalue=1.0, decay=1e-8)
gan.compile(optimizer=gan_optimizer, loss='binary_crossentropy')