對抗生成網絡主要的原理,主要是使用生成器生成網絡,判別器進行判別
生成器損失值:
判別器判別生成圖片為真的BCE損失值
判別器損失值
判別真實圖片為真和判別生成圖片為假的BCE損失值
第一步: 使用argparse構造cmd輸入的參數函數, 包含batch_size, lr學習率 ,latent_dim表示噪音生成的維度
第二步: 構造mnist數據集的dataloaders,使用torchvison.dataset.MNIST數據集, 使用transforms.compose([])進行數據集的轉換, 使用torch.utils.data.Dataloaders構造batch_size數據集
第三步: 實例化生成網絡
生成網絡網絡結構:
構造block模塊,包含nn.Leanear, nn.BatchNormal1d(out_feats, 0.8) 和 nn.LeakyRelu(0.2) 表示對於小於0的數據乘以0.2,將其比例進行稀釋
第一層: 轉換為int(latent_dim, 128)
第二層: 轉換為(128, 256)
第三層:轉換為(256, 512)
第四層: 轉換為(512, 1024)
第五層: 轉換為(1024, int(np.prod(input_size)))
第六層: nn.Tanh()
實例化判別網絡
判別網絡網絡結構:
第一層: 轉換為(int(np.prod(input_size), 512))
第二層: 轉換為(512, 256)
第三層: 轉換為(256, 1)
第四層:轉換為nn.Sigmoid()
第四步: 進行網絡訓練操作
第一步: 使用torch.nn.BCELoss() 構造損失函數
第二步: 實例化判別網絡和生成網絡
第三步: 構造迭代優化器
第四步:進行網絡的訓練操作,構造全1的真實標簽valid和構造全0的虛假標簽fake, 將輸入的數據轉換為Variable的tensor類型,構造隨機的100維噪音數據,將噪音數據傳入生成生成圖片
第五步: 構造生成器的優化函數,構造判別器的優化函數, 生成器的損失值,使用判別器判別為真的BCE損失值 , 對於判別器的損失值,使用判別器判別為真實圖片為真,判別生成圖片為假的損失值
第六步: 打印數據,同時使用save_images進行數據集的保存
import argparse import time import torch import torch.utils.data from torchvision import transforms, datasets from torch import nn from torch.autograd import Variable import os import numpy as np from torch import optim from torchvision.utils import save_image parser = argparse.ArgumentParser() parser.add_argument('--n_epochs', type=int, default=20, help='迭代的次數') parser.add_argument('--batch_size', type=int, default=64, help='每個batch_size') parser.add_argument('--lr', type=int, default=0.0002, help='表示學習率') parser.add_argument('--b1', type=float, default=0.5, help='表示動量梯度下降第一個參數') parser.add_argument('--b2', type=float, default=0.99, help='動量梯度下降第二個參數') parser.add_argument('--n_cpu', type=int, default=8, help='表示cpu運行的個數') parser.add_argument('--latent_dim', type=int, default=100, help='表示噪音數據生成的維度') parser.add_argument('--image_size', type=int, default=28, help='表示輸入數據的維度') parser.add_argument('--channel', type=int, default=1, help='表示輸入數據的通道數') parser.add_argument('--sample_interval', type=int, default=400, help='表示保存圖片的迭代數') opt = parser.parse_args() # 表示輸入數據的尺寸 input_size = (opt.image_size, opt.image_size, opt.channel) os.makedirs('./data', exist_ok=True) os.makedirs('./data/mnist', exist_ok=True) # 進行數據集的准備 os.makedirs('./data/mnist', exist_ok=True) dataloaders = torch.utils.data.DataLoader( datasets.MNIST( './data/mnist', train = True, download=True, transform=transforms.Compose( [transforms.Resize(opt.image_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] ), ), batch_size=opt.batch_size, shuffle=True ) cuda = True if torch.cuda.is_available() else False # 構建生成網絡 class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() def block(in_feats, out_feats, Normalize=True): layers = [nn.Linear(in_feats, out_feats)] if Normalize: layers.append(nn.BatchNorm1d(out_feats, 0.8)) layers.append(nn.LeakyReLU(0.2, inplace=True)) return layers self.model = nn.Sequential( *block(opt.latent_dim, 128, Normalize=False), *block(128, 256), *block(256, 512), *block(512, 1024), # 最后一層全連接層,不需要進行batchnomalize 和 relu操作 nn.Linear(1024, int(np.prod(input_size))), nn.Tanh(), ) def forward(self, x): output = self.model(x) return output class Discrimator(nn.Module): def __init__(self): super(Discrimator, self).__init__() self.model = nn.Sequential( nn.Linear(int(np.prod(input_size)), 512), nn.LeakyReLU(0.2, inplace=True), nn.Linear(512, 256), nn.LeakyReLU(0.2, inplace=True), nn.Linear(256, 1), nn.Sigmoid(), ) def forward(self, x): output = self.model(x) return output # 構造損失值函數 adversial_loss = torch.nn.BCELoss() generator = Generator() discrimator = Discrimator() # 將數據放在cuda上 if cuda: adversial_loss.cuda() generator.cuda() discrimator.cuda() optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) optimizer_D = torch.optim.Adam(discrimator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor for epoch in range(opt.n_epochs): for i, (image, _) in enumerate(dataloaders): # 構造標簽損失函數 valid = Variable(tensor(image.size(0), 1).fill_(1.0), requires_grad=False) fake = Variable(tensor(image.size(0), 1).fill_(0.0), requires_grad=False) # 構建真實的輸入值 real_image = torch.reshape(Variable(image.type(tensor)), (int(image.shape[0]), -1)) optimizer_G.zero_grad() # 對於生成器 z = Variable(tensor(np.random.normal(0, 1, (image.shape[0], opt.latent_dim)))) gen_images = generator(z) g_loss = adversial_loss(discrimator(gen_images), valid) g_loss.backward() optimizer_G.step() # # 構造判別器的損失函數 optimizer_D.zero_grad() real_loss = adversial_loss(discrimator(real_image), valid) fake_loss = adversial_loss(discrimator(gen_images.detach()), fake) d_loss = (real_loss + fake_loss) / 2 d_loss.backward() optimizer_D.step() print( '[Epoch %d / %d] Batch %d / %d [D loss: %f] [G loss: %f]' % (epoch, opt.n_epochs, i, len(dataloaders), d_loss.item(), g_loss.item()) ) batches_done = epoch * len(dataloaders) + i if batches_done % opt.sample_interval == 0: save_image(gen_images.data[:25], 'images/%d.png' % batches_done, nrow=5, normalize=True)