pytorch:修改预训练模型


torchvision中提供了很多训练好的模型,这些模型是在1000类,224*224的imagenet中训练得到的,很多时候不适合我们自己的数据,可以根据需要进行修改。

1、类别不同

    # coding=UTF-8 
    import torchvision.models as models #调用模型 
    model = models.resnet50(pretrained=True) #提取fc层中固定的参数 
    fc_features = model.fc.in_features #修改类别为9 
    model.fc = nn.Linear(fc_features, 9)  

2、添加层后,加载部分参数

model = ... model_dict = model.state_dict() # 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict) # 3. load the new state dict
model.load_state_dict(model_dict)

参考:https://blog.csdn.net/u012494820/article/details/79068625

          https://blog.csdn.net/whut_ldz/article/details/78845947


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM