通過GAN生成式對抗網絡,產生mnist數據
引入包,數據約定等
import numpy as np
import matplotlib.pyplot as plt
import input_data #讀取數據的一個工具文件,不影響理解
import tensorflow as tf
# 獲取數據
mnist = input_data.read_data_sets('data/', one_hot=True)
trainimg = mnist.train.images
X = mnist.train.images[:, :]
batch_size = 64
#用來返回真實數據
def iterate_minibatch(x, batch_size, shuffle=True):
indices = np.arange(x.shape[0])
if shuffle:
np.random.shuffle(indices)
for i in range(0, x.shape[0]-1000, batch_size):
temp = x[indices[i:i + batch_size], :]
temp = np.array(temp) * 2 - 1
yield np.reshape(temp, (-1, 28, 28, 1))
GAN對象結構
class GAN(object):
def __init__(self):
#初始函數,在這里對初始化模型
def netG(self, z):
#生成器模型
def netD(self, x, reuse=False):
#判別器模型
生成器函數
對隨機值z(維度為1,100),進行包裝,偽造,產生偽造數據。
包裝過程概括為:全連接->reshape->反卷積
包裝過程中使用了batch_normalization,Leaky ReLU,dropout,tanh等技巧
#對隨機值z(維度為1,100),進行包裝,偽造,產生偽造數據。
#包裝過程概括為:全連接->reshape->反卷積
#包裝過程中使用了batch_normalization,Leaky ReLU,dropout,tanh等技巧
def netG(self,z,alpha=0.01):
with tf.variable_scope('generator') as scope:
layer1 = tf.layers.dense(z, 4 * 4 * 512) # 這是一個全連接層,輸出 (n,4*4*512)
layer1 = tf.reshape(layer1, [-1, 4, 4, 512])
# batch normalization
layer1 = tf.layers.batch_normalization(layer1, training=True) # 做BN標准化處理
# Leaky ReLU
layer1 = tf.maximum(alpha * layer1, layer1)
# dropout
layer1 = tf.nn.dropout(layer1, keep_prob=0.8)
# 4 x 4 x 512 to 7 x 7 x 256
layer2 = tf.layers.conv2d_transpose(layer1, 256, 4, strides=1, padding='valid')
layer2 = tf.layers.batch_normalization(layer2, training=True)
layer2 = tf.maximum(alpha * layer2, layer2)
layer2 = tf.nn.dropout(layer2, keep_prob=0.8)
# 7 x 7 256 to 14 x 14 x 128
layer3 = tf.layers.conv2d_transpose(layer2, 128, 3, strides=2, padding='same')
layer3 = tf.layers.batch_normalization(layer3, training=True)
layer3 = tf.maximum(alpha * layer3, layer3)
layer3 = tf.nn.dropout(layer3, keep_prob=0.8)
# 14 x 14 x 128 to 28 x 28 x 1
logits = tf.layers.conv2d_transpose(layer3, 1, 3, strides=2, padding='same')
# MNIST原始數據集的像素范圍在0-1,這里的生成圖片范圍為(-1,1)
# 因此在訓練時,記住要把MNIST像素范圍進行resize
outputs = tf.tanh(logits)
return outputs
判別器函數
通過深度卷積+全連接的形式,判別器將輸入分類為真數據,還是假數據。
def netD(self, x, reuse=False,alpha=0.01):
with tf.variable_scope('discriminator') as scope:
if reuse:
scope.reuse_variables()
layer1 = tf.layers.conv2d(x, 128, 3, strides=2, padding='same')
layer1 = tf.maximum(alpha * layer1, layer1)
layer1 = tf.nn.dropout(layer1, keep_prob=0.8)
# 14 x 14 x 128 to 7 x 7 x 256
layer2 = tf.layers.conv2d(layer1, 256, 3, strides=2, padding='same')
layer2 = tf.layers.batch_normalization(layer2, training=True)
layer2 = tf.maximum(alpha * layer2, layer2)
layer2 = tf.nn.dropout(layer2, keep_prob=0.8)
# 7 x 7 x 256 to 4 x 4 x 512
layer3 = tf.layers.conv2d(layer2, 512, 3, strides=2, padding='same')
layer3 = tf.layers.batch_normalization(layer3, training=True)
layer3 = tf.maximum(alpha * layer3, layer3)
layer3 = tf.nn.dropout(layer3, keep_prob=0.8)
# 4 x 4 x 512 to 4*4*512 x 1
flatten = tf.reshape(layer3, (-1, 4 * 4 * 512))
f = tf.layers.dense(flatten, 1)
return f
初始化函數
有一個前置訓練,將真實數據喂給判別器,訓練判別器的鑒別能力
# 有一個前置訓練,將真實數據喂給判別器,訓練判別器的鑒別能力
def __init__(self):
self.z = tf.placeholder(tf.float32, shape=[batch_size, 100], name='z') # 隨機輸入值
self.x = tf.placeholder(tf.float32, shape=[batch_size, 28, 28, 1], name='real_x') # 圖片值
self.fake_x = self.netG(self.z) # 將隨機輸入,包裝為偽造圖片值
self.pre_logits = self.netD(self.x, reuse=False) # 判別器預訓練時,判別器對真實數據的判別情況-未sigmoid處理
self.real_logits = self.netD(self.x, reuse=True) # 判別器對真實數據的判別情況-未sigmoid處理
self.fake_logits = self.netD(self.fake_x, reuse=True) # 判別器對偽造數據的判別情況-未sigmoid處理
# 預訓練時判別器,判別器將真實數據判定為真的得分情況。
self.loss_pre_D = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.pre_logits,
labels=tf.ones_like(self.pre_logits)))
# 訓練時,判別器將真實數據判定為真,將偽造數據判定為假的得分情況。
self.loss_D = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.real_logits,
labels=tf.ones_like(self.real_logits))) + \
tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_logits,
labels=tf.zeros_like(self.fake_logits)))
# 訓練時,生成器偽造的數據,被判定為真實數據的得分情況。
self.loss_G = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_logits,
labels=tf.ones_like(self.fake_logits)))
# 獲取生成器和判定器對應的變量地址,用於更新變量
t_vars = tf.trainable_variables()
self.g_vars = [var for var in t_vars if var.name.startswith("generator")]
self.d_vars = [var for var in t_vars if var.name.startswith("discriminator")]
開始訓練
gan = DCGAN()
#預訓練時的梯度優化函數
d_pre_optim = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.4).minimize(gan.loss_pre_D, var_list=gan.d_vars)
#判別器的梯度優化函數
d_optim = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.4).minimize(gan.loss_D, var_list=gan.d_vars)
#預訓練時的梯度優化函數
g_optim = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.4).minimize(gan.loss_G, var_list=gan.g_vars)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
#對判別器的預訓練,訓練了兩個epoch
for i in range(2):
print('判別器初始訓練,第' + str(i) + '次包')
for x_batch in iterate_minibatch(X, batch_size=batch_size):
loss_pre_D, _ = sess.run([gan.pre_logits, d_pre_optim],
feed_dict={
gan.x: x_batch
})
#訓練5個epoch
for epoch in range(5):
print('對抗' + str(epoch) + '次包')
avg_loss = 0
count = 0
for x_batch in iterate_minibatch(X, batch_size=batch_size):
z_batch = np.random.uniform(-1, 1, size=(batch_size, 100)) # 隨機起點值
loss_D, _ = sess.run([gan.loss_D, d_optim],
feed_dict={
gan.z: z_batch,
gan.x: x_batch
})
loss_G, _ = sess.run([gan.loss_G, g_optim],
feed_dict={
gan.z: z_batch,
# gan.x: np.zeros(z_batch.shape)
})
avg_loss += loss_D
count += 1
# 顯示預測情況
if True:
avg_loss /= count
z = np.random.normal(size=(batch_size, 100))
excerpt = np.random.randint(100, size=batch_size)
needTest = np.reshape(X[excerpt, :], (-1, 28, 28, 1))
fake_x, real_logits, fake_logits = sess.run([gan.fake_x, gan.real_logits, gan.fake_logits],
feed_dict={gan.z: z, gan.x: needTest})
# accuracy = (np.sum(real_logits > 0.5) + np.sum(fake_logits < 0.5)) / (2 * batch_size)
print('real_logits')
print(len(real_logits))
print('fake_logits')
print(len(fake_logits))
print('\ndiscriminator loss at epoch %d: %f' % (epoch, avg_loss))
# print('\ndiscriminator accuracy at epoch %d: %f' % (epoch, accuracy))
print('----')
print()
# curr_img = np.reshape(trainimg[i, :], (28, 28)) # 28 by 28 matrix
curr_img = np.reshape(fake_x[0], (28, 28))
plt.matshow(curr_img, cmap=plt.get_cmap('gray'))
plt.show()
curr_img2 = np.reshape(fake_x[10], (28, 28))
plt.matshow(curr_img2, cmap=plt.get_cmap('gray'))
plt.show()
curr_img3 = np.reshape(fake_x[20], (28, 28))
plt.matshow(curr_img3, cmap=plt.get_cmap('gray'))
plt.show()
curr_img4 = np.reshape(fake_x[30], (28, 28))
plt.matshow(curr_img4, cmap=plt.get_cmap('gray'))
plt.show()
curr_img5 = np.reshape(fake_x[40], (28, 28))
plt.matshow(curr_img5, cmap=plt.get_cmap('gray'))
plt.show()
# plt.figure(figsize=(28, 28))
# plt.title("" + str(i) + "th Training Data "
# + "Label is " + str(curr_label))
# print("" + str(i) + "th Training Data "
# + "Label is " + str(curr_label))
# plt.scatter(X[:, 0], X[:, 1])
# plt.scatter(fake_x[:, 0], fake_x[:, 1])
# plt.show()