什么是GAN
生成對抗網絡(GAN)是一種由生成網絡和判別網絡組成的深度神經網絡架構。通過在生成和判別之間的多次循環,兩個網絡相互對抗,繼而兩者性能逐步提升。
生成網絡
生成網絡(Generator Network)借助現有的數據來生成新數據,比如使用從隨機產生的一組數字向量(稱為潛在空間 latent space)中生成數據(圖像、音頻等)。所以在構建的時候你首先要明確生成目標,然后將生成結果交給判別網絡做下一步的處理。
判別網絡
判別網絡(Discriminator Network)試圖區分接收的數據屬於真實數據還是由生成網絡生成的數據,它需要基於實現定於的類別對其進行分類。通常來說,GAN使用在二分類問題上。判別結果為0~1之間的數字,用來表示本次輸入被認為是真實數據的可能性。當判定結果為1時,則認為它來自真實數據,反之則屬於生成數據。
訓練過程
兩個網絡之間相互競爭,在訓練過程中,通常是固定其中一方的參數不改變,然后提升另一方的性能,循環往復。比如說大部分書籍或者視頻中提到的藝術品鑒賞。
我們使用下圖來簡要說明訓練過程:

- 將給定的D維噪聲向量\(Z\) 作為輸入投入生成網絡中進行訓練,試圖生成形似藝術品實物的作品\(G(z)\)
- 判別網絡主要任務是一件藝術品是真品還是贗品,所以它的輸入包括了真實圖片和模擬圖片。其判定結果將作為一個label輸出。
- 生成網絡在不斷迭代中生成看起來更加真實的藝術品,試圖騙過判別網絡,讓它相信這些生成的贗品是真品。
- 判別網絡不斷優化區分真假的標准,試圖識別出每一張由生成網絡制造的贗品。
- 在每一輪的迭代中,它們都會把自己所做的調整中的成功嘗試反饋給對方。
- 最終在判別網絡的幫助下,生成網絡已經訓練到可以讓判別網絡無法正確判斷真品和贗品的區分時,就可以停止迭代過程了。此時網絡進入到一種“納什均衡”的狀態。
GAN的具體架構
GAN主要由兩部分構成,分別是:生成網絡和判別網絡。它們可以是任何一類神經網絡,比如說普通的人工神經網絡、卷積神經網絡、循環神經網絡等。而判別網絡另外需要一些全連接層,然后以分類器收尾。
下面以一個簡單的GAN結構為例:
生成網絡架構
本次GAN的生成網絡是一個5層的簡單前饋神經網絡,輸出層、3個隱藏層及一個輸出層。如下圖所示:

該前饋神經網絡通過正向傳播處理信息的過程如下:
- 輸入層從正態分布采樣一個100維的向量,不做任何修改,直接傳遞給第一個隱藏層。
- 3個隱藏層分別是具有500、500和784個單元的全連接層。第1個隱藏層將一個形狀為(batch_size, 100)的張量變換成(batch_size, 500)。
- 第2個隱藏層(基於上一層輸出的結果)將張量形狀變換為(batch_size, 500)。
- 第3個隱藏層繼續將張量形狀變換為(batch_size, 784)。
- 最后的輸出層將張量的形狀從(batch_size, 784)變換為(batch_size, 28, 28)。這意味着該神經網絡會生成一批圖像,其中每張圖像的形狀為(28, 28)。
判別網絡架構
判別網絡是一個5層的前饋型神經網絡,包括一個輸入層、一個輸出層以及3個全連接層。其實它就是一個分類器,其輸出結果的實際意義是判定該輸入屬於哪個類別。

判別網絡在訓練過程中利用正向傳播來處理數據的過程如下:
- 首先讀取一個形狀為28×28的張量輸入。
- 輸入層接收形狀為(batch_size, 28, 28)的輸入張量,不做任何修改,直接傳遞給第一個隱藏層(扁平化層)
- 扁平化層將該張量轉換成784維,然后將其傳遞給第一個隱藏全連接層。經過前兩個隱藏層的處理,張量轉換成了500維。
- 最后一層是輸出層,也是全連接層,只有一個單元(神經元),使用sigmoid激活函數。它只輸出0或1:輸出0意味着判別網絡認為輸入圖像是假的;輸出1意味着判別網絡認為輸入圖像是真的。
GAN的訓練過程
在GAN的訓練中,有個問題我困惑了非常久:為什么輸入數據是從噪聲數據分布中隨機采樣出來的?
后來發現,因為簡單模型無法模擬概率分布函數\(P_{model}(x, \theta)\),所以需要使用神經網絡來實現,即經過生成網絡中的神經網絡后,可以映射成為幾乎任何的復雜分布,所以最開始我們可以使用高斯分布下的數值來模擬(之后再調整\(\theta\)參數)。
所以原來的\(P_{model}(x, \theta)\)可以被繞過,變成\(x = G(z, \theta_g)\),且$z $是符合噪聲分布的一組向量。

GAN的基本訓練過程如下圖所示:

- 初始化判別網絡的參數 \(\theta_{d}\) 和生成器G的參數 \(\theta_g\)。
- 從真實樣本中采樣 m個樣本 $[x^1, x^2, ... x^m $;從噪聲數據分布中采樣 m個噪聲樣本 \([z^1, z^2, ...,z^m ]\)並通過生成網絡獲取m個生成樣本\(\widetilde x^1, \widetilde x^2, ..., \widetilde x^m\)
- 固定生成網絡參數,訓練判別網絡,使其盡可能好地准確判別真實樣本和生成樣本,盡可能大得區分正確樣本和生成的樣本。
- 循環k次更新判別器之后,使用較小的學習率來更新一次生成器的參數。固定判別網絡參數,使其盡可能地減小生成樣本與真實樣本之間的差距,相當於使得判別網絡分辨的准確率降低。
- 多次更新迭代之后,最終理想情況是使得判別器判別不出樣本來自於生成器的輸出還是真實的輸出。亦即最終樣本判別概率均為0.5。
之所以要訓練k次判別器,再訓練生成器,是因為要先擁有一個好的判別器,使得能夠教好地區分出真實樣本和生成樣本之后,才好更為准確地對生成器進行更新。更直觀的理解可以參考下圖:

圖中的黑色虛線表示真實的樣本的分布情況,藍色虛線表示判別器判別概率的分布情況,綠色實線表示生成樣本的分布。 \(Z\) 表示噪聲, \(Z\) 到 \(x\) 表示通過生成器之后的分布的映射情況。
我們的目標是使用生成樣本分布(綠色實線)去擬合真實的樣本分布(黑色虛線),來達到生成以假亂真樣本的目的。
-
可以看到在(a)狀態處於最初始的狀態的時候,生成器生成的分布和真實分布區別較大,並且判別器判別出樣本的概率不是很穩定,因此會先訓練判別器來更好地分辨樣本。
-
通過多次訓練判別器來達到(b)樣本狀態,此時判別樣本區分得非常顯著和良好。然后再對生成器進行訓練。
-
訓練生成器之后達到(c)樣本狀態,此時生成器分布相比之前,逼近了真實樣本分布。
-
經過多次反復訓練迭代之后,最終希望能夠達到(d)狀態,生成樣本分布擬合於真實樣本分布,並且判別器分辨不出樣本是生成的還是真實的(判別概率均為0.5)。也就是說我們這個時候就可以生成出非常真實的樣本啦,目的達到。
GAN的數學原理
概率分布函數
1. 真實數據的概率分配函數\(P_{data}(x):\)
對於真實訓練數據集,將定義一個概率分布函數\(P_{data}(x)\),其中\(x\)是一個高維向量,也就相當於真實數據集中的某個數據點。關於概率分布函數\(P_{data}(x)\)到底是個什么東西?接下來以二次元人臉生成舉例。

根據李宏毅教授的解釋,在高維空間中,僅有一部分的點集能夠正確表示人臉。所以現在我們將\(x\)定義成一個二維向量以方便展示,圖中的藍色區域則可以表示為\(P_{data}(x)\)。
那么可以發現,在藍色區域中隨機取樣的兩個點可以生成出比較清晰的人臉,因此說這兩個樣本點具有high probability。所以相應的,藍色區域外的點就是具有low probability。
我感覺可以這么理解:某個從真實訓練集中抽取的樣本點\(x\),比較大的概率是來自於圖中的藍色區域。
2. 生成模型的概率分配函數\(P_{model}(x; \theta)\):
為了逼近真實數據的概率分布,我們也會為生成模型定義一個概率分布函數\(P_{model}(x; \theta)\),這個分布函數是通過參數變量\(\theta\)定義的,在實際的計算過程中,我們希望改變該參數,從而使得\(P_{model}(x; \theta)\)逼近\(P_{data}(x)\)。
但是,實際上我們並不知道\(P_{data}(x)\)的形式,所以,逼近的唯一方式就是從真實數據中采樣大量的數據,再借助這些真實樣本,來計算生成模型的概率分布。
綜上所述,生成網絡的目標就是以真實采樣數據\(\lbrace x^1, x^2, .. \rbrace\) 為我們的訓練集,並通過不斷地參數調整,使得該模型接收到一組向量時,能夠輸出接近真實的數據結果。
吃了沒有概率論基礎的虧... 也就是說通過調整參數,使得采樣點盡可能地在生成模型的概率分配函數上。
最大似然估計
極大似然估計提供了一種給定觀察數據來評估模型參數的方法,即:“模型已定,參數未知”。
比如我們要統計全國人口的身高,首先假設這個身高服從服從正態分布,但是該分布的均值與方差未知。我們沒有人力與物力去統計全國每個人的身高,但是可以使用采樣的方法:獲取部分人的身高,然后通過最大似然估計來獲取上述假設中的正態分布的均值與方差。
極大似然估計中采樣需滿足一個很重要的假設:所有的采樣都是獨立同分布的。
假如已知某個隨機樣本滿足某種概率分布,但是其中具體的參數不清楚,參數估計就是通過若干次試驗,觀察其結果,利用結果推出參數的大概值。
極大似然原理是說如果我們已知某個參數能使這個樣本出現的概率最大,我們當然不會再去選擇其他小概率的樣本,所以干脆就把這個參數作為估計的真實值。
直觀來看,一個隨機試驗如果有若干個可能的結果A,B,C,…N,那么如果在僅僅作一次的試驗中,結果A出現,則一般認為試驗條件對A出現有利,也即A出現的概率很大。而事件A發生的概率與參數\(\theta\)相關,A發生的概率記為P(A,\(\theta\)),則θθ的估計應該使上述概率達到最大,這樣的\(\theta\)顧名思義就稱為極大似然估計。
所以我們可以根據之前在真實數據分布中取樣的\(\lbrace x^1, x^2, .. , x^m \rbrace\) 這m個樣本數據,來計算它們在生成模型中的概率如下,最大似然估計的目標是通過這個概率的式子,尋找出一個\(\theta^*\)使得\(L\)最大化。
這樣做的實際含義是,在給出真實訓練集的前提下,我們希望生成模型能夠在這些數據上具備最大的概率,這樣才說明我們的生成模型在給出的訓練集上能夠逼近真實數據的概率分布。
KL散度
KL散度,也稱相對熵,用於判定兩個概率分布之間的相似度。它可以測量一個概率分布\(p\)相對於另一個概率分布\(q\)的偏離。如下公式用於計算兩個概率分布\(p(x)\)和\(q(x)\)之間的KL散度:
如果\(p(x)\)和\(q(x)\)處處相等,則此時KL散度為0,達到最小值。
由於KL散度具有不對稱性,因此不用於測量兩個概率分布之間的距離,因此也不用作距離的度量(metric)。
JS散度
JS散度,也稱信息半徑(information radius, IRaD)或者平均值總偏離(total divergence to the average),是測量兩個概率分布之間相似度的另一種方法。它基於KL散度,但具有對稱性,可用於測量兩個概率分布之間的距離。對JS散度開平方即可得到JS距離,所以它是一種距離度量。
計算兩個概率分布p和q之間JS散度的公式如下。
其中,(p+q)/2是p和q的中點測度,\(D_{KL}\)是KL散度。
公式數學推導
判別網絡:
對於判別網絡,假設其輸入數據為\(x\),使用\(D(x)\)來表示該樣本被判斷為正樣本的概率。則有:
- 如果\(x\)來自\(P_{data}\),那么\(D(x)\)要越大越好,可以用\(log(D(x)) \uparrow\)表示。
- 如果\(x\)來自於\(P_{model}\),那么\(D(x)\)越小越好,而此時的\(x = G(z)\),帶入得到\(D(G(z))\),進而表示為\(log[1-D(G(z))] \uparrow\)。
因此需要最大化此公式:
生成網絡:
對於生成網絡,目標是使得自己的輸出結果被判定為正樣本的概率越高越好:
- 如果\(x\)來自於\(P_{model}\),那么\(D(x)\)越大越好,而此時的\(x = G(z)\),帶入得到\(D(G(z))\),進而表示為\(log[1-D(G(z))] \downarrow\)。
因此需要最小化此公式:
最后得到我們的總目標實際上是:
全局最優解:
固定生成網絡參數,求\(max_DV(D,G)\):
我們現在的目標是希望尋找一個D使得V最大,我們希望對於積分中的項\(f(x) =p_{data}(x)logD(x)+p_{model}(x)log(1-D(x))\),無論x取何值都能最大。其中,我們已知\(p_{data}\)是固定的,之前我們也假定生成器G固定,所以\(P_{model}\)也是固定的,所以我們可以很容易地求出D以使得f(x)最大。
我們假設x固定,f(x)對D(x)求導等於零,下面是求解D(x)的推導。
那么將\(D_G^*\)代入后,有:
然后轉換為前面介紹的JS散度:
所以當\(p_{data} = \frac{p_{data} + p_{model}}{2} = p_{model}\)時,“\(=\)”成立,故最后得到\(D^* = \frac{1}{2}\)。
這也證明了,通過上述min max的博弈過程,理想情況下會收斂於生成分布擬合於真實分布。
真不知道這些公式之后我還有沒有可能記得住。。
Pytorch代碼實現
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.utils import save_image
import os
import torch.nn.functional as F
# Hyper Parameters
batch_size = 100
epochs = 300
latent_size = 100
hidden_size = 256
image_size = 784
RealImage = torchvision.datasets.MNIST(
root='./mnist/',
train=True,
transform=torchvision.transforms.ToTensor(), # 轉換PIL.Image成Tensor
download=True,
)
RealLoader = DataLoader(dataset=RealImage, batch_size=batch_size, shuffle=True)
# 判別器: 輸入原始圖片,輸出判別的結果
class Discriminator(nn.Module):
def __init__(self, d_input_dim):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(d_input_dim, 1024)
self.fc2 = nn.Linear(1024, 512)
self.fc3 = nn.Linear(512, 256)
self.fc4 = nn.Linear(256, 1)
def forward(self, x):
x = F.leaky_relu(self.fc1(x), 0.2)
x = F.dropout(x, 0.3)
x = F.leaky_relu(self.fc2(x), 0.2)
x = F.dropout(x, 0.3)
x = F.leaky_relu(self.fc3(x), 0.2)
x = F.dropout(x, 0.3)
return torch.sigmoid(self.fc4(x))
# 生成器: 根據給定的分布,來生成一張圖片
class Generator(nn.Module):
def __init__(self, g_input_dim, g_output_dim):
super(Generator, self).__init__()
self.fc1 = nn.Linear(g_input_dim, 256) # 100 -> 256
self.fc2 = nn.Linear(256, 512) # 256 -> 512
self.fc3 = nn.Linear(512, 1024) # 512 -> 1024
self.fc4 = nn.Linear(1024, g_output_dim) # 1024 -> 784
def forward(self, x):
x = F.leaky_relu(self.fc1(x), 0.2)
x = F.leaky_relu(self.fc2(x), 0.2)
x = F.leaky_relu(self.fc3(x), 0.2)
return torch.tanh(self.fc4(x))
G = Generator(g_input_dim=latent_size, g_output_dim=image_size)
D = Discriminator(image_size)
loss = nn.BCELoss()
optimizer1 = optim.Adam(D.parameters(), lr=0.0003)
optimizer2 = optim.Adam(G.parameters(), lr=0.0003)
for epoch in range(epochs):
for step, (x, y) in enumerate(RealLoader):
images = x.reshape(-1,image_size) # 真圖像
real_labels = torch.ones(batch_size, 1).reshape(x.size(0)).type(torch.FloatTensor)
fake_labels = torch.zeros(batch_size, 1).reshape(x.size(0)).type(torch.FloatTensor)
# ================================================================== #
# 訓練判別器 #
# ================================================================== #
# 定義判別器的損失函數
outputs = D(images)
real_loss = loss(outputs, real_labels)
real_score = outputs
# 定義判別器對假圖像的損失函數
fack_digit = torch.randn(batch_size, latent_size)
fake_images = G(fack_digit)
outputs = D(fake_images)
fake_loss = loss(outputs, fake_labels)
fake_score = outputs
# 得到判別器的總損失
total_loss = real_loss + fake_loss
optimizer1.zero_grad()
total_loss.backward()
optimizer1.step()
# ================================================================== #
# 訓練生成器 #
# ================================================================== #
z = torch.randn(batch_size, latent_size)
fake_images = G(z)
outputs = D(fake_images)
g_loss = loss(outputs, real_labels)
optimizer2.zero_grad()
g_loss.backward()
optimizer2.step()
if (step+1) % 200 == 0:
print(
'Epoch [{}/{}], total_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' .format(
epoch, epochs, total_loss.item(),
g_loss.item(), real_score.mean().item(), fake_score.mean().item()))
# 保存真圖像
if (epoch + 1) == 1:
images = images.reshape(images.size(0), 1, 28, 28)
save_image(images, os.path.join('../img/mnist', 'real_images.png'))
# 保存假圖像
fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
save_image(fake_images, os.path.join('../img/mnist', 'fake_images-{}.png'.format(epoch+1)))
# 保存模型
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')
參考資料
- https://blog.csdn.net/xg123321123/article/details/52980581
- https://zhuanlan.zhihu.com/p/33752313
- 《機器學習——白板推導系列三十一》
- 《生成對抗網絡入門指南》
- 《生成對抗網絡項目實戰》