pytorch獲得模型的參數信息,所占內存的大小


一 sum

一個模型所占的顯存無非是這兩種:

  • 模型權重參數
  • 模型所儲存的中間變量

其實權重參數一般來說並不會占用很多的顯存空間,主要占用顯存空間的還是計算時產生的中間變量,當我們定義了一個model之后,我們可以通過以下代碼簡單計算出這個模型權重參數所占用的數據量:

import numpy as np

# model是我們在pytorch定義的神經網絡層
# model.parameters()取出這個model所有的權重參數
para = sum([np.prod(list(p.size())) for p in model.parameters()])
# 下面的type_size是4,因為我們的參數是float32也就是4B,4個字節
 print('Model {} : params: {:4f}M'.format(model._get_name(), para * type_size / 1000 / 1000))

對上述含義的說明:https://oldpan.me/archives/how-to-use-memory-pytorch

二 torchsummary

1.pip install torchsummary安裝

2.

import torch
from torchsummary import summary

# 需要使用device來指定網絡在GPU還是CPU運行
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
netG_A2B = Generator(3, 3).to(device)   # 這里需要網絡,Generator報錯
summary(netG_A2B, input_size=(3, 256, 256))

Total params這三項比較好理解,因為有可能固定param。
input size也比較好理解:3*256*256/1024/1024*4=0.75(最后一個4表示存儲是需要4字節,float32類型)
Params size也比較好計算:138,357,544/1024/1024*4=527.79
Forward/backward pass size (MB)的計算:(10*24*24+20*8*8+20*8*8+50+10)/1024/1024*4*2=0.064(注意最有還有個2)
https://blog.csdn.net/csdnxiekai/article/details/110517751

三  pytorch-model-summary

1.pip install pytorch-model-summary

2.

# show input shape
print(summary(Net(), torch.zeros((1, 1, 28, 28)), show_input=True))

# show output shape
print(summary(Net(), torch.zeros((1, 1, 28, 28)), show_input=False))

# show output shape and hierarchical view of net
print(summary(Net(), torch.zeros((1, 1, 28, 28)), show_input=False, show_hierarchical=True))
原文鏈接:https://blog.csdn.net/csdnxiekai/article/details/110517751


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM