GAN模型生成手寫字


概述:在前期的文章中,我們用TensorFlow完成了對手寫數字的識別,得到了94.09%的識別准確度,效果還算不錯。在這篇文章中,筆者將帶領大家用GAN模型,生成我們想要的手寫數字。

GAN簡介

對抗性生成網絡(GenerativeAdversarial Network),由 Ian Goodfellow 首先提出,由兩個網絡組成,分別是generator網絡(用於生成)和discriminator網絡(用於判別)。GAN網絡的目的就是使其自己生成一副圖片,比如說經過對一系列貓的圖片的學習,generator網絡可以自己“繪制”出一張貓的圖片,且盡量真實。discriminator網絡則是用來進行判斷的,將一張真實的圖片和一張由generator網絡生成的照片同時交給discriminator網絡,不斷訓練discriminator網絡,使其可以准確將discriminator網絡生成的“假圖片”找出來。就這樣,generator網絡不斷改進使其可以騙過discriminator網絡,而discriminator網絡不斷改進使其可以更准確找到“假圖片”,這種相互促進相互對抗的關系,就叫做對抗網絡。圖一中展示了GAN模型的結構。

思路梳理

將MNIST數據集中標簽為0的圖片提取出來,然后訓練discriminator網絡,進行手寫數字0識別,接着讓generator產生一張隨機圖片,讓訓練好的discriminator去識別這張生成的圖片,不斷訓練discriminator,直到discriminator網絡將生成的圖片當做數字0為止。

生成“假圖片

生成一張隨機像素的28*28的圖片,分別進行全連接,Leaky ReLU函數激活,dropout處理(隨機丟棄一些神經元,防止過擬合),全連接,tanh函數激活,最終生成一張“假圖片”,TensorFlow代碼如下:

 

1def get_generator(noise_img, n_units, out_dim, reuse=False, alpha=0.01):
2    with tf.variable_scope("generator", reuse=reuse):
3        hidden1 = tf.layers.dense(noise_img, n_units)  # 全連接層
4        hidden1 = tf.maximum(alpha * hidden1, hidden1)
5        hidden1 = tf.layers.dropout(hidden1, rate=0.2)
6        logits = tf.layers.dense(hidden1, out_dim)
7        outputs = tf.tanh(logits)
8        return logits, outputs

圖像判別

將需要進行判別的圖片先后經過全連接,Leaky ReLU函數激活,全連接,sigmoid函數激活處理,最終輸出圖片的識別結果,TensorFlow代碼如下:

1def get_discriminator(img, n_units, reuse=False, alpha=0.01):
2    with tf.variable_scope("discriminator", reuse=reuse):
3        hidden1 = tf.layers.dense(img, n_units)
4        hidden1 = tf.maximum(alpha * hidden1, hidden1)
5        logits = tf.layers.dense(hidden1, 1)
6        outputs = tf.sigmoid(logits)
7        return logits, outputs

完整代碼

GAN手寫數字識別的完整代碼如下:

  1import tensorflow as tf
 2from tensorflow.examples.tutorials.mnist import input_data
 3import matplotlib.pyplot as plt
 4import numpy as np
 5
 6mnist = input_data.read_data_sets("E:/Tensor/MNIST_data/")
 7img = mnist.train.images[50]
 8
 9
10def get_inputs(real_size, noise_size):
11    real_img = tf.placeholder(tf.float32, [None, real_size], name="real_img")
12    noise_img = tf.placeholder(tf.float32, [None, noise_size], name="noise_img")
13    return real_img, noise_img
14
15
16# 生成圖像
17def get_generator(noise_img, n_units, out_dim, reuse=False, alpha=0.01):
18    with tf.variable_scope("generator", reuse=reuse):
19        hidden1 = tf.layers.dense(noise_img, n_units)  # 全連接層
20        hidden1 = tf.maximum(alpha * hidden1, hidden1)
21        hidden1 = tf.layers.dropout(hidden1, rate=0.2)
22        logits = tf.layers.dense(hidden1, out_dim)
23        outputs = tf.tanh(logits)
24        return logits, outputs
25
26
27# 圖像判別
28def get_discriminator(img, n_units, reuse=False, alpha=0.01):
29    with tf.variable_scope("discriminator", reuse=reuse):
30        hidden1 = tf.layers.dense(img, n_units)
31        hidden1 = tf.maximum(alpha * hidden1, hidden1)
32        logits = tf.layers.dense(hidden1, 1)
33        outputs = tf.sigmoid(logits)
34        return logits, outputs
35#真實圖像size
36img_size = mnist.train.images[0].shape[0]
37#傳入generator的噪聲size
38noise_size = 100
39#生成器隱層參數
40g_units = 128
41#判別器隱層參數
42d_units = 128
43#Leaky ReLU參數
44alpha = 0.01
45#學習率
46learning_rate = 0.001
47#label smoothing
48smooth = 0.1
49tf.reset_default_graph()
50real_img, noise_img = get_inputs(img_size, noise_size)
51g_logits, g_outputs = get_generator(noise_img, g_units, img_size)
52
53d_logits_real, d_outputs_real = get_discriminator(real_img, d_units)
54d_logits_fake, d_outputs_fake = get_discriminator(g_outputs, d_units, reuse=True)
55
56d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
57    logits=d_logits_real, labels=tf.ones_like(d_logits_real)
58) * (1 - smooth))
59d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
60    logits=d_logits_fake, labels=tf.zeros_like(d_logits_fake)
61))
62d_loss = tf.add(d_loss_real, d_loss_fake)
63g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
64    logits=d_logits_fake, labels=tf.ones_like(d_logits_fake)
65) * (1 - smooth))
66
67train_vars = tf.trainable_variables()
68g_vars = [var for var in train_vars if var.name.startswith("generator")]
69d_vars = [var for var in train_vars if var.name.startswith("discriminator")]
70
71d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
72g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)
73
74
75epochs = 10000
76samples = []
77n_sample = 10
78losses = []
79
80i = j = 0
81while i<10000:
82    if mnist.train.labels[j] == 0:
83        samples.append(mnist.train.images[j])
84        i += 1
85    j += 1
86
87print(len(samples))
88size = samples[0].size
89
90with tf.Session() as sess:
91    tf.global_variables_initializer().run()
92    for e in range(epochs):
93        batch_images = samples[e] * -1
94        batch_noise = np.random.uniform(-1, 1, size=noise_size)
95
96        _ = sess.run(d_train_opt, feed_dict={real_img:[batch_images], noise_img:[batch_noise]})
97        _ = sess.run(g_train_opt, feed_dict={noise_img:[batch_noise]})
98
99    sample_noise = np.random.uniform(-1, 1, size=noise_size)
100    g_logit, g_output = sess.run(get_generator(noise_img, g_units, img_size,
101                                         reuse=True), feed_dict={
102        noise_img:[sample_noise]
103    })
104    print(g_logit.size)
105    g_output = (g_output+1)/2
106    plt.imshow(g_output.reshape([28, 28]), cmap='Greys_r')
107    plt.show()

 

訓練效果

在經過了10000次的迭代后,generator網絡生成的圖片已經接近手寫數字零的形狀。

  

  本文是對GAN模型的初次探索,在后續GAN模型的系列文章中,筆者將層層深入的去講解GAN模型復雜的應用。

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM