生成式對抗網絡(GAN, Generative Adversarial Networks )是一種深度學習模型,是最近超級火的一個無監督學習方法,它主要由兩部分組成,一部分是生成模型G(generator),另一部分是判別模型D(discriminator),它的訓練過程可大致描述如下:
生成模型通過接收一個隨機噪聲來生成圖片,判別模型用來判斷這個圖片是不是“真實的”,也就是說,生成網絡的目標是盡量生成真實的圖片去欺騙判別網絡,判別網絡的目標就是把G生成的圖片和真實的圖片區分開來,從而構成一個動態的博弈過程。
GAN主要用來解決的問題是:在數據量不足的情況下,通過小型數據集去生成一些數據
從理論上來說,GAN系列神經網絡可以用來模擬任何數據分布,但是目前更主要用於圖像。
而事實也證明,GAN生成的數據是可以直接用在實際的圖像問題上的,如行人重識別數據集,細粒度識別等。

(GAN的網絡結構及訓練流程)
下面是用keras實現的GAN:
1 from __future__ import print_function, division 2 3 from keras.datasets import mnist 4 from keras.layers import Input, Dense, Reshape, Flatten, Dropout 5 from keras.layers import BatchNormalization, Activation, ZeroPadding2D 6 from keras.layers.advanced_activations import LeakyReLU 7 from keras.layers.convolutional import UpSampling2D, Conv2D 8 from keras.models import Sequential, Model 9 from keras.optimizers import Adam 10 11 import matplotlib.pyplot as plt 12 13 import sys 14 15 import numpy as np 16 17 class GAN(): 18 def __init__(self): 19 # 定義輸入圖像尺寸及通道 20 self.img_rows = 28 21 self.img_cols = 28 22 self.channels = 1 23 self.img_shape = (self.img_rows, self.img_cols, self.channels) 24 self.latent_dim = 100 25 26 # 設置網絡優化器 27 optimizer = Adam(0.0002, 0.5) 28 29 # 構建判別網絡 30 self.discriminator = self.build_discriminator() 31 self.discriminator.compile(loss='binary_crossentropy', 32 optimizer=optimizer, 33 metrics=['accuracy']) 34 35 # 構建生成網絡 36 self.generator = self.build_generator() 37 38 # 生成器根據噪聲生成圖像 39 z = Input(shape=(self.latent_dim,)) 40 img = self.generator(z) 41 42 # 在聯合模型中,設置判別器參數不可訓練 43 self.discriminator.trainable = False 44 45 # 判別器驗證生成圖像 46 validity = self.discriminator(img) 47 48 # 訓練生成器來欺騙判別器 49 self.combined = Model(z, validity) 50 self.combined.compile(loss='binary_crossentropy', optimizer=optimizer) 51 52 53 # 生成器結構 54 def build_generator(self): 55 56 model = Sequential() 57 58 model.add(Dense(256, input_dim=self.latent_dim)) 59 model.add(LeakyReLU(alpha=0.2)) 60 model.add(BatchNormalization(momentum=0.8)) 61 model.add(Dense(512)) 62 model.add(LeakyReLU(alpha=0.2)) 63 model.add(BatchNormalization(momentum=0.8)) 64 model.add(Dense(1024)) 65 model.add(LeakyReLU(alpha=0.2)) 66 model.add(BatchNormalization(momentum=0.8)) 67 model.add(Dense(np.prod(self.img_shape), activation='tanh')) 68 model.add(Reshape(self.img_shape)) 69 70 model.summary() 71 72 noise = Input(shape=(self.latent_dim,)) 73 img = model(noise) 74 75 return Model(noise, img) 76 77 # 判別器結構 78 def build_discriminator(self): 79 80 model = Sequential() 81 82 model.add(Flatten(input_shape=self.img_shape)) 83 model.add(Dense(512)) 84 model.add(LeakyReLU(alpha=0.2)) 85 model.add(Dense(256)) 86 model.add(LeakyReLU(alpha=0.2)) 87 model.add(Dense(1, activation='sigmoid')) 88 model.summary() 89 90 img = Input(shape=self.img_shape) 91 validity = model(img) 92 93 return Model(img, validity) 94 95 # 定義訓練過程 96 def train(self, epochs, batch_size=128, sample_interval=50): 97 (X_train, _), (_, _) = mnist.load_data() 98 99 X_train = X_train / 127.5 - 1. 100 X_train = np.expand_dims(X_train, axis=3) 101 102 valid = np.ones((batch_size, 1)) 103 fake = np.zeros((batch_size, 1)) 104 105 for epoch in range(epochs): 106 107 idx = np.random.randint(0, X_train.shape[0], batch_size) 108 imgs = X_train[idx] 109 110 noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) 111 112 gen_imgs = self.generator.predict(noise) 113 114 d_loss_real = self.discriminator.train_on_batch(imgs, valid) 115 d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) 116 d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) 117 118 noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) 119 120 # 根據判別器valid訓練生成器 121 g_loss = self.combined.train_on_batch(noise, valid) 122 123 print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss)) 124 125 # 保存生成圖像 126 if epoch % sample_interval == 0: 127 self.sample_images(epoch) 128 129 def sample_images(self, epoch): 130 r, c = 5, 5 131 noise = np.random.normal(0, 1, (r * c, self.latent_dim)) 132 gen_imgs = self.generator.predict(noise) 133 134 gen_imgs = 0.5 * gen_imgs + 0.5 135 136 fig, axs = plt.subplots(r, c) 137 cnt = 0 138 for i in range(r): 139 for j in range(c): 140 axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray') 141 axs[i,j].axis('off') 142 cnt += 1 143 fig.savefig("images/%d.png" % epoch) 144 plt.close() 145 146 147 if __name__ == '__main__': 148 gan = GAN() 149 gan.train(epochs=30000, batch_size=32, sample_interval=200)
程序初始運行結果如下:

訓練完成后效果如下:

