pytorch顯存越來越多的一個自己沒注意的原因


optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss

參考:https://blog.csdn.net/qq_27292549/article/details/80250031

我和博主犯了一毛一樣的低級錯誤。。。。

 

下面是原博解釋:

運行着就發現顯存炸了

觀察了一下發現隨着每個batch顯存消耗在不斷增大..

參考了別人的代碼發現那句loss一般是這樣寫 

loss_sum += loss.data[0]

這是因為輸出的loss的數據類型是Variable。

而PyTorch的動態圖機制就是通過Variable來構建圖。主要是使用Variable計算的時候,會記錄下新產生的Variable的運算符號,在反向傳播求導的時候進行使用。

如果這里直接將loss加起來,系統會認為這里也是計算圖的一部分,也就是說網絡會一直延伸變大~那么消耗的顯存也就越來越大~~

總之使用Variable的數據時候要非常小心。不是必要的話盡量使用Tensor來進行計算...

 

補充:

用Tensor計算也是有坑的,要寫成:

 train_loss += loss.item()

不然顯存還是會炸。。。。。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM