本篇文章為Goodfellow提出的GAN算法的開山之作"Generative Adversarial Nets"的學習筆記,若有錯誤,歡迎留言或私信指正。
1. Introduction
GAN模型解決的問題
作者在首段指出了本課題的意義——能夠避免深度生成模型中的兩個局限性:
(1)最大似然估計等相關策略中難以處理的概率計算;
(2)在生成環境中難以利用分段線性單元的優勢。
PS:深度生成模型是為了從原始的樣本數據中模擬出數據分布,進而產生符合這一分布的新的樣本。
GAN模型的構成
GAN模型主要分為兩個部分:生成模型\(G\)(generative model)和判別模型\(D\)(discriminative model)。作者將這兩個部分的關系類比為假幣制造商和警察。判別模型(警察)來判斷一個樣本究竟是來自於數據分布還是模型分布,而生成模型(假幣制造商)則是為了生成假的樣本來騙過判別模型(警察)。這樣產生的競爭驅使兩方都不斷更新自己的方法,直到假樣本與真樣本完全無法區分。框架中的生成模型通過將隨機噪聲輸入多層感知機進而得到假樣本,而判別模型則是將樣本輸入多層感知機來判斷樣本是否為真實樣本。作者指出,同時訓練這兩個模型的方法是使用反向傳播算法(backpropagation algorithm)和丟棄算法(應該是“隨機失活算法”,之前表述有所錯誤)(dropout algorithm)。
ps:多層感知機(Multi Layer Perceptron) 為如下圖所示的包括至少一個隱藏層(除去一個輸入層和一個輸出層以外)的多層神經元網絡。
2. Related work
作者在這一部分介紹了有關深度生成網絡的相關工作,由於還沒有系統掌握深度學習,暫且跳過。。。
3. Adversarial nets
作者定義了兩個函數:
- \(G\left(\boldsymbol{z} ; \theta_{g}\right)\)表示從噪音\(\boldsymbol{z}\)到數據樣本空間的映射函數,其中\(\theta_{g}\)表示多層感知機的參數。
- \(D(\boldsymbol{x};\theta_d)\)表示對於輸入樣本\(\boldsymbol{x}\)輸出判斷為真實數據的概率(為一個標量)
GAN的目標
作者將GAN的目標定義為下面這個優化公式:
這個公式由兩部分構成,前一部分 \(\mathbb{E}_{\mathbf{x}\sim p_{data}(x)}[\log D(x)]\) 可以理解為判別模型正確判定真實數據的對數期望,后一部分\(\mathbb{E}_{\mathbf{z}\sim p_z}[\log (1-D(G(z)))]\)可以理解為判別模型正確識別假樣本的對數期望。
相對熵?
后面重新看這個公式的時候突然想到一個問題:為什么公式中求期望時要在概率值的外面再套一個log函數呢?
還沒有完全想明白,感覺有點像是交叉熵和相對熵的概念,但是又無法將上述公式和相對熵的公式定義一一對應起來。
可以先可以參考這個鏈接:機器學習中的各種“熵”(這個博客頁面是真的卡)
后面作者在Theoretical Results部分果然用到了交叉熵的概念,用來證明這個公式的目標就是實現\(p_g=p_{data}\)
圖形理解
下面這張圖我看了一晚上沒看懂,今天早上才明白是什么意思。
圖中藍色的虛線表示判別模型的概率分布,黑色的點線表示原始數據的概率分布,綠色實線表示生成模型的生成概率分布。
- 圖(a)中表示訓練還未開始或剛開始一段時間時,判別模型和生成模型都還沒有經過大量的訓練,因此判別概率分布還有些波動,生成分布距離真實數據生成分布還有不小的距離。
- 圖(b)表示經過一段時間的訓練后,判別模型可以較好的判別出原始樣本和生成樣本,藍色虛線的高度表示當前位置對應橫坐標的樣本為真實數據樣本的概率,越高表示為真實樣本的概率越大。
- 圖(c)表示繼續訓練一段時間之后,原始樣本和生成樣本更加接近,判別模型還是有相對不錯的判別效果。
- 圖(d)表示經過足夠的訓練之后,原始樣本與生成樣本的概率分布特征基本一致,判別模型失去判別效果。
算法1
算法1原文描述
算法1中文python偽代碼描述(這是個什么鬼東西)
for i in range(訓練的迭代次數)):
for j in range(k步):
從噪音分布中取出m個噪音樣本
從數據分布中取出m個樣本
利用隨機梯度上升法更新判別器D
從噪音分布中取出m個噪音樣本
利用隨機梯度下降法更新生成器G
4. Theoretical Results
全局最優\(p_g=p_{data}\)
作者給出了第一個命題:當生成模型\(G\)固定時,最優判別器為:$$D^*G(x)=\frac{p{data}(x)}{p_{data}(x)+p_g(x)}$$
並且作者也給出了證明,證明的思路很簡單,形如\(y \rightarrow a \log (y)+b \log (1-y)\)的式子在\([0,1]\)區間內有固定的最大值為\(\frac{a}{a+b}\)。
這樣一來,原公式就可以替換成下式:
再利用KL散度(相對熵)公式:
可以將剛剛的式子替換為
進一步可以利用JS散度公式進行替換:
由JS散度公式的性質——即取值范圍為\([0,1]\)因此可知\(C(G)\)的最小值為\(-\log (4)\),此時JS散度公式取0,並且此時唯一解為\(p_g=p_{data}\)
5. Experiments
6. Advantages and disadvantages
作者指出,生成式對抗網絡有以下四個優勢:
- 根據實驗結果來看,比其他模型產生了更銳利、清晰的樣本。
- 生成式對抗網絡能夠被用來訓練任何一種生成器網絡。
- 不需要設計遵循任何種類的因式分解的模型。
- 無需利用馬爾可夫鏈反復采樣,回避了棘手的近似計算的概率問題。
GAN目前存在的主要問題是:
- 解決不收斂的問題