Generative Adversarial Nets[AAE]



本文來自《Adversarial Autoencoders》,時間線為2015年11月。是大神Goodfellow的作品。本文還有些部分未能理解完全,不過代碼在AAE_LabelInfo,這里實現了文中2.3小節,當然實現上有點差別,其中one-hot並不是11個類別,只是10個類別。

本文提出“對抗自動編碼器(AAE)”,其本質上是自動編碼器和GAN架構的合體,通過將AE隱藏層編碼向量的聚合后驗與任意先驗分布進行匹配完成變分推論(variational inference)。將聚合后驗與先驗進行匹配確保從該先驗任何部分都能夠生成有意義的樣本。AAE的解碼層可以看成是一個深度生成模型,可以將強加的先驗映射到數據分布上。本文並介紹如何將AAE用在如半監督分類,圖像分類,無監督聚類,維度約間和數據可視化。
本文主要是介紹了幾種AAE的應用:

  • Basic AAE (文中2到2.1之間的部分)
  • Incorporatiing Label Information in the Adversarial Regularization (文中2.3小節)
  • Supervised AAE (文中4小節)
  • Semi-supervised AAE (文中5小節)
  • Unsupervised Clustering with AAE (文中6小節)
  • Dimensionality Reduction with AAE (文中7小節)

0 引言

構建一個可伸縮的生成模型,能夠抓取如語音,圖像,視頻等分布是ML中一個核心問題。近些年的模型如RBM,DBN,DBM都是通過MCMC算法進行訓練的,而MCMC算法是通過計算log-似然的梯度去完成的,該方法在訓練階段並不實用,因為從馬爾可夫鏈中采樣的樣本無法在模式之間快速混合。近些年,生成模型主要都是通過BP進行訓練,避免MCMC的訓練困難。如變分自動編碼器(variational autoencoders,VAE)或者是重要性權重自動編碼器(importance weighted autoencoders)都是采用一個識別網絡(基於潛在變量基礎上)去預測后驗概率。GAN使用對抗訓練過程直接塑造網絡的輸出,如生成時刻匹配網絡(generative moment matching networks,GMMN)使用一個時刻匹配損失函數去學習數據的分布。

本文提出一個通用性算法,對抗自動編碼器,將一個自動編碼器變成生成模型。本文中AE通過兩個目標函數進行訓練,一個是傳統的重構誤差函數,另一個是對抗訓練函數,意在將AE隱藏層向量表示的聚合后驗分布與任意的先驗分布進行匹配。可以發現這個訓練准則和VAE有很強的聯系。訓練的結果是:

  • 編碼器學到將數據分布轉換成該先驗分布;
  • 解碼器學到一個深度生成模型,可以將強加的先驗映射到數據分布上。
    '''update AE network''' 
    _, loss_likehood = sess.run([ae_optim, neg_marginal_likelihood], feed_dict=feed_dict_input)
    '''update discriminator network'''
    _, d_loss = sess.run([d_optim, D_loss], feed_dict=feed_dict_input)
    '''update generator network, run 2 times'''
    _, g_loss = sess.run([g_optim, G_loss], feed_dict=feed_dict_input)
    _, g_loss = sess.run([g_optim, G_loss], feed_dict=feed_dict_input)

1 對抗自動編碼器(Adversarial Autoencoders,AAE)

假設\(\mathbf{x}\)是輸入向量,\(\mathbf{z}\)是AE的隱藏層編碼向量。令\(p(\mathbf{z})\)表示想要加在編碼上的先驗分布,\(q(\mathbf{z}|\mathbf{x})\)是一個編碼分布,\(p(\mathbf{x}|\mathbf{z})\)是一個解碼分布。同時,令\(p_d(\mathbf{x})\)表示數據分布,\(p(\mathbf{x})\)表示模型分布。基於AE的編碼函數 \(q(\mathbf{z}|\mathbf{x})\),定義\(q(\mathbf{z})\)的聚合后驗分布如下:

\[q(\mathbf{z})=\int_{\mathbf{x}}q(\mathbf{z}|\mathbf{x})p_d(\mathbf{x})d\mathbf{x} \tag{1} \]

AAE是在AE基礎上,通過將聚合后驗\(q(\mathbf{z})\)與任意先驗\(p(\mathbf{z})\)進行匹配來完成正則化。為了完成這樣的目標,對抗網絡與AE的隱藏編碼向量相關聯,如圖1.

讓對抗網絡指導\(q(\mathbf{z})\)去匹配\(p(\mathbf{z})\)。同時該AE也嘗試最小化重構誤差。該對抗網絡的生成器同時也是AE的編碼器 \(q(\mathbf{z}|\mathbf{x})\)。該編碼器確保聚合的后驗分布可以愚弄對抗網絡的判別器,讓其誤認為隱藏編碼\(q(\mathbf{z})\)來自真實先驗分布\(p(\mathbf{z})\)

對抗網絡和AE是通過SGD基於兩個階段聯合訓練的:基於mini-batch執行重構階段和正則階段

  • 在重構階段,AE更新編碼器和解碼器,並最小化輸入的重構誤差;
  • 在正則階段,對抗網絡首先更新判別網絡,以區分真實樣本(使用先驗生成)和生成樣本(通過AE計算隱藏編碼);然后,對抗網絡更新生成器(AE的編碼器)去混亂判別器。

一旦訓練階段完成了,AE的解碼器就可以看成是一個生成模型,可以將強加的先驗\(p(\mathbf{z})\)映射回數據分布上。
關於AE的編碼器 \(q(\mathbf{z}|\mathbf{x})\),有幾種可能的選擇:

確定性(Deterministic)
假設\(q(\mathbf{z}|\mathbf{x})\)\(\mathbf{x}\)的確定性函數。這種情況下,編碼器就類似標准AE的編碼器,\(q(\mathbf{z})\)中的隨機來源就是數據分布\(q_d(\mathbf{x})\)

高斯后驗(Gaussian posterior)
假設\(q(\mathbf{z}|\mathbf{x})\)是一個高斯分布,其均值和方差是通過編碼網絡預測的:\(z_i\sim \mathcal{N}(\mu_i(\mathbf{x}),\sigma_i(\mathbf{x}))\)。在這種情況下,\(q(\mathbf{z})\)中的隨機性同時來自數據分布和編碼器輸出的高斯分布隨機性。在網絡的BP過程中,可以使用《Auto-encoding variational bayes》同樣的重新參數技巧。

通用近似后驗(Universal approximator posterior)
AAE可以訓練\(q(\mathbf{z}|\mathbf{x})\)成通用近似后驗。假設AAE的編碼網絡是函數\(f(\mathbf{x},\eta )\),其輸入是\(\mathbf{x}\)和一個固定分布(如高斯)的隨機噪音\(\eta\)。通過在\(\eta\)的不同樣本上評估\(f(\mathbf{x},\eta )\),從而從任意的后驗分布\(q(\mathbf{z}|\mathbf{x})\)中進行采樣。換句話說,假設\(q(\mathbf{z}|\mathbf{x},\eta)=\delta (\mathbf{z}-f(\mathbf{x},\eta))\),那么后驗\(q(\mathbf{z}|\mathbf{x})\)和聚合后驗\(q(\mathbf{z})\)定義如下:

\[q(\mathbf{z}|\mathbf{x})=\int_{\eta}q(\mathbf{z}|\mathbf{x},\eta)p_{\eta}(\eta)d\eta\Rightarrow q(\mathbf{z})=\int_{\mathbf{x}}\int_{\eta}q(\mathbf{z}|\mathbf{x},\eta)p_d(\mathbf{x})p_{\eta}(\eta)d\eta d\mathbf{x} \]

在該情況下,\(q(\mathbf{z})\)的隨機性同時來自數據分布和編碼器輸入上的隨機噪音\(\eta\)。注意到該情況中后驗分布\(q(\mathbf{z}|\mathbf{x})\)不再是受限於高斯,且編碼器可以基於給定輸入\(\mathbf{x}\)學到任意的后驗分布。因為從聚合后驗\(q(\mathbf{z})\)上采樣是一個高效的方法,對抗訓練過程可以通過在編碼網絡\(f(\mathbf{x},\eta)\)上進行BP,讓\(q(\mathbf{z})\)去匹配\(p(\mathbf{z})\)

從上述三種策略,選擇不同類型的\(q(\mathbf{z}|\mathbf{x})\)可以生成不同類型的模型。例如,在\(q(\mathbf{z}|\mathbf{x})\)的確定情況中,網絡只能讓\(q(\mathbf{z})\)去匹配\(p(\mathbf{z})\),此時只利用了數據分布的隨機性。但是因為數據的經驗性分布是被訓練集固定的,映射是確定的,這可能生成一個不是很平滑的\(q(\mathbf{z})\);然而,在高斯或者通用近似情況中,網絡需要額外的隨機性來源,以幫助在對抗正則階段中對\(q(\mathbf{z})\)進行平滑懲罰。
然而,在多次試驗后,作者發現每個\(q(\mathbf{z}|\mathbf{x})\)策略上得到結果大同小異。所以在剩下部分中,只介紹\(q(\mathbf{z}|\mathbf{x})\)的確定性策略。

1.1 與VAE的關系

本文的想法類似《Auto-encoding variational bayes》中變分自動編碼器(variational autoencoders,VAE),然而他們使用的是KL散度懲罰的方法在隱藏層編碼向量上強加一個先驗分布,本文使用的是對抗訓練方法去實現該目的,即讓隱藏層編碼向量的聚合后驗能夠匹配先驗分布。VAE是最小化關於\(\mathbf{x}\)的負log似然上邊界

這里聚合后驗\(q(\mathbf{z})\)定義與式子1中一樣,假設\(q(\mathbf{z}|\mathbf{x})\)是高斯分布,\(p(\mathbf{z})\)是任意分布。變分邊界包含三個部分:第一項可以被認為是AE的重構項,第二項和第三項可以看成是正則項。在沒有正則項的時候,該模型簡單就是個AE。然而在有正則項的時候,VAE學到的隱藏層表征是與\(p(\mathbf{z})\)兼容的。損失函數的第二項鼓勵后驗分布有較大變化,而第三項是最小化聚合后驗\(q(\mathbf{z})\)和先驗\(p(\mathbf{z})\)之間的交叉熵。式子2中KL散度或者交叉熵鼓勵\(q(\mathbf{z})\)能與\(p(\mathbf{z})\)相匹配。而在AAE中,作者將后面兩項替換成一個對抗學習的過程,從而鼓勵\(q(\mathbf{z})\)能與\(p(\mathbf{z})\)整個分布相匹配。

該部分中,將AAE與VAE在編碼分布\(p(\mathbf{z})\)上插入特定先驗的能力做對比。

如圖2a,展示的是在測試數據上的2維編碼空間\(\mathbf{z}\),基於MNIST數據集上訓練AAE,並在隱藏層編碼\(\mathbf{z}\)上強加一個高斯分布。學到的流行顯示不同類別之間明顯的過渡,編碼空間被填充並且沒有空洞存在。實際上,編碼空間中明顯的過渡指的是在位於數據流行上的\(\mathbf{z}\)內插值生成的圖像(圖2e)。圖2c顯示VAE的編碼空間與AAE有相同結構。我們可以發現這種情況下VAE大致是匹配2D的高斯分布形態。然而,沒有數據點映射到幾個編碼空間的局部區域意味着VAE不能如AAE一樣很好的抓取數據流行。

圖2b和圖2d表現的是AAE和AVE的編碼空間,其中插入的分布是10個2D高斯混合分布。AAE成功抓取帶有先驗分布的聚合后驗(圖2b);而VAE表現出與10個組件高斯混合的強烈差別,即VAE更多強調匹配的分布模式(圖2d)。
基於VAE和AAE之間一個重要的差別是在VAE中,為了通過MC采樣對KL散度進行BP,需要得到准確的先驗分布的函數形式。而在AAE中,只需要能從先驗分布中進行采樣就能讓\(q(\mathbf{z})\)匹配\(p(\mathbf{z})\)。后面會介紹AAE還能插入復雜的分布(如swiss roll分布)而並不需要該函數的准確表現形式。

1.2 與GAN和GMMN的關系

1.3 在對抗正則中插入標簽信息

在該場景中,數據是標注過的,可以將標簽信息插入到對抗訓練過程中,以更好的塑造隱藏層編碼的分布。在該部分中,介紹如何使用部分或者所有標簽信息來更好的正則化AE的潛在表征。為了介紹該結構,先返回圖2b,其中AAE是擬合10個成分2維的混合高斯分布。現在讓混合高斯中每個成分表示MNIST中每個標簽。

圖3是半監督方法的訓練過程。這里增加了一個one-hot向量到判別網絡的輸入部分,以將標簽與分布模式相結合。該one-hot向量扮演着(給定類別標簽基礎上)該判別網絡的決策面。該one-hot向量有一個額外的無標簽樣本類別。例如在圖2b和4a中,一個10成分的2D高斯混合模型,one-hot向量有11個類別。前面10個類別對應混合模型中每個獨立的決策面。額外的one-hot向量對應無標簽訓練樣本點(如生成器生成的樣本)。

當一個無標簽樣本點出現在該模型中,額外的類別就會得到響應,以選擇整個高斯混合分布的決策面:

  • 在對抗訓練的正階段,通過one-hot將高斯混合模型生成的樣本的標簽傳給判別器。這些正樣本來自混合高斯模型,而不是某個具體的類別;
  • 在對抗訓練的負階段,通過one-hot將生成器生成的樣本的標簽給判別器。這些負樣本來自生成器。

圖4a中展現的是基於一個AAE的隱藏層表征,該AAE是基於10k個標記的MNIST樣本可40K個無標簽的MNIST樣本,10個成分的2D高斯混合模型上訓練的。此時,先驗中第i個混合懲罰以半監督方式與第i個類別相關。圖4b展示的是前三個混合成分的流行。注意到每個混合成分的類型表征是很一致的,且與各自的類相獨立。例如,圖4b中所有的左上區域對應於直立書寫樣式,右下區域對應於數字的傾斜書寫樣式。

該方法可以擴展到任意分布而不需要參數控制,如將MNIST數據映射到一個“swiss roll”(如條件高斯分布,其均值是均勻分布的,其長度為一個swiss roll的軸)。圖4c是編碼空間\(\mathbf{z}\)的展示,圖4d是沿着swiss roll軸前進生成的圖像。

2 AAE的似然分析

本節使用《Generative adversarial nets》中描述的評估方法,比較該模型在MNIST和toronto人臉數據集(TFD)上生成圖像的能力來測量AAE作為生成模型捕獲數據分布的能力。

圖5展示的就是基於訓練好的AAE生成的樣本。在tfd.gif這里是學到的TFD流行。為了鑒定模型是否過擬合,在最后一列展現的是以歐式距離計算最近的訓練集樣本。

通過在測試集上計算AAE的log似然來對其性能進行評估。不過因為使用似然函數不直觀,不能直接計算圖片的概率,所以這里使用先《Generative adversarial nets》中描述的方法計算真實對數似然的下界。用高斯Parzen窗口(核密度估計器)去擬合10000個 從模型生成的樣本,並計算此分布下的測試數據的可能性。parzen窗口中自由參數\(\sigma\)是通過交叉驗證選擇的。

表1計算在真實數據MNIST和TFD上,AAE和其他如DBN,堆疊CAE,深度GSN,GAN和GMMN+AE模型的對比結果。
注意到parzen窗口是在真實log似然上評估下邊界,略。。。

3 有監督AAE

4 半監督AAE

5 基於AAE的無監督聚類

6 基於AAE的維度約間


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM