Pytorch GAN訓練時多次backward時出錯問題


轉載自https://www.daimajiaoliu.com/daima/479755892900406

和 https://oldpan.me/archives/pytorch-retain_graph-work

從一個錯誤說起:RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed

在深度學習中,有些場景需要進行兩次反向,比如Gan網絡,需要對D進行一次,還要對G進行一次,很多人都會遇到上面這個錯誤,這個錯誤的意思就是嘗試對一個計算圖進行第二次反向,但是計算圖已經釋放了。其實看簡單點和我們之前的backward一樣,當圖進行了一次梯度更新,就會把一些梯度的緩存給清空,為了避免下次疊加,但在Gan這種情形下,我們必須要二次更新,那怎么辦呢。有兩種方案:

方案一:

這是網上大多數給出的解決方案,在第一次反向時候加入一個 l2.backward(retain_graph=True) ,這樣就能避免釋放掉了。

這個參數的作用是什么,官方定義為:

retain_graph (bool, optional) – If False, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. Defaults to the value of create_graph.

大意是如果設置為False,計算圖中的中間變量在計算完后就會被釋放。但是在平時的使用中這個參數默認都為False從而提高效率,和creat_graph的值一樣。

 也就相當於,假如你有兩個Loss:

# 假如你有兩個Loss,先執行第一個的backward,再執行第二個backward
loss1.backward(retain_graph=True)
loss2.backward() # 執行完這個后,所有中間變量都會被釋放,以便下一次的循環
optimizer.step() # 更新參數

方案二:

上面的方案雖然解決了問題,但是並不優美,因為我們用Gan的時候,D和G兩者的更新並無聯系,二者的聯系僅僅是D里面用到了G的輸出,而這個輸出一般我們都是直接拿來用的,而問題就出現在這里。下面給一個模擬:

data = torch.randn(4,10)

model1 = torch.nn.Linear(10,2)
model2 = torch.nn.Linear(2,2)

optimizer1 = torch.optim.Adam(model1.parameters(), lr=0.001,betas=(0.5, 0.999))
optimizer2 = torch.optim.Adam(model2.parameters(), lr=0.001,betas=(0.5, 0.999))

loss = torch.nn.CrossEntropyLoss()
data = torch.randn(4,10)
label = torch.Tensor([0,1,1,0]).long()
for i in range(20):
    a = model1(data)
    b = model2(a)
    l1 = loss(a,label)
    l2 = loss(b,label)
    optimizer2.zero_grad()
    l2.backward()
    optimizer2.step()

    optimizer1.zero_grad()
    l1.backward()
    optimizer1.step()

解決方案可以是l2.backward(retain_graph=True)。除此之外我們還可以是 b = model2(a.detach()) ,這個就優美一點,a.detach()和a的區別你可以打印出來看一下,其實a.detach()是沒有梯度的,所以相當於一個單純的數字,和model1就脫離了聯系,這樣model2和model1就是完全分離開來的兩個圖,但是如果用的是a則model2和model1則仍然公用一個圖,所以導致了錯誤。可以看下面示意圖(這個是我猜測,幫助理解):

左邊相當於直接用a而右邊則用a.detach(),類似的在Gan網絡里面D的輸入可以改為G的輸出y_fake.detach()。

但有一點需要注意的是,兩個網絡一定沒有需要共同更新的 ,假如上面的optimizer2 = torch.optim.Adam(itertools.chain(model1.parameters(),model2.parameters()), lr=0.001,betas=(0.5, 0.999)),則還是用retain_graph=True保險,因為.detach則model2反向不會傳播到model1,導致不對model1里面參數更新。

方案一可見:https://github.com/growvv/GAN-Pytorch/blob/93b49bd7ce395c2035df1d036daad83a67a9c691/Simple-GAN/simple_gan.py

方案二可見:https://github.com/growvv/GAN-Pytorch/blob/257b267ea60af80212adc3dc5ad4cf28aeed00f6/CycleGAN/train.py


 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM