VAEs(變分自編碼)之keras實踐


 

VAEs最早由“Diederik P. Kingma and Max Welling, “Auto-Encoding Variational Bayes, arXiv (2013)”和“Danilo Jimenez Rezende, Shakir Mohamed, and Daan Wierstra, “Stochastic Backpropagation and Approximate Inference in Deep Generative Models,” arXiv (2014)”同時發現。

原理:

對自編碼器來說,它只是將輸入數據投影到隱空間中,這些數據在隱空間中的位置是離散的,因此在此空間中進行采樣,解碼后的輸出很可能是毫無意義的。

而對VAEs來說,它將輸入數據轉換成2個分布,一個是平均值的分布,一個是方差的分布(這就像高斯混合型了),添加上一些噪音,組合后,再進行解碼。

如圖(網上找的,應該是論文里的,暫時沒看論文)

 

 為什么分為2個分布?

可以這么理解:假設均值和方差都有n個,那么編碼部分相當於用n個高斯分布(每個輸入是不同權重的n個分布的組合)去模擬輸入。

再通過一系列變換,轉化為隱空間的若干維度,其每個維度可能具有某種意義。比如下面代碼使用2維隱空間,可以看作是均值和方差維度。

方差部分指數化,保證非負。添加噪音讓隱空間更具有意義的連續性。

然后我們從隱空間采樣,由於隱空間具有意義上的連續性,那么解碼后的東東就可能類似輸入。

損失loss如何定義?為什么?

loss由2部分構成,第一部分就是解碼輸出與原始輸入的loss,可以定義為交叉熵或者均方誤差等。

第二部分是約束項。如上圖黃色框,m平方作為L2正則化項,前2項可以看做方差減去其泰勒展開,當σ趨近0時,方差也即e^σ為1。那么最小化前2項必然使得σ趨近0(求導即可知)。

由此,這第二部分,m平方項約束使得均值為0,前2項約束使得方差為1。這樣約束使得隱空間具有連續性,且強制輸入數據在隱空間中的表示范圍收攏。

這樣在隱空間中2個數據表示的中間,就有一種過渡區域。如果僅以第一部分約束,效果可能就和自編碼器一樣了,模型會過擬合。


 

下面進入代碼部分

以MNIST數據集作為訓練樣本。

from keras import backend as K

from keras.models import Model

from keras.metrics import binary_crossentropy

import numpy as np

from keras.layers import Conv2D,Flatten,Dense,Input,Lambda,Reshape,Conv2DTranspose,Layer

from keras.datasets import mnist

from keras.callbacks import EarlyStopping

編碼器使用卷積層,輸出2個部分

img_shape=(28,28,1)
batch_size=16
latent_dim=2

input_img=Input(shape=img_shape)
x=Conv2D(32,3,padding='same',activation='relu')(input_img)# 28,28,32
x=Conv2D(64,3,padding='same',activation='relu',strides=(2,2))(x)# 14,14,64
x=Conv2D(64,3,padding='same',activation='relu')(x)#14,14,64
x=Conv2D(64,3,padding='same',activation='relu')(x)#14,14,64
# 保存Flatten之前的shape
shape_before_flattening=K.int_shape(x)
x=Flatten()(x)#14*14*64
x=Dense(32,activation='relu')(x)#32
# 將輸入圖像拆分為2個向量
z_mean=Dense(latent_dim)(x)#2
z_log_var=Dense(latent_dim)(x)

定義采樣方法

def sampling(args):
    z_mean,z_log_var=args
#     得到一個平均值為0,方差為1的正態分布,shape為(?,2)
    epsilon=K.random_normal(shape=(K.shape(z_mean)[0],latent_dim),mean=0,stddev=1.)#K.shape返回仍是tensor
#     tensor*tensor為elementwise操作
    return z_mean+K.exp(z_log_var)*epsilon
z=Lambda(sampling)([z_mean,z_log_var])# 采樣

解碼

# 解碼過程,逆操作
decode_input=Input(K.int_shape(z)[1:])
# np.prod表示對數組某個axis進行乘法操作,如果axis不指定,則將所有的元素乘積返回一個值
x=Dense(np.prod(shape_before_flattening[1:]),activation='relu')(decode_input)#14*14*64
# 逆Flatten操作
x=Reshape(shape_before_flattening[1:])(x)#14,14,64
# 反卷積,strides=2將14*14變為28*28,跟Conv2D相反
x=Conv2DTranspose(32,3,padding='same',activation='relu',strides=2)(x)#28,28,32
# 注意這里的激活函數
x=Conv2D(1,3,padding='same',activation='sigmoid')(x)#28,28,1
# 解碼model
decoder=Model(decode_input,x)
# 解碼后的圖片數據
z_decoded=decoder(z)

定義loss,使用一個自定義layer實現

class CustomVariationalLayer(Layer):
    def vae_loss(self,x,z_decoded):
        x=K.flatten(x)
        z_decoded=K.flatten(z_decoded)
#         loss為原始輸入和編碼-解碼后的輸出比較
        xent_loss=binary_crossentropy(x,z_decoded)
#         約束
#         mean部分表示L2正則損失,K.exp(z_log_var)-(1+z_log_var)保證方差為1,如果不約束,網絡可能偷懶
        kl_loss=5e-4*K.mean(K.exp(z_log_var)-(1+z_log_var)+K.square(z_mean),axis=-1)
        return K.mean(xent_loss+kl_loss)

    def call(self,inputs):
        x=inputs[0]
        z_decoded=inputs[1]
        loss=self.vae_loss(x,z_decoded)
#         繼承方法
        self.add_loss(loss,inputs=inputs)#將根據inputs計算的損失loss加到本layer
        return x #不用,但是需要返回點啥

y=CustomVariationalLayer()([input_img,z_decoded])

加載數據,定義、訓練模型

(x_train,y_train),(x_test,y_test)=mnist.load_data()

x_train=x_train.astype('float32')/255.
# 表示添加一個通道維度,通道數為1(顏色只有一種模式)
x_train=x_train.reshape(x_train.shape+(1,))
x_test=x_test.astype('float32')/255.
x_test=x_test.reshape(x_test.shape+(1,))
vae=Model(input_img,y)
# 自定義層y里面已經包含了loss,這里不需要指定
vae.compile(optimizer='rmsprop',loss=None)
# 不需要標簽,所以y為None,我們只需要知道一個圖片的原始輸入是否和編碼-解碼后的輸出一致
vae.fit(x=x_train,y=None,shuffle=True,epochs=10,batch_size=batch_size,validation_data=(x_test,None),callbacks=[EarlyStopping(patience=2)],verbose=2)

測試

import matplotlib.pyplot as plt
from scipy.stats import norm

# 潛空間中任意矢量可以解碼成數字
n = 10
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
# norm.ppf([v1,v2...])表示正態分布積分值為vi時,對應的x軸坐標值xi
grid_x = norm.ppf(np.linspace(0.05, 0.95, n))#可以看作均值
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))#方差
for i, yi in enumerate(grid_x):
    for j, xi in enumerate(grid_y):
        z_sample = np.array([[xi, yi]])
#         np.tile將數組重復n次,如[1,2]->[1,2,1,2]。然后reshape到輸入格式
        z_sample = np.tile(z_sample, batch_size).reshape(batch_size, 2)
        x_decoded = decoder.predict(z_sample, batch_size=batch_size)
#         因為x_decoded為16個相同矢量得到的推導,取第一個就行,再將 28*28*1 reshape到 28*28
        digit = x_decoded[0].reshape(digit_size, digit_size)
        figure[i * digit_size: (i + 1) * digit_size,
               j * digit_size: (j + 1) * digit_size] = digit
plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
plt.show()

結果如下,可以看到,圖片是連續變化的。

 

 

 

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM