【pytorch】基於mnist數據集的dcgan手寫數字生成實現


在這里插入圖片描述



1. 學習目標

本教程教你如何使用dcgan訓練mnist數據集,生成手寫數字。

2. 環境配置

2.1. Python

請參考官網安裝。

2.2. Pytorch

請參考官網安裝。

2.3. Jupyter notebook

pip install jupyter

2.4. Matplotlib

pip install matplotlib

3. 具體實現

3.1. 導入模塊

import os
import time
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import utils, datasets, transforms
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

3.2. 設置隨機種子

# 設置隨機種子,以便復現實驗結果。
torch.manual_seed(0)

3.3. 超參數配置

  • dataroot:存放數據集文件夾所在的路徑
  • workers :數據加載器加載數據的線程數
  • batch_size:訓練的批次大小。
  • image_size:訓練圖像的維度。默認是64x64。如果需要其它尺寸,必須更改 D D D G G G的結構,點擊這里查看詳情
  • nc:輸入圖像的通道數。對於彩色圖像是3
  • nz:潛在空間的長度
  • ngf:與通過生成器進行的特征映射的深度有關
  • ndf:設置通過鑒別器傳播的特征映射的深度
  • num_epochs:訓練的總輪數。訓練的輪數越多,可能會導致更好的結果,但也會花費更長的時間
  • lr:學習率。DCGAN論文中用的是0.0002
  • beta1:Adam優化器的參數beta1。論文中,值為0.5
  • ngpus:可用的GPU數量。如果為0,代碼將在CPU模式下運行;如果大於0,它將在該數量的GPU下運行
# Root directory for dataset
dataroot = "data/mnist"

# Number of workers for dataloader
workers = 12

# Batch size during training
batch_size = 100

# Spatial size of training images. All images will be resized to this size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 1

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 5

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

3.4. 數據集

使用mnist數據集,其中訓練集6萬張,測試集1萬張,我們這里不是分類任務,而是使用gan的生成任務,所以就不分訓練和測試了,全部圖像都可以利用。

mnist_transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_data = datasets.MNIST(
    root=dataroot,
    train=True,
    transform=mnist_transform
    download=True
)
test_data = datasets.MNIST(
    root=dataroot,
    train=False,
    transform=mnist_transform
)
dataset = train_data+test_data
print(f'Total Size of Dataset: {len(dataset)}')

輸出:

Total Size of Dataset: 70000

3.5. 數據加載器

dataloader = DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=workers
)

3.6. 選擇訓練設備

檢測cuda是否可用,可用就用cuda加速,否則使用cpu訓練。

device = torch.device('cuda:0' if (torch.cuda.is_available() and ngpu > 0) else 'cpu')

3.7. 訓練數據可視化

inputs = next(iter(dataloader))[0]
plt.figure(figsize=(10,10), dpi=100)
plt.title("Training Images")
plt.axis('off')
inputs = utils.make_grid(inputs[:100], nrow=10)
plt.imshow(inputs.permute(1, 2, 0))

在這里插入圖片描述

3.8. 權重初始化

dcgan論文中,作者指出所有模型權重應當從均值為0,標准差為0.02的正態分布中隨機初始化。但這里不建議使用,親測使用后效果很差

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

3.9. 生成器

生成器的結構:
在這里插入圖片描述
構建生成器類:

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

生成器實例化:

# Create the generator
netG = Generator(ngpu).to(device)

# Handle multi-gpu if desired
if device.type == 'cuda' and ngpu > 1:
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights to mean=0, stdev=0.2.
# netG.apply(weights_init)

3.10. 判別器

構建判別器類:

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            # state size. (1) x 1 x 1
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

判別器實例化:

# Create the Discriminator
netD = Discriminator(ngpu).to(device)

# Handle multi-gpu if desired
if device.type == 'cuda' and ngpu > 1:
    netD = nn.DataParallel(netD, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights to mean=0, stdev=0.2.
# netD.apply(weights_init)

3.11. 優化器和損失函數

# Initialize BCELoss function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize the progression of the generator
fixed_noise = torch.randn(100, nz, 1, 1, device=device)
# print(f'Size of Latent Vector: {fixed_noise.size()}')

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

# Setup Adam optimizers for both G and D
optimizerD = torch.optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = torch.optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

3.12. 開始訓練

# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
D_x_list = []
D_z_list = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    beg_time = time.time()
    # For each batch in the dataloader
    for i, data in enumerate(dataloader):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0) # 64*8
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)

        # Forward pass real batch through D
        output = netD(real_cpu).view(-1) # output.size()=[128]

        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        end_time = time.time()
        run_time = round(end_time-beg_time)
        print(
            f'Epoch: [{epoch+1:0>{len(str(num_epochs))}}/{num_epochs}]',
            f'Step: [{i+1:0>{len(str(len(dataloader)))}}/{len(dataloader)}]',
            f'Loss-D: {errD.item():.4f}',
            f'Loss-G: {errG.item():.4f}',
            f'D(x): {D_x:.4f}',
            f'D(G(z)): [{D_G_z1:.4f}/{D_G_z2:.4f}]',
            f'Time: {int(run_time/60)}m{run_time%60}s',
            end='\r'
        )

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        
        # Save D(X) and D(G(z)) for plotting later
        D_x_list.append(D_x)
        D_z_list.append(D_G_z2)
        

        # Check how the generator is doing by saving G's output on fixed_noise
        iters += 1
        if (iters % 100 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(utils.make_grid(fake, nrow=10))
    print()

輸出:

Starting Training Loop...
Epoch: [1/5] Step: [700/700] Loss-D: 8.8328 Loss-G: 5.1051 D(x): 0.9995 D(G(z)): [0.9989/0.0200] Time: 1m6s
Epoch: [2/5] Step: [700/700] Loss-D: 2.5174 Loss-G: 0.7627 D(x): 0.1362 D(G(z)): [0.0006/0.5085] Time: 1m8s
Epoch: [3/5] Step: [700/700] Loss-D: 0.0355 Loss-G: 4.4222 D(x): 0.9767 D(G(z)): [0.0113/0.0163] Time: 1m8s
Epoch: [4/5] Step: [700/700] Loss-D: 0.9482 Loss-G: 1.9022 D(x): 0.6590 D(G(z)): [0.3345/0.1798] Time: 1m8s
Epoch: [5/5] Step: [700/700] Loss-D: 0.0939 Loss-G: 3.1018 D(x): 0.9168 D(G(z)): [0.0025/0.0698] Time: 1m8s

3.13. 訓練過程中的損失變化

plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses[::100], label="G")
plt.plot(D_losses[::100], label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

在這里插入圖片描述

3.14. 訓練過程中的D(x)和D(G(z))變化

plt.figure(figsize=(10, 5))
plt.title("D(x) and D(G(z)) During Training")
plt.plot(D_x_list[::100], label="D(x)")
plt.plot(D_z_list[::100], label="D(G(z))")
plt.xlabel("iterations")
plt.ylabel("Probability")
plt.legend()
plt.show()

在這里插入圖片描述

3.15. 可視化G的訓練過程

fig = plt.figure(figsize=(10, 10), dpi=100)
fig = plt.figure()
plt.axis("off")
ims = [[plt.imshow(item.permute(1, 2, 0), animated=True)] for item in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
HTML(ani.to_jshtml())

在這里插入圖片描述

4. 真圖 vs 假圖

# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(20,10), dpi=300)
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(utils.make_grid(real_batch[0][:100], nrow=10).permute(1, 2, 0))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.savefig('comparation.jpg', )
plt.imshow(transforms.Normalize((0.1307,), (0.3081,))(img_list[-1]).permute(1, 2, 0))

在這里插入圖片描述

(左邊是數據集中的真圖,右邊是生成器生成的假圖)

5. 溫馨提示

本教程使用的是1GTX 1080 Ti的顯卡,訓練一個epoch大概1m8s左右。雖然實驗室有8張卡,但沒必要都用,親測多卡訓練速度反而更慢,當然我這里說的是數據並行DataParallel。分布式distributed訓練的話應該會快很多,但對於初學者來說不太建議使用,因為配置很麻煩。

6. 完整代碼

train.ipynb:點擊下載(下載后請使用Jupyter notebook中打開)

7. 引用參考

https://blog.csdn.net/qq_42951560/article/details/110308336


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM