pytorch預訓練模型


1.加載預訓練模型:

只加載模型,不加載預訓練參數:resnet18 = models.resnet18(pretrained=False)

print resnet18 打印模型結構

resnet18.load_state_dict(torch.load('resnet18-5c106cde.pth'))加載預先下載好的預訓練參數到resnet18

print resnet18 打印的還是模型結構

note: cnn = resnet18.load_state_dict(torch.load('resnet18-5c106cde.pth'))是錯誤的,這樣cnn將是nonetype

pre_dict = resnet18.state_dict()按鍵值對將模型參數加載到pre_dict

print for k, v in pre_dict.items(): 打印模型參數

for k, v in pre_dict.items():

  print k

打印模型每層命名

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 

note:model是自己定義好的模型,將pretrained_dict和model_dict中命名一致的層加入pretrained_dict(包括參數)

加載模型和預訓練參數:resnet34 = models.resnet34(pretrained=True)

 

reference:

1.

http://blog.csdn.net/VictoriaW/article/details/72821329

 

2.

vgg16 = models.vgg16(pretrained=True)

pretrained_dict = vgg16.state_dict()

model_dict = model.state_dict()

# 1. filter out unnecessary keys

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

# 2. overwrite entries in the existing state dict

model_dict.update(pretrained_dict)

# 3. load the new state dict

model.load_state_dict(model_dict)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM