GAN-生成對抗網絡原理


最近一直在看GAN,我一直認為只有把博客看了一遍,然后再敲一遍。這樣才會有深刻的感悟。

 

GAN(生成式對抗網絡)(GAN, Generative Adversarial Networks )是一種深度學習模型,分布在無監督學習上。

分成兩個模塊:生成模型(Generative Model)和判別模型(Discriminative Model)。簡單來說就是:兩個人比賽,看是 A 的矛厲害,還是 B 的盾厲害。。

比如:我們有一些真實數據,同時也有隨機生成的假數據。A把假數據拼命地模仿成真數據,B拼命地想把真實數據和假數據分開。

 

這里,A就是一個生成模型,類似於造假鈔,一個勁的學習如何騙過B。B是一個判別模型,類似與警察,一個勁地學習如何分辨出A的造假技巧

然后,B的鑒別技巧越來越厲害,A的造假技術越來越逼真,成為一個一流的假幣制造者。而GAN就是獲得上述的兩個模型。

我們需要同時訓練兩個模型。G:生成器。D:判別器。生成器G的訓練過程是最大化判別器犯錯誤的概率,即判別器誤以為數據是真實樣本而不是生成器生成的假樣本。因此,這一框架就對應於兩個參與者的極小極大博弈。在所有可能的函數G和D中,我們可以求出唯一的均衡解,即G可以生成與訓練樣本相同的分布,而D判斷的概率為1/2,意思就是D已經無法判別數據的真假。

為了學習到生成器在數據x上的分布P_g,我們先定義一個輸入的噪聲變量z,然后根據G將其映射到數據空間中,其中G為多層感知機所表征的可微函數。

同樣要定義第二個多層感知機D,它的輸出是單個標量。D(x)表示x是真實數據。我們訓練D以最大化正確分配真實樣本和生成樣本的概率,引起我們就可以最小化log(1-D(G(z)))而同時訓練G。也就是說判別器D和生成器對價值函數V(G, D)進行極小極大化博弈。

 

如上圖所示,生成對抗網絡會訓練並更新判別分布(即D,藍色的虛線),更新判別器后就能將數據真實分布(黑點組成的線)從生成分布P_g(G)(綠色實線)中判別出來。下方的水平線代表采樣域Z,其中等距表示Z中的樣本為均勻分布,上方的水平線代表真實數據X中的一部分。向上的箭頭表示映射x = G(z)如何對噪聲樣本(均勻采樣)施加一個不均勻的分布P_g.

(a) 考慮在收斂點附近的對抗訓練:P_g和P_data已經十分相似,D是一個局部准確的分類器。

(b) 在算法內部循環中訓練D,從數據中判別出真實樣本,該循環最終會收斂到D(x) = P_data(x) / (P_data(x) + P_g(x))

(c) 隨后固定判別器並訓練生成器,在更新G后,D的梯度會引導G(z)流向更可能被D分類為真實數據的方向。

(d) 經過若干次訓練后,如果G和D有足夠的復雜度,那么他們就會到達一個均衡點,這時:P_g = P_data,即生成數據的概率密度函數等於真實數據的概率密度函數,生成數據 = 真實數據。在均衡點上D和G都不能進一步提升,並且判別器無法判斷數據到底是來自真實樣本還是偽造的數據,即D(x) = 1/2

公式推導(公式推導部分來自機器之心):

下面,我們必須證明該最優化問題也就是價值函數V(G, D),有唯一解並且該解滿足P_G = P_data

將數學期望展開為積分形式:

其實求積分的最大值可以轉化為求被積函數的最大值。而求被積函數的最大值是為了求得最優判別器D,因此不涉及判別器的項都可以被看做為常數項。如下所示:P_data(x)和P_G(x)都為標量,因此被積函數可表示為 a * D(x) + b * log(1 - D(x)).

若令判別器D(x)等於y,那么被積函數可以寫為:

為了找到最優的極值點,如果a + b ≠ 0,我們可以用以下一階導求解:

如果我們繼續求表達式f(y)在駐點的二階導:

 

最優生成器

當然GAN過程的目標是令P_G = P_data。

這意味着判別器已經完全困惑,它完全分辨不出P_date和P_G的區別,即判斷樣本來自P_data和P_G的概率都為1/2。基於這一觀點,GAN的作者證明了G就是極小極大博弈的解。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM