GAN
原始GAN中判別器要最小化如下損失函數,盡可能把真實樣本分為正例,生成樣本分為負例:
其中是真實樣本分布,
是由生成器產生的樣本分布。
第一個式子我們不看梯度符號的話即為判別器的損失函數,logD(xi)為判別器將真實數據判定為真實數據的概率,log(1-D(G(zi)))為判別器將生成器生成的虛假數據判定為真實數據的對立面即將虛假數據仍判定為虛假數據的概率。判別器就相當於警察,在鑒別真偽時,必須要保證鑒別的結果真的就是真的,假的就是假的,所以判別器的總損失即為二者之和,應當最大化該損失。由於判別器(警察)鑒別真偽的能力隨着訓練次數的增加越來越高,生成器就要與之“對抗”,生成器就要相應地提高“造假”技術,來迷惑判別器。第二個式子為第一個式子的第二項,含義相同,只不過對於生成器應當最小化該項,生成器當然希望辨別器將虛假數據仍判定為虛假數據的概率越低越好,即將虛假數據誤判定為真實數據的概率越大越好,即最大化log(D(G(zi)))損失函數。所以二者相互提高或者減小自身的損失,以不斷互相對抗。
從公式1可以得到,在生成器G固定參數時最優的判別器D應該是什么。對於一個具體的樣本,它可能來自真實分布也可能來自生成分布,它對公式1損失函數的貢獻是
令其關於的導數為0,得
化簡得最優判別器為:
(公式4)
看一個樣本來自真實分布和生成分布的可能性的相對比例。如果
且
,最優判別器就應該非常自信地給出概率0;如果
,說明該樣本是真是假的可能性剛好一半一半,此時最優判別器也應該給出概率0.5。
(公式2)
GAN訓練的問題之一,就是別把判別器訓練得太好,否則在實驗中生成器會完全學不動(loss降不下去),為了探究背后的原因,我們就可以看看在極端情況——判別器最優時,生成器的損失函數變成什么。給公式2加上一個不依賴於生成器的項,使之變成
最小化這個損失函數等價於最小化公式2,而且它剛好是判別器損失函數的相反數。代入最優判別器即公式4,再進行簡單的變換可以得到
(公式5)
從而可以得到KL散度和JS散度(衡量量兩個分布的差異區別)
於是公式5就可以繼續寫成
(公式8)
目前得到的結論:根據原始GAN定義的判別器loss,我們可以得到最優判別器的形式;而在最優判別器的下,我們可以把原始GAN定義的生成器loss等價變換為最小化真實分布與生成分布
之間的JS散度。我們越訓練判別器,它就越接近最優,最小化生成器的loss也就會越近似於最小化
和
之間的JS散度。
問題就出在這個JS散度上。我們會希望如果兩個分布之間越接近它們的JS散度越小,我們通過優化JS散度就能將“拉向”
,最終以假亂真。這個希望在兩個分布有所重疊的時候是成立的,但是如果兩個分布完全沒有重疊的部分,或者它們重疊的部分可忽略(下面解釋什么叫可忽略),它們的JS散度是多少呢?
答案是,因為對於任意一個x只有四種可能:
且
且
且
且
第一種對計算JS散度無貢獻,第二種情況由於重疊部分可忽略所以貢獻也為0,第三種情況對公式7右邊第一個項的貢獻是,第四種情況與之類似,所以最終
。
換句話說,無論跟
是遠在天邊,還是近在眼前,只要它們倆沒有一點重疊或者重疊部分可忽略,JS散度就固定是常數
,而這對於梯度下降方法意味着——梯度為0!此時對於最優判別器來說,生成器肯定是得不到一丁點梯度信息的;即使對於接近最優的判別器來說,生成器也有很大機會面臨梯度消失的問題。
但是與
不重疊或重疊部分可忽略的可能性有多大?不嚴謹的答案是:非常大。比較嚴謹的答案是:當
與
的支撐集(support)是高維空間中的低維流形(manifold)時,
與
重疊部分測度(measure)為0的概率為1。
- 支撐集(support)其實就是函數的非零部分子集,比如ReLU函數的支撐集就是
,一個概率分布的支撐集就是所有概率密度非零部分的集合。
- 流形(manifold)是高維空間中曲線、曲面概念的拓廣,我們可以在低維上直觀理解這個概念,比如我們說三維空間中的一個曲面是一個二維流形,因為它的本質維度(intrinsic dimension)只有2,一個點在這個二維流形上移動只有兩個方向的自由度。同理,三維空間或者二維空間中的一條曲線都是一個一維流形。
- 測度(measure)是高維空間中長度、面積、體積概念的拓廣,可以理解為“超體積”。
在(近似)最優判別器下,最小化生成器的loss等價於最小化與
之間的JS散度,而由於
與
幾乎不可能有不可忽略的重疊,所以無論它們相距多遠JS散度都是常數
,最終導致生成器的梯度(近似)為0,梯度消失。
- 首先,
與
之間幾乎不可能有不可忽略的重疊,所以無論它們之間的“縫隙”多狹小,都肯定存在一個最優分割曲面把它們隔開,最多就是在那些可忽略的重疊處隔不開而已。
- 由於判別器作為一個神經網絡可以無限擬合這個分隔曲面,所以存在一個最優判別器,對幾乎所有真實樣本給出概率1,對幾乎所有生成樣本給出概率0,而那些隔不開的部分就是難以被最優判別器分類的樣本,但是它們的測度為0,可忽略。
- 最優判別器在真實分布和生成分布的支撐集上給出的概率都是常數(1和0),導致生成器的loss梯度為0,梯度消失。
有了這些理論分析,原始GAN不穩定的原因就徹底清楚了:判別器訓練得太好,生成器梯度消失,生成器loss降不下去;判別器訓練得不好,生成器梯度不准,四處亂跑。只有判別器訓練得不好不壞才行,但是這個火候又很難把握,甚至在同一輪訓練的前后不同階段這個火候都可能不一樣,所以GAN才那么難訓練。
WGAN
引入Wasserstein距離
希望建立一個平滑的,處處可導的cost function。在圖中,藍色為真實分布,綠色為生成數據的分布。紅色為discriminator的cost function,我們發現雖然discriminator有效的區分了兩個分布,但是當藍綠兩個分布沒有交集時,在大量的點上的cost function為常數值,梯度為0,generator 不能更新了。這時看一下wasserstein 距離,它體現為那個草綠色的線,它平滑,可導這就是我們要尋找的cost function。
數學定義如下:
(公式12)
解釋如下:是
和
組合起來的所有可能的聯合分布的集合,反過來說,
中每一個分布的邊緣分布都是
和
。對於每一個可能的聯合分布
而言,可以從中采樣
得到一個真實樣本
和一個生成樣本
,並算出這對樣本的距離
,所以可以計算該聯合分布
下樣本對距離的期望值
。在所有可能的聯合分布中能夠對這個期望值取到的下界
,就定義為Wasserstein距離。
Wasserstein距離相比KL散度、JS散度的優越性在於,即便兩個分布沒有重疊,Wasserstein距離仍然能夠反映它們的遠近。
但是
因為Wasserstein距離定義(公式12)中的沒法直接求解,不過沒關系,作者用了一個已有的定理把它變換為如下形式
(公式13)
Lipschitz連續。它其實就是在一個連續函數上面額外施加了一個限制,
的導函數絕對值不超過
。限制了一個連續函數的最大局部變動幅度。
公式13的意思就是在要求函數的Lipschitz常數
不超過
的條件下,對所有可能滿足條件的
取到
的上界,然后再除以
。特別地,我們可以用一組參數
來定義一系列可能的函數
,此時求解公式13可以近似變成求解如下形式
(公式14)
再用上我們搞深度學習的人最熟悉的那一套,不就可以把用一個帶參數
的神經網絡來表示嘛!由於神經網絡的擬合能力足夠強大,我們有理由相信,這樣定義出來的一系列
雖然無法囊括所有可能,但是也足以高度近似公式13要求的那個
了。
最后,還不能忘了滿足公式14中這個限制。我們其實不關心具體的K是多少,只要它不是正無窮就行,因為它只是會使得梯度變大
倍,並不會影響梯度的方向。所以作者采取了一個非常簡單的做法,就是限制神經網絡
的所有參數
的不超過某個范圍
,比如
,此時關於輸入樣本
的導數
也不會超過某個范圍,所以一定存在某個不知道的常數
使得
的局部變動幅度不會超過它,Lipschitz連續條件得以滿足。具體在算法實現中,只需要每次更新完
后把它clip回這個范圍就可以了。
到此為止,我們可以構造一個含參數、最后一層不是非線性激活層的判別器網絡
,在限制
不超過某個范圍的條件下,使得
(公式15)
盡可能取到最大,此時就會近似真實分布與生成分布之間的Wasserstein距離(忽略常數倍數
)。注意原始GAN的判別器做的是真假二分類任務,所以最后一層是sigmoid,但是現在WGAN中的判別器
做的是近似擬合Wasserstein距離,屬於回歸任務,所以要把最后一層的sigmoid拿掉。
接下來生成器要近似地最小化Wasserstein距離,可以最小化,由於Wasserstein距離的優良性質,我們不需要擔心生成器梯度消失的問題。再考慮到
的第一項與生成器無關,就得到了WGAN的兩個loss。
(公式16,WGAN生成器loss函數)
(公式17,WGAN判別器loss函數)
公式15是公式17的反,可以指示訓練進程,其數值越小,表示真實分布與生成分布的Wasserstein距離越小,GAN訓練得越好。
WGAN與原始GAN第一種形式相比,只改了四點:
- 判別器最后一層去掉sigmoid
- 生成器和判別器的loss不取log
- 每次更新判別器的參數之后把它們的絕對值截斷到不超過一個固定常數c
- 不要用基於動量的優化算法(包括momentum和Adam),推薦RMSProp,SGD也行