[PyTorch 學習筆記] 8.3 GAN(生成對抗網絡)簡介


本章代碼:

這篇文章主要介紹了生成對抗網絡(Generative Adversarial Network),簡稱 GAN。

GAN 可以看作是一種可以生成特定分布數據的模型。

下面的代碼是使用 Generator 來生成人臉圖像,Generator 已經訓練好保存在 pkl 文件中,只需要加載參數即可。由於模型是在多 GPU 的機器上訓練的,因此加載參數后需要使用remove_module()函數來修改state_dict中的key

def remove_module(state_dict_g):
    # remove module.
    from collections import OrderedDict

    new_state_dict = OrderedDict()
    for k, v in state_dict_g.items():
        namekey = k[7:] if k.startswith('module.') else k
        new_state_dict[namekey] = v

    return new_state_dict

把隨機的高斯噪聲輸入到模型中,就可以得到人臉輸出,最后進行可視化。全部代碼如下:

import os
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
from common_tools import set_seed
from torch.utils.data import DataLoader
from my_dataset import CelebADataset
from dcgan import Discriminator, Generator
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def remove_module(state_dict_g):
    # remove module.
    from collections import OrderedDict

    new_state_dict = OrderedDict()
    for k, v in state_dict_g.items():
        namekey = k[7:] if k.startswith('module.') else k
        new_state_dict[namekey] = v

    return new_state_dict

set_seed(1)  # 設置隨機種子

# config
path_checkpoint = os.path.join(BASE_DIR, "gan_checkpoint_14_epoch.pkl")
image_size = 64
num_img = 64
nc = 3
nz = 100
ngf = 128
ndf = 128

d_transforms = transforms.Compose([transforms.Resize(image_size),
                   transforms.CenterCrop(image_size),
                   transforms.ToTensor(),
                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
               ])

# step 1: data
fixed_noise = torch.randn(num_img, nz, 1, 1, device=device)

flag = 0
# flag = 1
if flag:
    z_idx = 0
    single_noise = torch.randn(1, nz, 1, 1, device=device)
    for i in range(num_img):
        add_noise = single_noise
        add_noise = add_noise[0, z_idx, 0, 0] + i*0.01
        fixed_noise[i, ...] = add_noise

# step 2: model
net_g = Generator(nz=nz, ngf=ngf, nc=nc)
# net_d = Discriminator(nc=nc, ndf=ndf)
checkpoint = torch.load(path_checkpoint, map_location="cpu")

state_dict_g = checkpoint["g_model_state_dict"]
state_dict_g = remove_module(state_dict_g)
net_g.load_state_dict(state_dict_g)
net_g.to(device)
# net_d.load_state_dict(checkpoint["d_model_state_dict"])
# net_d.to(device)

# step3: inference
with torch.no_grad():
    fake_data = net_g(fixed_noise).detach().cpu()
img_grid = vutils.make_grid(fake_data, padding=2, normalize=True).numpy()
img_grid = np.transpose(img_grid, (1, 2, 0))
plt.imshow(img_grid)
plt.show()

輸出如下:


下面對 GAN 的網絡結構進行講解

Generator 接受隨機噪聲 $z$ 作為輸入,輸出生成的數據 $G(z)$。Generator 的目標是讓生成數據和真實數據的分布越接近。Discriminator 接收 $G(z)$ 和隨機選取的真實數據 $x$,目標是分類真實數據和生成數據,屬於 2 分類問題。Discriminator 的目標是把它們二者之間分開。這里體現了對抗的思想,也就是 Generator 要欺騙 Discriminator,而 Discriminator 要識別 Generator。

GAN 的訓練和監督學習訓練模式的差異

在監督學習的訓練模式中,訓練數經過模型得到輸出值,然后使用損失函數計算輸出值與標簽之間的差異,根據差異值進行反向傳播,更新模型的參數,如下圖所示。


在 GAN 的訓練模式中,Generator 接收隨機數得到輸出值,目標是讓輸出值的分布與訓練數據的分布接近,但是這里不是使用人為定義的損失函數來計算輸出值與訓練數據分布之間的差異,而是使用 Discriminator 來計算這個差異。需要注意的是這個差異不是單個數字上的差異,而是分布上的差異。如下圖所示。

# GAN 的訓練
  1. 首先固定 Generator,訓練 Discriminator。

    • 輸入:真實數據 $x$,Generator 生成的數據 $G(z)$
    • 輸出:二分類概率

    從噪聲分布中隨機采樣噪聲 $z$,經過 Generator 生成 $G(z)$。$G(z)$ 和 $x$ 輸入到 Discriminator 得到 $D(x)$ 和 $D(G(z))$,損失函數為 $\frac{1}{m} \sum_{i=1}^{m}\left[\log D\left(\boldsymbol{x}^{(i)}\right)+\log \left(1-D\left(G\left(\boldsymbol{z}^{(i)}\right)\right)\right)\right]$,這里是最大化損失函數,因此使用梯度上升法更新參數:$$\nabla_{\theta_{d}} \frac{1}{m} \sum_{i=1}^{m}\left[\log D\left(\boldsymbol{x}^{(i)}\right)+\log \left(1-D\left(G\left(\boldsymbol{z}^{(i)}\right)\right)\right)\right]$$。

  2. 固定 Discriminator,訓練 Generator。

    • 輸入:隨機噪聲 $z$
    • 輸出:分類概率 $D(G(z))$,目的是使 $D(G(z))=1$

    從噪聲分布中重新隨機采樣噪聲 $z$,經過 Generator 生成 $G(z)$。$G(z)$ 輸入到 Discriminator 得到 $D(G(z))$,損失函數為 $\frac{1}{m} \sum_{i=1}^{m} \log \left(1-D\left(G\left(z^{(i)}\right)\right)\right)$,這里是最小化損失函數,使用梯度下降法更新參數:$\nabla_{\theta_{g}} \frac{1}{m} \sum_{i=1}^{m} \log \left(1-D\left(G\left(z^{(i)}\right)\right)\right)$。

下面是 DCGAN 的例子,DC 的含義是 Deep Convolution,指 Generator 和 Discriminator 都是卷積神經網絡。

Generator 的網絡結構如下圖左邊,使用的是 transpose convolution,輸入是 100 維的隨機噪聲 $z$,形狀是 $(1,100,1,1)$,看作是 100 個 channel,每個特征圖寬高是 $1 \times 1$;輸出是 $(3,64,64)$ 的圖片 $G(z)$。

Generator 的網絡結構如下圖右邊,使用的是 convolution,輸入是 $G(z)$ 或者真實圖片 $x$,輸出是 2 分類概率。


使用數據集來源於 CelebA 人臉數據:http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html,有 22 萬張人臉圖片數據。

但是由於人臉在圖片中的角度、位置、所占區域大小等都不一樣。如下圖所示。


需要對關鍵點檢測算法對人臉在圖片中的位置和大小等進行矯正。下圖是矯正后的數據。

人臉矯正數據的下載地址:https://pan.baidu.com/s/1OhE_ITg3Je4ETECm74VfRA,提取碼:yarv。

在對圖片進行標准化時,經過toTensor()轉換到 $[0,1]$ 后,把transforms.Normalize()的均值和標准差均設置為 0.5,這樣就把數據轉換為到 $[-1,1]$ 區間,因為 $((0,1)-0.5)/0.5=(-1,1)$。

DCGAN 的定義如下:

from collections import OrderedDict
import torch
import torch.nn as nn


class Generator(nn.Module):
    def __init__(self, nz=100, ngf=128, nc=3):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

    def initialize_weights(self, w_mean=0., w_std=0.02, b_mean=1, b_std=0.02):
        for m in self.modules():
            classname = m.__class__.__name__
            if classname.find('Conv') != -1:
                nn.init.normal_(m.weight.data, w_mean, w_std)
            elif classname.find('BatchNorm') != -1:
                nn.init.normal_(m.weight.data, b_mean, b_std)
                nn.init.constant_(m.bias.data, 0)


class Discriminator(nn.Module):
    def __init__(self, nc=3, ndf=128):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

    def initialize_weights(self, w_mean=0., w_std=0.02, b_mean=1, b_std=0.02):
        for m in self.modules():
            classname = m.__class__.__name__
            if classname.find('Conv') != -1:
                nn.init.normal_(m.weight.data, w_mean, w_std)
            elif classname.find('BatchNorm') != -1:
                nn.init.normal_(m.weight.data, b_mean, b_std)
                nn.init.constant_(m.bias.data, 0)

其中nz是輸入的通道數 100,ngf表示最后輸出的圖片寬高,這里設置為 64,會有一個倍數關系, nc是最后的輸出通道數 3。

在迭代訓練時,首先根據 DataLoader 獲得的真實數據的 batch size,構造真實的標簽 1;

然后隨機生成噪聲,構造生成數據的標簽 0。把噪聲輸入到 Generator 中,得到生成數據。

分別把生成數據和真實數據輸入到 Discriminator,得到兩個 Loss,分別求取梯度相加,然后使用 Discriminator 的優化器更新 Discriminator 的參數。

然后生成數據的標簽改為 1,輸入到 Generator,求取梯度,這次使用 Generator 的優化器更新 Generator 的參數。

這里只使用了 2000 張圖片來訓練 20 個 epoch,下圖是每個 epoch 生成的數據的可視化。

效果不是很好,可以使用幾個 trick 來提升生成圖片的效果。

  • 使用 22 萬張圖片
  • ngf設置為 128
  • 標簽平滑:真實數據的標簽設置為 0.9,生成數據的標簽設置為 0.1

GAN 的一些應用

  • 生成人的不同姿態


  • CycleGAN:對一個風格的圖片轉換為另一個風格


  • PixelDTGAN:通過一件衣服生成相近的衣服


  • SRGAN:根據模糊圖像生成超分辨率的圖像


  • Progressive GAN:生成高分辨率的人臉圖像


  • StackGAN:根據文字生成圖片


  • Context Encoders:補全圖片中缺失的部分


  • Pix2Pix:也屬於圖像風格遷移


  • IcGAN:控制生成人臉的條件,如生成的人臉的頭發顏色、是否戴眼鏡等


參考資料


如果你覺得這篇文章對你有幫助,不妨點個贊,讓我有更多動力寫出好文章。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM