在pytorch環境下,有兩個計算FLOPs和參數量的包thop和ptflops,結果基本是一致的。
thop
參考https://github.com/Lyken17/pytorch-OpCounter
安裝方法:pip install thop
使用方法:
from torchvision.models import resnet18 from thop import profile model = resnet18() input = torch.randn(1, 3, 224, 224) #模型輸入的形狀,batch_size=1 flops, params = profile(model, inputs=(input, )) print(flops/1e9,params/1e6) #flops單位G,para單位M
用來測試3d resnet18的FLOPs:
model =C3D_Hash_Model(48) input = torch.randn(1, 3,10, 112, 112) #視頻取10幀 flops, params = profile(model, inputs=(input, )) print(flops/1e9,params/1e6)
ptflops
參考https://github.com/sovrasov/flops-counter.pytorch
安裝方法:pip install ptflops
或者 pip install git+https://github.com/sovrasov/flops-counter.pytorch.git
import torchvision.models as models
import torch
from ptflops import get_model_complexity_info
with torch.cuda.device(0):
net = models.resnet18()
flops, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, print_per_layer_stat=True) #不用寫batch_size大小,默認batch_size=1
print('Flops: ' + flops)
print('Params: ' + params)
用來測試3d resnet18的FLOPs:
import torch
from ptflops.flops_counter import get_model_complexity_info
with torch.cuda.device(0):
net = C3D_Hash_Model(48)
flops, params = get_model_complexity_info(net, (3,10, 112, 112), as_strings=True, print_per_layer_stat=True)
print('Flops: ' + flops)
print('Params: ' + params)
如果安裝ptflops出問題,可以直接到https://github.com/sovrasov/flops-counter.pytorch.git下載代碼,然后直接把目錄ptflops復制到項目代碼中,通過from ptflops.flops_counter import get_model_complexity_info來調用函數計算FLOPs
python計時程序運行時間
import time
time_start=time.time()
#在這里運行模型
time_end=time.time()
print('totally cost',time_end-time_start)
