推薦的幾個開源實現
- znxlwm 使用InfoGAN的結構,卷積反卷積
- eriklindernoren 把mnist轉成1維,label用了embedding
- wiseodd 直接從tensorflow代碼轉換過來的,數據集居然還用tf的數據集。。
- Yangyangii 轉1維向量,全連接
- FangYang970206 提供了多標簽作為條件的實現思路
- znxlwm 專門針對MNIST數據集的一個實現方法,轉1維,比較接近原paper的實現方法
訓練過程
- 簡述
# z - 隨機噪聲
# X - 輸入數據
# c - 輸入的label
# ===== 訓練判別器D =====
# 真數據輸入到D中
D_real = D(X, c)
# 真數據D的判斷結果應盡可能接近1
D_loss_real = nn.binary_cross_entropy(D_real, ones_label)
# 生成隨機噪聲
z = torch.rand((batch_size, self.z_dim))
# G生成的偽數據,這一步的c可以用已知的,也可以重新隨機生成一些label,但總之這些c所生成的數據都是偽的
G_sample = G(z, c)
# 偽數據輸入到D中
D_fake = D(G_sample , c)
# 偽數據D的判斷結果應盡可能接近0
D_loss_fake = nn.binary_cross_entropy(D_fake, zeros_label)
# D的loss定義為上面兩部分之和,即真數據要盡可能接近1,偽數據要盡可能接近0
D_loss = D_loss_real + D_loss_fake
# 更新D的參數
D_loss.backward()
D_solver.step()
# 在訓練G之前把梯度清零,也可以不這么做
reset_grad()
# ===== 訓練生成器G =====
# 這里可以選擇,有的實現是直接用上面的z
z = Variable(torch.randn(mb_size, Z_dim))
# 這里可以選擇用已知的c,或者重新采樣
c = 重新隨機一些label
# 用G生成偽數據
G_sample = G(z, c)
# 偽數據輸入到D中
D_fake = D(G_sample, c)
# 此時計算的是G的Loss,偽數據D的判斷結果應盡可能接近1,因為G要試圖騙過D
G_loss = nn.binary_cross_entropy(D_fake, ones_label)
# 更新G的參數
G_loss.backward()
G_solver.step()
一些坑
- 計算D和G的loss時最好分別用不同的隨機噪聲,否則有可能訓練過程不會收斂,而且結果差
- 注意,訓練的時候隨機噪聲的分布應該要保持和測試時的分布一致,不要一個用均勻分布,一個用正態分布
初步結果
哈哈哈看到終於訓練出來像樣的數字,還是有點小成就的