利用tensorflow訓練簡單的生成對抗網絡GAN


對抗網絡是14年Goodfellow Ian在論文Generative Adversarial Nets中提出來的。 原理方面,對抗網絡可以簡單歸納為一個生成器(generator)和一個判斷器(discriminator)之間博弈的過程。整個網絡訓練的過程中,

兩個模塊的分工

  • 判斷網絡,直觀來看就是一個簡單的神經網絡結構,輸入就是一副圖像,輸出就是一個概率值,用於判斷真假使用(概率值大於0.5那就是真,小於0.5那就是假)
  • 生成網絡,同樣也可以看成是一個神經網絡模型,輸入是一組隨機數Z,輸出是一個圖像。

兩個模塊的訓練目的

 

  • 判別網絡的目的:就是能判別出來屬於的一張圖它是來自真實樣本集還是假樣本集。假如輸入的是真樣本,網絡輸出就接近1,輸入的是假樣本,網絡輸出接近0,那么很完美,達到了很好判別的目的。
  • 生成網絡的目的:生成網絡是造樣本的,它的目的就是使得自己造樣本的能力盡可能強,強到判別網絡沒法判斷是真樣本還是假樣本。

GAN的訓練

  需要注意的是生成模型與對抗模型可以說是完全獨立的兩個模型,好比就是完全獨立的兩個神經網絡模型,他們之間沒有什么聯系。

那么訓練這樣的兩個模型的大方法就是:單獨交替迭代訓練因為是2個網絡,不好一起訓練,所以才去交替迭代訓練,我們一一來看。 

  首先我們先隨機產生一個生成網絡模型(當然可能不是最好的生成網絡),那么給一堆隨機數組,就會得到一堆假的樣本集(因為不是最終的生成模型,那么現在生成網絡可能就處於劣勢,導致生成的樣本很糟糕,可能很容易就被判別網絡判別出來了說這貨是假冒的),但是先不管這個,假設我們現在有了這樣的假樣本集,真樣本集一直都有,現在我們人為的定義真假樣本集的標簽,因為我們希望真樣本集的輸出盡可能為1,假樣本集為0,很明顯這里我們就已經默認真樣本集所有的類標簽都為1,而假樣本集的所有類標簽都為0.

  對於生成網絡,回想下我們的目標,是生成盡可能逼真的樣本。那么原始的生成網絡生成的樣本你怎么知道它真不真呢?就是送到判別網絡中,所以在訓練生成網絡的時候,我們需要聯合判別網絡一起才能達到訓練的目的。就是如果我們單單只用生成網絡,那么想想我們怎么去訓練?誤差來源在哪里?細想一下沒有,但是如果我們把剛才的判別網絡串接在生成網絡的后面,這樣我們就知道真假了,也就有了誤差了。所以對於生成網絡的訓練其實是對生成-判別網絡串接的訓練,就像圖中顯示的那樣。好了那么現在來分析一下樣本,原始的噪聲數組Z我們有,也就是生成了假樣本我們有,此時很關鍵的一點來了,我們要把這些假樣本的標簽都設置為1,也就是認為這些假樣本在生成網絡訓練的時候是真樣本。這樣才能起到迷惑判別器的目的,也才能使得生成的假樣本逐漸逼近為正樣本。

 

下面是代碼部分,這里,我們利用訓練的兩個數據集分別是

  • mnist
  • Celeba

來生成手寫數字以及人臉

首先是數據集的下載

 

 1 import math
 2 import os
 3 import hashlib
 4 from urllib.request import urlretrieve
 5 import zipfile
 6 import gzip
 7 import shutil
 8 
 9 data_dir = './data'
10 
11 def download_extract(database_name, data_path):
12      """
13      Download and extract database
14      :param database_name: Database name
15      """
16      DATASET_CELEBA_NAME = 'celeba'
17      DATASET_MNIST_NAME = 'mnist'
18  
19      if database_name == DATASET_CELEBA_NAME:
20          url = 'https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/celeba.zip'
21          hash_code = '00d2c5bc6d35e252742224ab0c1e8fcb'
22          extract_path = os.path.join(data_path, 'img_align_celeba')
23          save_path = os.path.join(data_path, 'celeba.zip')
24          extract_fn = _unzip
25      elif database_name == DATASET_MNIST_NAME:
26          url = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'
27          hash_code = 'f68b3c2dcbeaaa9fbdd348bbdeb94873'
28          extract_path = os.path.join(data_path, 'mnist')
29          save_path = os.path.join(data_path, 'train-images-idx3-ubyte.gz')
30          extract_fn = _ungzip
31  
32      if os.path.exists(extract_path):
33          print('Found {} Data'.format(database_name))
34          return
35  
36      if not os.path.exists(data_path):
37          os.makedirs(data_path)
38  
39      if not os.path.exists(save_path):
40          with DLProgress(unit='B', unit_scale=True, miniters=1, desc='Downloading {}'.format(database_name)) as pbar:
41              urlretrieve(
42                  url,
43                  save_path,
44                  pbar.hook)
45  
46      assert hashlib.md5(open(save_path, 'rb').read()).hexdigest() == hash_code, \
47          '{} file is corrupted.  Remove the file and try again.'.format(save_path)
48  
49      os.makedirs(extract_path)
50      try:
51          extract_fn(save_path, extract_path, database_name, data_path)
52      except Exception as err:
53          shutil.rmtree(extract_path)  # Remove extraction folder if there is an error
54          raise err
55  
56      # Remove compressed data
57      os.remove(save_path)
58 
59 # download mnist
60 download_extract('mnist', data_dir)
61 # download celeba
62 download_extract('celeba', data_dir

 

 

 

 

我們先看看我們的mnist還有celeba數據集是什么樣子

 1 # the number of images
 2 show_n_images =16
 3 
 4 %matplotlib inline
 5 import os
 6 from glob import glob
 7 from matplotlib import pyplot
 8 
 9 def get_batch(image_files, width, height, mode):
10     data_batch = np.array(
11         [get_image(sample_file, width, height, mode) for sample_file in image_files]).astype(np.float32)
12 
13     # Make sure the images are in 4 dimensions
14     if len(data_batch.shape) < 4:
15         data_batch = data_batch.reshape(data_batch.shape + (1,))
16 
17     return data_batch
18 
19 def images_square_grid(images, mode):
26     # Get maximum size for square grid of images
27     save_size = math.floor(np.sqrt(images.shape[0]))
28 
29     # Scale to 0-255
30     images = (((images - images.min()) * 255) / (images.max() - images.min())).astype(np.uint8)
31 
32     # Put images in a square arrangement
33     images_in_square = np.reshape(
34             images[:save_size*save_size],
35             (save_size, save_size, images.shape[1], images.shape[2], images.shape[3]))
36     if mode == 'L':
37         images_in_square = np.squeeze(images_in_square, 4)
38 
39     # Combine images to grid image
40     new_im = Image.new(mode, (images.shape[1] * save_size, images.shape[2] * save_size))
41     for col_i, col_images in enumerate(images_in_square):
42         for image_i, image in enumerate(col_images):
43             im = Image.fromarray(image, mode)
44             new_im.paste(im, (col_i * images.shape[1], image_i * images.shape[2]))
45 
46     return new_im
47 
48 mnist_images = get_batch(glob(os.path.join(data_dir, 'mnist/*.jpg'))[:show_n_images], 28, 28, 'L')
49 pyplot.imshow(images_square_grid(mnist_images, 'L'), cmap='gray')

 

mninst:

 

1 show_n_images = 9
2 
3 mnist_images = get_batch(glob(os.path.join(data_dir, 'img_align_celeba/*.jpg'))[:show_n_images], 28, 28, 'RGB')
4 pyplot.imshow(images_square_grid(mnist_images, 'RGB'))

 

celeba

現在我們開始搭建網絡

這里我建議用GPU來訓練,tensorflow的版本最好是1.1.0

 1 from distutils.version import LooseVersion
 2 import warnings
 3 import tensorflow as tf
 4 
 5 # Check TensorFlow Version
 6 assert LooseVersion(tf.__version__) >= LooseVersion('1.0'), 'Please use TensorFlow version 1.0 or newer.  You are using {}'.format(tf.__version__)
 7 print('TensorFlow Version: {}'.format(tf.__version__))
 8 
 9 # Check for a GPU
10 if not tf.test.gpu_device_name():
11     warnings.warn('No GPU found. Please use a GPU to train your neural network.')
12 else:
13     print('Default GPU Device: {}'.format(tf.test.gpu_device_name()))

接着我們要做的是構建輸入

 1 def model_inputs(image_width, image_height, image_channels, z_dim):
 2     ## Real imag
 3     inputs_real = tf.placeholder(tf.float32,(None, image_width,image_height,image_channels), name = 'input_real')
 4 
 5     ## input z
 6     
 7     inputs_z = tf.placeholder(tf.float32,(None, z_dim), name='input_z')
 8     
 9     ## Learning rate 
10     learning_rate = tf.placeholder(tf.float32, name = 'lr')
11 
12     return inputs_real, inputs_z, learning_rate

構建Discriminator

 1 def discriminator(images, reuse=False):
 2     """
 3     Create the discriminator network
 4     :param images: Tensor of input image(s)
 5     :param reuse: Boolean if the weights should be reused
 6     :return: Tuple of (tensor output of the discriminator, tensor logits of the discriminator)
 7     """
 8     # TODO: Implement Function
 9     
10     ## scope here
11     
12     with tf.variable_scope('discriminator', reuse=reuse):
13         
14         alpha = 0.2  ### leak relu coeff
15         
16         # drop out probability 
17         keep_prob = 0.8
18         
19         # input layer 28 * 28 * color channel
20         x1 = tf.layers.conv2d(images, 128, 5, strides=2, padding='same',
21                               kernel_initializer= tf.contrib.layers.xavier_initializer(seed=2))
22         ## No batch norm here
23         ## leak relu here / alpha = 0.2
24         relu1 = tf.maximum(alpha * x1, x1)
25         # applied drop out here
26         drop1 = tf.nn.dropout(relu1, keep_prob= keep_prob)
27         # 14 * 14 * 128
28         
29         # Layer 2
30         x2 = tf.layers.conv2d(drop1, 256, 5, strides=2, padding='same',
31                              kernel_initializer= tf.contrib.layers.xavier_initializer(seed=2))
32         ## employ batch norm here
33         bn2 = tf.layers.batch_normalization(x2, training=True)
34         ## leak relu 
35         relu2 = tf.maximum(alpha * bn2, bn2)
36         drop2 = tf.nn.dropout(relu2, keep_prob=keep_prob)
37         
38         # 7 * 7 * 256 
39         
40         # Layer3
41         x3 = tf.layers.conv2d(drop2, 512, 5, strides=2, padding='same',
42                              kernel_initializer= tf.contrib.layers.xavier_initializer(seed=2))
43         bn3 = tf.layers.batch_normalization(x3, training=True)
44         relu3 = tf.maximum(alpha * bn3, bn3)
45         drop3 = tf.nn.dropout(relu3, keep_prob=keep_prob)
46         # 4 * 4 * 512
47         
48         # Output
49         # Flatten
50         flatten = tf.reshape(relu3, (-1, 4 * 4 * 512))
51         logits = tf.layers.dense(flatten,1)
52         # activation
53         out = tf.nn.sigmoid(logits)
54      
55     return out, logits

接着是 Generator

 1 def generator(z, out_channel_dim, is_train=True):
 2     """
 3     Create the generator network
 4     :param z: Input z
 5     :param out_channel_dim: The number of channels in the output image
 6     :param is_train: Boolean if generator is being used for training
 7     :return: The tensor output of the generator
 8     """
 9     # TODO: Implement Function
10     
11     with tf.variable_scope('generator', reuse = not is_train):
12         # First Fully connect layer
13         x0 = tf.layers.dense(z, 4 * 4 * 512)
14         # Reshape 
15         x0 = tf.reshape(x0,(-1,4,4,512))
16         # Use the batch norm
17         bn0 = tf.layers.batch_normalization(x0, training= is_train)
18         # Leak relu
19         relu0 = tf.nn.relu(bn0)
20         # 4 * 4 * 512
21         
22         # Conv transpose here
23         x1 = tf.layers.conv2d_transpose(relu0, 256, 4, strides=1, padding='valid')
24         bn1 = tf.layers.batch_normalization(x1, training=is_train)
25         relu1 = tf.nn.relu(bn1)
26         # 7 * 7 * 256 
27         
28         x2 = tf.layers.conv2d_transpose(relu1, 128, 3, strides=2, padding='same')
29         bn2 = tf.layers.batch_normalization(x2, training=is_train)
30         relu2 = tf.nn.relu(bn2)
31         # 14 * 14 * 128
32         
33         # Last cov
34         logits = tf.layers.conv2d_transpose(relu2, out_channel_dim, 3, strides=2, padding='same')
35         ## without batch norm here
36         out = tf.tanh(logits)
37         
38         
39         return out

然后我們來定義loss,這里,加入了smoother

 1 def model_loss(input_real, input_z, out_channel_dim):
 2     """
 3     Get the loss for the discriminator and generator
 4     :param input_real: Images from the real dataset
 5     :param input_z: Z input
 6     :param out_channel_dim: The number of channels in the output image
 7     :return: A tuple of (discriminator loss, generator loss)
 8     """
 9     # TODO: Implement Function
10     
11     
12     g_model = generator(input_z, out_channel_dim, is_train=True)
13     
14     d_model_real, d_logits_real = discriminator(input_real, reuse = False)
15     
16     d_model_fake, d_logits_fake = discriminator(g_model, reuse= True)
17     
18     ## add smooth here
19     
20     smooth = 0.1
21     d_loss_real = tf.reduce_mean(
22         tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real, 
23                                                 labels=tf.ones_like(d_model_real) * (1 - smooth)))
24     
25     d_loss_fake = tf.reduce_mean(
26         tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.zeros_like(d_model_fake)))
27     
28     g_loss = tf.reduce_mean(
29         tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, 
30                                                 labels= tf.ones_like(d_model_fake)))
31     
32     d_loss = d_loss_real + d_loss_fake
33     
34     
35     
36     return d_loss, g_loss

接着我們需要定義網絡優化的過程,這里我們需要用到batch_normlisation, 不懂的話去搜下文檔

 1 def model_opt(d_loss, g_loss, learning_rate, beta1):
 2     """
 3     Get optimization operations
 4     :param d_loss: Discriminator loss Tensor
 5     :param g_loss: Generator loss Tensor
 6     :param learning_rate: Learning Rate Placeholder
 7     :param beta1: The exponential decay rate for the 1st moment in the optimizer
 8     :return: A tuple of (discriminator training operation, generator training operation)
 9     """
10     
11     t_vars = tf.trainable_variables()
12     d_vars = [var for var in t_vars if var.name.startswith('discriminator')]
13     g_vars = [var for var in t_vars if var.name.startswith('generator')] 
14     
15     
16     
17     update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
18 
19     with tf.control_dependencies(update_ops):
20         d_train_opt = tf.train.AdamOptimizer(learning_rate,beta1=beta1).minimize(d_loss,var_list = d_vars)
21         g_train_opt = tf.train.AdamOptimizer(learning_rate,beta1=beta1).minimize(g_loss,var_list = g_vars)
22     
23     return d_train_opt, g_train_opt

現在,我們網絡的模塊,損失函數,以及優化的過程都定義好了,現在我們就要開始訓練我們的網絡了,我們的訓練過程定義如下。

 1 def train(epoch_count, batch_size, z_dim, learning_rate, beta1, get_batches, data_shape, data_image_mode):
 2     """
 3     Train the GAN
 4     :param epoch_count: Number of epochs
 5     :param batch_size: Batch Size
 6     :param z_dim: Z dimension
 7     :param learning_rate: Learning Rate
 8     :param beta1: The exponential decay rate for the 1st moment in the optimizer
 9     :param get_batches: Function to get batches
10     :param data_shape: Shape of the data
11     :param data_image_mode: The image mode to use for images ("RGB" or "L")
12     """
13     losses = []
14     samples = []
15     
16     input_real, input_z, lr = model_inputs(data_shape[1], data_shape[2], data_shape[3], z_dim)
17     
18     d_loss, g_loss = model_loss(input_real,input_z,data_shape[-1])
19     
20     d_opt, g_opt = model_opt(d_loss, g_loss, learning_rate, beta1)
21 
22     steps = 0
23     
24     with tf.Session() as sess:
25         sess.run(tf.global_variables_initializer())
26         for epoch_i in range(epoch_count):
27             for batch_images in get_batches(batch_size):
28                 # TODO: Train Model
29                 steps += 1
30                 
31                 # Reshape the image and pass to Discriminator 
32                 batch_images = batch_images.reshape(batch_size, 
33                                                     data_shape[1], 
34                                                     data_shape[2],
35                                                     data_shape[3])
36                 # Rescale the data to -1 and 1
37                 batch_images = batch_images * 2
38                 
39                 # Sample the noise 
40                 batch_z = np.random.uniform(-1,1,size = (batch_size, z_dim))
41                 
42                 
43                 ## Run optimizer
44                 _ = sess.run(d_opt, feed_dict = {input_real:batch_images, 
45                                                  input_z:batch_z,
46                                                  lr:learning_rate
47                                                  })
48                 _ = sess.run(g_opt, feed_dict = {input_real:batch_images,
49                                                  input_z:batch_z,
50                                                  lr:learning_rate})
51                 
52                 if steps % 10 == 0:
53                     
54                     train_loss_d = d_loss.eval({input_real:batch_images, input_z:batch_z})
55                     train_loss_g = g_loss.eval({input_real:batch_images, input_z:batch_z})
56                     
57                     losses.append((train_loss_d,train_loss_g))
58                     
59                     print("Epoch {}/{}...".format(epoch_i+1, epochs),
60                           "Discriminator Loss: {:.4f}...".format(train_loss_d),
61                           "Generator Loss: {:.4f}".format(train_loss_g))
62                 
63                 if steps % 100 == 0:
64                     
65                     show_generator_output(sess, 25, input_z, data_shape[-1], data_image_mode)

開始訓練,超參數的設置

對於MNIST

 1 batch_size = 64
 2 z_dim = 100
 3 learning_rate = 0.001
 4 beta1 = 0.5
 5 epochs = 2
 6 
 7 mnist_dataset = helper.Dataset('mnist', glob(os.path.join(data_dir, 'mnist/*.jpg')))
 8 with tf.Graph().as_default():
 9     train(epochs, batch_size, z_dim, learning_rate, beta1, mnist_dataset.get_batches,
10           mnist_dataset.shape, mnist_dataset.image_mode)

訓練效果如下

開始的時候,網絡的參數很差,我們生成的手寫數字的效果自然就不好

隨着訓練的進行,輪廓逐漸清晰,效果如下,到最后:

我們看到數字的輪廓基本是清晰可以辨認的,當然,這只是兩個epoch的結果,如果有足夠的時間經過更長時間的訓練,效果會更好。

我們同樣展示下對celeba人臉數據集的訓練結果

 1 batch_size = 32
 2 z_dim = 100
 3 learning_rate = 0.001
 4 beta1 = 0.4
 5 epochs = 1
 6 
 7 celeba_dataset = helper.Dataset('celeba', glob(os.path.join(data_dir, 'img_align_celeba/*.jpg')))
 8 with tf.Graph().as_default():
 9     train(epochs, batch_size, z_dim, learning_rate, beta1, celeba_dataset.get_batches,
10           celeba_dataset.shape, celeba_dataset.image_mode)

訓練開始:

經過一個epoch之后:

人臉的輪廓基本清晰了。

 

這里我們就是用了DCGAN最簡單的方式來實現,原理過程說的不是很詳細,同時,可能這個參數設置也不是很合理,訓練的也不夠成分,但是我想可以幫大家快速掌握實現一個簡單的DCGAN的方法了。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM