第二十一節,條件變分自編碼


一 條件變分自編碼(CVAE)

變分自編碼存在一個問題,雖然可以生成一個樣本,但是只能輸出與輸入圖片相同類別的樣本。雖然也可以隨機從符合模型生成的高斯分布中取數據來還原成樣本,但是這樣的話餓哦們並不知道生成的樣本屬於哪個類別。條件變分編碼則可以解決這個問題,讓網絡按指定的類別生成樣本。

在變分自編碼的基礎上,再取理解條件編碼自編碼會很容易。主要的改動是,在訓練測試時加入一個one-hot向量,用於表示標簽向量。其實就是給編碼自編碼網絡加入一個條件,讓網絡學習圖片時加入標簽因素,這樣就可以按照指定的標簽生成圖片。 

二 CVAE實例 

在編碼節點需要在輸入端添加標簽對應的特征,在解碼階段同樣也需要將標簽加入輸入,這樣,再解碼的結果向原始的輸入樣本不斷逼近,最終得到的模型會把輸入的標簽的特征當成MNIST數據的一部分,從而實現通過標簽生成指定的圖片。

 該程序在上一節程序上作了一些修改,代碼如下:

# -*- coding: utf-8 -*-
"""
Created on Thu May 31 15:34:08 2018

@author: zy
"""

'''
條件變分自編碼
'''


import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data


mnist = input_data.read_data_sets('MNIST-data',one_hot=True)

print(type(mnist)) #<class 'tensorflow.contrib.learn.python.learn.datasets.base.Datasets'>

print('Training data shape:',mnist.train.images.shape)           #Training data shape: (55000, 784)
print('Test data shape:',mnist.test.images.shape)                #Test data shape: (10000, 784)
print('Validation data shape:',mnist.validation.images.shape)    #Validation data shape: (5000, 784)
print('Training label shape:',mnist.train.labels.shape)          #Training label shape: (55000, 10)

train_X = mnist.train.images
train_Y = mnist.train.labels
test_X = mnist.test.images
test_Y = mnist.test.labels


'''
定義網絡參數
'''
n_input = 784
n_hidden_1 = 256
n_hidden_2 = 2
n_classes = 10
learning_rate = 0.001
training_epochs = 20               #迭代輪數
batch_size = 128                   #小批量數量大小
display_epoch = 3
show_num = 10

x = tf.placeholder(dtype=tf.float32,shape=[None,n_input])
y = tf.placeholder(dtype=tf.float32,shape=[None,n_classes])
#后面通過它輸入分布數據,用來生成模擬樣本數據
zinput = tf.placeholder(dtype=tf.float32,shape=[None,n_hidden_2])


'''
定義學習參數
'''
weights = {
        'w1':tf.Variable(tf.truncated_normal([n_input,n_hidden_1],stddev = 0.001)),
        'w_lab1':tf.Variable(tf.truncated_normal([n_classes,n_hidden_1],stddev = 0.001)),
        'mean_w1':tf.Variable(tf.truncated_normal([n_hidden_1*2,n_hidden_2],stddev = 0.001)),
        'log_sigma_w1':tf.Variable(tf.truncated_normal([n_hidden_1*2,n_hidden_2],stddev = 0.001)),
        'w2':tf.Variable(tf.truncated_normal([n_hidden_2+n_classes,n_hidden_1],stddev = 0.001)),
        'w3':tf.Variable(tf.truncated_normal([n_hidden_1,n_input],stddev = 0.001))
        }

biases = {
        'b1':tf.Variable(tf.zeros([n_hidden_1])),
        'b_lab1':tf.Variable(tf.zeros([n_hidden_1])),
        'mean_b1':tf.Variable(tf.zeros([n_hidden_2])),
        'log_sigma_b1':tf.Variable(tf.zeros([n_hidden_2])),
        'b2':tf.Variable(tf.zeros([n_hidden_1])),
        'b3':tf.Variable(tf.zeros([n_input]))
        }


'''
定義網絡結構
'''
#第一個全連接層是由784個維度的輸入樣->256個維度的輸出
h1 = tf.nn.relu(tf.add(tf.matmul(x,weights['w1']),biases['b1']))
#輸入標簽
h_lab1 = tf.nn.relu(tf.add(tf.matmul(y,weights['w_lab1']),biases['b_lab1']))
#合並
hall1 = tf.concat([h1,h_lab1],1)

#第二個全連接層並列了兩個輸出網絡
z_mean = tf.add(tf.matmul(hall1,weights['mean_w1']),biases['mean_b1'])
z_log_sigma_sq = tf.add(tf.matmul(hall1,weights['log_sigma_w1']),biases['log_sigma_b1'])


#然后將兩個輸出通過一個公式的計算,輸入到以一個2節點為開始的解碼部分 高斯分布樣本
eps = tf.random_normal(tf.stack([tf.shape(h1)[0],n_hidden_2]),0,1,dtype=tf.float32)
z = tf.add(z_mean,tf.multiply(tf.sqrt(tf.exp(z_log_sigma_sq)),eps))
#合並
zall = tf.concat([z,y],1)    #None x 12


#解碼器 由12個維度的輸入->256個維度的輸出
h2 = tf.nn.relu(tf.matmul(zall,weights['w2']) + biases['b2'])
#解碼器 由256個維度的輸入->784個維度的輸出  即還原成原始輸入數據
reconstruction = tf.matmul(h2,weights['w3']) + biases['b3']


#這兩個節點不屬於訓練中的結構,是為了生成指定數據時用的
zinputall = tf.concat([zinput,y],1)
h2out = tf.nn.relu(tf.matmul(zinputall,weights['w2']) + biases['b2'])
reconstructionout = tf.matmul(h2out,weights['w3']) + biases['b3']

'''
構建模型的反向傳播
'''
#計算重建loss
#計算原始數據和重構數據之間的損失,這里除了使用平方差代價函數,也可以使用交叉熵代價函數  
reconstr_loss = 0.5*tf.reduce_sum((reconstruction-x)**2)
print(reconstr_loss.shape)    #(,) 標量
#使用KL離散度的公式
latent_loss = -0.5*tf.reduce_sum(1 + z_log_sigma_sq - tf.square(z_mean) - tf.exp(z_log_sigma_sq),1)
print(latent_loss.shape)      #(128,)
cost = tf.reduce_mean(reconstr_loss+latent_loss)


#定義優化器    
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)

num_batch = int(np.ceil(mnist.train.num_examples / batch_size))

'''
開始訓練
'''
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    print('開始訓練')
    for epoch in range(training_epochs):
        total_cost = 0.0
        for i in range(num_batch):
            batch_x,batch_y = mnist.train.next_batch(batch_size)            
           _,loss = sess.run([optimizer,cost],feed_dict={x:batch_x,y:batch_y})
            total_cost += loss
            
        #打印信息
        if epoch % display_epoch == 0:
            print('Epoch {}/{}  average cost {:.9f}'.format(epoch+1,training_epochs,total_cost/num_batch))
                        
    print('訓練完成')
    
    #測試
    print('Result:',cost.eval({x:mnist.test.images,y:mnist.test.labels}))
    #數據可視化   根據原始圖片生成自編碼數據                  
    reconstruction = sess.run(reconstruction,feed_dict = {x:mnist.test.images[:show_num],y:mnist.test.labels[:show_num]})
    plt.figure(figsize=(1.0*show_num,1*2))        
    for i in range(show_num):
        #原始圖像
        plt.subplot(2,show_num,i+1)            
        plt.imshow(np.reshape(mnist.test.images[i],(28,28)),cmap='gray')   
        plt.axis('off')
           
        #變分自編碼器重構圖像
        plt.subplot(2,show_num,i+show_num+1)
        plt.imshow(np.reshape(reconstruction[i],(28,28)),cmap='gray')       
        plt.axis('off')
    plt.show()
    

        
    '''
    高斯分布取樣,根據標簽生成模擬數據
    '''        
    z_sample = np.random.randn(show_num,2)
    reconstruction = sess.run(reconstructionout,feed_dict={zinput:z_sample,y:mnist.test.labels[:show_num]})  
    plt.figure(figsize=(1.0*show_num,1*2))        
    for i in range(show_num):
        #原始圖像
        plt.subplot(2,show_num,i+1)            
        plt.imshow(np.reshape(mnist.test.images[i],(28,28)),cmap='gray')   
        plt.axis('off')
           
        #根據標簽成成模擬數據
        plt.subplot(2,show_num,i+show_num+1)
        plt.imshow(np.reshape(reconstruction[i],(28,28)),cmap='gray')       
        plt.axis('off')
    plt.show()
    

上面第一幅圖是根據原始圖片生成的自編碼數據,第一行為原始數據,第二行為自編碼數據,該數據仍然保留一些原始圖片的特征。

第二幅圖片是利用樣本數據的標簽和高斯分布之z_sample一起生成的模擬數據,我們可以看到通過標簽生成的數據,已經徹底學會了樣本數據的分布,並生成了與輸入截然不同但帶有相同意義的數據。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM