文章目錄
1. 簡介
本教程將通過一個具體的實例來講解DCGANs。我們將訓練一個生成對抗性網絡(GAN),在向其展示許多真正名人的照片后,該網絡能產生新的名人。此處的大部分代碼都來自pytorch/examples中的dcgan實現,本文將對實現方式進行詳細的講解,並闡明該模型如何以及為什么起作用。你之前並不了解GAN也沒關系,但對於新手的話可能需要花費一些時間來理解幕后的實際情況。同樣,如果有一兩個GPU的話,將會幫助你節省訓練時間。讓我們開始吧。
2. 概述
2.1. 什么是GAN(生成對抗網絡)
GANs是一個深度學習模型框架,用於獲取訓練數據的分布,這樣我們就可以從同一分布中生成新的數據。GANs是Ian Goodfellow在2014年提出的,並在論文Generative Adversarial Nets中進行了首次描述。
它們由兩個不同的模型組成,分別是生成器和判別器。生成器的工作是生成看起來像訓練圖像的假圖。判別器的任務是判別一張圖像是真實的訓練圖像還是來自生成器的偽圖像。在訓練過程中,生成器通過生成越來越像真實圖像的偽圖來嘗試騙過判別器,而判別器則是努力地想成為更好的偵探,這樣才能正確地對真實和偽造的圖像進行分類。
博弈的平衡點是當生成器生成的偽造圖像看起來像直接來自訓練數據,而判別器始終以50%的置信度推測生成器的輸出是真的還是假的。
現在,讓我們從判別器開始定義一些在整個教程中都會使用的符號。令 x x x 為圖像數據, D ( x ) D(x) D(x)是判別器網絡輸出 x x x來自訓練數據而不是生成器的概率。由於我們要處理圖像,因此 D ( x ) D(x) D(x)的輸入是CHW大小為3x64x64的圖像。直觀地說,當 x x x來自訓練數據時, D ( x ) D(x) D(x)的值應該高;當 x x x來自生成生成器時, D ( x ) D(x) D(x)的值應該低。 D ( x ) D(x) D(x)其實也可以看作是傳統的二分類器。
對於生成器的表示法,令 z z z為從標准正態分布采樣的潛在空間向量。 G ( z ) G(z) G(z)表示將潛在空間向量 z z z映射到數據空間的生成器函數。 G G G的目標是估計訓練數據分布( p d a t a p_{data} pdata),以便它可以從估計的數據分布( p g p_g pg)中生成假樣本。
因此, D ( G ( z ) ) D(G(z)) D(G(z))是生成器 G G G的輸出為真實圖像的概率值(標量)。正如Goodfellow的論文所描述的, D D D和 G G G玩一個minimax的游戲,其中 D D D嘗試使它能正確分類真圖和偽圖的概率最大化( l o g D ( x ) logD(x) logD(x)),而 G G G卻嘗試使 D D D預測其輸出是偽圖的概率最小化( l o g ( 1 − D ( G ( x ) ) ) log(1-D(G(x))) log(1−D(G(x))))。論文中,GAN的損失函數是:
min G max D V ( D , G ) = E x ∼ p d a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] \underset{G}{\text{min}} \underset{D}{\text{max}}V(D,G) = \mathbb{E}_{x\sim p_{data}(x)}\big[logD(x)\big] + \mathbb{E}_{z\sim p_{z}(z)}\big[log(1-D(G(z)))\big] GminDmaxV(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]
從理論上講,此minimax游戲的最終解決方案是 p g = p d a t a p_g = p_{data} pg=pdata,並且判別器會隨機猜測輸入的圖像是真還是假。但是GANs的收斂理論仍在積極地研究中,實際上模型也並不總是能夠達到這一點。
2.2. 什么是DCGAN(深度卷積生成對抗網絡)
DCGAN是上述講的GAN的一個分支,不同的是DCGAN分別在判別器和生成器中使用卷積和反卷積層。它最初是由Radford等人在論文Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks中提出的。
判別器由卷積層、批標准化層、以及LeakyReLU激活函數組成。輸入是一張3x64x64的圖片,輸出是該圖來自真實數據分布的標量概率值。
生成器由反卷積層、批標准化層、以及ReLU激活函數組成。輸入是一個來自標准正分布的潛在空間向量 z z z,輸出是一個3x64x64的RGB彩色圖片。反置卷積層將潛在空間向量轉換為具有與真實圖像相同的維度。論文中,作者還提供了有關如何設置優化器,如何計算損失函數,以及如何初始化模型權重的一些技巧,所有這些將在接下來的部分中進行講解。
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
輸出:
Random Seed: 999
3. 輸入
- dataroot:數據集文件夾所在的路徑
- workers :數據加載器加載數據的線程數
- batch_size:訓練的批次大小。DCGAN論文中用的是128
- image_size:訓練圖像的維度。默認是64x64。如果需要其它尺寸,必須更改 D D D和 G G G的結構,點擊這里查看詳情
- nc:輸入圖像的通道數。對於彩色圖像是3
- nz:潛在空間的長度
- ngf:與通過生成器進行的特征映射的深度有關
- ndf:設置通過鑒別器傳播的特征映射的深度
- num_epochs:訓練的總輪數。訓練的輪數越多,可能會導致更好的結果,但也會花費更長的時間
- lr:學習率。DCGAN論文中用的是0.0002
- beta1:Adam優化器的參數beta1。論文中,值為0.5
- ngpus:可用的GPU數量。如果為0,代碼將在CPU模式下運行;如果大於0,它將在該數量的GPU下運行
# Root directory for dataset
dataroot = "data/celeba"
# Number of workers for dataloader
workers = 2
# Batch size during training
batch_size = 128
# Spatial size of training images. All images will be resized to this
# size using a transformer.
image_size = 64
# Number of channels in the training images. For color images this is 3
nc = 3
# Size of z latent vector (i.e. size of generator input)
nz = 100
# Size of feature maps in generator
ngf = 64
# Size of feature maps in discriminator
ndf = 64
# Number of training epochs
num_epochs = 5
# Learning rate for optimizers
lr = 0.0002
# Beta1 hyperparam for Adam optimizers
beta1 = 0.5
# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1
4. 數據
本教程中,我們將使用Celeb-A Faces數據集,該數據集可以在鏈接的網站或谷歌雲盤中下載。數據集下載下來是一個名為img_align_celeba.zip的壓縮文件。下載后,創建一個名為celeba的目錄,並將zip文件解壓到該目錄中。然后,將dataroot
設置為剛創建的目錄。結果目錄結構應該為:
/path/to/celeba
-> img_align_celeba
-> 188242.jpg
-> 173822.jpg
-> 284702.jpg
-> 537394.jpg
...
這是重要的一步,因為我們將使用ImageFolder
數據集類,該類要求數據集的根文件夾中有子目錄。現在,我們可以創建數據集、數據加載器,以及設置訓練的設備,最后可視化一些訓練數據。
# We can use an image folder dataset the way we have it setup.
# Create the dataset
dataset = dset.ImageFolder(root=dataroot,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=True, num_workers=workers)
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
5. 實現
設置輸入參數並准備好數據集后,我們可以進入實現了。我們將從權重初始化策略開始,然后詳細的討論生成器、判別器、損失函數和訓練過程。
5.1. 權重初始化
在DCGAN論文中,作者指出所有模型權重應當從均值為0,標准差為0.02的正態分布中隨機初始化。weights_init
函數以初始化的模型為輸入,重新初始化所有卷積層、反卷積層和批標准化層,以滿足這一標准。該函數在初始化后立即應用於模型。
# custom weights initialization called on netG and netD
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
5.2. 生成器
生成器 G G G用於將潛在空間向量 z z z映射到數據空間。由於我們的數據是圖像,因此將 z z z轉換到數據空間意味着最終創建與訓練圖像大小相同的RGB圖像(即3x64x64)。
實際上,這是通過一系列的二維反卷積層來完成的,每層都配帶有批標准化層和relu激活。生成器的輸出最終經過tanh函數處理,以使其返回到[-1, 1]的輸入數據范圍。
值得注意的是,在反卷積層之后存在批標准化函數,這是DCGAN論文中的關鍵貢獻。這些層有助於訓練過程中的梯度流動,DCGAN論文中生成器的一張圖片如下。
注意,我們在輸入部分中設置的輸入(nz
,ngf
和nc
)如何影響代碼中的生成器體系結構。 nz
是輸入向量 z z z的長度,ngf
與通過生成器傳播的特征圖的大小有關,nc
是輸出圖像的通道數(對於RGB圖像來說是3)。 下面是生成器的代碼。
# Generator Code
class Generator(nn.Module):
def __init__(self, ngpu):
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
def forward(self, input):
return self.main(input)
現在,我們可以實例化生成器並應用weights_init
函數。檢查打印的模型以查看生成器對象的結構。
# Create the generator
netG = Generator(ngpu).to(device)
# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
netG = nn.DataParallel(netG, list(range(ngpu)))
# Apply the weights_init function to randomly initialize all weights
# to mean=0, stdev=0.2.
netG.apply(weights_init)
# Print the model
print(netG)
輸出:
Generator(
(main): Sequential(
(0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace=True)
(9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(11): ReLU(inplace=True)
(12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(13): Tanh()
)
)
5.3. 判別器
如前所述,判別器 D D D是一個二分類網絡,該網絡將圖像作為輸入,並輸出該圖是真(與假相對)的標量概率。
這里, D D D以3x64x64的圖像作為輸入,通過一系列的Conv2d
,BatchNorm2d
和LeakyReLU
層的處理,然后通過Sigmoid
激活函數輸出最終概率。對於這個問題,如果需要的話,這個體系結構可以擴展更多的層,但是使用strided convolution
,BatchNorm
和LeakyReLUs
具有重要意義。DCGAN論文提到,使用strided convolution
而不是通過池化來進行下采樣是個好方法,因為它可以讓網絡學習自己的池化函數。 batch norm
和leaky relu
函數還可以促進健康的梯度流動,這對於 G G G和 D D D的學習過程都至關重要。
判別器代碼
class Discriminator(nn.Module):
def __init__(self, ngpu):
super(Discriminator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input)
現在和生成器一樣,我們可以創建判別器,應用weights_init
函數,並打印模型結構。
# Create the Discriminator
netD = Discriminator(ngpu).to(device)
# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
netD = nn.DataParallel(netD, list(range(ngpu)))
# Apply the weights_init function to randomly initialize all weights
# to mean=0, stdev=0.2.
netD.apply(weights_init)
# Print the model
print(netD)
輸出:
Discriminator(
(main): Sequential(
(0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(10): LeakyReLU(negative_slope=0.2, inplace=True)
(11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
(12): Sigmoid()
)
)
5.4. 損失函數和優化器
D D D和 G G G設置之后,我們可以指定它們如何通過損失函數和優化器學習。我們將使用在PyTorch中定義的二元交叉熵損失(BCELoss)函數:
ℓ ( x , y ) = L = { l 1 , … , l N } ⊤ , l n = − [ y n ⋅ log x n + ( 1 − y n ) ⋅ log ( 1 − x n ) ] \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = - \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right] ℓ(x,y)=L={l1,…,lN}⊤,ln=−[yn⋅logxn+(1−yn)⋅log(1−xn)]
注意此函數如何提供目標函數中兩個對數成分的計算(即 l o g ( D ( x ) ) log(D(x)) log(D(x))和 l o g ( 1 − D ( G ( z ) ) ) log(1-D(G(z))) log(1−D(G(z))))。 我們可以指定BCE方程的哪一部分用於 y y y輸入。 這是在即將到來的訓練循環中完成的,但重要的是要了解如何僅通過更改 y y y(即GT標簽)就可以選擇想要計算的組件。
接下來,我們將實際標簽定義為1,將假標簽定義為0。這些標簽將在計算 D D D和 G G G的損失時使用,這是在原始GAN論文中使用的慣例。
最后,我們設置了兩個單獨的優化器,一個針對 D D D,一個針對 G G G。正如DCGAN論文中所規定的,這兩個都是lr
為0.0002且Beta1
為0.5的Adam
優化器。為了跟蹤生成器的學習過程,我們將生成一批來自高斯分布的固定潛在空間向量(即fixed_noise
)。在訓練循環中,我們將定期地把fixed_noise
輸入到 G G G中,經過多次迭代,我們將看到圖像從噪聲中形成。
# Initialize BCELoss function
criterion = nn.BCELoss()
# Create batch of latent vectors that we will use to visualize
# the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)
# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.
# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
5.5. 訓練
最后,既然定義了GAN框架的所有部分,我們就可以對其進行訓練了。要注意,訓練GAN網絡在某種程度上來說是一種藝術形式,因為不正確的超參數設置會導致模式崩潰,而對失敗的原因幾乎不可解釋。
在這里,我們將嚴格遵守Goodfellow論文中的算法1,同時遵守ganhacks中展示的一些最佳做法。也即是說,我們將為真圖和假圖構造不同的mini-batches
,並調整 G G G的目標函數,使 l o g D ( G ( z ) ) logD(G(z)) logD(G(z))最大化。訓練分為兩個主要部分,第一部分是判別器的更新,第二部分是生成器的更新。
5.5.1. 第一部分 - 訓練判別器
回想一下,訓練判別器的目的是最大限度地提高將給定輸入正確分類為真實或偽造的可能性。就像Goodfellow在論文中所說的,我們希望“通過提升其隨機梯度來更新鑒別器”。
實際上,我們想最大化 l o g ( D ( x ) ) + l o g ( 1 − D ( G ( z ) ) ) log(D(x))+log(1−D(G(z))) log(D(x))+log(1−D(G(z)))。由於ganhacks提出了單獨的mini-batch建議,因此我們將分兩步進行計算。首先,我們將從訓練集中構造一批真實樣本,向前傳播給 D D D,計算損失( l o g ( D ( x ) ) log(D(x)) log(D(x))),然后向后傳播計算梯度。接着,我們將用當前的生成器構造一批假樣本,將該批樣本向前傳播給 D D D,計算損失( l o g ( 1 − D ( G ( z ) ) ) log(1−D(G(z))) log(1−D(G(z)))),並向后傳播累加梯度。現在,隨着從所有真實批次和所有假批次累積的梯度,我們稱之為判別器的優化器的一個步驟。
5.5.2. 第二部分 - 訓練生成器
如原論文所述,我們希望通過最小化 l o g ( 1 − D ( G ( z ) ) ) log(1−D(G(z))) log(1−D(G(z)))來訓練生成器,以產生更好的偽造品。但又如前所述,Goodfellow表明,這不能提供足夠的梯度,特別是在學習過程的早期。而解決方案是改為最大化 l o g ( D ( G ( z ) ) ) log(D(G(z))) log(D(G(z)))。
在代碼中,我們的具體實現方法是:用判別器對第一部分生成器的輸出進行分類,使用真圖的標簽作為GT計算 G G G的損失,計算 G G G在反向傳播中的梯度,最后通過優化器step
更新 G G G的參數。使用真圖的標簽作為GT來計算損失似乎是違反常識的,但這允許我們使用BCELoss的 l o g ( x ) log(x) log(x)部分(而不是 l o g ( 1 − x ) log(1−x) log(1−x)部分),這正是我們想要的。
最后,我們將做一些統計報告,在每個epoch
結束時,我們將通過生成器推動我們的fixed_noise batch
,以直觀地跟蹤 G G G的訓練過程。 上報的訓練統計數據為:
- Loss_D - 判別器損失,計算為所有真實批次和所有假批次的損失之和 ( l o g ( D ( x ) ) + l o g ( D ( G ( z ) ) ) log(D(x))+log(D(G(z))) log(D(x))+log(D(G(z))))。
- Loss_G - 生成器損失,計算為log(D(G(z)))。
- D(x) - 判別器對於真實批次的平均輸出(整個批次)。剛開始訓練的時候這個值應該接近1,當 G G G變得更好時,理論上收斂到0.5。想想這是為什么。
- D(G(z)) - 判別器對於假批次的平均輸出。第一個數字在 D D D更新之前,第二個數字在 D D D更新之后。這些數字在開始的時候應該是接近0的,並隨着 G G G的提高向0.5收斂。想想這是為什么。
注意:此步驟可能需要一段時間。具體取決於你運行了多少個epoch
以及是否從數據集中刪除了一些數據。
# Training Loop
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
# For each batch in the dataloader
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
## Train with all-real batch
netD.zero_grad()
# Format batch
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
# Forward pass real batch through D
output = netD(real_cpu).view(-1)
# Calculate loss on all-real batch
errD_real = criterion(output, label)
# Calculate gradients for D in backward pass
errD_real.backward()
D_x = output.mean().item()
## Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=device)
# Generate fake image batch with G
fake = netG(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = netD(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
# Calculate the gradients for this batch
errD_fake.backward()
D_G_z1 = output.mean().item()
# Add the gradients from the all-real and all-fake batches
errD = errD_real + errD_fake
# Update D
optimizerD.step()
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = netD(fake).view(-1)
# Calculate G's loss based on this output
errG = criterion(output, label)
# Calculate gradients for G
errG.backward()
D_G_z2 = output.mean().item()
# Update G
optimizerG.step()
# Output training stats
if i % 50 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
% (epoch, num_epochs, i, len(dataloader),
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
# Save Losses for plotting later
G_losses.append(errG.item())
D_losses.append(errD.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
iters += 1
輸出:
Starting Training Loop...
[0/5][0/1583] Loss_D: 1.9847 Loss_G: 5.5914 D(x): 0.6004 D(G(z)): 0.6680 / 0.0062
[0/5][50/1583] Loss_D: 0.4017 Loss_G: 17.8778 D(x): 0.8368 D(G(z)): 0.0000 / 0.0000
[0/5][100/1583] Loss_D: 2.8508 Loss_G: 22.8236 D(x): 0.9634 D(G(z)): 0.8460 / 0.0000
[0/5][150/1583] Loss_D: 0.2360 Loss_G: 5.4596 D(x): 0.8440 D(G(z)): 0.0308 / 0.0090
[0/5][200/1583] Loss_D: 1.6425 Loss_G: 4.7064 D(x): 0.3414 D(G(z)): 0.0079 / 0.0176
[0/5][250/1583] Loss_D: 0.2731 Loss_G: 4.4791 D(x): 0.9431 D(G(z)): 0.1680 / 0.0225
[0/5][300/1583] Loss_D: 0.6051 Loss_G: 4.6251 D(x): 0.8278 D(G(z)): 0.2424 / 0.0230
[0/5][350/1583] Loss_D: 0.7070 Loss_G: 1.6842 D(x): 0.6204 D(G(z)): 0.0824 / 0.2560
[0/5][400/1583] Loss_D: 0.6758 Loss_G: 4.0679 D(x): 0.9354 D(G(z)): 0.3946 / 0.0288
[0/5][450/1583] Loss_D: 0.5348 Loss_G: 5.7453 D(x): 0.9625 D(G(z)): 0.3514 / 0.0083
[0/5][500/1583] Loss_D: 0.6896 Loss_G: 7.8784 D(x): 0.9364 D(G(z)): 0.4080 / 0.0012
[0/5][550/1583] Loss_D: 0.4377 Loss_G: 8.1336 D(x): 0.9425 D(G(z)): 0.2840 / 0.0007
[0/5][600/1583] Loss_D: 1.8797 Loss_G: 2.5577 D(x): 0.3201 D(G(z)): 0.0123 / 0.1258
[0/5][650/1583] Loss_D: 1.3832 Loss_G: 10.6947 D(x): 0.9770 D(G(z)): 0.7006 / 0.0001
[0/5][700/1583] Loss_D: 0.3195 Loss_G: 3.7833 D(x): 0.8474 D(G(z)): 0.0844 / 0.0789
[0/5][750/1583] Loss_D: 0.2142 Loss_G: 4.1755 D(x): 0.8942 D(G(z)): 0.0813 / 0.0232
[0/5][800/1583] Loss_D: 1.4535 Loss_G: 2.3077 D(x): 0.4024 D(G(z)): 0.0111 / 0.1806
[0/5][850/1583] Loss_D: 0.4109 Loss_G: 6.3312 D(x): 0.9002 D(G(z)): 0.2153 / 0.0048
[0/5][900/1583] Loss_D: 2.7930 Loss_G: 4.5548 D(x): 0.1428 D(G(z)): 0.0022 / 0.0240
[0/5][950/1583] Loss_D: 0.3493 Loss_G: 5.5976 D(x): 0.8767 D(G(z)): 0.1498 / 0.0080
[0/5][1000/1583] Loss_D: 0.6749 Loss_G: 5.0457 D(x): 0.6349 D(G(z)): 0.0215 / 0.0194
[0/5][1050/1583] Loss_D: 0.4009 Loss_G: 4.5791 D(x): 0.7669 D(G(z)): 0.0484 / 0.0260
[0/5][1100/1583] Loss_D: 0.3453 Loss_G: 2.7277 D(x): 0.8885 D(G(z)): 0.1408 / 0.1219
[0/5][1150/1583] Loss_D: 0.2484 Loss_G: 5.0396 D(x): 0.8727 D(G(z)): 0.0595 / 0.0174
[0/5][1200/1583] Loss_D: 0.6760 Loss_G: 3.2315 D(x): 0.7052 D(G(z)): 0.1756 / 0.0688
[0/5][1250/1583] Loss_D: 0.5845 Loss_G: 3.1392 D(x): 0.7576 D(G(z)): 0.2018 / 0.0673
[0/5][1300/1583] Loss_D: 0.2762 Loss_G: 4.9311 D(x): 0.8666 D(G(z)): 0.0933 / 0.0136
[0/5][1350/1583] Loss_D: 0.4753 Loss_G: 4.7346 D(x): 0.8595 D(G(z)): 0.2228 / 0.0170
[0/5][1400/1583] Loss_D: 0.3764 Loss_G: 5.9964 D(x): 0.7758 D(G(z)): 0.0109 / 0.0098
[0/5][1450/1583] Loss_D: 0.4025 Loss_G: 3.8804 D(x): 0.8158 D(G(z)): 0.1413 / 0.0320
[0/5][1500/1583] Loss_D: 0.6678 Loss_G: 2.7302 D(x): 0.6980 D(G(z)): 0.1486 / 0.1040
[0/5][1550/1583] Loss_D: 0.6062 Loss_G: 3.1664 D(x): 0.7235 D(G(z)): 0.1305 / 0.0783
[1/5][0/1583] Loss_D: 0.6615 Loss_G: 8.0512 D(x): 0.9412 D(G(z)): 0.3797 / 0.0007
[1/5][50/1583] Loss_D: 0.8057 Loss_G: 2.1089 D(x): 0.5929 D(G(z)): 0.0869 / 0.1893
[1/5][100/1583] Loss_D: 0.4206 Loss_G: 3.3245 D(x): 0.7409 D(G(z)): 0.0554 / 0.0640
[1/5][150/1583] Loss_D: 0.6361 Loss_G: 4.0774 D(x): 0.7830 D(G(z)): 0.2605 / 0.0256
[1/5][200/1583] Loss_D: 1.7394 Loss_G: 7.5861 D(x): 0.9685 D(G(z)): 0.7499 / 0.0014
[1/5][250/1583] Loss_D: 0.4597 Loss_G: 3.1064 D(x): 0.7053 D(G(z)): 0.0265 / 0.0844
[1/5][300/1583] Loss_D: 0.4190 Loss_G: 2.2869 D(x): 0.7942 D(G(z)): 0.1163 / 0.1660
[1/5][350/1583] Loss_D: 0.4724 Loss_G: 4.3673 D(x): 0.8292 D(G(z)): 0.2106 / 0.0213
[1/5][400/1583] Loss_D: 0.2877 Loss_G: 4.3217 D(x): 0.8823 D(G(z)): 0.1125 / 0.0225
[1/5][450/1583] Loss_D: 0.8508 Loss_G: 0.8635 D(x): 0.5397 D(G(z)): 0.0390 / 0.5324
[1/5][500/1583] Loss_D: 0.4317 Loss_G: 3.1585 D(x): 0.7646 D(G(z)): 0.0931 / 0.0767
[1/5][550/1583] Loss_D: 0.8256 Loss_G: 6.1484 D(x): 0.9395 D(G(z)): 0.4563 / 0.0051
[1/5][600/1583] Loss_D: 0.9765 Loss_G: 1.5017 D(x): 0.4807 D(G(z)): 0.0076 / 0.2843
[1/5][650/1583] Loss_D: 1.8020 Loss_G: 8.8270 D(x): 0.9480 D(G(z)): 0.7248 / 0.0003
[1/5][700/1583] Loss_D: 0.3680 Loss_G: 3.7401 D(x): 0.7991 D(G(z)): 0.0949 / 0.0404
[1/5][750/1583] Loss_D: 0.5763 Loss_G: 2.0559 D(x): 0.6739 D(G(z)): 0.0851 / 0.1882
[1/5][800/1583] Loss_D: 0.7773 Loss_G: 5.0999 D(x): 0.9399 D(G(z)): 0.4335 / 0.0142
[1/5][850/1583] Loss_D: 0.3901 Loss_G: 3.4356 D(x): 0.8537 D(G(z)): 0.1744 / 0.0491
[1/5][900/1583] Loss_D: 0.7268 Loss_G: 6.5356 D(x): 0.9635 D(G(z)): 0.4428 / 0.0027
[1/5][950/1583] Loss_D: 0.4570 Loss_G: 3.8893 D(x): 0.8707 D(G(z)): 0.2376 / 0.0304
[1/5][1000/1583] Loss_D: 1.3551 Loss_G: 7.2447 D(x): 0.9333 D(G(z)): 0.6422 / 0.0030
[1/5][1050/1583] Loss_D: 0.3905 Loss_G: 3.3360 D(x): 0.8183 D(G(z)): 0.1462 / 0.0537
[1/5][1100/1583] Loss_D: 1.3858 Loss_G: 0.9796 D(x): 0.3336 D(G(z)): 0.0259 / 0.4584
[1/5][1150/1583] Loss_D: 0.5776 Loss_G: 2.6197 D(x): 0.6443 D(G(z)): 0.0532 / 0.1051
[1/5][1200/1583] Loss_D: 0.5647 Loss_G: 3.5713 D(x): 0.8026 D(G(z)): 0.2450 / 0.0428
[1/5][1250/1583] Loss_D: 0.4568 Loss_G: 3.6666 D(x): 0.8934 D(G(z)): 0.2581 / 0.0403
[1/5][1300/1583] Loss_D: 0.7197 Loss_G: 1.8175 D(x): 0.6211 D(G(z)): 0.1035 / 0.2184
[1/5][1350/1583] Loss_D: 0.5255 Loss_G: 3.2736 D(x): 0.8141 D(G(z)): 0.2233 / 0.0574
[1/5][1400/1583] Loss_D: 0.8241 Loss_G: 3.0776 D(x): 0.7807 D(G(z)): 0.3659 / 0.0743
[1/5][1450/1583] Loss_D: 0.4302 Loss_G: 3.3777 D(x): 0.9058 D(G(z)): 0.2518 / 0.0519
[1/5][1500/1583] Loss_D: 0.4173 Loss_G: 2.5610 D(x): 0.7916 D(G(z)): 0.1358 / 0.1058
[1/5][1550/1583] Loss_D: 0.7993 Loss_G: 5.1228 D(x): 0.8527 D(G(z)): 0.4162 / 0.0104
[2/5][0/1583] Loss_D: 0.4844 Loss_G: 2.2263 D(x): 0.7645 D(G(z)): 0.1510 / 0.1426
[2/5][50/1583] Loss_D: 0.6756 Loss_G: 2.4608 D(x): 0.5915 D(G(z)): 0.0657 / 0.1248
[2/5][100/1583] Loss_D: 0.4391 Loss_G: 3.0181 D(x): 0.7901 D(G(z)): 0.1486 / 0.0744
[2/5][150/1583] Loss_D: 0.5683 Loss_G: 1.8918 D(x): 0.7083 D(G(z)): 0.1411 / 0.1858
[2/5][200/1583] Loss_D: 0.5932 Loss_G: 3.3342 D(x): 0.9111 D(G(z)): 0.3576 / 0.0522
[2/5][250/1583] Loss_D: 0.7331 Loss_G: 2.3817 D(x): 0.6635 D(G(z)): 0.1665 / 0.1397
[2/5][300/1583] Loss_D: 0.5493 Loss_G: 2.3824 D(x): 0.7491 D(G(z)): 0.1742 / 0.1196
[2/5][350/1583] Loss_D: 0.6197 Loss_G: 1.8560 D(x): 0.6443 D(G(z)): 0.1018 / 0.1972
[2/5][400/1583] Loss_D: 0.6172 Loss_G: 3.0777 D(x): 0.8482 D(G(z)): 0.3251 / 0.0621
[2/5][450/1583] Loss_D: 0.5047 Loss_G: 3.2941 D(x): 0.9174 D(G(z)): 0.3116 / 0.0566
[2/5][500/1583] Loss_D: 0.7335 Loss_G: 1.2796 D(x): 0.5676 D(G(z)): 0.0575 / 0.3470
[2/5][550/1583] Loss_D: 0.7716 Loss_G: 1.9450 D(x): 0.5513 D(G(z)): 0.0580 / 0.1922
[2/5][600/1583] Loss_D: 0.4425 Loss_G: 2.0531 D(x): 0.8015 D(G(z)): 0.1640 / 0.1686
[2/5][650/1583] Loss_D: 1.0964 Loss_G: 4.4602 D(x): 0.9096 D(G(z)): 0.5833 / 0.0163
[2/5][700/1583] Loss_D: 0.4745 Loss_G: 2.8636 D(x): 0.8492 D(G(z)): 0.2403 / 0.0770
[2/5][750/1583] Loss_D: 0.4947 Loss_G: 3.6931 D(x): 0.8803 D(G(z)): 0.2732 / 0.0364
[2/5][800/1583] Loss_D: 0.9355 Loss_G: 4.3906 D(x): 0.9120 D(G(z)): 0.5168 / 0.0195
[2/5][850/1583] Loss_D: 0.9213 Loss_G: 1.6006 D(x): 0.4645 D(G(z)): 0.0339 / 0.2467
[2/5][900/1583] Loss_D: 0.5337 Loss_G: 3.7601 D(x): 0.9101 D(G(z)): 0.3310 / 0.0314
[2/5][950/1583] Loss_D: 1.2562 Loss_G: 4.9530 D(x): 0.9432 D(G(z)): 0.6244 / 0.0144
[2/5][1000/1583] Loss_D: 0.4187 Loss_G: 2.4701 D(x): 0.8454 D(G(z)): 0.1945 / 0.1129
[2/5][1050/1583] Loss_D: 0.5796 Loss_G: 2.3732 D(x): 0.7714 D(G(z)): 0.2253 / 0.1216
[2/5][1100/1583] Loss_D: 0.6325 Loss_G: 2.5824 D(x): 0.8307 D(G(z)): 0.3235 / 0.0939
[2/5][1150/1583] Loss_D: 0.7639 Loss_G: 3.9487 D(x): 0.9031 D(G(z)): 0.4398 / 0.0291
[2/5][1200/1583] Loss_D: 0.7040 Loss_G: 3.3561 D(x): 0.8073 D(G(z)): 0.3403 / 0.0500
[2/5][1250/1583] Loss_D: 1.0567 Loss_G: 4.7122 D(x): 0.9292 D(G(z)): 0.5656 / 0.0155
[2/5][1300/1583] Loss_D: 0.5431 Loss_G: 2.4260 D(x): 0.7628 D(G(z)): 0.2028 / 0.1116
[2/5][1350/1583] Loss_D: 0.7633 Loss_G: 4.1670 D(x): 0.9257 D(G(z)): 0.4404 / 0.0237
[2/5][1400/1583] Loss_D: 2.1958 Loss_G: 0.5288 D(x): 0.1539 D(G(z)): 0.0147 / 0.6404
[2/5][1450/1583] Loss_D: 0.6991 Loss_G: 1.8573 D(x): 0.5818 D(G(z)): 0.0621 / 0.1980
[2/5][1500/1583] Loss_D: 0.8286 Loss_G: 3.6899 D(x): 0.8805 D(G(z)): 0.4440 / 0.0364
[2/5][1550/1583] Loss_D: 0.5100 Loss_G: 2.5931 D(x): 0.7721 D(G(z)): 0.1862 / 0.0989
[3/5][0/1583] Loss_D: 0.7136 Loss_G: 2.6315 D(x): 0.8178 D(G(z)): 0.3462 / 0.1034
[3/5][50/1583] Loss_D: 0.6472 Loss_G: 2.6359 D(x): 0.7572 D(G(z)): 0.2460 / 0.0962
[3/5][100/1583] Loss_D: 0.5211 Loss_G: 1.7793 D(x): 0.7275 D(G(z)): 0.1402 / 0.2050
[3/5][150/1583] Loss_D: 0.9620 Loss_G: 4.0717 D(x): 0.9423 D(G(z)): 0.5500 / 0.0243
[3/5][200/1583] Loss_D: 0.5469 Loss_G: 2.1994 D(x): 0.7581 D(G(z)): 0.1972 / 0.1359
[3/5][250/1583] Loss_D: 0.3941 Loss_G: 2.7071 D(x): 0.7281 D(G(z)): 0.0401 / 0.0902
[3/5][300/1583] Loss_D: 0.6482 Loss_G: 1.4858 D(x): 0.6275 D(G(z)): 0.1085 / 0.2802
[3/5][350/1583] Loss_D: 1.2781 Loss_G: 4.7393 D(x): 0.9594 D(G(z)): 0.6587 / 0.0120
[3/5][400/1583] Loss_D: 0.5942 Loss_G: 2.8406 D(x): 0.7861 D(G(z)): 0.2579 / 0.0784
[3/5][450/1583] Loss_D: 0.5395 Loss_G: 1.9849 D(x): 0.6755 D(G(z)): 0.0854 / 0.1764
[3/5][500/1583] Loss_D: 0.7941 Loss_G: 2.5871 D(x): 0.7891 D(G(z)): 0.3784 / 0.1006
[3/5][550/1583] Loss_D: 0.6556 Loss_G: 3.9228 D(x): 0.9328 D(G(z)): 0.4053 / 0.0254
[3/5][600/1583] Loss_D: 0.6489 Loss_G: 3.2773 D(x): 0.8385 D(G(z)): 0.3419 / 0.0490
[3/5][650/1583] Loss_D: 0.9217 Loss_G: 1.3858 D(x): 0.4992 D(G(z)): 0.0854 / 0.3095
[3/5][700/1583] Loss_D: 0.4947 Loss_G: 2.2791 D(x): 0.7948 D(G(z)): 0.2035 / 0.1332
[3/5][750/1583] Loss_D: 0.9676 Loss_G: 1.6087 D(x): 0.4641 D(G(z)): 0.0363 / 0.2599
[3/5][800/1583] Loss_D: 0.5918 Loss_G: 1.8852 D(x): 0.7019 D(G(z)): 0.1637 / 0.1948
[3/5][850/1583] Loss_D: 0.7856 Loss_G: 3.4243 D(x): 0.8672 D(G(z)): 0.4219 / 0.0512
[3/5][900/1583] Loss_D: 0.5023 Loss_G: 2.7348 D(x): 0.8372 D(G(z)): 0.2416 / 0.0851
[3/5][950/1583] Loss_D: 0.9028 Loss_G: 1.8348 D(x): 0.5362 D(G(z)): 0.1219 / 0.2110
[3/5][1000/1583] Loss_D: 0.8118 Loss_G: 3.9327 D(x): 0.9092 D(G(z)): 0.4586 / 0.0306
[3/5][1050/1583] Loss_D: 0.8709 Loss_G: 3.1103 D(x): 0.8752 D(G(z)): 0.4686 / 0.0639
[3/5][1100/1583] Loss_D: 0.4286 Loss_G: 2.9141 D(x): 0.8379 D(G(z)): 0.1912 / 0.0741
[3/5][1150/1583] Loss_D: 0.6005 Loss_G: 1.8091 D(x): 0.7044 D(G(z)): 0.1727 / 0.2042
[3/5][1200/1583] Loss_D: 0.7432 Loss_G: 3.8108 D(x): 0.9088 D(G(z)): 0.4344 / 0.0297
[3/5][1250/1583] Loss_D: 0.6872 Loss_G: 1.8717 D(x): 0.7355 D(G(z)): 0.2731 / 0.1789
[3/5][1300/1583] Loss_D: 0.5740 Loss_G: 3.4426 D(x): 0.8874 D(G(z)): 0.3380 / 0.0422
[3/5][1350/1583] Loss_D: 0.5689 Loss_G: 2.0738 D(x): 0.6823 D(G(z)): 0.0966 / 0.1621
[3/5][1400/1583] Loss_D: 0.5023 Loss_G: 3.1107 D(x): 0.9225 D(G(z)): 0.3231 / 0.0565
[3/5][1450/1583] Loss_D: 0.7466 Loss_G: 3.1208 D(x): 0.8441 D(G(z)): 0.3891 / 0.0634
[3/5][1500/1583] Loss_D: 0.7135 Loss_G: 2.8145 D(x): 0.8924 D(G(z)): 0.4117 / 0.0765
[3/5][1550/1583] Loss_D: 0.7881 Loss_G: 4.0945 D(x): 0.9332 D(G(z)): 0.4717 / 0.0258
[4/5][0/1583] Loss_D: 0.6309 Loss_G: 2.2672 D(x): 0.7764 D(G(z)): 0.2761 / 0.1311
[4/5][50/1583] Loss_D: 0.8068 Loss_G: 1.4844 D(x): 0.5595 D(G(z)): 0.1015 / 0.2795
[4/5][100/1583] Loss_D: 0.4912 Loss_G: 2.0030 D(x): 0.7526 D(G(z)): 0.1516 / 0.1674
[4/5][150/1583] Loss_D: 3.0392 Loss_G: 0.6172 D(x): 0.0896 D(G(z)): 0.0134 / 0.6503
[4/5][200/1583] Loss_D: 0.6768 Loss_G: 2.5170 D(x): 0.7543 D(G(z)): 0.2852 / 0.0986
[4/5][250/1583] Loss_D: 1.2451 Loss_G: 0.9252 D(x): 0.3817 D(G(z)): 0.0554 / 0.4569
[4/5][300/1583] Loss_D: 0.5916 Loss_G: 1.7704 D(x): 0.6588 D(G(z)): 0.1113 / 0.2144
[4/5][350/1583] Loss_D: 1.3058 Loss_G: 0.6935 D(x): 0.3416 D(G(z)): 0.0394 / 0.5486
[4/5][400/1583] Loss_D: 0.6206 Loss_G: 3.0787 D(x): 0.8405 D(G(z)): 0.3261 / 0.0609
[4/5][450/1583] Loss_D: 0.5866 Loss_G: 1.4752 D(x): 0.6981 D(G(z)): 0.1565 / 0.2718
[4/5][500/1583] Loss_D: 0.5616 Loss_G: 3.0459 D(x): 0.8869 D(G(z)): 0.3223 / 0.0650
[4/5][550/1583] Loss_D: 0.6073 Loss_G: 3.2580 D(x): 0.7503 D(G(z)): 0.2344 / 0.0500
[4/5][600/1583] Loss_D: 0.6905 Loss_G: 3.0939 D(x): 0.8591 D(G(z)): 0.3762 / 0.0589
[4/5][650/1583] Loss_D: 0.5836 Loss_G: 1.7048 D(x): 0.6781 D(G(z)): 0.1227 / 0.2282
[4/5][700/1583] Loss_D: 0.8543 Loss_G: 3.7586 D(x): 0.8876 D(G(z)): 0.4712 / 0.0337
[4/5][750/1583] Loss_D: 0.8484 Loss_G: 2.3787 D(x): 0.6606 D(G(z)): 0.2724 / 0.1192
[4/5][800/1583] Loss_D: 0.5562 Loss_G: 2.1677 D(x): 0.7446 D(G(z)): 0.1887 / 0.1533
[4/5][850/1583] Loss_D: 0.7600 Loss_G: 1.4960 D(x): 0.5447 D(G(z)): 0.0559 / 0.2722
[4/5][900/1583] Loss_D: 0.5677 Loss_G: 3.0179 D(x): 0.8308 D(G(z)): 0.2804 / 0.0664
[4/5][950/1583] Loss_D: 0.5381 Loss_G: 2.9582 D(x): 0.7989 D(G(z)): 0.2345 / 0.0711
[4/5][1000/1583] Loss_D: 0.8333 Loss_G: 2.8499 D(x): 0.7720 D(G(z)): 0.3700 / 0.0786
[4/5][1050/1583] Loss_D: 0.5125 Loss_G: 1.8930 D(x): 0.7287 D(G(z)): 0.1387 / 0.1848
[4/5][1100/1583] Loss_D: 0.4527 Loss_G: 3.0039 D(x): 0.8639 D(G(z)): 0.2413 / 0.0614
[4/5][1150/1583] Loss_D: 0.7072 Loss_G: 0.8361 D(x): 0.5589 D(G(z)): 0.0563 / 0.4846
[4/5][1200/1583] Loss_D: 0.8619 Loss_G: 4.9323 D(x): 0.9385 D(G(z)): 0.4880 / 0.0112
[4/5][1250/1583] Loss_D: 0.6864 Loss_G: 2.4925 D(x): 0.7232 D(G(z)): 0.2431 / 0.1152
[4/5][1300/1583] Loss_D: 0.5835 Loss_G: 3.1599 D(x): 0.8430 D(G(z)): 0.3018 / 0.0644
[4/5][1350/1583] Loss_D: 0.9119 Loss_G: 4.7225 D(x): 0.9409 D(G(z)): 0.5082 / 0.0154
[4/5][1400/1583] Loss_D: 0.3856 Loss_G: 3.1007 D(x): 0.8980 D(G(z)): 0.2238 / 0.0584
[4/5][1450/1583] Loss_D: 1.3314 Loss_G: 5.1061 D(x): 0.9395 D(G(z)): 0.6621 / 0.0094
[4/5][1500/1583] Loss_D: 0.5882 Loss_G: 1.7242 D(x): 0.6443 D(G(z)): 0.0785 / 0.2306
[4/5][1550/1583] Loss_D: 0.5792 Loss_G: 2.0347 D(x): 0.7582 D(G(z)): 0.2143 / 0.1594
6. 結果
最后,讓我們看看我們是如何做到的。在這里,我們將看到三個不同的結果。首先,我們將看到 D D D和 G G G的損失在訓練過程中是如何變化的。然后,我們將可視化 G G G在每個epoch
的fixed_noise batch
上的輸出。最后,我們將對比一批真實數據和一批來自 G G G的假數據。
6.1. 損失隨迭代次數的變化趨勢圖
以下是 D D D& G G G的損失與迭代次數的關系圖。
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G")
plt.plot(D_losses, label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
6.2. 可視化G的訓練過程
還記得我們是如何在每個訓練的epoch
后保存生成器的輸出嗎?現在,我們可以用動畫來可視化 G G G的訓練過程。
#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
HTML(ani.to_jshtml())
6.3. 真圖 vs 假圖
最后,讓我們並排對比查看一些真實圖像和虛假圖像。
# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))
# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))
# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()
7. 展望
本教程到這里已經結束了,但是如果你想深入地研究和了解GAN,你可以:
腳本總運行: ( 28 minutes 38.953 seconds)
8. 原文
https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html