GAN簡介
一、什么是GAN
GAN是一類由兩個同時訓練的模型組成的機器學習技術:一個是生成器,訓練其生成偽數據:另一個是判別器,訓練其從真實數據中識別偽數據。
- 生成(generative)一詞預示着模型的總目標——生成新數據。GAN通過學習生成的數據取決於所選擇的訓練集,例如,如果我們想用GAN合成一幅看起來像達・芬奇作品的畫作,就得用達·芬奇的作品作為訓練集。
- 對抗(adversarial)一詞則是指構成GAN框架的兩個動態博弈、競爭的模型:生成器和判別器。生成器的目標是生成與訓練集中的真實數據無法區分的偽數據——在剛才的示例中這就意味着能夠創作出和達・芬奇畫作一樣的繪畫作品。判別器的目標是能辨別出哪些是來自訓練集的真實數據,哪些是來自生成器的偽數據。也就是說,判別器充當着藝術品鑒定專家的角色,評估被認為是達·芬奇畫作的作品的真實性。這兩個網絡不斷新地“斗智斗勇”,試圖互相欺騙:生成器生成的偽數據越逼真,判別器辨別真偽的能力就要越強。
- 網絡(network)一詞表示最常用於生成器和判別器的一類機器學習模型:神經網絡。依據GAN實現的復雜程度,這些網絡包括從最簡單的前饋神經網絡到卷積神經網絡以及更為復雜的變體。
二、GAN是如何工作的
還有一個比喻經常用來形容GAN,假幣制造者(生成器)和試圖逮捕他的偵探(判別器)——假鈔看起來越真實,就需要越好的偵探才能辨別出他們,反之亦然。
用更專業的術語來說,生成器的目標是生成能最大程度有效捕捉訓練集特征的樣本,以至於生成出的樣本與訓練數據別無二致。生成器可以看作一個反向的對象識別模型——對象識別算法學習圖像中的模式,以期能夠識別圖像的內容。生成器不是去識別這些模式,而是要學會從頭開始學習創建它們,實際上,生成器的輸入通常不過是一個隨機數向量。
生成器通過從判別器的分類結果中接收反饋來不斷學習。判別器的目標是判斷一個特定的樣本是真的(來自訓練集)還是假的(由生成器生成)。因此,每當判別器“上當受騙”將假的圖像錯判為真實圖像時,生成器就會知道自己做得很好:相反,每當判別器正確地將生成器生成的假圖像辨別出來時,生成器就會收到需要繼續改進的反饋。
判別器也會不斷地改善,像其他分類器一樣,它會從預測標簽與真實標簽(真或假)之間的偏差中學習。所以隨着生成器能更好地生成更逼真的數據,判別器也能更好地辨別真假數據,兩個網絡都在同時不斷地改進着。
表1.1 生成器和判別器的關鍵信息
生成器 | 判別器 | |
---|---|---|
輸入 | 一個隨機數向量 | 判別器的輸入有兩個來源:來自訓練集的真實樣本和來自生成器的偽樣本 |
輸出 | 盡可能令人信服的偽樣本 | 預測輸入樣本是真實的概率 |
目標 | 生成與訓練集中數據別無二致的偽數據 | 區分來自生成器的偽樣本和來自訓練集的真實樣本 |
三、GAN的結構
假定我們的目標是教GAN生成逼真的手寫數字。GAN的核心結構如下圖所示。
讓我們看看其中的細節。
(1) 訓練數據集——包含真實樣本的數據集,是我們希望生成器能以近乎完美的質量去學習模仿的數據。在這個示例中,數據集由手寫數字的圖像組成。該數據集用作判別器網絡的輸入(\(x\))。
(2) 隨機噪聲向量——生成器網絡的初始輸入(z)。此輸入是一個由隨機數組成的向量,生成器將其用作合成偽樣本的起點。
(3) 生成器網絡——生成器接收隨機數向量(z)作為輸入並輸出偽樣本(x*)。它的目標是生成和訓練數據集中的真實樣本別無二致的偽樣本。(卷積神經網絡)
(4) 判別器網絡——判別器接收來自訓練集的真實樣本(x)或生成器生成的偽樣本(x*)作為輸入。對每個樣本,判別器會進行判定並輸出其為真實的概率。(反卷積神經網絡)
(5) 迭代訓練/調優——對於每個判別器的預測,我們會衡量它效果有多好——就像對常規的分類器一樣——並用結果反向傳播去迭代優化判別器網絡和生成網絡。
- 更新判別器的權重和偏置,以最大化其分類的精確度(最大化正確預測的概率:x為真,x*為假)。
- 更新生成器的權重和偏置,以最大化判別器將x*誤判為真的概率。
3.1 GAN的訓練
為了了解GAN各組件的用途,我們首先介紹GAN的訓練算法,其次演示訓練過程,以便我們能夠可以清楚的看到實際的框架圖。
GAN訓練算法
對於每次訓練迭代,執行如下操作。
(1)訓練判別器
a.從訓練集中隨機抽取真實樣本x。
b.獲取一個新的隨機噪聲向量z,用生成器網絡合成一個偽樣本x*。
c.用判別器網絡對x和x*進行分類。
d.計算分類誤差並反向傳播總誤差以更新判別器的可訓練參數,尋求最小化分類誤差。
(2)訓練生成器
a.獲取一個新的隨機噪聲向量z,用生成器網絡合成一個偽樣本x*。
b.用判別器網絡對x*進行分類。
c.計算分類誤差並反向傳播以更新生成器的可訓練參數,尋求最大化判別器誤差。
結束
GAN訓練過程可視化
GAN的訓練算法如下圖所示,其中的字母表示GAN訓練算法中的步驟。
子程序圖示說明
(1)訓練判別器
a. 從訓練集中隨機抽取真實樣本x。
b. 獲取一個新的隨機噪聲向量z,用生成器網結合成一個偽樣本x*。
c. 用判別器網絡對x和x*進行分類。
d. 計算分類誤差並反向傳播總誤差以更新判別器的權重和偏置,尋求最小化分類誤差。
(2)訓練生成器
a. 獲取一個新的隨機噪聲向量z,用生成器網絡合成一個偽樣本x*。
b.用判別器網絡對x*進行分類。
c.計算分類誤差並反向傳播以更新生成器的可訓練參數,尋求最大化判別器誤差。
3.2 達到平衡
對於一般的神經網絡,我們通常有一個明確的目標去實現以及用來衡量效果。例如,當訓練一個分類器時,我們度量在訓練集和驗證集上的分類誤差,一旦發現驗證集開始變壞,就停止進程(為了避免過擬合)。在GAN結構中,判別器網絡和生成器網絡有兩個互為競爭對手的目標:一個網絡越好,另一個就越差。那么我們如何決定何時停止進程呢?
這其實是一個零和博弈問題,即一方的收益等於另一方的損失。當一方提高一定程度時,另一方會惡化同樣的程度。零和博弈都有一個納什均衡點,那就是任何一方無論怎么努力都不能改善他們的處境或結果。
當滿足以下條件時,GAN達到納什均衡點。
(1)生成器生成的偽樣本與訓練集中的真實數據別無二致。
(2)判別器所能做的只是隨機猜測一個特定的樣本是真的還是假的(也就是說,猜測一個示例為真的概率是50%)。
讓我們來解釋為何會出現這種情況。當每一個偽樣本(x*)與來自訓練集的真實樣本無法區分時,判別器用任何手段都無法區分它們。因為判別器接收到的樣本有一半是真的,半是假的,所以它所能做的最有用的事情就是拋硬幣,以50%的概率把每個樣本分為真和假。
同樣,生成器也處於這樣一個點上,它不能從進一步的調優中獲得任何提高了。因為生成器生成的樣本早已和真實樣本無法區分了,以至於對隨機噪聲向量(z)轉換為偽樣本(x)的過程做出哪怕一丁點兒改變,也可能給判別器提供從真實樣本中辨別出偽樣本的機會,從而使生成器變得更糟。
當達到納什均衡時,GAN就被認為是收斂的。這是一個棘手的問題,在實踐中,由於在非凸博弈中實現收斂所涉及的巨大復雜性,幾乎不可能達到GAN的納什均衡。實際上,GAN的收斂仍是GAN研究中最重要的開放性問題之一。
四、小結
- GAN是一種利用兩個神經網絡之間的動態競爭來合成真實數據樣本的深度學習技術,例如能合成具有照片級真實感的虛假圖像。構成一個完整GAN的兩個網絡如下:
- 生成器,其目標是通過生成與訓練數據集別無二致的數據來欺騙判別器;
- 判別器,其目標是正確區分來自訓練數據集的真實數據和由生成器生成的偽數據。