Pytorch-修改預訓練參數


我自己改進的模型為model(model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)),原模型為resnet50。

1.查看模型參數

現模型:

1 model_dict = model.state_dict()
2 for k,v in model_dict.items():
3     print(k)

預訓練模型參數

1 pretrained_dict = model_zoo.load_url(model_urls['resnet50'])
2 for k,v in pretrained_dict.items():
3     print(k)

2.將預訓練參數賦給自己改進的模型

改進的模型參數和原模型參數一致時:

1 import torch.utils.model_zoo as model_zoo
2 
3 model_urls = {
4     'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth'
5 }
6 
7 model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False)   

Tip:如果兩個模型參數完全一致的話,strict=True,如果兩個模型參數不一致的話,當strict=False預訓練模型會把具有相同參數名稱的值賦給改進的參數,不相同的則不賦值。

改進的模型參數和原模型參數不一致時,使用部分預訓練模型參數初始化網絡 :

1 model_dict = model.state_dict()          #取出自己模型的網絡參數 
2 pretrained_dict = model_zoo.load_url(model_urls['resnet50'])
3 
4 model_dict['classifiers.3.fc.weight'] = pretrained_dict['fc.weight'][:2]
5 model_dict['classifiers.3.fc.bias'] = pretrained_dict['fc.bias'][:2]


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM