使用生成對抗網絡(GAN)生成手寫字


先放結果


這是通過GAN迭代訓練30W次,耗時3小時生成的手寫字圖片效果,大部分的還是能看出來是數字的。

實現原理

簡單說下原理,生成對抗網絡需要訓練兩個任務,一個叫生成器,一個叫判別器,如字面意思,一個負責生成圖片,一個負責判別圖片,生成器不斷生成新的圖片,然后判別器去判斷哪兒哪兒不行,生成器再不斷去改進,不斷的像真實的圖片靠近。

這就如同一個造假團伙一樣,A負責生產,B負責就鑒定,剛開始的時候,兩個人都是菜鳥,A隨便畫了一幅畫拿給B看,B說你這不行,然后A再改進,當然需要改進的不止A,隨着A的改進,B也得不斷提升,B需要發現更細微的差異,直至他們覺得已經沒什么差異了(實際肯定還存在差異),他們便決定停止"訓練",開始賣吧。

實現代碼
# -*- coding: utf-8 -*-

# @author: Awesome_Tang
# @date: 2019-02-22
# @version: python2.7


import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from datetime import datetime
import numpy as np
import os
import matplotlib.pyplot as plt

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'


class Config:
    alpha = 1e-2
    drop_rate = 0.5  # 保留比例
    steps = 300000  # 迭代次數
    batch_size = 128  # 每批次訓練樣本數
    epochs = 100  # 訓練輪次

    num_units = 128
    size = 784
    noise_size = 100

    smooth = 0.01
    learning_rate = 1e-4

    print_per_step = 1000


class Gan:

    def __init__(self):
        print('Loading data......')
        # 讀取MNIST數據集
        self.mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

        # 定義占位符,真實圖片和生成的圖片
        self.real_images = tf.placeholder(tf.float32, [None, Config.size], name='real_images')
        self.noise = tf.placeholder(tf.float32, [None, Config.noise_size], name='noise')
        self.drop_rate = tf.placeholder('float')

        self.train_step()

    def generator_graph(self, noise, n_units, out_dim, alpha, reuse=False):

        with tf.variable_scope('generator', reuse=reuse):
            # Hidden layer
            h1 = tf.layers.dense(noise, n_units, activation=None)
            # Leaky ReLU
            h1 = tf.maximum(alpha * h1, h1)
            h1 = tf.layers.dropout(h1, rate=self.drop_rate)
            # Logits and tanh output
            logits = tf.layers.dense(h1, out_dim, activation=None)
            out = tf.tanh(logits)

        return out

    @staticmethod
    def discriminator_graph(image, n_units, alpha, reuse=False):

        with tf.variable_scope('discriminator', reuse=reuse):
            # Hidden layer
            h1 = tf.layers.dense(image, n_units, activation=None)
            # Leaky ReLU
            h1 = tf.maximum(alpha * h1, h1)

            logits = tf.layers.dense(h1, 1, activation=None)
            # out = tf.sigmoid(logits)

        return logits

    def net(self):
        # generator
        fake_image = self.generator_graph(self.noise, Config.num_units, Config.size, Config.alpha)

        # discriminator
        real_logits = self.discriminator_graph(self.real_images, Config.num_units, Config.alpha)
        fake_logits = self.discriminator_graph(fake_image, Config.num_units, Config.alpha, reuse=True)

        # discriminator的loss
        # 識別真實圖片
        d_loss_real = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=real_logits, labels=tf.ones_like(real_logits)) * (
                    1 - Config.smooth))
        # 識別生成的圖片
        d_loss_fake = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits, labels=tf.zeros_like(fake_logits)))
        # 總體loss
        d_loss = tf.add(d_loss_real, d_loss_fake)

        # generator的loss
        g_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits, labels=tf.ones_like(fake_logits)) * (
                    1 - Config.smooth))

        net_vars = tf.trainable_variables()

        # generator中的tensor
        g_vars = [var for var in net_vars if var.name.startswith("generator")]
        # discriminator中的tensor
        d_vars = [var for var in net_vars if var.name.startswith("discriminator")]

        # optimizer
        dis_optimizer = tf.train.AdamOptimizer(Config.learning_rate).minimize(d_loss, var_list=d_vars)
        gen_optimizer = tf.train.AdamOptimizer(Config.learning_rate).minimize(g_loss, var_list=g_vars)

        return dis_optimizer, gen_optimizer, d_loss, g_loss

    def train_step(self):
        dis_optimizer, gen_optimizer, d_loss, g_loss = self.net()

        print('Training & Evaluating......')
        start_time = datetime.now()
        sess = tf.Session()
        sess.run(tf.global_variables_initializer())

        for step in range(Config.steps):
            real_image, _ = self.mnist.train.next_batch(Config.batch_size)

            real_image = real_image * 2 - 1

            # generator的輸入噪聲
            batch_noise = np.random.uniform(-1, 1, size=(Config.batch_size, Config.noise_size))

            sess.run(gen_optimizer, feed_dict={self.noise: batch_noise, self.drop_rate: Config.drop_rate})
            sess.run(dis_optimizer, feed_dict={self.noise: batch_noise, self.real_images: real_image})

            if step % Config.print_per_step == 0:
                dis_loss = sess.run(d_loss, feed_dict={self.noise: batch_noise, self.real_images: real_image})
                gen_loss = sess.run(g_loss, feed_dict={self.noise: batch_noise, self.drop_rate: 1.})
                end_time = datetime.now()
                time_diff = (end_time - start_time).seconds

                msg = 'Step {:3}k Dis_Loss:{:6.2f}, Gen_Loss:{:6.2f}, Time_Usage:{:6.2f} mins.'
                print(msg.format(int(step / 1000), dis_loss, gen_loss, time_diff / 60.))

        self.gen_image(sess)

    def gen_image(self, sess):
        sample_noise = np.random.uniform(-1, 1, size=(25, Config.noise_size))
        samples = sess.run(
            self.generator_graph(self.noise, Config.num_units, Config.size, Config.alpha, reuse=True),
            feed_dict={self.noise: sample_noise})

        plt.figure(figsize=(8, 8), dpi=80)
        for i in range(25):
            img = samples[i]
            plt.subplot(5, 5, i + 1)
            plt.imshow(img.reshape((28, 28)), cmap='Greys_r')
            plt.axis('off')
        plt.show()


if __name__ == "__main__":
    Gan()


Peace~~


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM