DCGAN 代碼簡單解讀


    之前在DCGAN文章簡單解讀里說明了DCGAN的原理。本次來實現一個DCGAN,並在數據集上實際測試它的效果。本次的代碼來自github開源代碼DCGAN-tensorflow,感謝carpedm20的貢獻!

1. 代碼結構

    代碼結構如下圖1所示:

圖1 代碼結構

我們主要關注的文件為download.py,main.py,model.py,ops.py以及utils.py。其實看文件名字就大概可以猜出各個文件的作用了。

  • download.py主要下載數據集到本地,這里我們需要下載三個數據集:MNIST,lsun以及celebA
  • main.py是主函數,用於配置命令行參數以及模型的訓練和測試。
  • model.py 是定義DCGAN模型的地方,也是我們要重點關注的代碼。
  • ops.py 定義了很多構造模型的重要函數,比如batch_norm(BN操作),conv2d(卷積操作),deconv2d(翻卷積操作)等。
  • utils.py 定義很多有用的全局輔助函數。

2. 代碼簡單解讀

2.1 download.py

download.py代碼如下:

"""
Modification of https://github.com/stanfordnlp/treelstm/blob/master/scripts/download.py

Downloads the following:
- Celeb-A dataset
- LSUN dataset
- MNIST dataset
"""

from __future__ import print_function
import os
import sys
import gzip
import json
import shutil
import zipfile
import argparse
import requests
import subprocess
from tqdm import tqdm
from six.moves import urllib

parser = argparse.ArgumentParser(description='Download dataset for DCGAN.')
parser.add_argument('datasets', metavar='N', type=str, nargs='+', choices=['celebA', 'lsun', 'mnist'],
           help='name of dataset to download [celebA, lsun, mnist]')

def download(url, dirpath):
  filename = url.split('/')[-1]
  filepath = os.path.join(dirpath, filename)
  u = urllib.request.urlopen(url)
  f = open(filepath, 'wb')
  filesize = int(u.headers["Content-Length"])
  print("Downloading: %s Bytes: %s" % (filename, filesize))

  downloaded = 0
  block_sz = 8192
  status_width = 70
  while True:
    buf = u.read(block_sz)
    if not buf:
      print('')
      break
    else:
      print('', end='\r')
    downloaded += len(buf)
    f.write(buf)
    status = (("[%-" + str(status_width + 1) + "s] %3.2f%%") %
      ('=' * int(float(downloaded) / filesize * status_width) + '>', downloaded * 100. / filesize))
    print(status, end='')
    sys.stdout.flush()
  f.close()
  return filepath

def download_file_from_google_drive(id, destination):
  URL = "https://docs.google.com/uc?export=download"
  session = requests.Session()

  response = session.get(URL, params={ 'id': id }, stream=True)
  token = get_confirm_token(response)

  if token:
    params = { 'id' : id, 'confirm' : token }
    response = session.get(URL, params=params, stream=True)

  save_response_content(response, destination)

def get_confirm_token(response):
  for key, value in response.cookies.items():
    if key.startswith('download_warning'):
      return value
  return None

def save_response_content(response, destination, chunk_size=32*1024):
  total_size = int(response.headers.get('content-length', 0))
  with open(destination, "wb") as f:
    # 顯示進度條
    for chunk in tqdm(response.iter_content(chunk_size), total=total_size,
              unit='B', unit_scale=True, desc=destination):
      if chunk: # filter out keep-alive new chunks
        f.write(chunk)

def unzip(filepath):
  print("Extracting: " + filepath)
  dirpath = os.path.dirname(filepath)
  with zipfile.ZipFile(filepath) as zf:
    zf.extractall(dirpath)
  os.remove(filepath)

def download_celeb_a(dirpath):
  data_dir = 'celebA'
  # ./data/celebA
  if os.path.exists(os.path.join(dirpath, data_dir)):
    print('Found Celeb-A - skip')
    return

  filename, drive_id  = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM"
  # ./data/img_align_celeba.zip
  save_path = os.path.join(dirpath, filename)
  if os.path.exists(save_path):
    print('[*] {} already exists'.format(save_path)) # 文件已經存在
  else:
    download_file_from_google_drive(drive_id, save_path)

  zip_dir = ''
  with zipfile.ZipFile(save_path) as zf:
    zip_dir = zf.namelist()[0] # 解壓以后默認文件夾的名字
    zf.extractall(dirpath) # 提取文件到該文件夾
  os.remove(save_path) # 移除壓縮文件
  # 重命名文件夾
  os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir))

def _list_categories(tag):
  url = 'http://lsun.cs.princeton.edu/htbin/list.cgi?tag=' + tag
  f = urllib.request.urlopen(url)
  return json.loads(f.read())

def _download_lsun(out_dir, category, set_name, tag):
  # locals(),Return a dictionary containing the current scope's local variables
  url = 'http://lsun.cs.princeton.edu/htbin/download.cgi?tag={tag}' \
      '&category={category}&set={set_name}'.format(**locals())
  print(url)
  if set_name == 'test':
    out_name = 'test_lmdb.zip'
  else:
    out_name = '{category}_{set_name}_lmdb.zip'.format(**locals())
  # out_path:./data/lsun/xxx.zip
  out_path = os.path.join(out_dir, out_name)
  cmd = ['curl', url, '-o', out_path]
  print('Downloading', category, set_name, 'set')
  # 調用linux命令
  subprocess.call(cmd)

def download_lsun(dirpath):
  data_dir = os.path.join(dirpath, 'lsun')
  if os.path.exists(data_dir):
    print('Found LSUN - skip')
    return
  else:
    os.mkdir(data_dir)

  tag = 'latest'
  #categories = _list_categories(tag)
  categories = ['bedroom']

  for category in categories:
    _download_lsun(data_dir, category, 'train', tag)
    _download_lsun(data_dir, category, 'val', tag)
  _download_lsun(data_dir, '', 'test', tag)

def download_mnist(dirpath):
  data_dir = os.path.join(dirpath, 'mnist')
  if os.path.exists(data_dir):
    print('Found MNIST - skip')
    return
  else:
    os.mkdir(data_dir)
  url_base = 'http://yann.lecun.com/exdb/mnist/'
  file_names = ['train-images-idx3-ubyte.gz',
                'train-labels-idx1-ubyte.gz',
                't10k-images-idx3-ubyte.gz',
                't10k-labels-idx1-ubyte.gz']
  for file_name in file_names:
    url = (url_base+file_name).format(**locals())
    print(url)
    out_path = os.path.join(data_dir,file_name)
    cmd = ['curl', url, '-o', out_path]
    print('Downloading ', file_name)
    subprocess.call(cmd)
    cmd = ['gzip', '-d', out_path]
    print('Decompressing ', file_name)
    subprocess.call(cmd)

def prepare_data_dir(path = './data'):
  if not os.path.exists(path):
    os.mkdir(path)

if __name__ == '__main__':
  args = parser.parse_args()
  prepare_data_dir()

  # 如果datasets參數是 ['CelebA', 'celebA', 'celebA'] 其中之一
  if any(name in args.datasets for name in ['CelebA', 'celebA', 'celebA']):
    download_celeb_a('./data')
  if 'lsun' in args.datasets:
    download_lsun('./data')
  if 'mnist' in args.datasets:
    download_mnist('./data')
  • 首先需要導入的包中,gzipzipfile用於文件壓縮和解壓縮相關;argparse用於構建命令行參數;requests用於http請求下載網絡文件資源;subprocess用於運行shell命令;tqdm用於進度條顯示;six包用於python2和python3的兼容,比如 from six.moves import urllib 這句就是導入python2.x的urllib庫。
  • 上面的代碼除了原作者加的注釋之外,我也已經加了一部分注釋,意思應該比較好理解了。主要做的事情,就是利用requests庫從網絡上將mnist,lsun以及celebA這三個數據集下載下來,保存在data目錄下。注意mnistcelebA數據集下載下來之后還進行了解壓縮。
  • 上面的三個數據集,mnist是著名的手寫數字數據庫,大家應該都已經很熟悉了;lsun是大型場景理解數據集(large-scale-scene-understanding);celebA是一個開源的人臉數據庫。除了mnist之外,其余兩個數據集體積都較大,celebA大概有20w+的圖像,壓縮文件體積為1.4G;而lsun有很多個場景不同的數據集,如果按照上面的腳本下載,下載的文件為bedroom數據集,壓縮文件有46G之大,而且其實下載下來的文件解壓后為mdb(Access數據庫)格式,不是原始圖片格式,不方便處理。所以我們實際會下載其他的數據集作為替代,比如這個room layout estimation(2G)數據。如果使用download.py腳本下載速度較慢的話,可以自行下載好數據集,然后放在data目錄下即可。

2.2 main.py

    main.py代碼如下:

import os
import scipy.misc
import numpy as np

from model import DCGAN
from utils import pp, visualize, to_json, show_all_variables

import tensorflow as tf

# tensorflow 定義命令行參數
flags = tf.app.flags
# flag_name, default_value, docstring
flags.DEFINE_integer("epoch", 25, "Epoch to train [25]")
flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]")
flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]")
flags.DEFINE_float("train_size", np.inf, "The size of train images [np.inf]")
flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]")
flags.DEFINE_integer("input_height", 108, "The size of image to use (will be center cropped). [108]")
flags.DEFINE_integer("input_width", None, "The size of image to use (will be center cropped). If None, same value as input_height [None]")
flags.DEFINE_integer("output_height", 64, "The size of the output images to produce [64]")
flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]")
flags.DEFINE_integer("print_every",100,"print train info every 100 iterations")
flags.DEFINE_integer("checkpoint_every",500,"save checkpoint file every 500 iterations")
flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]")
flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
flags.DEFINE_string("data_dir", "./data", "Root directory of dataset [data]")
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]")
flags.DEFINE_boolean("train", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]")
flags.DEFINE_integer("generate_test_images", 100, "Number of images to generate during test. [100]")
FLAGS = flags.FLAGS

def main(_):
  pp.pprint(flags.FLAGS.__flags)

  # 如果寬度沒有指定,那么和高度一樣
  if FLAGS.input_width is None:
    FLAGS.input_width = FLAGS.input_height
  if FLAGS.output_width is None:
    FLAGS.output_width = FLAGS.output_height

  if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)
  if not os.path.exists(FLAGS.sample_dir):
    os.makedirs(FLAGS.sample_dir)

  #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
  run_config = tf.ConfigProto()
  run_config.gpu_options.allow_growth=True

  with tf.Session(config=run_config) as sess:
    if FLAGS.dataset == 'mnist':
      dcgan = DCGAN(
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          y_dim=10,
          z_dim=FLAGS.generate_test_images,
          dataset_name=FLAGS.dataset,
          input_fname_pattern=FLAGS.input_fname_pattern,
          crop=FLAGS.crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir,
          data_dir=FLAGS.data_dir)
    else:
      dcgan = DCGAN(
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          z_dim=FLAGS.generate_test_images,
          dataset_name=FLAGS.dataset,
          input_fname_pattern=FLAGS.input_fname_pattern,
          crop=FLAGS.crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir,
          data_dir=FLAGS.data_dir)

    show_all_variables()

    if FLAGS.train:
      dcgan.train(FLAGS)
    else:
        # dcgan.load return:True,counter
      if not dcgan.load(FLAGS.checkpoint_dir)[0]: #沒有成功加載checkpoint file
        raise Exception("[!] Train a model first, then run test mode")


    # to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
    #                 [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
    #                 [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
    #                 [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
    #                 [dcgan.h4_w, dcgan.h4_b, None])

    # Below is codes for visualization
    OPTION = 4
    visualize(sess, dcgan, FLAGS, OPTION)

if __name__ == '__main__':
  tf.app.run()
  • 這里需要注意的是 flags = tf.app.flags 用於tensorflow構建命令行參數, flags.DEFINE_xxx(param,default,description) 用於定義命令行參數及其取值,第一個參數param是具體參數值,第二個參數default是參數默認取值,第三個參數description是參數描述字符串。
  • 在構建了sess之后,我們需要區分數據集是mnist還是其他數據集。因為mnist比較特殊,它有10個類別的數字圖像,所以我們在構建DCGAN的時候需要額外多傳遞一個y_dim=10參數。 show_all_variables 函數用於顯示model所有變量的具體信息。
  • 接下來如果是訓練狀態( FLAGS.train == True ),則進行模型訓練( dcgan.train(FLAGS) ;否則進行測試,即加載之前訓練時候保存的checkpoint文件,然后調用 visualize 函數進行test(該函數可以生成image或者gif,可視化展示訓練的效果)。
  • tf.app.run() 是常用的tensorflow運行的起始命令。

2.3 model.py

    model.py代碼如下:

from __future__ import division
import os
import time
import math
from glob import glob
import tensorflow as tf
import numpy as np
from six.moves import xrange

from ops import *
from utils import *

def conv_out_size_same(size, stride):
  return int(math.ceil(float(size) / float(stride)))

class DCGAN(object):
  def __init__(self, sess, input_height=108, input_width=108, crop=True,
         batch_size=64, sample_num = 64, output_height=64, output_width=64,
         y_dim=None, z_dim=100, gf_dim=64, df_dim=64,
         gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default',
         input_fname_pattern='*.jpg', checkpoint_dir=None, sample_dir=None, data_dir='data'):
    """

    Args:
      sess: TensorFlow session
      batch_size: The size of batch. Should be specified before training.
      y_dim: (optional) Dimension of dim for y. [None]
      z_dim: (optional) Dimension of dim for Z. [100]
      # 生成器第一個卷積層 filters size
      gf_dim: (optional) Dimension of gen filters in first conv layer. [64]
      # 鑒別器第一個卷積層filters size
      df_dim: (optional) Dimension of discrim filters in first conv layer. [64]
      # 生成器全連接層units size
      gfc_dim: (optional) Dimension of gen units for for fully connected layer. [1024]
      # 鑒別器全連接層units size
      dfc_dim: (optional) Dimension of discrim units for fully connected layer. [1024]
      # image channel
      c_dim: (optional) Dimension of image color. For grayscale input, set to 1. [3]
    """
    self.sess = sess
    self.crop = crop

    self.batch_size = batch_size
    self.sample_num = sample_num

    self.input_height = input_height
    self.input_width = input_width
    self.output_height = output_height
    self.output_width = output_width

    self.y_dim = y_dim
    self.z_dim = z_dim

    self.gf_dim = gf_dim
    self.df_dim = df_dim

    self.gfc_dim = gfc_dim
    self.dfc_dim = dfc_dim

    # batch normalization : deals with poor initialization helps gradient flow
    self.d_bn1 = batch_norm(name='d_bn1')
    self.d_bn2 = batch_norm(name='d_bn2')

    if not self.y_dim:
      self.d_bn3 = batch_norm(name='d_bn3')

    self.g_bn0 = batch_norm(name='g_bn0')
    self.g_bn1 = batch_norm(name='g_bn1')
    self.g_bn2 = batch_norm(name='g_bn2')

    if not self.y_dim:
      self.g_bn3 = batch_norm(name='g_bn3')

    self.dataset_name = dataset_name
    self.input_fname_pattern = input_fname_pattern
    self.checkpoint_dir = checkpoint_dir
    self.data_dir = data_dir

    if self.dataset_name == 'mnist':
      self.data_X, self.data_y = self.load_mnist()
      self.c_dim = self.data_X[0].shape[-1]
    else:
      # dir *.jpg
      self.data = glob(os.path.join(self.data_dir, self.dataset_name, self.input_fname_pattern))
      imreadImg = imread(self.data[0])
      if len(imreadImg.shape) >= 3: #check if image is a non-grayscale image by checking channel number
        self.c_dim = imread(self.data[0]).shape[-1] # color image,get image channel
      else:
        self.c_dim = 1

    self.grayscale = (self.c_dim == 1) # 是否是灰度圖像

    self.build_model()

  def build_model(self):
    if self.y_dim:
      self.y = tf.placeholder(tf.float32, [self.batch_size, self.y_dim], name='y')
    else:
      self.y = None

    if self.crop:
      image_dims = [self.output_height, self.output_width, self.c_dim]
    else:
      image_dims = [self.input_height, self.input_width, self.c_dim]

    # self.inputs shape:(batch_size,height,width,channel)
    self.inputs = tf.placeholder(
      tf.float32, [self.batch_size] + image_dims, name='real_images')

    inputs = self.inputs

    self.z = tf.placeholder(
      tf.float32, [None, self.z_dim], name='z')
    # 直方圖可視化
    self.z_sum = histogram_summary("z", self.z)

    self.G                  = self.generator(self.z, self.y)
    self.D, self.D_logits   = self.discriminator(inputs, self.y, reuse=False)
    self.sampler            = self.sampler(self.z, self.y)
    self.D_, self.D_logits_ = self.discriminator(self.G, self.y, reuse=True)
    
    self.d_sum = histogram_summary("d", self.D)
    self.d__sum = histogram_summary("d_", self.D_)
    self.G_sum = image_summary("G", self.G)

    def sigmoid_cross_entropy_with_logits(x, y):
      try:
        return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y)
      except:
        return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, targets=y)

    self.d_loss_real = tf.reduce_mean(
      sigmoid_cross_entropy_with_logits(self.D_logits, tf.ones_like(self.D)))
    self.d_loss_fake = tf.reduce_mean(
      sigmoid_cross_entropy_with_logits(self.D_logits_, tf.zeros_like(self.D_)))
    self.g_loss = tf.reduce_mean(
      sigmoid_cross_entropy_with_logits(self.D_logits_, tf.ones_like(self.D_)))

    # scalar_summary:Outputs a `Summary` protocol buffer containing a single scalar value
    # 返回一個scalar
    self.d_loss_real_sum = scalar_summary("d_loss_real", self.d_loss_real)
    self.d_loss_fake_sum = scalar_summary("d_loss_fake", self.d_loss_fake)
                          
    self.d_loss = self.d_loss_real + self.d_loss_fake

    self.g_loss_sum = scalar_summary("g_loss", self.g_loss)
    self.d_loss_sum = scalar_summary("d_loss", self.d_loss)

    t_vars = tf.trainable_variables()

    self.d_vars = [var for var in t_vars if 'd_' in var.name] # 鑒別器相關變量
    self.g_vars = [var for var in t_vars if 'g_' in var.name] # 生成器相關變量

    self.saver = tf.train.Saver()

  def train(self, config):
    d_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \
              .minimize(self.d_loss, var_list=self.d_vars)
    g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \
              .minimize(self.g_loss, var_list=self.g_vars)
    try:
      tf.global_variables_initializer().run()
    except:
      tf.initialize_all_variables().run()

    self.g_sum = merge_summary([self.z_sum, self.d__sum,
      self.G_sum, self.d_loss_fake_sum, self.g_loss_sum])
    self.d_sum = merge_summary(
        [self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum])
    self.writer = SummaryWriter("./logs", self.sess.graph)

    sample_z = np.random.uniform(-1, 1, size=(self.sample_num , self.z_dim))
    
    if config.dataset == 'mnist':
      sample_inputs = self.data_X[0:self.sample_num]
      sample_labels = self.data_y[0:self.sample_num]
    else:
      # self.data is like:["0.jpg","1.jpg",...]
      sample_files = self.data[0:self.sample_num]
      sample = [
          # get_image返回的是取值為(-1,1)的,shape為(resize_height,resize_width)的
          # ndarray
          get_image(sample_file,
                    input_height=self.input_height,
                    input_width=self.input_width,
                    resize_height=self.output_height,
                    resize_width=self.output_width,
                    crop=self.crop,
                    grayscale=self.grayscale) for sample_file in sample_files]
      if (self.grayscale):
        # 灰度圖像的channel為1
        sample_inputs = np.array(sample).astype(np.float32)[:, :, :, None]
      else:
        # color image
        sample_inputs = np.array(sample).astype(np.float32)
  
    counter = 1
    start_time = time.time()
    could_load, checkpoint_counter = self.load(self.checkpoint_dir)
    if could_load:
      counter = checkpoint_counter
      print(" [*] Load SUCCESS")
    else:
      print(" [!] Load failed...")

    for epoch in xrange(config.epoch):
      if config.dataset == 'mnist':
        batch_idxs = min(len(self.data_X), config.train_size) // config.batch_size
      else:
        # self.data is like:["0.jpg","1.jpg",...]
        self.data = glob(os.path.join(
          config.data_dir, config.dataset, self.input_fname_pattern))
        batch_idxs = min(len(self.data), config.train_size) // config.batch_size

      for idx in xrange(0, batch_idxs):
        if config.dataset == 'mnist':
          batch_images = self.data_X[idx*config.batch_size:(idx+1)*config.batch_size]
          batch_labels = self.data_y[idx*config.batch_size:(idx+1)*config.batch_size]
        else:
          batch_files = self.data[idx*config.batch_size:(idx+1)*config.batch_size]
          batch = [
              get_image(batch_file,
                        input_height=self.input_height,
                        input_width=self.input_width,
                        resize_height=self.output_height,
                        resize_width=self.output_width,
                        crop=self.crop,
                        grayscale=self.grayscale) for batch_file in batch_files]
          if self.grayscale:
            # add a channel for grayscale
            # batch_images shape:(batch,height,width,channel)
            batch_images = np.array(batch).astype(np.float32)[:, :, :, None]
          else:
            batch_images = np.array(batch).astype(np.float32)
        # add noise
        batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim]) \
              .astype(np.float32)

        if config.dataset == 'mnist':
          # Update D network
          _, summary_str = self.sess.run([d_optim, self.d_sum],
            feed_dict={ 
              self.inputs: batch_images,
              self.z: batch_z,
              self.y:batch_labels,
            })
          # 用於可視化
          self.writer.add_summary(summary_str, counter)

          # Update G network
          _, summary_str = self.sess.run([g_optim, self.g_sum],
            feed_dict={
              self.z: batch_z, 
              self.y:batch_labels,
            })
          self.writer.add_summary(summary_str, counter)

          # Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
          _, summary_str = self.sess.run([g_optim, self.g_sum],
            feed_dict={ self.z: batch_z, self.y:batch_labels })
          self.writer.add_summary(summary_str, counter)
          
          errD_fake = self.d_loss_fake.eval({
              self.z: batch_z, 
              self.y:batch_labels
          })
          errD_real = self.d_loss_real.eval({
              self.inputs: batch_images,
              self.y:batch_labels
          })
          errG = self.g_loss.eval({
              self.z: batch_z,
              self.y: batch_labels
          })
        else:
          # Update D network
          _, summary_str = self.sess.run([d_optim, self.d_sum],
            feed_dict={ self.inputs: batch_images, self.z: batch_z })
          self.writer.add_summary(summary_str, counter)

          # Update G network
          _, summary_str = self.sess.run([g_optim, self.g_sum],
            feed_dict={ self.z: batch_z })
          self.writer.add_summary(summary_str, counter)

          # Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
          _, summary_str = self.sess.run([g_optim, self.g_sum],
            feed_dict={ self.z: batch_z })
          self.writer.add_summary(summary_str, counter)
          
          errD_fake = self.d_loss_fake.eval({ self.z: batch_z })
          errD_real = self.d_loss_real.eval({ self.inputs: batch_images })
          errG = self.g_loss.eval({self.z: batch_z})

        counter += 1
        print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
          % (epoch, config.epoch, idx, batch_idxs,
            time.time() - start_time, errD_fake+errD_real, errG))
        # np.mod:Return element-wise remainder of division.
        # 每100次生成一次samples
        if np.mod(counter, config.print_every) == 1:
          if config.dataset == 'mnist':
            samples, d_loss, g_loss = self.sess.run(
              [self.sampler, self.d_loss, self.g_loss],
              feed_dict={
                  self.z: sample_z,
                  self.inputs: sample_inputs,
                  self.y:sample_labels,
              }
            )
            # 保存生成的樣本
            save_images(samples, image_manifold_size(samples.shape[0]),
                  './{}/train_{:02d}_{:04d}.png'.format(config.sample_dir, epoch, idx))
            print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss)) 
          else:
            try:
              samples, d_loss, g_loss = self.sess.run(
                [self.sampler, self.d_loss, self.g_loss],
                feed_dict={
                    self.z: sample_z,
                    self.inputs: sample_inputs,
                },
              )
              save_images(samples, image_manifold_size(samples.shape[0]),
                    './{}/train_{:02d}_{:04d}.png'.format(config.sample_dir, epoch, idx))
              print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss)) 
            except:
              print("one pic error!...")
        # 每500次保存一下checkpoint
        if np.mod(counter, config.checkpoint_every) == 2: # save checkpoint file
          self.save(config.checkpoint_dir, counter)

  def discriminator(self, image, y=None, reuse=False):
    with tf.variable_scope("discriminator") as scope:
      if reuse:
        scope.reuse_variables()

      if not self.y_dim:
        h0 = lrelu(conv2d(image, self.df_dim, name='d_h0_conv'))
        h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim*2, name='d_h1_conv')))
        h2 = lrelu(self.d_bn2(conv2d(h1, self.df_dim*4, name='d_h2_conv')))
        h3 = lrelu(self.d_bn3(conv2d(h2, self.df_dim*8, name='d_h3_conv')))
        h4 = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'd_h4_lin')

        return tf.nn.sigmoid(h4), h4
      else:
        yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim])
        x = conv_cond_concat(image, yb)

        h0 = lrelu(conv2d(x, self.c_dim + self.y_dim, name='d_h0_conv'))
        h0 = conv_cond_concat(h0, yb)

        h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim + self.y_dim, name='d_h1_conv')))
        h1 = tf.reshape(h1, [self.batch_size, -1])      
        h1 = concat([h1, y], 1)
        
        h2 = lrelu(self.d_bn2(linear(h1, self.dfc_dim, 'd_h2_lin')))
        h2 = concat([h2, y], 1)

        h3 = linear(h2, 1, 'd_h3_lin')
        
        return tf.nn.sigmoid(h3), h3

  def generator(self, z, y=None):
    with tf.variable_scope("generator") as scope:
      if not self.y_dim:
        s_h, s_w = self.output_height, self.output_width
        # 2 is stride
        s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2)
        s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2)
        s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2)
        s_h16, s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2)

        # project `z` and reshape
        self.z_, self.h0_w, self.h0_b = linear(
            z, self.gf_dim*8*s_h16*s_w16, 'g_h0_lin', with_w=True)

        self.h0 = tf.reshape(
            self.z_, [-1, s_h16, s_w16, self.gf_dim * 8])
        h0 = tf.nn.relu(self.g_bn0(self.h0))

        self.h1, self.h1_w, self.h1_b = deconv2d(
            h0, [self.batch_size, s_h8, s_w8, self.gf_dim*4], name='g_h1', with_w=True)
        h1 = tf.nn.relu(self.g_bn1(self.h1))

        h2, self.h2_w, self.h2_b = deconv2d(
            h1, [self.batch_size, s_h4, s_w4, self.gf_dim*2], name='g_h2', with_w=True)
        h2 = tf.nn.relu(self.g_bn2(h2))

        h3, self.h3_w, self.h3_b = deconv2d(
            h2, [self.batch_size, s_h2, s_w2, self.gf_dim*1], name='g_h3', with_w=True)
        h3 = tf.nn.relu(self.g_bn3(h3))

        h4, self.h4_w, self.h4_b = deconv2d(
            h3, [self.batch_size, s_h, s_w, self.c_dim], name='g_h4', with_w=True)

        return tf.nn.tanh(h4)
      else:
        s_h, s_w = self.output_height, self.output_width
        s_h2, s_h4 = int(s_h/2), int(s_h/4)
        s_w2, s_w4 = int(s_w/2), int(s_w/4)

        # yb = tf.expand_dims(tf.expand_dims(y, 1),2)
        yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim])
        z = concat([z, y], 1)

        h0 = tf.nn.relu(
            self.g_bn0(linear(z, self.gfc_dim, 'g_h0_lin')))
        h0 = concat([h0, y], 1)

        h1 = tf.nn.relu(self.g_bn1(
            linear(h0, self.gf_dim*2*s_h4*s_w4, 'g_h1_lin')))
        h1 = tf.reshape(h1, [self.batch_size, s_h4, s_w4, self.gf_dim * 2])

        h1 = conv_cond_concat(h1, yb)

        h2 = tf.nn.relu(self.g_bn2(deconv2d(h1,
            [self.batch_size, s_h2, s_w2, self.gf_dim * 2], name='g_h2')))
        h2 = conv_cond_concat(h2, yb)

        return tf.nn.sigmoid(
            deconv2d(h2, [self.batch_size, s_h, s_w, self.c_dim], name='g_h3'))

  def sampler(self, z, y=None): # 采樣測試
    with tf.variable_scope("generator") as scope:
      scope.reuse_variables()

      if not self.y_dim: # generator
        s_h, s_w = self.output_height, self.output_width
        s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2)
        s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2)
        s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2)
        s_h16, s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2)

        # project `z` and reshape
        h0 = tf.reshape(
            linear(z, self.gf_dim*8*s_h16*s_w16, 'g_h0_lin'),
            [-1, s_h16, s_w16, self.gf_dim * 8])
        h0 = tf.nn.relu(self.g_bn0(h0, train=False))

        h1 = deconv2d(h0, [self.batch_size, s_h8, s_w8, self.gf_dim*4], name='g_h1')
        h1 = tf.nn.relu(self.g_bn1(h1, train=False))

        h2 = deconv2d(h1, [self.batch_size, s_h4, s_w4, self.gf_dim*2], name='g_h2')
        h2 = tf.nn.relu(self.g_bn2(h2, train=False))

        h3 = deconv2d(h2, [self.batch_size, s_h2, s_w2, self.gf_dim*1], name='g_h3')
        h3 = tf.nn.relu(self.g_bn3(h3, train=False))

        h4 = deconv2d(h3, [self.batch_size, s_h, s_w, self.c_dim], name='g_h4')

        return tf.nn.tanh(h4)
      else: # discriminator
        s_h, s_w = self.output_height, self.output_width
        s_h2, s_h4 = int(s_h/2), int(s_h/4)
        s_w2, s_w4 = int(s_w/2), int(s_w/4)

        # yb = tf.reshape(y, [-1, 1, 1, self.y_dim])
        yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim])
        z = concat([z, y], 1)

        h0 = tf.nn.relu(self.g_bn0(linear(z, self.gfc_dim, 'g_h0_lin'), train=False))
        h0 = concat([h0, y], 1)

        h1 = tf.nn.relu(self.g_bn1(
            linear(h0, self.gf_dim*2*s_h4*s_w4, 'g_h1_lin'), train=False))
        h1 = tf.reshape(h1, [self.batch_size, s_h4, s_w4, self.gf_dim * 2])
        h1 = conv_cond_concat(h1, yb)

        h2 = tf.nn.relu(self.g_bn2(
            deconv2d(h1, [self.batch_size, s_h2, s_w2, self.gf_dim * 2], name='g_h2'), train=False))
        h2 = conv_cond_concat(h2, yb)

        return tf.nn.sigmoid(deconv2d(h2, [self.batch_size, s_h, s_w, self.c_dim], name='g_h3'))

  def load_mnist(self):
    data_dir = os.path.join(self.data_dir, self.dataset_name)
    
    fd = open(os.path.join(data_dir,'train-images-idx3-ubyte'))
    loaded = np.fromfile(file=fd,dtype=np.uint8)
    trX = loaded[16:].reshape((60000,28,28,1)).astype(np.float)

    fd = open(os.path.join(data_dir,'train-labels-idx1-ubyte'))
    loaded = np.fromfile(file=fd,dtype=np.uint8)
    trY = loaded[8:].reshape((60000)).astype(np.float)

    fd = open(os.path.join(data_dir,'t10k-images-idx3-ubyte'))
    loaded = np.fromfile(file=fd,dtype=np.uint8)
    teX = loaded[16:].reshape((10000,28,28,1)).astype(np.float)

    fd = open(os.path.join(data_dir,'t10k-labels-idx1-ubyte'))
    loaded = np.fromfile(file=fd,dtype=np.uint8)
    teY = loaded[8:].reshape((10000)).astype(np.float)

    trY = np.asarray(trY)
    teY = np.asarray(teY)
    
    X = np.concatenate((trX, teX), axis=0)
    y = np.concatenate((trY, teY), axis=0).astype(np.int)
    
    seed = 547
    np.random.seed(seed)
    np.random.shuffle(X)
    np.random.seed(seed)
    np.random.shuffle(y)
    
    y_vec = np.zeros((len(y), self.y_dim), dtype=np.float)
    for i, label in enumerate(y):
      y_vec[i,y[i]] = 1.0
    
    return X/255.,y_vec

  @property # 可以當屬性來用
  def model_dir(self):
    return "{}_{}_{}_{}".format(
        self.dataset_name, self.batch_size,
        self.output_height, self.output_width)
      
  def save(self, checkpoint_dir, step):
    # save checkpoint files
    model_name = "DCGAN.model"
    checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)

    if not os.path.exists(checkpoint_dir):
      os.makedirs(checkpoint_dir)

    self.saver.save(self.sess,
            os.path.join(checkpoint_dir, model_name),
            global_step=step)

  # load checkpoints file
  def load(self, checkpoint_dir):
    import re
    print(" [*] Reading checkpoints...")
    checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
    #A CheckpointState if the state was available, None
    # otherwise
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
      # basename:Returns the final component of a pathname
      ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
      self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
      counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
      print(" [*] Success to read {}".format(ckpt_name))
      return True, counter
    else:
      print(" [*] Failed to find a checkpoint")
      return False, 0
  • from _future_ import division 這句話當python的版本為2.x時生效,可以讓兩個整數數字相除的結果返回一個浮點數(在python2中默認是整數,python3默認為浮點數)。glob可以以簡單的正則表達式篩選的方式返回某個文件夾下符合要求的文件名列表。
  • DCGAN的構造方法除了設置一大堆的屬性之外,還要注意區分dataset是否是mnist,因為mnist是灰度圖像,所以應該設置channel = 1( self.c_dim = 1 ),如果是彩色圖像,則 self.c_dim = 3 or self.c_dim = 4 。然后就是build_model
  • self.generator 用於構造生成器; self.discriminator 用於構造鑒別器; self.sampler 用於隨機采樣(用於生成樣本)。這里需要注意的是, self.y 只有當dataset是mnist的時候才不為None,不是mnist的情況下,只需要 self.z 即可生成samples。
  • sigmoid_cross_entropy_with_logits 函數被重新定義了,是為了兼容不同版本的tensorflow。該函數首先使用sigmoid activation,然后計算cross-entropy loss。
  • self.g_loss 是生成器損失; self.d_loss_real 是真實圖片的鑒別器損失; self.d_loss_fake 是虛假圖片(由生成器生成的fake images)的損失; self.d_loss 是總的鑒別器損失。
  • 這里的 histogram_summaryscalar_summary 是為了在后續在tensorboard中對各個損失函數進行可視化。
  • tf.trainable_variables() 可以獲取model的全部可訓練參數,由於我們在定義生成器和鑒別器變量的時候使用了不同的name,因此我們可以通過variable的name來獲取得到self.d_vars(鑒別器相關變量),self.g_vars(生成器相關變量)。 self.saver = tf.train.Saver() 用於保存訓練好的模型參數到checkpoint。
  • train 函數是核心的訓練函數。這里optimizer和DCGAN的原文保持一直,選用Adam優化函數, lr=0.0002 , beta1=0.5merge_summary 函數和 SummaryWriter 用於構建summary,在tensorboard中顯示。
  • sample_z 是從[-1,1]的均勻分布產生的。如果dataset是mnist,則可以直接讀取sample_inputssample_labels。否則需要手動逐個處理圖像, get_image
    返回的是取值為(-1,1)的,shape為(resize_height,resize_width)的ndarray。如果處理的圖像是灰度圖像,則需要再增加一個dim,表示圖像的channel=1,對應的代碼是 sample_inputs = np.array(sample).astype(np.float32)[:, :, :, None]
  • 接下來通過 self.sess.run([d_optim,...self.sess.run([g_optim,...) 來更新鑒別器和生成器。 self.writer.add_summary(summary_str, counter) 增加summary到writer。由於同樣的原因,這里仍然需要區分mnist和其他的數據集,所以計算最優化函數的過程需要一個if和一個else
  • np.mod(counter, config.print_every) == 1 表示每print_every次生成一次samples; np.mod(counter, config.checkpoint_every) == 2 表示每checkpoint_every次保存一下checkpoint file。
  • 下面是discriminator(鑒別器)的具體實現。首先鑒別器使用conv(卷積)操作,激活函數使用leaky-relu,每一個layer需要使用batch normalization。tensorflow的batch normalization使用 tf.contrib.layers.batch_norm 實現。如果不是mnist,則第一層使用leaky-relu+conv2d,后面三層都使用conv2d+BN+leaky-relu,最后加上一個one hidden unit的linear layer,再送入sigmoid函數即可;如果是mnist,則 yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim]) 首先給y增加兩維,以便可以和image連接起來,這里實際上是使用了conditional GAN(條件GAN)的思想。 x = conv_cond_concat(image, yb) 得到condition和image合並之后的結果,然后 h0 = lrelu(conv2d(x, self.c_dim + self.y_dim, name='d_h0_conv')) 進行卷積操作。第二次進行conv2d+leaky-relu+concat操作。第三次進行conv2d+BN+leaky-relu+reshape+concat操作。第四次進行linear+BN+leaky-relu+concat操作。最后同樣是linear+sigmoid操作。
  • 下面是generator(生成器)的具體實現。和discriminator不同的是,generator需要使用deconv(反卷積)以及relu 激活函數。generator的結構是:1.如果不是mnist:linear+reshape+BN+relu---->(deconv+BN+relu)x3 ---->deconv+tanh;2.如果是mnist,則除了需要考慮輸入z之外,還需要考慮label y,即需要將z和y連接起來(Conditional GAN),具體的結構是:reshape+concat---->linear+BN+relu+concat---->linear+BN+relu+reshape+concat---->deconv+BN+relu+concat---->deconv+sigmoid。注意的最后的激活函數沒有采用通常的tanh,而是采用了sigmoid(其輸出會直接映射到0-1之間)。
  • sampler函數是采樣函數,用於生成樣本送入當前訓練的生成器,查看訓練效果。其邏輯和generator函數基本類似,也是需要區分是否是mnist,二者需要采用不同的結構。不是mnist時,y=None即可;否則mnist還需要考慮y。
  • load_mnist 函數用於加載mnist數據集; save 函數用於保存checkpoint; load 函數用於加載checkpoint。

2.4 ops.py

    ops.py代碼如下:

import math
import numpy as np 
import tensorflow as tf

from tensorflow.python.framework import ops

from utils import *

try:
  image_summary = tf.image_summary
  scalar_summary = tf.scalar_summary
  histogram_summary = tf.histogram_summary
  merge_summary = tf.merge_summary
  SummaryWriter = tf.train.SummaryWriter
except:
  image_summary = tf.summary.image
  scalar_summary = tf.summary.scalar
  histogram_summary = tf.summary.histogram
  merge_summary = tf.summary.merge
  SummaryWriter = tf.summary.FileWriter

if "concat_v2" in dir(tf):
  def concat(tensors, axis, *args, **kwargs):
    return tf.concat_v2(tensors, axis, *args, **kwargs)
else:
  def concat(tensors, axis, *args, **kwargs):
    return tf.concat(tensors, axis, *args, **kwargs)

class batch_norm(object):
  def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"):
    with tf.variable_scope(name):
      self.epsilon  = epsilon
      self.momentum = momentum
      self.name = name

  # 定義了class 的__call__ 方法,可以把類像函數一樣調用
  def __call__(self, x, train=True):
    return tf.contrib.layers.batch_norm(x,
                      decay=self.momentum, 
                      updates_collections=None,
                      epsilon=self.epsilon,
                      scale=True,
                      is_training=train,
                      scope=self.name)

def conv_cond_concat(x, y):
  """Concatenate conditioning vector on feature map axis."""
  x_shapes = x.get_shape()
  y_shapes = y.get_shape()
  # 沿axis = 3(最后一個維度連接)
  return concat([
    x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3)

def conv2d(input_, output_dim, 
       k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
       name="conv2d"):
  with tf.variable_scope(name):
    w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim],
              initializer=tf.truncated_normal_initializer(stddev=stddev))
    conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')

    biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
    conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())

    return conv

# 做一個反卷積操作,tf.nn.conv2d_transpose
def deconv2d(input_, output_shape,
       k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
       name="deconv2d", with_w=False):
  with tf.variable_scope(name):
    # filter : [height, width, output_channels, in_channels]
    w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]],
              initializer=tf.random_normal_initializer(stddev=stddev))
    
    try:
      deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape,
                strides=[1, d_h, d_w, 1])

    # Support for verisons of TensorFlow before 0.7.0
    except AttributeError:
      deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape,
                strides=[1, d_h, d_w, 1])

    biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
    deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())

    if with_w:
      return deconv, w, biases
    else:
      return deconv

# leaky relu
def lrelu(x, leak=0.2, name="lrelu"):
  return tf.maximum(x, leak*x)

def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False):
  # 本質其實就是做了一個matmul....
  shape = input_.get_shape().as_list()

  with tf.variable_scope(scope or "Linear"):
    matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32,
                 tf.random_normal_initializer(stddev=stddev))
    bias = tf.get_variable("bias", [output_size],
      initializer=tf.constant_initializer(bias_start))
    if with_w:
      return tf.matmul(input_, matrix) + bias, matrix, bias
    else:
      return tf.matmul(input_, matrix) + bias
  • 第9行到第20行的代碼是為了保持tf0.x和tf1.x版本的兼容性。tf0.x版本使用tf.xxx_summary風格的函數,而tf1.x版本則使用tf.summary.xxx風格的函數。為了保持一致性,通過重命名統一成tf.xxx_summary風格了。
  • 22行到27行重新定義了concat函數,也是為了兼容性考慮, if "concat_v2" in dir(tf): 這句話是說如果tf有concat_v2這個方法的話,tf0.x中使用concat_v2函數,而tf1.x版本中使用concat函數。
  • 29行到44行定義了batch_norm類。需要注意的是37-44行定義了類的__call__特殊方法,這個方法的作用是可以將類像普通的函數那樣直接調用,而不用先構造一個對象再調用方法,這是常用的一個技巧。tf中的batch normalization 是函數 tf.contrib.layers.batch_norm
  • conv_cond_concat函數的作用是將conv(卷積)和cond(條件)concat起來。在mnist的generator和discriminator中會用到。
  • 54行到65行的conv2d函數重新定義了卷積操作,主要是封裝了 tf.nn.conv2d 函數。
  • 68行到91行定義了deconv2d(反卷積)函數。tf0.x的反卷積函數為 tf.nn.deconv2d ,tf1.x的反卷積函數為 tf.nn.conv2d_transpose 。最后還加上了一個bias( tf.nn.bias_add )。
  • 94到95行定義了leaky-relu函數lrelu。其實就一行代碼: tf.maximum(x, leak*x)
  • 97行到109行定義了linear函數,其實就是一個fully_connected layer

2.5 utils.py

    utils.py代碼如下:

"""
Some codes from https://github.com/Newmu/dcgan_code
"""
from __future__ import division
from glob import glob
from os.path import join,basename,exists
from os import makedirs
import math
import json
import random
import pprint
import scipy.misc
import numpy as np
from time import gmtime, strftime
from six.moves import xrange

import tensorflow as tf
import tensorflow.contrib.slim as slim

pp = pprint.PrettyPrinter()

get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1])

def show_all_variables():
  model_vars = tf.trainable_variables()
  # Prints the names and shapes of the variables
  slim.model_analyzer.analyze_vars(model_vars, print_info=True)

def get_image(image_path, input_height, input_width,
              resize_height=64, resize_width=64,
              crop=True, grayscale=False):
  image = imread(image_path, grayscale)
  return transform(image, input_height, input_width,
                   resize_height, resize_width, crop)

def save_images(images, size, image_path):
  return imsave(inverse_transform(images), size, image_path)

def imread(path, grayscale = False):
  if (grayscale):
    return scipy.misc.imread(path, flatten = True).astype(np.float)
  else:
    return scipy.misc.imread(path).astype(np.float)

def merge_images(images, size):
  return inverse_transform(images)

def merge(images, size):
  # samples 圖片的真實高和寬
  h, w = images.shape[1], images.shape[2]
  # 圖片channel的有效值只能是3或者4
  if (images.shape[3] in (3,4)):
    c = images.shape[3]
    # img是合並之后的大圖片,圖片寬和高都倍增了
    img = np.zeros((h * size[0], w * size[1], c))
    # 遍歷每一張圖片
    for idx, image in enumerate(images):
      i = idx % size[1]
      j = idx // size[1]
      # 依次向大圖填充小圖(按行填充)
      img[j * h:j * h + h, i * w:i * w + w, :] = image
    return img
  elif images.shape[3]==1:
    # drop channel
    img = np.zeros((h * size[0], w * size[1]))
    for idx, image in enumerate(images):
      i = idx % size[1]
      j = idx // size[1]
      img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0]
    return img
  else:
    raise ValueError('in merge(images,size) images parameter '
                     'must have dimensions: HxW or HxWx3 or HxWx4')

def imsave(images, size, path):
  '''
  modified imsave
  :param images: ndarray,shape:(batch,height,width,channel)
  :param size: (row images num,col images num)
  :param path: save path
  :return:
  '''
  # np.squeeze:去除維度為1的維
  image = np.squeeze(merge(images, size))
  return scipy.misc.imsave(path, image)

def center_crop(x, crop_h, crop_w,
                resize_h=64, resize_w=64):
  '''
  對圖像進行中心化crop處理
  :param x: image ndarray
  :param crop_h: input height
  :param crop_w: input width
  :param resize_h: resize height
  :param resize_w: resize width
  :return: resized image
  '''
  if crop_w is None:
    crop_w = crop_h
  h, w = x.shape[:2]
  j = int(round((h - crop_h)/2.))
  i = int(round((w - crop_w)/2.))
  return scipy.misc.imresize(
      x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w])

def transform(image, input_height, input_width, 
              resize_height=64, resize_width=64, crop=True):
  '''
  對圖像進行轉化處理
  :param image: ndarray of image
  :param input_height: image height
  :param input_width:  image width
  :param resize_height: height after resize
  :param resize_width:  width after resize
  :param crop: if to crop or not
  :return:
  '''
  if crop:
    cropped_image = center_crop(
      image, input_height, input_width, 
      resize_height, resize_width)
  else:
    # 直接resize
    cropped_image = scipy.misc.imresize(image, [resize_height, resize_width])
  # 將(0,255)映射到(-1,1)
  return np.array(cropped_image)/127.5 - 1.

def inverse_transform(images):
  # (-1,1) ---> (0,1)
  return (images+1.)/2.

def to_json(output_path, *layers):
  with open(output_path, "w") as layer_f:
    lines = ""
    for w, b, bn in layers:
      layer_idx = w.name.split('/')[0].split('h')[1]

      B = b.eval()

      if "lin/" in w.name:
        W = w.eval()
        depth = W.shape[1]
      else:
        W = np.rollaxis(w.eval(), 2, 0)
        depth = W.shape[0]

      biases = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(B)]}
      if bn != None:
        gamma = bn.gamma.eval()
        beta = bn.beta.eval()

        gamma = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(gamma)]}
        beta = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(beta)]}
      else:
        gamma = {"sy": 1, "sx": 1, "depth": 0, "w": []}
        beta = {"sy": 1, "sx": 1, "depth": 0, "w": []}

      if "lin/" in w.name:
        fs = []
        for w in W.T:
          fs.append({"sy": 1, "sx": 1, "depth": W.shape[0], "w": ['%.2f' % elem for elem in list(w)]})

        lines += """
          var layer_%s = {
            "layer_type": "fc", 
            "sy": 1, "sx": 1, 
            "out_sx": 1, "out_sy": 1,
            "stride": 1, "pad": 0,
            "out_depth": %s, "in_depth": %s,
            "biases": %s,
            "gamma": %s,
            "beta": %s,
            "filters": %s
          };""" % (layer_idx.split('_')[0], W.shape[1], W.shape[0], biases, gamma, beta, fs)
      else:
        fs = []
        for w_ in W:
          fs.append({"sy": 5, "sx": 5, "depth": W.shape[3], "w": ['%.2f' % elem for elem in list(w_.flatten())]})

        lines += """
          var layer_%s = {
            "layer_type": "deconv", 
            "sy": 5, "sx": 5,
            "out_sx": %s, "out_sy": %s,
            "stride": 2, "pad": 1,
            "out_depth": %s, "in_depth": %s,
            "biases": %s,
            "gamma": %s,
            "beta": %s,
            "filters": %s
          };""" % (layer_idx, 2**(int(layer_idx)+2), 2**(int(layer_idx)+2),
               W.shape[0], W.shape[3], biases, gamma, beta, fs)
    layer_f.write(" ".join(lines.replace("'","").split()))

def make_gif(images, fname, duration=2, true_image=False):
  # 生成gif圖
  # duration:持續時間
  # images shape:(batch_size,height,width,channel)
  import moviepy.editor as mpy

  def make_frame(t):
    try:
      # x 代表是t時刻選取的幀圖片
      x = images[int(len(images)/duration*t)]
    except:
      x = images[-1]

    if true_image: # 返回不經過處理的ndarray,元素值是(-1,1)之間
      return x.astype(np.uint8)
    else:
      # (-1,1) ---> (0,255)
      return ((x+1)/2*255).astype(np.uint8)

  clip = mpy.VideoClip(make_frame, duration=duration)
  clip.write_gif(fname, fps = len(images) / duration)

def visualize(sess, dcgan, config, option):
  # 用於可視化
  image_frame_dim = int(math.ceil(config.batch_size**.5)) # 圖片尺寸
  if option == 0:
    # noise
    z_sample = np.random.uniform(-0.5, 0.5, size=(config.batch_size, dcgan.z_dim))
    samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
    save_images(samples, [image_frame_dim, image_frame_dim], './%s/test_%s.png' % (config.sample_dir,strftime("%Y-%m-%d-%H-%M-%S", gmtime())))
  elif option == 1: # 將samples生成大圖
    values = np.arange(0, 1, 1./config.batch_size)
    for idx in xrange(dcgan.z_dim):
      print(" [*] %d" % idx)
      z_sample = np.random.uniform(-1, 1, size=(config.batch_size , dcgan.z_dim))
      for kdx, z in enumerate(z_sample):
        z[idx] = values[kdx]

      if config.dataset == "mnist":
        # y是batch_size個0-9之間的隨機數
        y = np.random.choice(10, config.batch_size)
        save_random_digits(y,image_frame_dim,image_frame_dim,'./%s/test_arange_%s.txt' % (config.sample_dir,idx))
        y_one_hot = np.zeros((config.batch_size, 10))
        y_one_hot[np.arange(config.batch_size), y] = 1

        samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample, dcgan.y: y_one_hot})
      else:
        samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})

      save_images(samples, [image_frame_dim, image_frame_dim], './%s/test_arange_%s.png' % (config.sample_dir,idx))
  elif option == 2:
    values = np.arange(0, 1, 1./config.batch_size)
    # idx是隨機的
    for idx in [random.randint(0, dcgan.z_dim - 1) for _ in xrange(dcgan.z_dim)]:
      print(" [*] %d" % idx)
      # z_dim:test_images_num
      z = np.random.uniform(-0.2, 0.2, size=(dcgan.z_dim))
      # np.tile:按照指定的維度將array重復
      # z_sample shape:(batch_size,z_dim)
      z_sample = np.tile(z, (config.batch_size, 1))
      #z_sample = np.zeros([config.batch_size, dcgan.z_dim])
      for kdx, z in enumerate(z_sample):
        z[idx] = values[kdx]

      if config.dataset == "mnist":
        y = np.random.choice(10, config.batch_size)
        #save_random_digits(y, image_frame_dim, image_frame_dim, './%s/test_%s.txt' % % (config.sample_dir,strftime("%Y-%m-%d-%H-%M-%S", gmtime())))
        y_one_hot = np.zeros((config.batch_size, 10))
        y_one_hot[np.arange(config.batch_size), y] = 1

        samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample, dcgan.y: y_one_hot})
      else:
        samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})

      try:
        make_gif(samples, './%s/test_gif_%s.gif' % (config.sample_dir,idx))
      except:
        save_images(samples, [image_frame_dim, image_frame_dim], './%s/test_%s.png' % (config.sample_dir,strftime("%Y-%m-%d-%H-%M-%S", gmtime())))
  elif option == 3: # 不能是mnist,直接生成gif
    values = np.arange(0, 1, 1./config.batch_size)
    for idx in xrange(dcgan.z_dim):
      print(" [*] %d" % idx)
      z_sample = np.zeros([config.batch_size, dcgan.z_dim])
      for kdx, z in enumerate(z_sample):
        z[idx] = values[kdx]

      samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
      make_gif(samples, './%s/test_gif_%s.gif' % (config.sample_dir,idx))
  elif option == 4:
    image_set = []
    values = np.arange(0, 1, 1./config.batch_size)

    for idx in xrange(dcgan.z_dim):
      print(" [*] %d" % idx)
      z_sample = np.zeros([config.batch_size, dcgan.z_dim])
      for kdx, z in enumerate(z_sample): z[idx] = values[kdx]

      image_set.append(sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}))
      #make_gif(image_set[-1], './%s/test_gif_%s.gif' % (config.sample_dir,idx))

    # 合成一張大圖gif(64張大圖)
    new_image_set = [merge(np.array([images[idx] for images in image_set]), [10, 10]) \
        for idx in range(63, -1, -1)] # 63-0
    make_gif(new_image_set, './%s/test_gif_merged.gif' % config.sample_dir, duration=8)

def save_random_digits(arr,height,width,save_path):
  '''
  將arr中數字保存到文件,按行保存,共有height行,width列
  :param arr: ndarray
  :param height: 行數
  :param width: 列數
  :param save_path: 保存文件地址
  :return:
  '''
  with open(save_path,"w") as f:
    for i in range(height):
      for j in range(width):
        if j != width-1:
          f.write("%d," % arr[i*width+j])
        else:
          f.write("%d\n" % arr[i*width+j])
  f.close()



def image_manifold_size(num_images):
  manifold_h = int(np.floor(np.sqrt(num_images)))
  manifold_w = int(np.ceil(np.sqrt(num_images)))
  assert manifold_h * manifold_w == num_images
  return manifold_h, manifold_w

def resize_imgs(imgs_path,size,save_dir):
  '''
  將imgs_path文件夾的所有圖片都resize到size大小,並重新保存到save_dir
  :param imgs_path: 原始圖像文件夾地址
  :param size: resize之后的圖像大小
  :param save_dir: resize之后的圖像保存地址
  :return:
  '''
  if not exists(save_dir):
    makedirs(save_dir)
  imgs = glob(imgs_path+"*.jpg")
  for i,img in enumerate(imgs,1):
    try:
      img_arr = scipy.misc.imread(img)
      new_img = scipy.misc.imresize(img_arr,size)
      scipy.misc.imsave(join(save_dir,basename(img)),new_img)
    except Exception as e:
      print(e)
    if i % 100 == 0:
      print("Resize and save %d images!" % i)
  print("Resize and save all %d images!" % len(imgs))


# if __name__ == '__main__':
#     imgs_path = "data/images/"
#     save_dir = "data/lsun_new/"
#     size = (108,108)
#     resize_imgs(imgs_path,size,save_dir)

utils.py定義了很多有用的全局工具函數,可以直接被其他的腳本調用。

  • glob庫用來list 某一個文件夾下的files;os庫用來操作路徑和文件夾等;pprint用於美觀打印;gtime和strftime有用格式化日期;scipy.misc包含了很多和圖像相關的有用的函數。
  • 24-27行的show_all_variables函數,調用了 slim.model_analyzer.analyze_vars(vars,print_info) 函數來打印model所有variables的信息。
  • 39-43行的imread函數封裝了 scipy.misc.imread 函數,該函數參數 flatten = True 表示將color layer 展平成一個single gray-scale layer。
  • 48-73行的merge函數用於從一系列小圖產生大圖,images[0]表示小圖的個數,h=images[1]表示小圖的高,w = images[2]表示小圖的寬,x_h = size[0]表示最終大圖height應該擴展的倍數,x_w = size[1]表示最終大圖width應該擴展的倍數。該函數最終生成一個高為h*x_h,寬為w*x_w的大圖。表示大圖的高度方向包含x_h個小圖,寬度方向包含x_w個小圖。
  • 75-85行定義了保存圖像的imsave函數。注意 np.squeeze 可以去除數組中維度為1的那些維(降維),與之相反的操作是 np.expand_dims(arr,axis) 函數,可以給指定的axis維度增加一維。
  • 87-104行的center_crop函數的作用是中心化剪切處理,同時對圖像進行了resize操作。
  • 106-126行的transform函數,也是對圖像進行center_crop(可選)以及resize操作,只不過它最后將image array的每個元素的取值范圍從(0,255)映射到(-1,1),(-1,1)是tanh函數的取值范圍。
  • 132-193行的to_json函數將各個layers結構保存到json文件,我們不用這個函數,就不細說了。
  • 195-215行的make_gif函數可以將生成的序列圖像轉換為gif圖像,這里使用moviepy庫來完成這個工作,關於moviepy的介紹和使用,可以參考我之前的一篇文章
  • 217-298行的visualize用於測試階段生成圖像樣本,可以是單個jpg格式的圖像,也可以是gif圖像,還可以是小圖拼接成的大圖。visualize函數通過option變量的取值(可以取0,1,2,3,4五個值)來控制以五種不同的方式保存結果。
  1. option=0:這種情況只適用於dataset 不等於mnist的情況,直接將samples merge成一個大圖,然后保存即可,其中大圖共有batch_size張小圖,每行和每列各有ceil(sqrt(batch_size))個;
  2. option=1:這種情況和option=0類似,只是它考慮到了dataset為mnist的情況,如果是mnist,則會隨機生成batch_size個digit labels,然后從generator生成相應的數字,最后拼接成一個大圖,這里我自己定義了一個save_random_digits函數用於將每次隨機生成的數字保存到txt文件中去,這樣后續可以驗證生成的數字圖像是否是我們希望生成的;
  3. option=2:這種情況下,不會生成一張大圖,而是生成含有batch_size幀的gif圖,默認時間是2s,如果生成gif失敗,則會生成和option=1一樣的大圖;
  4. option=3:不能是mnist數據集,生成和option=2一樣的gif。
  5. option=4:合成一張大圖的gif,一共有batch_size個大圖,每個大圖由z_dim(生成樣本數目)個小圖組成。
  • 300-316行的save_random_digits函數是我自定義的函數,用於將隨機數字保存到txt文件;
  • 最后326-346行的resize_imgs函數是我自己添加的,作用就是將指定文件夾下的圖像resize成指定的大小,這樣我們就可以利用自己的數據集訓練model了。

3. 代碼運行結果(生成圖像效果驗證)

1. mnist

    根據我們上面的解讀,運行如下命令即可以使用mnist訓練DCGAN:

python3 main.py --dataset=mnist --input_height=28 --output_height=28 --train True

你需要確保main.py目錄下的data/mnist文件夾下有已經解壓縮的mnist數據文件。由於mnist數據規模不大,所以使用gpu訓練大概只需要幾十分鍾。訓練完成之后,訓練過程中采樣得到的生成圖片保存在samples文件夾下,第一次采樣和最后一次采樣得到圖片分別為下圖1和圖2所示:

圖1 mnist訓練第一次采樣生成圖片
圖2 mnist訓練最后一次采樣生成圖片
可以看出隨着訓練的進行,生成的手寫數字的質量確實是慢慢提高的。好了,接着利用訓練得到的checkpoint來進行test,這里visualize的option參數設置為1,然后運行如下的命令即可以進行測試:
python3 main.py --dataset=mnist --input_height=28 --output_height=28 --train False

測試默認會生成100張合成的大圖,我們隨機抽取一張,比如第66張吧,其真實的隨機數字排列和生成的手寫數字如下圖3和圖4所示:

圖3 第66張真實的隨機數字排列
圖4 第66張生成的手寫數字
可以發現生成的手寫數字和真實的數字是完全符合的,通過隨機查看其他的生成圖片,可以發現基本全部是100%符合的,這說明conditional DCGAN是非常有效的。

2. celebA

    celebA數據集比mnist數據集規模要大,有大約20w+的人臉圖片,圖片是彩色的108*108尺寸。運行下面的命令即可以進行訓練:

python3 main.py --dataset celebA --input_height=108 --crop --train True \
                --epoch 2 --sample_dir ./celebA_samples --visualize True 

注意默認訓練采樣保存的文件夾是samples文件夾,由於我們已經把mnist的結果保存在那里了,如果繼續使用這個文件夾,celebA的結果會把之前的文件覆蓋掉。為了避免這樣的情況,我們重新設定保存sample的文件夾為celebA_samples文件夾,這個文件夾會在運行過程中自動創建,不需要手動創建。由於celebA的數據集規模較大,我電腦的配置是:ubuntu 16.04,tensorflow1.4.1,cuda8+cudnn6,顯卡是nvidia GTX950M,顯存4G。在batch_size = 64的情況下,大概1.5s可以訓練一個batch,因此如果按照默認配置epoch=25,一個epoch的batch_num = ceil(202602/64)=3166,因此全部訓練完大約需要的時間為1.5*3166*25/3600 ≈33h。由於我沒有台式機,自己的筆記本不太可能一直訓練這么長時間;機房的電腦配置太渣,train不動。所以我只能隨便train一下了。我甚至一輪都沒有訓練完就停下來了。第1個epoch第100個batch生成的圖像如下圖5所示:

圖5 第1個epoch第100個batch生成的圖像
第1個epoch第2500個batch生成的圖像如下圖6所示:
圖6 第1個epoch第2500個batch生成的圖像
可以發現,雖然都沒有完整的訓練一個epoch,但是第2500個batch生成的圖像效果已經能初步看出人臉的輪廓了,如果你有足夠的算力,不妨試着完整訓練一下,最后得到的結果應該會相當不錯。

接着我們可以利用上面那個只訓練了一點點的模型進行測試,測試celebA運行命令:

python3 main.py --dataset celebA --input_height=108 --crop --train False \
				--checkpoint_dir ./checkpoint --sample_dir ./celebA_samples

當然你仍然可以通過設定option的值來控制test的輸出。下面的圖7和圖8是生成的gif圖(圖8由於體積太大已經轉為jpg格式),由於訓練非常不充分,因此效果不佳,但是仍然有臉部的輪廓:

圖7 celebA訓練不到一輪生成臉部圖像gif(小圖)
圖8 celebA訓練不到一輪生成臉部圖像gif(大圖)

3. lsun

    由於我使用download.py下載的lsun文件體積非常大(46G),而且格式是mdb格式的,不好直接讀取。所以我后來從lsun的官網又自己重新下載了一個2G的圖像壓縮文件,解壓縮之后大概有9000張圖像,里面的圖像種類較多,主要是關於各種自然景觀的。由於圖像數量不大,而且各個圖像風格差異較大,因此不是很適合訓練DCGAN(當然也是可以train的),所以我自己就沒有實驗了。如果大家有興趣可以自己嘗試訓練一下看看效果怎么樣。

4. beauty_girls

    這個是我自己搜集的數據集,看名字就知道是關於美女的啊。大約有2000張美女圖,基本上是全身圖,原圖尺寸較大,而且size不統一,我們需要利用上面提到的utils.py中的resize_imgs函數首先將所有圖片resize到相同的尺寸(這里我resize到width和height都是108),然后保存到文件夾beauty_girls,將該文件夾放入data目錄下,然后運行如下的命令就可以訓練:

python3 main.py --dataset beauty_girls --input_height=108 --crop --train True \
                --epoch 500 --sample_dir ./beauty_girls_samples --visualize True \
                --print_every 10 --checkpoint_every 240

這一次因為圖片數量只有2000,所以我設定要訓練500輪,我在晚上睡覺的時候用筆記本跑了一下,這下卻翻車了,訓練采樣得到的圖片是這樣的:

圖9 beauty_girls 從上到下依次訓練1輪,66輪,200輪,300輪,500輪生成的圖像
可以發現從第1輪到第300輪生成圖片的質量是提高的,但是再往后訓練,特別是到了最后500輪的時候,圖像明顯花了,很多小圖都是相似的看不懂的模式(也就是論文里說的mode collapse),這說明最多訓練到300輪左右模型就已經差不多收斂了,再往后效果可能會更差,也許會發生mode collapse這種現象。這一點和論文最后提到的是一致的。而且可以發現即使是最好的生成圖片,質量也不是特別好,這可能主要是與訓練樣本數太少(只有2000)而且圖像風格差異太大引起的。最后,不要問我要原始訓練圖片,是拿什么圖片訓練的,你看生成圖片難道猜不到么?哈哈哈。

5. girl_face

    這個數據集來自知乎網友Best July的文章:用DCGAN生成女朋友,有興趣大家可以看看這篇文章。該數據集包含了剪切好的8000多張妹子的頭像,大小都是96x96的。差不多是下面這種:

圖10 girl_face 訓練示例圖片
數據集大家可以去[faces](https://pan.baidu.com/s/1dERYUmH)下載,密碼:09h9。運行下面的命令即可以開始訓練:
python3 main.py --dataset girl_face --input_height=96 --crop --train True \
                --epoch 200 --sample_dir ./girl_face --visualize True \
                --print_every 30 --checkpoint_every 300

你需要確保將包含圖片數據的girl_face文件夾放在data目錄下,我們設定訓練200輪,全部訓練完成估計要5,6個小時。下圖11(從上至下)是分別訓練1輪,30輪,70輪,100輪,130輪以及170輪時候產生的圖像,可以發現隨着訓練輪數的增加,生成圖像的質量是逐漸增加的,大概到100輪左右的時候,其實生成的頭像質量已經很不錯了(可以發現是美女了),后續個別位置的小圖質量有所增加,但是始終有一些小圖有一些畸變,不是特別自然。但是總體上來說,生成的圖片質量很不錯了。

圖11 girl_face 訓練1輪,30輪,70輪,100輪,130輪以及170輪時候產生的圖像(從上至下)

訓練完成之后,我們使用訓練得到的model進行test,但是其實有一個問題我們之前沒有提到,那就是如果訓練輪數設定的過多,那么最新的一個checkpoint加載得到的model未必是最優的,最優的可能在中間的某一個epoch。但是原代碼只能加載最新的一個checkpoint,所以我們將model.py中的 load 函數修改如下:

# load checkpoints file
  def load(self, checkpoint_dir,checkpoint_name = None):
    import re
    print(" [*] Reading checkpoints...")
    checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
    #A CheckpointState if the state was available, None
    # otherwise
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
      # basename:Returns the final component of a pathname
      ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
      if checkpoint_name is None:
        # 加載最新的checkpoint
        self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
      else:
        # 加載指定的而不是最新的checkpoint
        self.saver.restore(self.sess, os.path.join(checkpoint_dir, checkpoint_name))
      counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
      if checkpoint_name is None:
        print(" [*] Success to read {}".format(ckpt_name))
      else:
        print(" [*] Success to read {}".format(checkpoint_name))
      return True, counter
    else:
      print(" [*] Failed to find a checkpoint")
      return False, 0

主要的修改就是增加了一個checkpoint_name參數,用於指定特定的而不是最新的checkpoint file。同時我們增加了一個checkpoint_name命令行參數: flags.DEFINE_string("checkpoint_name",None,"the name of the loaded checkpoint file,default is the lastest checkpoint") 用來指定checkpoint_name參數,默認值是None。

另外還有一個問題就是,在train的時候sample的樣本,輸入噪聲z是服從(-1,1)的均勻分布,而原代碼的visualize函數在option=1,2,3,4的時候,sample不是通過(-1,1)的均勻分布采樣得到的,經過我的實驗,如果在option=1,2,3,4的時候直接用原代碼進行test,得到的生成圖片幾乎都是模糊的。我猜想這是因為test和train的時候的輸入采樣分布不一致導致的結果。因此我也對utils.py的visualize函數進行了修改如下:

def visualize(sess, dcgan, config, option):
  # 用於可視化
  image_frame_dim = int(math.ceil(config.batch_size**.5)) # 圖片尺寸
  if option == -1:
    # noise
    z_sample = np.random.uniform(-1, 1, size=(config.batch_size, dcgan.z_dim))
    samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
    save_images(samples, [image_frame_dim, image_frame_dim],
                './%s/test_%s.png' % (config.sample_dir, strftime("%Y-%m-%d-%H-%M-%S", gmtime())))
  elif option == 0:
    # noise
    z_sample = np.random.uniform(-0.5, 0.5, size=(config.batch_size, dcgan.z_dim))
    samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
    save_images(samples, [image_frame_dim, image_frame_dim], './%s/test_%s.png' % (config.sample_dir,strftime("%Y-%m-%d-%H-%M-%S", gmtime())))
  elif option == 1: # 將samples生成大圖
    #values = np.arange(0, 1, 1./config.batch_size)
    for idx in xrange(dcgan.z_dim):
      print(" [*] %d" % idx)
      z_sample = np.random.uniform(-1, 1, size=(config.batch_size , dcgan.z_dim))
      # for kdx, z in enumerate(z_sample):
      #   z[idx] = values[kdx]

      if config.dataset == "mnist":
        # y是batch_size個0-9之間的隨機數
        y = np.random.choice(10, config.batch_size)
        save_random_digits(y,image_frame_dim,image_frame_dim,'./%s/test_arange_%s.txt' % (config.sample_dir,idx))
        y_one_hot = np.zeros((config.batch_size, 10))
        y_one_hot[np.arange(config.batch_size), y] = 1

        samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample, dcgan.y: y_one_hot})
      else:
        samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})

      save_images(samples, [image_frame_dim, image_frame_dim], './%s/test_arange_%s.png' % (config.sample_dir,idx))
  elif option == 2:
    # values = np.arange(0, 1, 1./config.batch_size)
    # idx是隨機的
    # for idx in [random.randint(0, dcgan.z_dim - 1) for _ in xrange(dcgan.z_dim)]:
    for idx in xrange(dcgan.z_dim):
      print(" [*] %d" % idx)
      # z_dim:test_images_num
      #z = np.random.uniform(-0.2, 0.2, size=(dcgan.z_dim))
      # np.tile:按照指定的維度將array重復
      # z_sample shape:(batch_size,z_dim)
      #z_sample = np.tile(z, (config.batch_size, 1))
      #z_sample = np.zeros([config.batch_size, dcgan.z_dim])
      # for kdx, z in enumerate(z_sample):
      #   z[idx] = values[kdx]
      z_sample = np.random.uniform(-1, 1, size=(config.batch_size, dcgan.z_dim))
      if config.dataset == "mnist":
        y = np.random.choice(10, config.batch_size)
        #save_random_digits(y, image_frame_dim, image_frame_dim, './%s/test_%s.txt' % % (config.sample_dir,strftime("%Y-%m-%d-%H-%M-%S", gmtime())))
        y_one_hot = np.zeros((config.batch_size, 10))
        y_one_hot[np.arange(config.batch_size), y] = 1

        samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample, dcgan.y: y_one_hot})
      else:
        samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
      try:
        make_gif(samples, './%s/test_gif_%s.gif' % (config.sample_dir,idx),4)
      except:
        save_images(samples, [image_frame_dim, image_frame_dim], './%s/test_%s.png' % (config.sample_dir,strftime("%Y-%m-%d-%H-%M-%S", gmtime())))
  elif option == 3: # 不能是mnist,直接生成gif
    # values = np.arange(0, 1, 1./config.batch_size)
    for idx in xrange(dcgan.z_dim):
      print(" [*] %d" % idx)
      # z_sample = np.zeros([config.batch_size, dcgan.z_dim])
      # for kdx, z in enumerate(z_sample):
      #   z[idx] = values[kdx]
      z_sample = np.random.uniform(-1, 1, size=(config.batch_size, dcgan.z_dim))
      samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
      make_gif(samples, './%s/test_gif_%s.gif' % (config.sample_dir,idx),4)
  elif option == 4:
    image_set = []
    # values = np.arange(0, 1, 1./config.batch_size)

    for idx in xrange(dcgan.z_dim):
      print(" [*] %d" % idx)
      # z_sample = np.zeros([config.batch_size, dcgan.z_dim])
      # for kdx, z in enumerate(z_sample): z[idx] = values[kdx]
      z_sample = np.random.uniform(-1, 1, size=(config.batch_size, dcgan.z_dim))
      image_set.append(sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}))
      make_gif(image_set[-1], './%s/test_gif_%s.gif' % (config.sample_dir,idx),12)

    # 合成一張大圖gif(64張大圖)
    new_image_set = [merge(np.array([images[idx] for images in image_set]), [10, 10]) \
        for idx in range(63, -1, -1)] # 63-0
    make_gif(new_image_set, './%s/test_gif_merged.gif' % config.sample_dir, duration=8)

  elif option == 5:
    #保存單個的小圖
    z_sample = np.random.uniform(-1, 1, size=(config.batch_size, dcgan.z_dim))
    samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
    for i,sample in enumerate(samples):
      scipy.misc.imsave("./%s/single_test_%s.png" %(config.sample_dir,i),sample)

主要的修改是將所有的采樣方式都改為(-1,1)的均勻分布: z_sample = np.random.uniform(-1, 1, size=(config.batch_size, dcgan.z_dim)) 。實驗發現,這種方式在test的時候是非常有效的。另外,我保留了option=0的情況不變,增加了option=-1的情況以及option=5的情況。option=5表示將生成的圖片按小圖保存。下面的幾張圖展示了test的結果:

圖12 girl_face 隨機選取的一個test 生成圖像大圖
圖13 girl_face 隨機選取的幾張test 生成圖像小圖合集
圖14 girl_face 隨機選取的生成gif圖像
圖15 girl_face 生成的大圖gif圖像

4. 總結

    本文詳細解讀了DCGAN代碼的tensorflow實現,並在mnist,celebA,以及自定義的數據集beauty_girs和girl_face數據集上進行了訓練,測試。我們發現DCGAN確實在一定程度上提高了GAN訓練的穩定性(不太容易發生mode collapse的情況),而且生成的圖片質量如果數據集數量較高、訓練充分,還是很不錯的。但是如果訓練時間過長,還是可能會發生mode collapse的情況,而且訓練結果的質量也很取決於數據集的質量,數據集最好足夠大(至少1w+吧),而且圖片的風格最好是一致的,否則可能無法得到讓人滿意的結果(就像beauty_girls那樣)。

本文完,感謝閱讀!


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM