torchvision中提供了很多訓練好的模型,這些模型是在1000類,224*224的imagenet中訓練得到的,很多時候不適合我們自己的數據,可以根據需要進行修改。
1、類別不同
# coding=UTF-8
import torchvision.models as models #調用模型
model = models.resnet50(pretrained=True) #提取fc層中固定的參數
fc_features = model.fc.in_features #修改類別為9
model.fc = nn.Linear(fc_features, 9)
2、添加層后,加載部分參數
model = ... model_dict = model.state_dict() # 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict) # 3. load the new state dict
model.load_state_dict(model_dict)
參考:https://blog.csdn.net/u012494820/article/details/79068625
https://blog.csdn.net/whut_ldz/article/details/78845947