問題
在用pytorch跑生成對抗網絡的時候,出現錯誤Runtime Error: one of the variables needed for gradient computation has been modified by an inplace operation
,特記錄排坑記錄。
環境配置
windows10 2004
python 3.7.4
pytorch 1.7.0 + cpu
解決過程
- 嘗試一
這段錯誤代碼看上去不難理解,意思為:計算梯度所需的某變量已被一就地操作修改。什么是就地操作呢,舉個例子如x += 1
就是典型的就地操作,可將其改為y = x + 1
。但很遺憾,這樣並沒有解決我的問題,這種方法的介紹如下。
在網上搜了很多相關博客,大多原因如下:
由於0.4.0把Varible和Tensor融合為一個Tensor,inplace操作,之前對Varible能用,但現在對Tensor,就會出錯了。
所以解決方案很簡單:將所有inplace操作轉換為非inplace操作。如將x += 1
換為y = x + 1
。
仍然有一個問題,即如何找到inplace操作,這里提供一個小trick:分階段調用y.backward()
,若報錯,則說明這之前有問題;反之則說明錯誤在該行之后。
- 嘗試二
在我的代碼里根本就沒有找到任何inplace操作,因此上面這種方法行不通。自己盯着代碼,debug,啥也看不出來,好久......
忽然有了新idea。我的訓練階段的代碼如下:
for epoch in range(1, epochs + 1):
for idx, (lr, hr) in enumerate(traindata_loader):
lrs = lr.to(device)
hrs = hr.to(device)
# update the discriminator
netD.zero_grad()
logits_fake = netD(netG(lrs).detach())
logits_real = netD(hrs)
# Label smoothing
real = (torch.rand(logits_real.size()) * 0.25 + 0.85).clone().detach().to(device)
fake = (torch.rand(logits_fake.size()) * 0.15).clone().detach().to(device)
d_loss = bce(logits_real, real) + bce(logits_fake, fake)
d_loss.backward(retain_graph=True)
optimizerD.step()
# update the generator
netG.zero_grad()
# !!!問題出錯行
g_loss = contentLoss(netG(lrs), hrs) + adversarialLoss(logits_fake)
g_loss.backward()
optimizerG.step()
判別器loss的backward是正常的,生成器loss的backward有問題。觀察到g_loss由兩項組成,所以很自然的想法就是刪掉其中一項看是否正常。結果為:只保留第一項程序正常運行;g_loss中包含第二項程序就出錯。
因此去看了adversarialLoss
的代碼:
class AdversarialLoss(nn.Module):
def __init__(self):
super(AdversarialLoss, self).__init__()
self.bec_loss = nn.BCELoss()
def forward(self, logits_fake):
# Adversarial Loss
# !!! 問題在這,logits_fake加上detach后就可以正常運行
adversarial_loss = self.bec_loss(logits_fake, torch.ones_like(logits_fake))
return 0.001 * adversarial_loss
看不出來任何問題,只能挨個試。這里只有兩個變量:logits_fake
和torch.ones_like(logits_fake)
。后者為常量,所以試着固定logits_fake
,不讓其參與訓練,程序竟能運行了!
class AdversarialLoss(nn.Module):
def __init__(self):
super(AdversarialLoss, self).__init__()
self.bec_loss = nn.BCELoss()
def forward(self, logits_fake):
# Adversarial Loss
# !!! 問題在這,logits_fake加上detach后就可以正常運行
adversarial_loss = self.bec_loss(logits_fake.detach(), torch.ones_like(logits_fake))
return 0.001 * adversarial_loss
由此知道了被修改的變量是logits_fake。盡管程序可以運行了,但這樣做不一定合理。類AdversarialLoss
中沒有對logits_fake
進行修改,所以返回剛才的訓練程序中。
for epoch in range(1, epochs + 1):
for idx, (lr, hr) in enumerate(traindata_loader):
lrs = lr.to(device)
hrs = hr.to(device)
# update the discriminator
netD.zero_grad()
logits_fake = netD(netG(lrs).detach())
logits_real = netD(hrs)
# Label smoothing
real = (torch.rand(logits_real.size()) * 0.25 + 0.85).clone().detach().to(device)
fake = (torch.rand(logits_fake.size()) * 0.15).clone().detach().to(device)
d_loss = bce(logits_real, real) + bce(logits_fake, fake)
d_loss.backward(retain_graph=True)
# 這里進行的更新操作
optimizerD.step()
# update the generator
netG.zero_grad()
# !!!問題出錯行
g_loss = contentLoss(netG(lrs), hrs) + adversarialLoss(logits_fake)
g_loss.backward()
optimizerG.step()
注意到Discriminator在出錯行之前進行了更新操作,因此真相呼之欲出————optimizerD.step()
對logits_fake
進行了修改。直接將其挪到倒數第二行即可,修改后代碼為:
for epoch in range(1, epochs + 1):
for idx, (lr, hr) in enumerate(traindata_loader):
lrs = lr.to(device)
hrs = hr.to(device)
# update the discriminator
netD.zero_grad()
logits_fake = netD(netG(lrs).detach())
logits_real = netD(hrs)
# Label smoothing
real = (torch.rand(logits_real.size()) * 0.25 + 0.85).clone().detach().to(device)
fake = (torch.rand(logits_fake.size()) * 0.15).clone().detach().to(device)
d_loss = bce(logits_real, real) + bce(logits_fake, fake)
d_loss.backward(retain_graph=True)
# update the generator
netG.zero_grad()
g_loss = contentLoss(netG(lrs), hrs) + adversarialLoss(logits_fake)
g_loss.backward()
optimizerD.step()
optimizerG.step()
程序終於正常運行了,耶( •̀ ω •́ )y!
總結
原因:在計算生成器網絡梯度之前先對判別器進行更新,修改了某些值,導致Generator網絡的梯度計算失敗。
解決方法:將Discriminator的更新步驟放到Generator的梯度計算步驟后面。