GAN生成式對抗網絡(三)——mnist數據生成


通過GAN生成式對抗網絡,產生mnist數據

引入包,數據約定等

import numpy as np
import matplotlib.pyplot as plt
import input_data  #讀取數據的一個工具文件,不影響理解
import tensorflow as tf


# 獲取數據
mnist = input_data.read_data_sets('data/', one_hot=True)
trainimg = mnist.train.images

X = mnist.train.images[:, :]
batch_size = 64

#用來返回真實數據
def iterate_minibatch(x, batch_size, shuffle=True):
    indices = np.arange(x.shape[0])
    if shuffle:
        np.random.shuffle(indices)
    for i in range(0, x.shape[0]-1000, batch_size):
        temp = x[indices[i:i + batch_size], :]
        temp = np.array(temp) * 2 - 1
        yield np.reshape(temp, (-1, 28, 28, 1))

GAN對象結構

class GAN(object):
    def __init__(self):
        #初始函數,在這里對初始化模型
    def netG(self, z):
        #生成器模型
    def netD(self, x, reuse=False):
        #判別器模型

生成器函數

對隨機值z(維度為1,100),進行包裝,偽造,產生偽造數據。
包裝過程概括為:全連接->reshape->反卷積
包裝過程中使用了batch_normalization,Leaky ReLU,dropout,tanh等技巧

   #對隨機值z(維度為1,100),進行包裝,偽造,產生偽造數據。
    #包裝過程概括為:全連接->reshape->反卷積
    #包裝過程中使用了batch_normalization,Leaky ReLU,dropout,tanh等技巧
    def netG(self,z,alpha=0.01):
        with tf.variable_scope('generator') as scope:
            layer1 = tf.layers.dense(z, 4 * 4 * 512)  # 這是一個全連接層,輸出 (n,4*4*512)
            layer1 = tf.reshape(layer1, [-1, 4, 4, 512])
            # batch normalization
            layer1 = tf.layers.batch_normalization(layer1, training=True)  # 做BN標准化處理
            # Leaky ReLU
            layer1 = tf.maximum(alpha * layer1, layer1)
            # dropout
            layer1 = tf.nn.dropout(layer1, keep_prob=0.8)

            # 4 x 4 x 512 to 7 x 7 x 256
            layer2 = tf.layers.conv2d_transpose(layer1, 256, 4, strides=1, padding='valid')
            layer2 = tf.layers.batch_normalization(layer2, training=True)
            layer2 = tf.maximum(alpha * layer2, layer2)
            layer2 = tf.nn.dropout(layer2, keep_prob=0.8)

            # 7 x 7 256 to 14 x 14 x 128
            layer3 = tf.layers.conv2d_transpose(layer2, 128, 3, strides=2, padding='same')
            layer3 = tf.layers.batch_normalization(layer3, training=True)
            layer3 = tf.maximum(alpha * layer3, layer3)
            layer3 = tf.nn.dropout(layer3, keep_prob=0.8)

            # 14 x 14 x 128 to 28 x 28 x 1
            logits = tf.layers.conv2d_transpose(layer3, 1, 3, strides=2, padding='same')
            # MNIST原始數據集的像素范圍在0-1,這里的生成圖片范圍為(-1,1)
            # 因此在訓練時,記住要把MNIST像素范圍進行resize
            outputs = tf.tanh(logits)

            return outputs

判別器函數

通過深度卷積+全連接的形式,判別器將輸入分類為真數據,還是假數據。

    def netD(self, x, reuse=False,alpha=0.01):
        with tf.variable_scope('discriminator') as scope:
            if reuse:
                scope.reuse_variables()
            layer1 = tf.layers.conv2d(x, 128, 3, strides=2, padding='same')
            layer1 = tf.maximum(alpha * layer1, layer1)
            layer1 = tf.nn.dropout(layer1, keep_prob=0.8)

            # 14 x 14 x 128 to 7 x 7 x 256
            layer2 = tf.layers.conv2d(layer1, 256, 3, strides=2, padding='same')
            layer2 = tf.layers.batch_normalization(layer2, training=True)
            layer2 = tf.maximum(alpha * layer2, layer2)
            layer2 = tf.nn.dropout(layer2, keep_prob=0.8)

            # 7 x 7 x 256 to 4 x 4 x 512
            layer3 = tf.layers.conv2d(layer2, 512, 3, strides=2, padding='same')
            layer3 = tf.layers.batch_normalization(layer3, training=True)
            layer3 = tf.maximum(alpha * layer3, layer3)
            layer3 = tf.nn.dropout(layer3, keep_prob=0.8)

            # 4 x 4 x 512 to 4*4*512 x 1
            flatten = tf.reshape(layer3, (-1, 4 * 4 * 512))
            f = tf.layers.dense(flatten, 1)
            return f

初始化函數

有一個前置訓練,將真實數據喂給判別器,訓練判別器的鑒別能力

    # 有一個前置訓練,將真實數據喂給判別器,訓練判別器的鑒別能力
    def __init__(self):
        self.z = tf.placeholder(tf.float32, shape=[batch_size, 100], name='z')  # 隨機輸入值
        self.x = tf.placeholder(tf.float32, shape=[batch_size, 28, 28, 1], name='real_x')  # 圖片值

        self.fake_x = self.netG(self.z)  # 將隨機輸入,包裝為偽造圖片值

        self.pre_logits = self.netD(self.x, reuse=False)  # 判別器預訓練時,判別器對真實數據的判別情況-未sigmoid處理
        self.real_logits = self.netD(self.x, reuse=True)  # 判別器對真實數據的判別情況-未sigmoid處理
        self.fake_logits = self.netD(self.fake_x, reuse=True)  # 判別器對偽造數據的判別情況-未sigmoid處理

        # 預訓練時判別器,判別器將真實數據判定為真的得分情況。
        self.loss_pre_D = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.pre_logits,
                                                                                 labels=tf.ones_like(self.pre_logits)))
        # 訓練時,判別器將真實數據判定為真,將偽造數據判定為假的得分情況。
        self.loss_D = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.real_logits,
                                                                             labels=tf.ones_like(self.real_logits))) + \
                      tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_logits,
                                                                             labels=tf.zeros_like(self.fake_logits)))
        # 訓練時,生成器偽造的數據,被判定為真實數據的得分情況。
        self.loss_G = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_logits,
                                                                             labels=tf.ones_like(self.fake_logits)))

        # 獲取生成器和判定器對應的變量地址,用於更新變量
        t_vars = tf.trainable_variables()
        self.g_vars = [var for var in t_vars if var.name.startswith("generator")]
        self.d_vars = [var for var in t_vars if var.name.startswith("discriminator")]

開始訓練

gan = DCGAN()
#預訓練時的梯度優化函數
d_pre_optim = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.4).minimize(gan.loss_pre_D, var_list=gan.d_vars)
#判別器的梯度優化函數
d_optim = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.4).minimize(gan.loss_D, var_list=gan.d_vars)
#預訓練時的梯度優化函數
g_optim = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.4).minimize(gan.loss_G, var_list=gan.g_vars)

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    #對判別器的預訓練,訓練了兩個epoch
    for i in range(2):
        print('判別器初始訓練,第' + str(i) + '次包')
        for x_batch in iterate_minibatch(X, batch_size=batch_size):
            loss_pre_D, _ = sess.run([gan.pre_logits, d_pre_optim],
                                     feed_dict={
                                         gan.x: x_batch
                                     })
    #訓練5個epoch
    for epoch in range(5):
        print('對抗' + str(epoch) + '次包')
        avg_loss = 0
        count = 0
        for x_batch in iterate_minibatch(X, batch_size=batch_size):
            z_batch = np.random.uniform(-1, 1, size=(batch_size, 100))  # 隨機起點值

            loss_D, _ = sess.run([gan.loss_D, d_optim],
                                 feed_dict={
                                     gan.z: z_batch,
                                     gan.x: x_batch
                                 })

            loss_G, _ = sess.run([gan.loss_G, g_optim],
                                 feed_dict={
                                     gan.z: z_batch,
                                     # gan.x: np.zeros(z_batch.shape)
                                 })

            avg_loss += loss_D
            count += 1

        # 顯示預測情況
        if True:
            avg_loss /= count
            z = np.random.normal(size=(batch_size, 100))
            excerpt = np.random.randint(100, size=batch_size)
            needTest = np.reshape(X[excerpt, :], (-1, 28, 28, 1))
            fake_x, real_logits, fake_logits = sess.run([gan.fake_x, gan.real_logits, gan.fake_logits],
                                                        feed_dict={gan.z: z, gan.x: needTest})
            # accuracy = (np.sum(real_logits > 0.5) + np.sum(fake_logits < 0.5)) / (2 * batch_size)
            print('real_logits')
            print(len(real_logits))
            print('fake_logits')
            print(len(fake_logits))
            print('\ndiscriminator loss at epoch %d: %f' % (epoch, avg_loss))
            # print('\ndiscriminator accuracy at epoch %d: %f' % (epoch, accuracy))
            print('----')
            print()

            # curr_img = np.reshape(trainimg[i, :], (28, 28))  # 28 by 28 matrix
            curr_img = np.reshape(fake_x[0], (28, 28))
            plt.matshow(curr_img, cmap=plt.get_cmap('gray'))
            plt.show()
            curr_img2 = np.reshape(fake_x[10], (28, 28))
            plt.matshow(curr_img2, cmap=plt.get_cmap('gray'))
            plt.show()
            curr_img3 = np.reshape(fake_x[20], (28, 28))
            plt.matshow(curr_img3, cmap=plt.get_cmap('gray'))
            plt.show()

            curr_img4 = np.reshape(fake_x[30], (28, 28))
            plt.matshow(curr_img4, cmap=plt.get_cmap('gray'))
            plt.show()

            curr_img5 = np.reshape(fake_x[40], (28, 28))
            plt.matshow(curr_img5, cmap=plt.get_cmap('gray'))
            plt.show()
            # plt.figure(figsize=(28, 28))

            # plt.title("" + str(i) + "th Training Data "
            #           + "Label is " + str(curr_label))
            # print("" + str(i) + "th Training Data "
            #       + "Label is " + str(curr_label))

            # plt.scatter(X[:, 0], X[:, 1])
            # plt.scatter(fake_x[:, 0], fake_x[:, 1])
            # plt.show()

結果

下載鏈接


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM