Generative Adversarial Networks
-
GAN框架
GAN框架是有兩個對象(discriminator,generator)的對抗游戲。generator是一個生成器,generator產生來自和訓練樣本一樣的分布的樣本。discriminator是一個判別器,判別是真實數據還是generator產生的偽造數據。discriminator使用傳統的監督學習技術進行訓練,將輸入分成兩類(真實的或者偽造的)。generator訓練的目標就是欺騙判別器。
游戲中的兩個參與對象由兩個函數表示,每個都是關於輸入和參數的可微分函數。discriminator是一個以 x 作為輸入和使用θ(D) 為參數的函數D,D(x)是指判斷輸入樣本x是真實樣本的概率 ;generator由一個以z為輸入使用 θ(G) 為參數的函數G,G(z)是指輸入樣本z產生一個新的樣本,這個新樣本希望接近真實樣本的分布。
discriminator與generator都用兩個參與對象的參數定義的代價函數。discriminator希望僅控制住θ(D) 情形下最小化 J(D)(θ(D), θ(G))。generator希望在僅控制θ(D) 情形下最小化 J(G)(θ(D),θ(G))。因為每個參與對象的代價依賴於其他參與對象的參數,但是每個參與對象不能控制別人的參數,這個場景其實更為接近一個博弈而非優化問題。優化問題的解是一個局部最小,這是參數空間的點其鄰居有着不小於它的代價。而對一個博弈的解釋一個納什均衡。在這樣的設定下,Nash 均衡是一個元組,( θ(D), θ(G)) 既是關於θ(D)的 J(D) 的局部最小值和也是關於θ(G)的 J(G) 局部最小值。
圖 1 GAN兩種場景
如圖 1所示GAN有兩種場景,第一種場景(左圖),discriminator對象隨機從樣本集中取一個元素X作為輸入,discriminator對象的目標是以真實樣本X作為輸入時,盡量判斷D(x)為1;而第二種場景(右圖),具有discriminator和generator兩個對象的參與,generator對象以噪聲變量z作為輸入,然后產生一個樣本G(z),discriminator對象以G(z)作為輸入並盡量判斷D(G(z) )為0;而generator對象的目標是盡量讓discriminator對象計算D(G(z) )為1。最后這個游戲是達到納什均衡(Nash equilibrium),即G(z)產生的數據樣本分布與真實數據樣本分布一樣,即對於所有的輸入x,D(x) 的計算結果為0.5。
-
ANN函數
GAN是由一個判別模型(discriminator)和生成模(generator)型組成。其中discriminator和generator可以由任何可微函數來描述,如圖 4所示是采用兩個多層的神經網絡來描述discriminator和generator模型,即圖中的G和D函數。
圖 2
generator是一個可微分函數 G。當 z 從某個簡單的先驗分布中采樣出來時,G(z) 產生一個從 pmodel 中的樣本。一般來說, GAN對於generator神經網絡只有很少的限制。如果我們希望 pmodel 是 x 空間的支集(support),我們需要 z 的維度需要至少和 x 的維度一樣大,而且 G 必須是可微分的,但是這些其實就是僅有的要求了。
-
損失函數
-
discriminator的代價
-
交叉熵[2]
交叉熵代價函數(Cross-entropy cost function)是用來衡量人工神經網絡(ANN)的預測值與實際值的一種方式。交叉熵損失函數定義如下:
其中:
- x表示樣本
- y表示樣本x對應的標簽
- a表示以樣本x作為輸入,模型的輸出標簽
- n表示樣本的總數,當為二分類時n為2
-
目前為 GANs 設計的所有不同的博弈針對discriminator的 J(D) 使用了同樣的代價函數。他們僅僅是generator J(G) 的代價函數不同。
discriminator的代價函數是:
其中: 表示在分布上的期望,D(x)為概率函數。
其實就是標准地訓練一個sigmoid 輸出的標准二分類器交叉熵代價函數。唯一的不同就是分類器在兩個 minibatch 的數據上進行訓練;一個來自數據集(其中的標簽均是 1),另一個來自生成器(其標簽均是 0)。
通過給discriminator模型定義損失函數后,將優化discriminator模型轉移為優化等式,即訓練discriminator模型就是為了最小化discriminator的等式。
-
Minimax
GAN框架有兩個參與對象discriminator和generator ,上一節只考慮優化discriminator模型,還需要考慮優化generator模型。GAN使用了零和博弈思想為generator模型定義損失函數。在零和博弈游戲中,其所有參與人的代價總是 0,即在游戲中贏的得正數,輸的得負數,所以總和為0。在零和博弈中,參加游戲雙方的得分互為相反數,所以根據discriminator的損失函數,可推導出generator的損失函數為:
所以優化generator模型,一樣是優化 損失函數,即最小化該損失函數。由於和兩個損失函數只是互為相反數,所以可以將兩個等式合並為一個優化等式。即
由於我們訓練D來最大化分配正確標簽給不管是來自於訓練樣例還是G生成的樣例的概率.我們同時訓練G來最小化。換句話說,D和G的訓練是關於值函數V(G,D)的極小化極大的二人博弈問題:
其中:
- G表示生成模型,D表示分類模型
- x~pdata(x) 表示x取自訓練數據的分布
- z~p(z) 表示z取自我們模擬數據的分布
圖 3
如圖 2所示a-b是模型G和D的優化過程,黑色的虛線表示訓練數據的分布;綠色的實線表示模型G產生的分布;藍色的虛線表示模型D的計算值;水平X軸表示D函數的計算值;水平z軸表示噪聲值。一開始G的產生分布於真實數據分布偏離較大,且模型D對真實數據和偽造數據區分能力較強,即對真實數據D函數的計算值較大,而對偽造數據D函數的計算值較小,如圖a;隨着模型的訓練,G數據分布於真實數據分布逐漸重合,如圖d,最后D的計算值恆等為0.5。
-
訓練過程
訓練過程包含同時隨機梯度下降 simultaneous SGD。在每一步,會采樣兩個 minibatch:一個來自數據集的 x 的 minibatch 和一個從隱含變量的模型先驗采樣的 z 的 minibatch。然后兩個梯度步驟同時進行:一個更新 θ(D) 來降低 J(D),另一個更新 θ(G) 來降低 J(G)。這兩個步驟都可以使用你選擇的基於梯度的優化算法。
生成對抗網絡的minibatch隨機梯度下降訓練。判別器的訓練步數,k是一個超參數。在我們的試驗中使用k=1,使消耗最小。
圖 4
-
理論分析
GAN的設計思想采用discriminator和generator兩個模型進行對抗優化,本章用兩個證明來從理論上論證了對抗網絡的合理性。
-
命題一:全局最優
命題:當G固定的時候,D會有唯一的最優解。真實描述如下:
證明如下:
-
首先,根據連續函數的期望計算方式,對V(G,D)進行變換:
-
對於任意的a,b ∈ R2 \ {0, 0}, 下面的式子在a/(a+b)處達到最優:
所以得證。
-
命題二:收斂性
命題:如果G和D有足夠的性能,對於算法中的每一步,給定G時,判別器能夠達到它的最優,並且通過更新pg來提高這個判別准則。
則pg收斂為pdata。
證明略,看不太懂。
-
CycleGAN[5]
-
概述
CycleGAN的原理可以概述為: 將一類圖片轉換成另一類圖片 。也就是說,現在有兩個樣本空間X和Y,我們希望把X空間中的樣本轉換成Y空間中的樣本。(獲取一個數據集的特征,並轉化成另一個數據集的特征).
圖 5
-
形式化
CycleGAN模型的學習目標是訓練兩個映射函數:G:XàY和F:YàX,同時CycleGAN模型還包含了兩個相關的discriminator對象:Dx和Dy。Dy是為了區分G函數產生的數據和Y數據;而Dx是為了區分F函數產生的數據和X數據,如圖 5(a)所示。
-
對抗損失函數
如3.2小節所示介紹的對抗網絡,對於一個映射函數G:XàY,和discriminator對象DY,則GAN的損失函數定義為:
其中,映射函數G是將X領域的數據轉換為類似Y領域的數據,而DY就是判別真實的Y數據和G偽造的Y數據。即GAN的優化目標是:。同樣的對於映射F:YàX,和discriminator對象DX,可以定義一個GAN損失函數的優化目標:.
-
循環一致損失函數
理論上GAN能夠學習兩個映射函數G和F,其能夠分別從X或Y一個領域的數據生成到另一個領域的數據。但是由於映射函數變換可能性非常多,無法保證映射函數能夠將一個領域的輸入數據xi轉換為其它領域的數據yi。為了減少映射函數的變換范圍或可能性,CycleGAN增加了一些約束函數來限制這種變換范圍過大的問題。
如圖 5(b)所示,通過映射函數G和F,可以從X領域的數據樣本變換為領域Y的數據樣本,再變換為X領域的數據樣本,從而生成一個環,即:,同理有圖 5(c)的。所以原始數據樣本x和循環產生的數據F(G(x))之間肯定有差異,那么可以定義一致性損失函數為:
其中式中的方括號是使用了L1規范化。
-
完整表達式
綜上所述,CycleGAN的損失函數可以完整表達為:
其中控制了映射函數G和F的相對重要性。所以CycleGAN的優化目標是:
其中G和F兩個映射函數的內部結構互相彼此獨立,即它們能將一個數據樣本映射到另一個領域的數據樣本。
-
實現
CycleGAN網絡的實現就是定義四個神經網絡:G、F、Dx和Dy;然后優化這個最終的表達式,
-
參考文獻
-
Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks.
-
