源碼地址:https://github.com/aitorzip/PyTorch-CycleGAN
訓練的代碼見於train.py,首先定義好網絡,兩個生成器A2B, B2A和兩個判別器A, B,以及對應的優化器(優化器的設置保證了只更新生成器或判別器,不會互相影響)
###### Definition of variables ###### # Networks netG_A2B = Generator(opt.input_nc, opt.output_nc) netG_B2A = Generator(opt.output_nc, opt.input_nc) netD_A = Discriminator(opt.input_nc) netD_B = Discriminator(opt.output_nc)
# Optimizers & LR schedulers optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=opt.lr, betas=(0.5, 0.999)) optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999)) optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999))
然后是數據
# Dataset loader transforms_ = [ transforms.Resize(int(opt.size*1.12), Image.BICUBIC), transforms.RandomCrop(opt.size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ] dataloader = DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, unaligned=True), batch_size=opt.batchSize, shuffle=True, num_workers=opt.n_cpu)
接着就可以求取損失,反傳梯度,更新網絡,更新網絡的時候首先更新生成器,然后分別更新兩個判別器
生成器:損失函數=身份損失+對抗損失+循環一致損失
###### Generators A2B and B2A ###### optimizer_G.zero_grad() # Identity loss # G_A2B(B) should equal B if real B is fed same_B = netG_A2B(real_B) loss_identity_B = criterion_identity(same_B, real_B)*5.0 # G_B2A(A) should equal A if real A is fed same_A = netG_B2A(real_A) loss_identity_A = criterion_identity(same_A, real_A)*5.0 # GAN loss fake_B = netG_A2B(real_A) pred_fake = netD_B(fake_B) loss_GAN_A2B = criterion_GAN(pred_fake, target_real) fake_A = netG_B2A(real_B) pred_fake = netD_A(fake_A) loss_GAN_B2A = criterion_GAN(pred_fake, target_real) # Cycle loss recovered_A = netG_B2A(fake_B) loss_cycle_ABA = criterion_cycle(recovered_A, real_A)*10.0 recovered_B = netG_A2B(fake_A) loss_cycle_BAB = criterion_cycle(recovered_B, real_B)*10.0 # Total loss loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB loss_G.backward()
optimizer_G.step()
判別器A 損失函數= 真實樣本分類損失 + 虛假樣本分類損失
###### Discriminator A ###### optimizer_D_A.zero_grad() # Real loss pred_real = netD_A(real_A) loss_D_real = criterion_GAN(pred_real, target_real) # Fake loss fake_A = fake_A_buffer.push_and_pop(fake_A) pred_fake = netD_A(fake_A.detach()) loss_D_fake = criterion_GAN(pred_fake, target_fake) # Total loss loss_D_A = (loss_D_real + loss_D_fake)*0.5 loss_D_A.backward() optimizer_D_A.step() ###################################
判別器B 損失函數= 真實樣本分類損失 + 虛假樣本分類損失
###### Discriminator B ###### optimizer_D_B.zero_grad() # Real loss pred_real = netD_B(real_B) loss_D_real = criterion_GAN(pred_real, target_real) # Fake loss fake_B = fake_B_buffer.push_and_pop(fake_B) pred_fake = netD_B(fake_B.detach()) loss_D_fake = criterion_GAN(pred_fake, target_fake) # Total loss loss_D_B = (loss_D_real + loss_D_fake)*0.5 loss_D_B.backward() optimizer_D_B.step() ###################################
可以注意到,判別器損失中,虛假樣本fake_A,fake_B都采用detach()操作,脫離計算圖,這樣判別器的損失進行反向傳播不會對整個網絡計算梯度,避免了不必要的計算