PyTorch實現簡單的生成對抗網絡GAN


     生成對抗網絡是一個關於數據的生成模型:即給定訓練數據,GANs能夠估計數據的概率分布,基於這個概率分布產生數據樣本(這些樣本可能並沒有出現在訓練集中)。

   GAN中,兩個神經網絡互相競爭。給定訓練集X,假設是幾千張貓的圖片。將一個隨機向量輸入給生成器G(x),讓G(x)生成跟訓練集類似的圖片。判別器D(x)是一個二分類分類器,其試圖區分真實的貓圖片和生成器生成的假貓圖片。總的來說,生成器的目的是學習訓練數據的分布,生成盡可能真實的貓圖片,以確保判別器無法區分。判別器需要不斷地學習生成器的“造假圖片”,以防止自己被欺騙。

      判別器與生成器不斷“斗智斗勇”的過程中,生成器或多或少地學習到了訓練數據的真實分布,已經能生成一些以假亂真的圖片了;而判別器最終已經無法判斷貓的圖片是真實的,還是來自於生成器。從某種意義上來說,生成器和判別器都希望對方“失敗”,這個角度來看,不是很容易解釋。

     另外一個角度來說,判別器實際上是在指導生成器,告訴生成器: 真的貓圖片到底什么樣?模型訓練的最終結果是生成器能夠學習到數據的分布,最終可以生成近似真的貓圖片。GANs的訓練方法類似於博弈論中的MinMax算法,生成器和判別器最終達到了納什均衡。(摘自https://zhuanlan.zhihu.com/p/74663048

      生成對抗網絡(Generative Adversarial Network, GAN)包括生成網絡和對抗網絡兩部分。生成網絡像自動編碼器的解碼器,能夠生成數據,比如生成一張圖片。對抗網絡用來判斷數據的真假,比如是真圖片還是假圖片,真圖片是拍攝得到的,假圖片是生成網絡生成的。

       以下程序主要來自廖星宇的《深度學習之PyTorch》的第六章,本文對原代碼進行了改進:

import torch
from torch import nn
import torchvision.transforms as tfs
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import numpy as np
import matplotlib.pyplot as plt


def preprocess_img(x):
    x = tfs.ToTensor()(x)      # x (0., 1.)
    return (x - 0.5) / 0.5     # x (-1., 1.)


def deprocess_img(x):          # x (-1., 1.)
    return (x + 1.0) / 2.0     # x (0., 1.)


def discriminator():
    net = nn.Sequential(
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
        )
    return net


def generator(noise_dim):
    net = nn.Sequential(
        nn.Linear(noise_dim, 1024),
        nn.ReLU(True),
        nn.Linear(1024, 1024),
        nn.ReLU(True),
        nn.Linear(1024, 784),
        nn.Tanh(),
    )
    return net


def discriminator_loss(logits_real, logits_fake):   # 判別器的loss
    size = logits_real.shape[0]
    true_labels = torch.ones(size, 1).float()
    false_labels = torch.zeros(size, 1).float()
    bce_loss = nn.BCEWithLogitsLoss()
    loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
    return loss


def generator_loss(logits_fake):  # 生成器的 loss
    size = logits_fake.shape[0]
    true_labels = torch.ones(size, 1).float()
    bce_loss = nn.BCEWithLogitsLoss()
    loss = bce_loss(logits_fake, true_labels)   # 假圖與真圖的誤差。訓練的目的是減小誤差,即讓假圖接近真圖。
    return loss


# 使用 adam 來進行訓練,beta1 是 0.5, beta2 是 0.999
def get_optimizer(net, LearningRate):
    optimizer = torch.optim.Adam(net.parameters(), lr=LearningRate, betas=(0.5, 0.999))
    return optimizer


def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss,
                noise_size, num_epochs, num_img):

    f, a = plt.subplots(num_img, num_img, figsize=(num_img, num_img))
    plt.ion()  # Turn the interactive mode on, continuously plot

    for epoch in range(num_epochs):
        for iteration, (x, _)in enumerate(train_data):
            bs = x.shape[0]

            # 訓練判別網絡
            real_data = x.view(bs, -1)  # 真實數據
            logits_real = D_net(real_data)  # 判別網絡得分

            rand_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5  # -1 ~ 1 的均勻分布
            fake_images = G_net(rand_noise)  # 生成的假的數據
            logits_fake = D_net(fake_images)  # 判別網絡得分

            d_total_error = discriminator_loss(logits_real, logits_fake)  # 判別器的 loss
            D_optimizer.zero_grad()
            d_total_error.backward()
            D_optimizer.step()  # 優化判別網絡

            # 訓練生成網絡
            rand_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5  # -1 ~ 1 的均勻分布
            fake_images = G_net(rand_noise)  # 生成的假的數據

            gen_logits_fake = D_net(fake_images)
            g_error = generator_loss(gen_logits_fake)  # 生成網絡的 loss
            G_optimizer.zero_grad()
            g_error.backward()
            G_optimizer.step()  # 優化生成網絡

            if iteration % 20 == 0:
                print('Epoch: {:2d} | Iter: {:<4d} | D: {:.4f} | G:{:.4f}'.format(epoch,
                                                                                  iteration,
                                                                                  d_total_error.data.numpy(),
                                                                                  g_error.data.numpy()))
                imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
                for i in range(num_img ** 2):
                    a[i // num_img][i % num_img].imshow(np.reshape(imgs_numpy[i], (28, 28)), cmap='gray')
                    a[i // num_img][i % num_img].set_xticks(())
                    a[i // num_img][i % num_img].set_yticks(())
                plt.suptitle('epoch: {} iteration: {}'.format(epoch, iteration))
                plt.pause(0.01)

    plt.ioff()
    plt.show()


if __name__ == '__main__':

    EPOCH = 5
    BATCH_SIZE = 128
    LR = 5e-4
    NOISE_DIM = 96
    NUM_IMAGE = 4   # for showing images when training
    train_set = MNIST(root='/Users/wangpeng/Desktop/all/CS/Courses/Deep Learning/mofan_PyTorch/mnist/',
                      train=True,
                      download=False,
                      transform=preprocess_img)
    train_data = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)

    D = discriminator()
    G = generator(NOISE_DIM)

    D_optim = get_optimizer(D, LR)
    G_optim = get_optimizer(G, LR)

    train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss, NOISE_DIM, EPOCH, NUM_IMAGE)

效果:

 

程序的理解:

訓練Discriminator:

 

訓練Generatord:

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM