pytorch快速加載預訓練模型參數的方式
針對的預訓練模型是通用的模型,也可以是自定義模型,大多是vgg16 , resnet50 , resnet101 , 等,從官網加載太慢
直接修改源碼,改為本地地址
1.直接使用默認程序里的下載方式,往往比較慢;
2.通過修改源代碼,使得模型加載已經下載好的參數,修改地方如下:
通過查找自己代碼里所調用網絡的類,使用pycharm自帶的函數查找功能(ctrl+鼠標左鍵),查看此網絡的加載方法,修改model.load_state_dict()函數。
例如:已經下載好的resnet50的參數文件:放在model_urls里面,這樣就可以提前下載直接使用。
model_urls = {
'resnet50': '/home/huihua/NewDisk1/pretrain_parameter/resnet50-19c8e357.pth',
}
把模型權重下載至torch的緩存文件夾
在本用戶目錄下,linux和win有不同
cd .cache/torch/checkpoints
cd /home/team/.torch/models
兩種方式,常常是用第二種作為torch模型的緩存文件夾, /home/team/是用戶的文件夾
進入文件夾把所需模型權重放入即可自動加載,相比第一種方法簡單點。