1 """ 2 生成對抗網絡(GAN,Generative Adversarial Networks)的基本原理很簡單: 3 假設有兩個網絡,生成網絡G和判別網絡D。生成網絡G接受一個隨機的噪聲z並生成圖片, 4 記為G(z);判別網絡D的作用是判別一張圖片x是否真實,對於輸入x,D(x)是x為真實圖片的概率。 5 在訓練過程中, 生成器努力讓生成的圖片更加真實從而使得判別器無法辨別圖像的真假, 6 而D的目標就是盡量把分辨出真實圖片和生成網絡G產出的圖片,這個過程就類似於二人博弈, 7 G和D構成了一個動態的“博弈過程”。隨着時間的推移,生成器和判別器在不斷地進行對抗, 8 最終兩個網絡達到一個動態平衡:生成器生成的圖像G(z)接近於真實圖像分布,而判別器識別不出真假圖像, 9 即D(G(z))=0.5。最后,我們就可以得到一個生成網絡G,用來生成圖片。 10 """ 11 import tensorflow as tf 12 from matplotlib import pyplot as plt 13 import os 14 import numpy as np 15 from tensorflow.examples.tutorials.mnist import input_data 16 mnist=input_data.read_data_sets('/MNIST_data/',one_hot=True) 17 batch_size=64 18 units_size=128 19 learning_rate=0.001 20 epoch=300 21 smooth=0.1 22 """定義生成模型""" 23 def generatorModel(noise_img,units_size,out_size,alpha=0.01): 24 """生成器的目的是:對於生成的圖片,G希望D打上標簽1""" 25 with tf.variable_scope('generator'): 26 FC=tf.layers.dense(noise_img,units_size) 27 relu=tf.nn.leaky_relu(FC,alpha) 28 drop=tf.layers.dropout(relu,rate=0.2) 29 logits=tf.layers.dense(drop,out_size) 30 outputs=tf.tanh(logits) 31 return logits,outputs 32 33 """定義判別模型""" 34 def discriminatorModel(images,unite_size,alpha=0.01,reuse=False): 35 """ 36 判別器的目的是: 37 1. 對於真實圖片,D要為其打上標簽1 38 2. 對於生成圖片,D要為其打上標簽0 39 """ 40 with tf.variable_scope('discriminator',reuse=reuse): 41 FC=tf.layers.dense(images,units_size) 42 relu=tf.nn.leaky_relu(FC,alpha) 43 logits=tf.layers.dense(relu,1) 44 outputs=tf.sigmoid(logits) 45 return logits,outputs 46 """定義損失函數""" 47 def loss_fenction(real_logits,fake_logits,smooth): 48 """生成器希望判別器判別出來的標簽為1; tf.ones_like()創建一個將所有元素都設置為1的張量""" 49 G_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 50 logits=fake_logits, 51 labels=tf.ones_like(fake_logits)*(1-smooth)) 52 ) 53 """判別器識別生成器產出的圖片,希望識別出來的標簽為0;tf.zeros_like()創建一個將所有元素都設置為0的張量""" 54 fake_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 55 logits=fake_logits, 56 labels=tf.zeros_like(fake_logits)) 57 ) 58 """判別器判別真實圖片,希望判別出來的標簽為1;tf.ones_like()創建一個將所有元素都設置為1的張量""" 59 real_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 60 logits=real_logits, 61 labels=tf.ones_like(real_logits)*(1-smooth)) 62 ) 63 D_loss=tf.add(fake_loss,real_loss) 64 return G_loss,fake_loss,real_loss,D_loss 65 """定義優化器""" 66 def optimizer(G_loss,D_loss,learning_rate): 67 """因為GAN中一共訓練了兩個網絡,所以分別對G和D進行優化""" 68 train_var=tf.trainable_variables() #需要訓練的變量 69 G_var=[var for var in train_var if var.name.startswith('generator')] 70 D_var=[var for var in train_var if var.name.startswith('discriminator')] 71 G_optimizer=tf.train.AdadeltaOptimizer(learning_rate).minimize(G_loss,var_list=G_var) 72 D_optimizer=tf.train.AdadeltaOptimizer(learning_rate).minimize(D_loss,var_list=D_var) 73 return G_optimizer,D_optimizer 74 """訓練""" 75 def train(mnist): 76 image_size = mnist.train.images[0].shape[0] 77 real_images = tf.placeholder(tf.float32,[None,image_size]) 78 fake_images = tf.placeholder(tf.float32,[None,image_size]) 79 """調用生成模型生成假圖片G_output""" 80 G_logits,G_output = generatorModel(fake_images,units_size,image_size) 81 """D對真實圖像的判別""" 82 real_logits,real_output = discriminatorModel(real_images,units_size) 83 """D對G生成圖像的判別""" 84 fake_logits,fake_output=discriminatorModel(G_output,units_size,reuse=True) 85 G_loss,real_loss,fake_loss,D_loss=loss_fenction(real_logits,fake_logits,smooth) 86 G_optimizer,D_optimizer=optimizer(G_loss,D_loss,learning_rate) 87 88 saver=tf.train.Saver() 89 step=0 90 with tf.Session() as session: 91 session.run(tf.global_variables_initializer()) 92 for Epoch in range(epoch): 93 for batch_i in range(mnist.train.num_examples//batch_size): 94 batch_image,_=mnist.train.next_batch(batch_size) 95 """對圖像像素進行scale,tanh的輸出結果為(-1,1)""" 96 batch_image=batch_image*2-1 97 """模型的輸入噪聲""" 98 noise_image=np.random.uniform(-1,1,size=(batch_size,image_size))#從均勻分布[-1,1)中隨機采樣 99 session.run(G_optimizer,feed_dict={fake_images:noise_image}) 100 session.run(D_optimizer,feed_dict={real_images:batch_image,fake_images:noise_image}) 101 step=step+1 102 loss_D= session.run(D_loss, feed_dict={real_images: batch_image, fake_images: noise_image}) 103 loss_real= session.run(real_loss, feed_dict={real_images: batch_image, fake_images: noise_image}) 104 loss_fake= session.run(fake_loss, feed_dict={real_images: batch_image, fake_images: noise_image}) 105 loss_G= session.run(G_loss, feed_dict={fake_images: noise_image}) 106 print('epoch:', Epoch, 'loss_D:', loss_D,'loss_real', loss_real,'loss_fake', loss_fake, 'loss_G', loss_G) 107 model_path=os.getcwd()+os.sep+"mnist.model" 108 saver.save(session,model_path,global_step=step) 109 """定義主函數""" 110 def main(argv=None): 111 train(mnist) 112 if __name__ =='__main__': 113 tf.app.run()
1 import tensorflow as tf 2 import numpy as np 3 from matplotlib import pyplot as plt 4 import pickle 5 import example88_0 6 7 UNITS_SIZE = example88_0.units_size 8 9 10 def generatorImage(image_size): 11 sample_images = tf.placeholder(tf.float32, [None, image_size]) 12 G_logits, G_output = example88_0.generatorModel(sample_images, UNITS_SIZE, image_size) 13 saver = tf.train.Saver() 14 with tf.Session() as session: 15 session.run(tf.global_variables_initializer()) 16 saver.restore(session, tf.train.latest_checkpoint('.')) 17 sample_noise = np.random.uniform(-1, 1, size=(25, image_size)) 18 samples = session.run(G_output, feed_dict={sample_images: sample_noise}) 19 with open('samples.pkl', 'wb') as f: 20 pickle.dump(samples, f) 21 22 23 def show(): 24 with open('samples.pkl', 'rb') as f: 25 samples = pickle.load(f) 26 fig, axes = plt.subplots(figsize=(7, 7), nrows=5, ncols=5, sharey=True, sharex=True) 27 for ax, image in zip(axes.flatten(), samples): 28 ax.xaxis.set_visible(False) 29 ax.yaxis.set_visible(False) 30 ax.imshow(image.reshape((28, 28)), cmap='Greys_r') 31 plt.show() 32 33 34 def main(argv=None): 35 image_size = example88_0.mnist.train.images[0].shape[0] 36 generatorImage(image_size) 37 show() 38 39 40 if __name__ == '__main__': 41 tf.app.run()