首先不妨假設最簡單的一種情況:
假設$G$和$D$的損失函數:
那么計算梯度有:
第一種正確的方式:
import torch from torch import nn def set_requires_grad(net: nn.Module, mode=True): for p in net.parameters(): p.requires_grad_(mode) print(f"Pytorch version: {torch.__version__} \n") X = torch.ones(size=[1, 1, 1, 1]).requires_grad_(False) G = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0, bias=False) G.weight.data.fill_(0.5) G_optim = torch.optim.SGD(G.parameters(), lr=1.0) D = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0, bias=False) D.weight.data.fill_(0.7) D_optim = torch.optim.SGD(D.parameters(), lr=1.0)print(f"Init grad: {G.weight.grad} {D.weight.grad}") print(f"Init weight: {G.weight.detach()} {D.weight.detach()} \n") # Zero gradient of 2 layers. G_optim.zero_grad() D_optim.zero_grad() # Forward pass. Y = G(X) # Calculate D loss. D_loss = D(Y.detach()) ** 2 # Calculate G loss. G_loss = D(Y) ** 2 # Backward D loss. D_loss.backward(retain_graph=True) print(f"Checkpoint 1 grad: {G.weight.grad} {D.weight.grad}") print(f"Checkpoint 1 weight: {G.weight.detach()} {D.weight.detach()} \n") # Backward G loss. set_requires_grad(D, False) # Turn off D's grad to avoid redundant gradient accumulation on D. G_loss.backward() set_requires_grad(D, True) print(f"Checkpoint 2 grad: {G.weight.grad} {D.weight.grad}") print(f"Checkpoint 2 weight: {G.weight.detach()} {D.weight.detach()} \n") # Update G. G_optim.step() print(f"Checkpoint 3 grad: {G.weight.grad} {D.weight.grad}") print(f"Checkpoint 3 weight: {G.weight.detach()} {D.weight.detach()} \n") # Update D. D_optim.step() print(f"Checkpoint 4 grad: {G.weight.grad} {D.weight.grad}") print(f"Checkpoint 4 weight: {G.weight.detach()} {D.weight.detach()} \n")
運行結果:
Pytorch version: 1.9.0 Init grad: None None Init weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) Checkpoint 1 grad: None tensor([[[[0.3500]]]]) Checkpoint 1 weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) Checkpoint 2 grad: tensor([[[[0.4900]]]]) tensor([[[[0.3500]]]]) Checkpoint 2 weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) Checkpoint 3 grad: tensor([[[[0.4900]]]]) tensor([[[[0.3500]]]]) Checkpoint 3 weight: tensor([[[[0.0100]]]]) tensor([[[[0.7000]]]]) Checkpoint 4 grad: tensor([[[[0.4900]]]]) tensor([[[[0.3500]]]]) Checkpoint 4 weight: tensor([[[[0.0100]]]]) tensor([[[[0.3500]]]])
分析:
此時,$x = 1.0, y = 0.5, z = 0.7, \theta_G = 0.5, \theta_D = 0.7$,
首先checkpoint 1處,D loss的梯度反傳到D網絡上得到了 $2 y^2 \cdot \theta_D = 2 \times 0.25 \times 0.7 = 0.35$,沒有反傳到G網絡。
其次checkpoint 2處,G loss的梯度反傳,D網絡梯度不受影響(因為所有網絡參數的requires_grad := False),在G網絡上得到了 $2 \times 0.5 \times 0.7^2 \times 1.0 = 0.49$。注意,這里的D網絡參數 $\theta_D = 0.7$,因為盡管此時D loss已經反傳,但是沒有D optimizer的step()就還沒有更新D網絡。
最后checkpoint 3和4處,就是兩個optimizer的step()分別更新G網絡和D網絡,這兩個step()之間的先后順序對最終的網絡更新結果沒什么影響。
注意,這種做法更新G網絡時,對應的是更新前的D網絡。
在Pytorch 1.2上的結果:
Pytorch version: 1.2.0 Init grad: None None Init weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) Checkpoint 1 grad: None tensor([[[[0.3500]]]]) Checkpoint 1 weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) Checkpoint 2 grad: tensor([[[[0.4900]]]]) tensor([[[[0.3500]]]]) Checkpoint 2 weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) Checkpoint 3 grad: tensor([[[[0.4900]]]]) tensor([[[[0.3500]]]]) Checkpoint 3 weight: tensor([[[[0.0100]]]]) tensor([[[[0.7000]]]]) Checkpoint 4 grad: tensor([[[[0.4900]]]]) tensor([[[[0.3500]]]]) Checkpoint 4 weight: tensor([[[[0.0100]]]]) tensor([[[[0.3500]]]])
可以看到,也是一樣的。
錯誤的做法:
在 G_loss.backward() 前后不進行對D網絡的網絡參數的requires_grad的關和開,使得G loss反傳了多余的梯度到D網絡上。
第二種正確的方式:
import torch from torch import nn
print(f"Pytorch version: {torch.__version__} \n") X = torch.ones(size=[1, 1, 1, 1]).requires_grad_(False) G = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0, bias=False) G.weight.data.fill_(0.5) G_optim = torch.optim.SGD(G.parameters(), lr=1.0) D = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0, bias=False) D.weight.data.fill_(0.7) D_optim = torch.optim.SGD(D.parameters(), lr=1.0) print(f"Init grad: {G.weight.grad} {D.weight.grad}") print(f"Init weight: {G.weight.detach()} {D.weight.detach()} \n") # Forward pass. Y = G(X) # Zero gradient of D. D_optim.zero_grad() # Calculate D loss. D_loss = D(Y.detach()) ** 2 # Backward D loss. D_loss.backward() # Update D. D_optim.step() print(f"Checkpoint 1 grad: {G.weight.grad} {D.weight.grad}") print(f"Checkpoint 1 weight: {G.weight.detach()} {D.weight.detach()} \n") # Zero gradient of G. G_optim.zero_grad() # Calculate G loss. G_loss = D(Y) ** 2 # Backward G loss. G_loss.backward() # Update G. G_optim.step() print(f"Checkpoint 2 grad: {G.weight.grad} {D.weight.grad}") print(f"Checkpoint 2 weight: {G.weight.detach()} {D.weight.detach()} \n")
分析:
這種方式就很明了,更新D網絡和更新G網絡完全分開。
此時,$x = 1.0, y = 0.5, z = 0.7, \theta_G = 0.5, \theta_D = 0.7$,
首先checkpoint 1處,D loss的梯度反傳到D網絡上得到了 $2 y^2 \cdot \theta_D = 2 \times 0.25 \times 0.7 = 0.35$,沒有反傳到G網絡。
其次checkpoint 2處,G loss的梯度同時反傳到了G網絡和D網絡上,但是由於只更新G網絡,D網絡上的梯度會在下一個iteration中被zero_grad()清零。G網絡上的梯度是 $2 \times 0.5 \times 0.35^2 \times 1.0 = 0.1225$,注意此時的D網絡參數已經從 $0.7$ 更新為 $0.7 - 1.0 \times 0.35 = 0.35$(梯度下降:原參數減去學習率乘梯度,得新參數)。
注意,這種做法更新G網絡時,對應的是已經更新后的D網絡。事實上,我認為這種做法更正確,同時在邏輯上也更加清晰、更好理解。
運行結果:
Pytorch version: 1.9.0 Init grad: None None Init weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) Checkpoint 1 grad: None tensor([[[[0.3500]]]]) Checkpoint 1 weight: tensor([[[[0.5000]]]]) tensor([[[[0.3500]]]]) Checkpoint 2 grad: tensor([[[[0.1225]]]]) tensor([[[[0.5250]]]]) Checkpoint 2 weight: tensor([[[[0.3775]]]]) tensor([[[[0.3500]]]])
Pytorch version: 1.2.0 Init grad: None None Init weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) Checkpoint 1 grad: None tensor([[[[0.3500]]]]) Checkpoint 1 weight: tensor([[[[0.5000]]]]) tensor([[[[0.3500]]]]) Checkpoint 2 grad: tensor([[[[0.1225]]]]) tensor([[[[0.5250]]]]) Checkpoint 2 weight: tensor([[[[0.3775]]]]) tensor([[[[0.3500]]]])
一種錯誤的方式:
import torch from torch import nn print(f"Pytorch version: {torch.__version__} \n") X = torch.ones(size=[1, 1, 1, 1]).requires_grad_(False) G = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0, bias=False) G.weight.data.fill_(0.5) G_optim = torch.optim.SGD(G.parameters(), lr=1.0) D = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0, bias=False) D.weight.data.fill_(0.7) D_optim = torch.optim.SGD(D.parameters(), lr=1.0) print(f"Init grad: {G.weight.grad} {D.weight.grad}") print(f"Init weight: {G.weight.detach()} {D.weight.detach()} \n") # Forward pass. Y = G(X) # Zero gradient of G & D. G_optim.zero_grad() D_optim.zero_grad() # Calculate D loss. D_loss = D(Y.detach()) ** 2 # Calculate G loss. G_loss = D(Y) ** 2 # Backward D loss. D_loss.backward() # Update D. D_optim.step() print(f"Checkpoint 1 grad: {G.weight.grad} {D.weight.grad}") print(f"Checkpoint 1 weight: {G.weight.detach()} {D.weight.detach()} \n") # Backward G loss. G_loss.backward() # Update G. G_optim.step() print(f"Checkpoint 2 grad: {G.weight.grad} {D.weight.grad}") print(f"Checkpoint 2 weight: {G.weight.detach()} {D.weight.detach()} \n")
運行結果:
Pytorch version: 1.2.0 Init grad: None None Init weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) Checkpoint 1 grad: None tensor([[[[0.3500]]]]) Checkpoint 1 weight: tensor([[[[0.5000]]]]) tensor([[[[0.3500]]]]) Checkpoint 2 grad: tensor([[[[0.2450]]]]) tensor([[[[0.7000]]]]) Checkpoint 2 weight: tensor([[[[0.2550]]]]) tensor([[[[0.3500]]]])
Pytorch version: 1.9.0 Init grad: None None Init weight: tensor([[[[0.5000]]]]) tensor([[[[0.7000]]]]) Checkpoint 1 grad: None tensor([[[[0.3500]]]]) Checkpoint 1 weight: tensor([[[[0.5000]]]]) tensor([[[[0.3500]]]]) Traceback (most recent call last): File "D:/Program/PycharmProjects/Test/test.py", line 65, in <module> G_loss.backward() File "D:\Program\Anaconda\envs\py38_torch19\lib\site-packages\torch\_tensor.py", line 255, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) File "D:\Program\Anaconda\envs\py38_torch19\lib\site-packages\torch\autograd\__init__.py", line 147, in backward Variable._execution_engine.run_backward( RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1, 1, 1, 1]] is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
分析:
這種做法的錯誤核心就在於,你在計算G loss時(前向傳播時)使用的是更新前的D網絡,但是你在G loss反向傳播時D網絡已經變成了更新后的,
這個錯誤在較低版本(1.2.0)的Pytorch上並沒有報錯,我們可以看到它在計算G網絡的梯度時,似乎用了更新前后的 $\theta_D$ 相乘 $2 \times 0.5 \times (0.7 \times 0.35) \times 1.0 = 0.245$,而非單純更新前的 ${\theta_D}^2 = 0.7^2$,或者單純更新后的 ${\theta_D}^2 = 0.35^2$.
至於在較高版本(1.9.0)的Pytorch上則直接報錯了,估計是因為step()更新了D網絡之后,G loss對應的計算圖被破壞了,因此直接報了一個 "inplace operation" 錯誤。
因此,使用低版本的Pytorch時千萬一定要注意這種比較隱蔽的錯誤寫法!!!!