一 sum
一個模型所占的顯存無非是這兩種:
- 模型權重參數
- 模型所儲存的中間變量
其實權重參數一般來說並不會占用很多的顯存空間,主要占用顯存空間的還是計算時產生的中間變量,當我們定義了一個model之后,我們可以通過以下代碼簡單計算出這個模型權重參數所占用的數據量:
import numpy as np # model是我們在pytorch定義的神經網絡層 # model.parameters()取出這個model所有的權重參數 para = sum([np.prod(list(p.size())) for p in model.parameters()])
# 下面的type_size是4,因為我們的參數是float32也就是4B,4個字節 print('Model {} : params: {:4f}M'.format(model._get_name(), para * type_size / 1000 / 1000))
對上述含義的說明:https://oldpan.me/archives/how-to-use-memory-pytorch
二 torchsummary
1.pip install torchsummary安裝
2.
import torch from torchsummary import summary # 需要使用device來指定網絡在GPU還是CPU運行 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') netG_A2B = Generator(3, 3).to(device) # 這里需要網絡,Generator報錯 summary(netG_A2B, input_size=(3, 256, 256))
Total params這三項比較好理解,因為有可能固定param。
input size也比較好理解:3*256*256/1024/1024*4=0.75(最后一個4表示存儲是需要4字節,float32類型)
Params size也比較好計算:138,357,544/1024/1024*4=527.79
Forward/backward pass size (MB)的計算:(10*24*24+20*8*8+20*8*8+50+10)/1024/1024*4*2=0.064
(注意最有還有個2)
https://blog.csdn.net/csdnxiekai/article/details/110517751
三 pytorch-model-summary
1.pip install pytorch-model-summary
2.
# show input shape print(summary(Net(), torch.zeros((1, 1, 28, 28)), show_input=True)) # show output shape print(summary(Net(), torch.zeros((1, 1, 28, 28)), show_input=False)) # show output shape and hierarchical view of net print(summary(Net(), torch.zeros((1, 1, 28, 28)), show_input=False, show_hierarchical=True))
原文鏈接:https://blog.csdn.net/csdnxiekai/article/details/110517751