pytorch快速加載預訓練模型參數的方式
https://github.com/pytorch/vision/tree/master/torchvision/models
常用預訓練模型在這里面
總結下各種模型的下載地址:
1 Resnet: 2 3 model_urls = { 4 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 5 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 6 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 7 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 8 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 9 } 10 11 inception: 12 13 model_urls = { 14 # Inception v3 ported from TensorFlow 15 'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', 16 } 17 18 Densenet: 19 20 model_urls = { 21 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 22 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 23 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 24 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 25 } 26 27 28 29 Alexnet: 30 31 model_urls = { 32 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', 33 } 34 35 vggnet: 36 37 model_urls = { 38 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 39 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 40 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 41 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 42 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 43 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 44 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 45 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 46 }
解決下載速度慢的方法:
1.換移動網絡,有些公司網、校園網對於pytorch網站有很大的限速。
2.翻牆(有時不翻牆也可)先下載下來,放入文件夾中,方法如下兩種(推薦第二種)
針對的預訓練模型是通用的模型,也可以是自定義模型,大多是vgg16 , resnet50 , resnet101 , 等,從官網加載太慢
1.直接修改源碼,改為本地地址
直接使用默認程序里的下載方式,往往比較慢;
通過修改源代碼,使得模型加載已經下載好的參數,修改地方如下:
通過查找自己代碼里所調用網絡的類,使用pycharm自帶的函數查找功能(ctrl+鼠標左鍵),查看此網絡的加載方法,修改model.load_state_dict()函數。
例如:已經下載好的resnet50的參數文件:放在model_urls里面,這樣就可以提前下載直接使用。
model_urls = {
'resnet50': '/home/huihua/NewDisk1/pretrain_parameter/resnet50-19c8e357.pth',
}
2.把模型權重下載至torch的緩存文件夾
由於torch在加載模型時候首先檢查本地緩存是否已經存在模型,所以在本用戶目錄下,預先下載放入可快速加載模型。
cd .cache/torch/checkpoints
cd /home/team/.torch/models
兩種方式,常常是用第二種作為torch模型的緩存文件夾
進入文件夾把所需模型權重放入即可自動加載,相比第一種方法簡單點。