GAN
這一概念是由Ian Goodfellow
於2014年提出,並迅速成為了非常火熱的研究話題,GAN的變種更是有上千種,深度學習先驅之一的Yann LeCun
就曾說,"GAN及其變種是數十年來機器學習領域最有趣的idea
"。那么什么是GAN呢?GAN的應用有哪些呢?GAN的原理是什么呢?怎樣去實現一個GAN呢?本文將一一闡述。具體大綱如下:
- 1.什么是GAN?
- 2. GAN的應用
- 3. GAN的原理
- 4.實現DCGAN [Github鏈接]
- 5. 實現WGAN[Github鏈接]
- 6.實現ConditionalGAN [Github鏈接]
- 7. GAN小技巧
- 8. 參考
- 9. 未完待續(后期還會加一些其他的GAN)
什么是GAN?
GAN的英文全稱是Generative Adversarial Network
,中文名是生成對抗網絡,它由兩個部分組成,一個是生成器(generative),還有一個是鑒別器,與生成器是敵對(Adversarial)關系。對GAN有了初步了解,知道它有兩個模塊組成,下面通過事例來理解這兩個模塊的產生思想?
對抗思想——啵啵鳥與枯葉蝶
在生物進化的過程中,被捕食者會慢慢演化自己的特征,從而達到欺騙捕食者的目的,而捕食者也會根據情況調整自己對被捕食者的識別,共同進化,上圖中的啵啵鳥和枯葉蝶就是這樣的一種關系。生成器代表的是枯葉蝶,鑒別器代表的是啵啵鳥。它們的對抗思想與GAN類似,但GAN卻有所不同。
GAN思想——畫畫的演變
GAN之所以有所不同,這里的原因是GAN所作的工作與自然界的生物進化不同,它是已經知道最終鑒別的目標是什么樣子,不知道假目標是什么樣子,它會對生成器所產生的假目標做懲罰和對真目標進行獎勵,這樣鑒別器就知道什么目標是不好的假目標,什么目標是好的真目標,而生成器則是希望通過進化,產生比上一次更好的假目標,使鑒別器對自己的懲罰更小。以上是一個輪回,下一個輪回,鑒別器通過學習上一個輪回進化的假目標和真目標,再次進化對假目標的懲罰,而生成器不屈不撓,再次進化,直到以假亂真,與真目標一致,至此進化結束。
以上圖為例,我們最開始畫人物頭像只知道有一個頭的大致形狀,有眼睛有鼻子等等,但畫得不精致,后來通過找老師學習,畫得更好了,有模有樣,直到,我們畫得與專門畫頭像的老師一樣好。這里的我們
就像是生成器
,一步步進化(對應生成器不同的等級),這里的老師
就像是鑒別器
(這里只是比喻說明,現實世界的老師已經是一個成熟的鑒別器,不需要通過假樣本進行學習,這里有那個意思就行)
零和博弈(zero-sum game)
玩過紙牌的人知道,贏家的快樂是建立在輸家的痛苦之上,收益和損失的總和始終為0。生成器和鑒別器也是這樣一對博弈關系:鑒別器懲罰生成器,鑒別器收益,生成器損失;生成器進化,使鑒別器對自己懲罰小,生成器收益,鑒別器損失。
小結
什么是GAN?GAN是由生成器和鑒別器兩個部分組成,生成器的目的是生成假的目標,企圖徹底騙過鑒別器的識別。而鑒別器通過學習真目標和假目標,提高自己的鑒別能力,不讓假目標騙過自己。兩者相互進化,相互博弈,一方進化,另一方損失,最后直到假目標與真目標很相似則停止進化。
GAN的應用
首先,我們要知道結構化學習
(Structured Learning),GAN也是結構化學習的一種。與分類和回歸類似,結構化學習也是需要找到一個X\(\rightarrow\)Y的映射,但結構化學習的輸入和輸出多種多樣,可以是序列(sequence)到序列,序列到矩陣(matrix),矩陣到圖(graph),圖到樹(tree)等等。這樣,GAN的應用就十分廣泛了。例如,機器翻譯(machine translation)可以用GAN去做,如下圖所示
還有語音識別(speech recognition)以及聊天機器人(chat-bot)
在圖像方面,我們可以做圖像轉圖像(image-to-image),彩色化(colorization),還有文本轉圖像(text-to-image)
當然,GAN的應用遠不止這么些,有非常有趣的變臉,圖像自動打馬賽克,自動生成多表情圖像,年輕轉年老等等,更多cool又skr
的應用靜待各位挖掘!
GAN原理
GAN的最終目的是為了生成能夠產生以假亂真的目標的生成器。那么,是不是一定要用GAN呢?生成器可不可以自己訓練得到目標?鑒別器可不可以自己訓練得到目標?我們先來看這兩個問題,然后再深入討論GAN。
生成器是否可以自我訓練?
答案是肯定的,我們所熟知的自編碼器
(Auto-Encoder)以及變分自編碼器
(Variational Auto-Encoder)都是典型的生成器。輸入通過Encoder編碼成code,然后code通過Decoder重建原圖,其中自編碼器中的Decoder就是生成器,code可隨機取值,產生不同的輸出。
自編碼器的結構如下:
變分自編碼器的結構如下
然后自編碼器存在着問題,我們來看看下面這張圖
生成器的問題:由於自編碼器的目標是讓重建誤差越來越小,但從上圖中,我們可以看出,其中1個pixel的error,自編碼器是覺得ok的,我們是覺得不行,另外6個pixel的誤差我們覺得能接受的,自編碼器不能接受,誤差所在的位置很重要,而生成器並不知道這一點,自編碼器缺少理解像素點之間的空間相關性的能力。還有一點,就是自編碼器所產生的圖像是模糊的,不能夠產生十分清晰的圖像,如下圖所示
所以說目前單憑生成器是很難生成非常高質量的圖像的。
鑒別器是否可以自我訓練?
答案也是肯定的。鑒別器是給定一個輸入,輸出一個[0,1]的置信度,越接近1則置信越高,越接近0則置信度越低,如圖所示:
鑒別器的優勢在於它可以很輕易地捕捉到元素之間的相關性,例如自編碼器中出現的像素問題就不會在鑒別器中出現,如圖所示,用一個濾波器就解決了。
現在來說說鑒別器要怎么樣產生樣本,參考下圖:
首先也需要隨機生成負樣本,然后與真實樣本一起送入鑒別器進行訓練,在循環迭代中,通過最大概率選出最好的負樣本,再與真樣本一起送入鑒別器進行訓練,然而,看起來和GAN訓練差不多一致,沒啥問題,其實這里面還有存在着問題的。我們來看下面這張圖:
鑒別器的問題:鑒別器的訓練是對真樣本進行獎勵,對負樣本進行壓低,也就是圖中的綠色抬高,藍色壓低,這就造成了問題,我們要訓練出好的鑒別器,訓練過程需要隨機采樣出除綠色圖像外所有的假樣本,這樣鑒別器就只會對真實樣本的分布取高分,對其他分布取低分,這樣才能訓練的好,然后再高維空間中,這樣的負樣本采樣過程其實是很難進行的,而且還有一個問題,生成樣本的過程要枚舉大量樣本,才有可能出現一個與真樣本分布相符的樣本,通過求那個最大化概率問題求出最好的樣本,這實在是過於繁瑣。
生成器、鑒別器和GAN的優缺點
通過上面的闡述,我們初步知道了它們的優缺點,下面這張ppt直觀地給出了每個的優缺點,如圖所示:
可以看出生成器和鑒別器的優缺點是可以互補的,這也就是GAN的優勢。(生成器+鑒別器),下圖介紹了GAN的優點,從兩個角度出發。
- 從鑒別器的角度出發,利用生成器去生成樣本,去求解最大化問題
- 從生成器角度出發,生成的樣本依舊是逐個元素,但通過鑒別器可以得到全局性。
當然,GAN也是又缺點的,它是一種隱變量模型,可解釋沒有生成器和鑒別器強,另外GAN是不好進行訓練。我在訓練DAGAN的時候就成功造成了鑒別器的誤差為0,無法進行反向傳播更新梯度。
GAN背后的理論
對於生成器而言,它的目標是希望能夠學習到真實樣本的分布,這樣就可以隨機生成以假亂真的樣本。如下圖所示
如何去學習真實樣本分布呢,這就需要用到極大似然估計
(Maximum Likelihood Estimation),先來看看下面這張圖
我們需要隨機采樣真實分布中的數據,通過學習\(P(x;\theta)\)中的\(\theta\),希望\(P(x;\theta)\)越接近\(P_{data}(x)\),其中每一個\(x\)對應的\(P_{data}(x)\)的概率是很大的,為了使\(P(x;\theta)\)越接近\(P_{data}(x)\),原問題等價於最大化每一個\(P(x_i;\theta)\),合起來就是最大化\(\prod_{i=1}^mP_{G}(x^i;\theta)\)。而實際上極大似然估計是等價於最小化\(KL-divergence\),具體推導看下圖,先取\(log\)(\(log\)是單調遞增,不會改變原問題)將相乘化為相加,最后變成了\(P_{data}\)下\(logP_{G}(x;\theta)\)的期望,然后轉化成積分的形式,后面加了一項\(\intop_xP_{data}(x)logP_{data}(x)dx\),這一項是一個常數,沒有變量\(\theta\),加了也不會影響原問題的解,加了這一項之后原問題就等於最小化\(P_{data}和P_{G}\)的\(KL-divergence\)。
我們已經知道生成器要做的是\(arg\space \underset{G}{min}\space Div(P_{data},P_{G})\),這里\(P_{G}\)是我們要去最優化的,雖然我們有真實樣本,但\(P_G\)的分布我們還是不知道,而且如何去定量計算\(P_{data}\)和\(P_G\)的\(divergence\),也就是\(Div(P_{data},P_G)\),我們也是不知道的。所以接下來就需要引入鑒別器了。
雖然我們不知道\(P_G\)和\(P_{data}\)的分布,但我們可以隨機采樣它們分布的樣本,如下圖所示:
而我們知道鑒別器的目標是給真樣本獎勵,假樣本懲罰,如下圖所示,最后得到要鑒別器要優化的目標函數,鑒別器希望能夠最大化這個目標函數,也就是\(arg \space \underset{D}{max}\space V(D,G)\).注意,這里是是將\(G\)是\(fixed\),是不變的。
我們再來解這個問題,解出最優\(D^*\),接下來的步驟就比較數學了,給一個目標函數,求出極大值解。具體如圖下
這個求解過程還是蠻詳細的,最后我們竟然得到最大化\(V(D,G)\)竟然等於一個常數加上\(P_G\)和\(P_{data}\)的\(JS-divergence\)(\(JS-divergence\)與\(KL-divergence\)類似,不會改變解),這正是我們在生成器一直想求,可不會求得東西,鑒別器幫我們做到了。
於是,原始生成器的最優化問題\(arg\space\underset{G}{min}Div(P_G,P_{data})\)就可以轉化成\(arg\space\underset{G}{min}\space \underset{D}{max}V(G,D)\)。那如何來求解\(arg\space\underset{G}{min}\space \underset{D}{max}V(G,D)\)這個最小最大問題呢?其實上面圖上已經給出答案了,通過固定其中一個,求另一個,然后固定另一個,求之前固定住的這個。具體做法如圖下:
更加詳細的實踐過程(也就是GAN的訓練過程)如下所示,相信看了上面的一系列解釋,會對GAN如此訓練有了比較深的理解了吧。
GAN的理論就到此結束。
實現DCGAN
這里使用數據集是Anime——台大李宏毅老師的GAN課程的數據集,點擊鏈接下載,首先我們來看一下DCGAN的框架,如圖所示
這個是生成器的結構圖,鑒別器的結構與生成器大致相反,DCGAN與普通的GAN有一些區別,具體分為下面幾點
- DCGAN的網絡都是全卷積的
- 生成器除最后一層外都加batchnorm,鑒別器則是第一層沒加bacthnorm
- 鑒別器中的激活函數使用的是leaky_relu,負斜率是0.2
- 生成器中的激活函數使用relu,輸出層采用tanh
- 采用Adam優化算法,學習率是0.0002,beta1=0.5
model.py
import torch
import torch.nn as nn
import torch.functional as F
class Generate(nn.Module):
def __init__(self, input_dim=100):
super(Generate, self).__init__()
channel = [512, 256, 128, 64, 3]
kernel_size = 4
stride = 2
padding = 1
self.convtrans1_block = self.__convtrans_bolck(input_dim, channel[0], 6, padding=0, stride=stride)
self.convtrans2_block = self.__convtrans_bolck(channel[0], channel[1], kernel_size, padding, stride)
self.convtrans3_block = self.__convtrans_bolck(channel[1], channel[2], kernel_size, padding, stride)
self.convtrans4_block = self.__convtrans_bolck(channel[2], channel[3], kernel_size, padding, stride)
self.convtrans5_block = self.__convtrans_bolck(channel[3], channel[4], kernel_size, padding, stride, layer="last_layer")
def __convtrans_bolck(self, in_channel, out_channel, kernel_size, padding, stride, layer=None):
if layer == "last_layer":
convtrans = nn.ConvTranspose2d(in_channel, out_channel, kernel_size, stride, padding, bias=False)
tanh = nn.Tanh()
return nn.Sequential(convtrans, tanh)
else:
convtrans = nn.ConvTranspose2d(in_channel, out_channel, kernel_size, stride, padding, bias=False)
batch_norm = nn.BatchNorm2d(out_channel)
relu = nn.ReLU(True)
return nn.Sequential(convtrans, batch_norm, relu)
def forward(self, inp):
x = self.convtrans1_block(inp)
x = self.convtrans2_block(x)
x = self.convtrans3_block(x)
x = self.convtrans4_block(x)
x = self.convtrans5_block(x)
return x
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
channels = [3, 64, 128, 256, 512]
kernel_size = 4
stride = 2
padding = 1
self.conv_bolck1 = self.__conv_block(channels[0], channels[1], kernel_size, stride, padding, "first_layer")
self.conv_bolok2 = self.__conv_block(channels[1], channels[2], kernel_size, stride, padding)
self.conv_bolok3 = self.__conv_block(channels[2], channels[3], kernel_size, stride, padding)
self.conv_bolok4 = self.__conv_block(channels[3], channels[4], kernel_size, stride, padding)
self.conv_bolok5 = self.__conv_block(channels[4], 1, kernel_size+1, stride, 0, "last_layer")
def __conv_block(self, inchannel, outchannel, kernel_size, stride, padding, layer=None):
if layer == "first_layer":
conv = nn.Conv2d(inchannel, outchannel, kernel_size, stride, padding, bias=False)
leakrelu = nn.LeakyReLU(0.2, inplace=True)
return nn.Sequential(conv, leakrelu)
elif layer == "last_layer":
conv = nn.Conv2d(inchannel, outchannel, kernel_size, stride, padding, bias=False)
sigmoid = nn.Sigmoid()
return nn.Sequential(conv, sigmoid)
else:
conv = nn.Conv2d(inchannel, outchannel, kernel_size, stride, padding, bias=False)
batchnorm = nn.BatchNorm2d(outchannel)
leakrelu = nn.LeakyReLU(0.2, inplace=True)
return nn.Sequential(conv, batchnorm, leakrelu)
def forward(self,inp):
x = self.conv_bolck1(inp)
x = self.conv_bolok2(x)
x = self.conv_bolok3(x)
x = self.conv_bolok4(x)
x = self.conv_bolok5(x)
return x
def weight_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0,0.01)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0,0.01)
m.bias.data.fill_(0)
if __name__ == "__main__":
model1 = Generate()
x = torch.randn(10,100,1,1)
y = model1.forward(x)
print(y.size())
model2 = Discriminator()
a = torch.randn(10,3,96,96)
b = model2.forward(a)
print(b.size())
AnimeDataset.py
import torch,torch.utils.data
import numpy as np
import scipy.misc, os
class AnimeDataset(torch.utils.data.Dataset):
def __init__(self, directory, dataset, size_per_dataset):
self.directory = directory
self.dataset = dataset
self.size_per_dataset = size_per_dataset
self.data_files = []
data_path = os.path.join(directory, dataset)
for i in range(size_per_dataset):
self.data_files.append(os.path.join(data_path,"{}.jpg".format(i)))
def __getitem__(self, ind):
path = self.data_files[ind]
img = scipy.misc.imread(path)
img = img.transpose(2,0,1)-127.5/127.5
return img
def __len__(self):
return len(self.data_files)
if __name__ == "__main__":
dataset = AnimeDataset(os.getcwd(),"faces",100)
loader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True,num_workers=4)
for i, inp in enumerate(loader):
print(i,inp.size())
utils.py
import os, imageio,scipy.misc
import matplotlib.pyplot as plt
def creat_gif(gif_name, img_path, duration=0.3):
frames = []
img_names = os.listdir(img_path)
img_list = [os.path.join(img_path, img_name) for img_name in img_names]
for img_name in img_list:
frames.append(imageio.imread(img_name))
imageio.mimsave(gif_name, frames, 'GIF', duration=duration)
def visualize_loss(generate_txt_path, discriminator_txt_path):
with open(generate_txt_path, 'r') as f:
G_list_str = f.readlines()
with open(discriminator_txt_path, 'r') as f:
D_list_str = f.readlines()
D_list_float, G_list_float = [], []
for D_item, G_item in zip(D_list_str, G_list_str):
D_list_float.append(float(D_item.strip().split(':')[-1]))
G_list_float.append(float(G_item.strip().split(':')[-1]))
list_epoch = list(range(len(D_list_float)))
full_path = os.path.join(os.getcwd(), "saved/logging.png")
plt.figure()
plt.plot(list_epoch, G_list_float, label="generate", color='g')
plt.plot(list_epoch, D_list_float, label="discriminator", color='b')
plt.legend()
plt.title("DCGAN_Anime")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.savefig(full_path)
main.py
import torch
import torch.nn as nn
from torch.optim import Adam
from torchvision.utils import make_grid
from model import Generate,Discriminator,weight_init
from AnimeDataset import AnimeDataset
import matplotlib.pyplot as plt
import numpy as np
import scipy.misc
import os, argparse
from tqdm import tqdm
from utils import creat_gif, visualize_loss
def main():
parse = argparse.ArgumentParser()
parse.add_argument("--lr", type=float, default=0.0001,
help="learning rate of generate and discriminator")
parse.add_argument("--beta1", type=float, default=0.5,
help="adam optimizer parameter")
parse.add_argument("--batch_size", type=int, default=64,
help="number of dataset in every train or test iteration")
parse.add_argument("--dataset", type=str, default="anime",
help="base path for dataset")
parse.add_argument("--epochs", type=int, default=500,
help="number of training epochs")
parse.add_argument("--loaders", type=int, default=4,
help="number of parallel data loading processing")
parse.add_argument("--size_per_dataset", type=int, default=30000,
help="number of training data")
parse.add_argument("--pre_train", type=bool, default=False,
help="whether load pre_train model")
args = parse.parse_args()
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
if not os.path.exists("saved"):
os.mkdir("saved")
if not os.path.exists("saved/img"):
os.mkdir("saved/img")
if os.path.exists("faces"):
pass
else:
print("Don't find the dataset directory, please copy the link in website ,download and extract faces.tar.gz .\n \
https://drive.google.com/drive/folders/1mCsY5LEsgCnc0Txv0rpAUhKVPWVkbw5I \n ")
exit()
if args.pre_train:
generate = torch.load("saved/generate.t7").to(device)
discriminator = torch.load("saved/discriminator.t7").to(device)
else:
generate = Generate().to(device)
discriminator = Discriminator().to(device)
generate.apply(weight_init)
discriminator.apply(weight_init)
dataset = AnimeDataset(os.getcwd(), args.dataset, args.size_per_dataset)
dataload = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
criterion = nn.BCELoss().to(device)
optimizer_G = Adam(generate.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
optimizer_D = Adam(discriminator.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
fixed_noise = torch.randn(64, 100, 1, 1).to(device)
for epoch in range(args.epochs):
print("Main epoch{}:".format(epoch))
progress = tqdm(total=len(dataload.dataset))
loss_d, loss_g = 0, 0
for i, inp in enumerate(dataload):
# train discriminator
real_data = inp.float().to(device)
real_label = torch.ones(inp.size()[0]).to(device)
noise = torch.randn(inp.size()[0], 100, 1, 1).to(device)
fake_data = generate(noise)
fake_label = torch.zeros(fake_data.size()[0]).to(device)
optimizer_D.zero_grad()
real_output = discriminator(real_data)
real_loss = criterion(real_output.squeeze(), real_label)
real_loss.backward()
fake_output = discriminator(fake_data)
fake_loss = criterion(fake_output.squeeze(), fake_label)
fake_loss.backward()
loss_D = real_loss + fake_loss
optimizer_D.step()
#train generate
optimizer_G.zero_grad()
fake_data = generate(noise)
fake_label = torch.ones(fake_data.size()[0]).to(device)
fake_output = discriminator(fake_data)
loss_G = criterion(fake_output.squeeze(), fake_label)
loss_G.backward()
optimizer_G.step()
progress.update(dataload.batch_size)
progress.set_description("D:{}, G:{}".format(loss_D.item(), loss_G.item()))
loss_g += loss_G.item()
loss_d += loss_D.item()
loss_g /= (i+1)
loss_d /= (i+1)
with open("generate_loss.txt", 'a+') as f:
f.write("loss_G:{} \n".format(loss_G.item()))
with open("discriminator_loss.txt", 'a+') as f:
f.write("loss_D:{} \n".format(loss_D.item()))
if epoch % 20 == 0:
torch.save(generate, os.path.join(os.getcwd(), "saved/generate.t7"))
torch.save(discriminator, os.path.join(os.getcwd(), "saved/discriminator.t7"))
img = generate(fixed_noise).to("cpu").detach().numpy()
display_grid = np.zeros((8*96,8*96,3))
for j in range(int(64/8)):
for k in range(int(64/8)):
display_grid[j*96:(j+1)*96,k*96:(k+1)*96,:] = (img[k+8*j].transpose(1, 2, 0)+1)/2
img_save_path = os.path.join(os.getcwd(),"saved/img/{}.png".format(epoch))
scipy.misc.imsave(img_save_path, display_grid)
creat_gif("evolution.gif", os.path.join(os.getcwd(),"saved/img"))
visualize_loss("generate_loss.txt", "discriminator_loss.txt")
if __name__ == "__main__":
main()
代碼運行請參考github的readme,最后500個epoch的結果圖如下
實現WGAN
WGAN pytorch版本一直都有bug,目前還沒找到原因,實現了一個keras版本的,代碼如下(運行前記得看readme): ```python import os,scipy.misc import keras.backend as K from keras.models import Sequential, Model from keras.layers import Conv2D, ZeroPadding2D, BatchNormalization, Input from keras.layers import Conv2DTranspose, Reshape, Activation, Cropping2D, Flatten from keras.layers.advanced_activations import LeakyReLU from keras.optimizers import RMSprop from keras.activations import relu from keras.initializers import RandomNormal from keras.preprocessing.image import ImageDataGenerator import numpy as npconv_init = RandomNormal(0, 0.02)
gamma_init = RandomNormal(1., 0.02)
os.environ['KERAS_BACKEND']='tensorflow'
os.environ['TENSORFLOW_FLAGS']='floatX=float32,device=cuda'
def DCGAN_D(isize, nc, ndf):
inputs = Input(shape=(isize, isize, nc))
x = ZeroPadding2D()(inputs)
x = Conv2D(ndf, kernel_size=4, strides=2, use_bias=False, kernel_initializer=conv_init)(x)
x = LeakyReLU(alpha=0.2)(x)
for _ in range(4):
x = ZeroPadding2D()(x)
x = Conv2D(ndf*2, kernel_size=4, strides=2, use_bias=False, kernel_initializer=conv_init)(x)
x = BatchNormalization(epsilon=1.01e-5, gamma_init=gamma_init)(x, training=1)
x = LeakyReLU(alpha=0.2)(x)
ndf *= 2
x = Conv2D(1, kernel_size=3, strides=1, use_bias=False, kernel_initializer=conv_init)(x)
outputs = Flatten()(x)
return Model(inputs=inputs, outputs=outputs)
def DCGAN_G(isize, nz, ngf):
inputs = Input(shape=(nz,))
x = Reshape((1, 1, nz))(inputs)
x = Conv2DTranspose(filters=ngf, kernel_size=3, strides=2, use_bias=False,
kernel_initializer = conv_init)(x)
for _ in range(4):
x = Conv2DTranspose(filters=int(ngf/2), kernel_size=4, strides=2, use_bias=False,
kernel_initializer = conv_init)(x)
x = Cropping2D(cropping=1)(x)
x = BatchNormalization(epsilon=1.01e-5, gamma_init=gamma_init)(x, training=1)
x = Activation("relu")(x)
ngf = int(ngf/2)
x = Conv2DTranspose(filters=3, kernel_size=4, strides=2, use_bias=False,
kernel_initializer = conv_init)(x)
x = Cropping2D(cropping=1)(x)
outputs = Activation("tanh")(x)
return Model(inputs=inputs, outputs=outputs)
nc = 3
nz = 100
ngf = 1024
ndf = 64
imageSize = 96
batchSize = 64
lrD = 0.00005
lrG = 0.00005
clamp_lower, clamp_upper = -0.01, 0.01
netD = DCGAN_D(imageSize, nc, ndf)
netD.summary()
netG = DCGAN_G(imageSize, nz, ngf)
netG.summary()
clamp_updates = [K.update(v, K.clip(v, clamp_lower, clamp_upper))
for v in netD.trainable_weights]
netD_clamp = K.function([],[], clamp_updates)
netD_real_input = Input(shape=(imageSize, imageSize, nc))
noisev = Input(shape=(nz,))
loss_real = K.mean(netD(netD_real_input))
loss_fake = K.mean(netD(netG(noisev)))
loss = loss_fake - loss_real
training_updates = RMSprop(lr=lrD).get_updates(netD.trainable_weights,[], loss)
netD_train = K.function([netD_real_input, noisev],
[loss_real, loss_fake],
training_updates)
loss = -loss_fake
training_updates = RMSprop(lr=lrG).get_updates(netG.trainable_weights,[], loss)
netG_train = K.function([noisev], [loss], training_updates)
fixed_noise = np.random.normal(size=(batchSize, nz)).astype('float32')
datagen = ImageDataGenerator(
# featurewise_center=True,
# featurewise_std_normalization=True,
rotation_range=20,
rescale=1./255
)
train_generate = datagen.flow_from_directory("faces/", target_size=(96,96), batch_size=64,
shuffle=True, class_mode=None, save_format='jpg')
step = 0
print(dir(train_generate))
for step in range(100000):
for _ in range(5):
real_data = (np.array(train_generate.next())*2-1)
noise = np.random.normal(size=(batchSize, nz))
errD_real, errD_fake = netD_train([real_data, noise])
errD = errD_real - errD_fake
netD_clamp([])
noise = np.random.normal(size=(batchSize, nz))
errG, = netG_train([noise])
print('[%d] Loss_D: %f Loss_G: %f Loss_D_real: %f Loss_D_fake %f' % (step, errD, errG, errD_real, errD_fake))
if step%1000==0:
netD.save("discriminator.h5")
netG.save("generate.h5")
fake = netG.predict(fixed_noise)
display_grid = np.zeros((8*96,8*96,3))
for j in range(int(64/8)):
for k in range(int(64/8)):
display_grid[j*96:(j+1)*96,k*96:(k+1)*96,:] = fake[k+8*j]
img_save_path = os.path.join(os.getcwd(),"saved/img/{}.png".format(step))
scipy.misc.imsave(img_save_path, display_grid)
代碼運行請參考github的[readme][31],100000step的結果:
![wgan_keras_result.png-1175.5kB][32]
<h2>
<a id="F">
實現ConditionalGAN
</a>
</h2>
詳細運行請看github中的[readme][33]。
<h3>
<a id="F1">
CGAN.py
</a>
</h3>
```python
import torch,os,scipy.misc,random
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch.optim import Adam
from utils import load_Anime,test_Anime
class Reshape(nn.Module):
def __init__(self, *args):
super(Reshape, self).__init__()
self.shape = args
def forward(self, x):
return x.view(self.shape)
class Generate(nn.Module):
def __init__(self, z_dim, y_dim, image_height, image_width):
super(Generate, self).__init__()
self.conv_trans = nn.Sequential(
nn.Linear(z_dim+y_dim, (image_height//16)*(image_width//16)*384),
nn.BatchNorm1d((image_height//16)*(image_width//16)*384,
eps=1e-5, momentum=0.9, affine=True),
Reshape(-1, 384, image_height//16, image_width//16),
nn.ConvTranspose2d(384, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256, eps=1e-5, momentum=0.9, affine=True),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128, eps=1e-5, momentum=0.9, affine=True),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64, eps=1e-5, momentum=0.9, affine=True),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
nn.Tanh()
)
def forward(self, z, y):
z = torch.cat((z,y), dim=-1)
z = self.conv_trans(z)
return z
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 384, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
)
self.conv1 = nn.Sequential(
nn.Conv2d(407, 384, kernel_size=1, stride=1, padding=0, bias=False),
nn.LeakyReLU(0.2, inplace=True)
)
self.linear = nn.Linear(4*4*384, 1)
def forward(self, x, y):
x = self.conv(x)
y = torch.unsqueeze(y, 2)
y = torch.unsqueeze(y, 3)
y = y.expand(y.size()[0], y.size()[1], x.size()[2], x.size()[3])
x = torch.cat((x,y), dim=1)
x = self.conv1(x)
x = x.view(x.size()[0], -1)
x = self.linear(x)
x = x.squeeze()
x = F.sigmoid(x)
return x
def weight_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0,0.01)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0,0.01)
m.bias.data.fill_(0)
class CGAN(object):
def __init__(self, dataset_path, save_path, epochs, batchsize, z_dim, device, mode):
self.dataset_path = dataset_path
self.save_path = save_path
self.epochs = epochs
self.batch_size = batchsize
self.mode = mode
self.image_height = 64
self.image_width = 64
self.learning_rate = 0.0001
self.z_dim = z_dim
self.y_dim = 23
self.iters_d = 2
self.iters_g = 1
self.device = device
self.criterion = nn.BCELoss().to(device)
if mode == "train":
self.X, self.Y = load_Anime(self.dataset_path)
self.batch_nums = len(self.X)//self.batch_size
def train(self):
generate = Generate(self.z_dim, self.y_dim, self.image_height, self.image_width).to(self.device)
discriminator = Discriminator().to(self.device)
generate.apply(weight_init)
discriminator.apply(weight_init)
optimizer_G = Adam(generate.parameters(), lr=self.learning_rate)
optimizer_D = Adam(discriminator.parameters(), lr=self.learning_rate)
step = 0
for epoch in range(self.epochs):
print("Main epoch:{}".format(epoch))
for i in range(self.batch_nums):
step += 1
batch_images = torch.from_numpy(np.asarray(self.X[i*self.batch_size:(i+1)*self.batch_size]).astype(np.float32)).to(self.device)
batch_labels = torch.from_numpy(np.asarray(self.Y[i*self.batch_size:(i+1)*self.batch_size]).astype(np.float32)).to(self.device)
batch_images_wrong = torch.from_numpy(np.asarray(self.X[random.sample(range(len(self.X)), len(batch_images))]).astype(np.float32)).to(self.device)
batch_labels_wrong = torch.from_numpy(np.asarray(self.Y[random.sample(range(len(self.Y)), len(batch_images))]).astype(np.float32)).to(self.device)
batch_z = torch.from_numpy(np.random.normal(0, np.exp(-1 / np.pi), [self.batch_size, self.z_dim]).astype(np.float32)).to(self.device)
# discriminator twice, generate once
for _ in range(self.iters_d):
optimizer_D.zero_grad()
d_loss_real = self.criterion(discriminator(batch_images, batch_labels), torch.ones(self.batch_size).to(self.device))
d_loss_fake = (self.criterion(discriminator(batch_images, batch_labels_wrong), torch.zeros(self.batch_size).to(self.device)) \
+ self.criterion(discriminator(batch_images_wrong, batch_labels), torch.zeros(self.batch_size).to(self.device)) \
+ self.criterion(discriminator(generate(batch_z, batch_labels), batch_labels), torch.zeros(self.batch_size).to(self.device)))/3
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
optimizer_D.step()
for _ in range(self.iters_g):
optimizer_G.zero_grad()
g_loss = self.criterion(discriminator(generate(batch_z, batch_labels), batch_labels), torch.ones(self.batch_size).to(self.device))
g_loss.backward()
optimizer_G.step()
print("epoch:{}, step:{}, d_loss:{}, g_loss:{}".format(epoch, step, d_loss.item(), g_loss.item()))
#show result and save model
if (step)%5000 == 0:
z, y = test_Anime()
image = generate(torch.from_numpy(z).float().to(self.device),torch.from_numpy(y).float().to(self.device)).to("cpu").detach().numpy()
display_grid = np.zeros((5*64,5*64,3))
for j in range(5):
for k in range(5):
display_grid[j*64:(j+1)*64,k*64:(k+1)*64,:] = image[k+5*j].transpose(1, 2, 0)
img_save_path = os.path.join(self.save_path,"training_img/{}.png".format(step))
scipy.misc.imsave(img_save_path, display_grid)
torch.save(generate, os.path.join(self.save_path, "generate.t7"))
torch.save(discriminator, os.path.join(self.save_path, "discriminator.t7"))
def infer(self):
z, y = test_Anime()
generate = torch.load(os.path.join(self.save_path, "generate.t7")).to(self.device)
image = generate(torch.from_numpy(z).float().to(self.device),torch.from_numpy(y).float().to(self.device)).to("cpu").detach().numpy()
display_grid = np.zeros((5*64,5*64,3))
for j in range(5):
for k in range(5):
display_grid[j*64:(j+1)*64,k*64:(k+1)*64,:] = image[k+5*j].transpose(1, 2, 0)
img_save_path = os.path.join(self.save_path,"testing_img/test.png")
scipy.misc.imsave(img_save_path, display_grid)
print("infer ended, look the result in the save/testing_img/")
utils.py
```python # most code from https://github.com/JasonYao81000/MLDS2018SPRING/blob/master/hw3/hw3_2/ import numpy as np import cv2 import osdef test_Anime():
np.random.seed(999)
z = np.random.normal(0, np.exp(-1 / np.pi), [25, 62])
tag_dict = ['orange hair', 'white hair', 'aqua hair', 'gray hair', 'green hair', 'red hair', 'purple hair',
'pink hair', 'blue hair', 'black hair', 'brown hair', 'blonde hair',
'gray eyes', 'black eyes', 'orange eyes', 'pink eyes', 'yellow eyes',
'aqua eyes', 'purple eyes', 'green eyes', 'brown eyes', 'red eyes', 'blue eyes']
tag_txt = open("test.txt", 'r').readlines()
labels = []
for line in tag_txt:
label = np.zeros(len(tag_dict))
for i in range(len(tag_dict)):
if tag_dict[i] in line:
label[i] = 1
labels.append(label)
for i in range(len(tag_txt)):
for j in range(4):
labels.insert(5*i+j, labels[5*i])
return z, np.array(labels)
def load_Anime(dataset_filepath):
tag_csv_filename = dataset_filepath.replace('images/', 'tags.csv')
tag_dict = ['orange hair', 'white hair', 'aqua hair', 'gray hair', 'green hair', 'red hair', 'purple hair',
'pink hair', 'blue hair', 'black hair', 'brown hair', 'blonde hair',
'gray eyes', 'black eyes', 'orange eyes', 'pink eyes', 'yellow eyes',
'aqua eyes', 'purple eyes', 'green eyes', 'brown eyes', 'red eyes', 'blue eyes']
tag_csv = open(tag_csv_filename, 'r').readlines()
id_label = []
for line in tag_csv:
id, tags = line.split(',')
label = np.zeros(len(tag_dict))
for i in range(len(tag_dict)):
if tag_dict[i] in tags:
label[i] = 1
# Keep images with hair or eyes.
if np.sum(label) == 2 or np.sum(label) == 1:
id_label.append((id, label))
# Load file name of images.
image_file_list = []
for image_id, _ in id_label:
image_file_list.append(image_id + '.jpg')
# Resize image to 64x64.
image_height = 64
image_width = 64
image_channel = 3
# Allocate memory space of images and labels.
images = np.zeros((len(image_file_list), image_channel, image_width, image_height))
labels = np.zeros((len(image_file_list), len(tag_dict)))
print ('images.shape: ', images.shape)
print ('labels.shape: ', labels.shape)
print ('Loading images to numpy array...')
data_dir = dataset_filepath
for index, filename in enumerate(image_file_list):
images[index] = cv2.cvtColor(
cv2.resize(
cv2.imread(os.path.join(data_dir, filename), cv2.IMREAD_COLOR),
(image_width, image_height)),
cv2.COLOR_BGR2RGB).transpose(2,0,1)
labels[index] = id_label[index][34]
print ('Random shuffling images and labels...')
np.random.seed(9487)
indice = np.array(range(len(image_file_list)))
np.random.shuffle(indice)
images = images[indice]
labels = labels[indice]
print ('[Tip 1] Normalize the images between -1 and 1.')
# Tip 1. Normalize the inputs
# Normalize the images between -1 and 1.
# Tanh as the last layer of the generator output.
return (images / 127.5) - 1, labels
def check_folder(log_dir):
if not os.path.exists(log_dir):
os.makedirs(log_dir)
<h3>
<a id="F3">
main.py
</a>
</h3>
```python
import argparse
from CGAN import CGAN
from utils import check_folder
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=100, help='The number of epochs to run')
parser.add_argument('--batch_size', type=int, default=64, help='The size of batch')
parser.add_argument('--z_dim', type=int, default=62, help='Dimension of noise vector')
parser.add_argument('--dataset_path', type=str, default='./images/',
help='Directory name to save the checkpoints')
parser.add_argument('--save_path', type=str, default='./save/',
help='Directory name to save the generated images')
parser.add_argument('--mode', type=str, default='train',
help='train or infer')
parser.add_argument('--device', type=str, default='cuda',
help='train on GPU or CPU')
parser.add_argument('--save_training_img_path', type=str, default='./save/training_img/',
help='Directory name to save the training images')
parser.add_argument('--save_testing_img_path', type=str, default='./save/testing_img/',
help='Directory name to save the training images')
return parser.parse_args()
def main():
args = parse_args()
check_folder(args.dataset_path)
check_folder(args.save_path)
gan = CGAN(args.dataset_path,
args.save_path,
args.epochs,
args.batch_size,
args.z_dim,
args.device,
args.mode)
if args.mode == "train":
check_folder(args.save_training_img_path)
gan.train()
else:
check_folder(args.save_testing_img_path)
gan.infer()
if __name__ == "__main__":
main()
55000step的結果:
GAN小技巧
1.對真實圖片進行歸一化,與生成圖片分布一樣,也就是[-1,1].
2.隨機噪聲使用高斯分布,不要使用均勻分布,也就是在代碼中使用torch.randn,而不是torch.rand
3.初始化權重很有必要,詳細見model.py中的weight_init函數
4.在訓練時,在鑒別器中產生的noise,生成器也要用這個noise進行參數,這點很重要。我最開始的時候就是鑒別器隨機產生noise,生成器也隨機產生noise,訓練得很不好。
5.在訓練過程中,很有可能鑒別器的loss等於0(鑒別器太強了,起初我試過減小鑒別器的學習率,但還是會有這個情況,我猜想原因是在某一個batch中,鑒別器恰好將隨機噪聲產生的圖片和真實圖片完全區分開,loss為0),導致生成器崩潰(梯度彌散),所以最好按多少個epoch保存模型,然后在導入模型再訓練。個人覺得數據增強和增大batchsize會減弱這種情況的可能性,這個還未實踐。
參考
1 李宏毅GAN課程及PPT
2 DCGAN paper
3 chenyuntc