RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation


問題

在用pytorch生成對抗網絡的時候,出現錯誤Runtime Error: one of the variables needed for gradient computation has been modified by an inplace operation,特記錄排坑記錄。

環境配置

windows10 2004
python 3.7.4
pytorch 1.7.0 + cpu

解決過程

  • 嘗試一

這段錯誤代碼看上去不難理解,意思為:計算梯度所需的某變量已被一就地操作修改。什么是就地操作呢,舉個例子如x += 1就是典型的就地操作,可將其改為y = x + 1。但很遺憾,這樣並沒有解決我的問題,這種方法的介紹如下。
在網上搜了很多相關博客,大多原因如下:

由於0.4.0把Varible和Tensor融合為一個Tensor,inplace操作,之前對Varible能用,但現在對Tensor,就會出錯了。

所以解決方案很簡單:將所有inplace操作轉換為非inplace操作。如將x += 1換為y = x + 1
仍然有一個問題,即如何找到inplace操作,這里提供一個小trick:分階段調用y.backward(),若報錯,則說明這之前有問題;反之則說明錯誤在該行之后。

  • 嘗試二

在我的代碼里根本就沒有找到任何inplace操作,因此上面這種方法行不通。自己盯着代碼,debug,啥也看不出來,好久......
忽然有了新idea。我的訓練階段的代碼如下:

for epoch in range(1, epochs + 1):
    for idx, (lr, hr) in enumerate(traindata_loader):
        lrs = lr.to(device)
        hrs = hr.to(device)

        # update the discriminator
        netD.zero_grad()
        logits_fake = netD(netG(lrs).detach())
        logits_real = netD(hrs)
        # Label smoothing
        real = (torch.rand(logits_real.size()) * 0.25 + 0.85).clone().detach().to(device)
        fake = (torch.rand(logits_fake.size()) * 0.15).clone().detach().to(device)
        d_loss = bce(logits_real, real) + bce(logits_fake, fake)
        d_loss.backward(retain_graph=True)
        optimizerD.step()

        # update the generator
        netG.zero_grad()
        # !!!問題出錯行
        g_loss = contentLoss(netG(lrs), hrs) + adversarialLoss(logits_fake)
        g_loss.backward()        
        optimizerG.step()

判別器loss的backward是正常的,生成器loss的backward有問題。觀察到g_loss由兩項組成,所以很自然的想法就是刪掉其中一項看是否正常。結果為:只保留第一項程序正常運行;g_loss中包含第二項程序就出錯。
因此去看了adversarialLoss的代碼:

class AdversarialLoss(nn.Module):
    def __init__(self):
        super(AdversarialLoss, self).__init__()
        self.bec_loss = nn.BCELoss()

    def forward(self, logits_fake):
        # Adversarial Loss
        # !!! 問題在這,logits_fake加上detach后就可以正常運行
        adversarial_loss = self.bec_loss(logits_fake, torch.ones_like(logits_fake))
        return 0.001 * adversarial_loss

看不出來任何問題,只能挨個試。這里只有兩個變量:logits_faketorch.ones_like(logits_fake)。后者為常量,所以試着固定logits_fake,不讓其參與訓練,程序竟能運行了!

class AdversarialLoss(nn.Module):
    def __init__(self):
        super(AdversarialLoss, self).__init__()
        self.bec_loss = nn.BCELoss()

    def forward(self, logits_fake):
        # Adversarial Loss
        # !!! 問題在這,logits_fake加上detach后就可以正常運行
        adversarial_loss = self.bec_loss(logits_fake.detach(), torch.ones_like(logits_fake))
        return 0.001 * adversarial_loss

由此知道了被修改的變量是logits_fake。盡管程序可以運行了,但這樣做不一定合理。類AdversarialLoss中沒有對logits_fake進行修改,所以返回剛才的訓練程序中。

for epoch in range(1, epochs + 1):
    for idx, (lr, hr) in enumerate(traindata_loader):
        lrs = lr.to(device)
        hrs = hr.to(device)

        # update the discriminator
        netD.zero_grad()
        logits_fake = netD(netG(lrs).detach())
        logits_real = netD(hrs)
        # Label smoothing
        real = (torch.rand(logits_real.size()) * 0.25 + 0.85).clone().detach().to(device)
        fake = (torch.rand(logits_fake.size()) * 0.15).clone().detach().to(device)
        d_loss = bce(logits_real, real) + bce(logits_fake, fake)
        d_loss.backward(retain_graph=True)
        # 這里進行的更新操作
        optimizerD.step()

        # update the generator
        netG.zero_grad()
        # !!!問題出錯行
        g_loss = contentLoss(netG(lrs), hrs) + adversarialLoss(logits_fake)
        g_loss.backward()        
        optimizerG.step()

注意到Discriminator在出錯行之前進行了更新操作,因此真相呼之欲出————optimizerD.step()logits_fake進行了修改。直接將其挪到倒數第二行即可,修改后代碼為:

for epoch in range(1, epochs + 1):
    for idx, (lr, hr) in enumerate(traindata_loader):
        lrs = lr.to(device)
        hrs = hr.to(device)

        # update the discriminator
        netD.zero_grad()
        logits_fake = netD(netG(lrs).detach())
        logits_real = netD(hrs)
        # Label smoothing
        real = (torch.rand(logits_real.size()) * 0.25 + 0.85).clone().detach().to(device)
        fake = (torch.rand(logits_fake.size()) * 0.15).clone().detach().to(device)
        d_loss = bce(logits_real, real) + bce(logits_fake, fake)
        d_loss.backward(retain_graph=True)
        

        # update the generator
        netG.zero_grad()
        g_loss = contentLoss(netG(lrs), hrs) + adversarialLoss(logits_fake)
        g_loss.backward()   
        optimizerD.step()     
        optimizerG.step()

程序終於正常運行了,耶( •̀ ω •́ )y!

總結

原因:在計算生成器網絡梯度之前先對判別器進行更新,修改了某些值,導致Generator網絡的梯度計算失敗。
解決方法:將Discriminator的更新步驟放到Generator的梯度計算步驟后面。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM