第七章 生成對抗網絡
7.1 GAN基本概念
7.1.1 如何通俗理解GAN?
生成對抗網絡(GAN Generative adversarial network)自從2014年被Ian Goodfellow提出以來,掀起來了一股研究熱潮。GAN由生成器和判別器組成,生成器負責生成樣本,判別器負責判斷生成器生成的樣本是否為真。生成器要盡可能迷惑判別器,而判別器要盡可能區分生成器生成的樣本和真實樣本。
在GAN的原作[1]中,作者將生成器比喻為印假鈔票的犯罪分子,判別器則類比為警察。犯罪分子努力讓鈔票看起來逼真,警察則不斷提升對於假鈔的辨識能力。二者互相博弈,隨着時間的進行,都會越來越強。那么類比於圖像生成任務,生成器不斷生成盡可能逼真的假圖像。判別器則判斷圖像是否是真實的圖像,還是生成的圖像,二者不斷博弈優化。最終生成器生成的圖像使得判別器完全無法判別真假。
7.1.2 GAN的形式化表達
上述例子只是簡要介紹了一下GAN的思想,下面對於GAN做一個形式化的,更加具體的定義。通常情況下,無論是生成器還是判別器,我們都可以用神經網絡來實現。那么,我們可以把通俗化的定義用下面這個模型來表示:
上述模型左邊是生成器G,其輸入是\(z\),對於原始的GAN,\(z\)是由高斯分布隨機采樣得到的噪聲。噪聲\(z\)通過生成器得到了生成的假樣本。
生成的假樣本與真實樣本放到一起,被隨機抽取送入到判別器D,由判別器去區分輸入的樣本是生成的假樣本還是真實的樣本。整個過程簡單明了,生成對抗網絡中的“生成對抗”主要體現在生成器和判別器之間的對抗。
7.1.3 GAN的目標函數是什么?
對於上述神經網絡模型,如果想要學習其參數,首先需要一個目標函數。GAN的目標函數定義如下:
這個目標函數可以分為兩個部分來理解:
第一部分:判別器的優化通過$$\mathop {\max}\limits_D V(DG)$$實現,$$V(DG)$$為判別器的目標函數,其第一項 $${\rm E}{x\sim{p{data}(x)}}[\log D(x)]$$表示對於從真實數據分布 中采用的樣本 其被判別器判定為真實樣本概率的數學期望。對於真實數據分布 中采樣的樣本,其預測為正樣本的概率當然是越接近1越好。因此希望最大化這一項。第二項 $${\rm E}_{z\sim{p_z}(z)}[\log (1 - D(G(z)))]$$表示:對於從噪聲$$P_z(z)$$分布當中采樣得到的樣本,經過生成器生成之后得到的生成圖片,然后送入判別器,其預測概率的負對數的期望,這個值自然是越大越好,這個值越大, 越接近0,也就代表判別器越好。
第二部分:生成器的優化通過 $$\mathop {\min }\limits_G({\mathop {\max }\limits_D V(DG)})$$來實現。注意,生成器的目標不是 $$\mathop {\min }\limits_GV(DG)$$ ,即生成器不是最小化判別器的目標函數,二是最小化判別器目標函數的最大值,判別器目標函數的最大值代表的是真實數據分布與生成數據分布的JS散度(詳情可以參閱附錄的推導),JS散度可以度量分布的相似性,兩個分布越接近,JS散度越小。
7.1.4 GAN的目標函數和交叉熵有什么區別?
判別器目標函數寫成離散形式即為:
可以看出,這個目標函數和交叉熵是一致的,即判別器的目標是最小化交叉熵損失,生成器的目標是最小化生成數據分布和真實數據分布的JS散度。
7.1.5 GAN的Loss為什么降不下去?
對於很多GAN的初學者在實踐過程中可能會納悶,為什么GAN的Loss一直降不下去。GAN到底什么時候才算收斂?其實,作為一個訓練良好的GAN,其Loss就是降不下去的。衡量GAN是否訓練好了,只能由人肉眼去看生成的圖片質量是否好。不過,對於沒有一個很好的評價是否收斂指標的問題,也有許多學者做了一些研究,后文提及的WGAN就提出了一種新的Loss設計方式,較好的解決了難以判斷收斂性的問題。下面我們分析一下GAN的Loss為什么降不下去?
對於判別器而言,GAN的Loss如下:
從 \(\mathop {\min }\limits_G \mathop {\max }\limits_D V(DG)\) 可以看出,生成器和判別器的目的相反,也就是說兩個生成器網絡和判別器網絡互為對抗,此消彼長。不可能Loss一直降到一個收斂的狀態。
- 對於生成器,其Loss下降快,很有可能是判別器太弱,導致生成器很輕易的就"愚弄"了判別器。
- 對於判別器,其Loss下降快,意味着判別器很強,判別器很強則說明生成器生成的圖像不夠逼真,才使得判別器輕易判別,導致Loss下降很快。
也就是說,無論是判別器,還是生成器。loss的高低不能代表生成器的好壞。一個好的GAN網絡,其GAN Loss往往是不斷波動的。
看到這里可能有點讓人絕望,似乎判斷模型是否收斂就只能看生成的圖像質量了。實際上,后文探討的WGAN,提出了一種新的loss度量方式,讓我們可以通過一定的手段來判斷模型是否收斂。
7.1.6 生成式模型、判別式模型的區別?
對於機器學習模型,我們可以根據模型對數據的建模方式將模型分為兩大類,生成式模型和判別式模型。如果我們要訓練一個關於貓狗分類的模型,對於判別式模型,只需要學習二者差異即可。比如說貓的體型會比狗小一點。而生成式模型則不一樣,需要學習貓張什么樣,狗張什么樣。有了二者的長相以后,再根據長相去區分。具體而言:
-
生成式模型:由數據學習聯合概率分布P(XY) 然后由P(Y|X)=P(XY)/P(X)求出概率分布P(Y|X)作為預測的模型。該方法表示了給定輸入X與產生輸出Y的生成關系
-
判別式模型:由數據直接學習決策函數Y=f(X)或條件概率分布P(Y|X)作為預測模型,即判別模型。判別方法關心的是對於給定的輸入X,應該預測什么樣的輸出Y。
對於上述兩種模型,從文字上理解起來似乎不太直觀。我們舉個例子來闡述一下,對於性別分類問題,分別用不同的模型來做:
1)如果用生成式模型:可以訓練一個模型,學習輸入人的特征X和性別Y的關系。比如現在有下面一批數據:
Y(性別) | 0 | 1 | |
---|---|---|---|
X(特征) | 0 | 1/4 | 3/4 |
1 | 3/4 | 1/4 |
這個數據可以統計得到,即統計人的特征X=01….的時候,其類別為Y=01的概率。統計得到上述聯合概率分布P(X Y)后,可以學習一個模型,比如讓二維高斯分布去擬合上述數據,這樣就學習到了X,Y的聯合分布。在預測時,如果我們希望給一個輸入特征X,預測其類別,則需要通過貝葉斯公式得到條件概率分布才能進行推斷:
2)如果用判別式模型:可以訓練一個模型,輸入人的特征X,這些特征包括人的五官,穿衣風格,發型等。輸出則是對於性別的判斷概率,這個概率服從一個分布,分布的取值只有兩個,要么男,要么女,記這個分布為Y。這個過程學習了一個條件概率分布P(Y|X),即輸入特征X的分布已知條件下,Y的概率分布。
顯然,從上面的分析可以看出。判別式模型似乎要方便很多,因為生成式模型要學習一個X,Y的聯合分布往往需要很多數據,而判別式模型需要的數據則相對少,因為判別式模型更關注輸入特征的差異性。不過生成式既然使用了更多數據來生成聯合分布,自然也能夠提供更多的信息,現在有一個樣本(XY)其聯合概率P(XY)經過計算特別小,那么可以認為這個樣本是異常樣本。這種模型可以用來做outlier detection。
7.1.7 什么是mode collapsing?
某個模式(mode)出現大量重復樣本,例如:
上圖左側的藍色五角星表示真實樣本空間,黃色的是生成的。生成樣本缺乏多樣性,存在大量重復。比如上圖右側中,紅框里面人物反復出現。
7.1.8 如何解決mode collapsing?
方法一:針對目標函數的改進方法
為了避免前面提到的由於優化maxmin導致mode跳來跳去的問題,UnrolledGAN采用修改生成器loss來解決。具體而言,UnrolledGAN在更新生成器時更新k次生成器,參考的Loss不是某一次的loss,是判別器后面k次迭代的loss。注意,判別器后面k次迭代不更新自己的參數,只計算loss用於更新生成器。這種方式使得生成器考慮到了后面k次判別器的變化情況,避免在不同mode之間切換導致的模式崩潰問題。此處務必和迭代k次生成器,然后迭代1次判別器區分開[8]。DRAGAN則引入博弈論中的無后悔算法,改造其loss以解決mode collapse問題[9]。前文所述的EBGAN則是加入VAE的重構誤差以解決mode collapse。
方法二:針對網絡結構的改進方法
Multi agent diverse GAN(MAD-GAN)采用多個生成器,一個判別器以保障樣本生成的多樣性。具體結構如下:
相比於普通GAN,多了幾個生成器,且在loss設計的時候,加入一個正則項。正則項使用余弦距離懲罰三個生成器生成樣本的一致性。
MRGAN則添加了一個判別器來懲罰生成樣本的mode collapse問題。具體結構如下:
輸入樣本\(x\)通過一個Encoder編碼為隱變量\(E(x)\),然后隱變量被Generator重構,訓練時,Loss有三個。\(D_M\)和\(R\)(重構誤差)用於指導生成real-like的樣本。而\(D_D\)則對\(E(x)\)和\(z\)生成的樣本進行判別,顯然二者生成樣本都是fake samples,所以這個判別器主要用於判斷生成的樣本是否具有多樣性,即是否出現mode collapse。
方法三:Mini-batch Discrimination
Mini-batch discrimination在判別器的中間層建立一個mini-batch layer用於計算基於L1距離的樣本統計量,通過建立該統計量,實現了一個batch內某個樣本與其他樣本有多接近。這個信息可以被判別器利用到,從而甄別出哪些缺乏多樣性的樣本。對生成器而言,則要試圖生成具有多樣性的樣本。
7.2 GAN的生成能力評價
7.2.1 如何客觀評價GAN的生成能力?
最常見評價GAN的方法就是主觀評價。主觀評價需要花費大量人力物力,且存在以下問題:
-
評價帶有主管色彩,有些bad case沒看到很容易造成誤判
-
如果一個GAN過擬合了,那么生成的樣本會非常真實,人類主觀評價得分會非常高,可是這並不是一個好的GAN。
因此,就有許多學者提出了GAN的客觀評價方法。
7.2.2 Inception Score
對於一個在ImageNet訓練良好的GAN,其生成的樣本丟給Inception網絡進行測試的時候,得到的判別概率應該具有如下特性:
-
對於同一個類別的圖片,其輸出的概率分布應該趨向於一個脈沖分布。可以保證生成樣本的准確性。
-
對於所有類別,其輸出的概率分布應該趨向於一個均勻分布,這樣才不會出現mode dropping等,可以保證生成樣本的多樣性。
因此,可以設計如下指標:
根據前面分析,如果是一個訓練良好的GAN,\(p_M(y|x)\)趨近於脈沖分布,\(p_M(y)\)趨近於均勻分布。二者KL散度會很大。Inception Score自然就高。實際實驗表明,Inception Score和人的主觀判別趨向一致。IS的計算沒有用到真實數據,具體值取決於模型M的選擇。
特點:可以一定程度上衡量生成樣本的多樣性和准確性,但是無法檢測過擬合。Mode Score也是如此。不推薦在和ImageNet數據集差別比較大的數據上使用。
7.2.3 Mode Score
Mode Score作為Inception Score的改進版本,添加了關於生成樣本和真實樣本預測的概率分布相似性度量一項。具體公式如下:
7.2.4 Kernel MMD (Maximum Mean Discrepancy)
計算公式如下:
對於Kernel MMD值的計算,首先需要選擇一個核函數\(k\),這個核函數把樣本映射到再生希爾伯特空間(Reproducing Kernel Hilbert Space RKHS) ,RKHS相比於歐幾里得空間有許多優點,對於函數內積的計算是完備的。將上述公式展開即可得到下面的計算公式:
MMD值越小,兩個分布越接近。
特點:可以一定程度上衡量模型生成圖像的優劣性,計算代價小。推薦使用。
7.2.5 Wasserstein distance
Wasserstein distance在最優傳輸問題中通常也叫做推土機距離。這個距離的介紹在WGAN中有詳細討論。公式如下:
Wasserstein distance可以衡量兩個分布之間的相似性。距離越小,分布越相似。
特點:如果特征空間選擇合適,會有一定的效果。但是計算復雜度為\(O(n^3)\)太高
7.2.6 Fréchet Inception Distance (FID)
FID距離計算真實樣本,生成樣本在特征空間之間的距離。首先利用Inception網絡來提取特征,然后使用高斯模型對特征空間進行建模。根據高斯模型的均值和協方差來進行距離計算。具體公式如下:
\(\muC\)分別代表協方差和均值。
特點:盡管只計算了特征空間的前兩階矩,但是魯棒,且計算高效。
7.2.7 1-Nearest Neighbor classifier
使用留一法,結合1-NN分類器(別的也行)計算真實圖片,生成圖像的精度。如果二者接近,則精度接近50%,否則接近0%。對於GAN的評價問題,作者分別用正樣本的分類精度,生成樣本的分類精度去衡量生成樣本的真實性,多樣性。
- 對於真實樣本\(x_r\),進行1-NN分類的時候,如果生成的樣本越真實。則真實樣本空間\(\mathbb R\)將被生成的樣本\(x_g\)包圍。那么\(x_r\)的精度會很低。
- 對於生成的樣本\(x_g\),進行1-NN分類的時候,如果生成的樣本多樣性不足。由於生成的樣本聚在幾個mode,則\(x_g\)很容易就和\(x_r\)區分,導致精度會很高。
特點:理想的度量指標,且可以檢測過擬合。
7.2.8 其他評價方法
AIS,KDE方法也可以用於評價GAN,但這些方法不是model agnostic metrics。也就是說,這些評價指標的計算無法只利用:生成的樣本,真實樣本來計算。
7.3 其他常見的生成式模型有哪些?
7.3.1 什么是自回歸模型:pixelRNN與pixelCNN?
自回歸模型通過對圖像數據的概率分布\(p_{data}(x)\)進行顯式建模,並利用極大似然估計優化模型。具體如下:
上述公式很好理解,給定\(x_1x_2...x_{i-1}\)條件下,所有\(p(x_i)\)的概率乘起來就是圖像數據的分布。如果使用RNN對上述依然關系建模,就是pixelRNN。如果使用CNN,則是pixelCNN。具體如下[5]:
顯然,不論是對於pixelCNN還是pixelRNN,由於其像素值是一個個生成的,速度會很慢。語音領域大火的WaveNet就是一個典型的自回歸模型。
7.3.2 什么是VAE?
PixelCNN/RNN定義了一個易於處理的密度函數,我們可以直接優化訓練數據的似然;對於變分自編碼器我們將定義一個不易處理的密度函數,通過附加的隱變量\(z\)對密度函數進行建模。 VAE原理圖如下[6]:
在VAE中,真實樣本\(X\)通過神經網絡計算出均值方差(假設隱變量服從正太分布),然后通過采樣得到采樣變量\(Z\)並進行重構。VAE和GAN均是學習了隱變量\(z\)到真實數據分布的映射。但是和GAN不同的是:
- GAN的思路比較粗暴,使用一個判別器去度量分布轉換模塊(即生成器)生成分布與真實數據分布的距離。
- VAE則沒有那么直觀,VAE通過約束隱變量\(z\)服從標准正太分布以及重構數據實現了分布轉換映射\(X=G(z)\)
生成式模型對比
- 自回歸模型通過對概率分布顯式建模來生成數據
- VAE和GAN均是:假設隱變量\(z\)服從某種分布,並學習一個映射\(X=G(z)\),實現隱變量分布\(z\)與真實數據分布\(p_{data}(x)\)的轉換。
- GAN使用判別器去度量映射\(X=G(z)\)的優劣,而VAE通過隱變量\(z\)與標准正太分布的KL散度和重構誤差去度量。
7.4 GAN的改進與優化
7.4.1 如何生成指定類型的圖像——條件GAN
條件生成對抗網絡(CGAN Conditional Generative Adversarial Networks)作為一個GAN的改進,其一定程度上解決了GAN生成結果的不確定性。如果在Mnist數據集上訓練原始GAN,GAN生成的圖像是完全不確定的,具體生成的是數字1,還是2,還是幾,根本不可控。為了讓生成的數字可控,我們可以把數據集做一個切分,把數字0~9的數據集分別拆分開訓練9個模型,不過這樣太麻煩了,也不現實。因為數據集拆分不僅僅是分類麻煩,更主要在於,每一個類別的樣本少,拿去訓練GAN很有可能導致欠擬合。因此,CGAN就應運而生了。我們先看一下CGAN的網絡結構:
從網絡結構圖可以看到,對於生成器Generator,其輸入不僅僅是隨機噪聲的采樣z,還有欲生成圖像的標簽信息。比如對於mnist數據生成,就是一個one-hot向量,某一維度為1則表示生成某個數字的圖片。同樣地,判別器的輸入也包括樣本的標簽。這樣就使得判別器和生成器可以學習到樣本和標簽之間的聯系。Loss如下:
Loss設計和原始GAN基本一致,只不過生成器,判別器的輸入數據是一個條件分布。在具體編程實現時只需要對隨機噪聲采樣z和輸入條件y做一個級聯即可。
7.4.2 CNN與GAN——DCGAN
前面我們聊的GAN都是基於簡單的神經網絡構建的。可是對於視覺問題,如果使用原始的基於DNN的GAN,則會出現許多問題。如果輸入GAN的隨機噪聲為100維的隨機噪聲,輸出圖像為256x256大小。也就是說,要將100維的信息映射為65536維。如果單純用DNN來實現,那么整個模型參數會非常巨大,而且學習難度很大(低維度映射到高維度需要添加許多信息)。因此,DCGAN就出現了。具體而言,DCGAN將傳統GAN的生成器,判別器均采用GAN實現,且使用了一下tricks:
-
將pooling層convolutions替代,其中,在discriminator上用strided convolutions替代,在generator上用fractional-strided convolutions替代。
-
在generator和discriminator上都使用batchnorm。
-
移除全連接層,global pooling增加了模型的穩定性,但傷害了收斂速度。
-
在generator的除了輸出層外的所有層使用ReLU,輸出層采用tanh。
-
在discriminator的所有層上使用LeakyReLU。
網絡結構圖如下:
7.4.3 如何理解GAN中的輸入隨機噪聲?
為了了解輸入隨機噪聲每一個維度代表的含義,作者做了一個非常有趣的工作。即在隱空間上,假設知道哪幾個變量控制着某個物體,那么僵這幾個變量擋住是不是就可以將生成圖片中的某個物體消失?論文中的實驗是這樣的:首先,生成150張圖片,包括有窗戶的和沒有窗戶的,然后使用一個邏輯斯底回歸函數來進行分類,對於權重不為0的特征,認為它和窗戶有關。將其擋住,得到新的生成圖片,結果如下:
此外,將幾個輸入噪聲進行算數運算,可以得到語義上進行算數運算的非常有趣的結果。類似於word2vec。
7.4.4 GAN為什么容易訓練崩潰?
所謂GAN的訓練崩潰,指的是訓練過程中,生成器和判別器存在一方壓倒另一方的情況。
GAN原始判別器的Loss在判別器達到最優的時候,等價於最小化生成分布與真實分布之間的JS散度,由於隨機生成分布很難與真實分布有不可忽略的重疊以及JS散度的突變特性,使得生成器面臨梯度消失的問題;可是如果不把判別器訓練到最優,那么生成器優化的目標就失去了意義。因此需要我們小心的平衡二者,要把判別器訓練的不好也不壞才行。否則就會出現訓練崩潰,得不到想要的結果
7.4.5 WGAN如何解決訓練崩潰問題?
WGAN作者提出了使用Wasserstein距離,以解決GAN網絡訓練過程難以判斷收斂性的問題。Wasserstein距離定義如下:
通過最小化Wasserstein距離,得到了WGAN的Loss:
- WGAN生成器Loss:\(- {\rm E}_{x\sim{p_g}(x)}[f_w(x)]\)
- WGAN判別器Loss:\(L=-{\rm E}_{x\sim{p_{data}}(x)}[f_w(x)] + {\rm E}_{x\sim{p_g}(x)}[f_w(x)]\)
從公式上GAN似乎總是讓人摸不着頭腦,在代碼實現上來說,其實就以下幾點:
- 判別器最后一層去掉sigmoid
- 生成器和判別器的loss不取log
- 每次更新判別器的參數之后把它們的絕對值截斷到不超過一個固定常數c
7.4.6 WGAN-GP:帶有梯度正則的WGAN
實際實驗過程發現,WGAN沒有那么好用,主要原因在於WAGN進行梯度截斷。梯度截斷將導致判別網絡趨向於一個二值網絡,造成模型容量的下降。
於是作者提出使用梯度懲罰來替代梯度裁剪。公式如下:
由於上式是對每一個梯度進行懲罰,所以不適合使用BN,因為它會引入同個batch中不同樣本的相互依賴關系。如果需要的話,可以選擇Layer Normalization。實際訓練過程中,就可以通過Wasserstein距離來度量模型收斂程度了:
上圖縱坐標是Wasserstein距離,橫坐標是迭代次數。可以看出,隨着迭代的進行,Wasserstein距離趨於收斂,生成圖像也趨於穩定。
7.4.7 LSGAN
LSGAN(Least Squares GAN)這篇文章主要針對標准GAN的穩定性和圖片生成質量不高做了一個改進。作者將原始GAN的交叉熵損失采用最小二乘損失替代。LSGAN的Loss:
實際實現的時候非常簡單,最后一層去掉sigmoid,並且計算Loss的時候用平方誤差即可。之所以這么做,作者在原文給出了一張圖:

上面是作者給出的基於交叉熵損失以及最小二乘損失的Loss函數。橫坐標代表Loss函數的輸入,縱坐標代表輸出的Loss值。可以看出,隨着輸入的增大,sigmoid交叉熵損失很快趨於0,容易導致梯度飽和問題。如果使用右邊的Loss設計,則只在x=0點處飽和。因此使用LSGAN可以很好的解決交叉熵損失的問題。
7.4.8 如何盡量避免GAN的訓練崩潰問題?
-
歸一化圖像輸入到(-1,1)之間;Generator最后一層使用tanh激活函數
-
生成器的Loss采用:min (log 1-D)。因為原始的生成器Loss存在梯度消失問題;訓練生成器的時候,考慮反轉標簽,real=fake fake=real
-
不要在均勻分布上采樣,應該在高斯分布上采樣
-
一個Mini-batch里面必須只有正樣本,或者負樣本。不要混在一起;如果用不了Batch Norm,可以用Instance Norm
-
避免稀疏梯度,即少用ReLU,MaxPool。可以用LeakyReLU替代ReLU,下采樣可以用Average Pooling或者Convolution + stride替代。上采樣可以用PixelShuffle ConvTranspose2d + stride
-
平滑標簽或者給標簽加噪聲;平滑標簽,即對於正樣本,可以使用0.7-1.2的隨機數替代;對於負樣本,可以使用0-0.3的隨機數替代。 給標簽加噪聲:即訓練判別器的時候,隨機翻轉部分樣本的標簽。
-
如果可以,請用DCGAN或者混合模型:KL+GAN,VAE+GAN。
-
使用LSGAN,WGAN-GP
-
Generator使用Adam,Discriminator使用SGD
-
盡快發現錯誤;比如:判別器Loss為0,說明訓練失敗了;如果生成器Loss穩步下降,說明判別器沒發揮作用
-
不要試着通過比較生成器,判別器Loss的大小來解決訓練過程中的模型坍塌問題。比如:
While Loss D > Loss A:
Train D
While Loss A > Loss D:
Train A -
如果有標簽,請盡量利用標簽信息來訓練
-
給判別器的輸入加一些噪聲,給G的每一層加一些人工噪聲。
-
多訓練判別器,尤其是加了噪聲的時候
-
對於生成器,在訓練,測試的時候使用Dropout
7.3 GAN的應用(圖像翻譯)
7.3.1 什么是圖像翻譯?
GAN作為一種強有力的生成模型,其應用十分廣泛。最為常見的應用就是圖像翻譯。所謂圖像翻譯,指從一副圖像到另一副圖像的轉換。可以類比機器翻譯,一種語言轉換為另一種語言。常見的圖像翻譯任務有:
-
圖像去噪
-
圖像超分辨
-
圖像補全
-
風格遷移
-
...
本節將介紹一個經典的圖像翻譯網絡及其改進。圖像翻譯可以分為有監督圖像翻譯和無監督圖像翻譯:
-
有監督圖像翻譯:原始域與目標域存在一一對應數據
-
無監督圖像翻譯:原始域與目標域不存在一一對應數據
7.3.2 有監督圖像翻譯:pix2pix
在這篇paper里面,作者提出的框架十分簡潔優雅(好用的算法總是簡潔優雅的)。相比以往算法的大量專家知識,手工復雜的loss。這篇paper非常粗暴,使用CGAN處理了一系列的轉換問題。下面是一些轉換示例:
上面展示了許多有趣的結果,比如分割圖\(\longrightarrow\)街景圖,邊緣圖\(\longrightarrow\)真實圖。對於第一次看到的時候還是很驚艷的,那么這個是怎么做到的呢?我們可以設想一下,如果是我們,我們自己會如何設計這個網絡?
直觀的想法?
最直接的想法就是,設計一個CNN網絡,直接建立輸入-輸出的映射,就像圖像去噪問題一樣。可是對於上面的問題,這樣做會帶來一個問題。生成圖像質量不清晰。
拿左上角的分割圖\(\longrightarrow\)街景圖為例,語義分割圖的每個標簽比如“汽車”可能對應不同樣式,顏色的汽車。那么模型學習到的會是所有不同汽車的評均,這樣會造成模糊。
如何解決生成圖像的模糊問題?
這里作者想了一個辦法,即加入GAN的Loss去懲罰模型。GAN相比於傳統生成式模型可以較好的生成高分辨率圖片。思路也很簡單,在上述直觀想法的基礎上加入一個判別器,判斷輸入圖片是否是真實樣本。模型示意圖如下:
上圖模型和CGAN有所不同,但它是一個CGAN,只不過輸入只有一個,這個輸入就是條件信息。原始的CGAN需要輸入隨機噪聲,以及條件。這里之所有沒有輸入噪聲信息,是因為在實際實驗中,如果輸入噪聲和條件,噪聲往往被淹沒在條件C當中,所以這里直接省去了。
7.3.3 其他圖像翻譯的tricks
從上面兩點可以得到最終的Loss由兩部分構成:
-
輸出和標簽信息的L1 Loss。
-
GAN Loss
-
測試也使用Dropout,以使輸出多樣化
\[G^*=arg\mathop {\min }\limits_G \mathop {\max }\limits_D \Gamma_{cGAN}(GD)+\lambda\Gamma_{L1}(G) \]采用L1 Loss而不是L2 Loss的理由很簡單,L1 Loss相比於L2 Loss保邊緣(L2 Loss基於高斯先驗,L1 Loss基於拉普拉斯先驗)。 GAN Loss為LSGAN的最小二乘Loss,並使用PatchGAN(進一步保證生成圖像的清晰度)。PatchGAN將圖像換分成很多個Patch,並對每一個Patch使用判別器進行判別(實際代碼實現有更取巧的辦法),將所有Patch的Loss求平均作為最終的Loss。
7.3.4 如何生成高分辨率圖像和高分辨率視頻?
pix2pix提出了一個通用的圖像翻譯框架。對於高分辨率的圖像生成以及高分辨率的視頻生成,則需要利用更好的網絡結構以及更多的先驗只是。pix2pixHD提出了一種多尺度的生成器以及判別器等方式從而生成高分辨率圖像。Vid2Vid則在pix2pixHD的基礎上利用光流,時序約束生成了高分辨率視頻。
7.3.5 有監督的圖像翻譯的缺點?
許多圖像翻譯算法如前面提及的pix2pix系列,需要一一對應的圖像。可是在許多應用場景下,往往沒有這種一一對應的強監督信息。比如說以下一些應用場景:
以第一排第一幅圖為例,要找到這種一一配對的數據是不現實的。因此,無監督圖像翻譯算法就被引入了。
7.3.6 無監督圖像翻譯:CycleGAN
模型結構
總體思路如下,假設有兩個域的數據,記為A,B。對於上圖第一排第一幅圖A域就是普通的馬,B域就是斑馬。由於A->B的轉換缺乏監督信息,於是,作者提出采用如下方法進行轉換:
a. A->fake_B->rec_A
b. B->fake_A->rec_B
對於A域的所有圖像,學習一個網絡G_B,該網絡可以生成B。對於B域的所有圖像,也學習一個網絡G_A,該網絡可以生成G_B。
訓練過程分成兩步,首先對於A域的某張圖像,送入G_B生成fake_B,然后對fake_B送入G_A,得到重構后的A圖像rec_A。對於B域的某一張圖像也是類似。重構后的圖像rec_A/rec_B可以和原圖A/B做均方誤差,實現了有監督的訓練。此處值得注意的是A->fake_B(B->fake_A)和fake_A->rec_B(fake_B->rec_A)的網絡是一模一樣的。下圖是形象化的網絡結構圖:
cycleGAN的生成器采用U-Net,判別器采用LS-GAN。
Loss設計
總的Loss就是X域和Y域的GAN Loss,以及Cycle consistency loss:
整個過程End to end訓練,效果非常驚艷,利用這一框架可以完成非常多有趣的任務
7.3.7 多領域的無監督圖像翻譯:StarGAN
cycleGAN模型較好的解決了無監督圖像轉換問題,可是這種單一域的圖像轉換還存在一些問題:
-
要針對每一個域訓練一個模型,效率太低。舉例來說,我希望可以將橘子轉換為紅蘋果和青蘋果。對於cycleGAN而言,需要針對紅蘋果,青蘋果分別訓練一個模型。
-
對於每一個域都需要搜集大量數據,太麻煩。還是以橘子轉換為紅蘋果和青蘋果為例。不管是紅蘋果還是青蘋果,都是蘋果,只是顏色不一樣而已。這兩個任務信息是可以共享的,沒必要分別訓練兩個模型。而且針對紅蘋果,青蘋果分別取搜集大量數據太費事。
starGAN則提出了一個多領域的無監督圖像翻譯框架,實現了多個領域的圖像轉換,且對於不同領域的數據可以混合在一起訓練,提高了數據利用率
7.4 GAN的應用(文本生成)
7.4.1 GAN為什么不適合文本任務?
GAN在2014年被提出之后,在圖像生成領域取得了廣泛的研究應用。然后在文本領域卻一直沒有很驚艷的效果。主要在於文本數據是離散數據,而GAN在應用於離散數據時存在以下幾個問題:
-
GAN的生成器梯度來源於判別器對於正負樣本的判別。然而,對於文本生成問題,RNN輸出的是一個概率序列,然后取argmax。這會導致生成器Loss不可導。還可以站在另一個角度理解,由於是argmax,所以參數更新一點點並不會改變argmax的結果,這也使得GAN不適合離散數據。
-
GAN只能評估整個序列的loss,但是無法評估半句話,或者是當前生成單詞對后續結果好壞的影響。
-
如果不加argmax,那么由於生成器生成的都是浮點數值,而ground truth都是one-hot encoding,那么判別器只要判別生成的結果是不是0/1序列組成的就可以了。這容易導致訓練崩潰。
7.4.2 seqGAN用於文本生成
seqGAN在GAN的框架下,結合強化學習來做文本生成。 模型示意圖如下:
在文本生成任務,seqGAN相比較於普通GAN區別在以下幾點:
- 生成器不取argmax。
- 每生成一個單詞,則根據當前的詞語序列進行蒙特卡洛采樣生成完成的句子。然后將句子送入判別器計算reward。
- 根據得到的reward進行策略梯度下降優化模型。
7.5 GAN在其他領域的應用
7.5.1 數據增廣
GAN的良好生成特性近年來也開始被用於數據增廣。以行人重識別為例,有許多GAN用於數據增廣的工作[1-4]。行人重識別問題一個難點在於不同攝像頭下拍攝的人物環境,角度差別非常大,導致存在較大的Domain gap。因此,可以考慮使用GAN來產生不同攝像頭下的數據進行數據增廣。以論文[1]為例,本篇paper提出了一個cycleGAN用於數據增廣的方法。具體模型結構如下:
對於每一對攝像頭都訓練一個cycleGAN,這樣就可以實現將一個攝像頭下的數據轉換成另一個攝像頭下的數據,但是內容(人物)保持不變。
7.5.2 圖像超分辨與圖像補全
圖像超分辨與補全均可以作為圖像翻譯問題,該類問題的處理辦法也大都是訓練一個端到端的網絡,輸入是原始圖片,輸出是超分辨率后的圖片,或者是補全后的圖片。文獻[5]利用GAN作為判別器,使得超分辨率模型輸出的圖片更加清晰,更符合人眼主管感受。日本早稻田大學研究人員[6]提出一種全局+局部一致性的GAN實現圖像補全,使得修復后的圖像不僅細節清晰,且具有整體一致性。
7.5.3 語音領域
相比於圖像領域遍地開花,GAN在語音領域則應用相對少了很多。這里零碎的找一些GAN在語音領域進行應用的例子作為介紹。文獻[7]提出了一種音頻去噪的SEGAN,緩解了傳統方法支持噪聲種類稀少,泛化能力不強的問題。Donahue利用GAN進行語音增強,提升了ASR系統的識別率。