pytorch conditional GAN 調試筆記


推薦的幾個開源實現

  1. znxlwm 使用InfoGAN的結構,卷積反卷積
  2. eriklindernoren 把mnist轉成1維,label用了embedding
  3. wiseodd 直接從tensorflow代碼轉換過來的,數據集居然還用tf的數據集。。
  4. Yangyangii 轉1維向量,全連接
  5. FangYang970206 提供了多標簽作為條件的實現思路
  6. znxlwm 專門針對MNIST數據集的一個實現方法,轉1維,比較接近原paper的實現方法

訓練過程

  • 簡述
    # z - 隨機噪聲
    # X - 輸入數據
    # c - 輸入的label

    # ===== 訓練判別器D =====

    # 真數據輸入到D中
    D_real = D(X, c) 
    # 真數據D的判斷結果應盡可能接近1  
    D_loss_real = nn.binary_cross_entropy(D_real, ones_label)   

    # 生成隨機噪聲
    z = torch.rand((batch_size, self.z_dim)) 
    # G生成的偽數據,這一步的c可以用已知的,也可以重新隨機生成一些label,但總之這些c所生成的數據都是偽的
    G_sample = G(z, c)  
    # 偽數據輸入到D中
    D_fake = D(G_sample , c)    
    # 偽數據D的判斷結果應盡可能接近0
    D_loss_fake = nn.binary_cross_entropy(D_fake, zeros_label)     
   
    # D的loss定義為上面兩部分之和,即真數據要盡可能接近1,偽數據要盡可能接近0
    D_loss = D_loss_real + D_loss_fake 

    # 更新D的參數
    D_loss.backward()
    D_solver.step()

    # 在訓練G之前把梯度清零,也可以不這么做
    reset_grad()
    
    # ===== 訓練生成器G =====

    # 這里可以選擇,有的實現是直接用上面的z
    z = Variable(torch.randn(mb_size, Z_dim))  
    # 這里可以選擇用已知的c,或者重新采樣
    c = 重新隨機一些label  
    # 用G生成偽數據
    G_sample = G(z, c) 
    # 偽數據輸入到D中              
    D_fake = D(G_sample, c)     
    # 此時計算的是G的Loss,偽數據D的判斷結果應盡可能接近1,因為G要試圖騙過D
    G_loss = nn.binary_cross_entropy(D_fake, ones_label)  
    
    # 更新G的參數
    G_loss.backward()
    G_solver.step()

一些坑

  1. 計算D和G的loss時最好分別用不同的隨機噪聲,否則有可能訓練過程不會收斂,而且結果差
  2. 注意,訓練的時候隨機噪聲的分布應該要保持和測試時的分布一致,不要一個用均勻分布,一個用正態分布

初步結果

哈哈哈看到終於訓練出來像樣的數字,還是有點小成就的


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM