對於顯存不充足的煉丹研究者來說,弄清楚Pytorch顯存的分配機制是很有必要的。下面直接通過實驗來推出Pytorch顯存的分配過程。
實驗實驗代碼如下:
import torch from torch import cuda x = torch.zeros([3,1024,1024,256],requires_grad=True,device='cuda') print("1", cuda.memory_allocated()/1024**2) y = 5 * x print("2", cuda.memory_allocated()/1024**2) torch.mean(y).backward() print("3", cuda.memory_allocated()/1024**2) print(cuda.memory_summary())
輸出如下:

代碼首先分配3GB的顯存創建變量x,然后計算y,再用y進行反向傳播。可以看到,創建x后與計算y后分別占顯存3GB與6GB,這是合理的。另外,后面通過backward(),計算出x.grad,占存與x一致,所以最終一共占有顯存9GB,這也是合理的。但是,輸出顯示了顯存的峰值為12GB,這多出的3GB是怎么來的呢?首先畫出計算圖:
下面通過列表的形式來模擬Pytorch在運算時分配顯存的過程:

如上所示,由於需要保存反向傳播以前所有前向傳播的中間變量,所以有了12GB的峰值占存。
我們可以不存儲計算圖中的非葉子結點,達到節省顯存的目的,即可以把上面的代碼中的y=5*x與mean(y)寫成一步:
import torch from torch import cuda x = torch.zeros([3,1024,1024,256],requires_grad=True,device='cuda') print("1", cuda.memory_allocated()/1024**2) torch.mean(5*x).backward() print("2", cuda.memory_allocated()/1024**2) print(cuda.memory_summary())
占顯存量減少了3GB: