VAE 與 GAN 的關系
VAE(Variational Auto-Encoder)和 GAN(Ganerative Adversarial Networks)都是生成模型(Generative model)。所謂生成模型,即能生成樣本的模型。我們可以將訓練集中的數據點看作是某個隨機分布抽樣出來的樣本,比如:MNIST 手寫體樣本,我們就可將每一幅圖像看作是一個隨機分布p(x)p(x) 的抽樣(實例)。如果我們能夠得到這樣的一個隨機模型,我們就可以無限制地生成樣本,再無采集樣本的煩惱。但這個隨機分布 p(x)p(x) 我們並不知道,需要通過對訓練集的學習來得到它,或者逼近它。要逼近一個隨機分布,其基本思想是:將一個已知的,可控的隨機分布 q(z)q(z) 映射到目標隨機分布 p(x)p(x) 上。在深度學習領域中,有兩個典型的生成模型:1、VAE,變分自編碼器;2、GAN,生成對抗性網絡,它們的結構如圖 1、圖 2:
圖 1 VAE 結構圖
圖 2 GAN 結構圖
VAE 的工作流程是:
1、在訓練集(Dataset)中抽樣,得到樣本xixi,xixi經過神經網絡編碼器(NN Encoder)得到一個正態分布(N(μ**i,σ2i)N(μi,σi2))的充分統計量:均值(以圖 1 為例進行解釋,μ**i=(m1,m2,m3)μi=(m1,m2,m3))和方差(σi=(σ1,σ2,σ3)σi=(σ1,σ2,σ3));
2、由N(μ**i,σ2i)N(μi,σi2) 抽樣得到 zizi, 已知 zizi 的分布是標准正態分布 N(0,1)N(0,1);
3、zizi 經過 NN_Decoder 得到輸出x****̂ ix^i
4、Los**s=‖x****̂ i−xi‖Loss=‖x^i−xi‖,訓練過程就是讓 Loss 取得最小值。
5、訓練好的模型,我們可以利用 Decoder 生成樣本,即將已知分布 q(z)q(z) 的樣本通過 Decoder 映射成目標 p(x)p(x) 分布的樣本。
以上過程可以用圖 3 進行概括。
圖 3 VAE 原理圖
為比較 VAE 和 GAN 的差異,參考圖 4,簡述 GAN 的工作原理如下:
1、在一個已知的、可控的隨機分布q(z)q(z)(e.g.: 多維正態分布 q(z)=N(μ,σ2)q(z)=N(μ,σ2)) 采樣,得到zizi;
2、zizi 經過生成器映射 G(zi)G(zi) 得到在樣本空間(高維空間)的一個數據點 xgixig,有 xgi=G*(zi)xig=G(zi);
3、xgixig 與樣本流型(Manifolds)之間的距離在圖中表示為 D(‖xgi*,x̂** **gi)=‖xgi−x****̂ g****i‖D(‖xig,xig)=‖xig−xig‖,其中 x****̂ g****ix^ig 表示 xgixig 在流型上的投影,生成器的目標是減少此距離,可以將此距離定義為生成器的損失(Loss)。
4、因為不能直接得到樣本的流型,因而需要借助判別器 (Discriminator) 間接地告訴生成器(Generator)它生成的樣本距樣本流型是 “遠了” 還是“近了”,即判別真(real)和假(fake)正確的概率,一方面要判別器提高判別准確性,一方面又要提高生成器的以假亂真的能力,由此形成了競爭導致的均衡,使判別器和生成器兩者性能同時提高,最后我們可獲得最優的生成器。
5、理想的最優生成器,能將生成的 xg*i*xig 映射到樣本分布的流型中(即圖中陰影部分)
圖 4 GAN 原理
由 GAN 的生成過程,我們可以很直觀地得到兩個 GAN 缺陷的解釋(詳細分析可見《Wasserstein GAN》):
1、模型坍塌(Model collapse)
GAN 只要求將 xgixig 映射至離樣本分布流型盡可能近的地方,卻不管是不是同一個點,於是生成器有將所有 zi*zi 都映射為一點的傾向,於是模型坍塌就發生了。
2、不收斂
由於生成器的 Loss 依賴於判別器 Loss 后向傳遞,而不是直接來自距離 D(xgi,x***̂ g****i)D(xig,x^ig) ,因而若判別器總是能准確地判別出真假,則向后傳遞的信息就非常少(體現為梯度為 0),則生成器便無法形成自己的 Loss,因而便無法有效地訓練生成器。正是因為這個原因,《Wasserstein GAN》才提出了一個新的距離定義(Wasserstein Distance)應用於判別器,而不是原型中簡單粗暴的對真偽樣本的分辨正確的概率。Wasserstein Distance 所針對的就是找一個方法來度量 D(xgi,x̂* g*i)D(xig,x^ig)。
比較 VAE 與 GAN 的效果,我們可以得到以下兩點:
1、GAN 生成的效果要優於 VAE
我認為GAN和VAE的一個本質區別就是loss的區別
VAE是pointwise loss,一個典型的特征就是pointwise loss常常會脫離數據流形面,因此看起來生成的圖片會模糊
GAN是分布匹配的loss,更能貼近流行面,看起來就會清晰
但分布匹配的難度較大,一個例子就是經常發生mode collapse問題,小分布丟失,而pointwise loss就沒有這個問題,可以用於做初始化或做糾正,因此發展了一系列GAN+VAE的工作
VAE 是一個在數學上具有完美證明的模型,其優化目標和過程都是顯式的。由於 VAE 設計的強預設性,其優化過程強制性地把數據擬合到有限維度的混合高斯或者其他分布上,這導致兩個結果:
- 映射過程中必然導致信息損失,特別是次要信息的損失
- 不符合預設分布的信息的編碼和恢復效果差,如果強制性的把這樣的分布投射到高斯分布上,就必然導致模糊
對應地,GAN 的訓練中沒有這樣的強預設,它是通過判別器網絡來進行優化的,讓生成器產生數據的分布直接擬合訓練數據的分布,而對具體的分布沒有特別的要求。
2、GAN 比 VAE 要難於訓練
文章:Variational Inference: A Unified Framework of Generative Models and Some Revelations,原文:arXiv:1807.05936,中文鏈接:https://kexue.fm/archives/5716。文章來自中山大學 數學學院 蘇劍林。文中為兩個生成模型建立了統一的框架,並提出一種為訓練 Generator 而設計的正則項原理,它可以作為《Wasserstein GAN》的一個補充,因為 WGAN 給出的是 GAN 判別器 Loss 的改進意見,而該文卻是對生成器下手,提出生成器的一個正則項原則。
以下是從變分推斷(Variational Inference)角度對兩個模型的推導過程:
p(x)p(x) 是真實隨機變量的分布,樣本集中的數據可認為是從這個分布中抽樣出來的樣本,它是未知的。我們希望用一個可控的、已知的分布 q(x)q(x) 來逼近它,或者說讓這兩個分布盡量重合。如何實現這個目標呢?一個直觀的思路是:從一個已知的分布出發,對其中的隨機變量進行映射,得到一個新的分布,這個映射后分布與目標分布盡量重合。在 VAE 中,這個映射由 Decoder 完成;在 GAN 中,這個映射由 Generator 來完成。而已知分布可以是:均勻分布或正態分布,例如:隨機變量 zz 服從多維標准正態分布 z∼N(0,1)z∼N(0,1),它經過映射得到 xgi*xig,為表述方便我們統一用 G(zi)=xgi,xgi∼q*(x)G(zi)=xig,xig∼q(x) 表示經過映射后的隨機變量 xg*xg 服從分布 q(x)q(x),它與真實變量分布 xr∼p*(x)xr∼p(x) 在同一個空間 XX。我們希望得到盡可能與 p(x)p(x) 重合的 q(x)q(x)。
為達到 “盡量” 這一目標,需要對重合程度進行量化,於是我們定義了一個測度的指標——KL 散度:
K**L(p(x)‖q(x))=∫p(x)logp(x)q(x)d**x(1)KL(p(x)‖q(x))=∫p(x)logp(x)q(x)dx(1)
當 p(x) 與 q(x) 完全重合,有 K**L(p(x)‖q(x))=0KL(p(x)‖q(x))=0,否則 K**L(p(x)‖q(x))>0KL(p(x)‖q(x))>0。直接求上述真實分布 p(x)p(x) 與生成分布 q(x)q(x) 的 KL 散度 K**L(p(x)‖q(x))KL(p(x)‖q(x)) 有時十分困難,往往需要引入隱變量(Latent Variables)構成聯合分布 p(x,z)p(x,z) 和 q(x,z)q(x,z),計算 K**L(p(x,z)‖q(x,z))KL(p(x,z)‖q(x,z)) 來代替直接計算 K**L(p(x)‖q(x))KL(p(x)‖q(x))。
(1)GAN 的聯合分布 p(x,z)p(x,z) 的擬合
GAN 在 Generator 生成了 xgixig 將與真實樣本 xri*xir 一起作為輸入xx,進入一個二選一選擇器,該選擇器以一定的概率(比如說 50%)選擇其一(選擇器可以看成是隨機變量 yy ,它服從服從 0-1 分布), xx接着進入判決器,判決器判定輸入xx的是真實樣本(xrxr ),判斷為 1,或是生成樣本(xgxg),判斷為 0,最后輸出給定xx判為 1 的概率。
令聯合分布是輸入xx 和選擇器 yy 的分布,p(x,y)p(x,y) 表示真實的聯合分布,而 q(x,y)q(x,y) 是判別器可控制的聯合分布,是為擬合真實分布所設計的分布,由上機制可見,在 q(x,y)q(x,y) 中,x*x 與 yy 其實是相互獨立的,因而有:
q(x,y)={p(x)⋅p1q(x)⋅p0if y=1 if y=0 (2)q(x,y)={p(x)⋅p1if y=1 q(x)⋅p0if y=0 (2)
其中,p1p1 表示 y=1y=1 的概率,p0p0 表示 y=0y=0 的概率,有 p1+p0=1p1+p0=1。若判別器有足夠的擬合能力,並達到最優時,則q(x,y)→p(x,y)q(x,y)→p(x,y),這意味着:
q(x)=∑y**q(x,y)→∑y**p(x,y)=p(x)(3)q(x)=∑yq(x,y)→∑yp(x,y)=p(x)(3)
為擬合,由(2)可得兩個聯合分布 KL 散度:
K**L(q(x,y)‖p(x,y))=∫[ p(x)⋅p1logp(x)p1p(y=1|x)p(x)+q(x)⋅p0logq(x)p0p(y=0|x)p(x)] d**x(4)KL(q(x,y)‖p(x,y))=∫[ p(x)⋅p1logp(x)p1p(y=1|x)p(x)+q(x)⋅p0logq(x)p0p(y=0|x)p(x)] dx(4)
令 p1=p0=0.5p1=p0=0.5,以 K**L(q(x,y)‖p(x,y))KL(q(x,y)‖p(x,y)) 為 Loss,因為p1p1、p0p0、p(x)p(x)、q(x)q(x) 與判別器參數無關,另外,判別器的輸出是 p(y=1|x)=D(x)p(y=1|x)=D(x),因而:
K**L(q(x,y)‖p(x,y))∼∫p(x)log1p(y=1|x)d**x+∫q(x)log1p(y=0|x)d**x=∫p(x)log1D(x)d**x+∫q(x)log11−D(x)d**x=−E****x∼p(x)(logD(x))−E****x∼q(x)(log(1−D(x)))(5)KL(q(x,y)‖p(x,y))∼∫p(x)log1p(y=1|x)dx+∫q(x)log1p(y=0|x)dx=∫p(x)log1D(x)dx+∫q(x)log11−D(x)dx=−Ex∼p(x)(logD(x))−Ex∼q(x)(log(1−D(x)))(5)
GAN 的判別器要盡量分辨兩個分布,因而要 KL 散度盡可能大,因而若 (5) 式要作為判別器的 Loss,則需取反,即:
LossD=E****x∼p(x)(logD(x))+E****x∼q(x)(log(1−D(x)))(6)D=argminDE****x∼p(x)(logD(x))+E****x∼q(x)(log(1−D(x)))LossD=Ex∼p(x)(logD(x))+Ex∼q(x)(log(1−D(x)))(6)D=argminDEx∼p(x)(logD(x))+Ex∼q(x)(log(1−D(x)))
(6)式與傳統 GAN 判別器 Loss 分析結果(可參見:https://blog.csdn.net/StreamRock/article/details/81096105)相同。
討論完判別器 Loss_D,接下來討論生成器的 Loss_G,同樣從(4)式入手,此時可調的參數是生成器的參數,p1p1、p0p0、p(x)p(x)、p(y=1|x)p(y=1|x) (即D(x)D(x)) 與生成器參數無關,與之相關的是q(x)q(x),因而有:
K**L(q(x,y)‖p(x,y))=∫[ p(x)⋅p1logp(x)p1p(y=1|x)p(x)+q(x)⋅p0logq(x)p0p(y=0|x)p(x)] d**x∼∫q(x)logq(x)(1−D(x))p(x)dx(8)KL(q(x,y)‖p(x,y))=∫[ p(x)⋅p1logp(x)p1p(y=1|x)p(x)+q(x)⋅p0logq(x)p0p(y=0|x)p(x)] dx∼∫q(x)logq(x)(1−D(x))p(x)dx(8)
最優判決器應該具有如下特性:
D(x)=p(x)p(x)+q(x)(9)D(x)=p(x)p(x)+q(x)(9)
(9)式可以從圖 5,直接得到。
圖 5 最優判決器
將(9)代入(8)有:
Los**s_G=∫q(x)logq(x)(1−D(x))p(x)dx=∫q(x)logq(x)(1−p(x)p(x)+q(x))p(x)dx=∫q(x)log1p(x)p(x)+q(x)dx=−∫q(x)logD(x)dx=−E****x∼q(x)(logD(x))=−E****z∼N(0,1)(logD(G(z)))(10)Loss_G=∫q(x)logq(x)(1−D(x))p(x)dx=∫q(x)logq(x)(1−p(x)p(x)+q(x))p(x)dx=∫q(x)log1p(x)p(x)+q(x)dx=−∫q(x)logD(x)dx=−Ex∼q(x)(logD(x))=−Ez∼N(0,1)(logD(G(z)))(10)
於是可采用(10)作為生成器的損失 Loss_G,與傳統推導的結果一致。考察(9)式,式中 q(x)q(x) 實際上是不斷變化的,因而我們將 (9) 改造一下,q0(x)q0(x) 表示前一次生成器的狀態,於是有:
D(x)=p(x)p(x)+q0(x)(11)Los**s_G=∫q(x)logq(x)(1−D(x))p(x)dx=∫q(x)logq(x)(1−p(x)p(x)+q0(x))p(x)dx=∫q(xlogq(x)D(x)q0(x)dx)=−E****x∼q(x)[logD(x)]+K**L(q(x)‖q0(x))=−E****z∼N(0,1)[logD(G(z))]+K**L(q(x)‖q0(x))(12)D(x)=p(x)p(x)+q0(x)(11)Loss_G=∫q(x)logq(x)(1−D(x))p(x)dx=∫q(x)logq(x)(1−p(x)p(x)+q0(x))p(x)dx=∫q(xlogq(x)D(x)q0(x)dx)=−Ex∼q(x)[logD(x)]+KL(q(x)‖q0(x))=−Ez∼N(0,1)[logD(G(z))]+KL(q(x)‖q0(x))(12)
比較(10)與(12)發現 (12) 多了一項 K**L(q(x)‖q0(x))KL(q(x)‖q0(x)) , 此為生成器 Loss 的正則項,它要求q(x)q(x) 與 q0(x)q0(x) 盡可能小,由此可實現生成器的正則項,詳細分析可見:https://kexue.fm/archives/5716