pytorch快速加載預訓練模型參數的方式


pytorch快速加載預訓練模型參數的方式

針對的預訓練模型是通用的模型,也可以是自定義模型,大多是vgg16 ,  resnet50 , resnet101 , 等,從官網加載太慢

直接修改源碼,改為本地地址

1.直接使用默認程序里的下載方式,往往比較慢;

2.通過修改源代碼,使得模型加載已經下載好的參數,修改地方如下:

通過查找自己代碼里所調用網絡的類,使用pycharm自帶的函數查找功能(ctrl+鼠標左鍵),查看此網絡的加載方法,修改model.load_state_dict()函數。

例如:已經下載好的resnet50的參數文件:放在model_urls里面,這樣就可以提前下載直接使用。

model_urls = {
'resnet50': '/home/huihua/NewDisk1/pretrain_parameter/resnet50-19c8e357.pth',
}

 

把模型權重下載至torch的緩存文件夾

在本用戶目錄下,linux和win有不同

cd .cache/torch/checkpoints
cd /home/team/.torch/models
兩種方式,常常是用第二種作為torch模型的緩存文件夾, /home/team/是用戶的文件夾

進入文件夾把所需模型權重放入即可自動加載,相比第一種方法簡單點。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM