1.加載預訓練模型:
只加載模型,不加載預訓練參數:resnet18 = models.resnet18(pretrained=False)
print resnet18 打印模型結構
resnet18.load_state_dict(torch.load('resnet18-5c106cde.pth'))加載預先下載好的預訓練參數到resnet18
print resnet18 打印的還是模型結構
note: cnn = resnet18.load_state_dict(torch.load('resnet18-5c106cde.pth'))是錯誤的,這樣cnn將是nonetype
pre_dict = resnet18.state_dict()按鍵值對將模型參數加載到pre_dict
print for k, v in pre_dict.items(): 打印模型參數
for k, v in pre_dict.items():
print k
打印模型每層命名
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
note:model是自己定義好的模型,將pretrained_dict和model_dict中命名一致的層加入pretrained_dict(包括參數)
加載模型和預訓練參數:resnet34 = models.resnet34(pretrained=True)
reference:
1.
http://blog.csdn.net/VictoriaW/article/details/72821329
2.
vgg16 = models.vgg16(pretrained=True)
pretrained_dict = vgg16.state_dict()
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)