這篇筆記基於上一篇《關於GAN的一些筆記》。
1 GAN的缺陷
由於 $P_G$ 和 $P_{data}$ 它們實際上是 high-dim space 中的 low-dim manifold,因此 $P_G$ 和 $P_{data}$ 之間幾乎是沒有重疊的

正如我們之前說的,如果兩個分布 $P,Q$ 完全沒有重疊,那么 JS divergence 是一個常數 $\log(2)$。
由於最優的 generator 是
![]()
我們在普通的 GAN 中,最小化的是 $P_{data}$ 和 $P_G$ 之間的 JS divergence,那么由於 $P_G$ 和 $P_{data}$ 之間幾乎是沒有重疊的,所以往往會導致 $P_G$ 和 $P_{data}$ 之間的 JS divergence 接近於 $\log(2)$。
由於無法判別到底那種情況下兩個分布更加接近,這就意味着有時候普通的 GAN 很難訓練,甚至沒法訓練。
而如果我們采用實際代碼實現中的 NSGAN,即把 generator 的 loss 改成
![]()
首先請注意,我們訓練 generator 時,discriminator 是固定的,不妨記作 $D^{*}$,而 $D^{*} = P_{data}(x) / (P_{data}(x) + P_G(x))$,這里的 $P_G$ 是還未更新的 generator $G$ 所對應的 distribution。
由於我們已知(詳細的推導可以參見《關於GAN的一些筆記》)
![]()
類似的我們也可以把 KL divergence 寫成

所以

注意到對於后兩項,一項是常數項,一項是更改 $G$ 無法影響的(當你訓練 $G$ 時,$D$ 是固定的,同時 $P_{data}$ 顯然也是不會變的)。所以,你如果把 generator 的 loss 改成了 $V = E_{x \sim P_G}[-\log D(x)]$,那么你就相當於在尋找最優的 generator
![]()
這顯然在理論上是站不住腳的,一邊想使得兩個分布的 KL divergence 盡量小,一邊又想要使得兩個分布的 JS divergence 盡量大,這是矛盾的。這在數值上則會導致梯度不穩定,這就是后面那個 JS divergence 所帶來的問題。
而且另外一個問題是 KL divergence 是非對稱的,會帶來以下問題:
首先寫出 $D_{KL}(P_G \parallel P_{data}) = \int_{x} P_G(x)\log \frac{P_G(x)}{P_{data}(x)}dx$,我們分兩種情況考慮 generator $G$ 會犯的錯誤:
① 對於某處的 $x$,$P_G(x)$ 是高概率(接近 $1$)而 $P_{data}(x)$ 是低概率(接近 $0$),那么此時 $P_G(x)\log \frac{P_G(x)}{P_{data}(x)}$ 接近於正無窮,對於 $D_{KL}(P_G \parallel P_{data})$ 產生了巨大的貢獻。
② 對於某處的 $x$,$P_G(x)$ 是低概率(接近 $0$)而 $P_{data}(x)$ 是高概率(接近 $1$),那么此時 $P_G(x)\log \frac{P_G(x)}{P_{data}(x)}$ 接近於 $0$,對於 $D_{KL}(P_G \parallel P_{data})$ 產生了微乎其微的貢獻。
這就導致了,對於錯誤①(generator 生成了不符合 $P_{data}$ 的錯誤圖片)懲罰巨大,而對於錯誤② (generator 沒有盡可能生成符合 $P_{data}$ 的正確圖片)懲罰很小。這就是的 generator $G$ 會多生成一些重復的但是符合 $P_{data}$ 的正確圖片,而不願意去生成多樣性的樣本,因為那樣就很容易產生錯誤①,會受到巨大的懲罰。這種現象就是大家常說的 collapse mode。這應該就是《關於GAN的一些筆記》中生成結果中有大量的“$1$”的原因。
2 WGAN
之前在《關於GAN的一些筆記》中寫到了 Wasserstein distance 相較於 JS/KL divergence 的優越性。就算 $P_G, P_{data}$ 之間沒有重疊也可以衡量兩個分布的距離。
當然,$W(P,Q) = \inf\limits_{\gamma \in \Pi(P_{data},P_G)} E_{(x,y) \sim \gamma}[\left \| x-y \right \|]$ 這種形式沒法直接變換得到objective function。但是可以用一個定理將其變換成如下形式

這里需要用到的一個知識是 Lipschitz 連續,它對一個函數 $f$ 施加一個限制,要求存在一個常數 $K$ 使得 $f$ 的定義域內任意的兩個元素 $x_1, x_2$ 都滿足
![]()
形象一點的描述就是迫使函數不能過分陡峭,此時成函數 $f$ 的 Lipschitz 常數為 $K$。
所以,變換后的 Wasserstein distance 的意思就是在要求函數 $f$ 的 Lipschitz 常數 $\left \| f \right \|_{L}$ 不超過 $K$ 的條件下,對所有可能滿足條件的 $f$ 取到 $E_{x \sim P_{data}}[f(x)] - E_{x \sim P_G}[f(x)]$ 的上界,然后再除以 $K$。假設我們有一組參數 $w$ 來定義函數 $f_w$,那么 Wasserstein distance 可以近似表達成

回到 GAN 本身,我們知道訓練 generator $G$ 的目的是減小 $P_{data},P_G$ 之間的距離,而訓練 discriminator $D$ 的目的是量出 $P_{data},P_G$ 之間的距離。那么對於 generator $G$ 有
![]()
而 discriminator $D$ 就是要在給定 $G$ 的條件下,量取此時的 $W(P_{data}, P_G)$,參考上面 Wasserstein distance 的近似式,以及 network 強大的函數擬合能力(由於現在 $D$ 做的是近似擬合 Wasserstein distance 屬回歸任務,而非分類任務,所以要把最后一層的sigmoid拿掉),我們的 discriminator $D$ 自然而然就是令
![]()
盡可能地取到最大值,此時的 $V$ 即約等於 $W(P_{data}, P_G)$。
需要注意的點是,對於函數 $D(x)$ 是有限制的,即要存在一個常數 $K$ 使得 $\left \| D \right \|_{L} \leq K$, 這其實很簡單,我們只要使得 network $D$ 的任意一個參數 $w_i$ 都在一個區間 $[-c,c]$ 以內, 此時肯定會使得梯度 $\nabla_{x}D(x)$ 不會大於某一個常數,也就使得 $D$ 滿足了 $\left \| D \right \|_{L} \leq K$。而在具體實現中,只需要在更新完 $D$ 的參數后,做一個weight clipping。即若 $w_i > c$ 則 $w_i := c$,若 $w_i < -c$ 則 $w_i := -c$。
所以綜上,對於 $D$ 有loss function
![]()
加負號是因為loss function一般是越小越好。
而對於 $G$ 有loss function
![]()
可以去掉第一項是因為 $E_{x\sim P_{data}}[D(x)]$ 不受 $G$ 的變動影響。
最后總結,WGAN與原始GAN的區別就以下四點
- discriminator 最后一層去掉 $\rm{sigmoid}$;
- generator 和 discriminator 的 loss 不取 $\log$;
- 每次更新 discriminator 的參數之后把它們的絕對值截斷至不超過一個固定常數 $c$;
- 不要用基於動量的優化算法(包括 momentum 和 Adam),推薦 RMSProp,SGD 也行(這點是作者從實驗中發現的,屬於trick。作者發現如果使用 Adam,discriminator 的 loss 有時候會崩掉,當它崩掉時,Adam 給出的更新方向與梯度方向夾角的 $\cos$ 值就變成負數,更新方向與梯度方向南轅北轍,這意味着 discriminator 的 loss 梯度是不穩定的,所以不適合用Adam這類基於動量的優化算法)。
代碼
這個代碼是來自https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/wgan/wgan.py
import argparse import os import numpy as np import math import sys import torchvision.transforms as transforms from torchvision.utils import save_image from torch.utils.data import DataLoader from torchvision import datasets from torch.autograd import Variable import torch.nn as nn import torch.nn.functional as F import torch os.makedirs("images", exist_ok=True) parser = argparse.ArgumentParser() parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training") parser.add_argument("--batch_size", type=int, default=64, help="size of the batches") parser.add_argument("--lr", type=float, default=0.00005, help="learning rate") parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation") parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space") parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension") parser.add_argument("--channels", type=int, default=1, help="number of image channels") parser.add_argument("--n_critic", type=int, default=5, help="number of training steps for discriminator per iter") parser.add_argument("--clip_value", type=float, default=0.01, help="lower and upper clip value for disc. weights") parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples") opt = parser.parse_args() print(opt) img_shape = (opt.channels, opt.img_size, opt.img_size) cuda = True if torch.cuda.is_available() else False print('CUDA is available: ', cuda) class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() def block(in_feat, out_feat, normalize=True): layers = [nn.Linear(in_feat, out_feat)] if normalize: layers.append(nn.BatchNorm1d(out_feat, 0.8)) layers.append(nn.LeakyReLU(0.2, inplace=True)) return layers self.model = nn.Sequential( *block(opt.latent_dim, 128, normalize=False), *block(128, 256), *block(256, 512), *block(512, 1024), nn.Linear(1024, int(np.prod(img_shape))), nn.Tanh() ) def forward(self, z): img = self.model(z) return img.view(img.shape[0], *img_shape) class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.model = nn.Sequential( nn.Linear(int(np.prod(img_shape)), 512), nn.LeakyReLU(0.2, inplace=True), nn.Linear(512, 256), nn.LeakyReLU(0.2, inplace=True), nn.Linear(256, 1) ) def forward(self, img): img_flat = img.view(img.shape[0], -1) return self.model(img_flat) # Initialize generator and discriminator G = Generator() D = Discriminator() if cuda: G.cuda() D.cuda() # Configure data loader os.makedirs("../../data/mnist", exist_ok=True) dataloader = torch.utils.data.DataLoader( datasets.MNIST( "../../data/mnist", train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]), ), batch_size=opt.batch_size, shuffle=True, ) # Optimizers optimizer_G = torch.optim.RMSprop(G.parameters(), lr=opt.lr) optimizer_D = torch.optim.RMSprop(D.parameters(), lr=opt.lr) Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor batches_done = 0 for epoch in range(opt.n_epochs): for i, (imgs, _) in enumerate(dataloader): # Configure input real_imgs = imgs.type(Tensor) # --------------------- # Train Discriminator # --------------------- # Sample noise as generator input z = Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))) # Generate a batch of images fake_imgs = G(z) # Adversarial loss loss_D = -torch.mean(D(real_imgs)) + torch.mean(D(fake_imgs)) optimizer_D.zero_grad() loss_D.backward() optimizer_D.step() # Clip weights of discriminator for p in D.parameters(): p.data.clamp_(-opt.clip_value, opt.clip_value) # Train the generator every n_critic iterations if i % opt.n_critic == 0: # ----------------- # Train Generator # ----------------- # Generate a batch of images fake_imgs = G(z) # Adversarial loss loss_G = -torch.mean(D(fake_imgs)) optimizer_G.zero_grad() loss_G.backward() optimizer_G.step() print( "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch + 1, opt.n_epochs, i, len(dataloader), loss_D.item(), loss_G.item()) ) if batches_done % opt.sample_interval == 0: save_image(fake_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True) batches_done += opt.n_critic
運行結果
看起來似乎不是很好。
3 WGAN的進一步優化
3.1 WGAN存在的問題
WGAN-GP 是針對 WGAN 的存在的問題提出來的,WGAN 在真實的實驗過程中依舊存在着訓練困難、收斂速度慢的問題,相比較傳統GAN在實驗上提升不是很明顯。
WGAN-GP 在文章中指出了 WGAN 存在問題的原因,那就是 WGAN 在處理 Lipschitz 限制條件時直接采用了 weight clipping。通過在訓練過程中保證 discriminator 的所有參數處於 $[-c,c]$ 的范圍內,保證了 discriminator 不能對兩個略微不同的樣本在判別上差異過大,從而間接實現 Lipschitz 限制。
實際訓練中 discriminator 希望盡可能拉大真假樣本的分數差,然而 weight clipping 獨立地限制每一個網絡參數的取值范圍,在這種情況下最優的策略就是盡可能讓所有參數走極端,要么取最大值($c$)要么取最小值($-c$),文章通過實驗驗證了猜測如下圖所示判別器的參數幾乎都集中在最大值和最小值上。

另一個問題就是 weight clipping 會很容易導致梯度消失或者梯度爆炸。原因是 discriminator 是一個多層網絡,如果把 weight clipping threshold 設得稍微小了一點,每經過一層網絡,梯度就變小一點點,多層之后就會指數衰減;反之,如果設得稍微大了一點,每經過一層網絡,梯度變大一點點,多層之后就會指數爆炸。
只有設得不大不小,才能讓生成器獲得恰到好處的回傳梯度,然而在實際應用中這個平衡區域可能很狹窄,就會給調參工作帶來麻煩。文章也通過實驗展示了這個問題,下圖中橫軸代表判別器從低到高第幾層,縱軸代表梯度回傳到這一層之后的尺度大小

3.2 WGAN-GP
針對以上問題,WGAN-GP 作者提出了解決方案,即 gradient penalty。Lipschitz 限制是要求 discriminator 的梯度不超過 $K$,gradient penalty 就是給 loss 添加一個額外的懲罰項來控制梯度與 $K$ 之間的關系,這就是 gradient penalty 的核心所在。
首先將 Wasserstein distance 的 WGAN 的近似表達式

變成

因為 $D \in 1-Lipschitz$ 等價於對於 $\forall x$ 都有 $\left \| \nabla_x D(x) \right \| \leq 1$,所以上式中的懲罰項就是對於 $\left \| \nabla_x D(x) \right \| > 1$ 的情況進行懲罰。
但顯然我們依然不可能檢查所有的 $x$ 是否 $\left \| \nabla_x D(x) \right \| > 1$,因此繼續進行近似
![]()
我們既然不可能檢查所有的 $x$,那我們只檢查服從分布 $P_{penalty}$(一個事先確定好的分布)的 $x$ 總可以吧。我們盡量讓這部分的 $x$ 的 $\left \| \nabla_x D(x) \right \| \leq 1$。
而我們如何去從 $P_{penalty}$ 中采樣 $x$ 呢,做法是,對任意的服從 $P_{data}$ 的 $x$ 和服從 $P_G$ 的 $x$ 之間連一條邊,在這條邊上隨機采樣,即作為服從 $P_{penalty}$ 的 $x$

換句話說,我們只限制 $P_{data}$ 和 $P_G$ 之間的區域上 $x$ 的梯度,因為隨着訓練進行 $P_G$ 是逐漸靠近 $P_{data}$ 的。
然后文章的作者通過實驗發現,在實際實現中,如下近似效果更好:

原本是僅僅懲罰 $\left \| \nabla_x D(x) \right \| > 1$ 的情況,現在是 $\left \| \nabla_x D(x) \right \| < 1$ 以及 $\left \| \nabla_x D(x) \right \| > 1$ 都懲罰。
所以,最終的 loss function是

代碼
這個代碼來自https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/wgan_gp/wgan_gp.py
import argparse import os import numpy as np import math import sys import torchvision.transforms as transforms from torchvision.utils import save_image from torch.utils.data import DataLoader from torchvision import datasets from torch.autograd import Variable import torch.nn as nn import torch.nn.functional as F import torch.autograd as autograd import torch os.makedirs("images", exist_ok=True) parser = argparse.ArgumentParser() parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training") parser.add_argument("--batch_size", type=int, default=64, help="size of the batches") parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation") parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space") parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension") parser.add_argument("--channels", type=int, default=1, help="number of image channels") parser.add_argument("--n_critic", type=int, default=5, help="number of training steps for discriminator per iter") parser.add_argument("--clip_value", type=float, default=0.01, help="lower and upper clip value for disc. weights") parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples") opt = parser.parse_args() print(opt) img_shape = (opt.channels, opt.img_size, opt.img_size) cuda = True if torch.cuda.is_available() else False class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() def block(in_feat, out_feat, normalize=True): layers = [nn.Linear(in_feat, out_feat)] if normalize: layers.append(nn.BatchNorm1d(out_feat, 0.8)) layers.append(nn.LeakyReLU(0.2, inplace=True)) return layers self.model = nn.Sequential( *block(opt.latent_dim, 128, normalize=False), *block(128, 256), *block(256, 512), *block(512, 1024), nn.Linear(1024, int(np.prod(img_shape))), nn.Tanh() ) def forward(self, z): img = self.model(z) img = img.view(img.shape[0], *img_shape) return img class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.model = nn.Sequential( nn.Linear(int(np.prod(img_shape)), 512), nn.LeakyReLU(0.2, inplace=True), nn.Linear(512, 256), nn.LeakyReLU(0.2, inplace=True), nn.Linear(256, 1), ) def forward(self, img): img_flat = img.view(img.shape[0], -1) validity = self.model(img_flat) return validity # Loss weight for gradient penalty lambda_gp = 10 # Initialize generator and discriminator G = Generator() D = Discriminator() if cuda: G.cuda() D.cuda() # Configure data loader os.makedirs("../../data/mnist", exist_ok=True) dataloader = torch.utils.data.DataLoader( datasets.MNIST( "../../data/mnist", train=True, download=True, transform=transforms.Compose( [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] ), ), batch_size=opt.batch_size, shuffle=True, ) # Optimizers optimizer_G = torch.optim.Adam(G.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) optimizer_D = torch.optim.Adam(D.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor def compute_gradient_penalty(D, real_samples, fake_samples): """Calculates the gradient penalty loss for WGAN GP""" # Random weight term for interpolation between real and fake samples alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1))) # Get random interpolation between real and fake samples interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) d_interpolates = D(interpolates) fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False) # Get gradient w.r.t. interpolates gradients = autograd.grad( outputs=d_interpolates, inputs=interpolates, grad_outputs=fake, create_graph=True, retain_graph=True, only_inputs=True, )[0] gradients = gradients.view(gradients.size(0), -1) gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() return gradient_penalty # ---------- # Training # ---------- batches_done = 0 for epoch in range(opt.n_epochs): for i, (imgs, _) in enumerate(dataloader): # Configure input real_imgs = imgs.type(Tensor) # --------------------- # Train Discriminator # --------------------- # Sample noise as generator input z = Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))) # Generate a batch of images fake_imgs = G(z) # Gradient penalty gradient_penalty = compute_gradient_penalty(D, real_imgs.data, fake_imgs.data) # Adversarial loss d_loss = -torch.mean(D(real_imgs)) + torch.mean(D(fake_imgs)) + lambda_gp * gradient_penalty optimizer_D.zero_grad() d_loss.backward() optimizer_D.step() # Train the generator every n_critic steps if i % opt.n_critic == 0: # ----------------- # Train Generator # ----------------- # Generate a batch of images fake_imgs = G(z) # Loss measures generator's ability to fool the discriminator # Train on fake images g_loss = -torch.mean(D(fake_imgs)) optimizer_G.zero_grad() g_loss.backward() optimizer_G.step() print( "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch + 1, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()) ) if batches_done % opt.sample_interval == 0: save_image(fake_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True) batches_done += opt.n_critic
運行結果

看起來是比 WGAN 要好。
