GAN由論文《Ian Goodfellow et al., “Generative Adversarial Networks,” arXiv (2014)》提出。
GAN與VAEs的區別
GANs require differentiation through the visible units, and thus cannot model discrete data,
while VAEs require differentiation through the hidden units, and thus cannot have discrete latent variables.
即GAN不能處理離散數據,VAEs不能處理離散隱空間變量。
訓練過程
常見模型是最小化一個loss,GAN里的生成器和鑒別器則是一個minmax操作,即
同時,生成器更新一次后,鑒別器應該更新多次,這樣保證鑒別器可以維持在最優解附近。
如果生成器連續多次更新,而鑒別器不更新,則生成器傾向於生成那些“為難”鑒別器的同一批樣本,這樣生成器就缺乏多樣性。
論文中給出的算法流程(簡單的一次生成器更新對應多次鑒別器更新):
一些細節:
生成器使用relu和sigmoid激活函數,鑒別器使用maxout激活函數,Dropout只添加於鑒別器。
本文代碼使用的一些trick:
- 生成器最后的激活函數使用tanh代替sigmoid
- 隱空間中使用正態分布去采樣
- 添加隨機性因素。GAN是非常難以訓練的,添加一些噪音可以讓訓練不會輕易卡主。除了Dropout外,此處對鑒別器判斷的標簽也添加隨機噪音。
- 稀疏梯度(Sparse gradients)在一些網絡中通常是渴求的目標。但在GAN中,它會妨礙訓練過程。所以將maxpool替換為帶stride的卷積層,並使用leakyRELU代替relu激活函數。
- 為了避免產生的圖像如棋盤狀(即一個個正方形像素塊,而非連續流暢的像素),設定卷積窗口大小為步長的整數倍。
- 優化器使用的是RMSprop,並使用梯度裁剪和梯度衰減。
訓練過程為:
數據集為cifar10
定義生成器網絡,輸入為隱空間中一個矢量,輸出為一個圖片。
定義鑒別器網絡,輸入為生成器網絡采樣所得的圖片和真實圖片(以及標簽),輸出為sigmoid激活函數的標量值,即判斷圖片為真實還是偽造。
定義生成對抗網絡,為D(G(x))即生成網絡與鑒別網絡的嵌套形式。輸入為生成網絡的輸入,輸出為鑒別器網絡的輸出。
訓練時,使用高斯分布從隱空間中采樣,經過生成網絡得到生成的圖片,與真實圖片混合后(以及標簽)作為鑒別器網絡的輸入。
先訓練鑒別器。然后重新采樣生成圖片,此時需將這些圖片的標簽置為真實圖片的標簽(固定標簽后,訓練生成器,即讓其參數調整到鑒別器都以為確實是真實圖片)。再訓練GAN(此時凍結鑒別器參數,訓練的只是生成器)。
可以看到,定義了3個模型,只是因為生成器網絡的訓練要基於鑒別器網絡進行。
代碼如下
import numpy as np from keras.datasets import cifar10 from keras.models import Model from keras.layers import Input,Dense,LeakyReLU,Reshape,Conv2D,Conv2DTranspose,Flatten,Dropout from keras.optimizers import RMSprop from keras.preprocessing import image import os latent_dim=32 # Cifar10圖片尺寸 height,width=(32,32) channels=3
3個網絡定義
# 生成網絡:將隱空間中矢量生成圖片,使用Conv2DTranspose generator_input=Input((latent_dim,)) x=Dense(128*16*16)(generator_input) # 只添加了一個alpha參數,其他地方跟書上一致,alpha默認0.3 x=LeakyReLU(alpha=0.1)(x) x=Reshape((16,16,128))(x) x=Conv2D(256,5,padding='same')(x) x=LeakyReLU(alpha=0.1)(x) # 結果為32*32*256,為避免生成圖片呈現棋盤的點陣格式,凡是使用strides的地方,窗口大小為strides的整數倍 x=Conv2DTranspose(256,4,strides=2,padding='same')(x) x=LeakyReLU(alpha=0.1)(x) x=Conv2D(256,5,padding='same')(x) x=LeakyReLU(alpha=0.1)(x) x=Conv2D(256,5,padding='same')(x) x=LeakyReLU(alpha=0.1)(x) # 結果為32*32*3,即一個圖片正確格式。使用tanh代替sigmoid x=Conv2D(channels,7,activation='tanh',padding='same')(x) generator=Model(generator_input,x)#它在包含在GAN里訓練的,所以這里不用編譯 # generator.summary() # 鑒別網絡 discriminator_input=Input((height,width,channels)) x=Conv2D(128,3)(discriminator_input) x=LeakyReLU(alpha=0.1)(x) x=Conv2D(128,4,strides=2)(x) x=LeakyReLU(alpha=0.1)(x) x=Conv2D(128,4,strides=2)(x) x=LeakyReLU(alpha=0.1)(x) # 2*2*128 x=Conv2D(128,4,strides=2)(x) x=LeakyReLU(alpha=0.1)(x) x=Flatten()(x) # Dropout和給標簽添加噪聲,可以避免GAN卡住 x=Dropout(0.4)(x) x=Dense(1,activation='sigmoid')(x) discriminator=Model(discriminator_input,x) # discriminator.summary() # clipvalue,梯度超過這個值就截斷,decay,衰減,使得訓練穩定 discriminator_optimizer=RMSprop(lr=0.0003,clipvalue=1.0,decay=1e-8) discriminator.compile(optimizer=discriminator_optimizer,loss='binary_crossentropy') # 最后的生成對抗網絡,由生成網絡與對抗網絡組合而成,此時凍結鑒別網絡,訓練的只是生成網絡 discriminator.trainable=False # 組成整個生成對抗網絡 gan_input=Input((latent_dim,)) # 最終網絡形式為鑒別網絡作用於生成網絡,故生成器也不用compile gan_output=discriminator(generator(gan_input)) gan_optimizer=RMSprop(lr=0.0004,clipvalue=1.0,decay=1e-8) gan=Model(gan_input,gan_output) gan.compile(optimizer=gan_optimizer,loss='binary_crossentropy')
訓練過程,此處並未使用多次鑒別器更新一次生成器更新,你可以自己調整(即循環里面開頭添加個循環,訓練鑒別器)。
(x_train,y_train),(x_test,y_test)=cifar10.load_data() # 選擇frog類別,總共10個類 x_train=x_train[y_train.flatten()==6] # reshape到輸入格式 nums*height*width*channels,像素歸一化 x_train=x_train.reshape((x_train.shape[0],)+(height,width,channels)).astype('float32')/255. iters=10000 batch_size=20 save_dir='frog' start=0 for step in range(iters): # 選取潛空間中隨機矢量(正態分布) random_latent_vec=np.random.normal(size=(batch_size,latent_dim)) # 生成網絡產生圖片 generated_images=generator.predict(random_latent_vec) stop=start+batch_size # 真實原始圖片 real_images=x_train[start:stop] # mix生成和真實圖片 combined_images=np.concatenate([generated_images,real_images]) # mix labels labels=np.concatenate([np.ones((batch_size,1)),np.zeros((batch_size,1))]) # trick:標簽添加隨機噪聲 labels+=0.05*np.random.random(labels.shape) # 鑒別loss,可能為負,因為使用的是LeakyReLU d_loss=discriminator.train_on_batch(combined_images,labels) # 重新生成隨機矢量 random_latent_vec=np.random.normal(size=(batch_size,latent_dim)) # 故意設置標簽為真實 misleading_targets=np.zeros((batch_size,1)) a_loss=gan.train_on_batch(random_latent_vec,misleading_targets) start+=batch_size if start>len(x_train)-batch_size: start=0 if step%100==0: # gan.save_weights('gan.h5') print('discriminator loss:',d_loss) print('adversarial loss:',a_loss) # 保存一個batch里的第一個圖片,之前像素歸一化了,這里乘以255還原 img=image.array_to_img(generated_images[0]*255.,scale=False) img.save(os.path.join(save_dir,'generated_frog'+str(step)+'.png')) # 保存一個對比圖片 img=image.array_to_img(real_images[0]*255.,scale=False) img.save(os.path.join(save_dir,'real_frog'+str(step)+'.png'))
loss變化趨勢,可以看到是不穩定的
看真實圖和生成圖片對比,上下2行圖片只是同一批保存的,沒有相關性。這是訓練4000步,也即80000個訓練樣本后的結果。看起來比較丑陋吧。