變分自編碼器(variational autoencoder, VAE)是一種生成模型,訓練模型分為編碼器和解碼器兩部分。
編碼器將輸入樣本映射為某個低維分布,這個低維分布通常是不同維度之間相互獨立的多元高斯分布,因此編碼器的輸出為這個高斯分布的均值與對數方差(因為方差總是大於0,為了將它映射到$(-\infty,\infty)$,所以加了對數)。在編碼器的分布中抽樣后,解碼器做的事是將從這個低維抽樣重新解碼,生成與輸入樣本相似的數據。數據可以是圖像、文字、音頻等。
VAE模型的結構不難理解,關鍵在於它的損失函數的定義。我們要讓解碼器的輸出與編碼器的輸入盡量相似,這個損失可以由這二者之間的二元交叉熵(binary crossentropy)來定義。但是僅由這個作為最終的目標函數是不夠的。在這樣的目標函數下,不斷的梯度下降,會使編碼器在不同輸入下的輸出均值之間的差別越來越大,輸出方差則會不斷地趨向於0,也就是對數方差趨向於負無窮。因為只有這樣才會使從生成分布獲取的抽樣更加明確,從而讓解碼器能生成與輸入數據更接近的數據,以使損失變得更小。但是這就與生成器的初衷有悖了,生成器的初衷實際上是為了生成更多“全新”的數據,而不是為了生成與輸入數據“更像”的數據。所以,我們還要再給目標函數加上編碼器生成分布的“正則化損失”:生成分布與標准正態分布之間的KL散度(相對熵)。讓生成分布不至於“太極端、太確定”,從而讓不同輸入數據的生成分布之間有交叉 。於是解碼器通過這些交叉的“緩沖帶”上的抽樣,能夠生成“中間數據”,產生意想不到的效果。
詳細的分析請看:變分自編碼器VAE:原來是這么一回事 - 知乎
以下使用Keras實現VAE生成圖像,數據集是MNIST。
代碼實現
編碼器
編碼器將MNIST的數字圖像轉換為2維的正態分布均值與對數方差。簡單堆疊卷積層與全連接層即可,代碼如下:
#%%編碼器 import numpy as np import keras from keras import layers,Model,models,utils from keras import backend as K from keras.datasets import mnist img_shape = (28,28,1) latent_dim = 2 input_img = layers.Input(shape=img_shape) x = layers.Conv2D(32,3,padding='same',activation='relu')(input_img) x = layers.Conv2D(64,3,padding='same',activation='relu',strides=2)(x) x = layers.Conv2D(64,3,padding='same',activation='relu')(x) x = layers.Conv2D(64,3,padding='same',activation='relu')(x) inter_shape = K.int_shape(x) x = layers.Flatten()(x) x = layers.Dense(32,activation='relu')(x) encode_mean = layers.Dense(2,name = 'encode_mean')(x) #分布均值 encode_log_var = layers.Dense(2,name = 'encode_logvar')(x) #分布對數方差 encoder = Model(input_img,[encode_mean,encode_log_var],name = 'encoder')
解碼器
解碼器接受2維向量,將這個向量“解碼”為圖像。同樣也是簡單的堆疊卷積層、逆卷積層與全連接層即可,代碼如下:
#%%解碼器 input_code = layers.Input(shape=[2]) x = layers.Dense(np.prod(inter_shape[1:]),activation='relu')(input_code) x = layers.Reshape(target_shape=inter_shape[1:])(x) x = layers.Conv2DTranspose(32,3,padding='same',activation='relu',strides=2)(x) x = layers.Conv2D(1,3,padding='same',activation='sigmoid')(x) decoder = Model(input_code,x,name = 'decoder')
整體待訓練模型
整個待訓練模型包括編碼器、抽樣層、解碼器。中間的抽樣操作在獲取編碼器傳出的均值與方差后,通過一個自定義的lambda層來實現。這個抽樣是先從標准正態分布中抽樣,再通過乘生成分布的標准差,加上均值來獲得。因此這個操作並不會把反向傳播中斷,可以將編碼器與解碼器的張量流連接起來。
定義好模型后是損失的定義,如前面所說,最終損失(目標函數)是生成圖像與原圖像之間的二元交叉熵和生成分布的正則化的平均值。使用add_loss方法來添加模型的損失,具體的自定義損失方法看鏈接。
代碼如下:
#%%整體待訓練模型 def sampling(arg): mean = arg[0] logvar = arg[1] epsilon = K.random_normal(shape=K.shape(mean),mean=0.,stddev=1.) #從標准正態分布中抽樣 return mean + K.exp(0.5*logvar) * epsilon #獲取生成分布的抽樣 input_img = layers.Input(shape=img_shape,name = 'img_input') code_mean, code_log_var = encoder(input_img) #獲取生成分布的均值與方差 x = layers.Lambda(sampling,name = 'sampling')([code_mean, code_log_var]) x = decoder(x) training_model = Model(input_img,x,name = 'training_model') decode_loss = keras.metrics.binary_crossentropy(K.flatten(input_img), K.flatten(x)) kl_loss = -5e-4*K.mean(1+code_log_var-K.square(code_mean)-K.exp(code_log_var)) training_model.add_loss(K.mean(decode_loss+kl_loss)) #新出的方法,方便得很 training_model.compile(optimizer='rmsprop')
訓練
因為損失函數並沒有定義真實數據與預測數據直接的損失,因此fit方法只需傳入輸入即可(不用輸出)。代碼如下:
#%%讀取數據集訓練 (x_train,y_train),(x_test,y_test) = mnist.load_data() x_train = x_train.astype('float32')/255 x_train = x_train[:,:,:,np.newaxis] training_model.fit( x_train, batch_size=512, epochs=100, validation_data=(x_train[:2],None))
生成測試
使用scipy.stats中的norm.ppf方法在概率區間(0.01,0.99)內生成20*20個解碼器輸入,這個方法類似在標准正態分布中抽樣,但並不是隨機的,是正態分布下的等概率。生成的二維點分布如下圖:
這樣抽樣而不均勻抽樣為了和編碼器的生成分布契合,因為編碼器正則化后生成的分布是靠近標准正態分布的。然后用解碼器生成圖片,這一部分的代碼如下:
#%%測試 from scipy.stats import norm import numpy as np import matplotlib.pyplot as plt n = 20 x = y = norm.ppf(np.linspace(0.01,0.99,n)) #生成標准正態分布數 X,Y = np.meshgrid(x,y) #形成網格 X = X.reshape([-1,1]) #數組展平 Y = Y.reshape([-1,1]) input_points = np.concatenate([X,Y],axis=-1)#連接為輸入 for i in input_points: plt.scatter(i[0],i[1]) plt.show() img_size = 28 predict_img = decoder.predict(input_points) pic = np.empty([img_size*n,img_size*n,1]) for i in range(n): for j in range(n): pic[img_size*i:img_size*(i+1), img_size*j:img_size*(j+1)] = predict_img[i*n+j] plt.figure(figsize=(10,10)) plt.axis('off') pic = np.squeeze(pic) plt.imshow(pic,cmap='bone') plt.show()
生成的400張圖:
可以看出來,二維坐標系中某個方向的編碼是可以使解碼器的輸出從一個數字變換到另一個數字的。