關於pytorch下GAN loss的backward和step等注意事項


首先不妨假設最簡單的一種情況:

假設$G$和$D$的損失函數:

那么計算梯度有:

 

 

 

第一種正確的方式:

import torch
from torch import nn


def set_requires_grad(net: nn.Module, mode=True):
    for p in net.parameters():
        p.requires_grad_(mode)


print(f"Pytorch version: {torch.__version__} \n")

X = torch.ones(size=[1, 1, 1, 1]).requires_grad_(False)

G = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0, bias=False)
G.weight.data.fill_(0.5)
G_optim = torch.optim.SGD(G.parameters(), lr=1.0)

D = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0, bias=False)
D.weight.data.fill_(0.7)
D_optim = torch.optim.SGD(D.parameters(), lr=1.0)print(f"Init grad: {G.weight.grad} {D.weight.grad}")
print(f"Init weight: {G.weight.detach()} {D.weight.detach()} \n")

# Zero gradient of 2 layers.
G_optim.zero_grad()
D_optim.zero_grad()

# Forward pass.
Y = G(X)

# Calculate D loss.
D_loss = D(Y.detach()) ** 2

# Calculate G loss.
G_loss = D(Y) ** 2

# Backward D loss.
D_loss.backward(retain_graph=True)

print(f"Checkpoint 1 grad: {G.weight.grad} {D.weight.grad}")
print(f"Checkpoint 1 weight: {G.weight.detach()} {D.weight.detach()} \n")

# Backward G loss.
set_requires_grad(D, False)  # Turn off D's grad to avoid redundant gradient accumulation on D.
G_loss.backward()
set_requires_grad(D, True)

print(f"Checkpoint 2 grad: {G.weight.grad} {D.weight.grad}")
print(f"Checkpoint 2 weight: {G.weight.detach()} {D.weight.detach()} \n")

# Update G.
G_optim.step()

print(f"Checkpoint 3 grad: {G.weight.grad} {D.weight.grad}")
print(f"Checkpoint 3 weight: {G.weight.detach()} {D.weight.detach()} \n")

# Update D.
D_optim.step()

print(f"Checkpoint 4 grad: {G.weight.grad} {D.weight.grad}")
print(f"Checkpoint 4 weight: {G.weight.detach()} {D.weight.detach()} \n")

 

運行結果:

Pytorch version: 1.9.0 

Init grad: None None
Init weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) 

Checkpoint 1 grad: None tensor([[[[0.3500]]]])
Checkpoint 1 weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) 

Checkpoint 2 grad: tensor([[[[0.4900]]]]) tensor([[[[0.3500]]]])
Checkpoint 2 weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) 

Checkpoint 3 grad: tensor([[[[0.4900]]]]) tensor([[[[0.3500]]]])
Checkpoint 3 weight: tensor([[[[0.0100]]]]) tensor([[[[0.7000]]]]) 

Checkpoint 4 grad: tensor([[[[0.4900]]]]) tensor([[[[0.3500]]]])
Checkpoint 4 weight: tensor([[[[0.0100]]]]) tensor([[[[0.3500]]]]) 

 

分析:

此時,$x = 1.0, y = 0.5, z = 0.7, \theta_G = 0.5, \theta_D = 0.7$,

首先checkpoint 1處,D loss的梯度反傳到D網絡上得到了 $2 y^2 \cdot \theta_D = 2 \times 0.25 \times 0.7 = 0.35$,沒有反傳到G網絡。

其次checkpoint 2處,G loss的梯度反傳,D網絡梯度不受影響(因為所有網絡參數的requires_grad := False),在G網絡上得到了 $2 \times 0.5 \times 0.7^2 \times 1.0 = 0.49$。注意,這里的D網絡參數 $\theta_D = 0.7$,因為盡管此時D loss已經反傳,但是沒有D optimizer的step()就還沒有更新D網絡。

最后checkpoint 3和4處,就是兩個optimizer的step()分別更新G網絡和D網絡,這兩個step()之間的先后順序對最終的網絡更新結果沒什么影響。

 

注意,這種做法更新G網絡時,對應的是更新前的D網絡。

 

在Pytorch 1.2上的結果:

Pytorch version: 1.2.0 

Init grad: None None
Init weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) 

Checkpoint 1 grad: None tensor([[[[0.3500]]]])
Checkpoint 1 weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) 

Checkpoint 2 grad: tensor([[[[0.4900]]]]) tensor([[[[0.3500]]]])
Checkpoint 2 weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) 

Checkpoint 3 grad: tensor([[[[0.4900]]]]) tensor([[[[0.3500]]]])
Checkpoint 3 weight: tensor([[[[0.0100]]]]) tensor([[[[0.7000]]]]) 

Checkpoint 4 grad: tensor([[[[0.4900]]]]) tensor([[[[0.3500]]]])
Checkpoint 4 weight: tensor([[[[0.0100]]]]) tensor([[[[0.3500]]]]) 

可以看到,也是一樣的。

 

錯誤的做法:

在 G_loss.backward() 前后不進行對D網絡的網絡參數的requires_grad的關和開,使得G loss反傳了多余的梯度到D網絡上。

 


 

第二種正確的方式:

import torch
from torch import nn


print(f"Pytorch version: {torch.__version__} \n") X = torch.ones(size=[1, 1, 1, 1]).requires_grad_(False) G = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0, bias=False) G.weight.data.fill_(0.5) G_optim = torch.optim.SGD(G.parameters(), lr=1.0) D = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0, bias=False) D.weight.data.fill_(0.7) D_optim = torch.optim.SGD(D.parameters(), lr=1.0) print(f"Init grad: {G.weight.grad} {D.weight.grad}") print(f"Init weight: {G.weight.detach()} {D.weight.detach()} \n") # Forward pass. Y = G(X) # Zero gradient of D. D_optim.zero_grad() # Calculate D loss. D_loss = D(Y.detach()) ** 2 # Backward D loss. D_loss.backward() # Update D. D_optim.step() print(f"Checkpoint 1 grad: {G.weight.grad} {D.weight.grad}") print(f"Checkpoint 1 weight: {G.weight.detach()} {D.weight.detach()} \n") # Zero gradient of G. G_optim.zero_grad() # Calculate G loss. G_loss = D(Y) ** 2 # Backward G loss. G_loss.backward() # Update G. G_optim.step() print(f"Checkpoint 2 grad: {G.weight.grad} {D.weight.grad}") print(f"Checkpoint 2 weight: {G.weight.detach()} {D.weight.detach()} \n")

 

分析:

這種方式就很明了,更新D網絡和更新G網絡完全分開。

此時,$x = 1.0, y = 0.5, z = 0.7, \theta_G = 0.5, \theta_D = 0.7$,

首先checkpoint 1處,D loss的梯度反傳到D網絡上得到了 $2 y^2 \cdot \theta_D = 2 \times 0.25 \times 0.7 = 0.35$,沒有反傳到G網絡。

其次checkpoint 2處,G loss的梯度同時反傳到了G網絡和D網絡上,但是由於只更新G網絡,D網絡上的梯度會在下一個iteration中被zero_grad()清零。G網絡上的梯度是 $2 \times 0.5 \times 0.35^2 \times 1.0 = 0.1225$,注意此時的D網絡參數已經從 $0.7$ 更新為 $0.7 - 1.0 \times 0.35 = 0.35$(梯度下降:原參數減去學習率乘梯度,得新參數)。

 

注意,這種做法更新G網絡時,對應的是已經更新后的D網絡。事實上,我認為這種做法更正確,同時在邏輯上也更加清晰、更好理解。

 

運行結果:

Pytorch version: 1.9.0 

Init grad: None None
Init weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) 

Checkpoint 1 grad: None tensor([[[[0.3500]]]])
Checkpoint 1 weight: tensor([[[[0.5000]]]]) tensor([[[[0.3500]]]]) 

Checkpoint 2 grad: tensor([[[[0.1225]]]]) tensor([[[[0.5250]]]])
Checkpoint 2 weight: tensor([[[[0.3775]]]]) tensor([[[[0.3500]]]]) 
Pytorch version: 1.2.0 

Init grad: None None
Init weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) 

Checkpoint 1 grad: None tensor([[[[0.3500]]]])
Checkpoint 1 weight: tensor([[[[0.5000]]]]) tensor([[[[0.3500]]]]) 

Checkpoint 2 grad: tensor([[[[0.1225]]]]) tensor([[[[0.5250]]]])
Checkpoint 2 weight: tensor([[[[0.3775]]]]) tensor([[[[0.3500]]]]) 

 


 

 

一種錯誤的方式:

import torch
from torch import nn


print(f"Pytorch version: {torch.__version__} \n")

X = torch.ones(size=[1, 1, 1, 1]).requires_grad_(False)

G = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0, bias=False)
G.weight.data.fill_(0.5)
G_optim = torch.optim.SGD(G.parameters(), lr=1.0)

D = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0, bias=False)
D.weight.data.fill_(0.7)
D_optim = torch.optim.SGD(D.parameters(), lr=1.0)

print(f"Init grad: {G.weight.grad} {D.weight.grad}")
print(f"Init weight: {G.weight.detach()} {D.weight.detach()} \n")

# Forward pass.
Y = G(X)

# Zero gradient of G & D.
G_optim.zero_grad()
D_optim.zero_grad()

# Calculate D loss.
D_loss = D(Y.detach()) ** 2

# Calculate G loss.
G_loss = D(Y) ** 2

# Backward D loss.
D_loss.backward()

# Update D.
D_optim.step()

print(f"Checkpoint 1 grad: {G.weight.grad} {D.weight.grad}")
print(f"Checkpoint 1 weight: {G.weight.detach()} {D.weight.detach()} \n")

# Backward G loss.
G_loss.backward()

# Update G.
G_optim.step()

print(f"Checkpoint 2 grad: {G.weight.grad} {D.weight.grad}")
print(f"Checkpoint 2 weight: {G.weight.detach()} {D.weight.detach()} \n")

 

運行結果:

Pytorch version: 1.2.0 

Init grad: None None
Init weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) 

Checkpoint 1 grad: None tensor([[[[0.3500]]]])
Checkpoint 1 weight: tensor([[[[0.5000]]]]) tensor([[[[0.3500]]]]) 

Checkpoint 2 grad: tensor([[[[0.2450]]]]) tensor([[[[0.7000]]]])
Checkpoint 2 weight: tensor([[[[0.2550]]]]) tensor([[[[0.3500]]]]) 
Pytorch version: 1.9.0 

Init grad: None None
Init weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) 

Checkpoint 1 grad: None tensor([[[[0.3500]]]])
Checkpoint 1 weight: tensor([[[[0.5000]]]]) tensor([[[[0.3500]]]]) 

Traceback (most recent call last):
  File "D:/Program/PycharmProjects/Test/test.py", line 65, in <module>
    G_loss.backward()
  File "D:\Program\Anaconda\envs\py38_torch19\lib\site-packages\torch\_tensor.py", line 255, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "D:\Program\Anaconda\envs\py38_torch19\lib\site-packages\torch\autograd\__init__.py", line 147, in backward
    Variable._execution_engine.run_backward(
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1, 1, 1, 1]] is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

 

分析:

這種做法的錯誤核心就在於,你在計算G loss時(前向傳播時)使用的是更新前的D網絡,但是你在G loss反向傳播時D網絡已經變成了更新后的,

這個錯誤在較低版本(1.2.0)的Pytorch上並沒有報錯,我們可以看到它在計算G網絡的梯度時,似乎用了更新前后的 $\theta_D$ 相乘 $2 \times 0.5 \times (0.7 \times 0.35) \times 1.0 = 0.245$,而非單純更新前的 ${\theta_D}^2 = 0.7^2$,或者單純更新后的 ${\theta_D}^2 = 0.35^2$.

至於在較高版本(1.9.0)的Pytorch上則直接報錯了,估計是因為step()更新了D網絡之后,G loss對應的計算圖被破壞了,因此直接報了一個 "inplace operation" 錯誤。

 

因此,使用低版本的Pytorch時千萬一定要注意這種比較隱蔽的錯誤寫法!!!!


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM