pytorch:修改預訓練模型


torchvision中提供了很多訓練好的模型,這些模型是在1000類,224*224的imagenet中訓練得到的,很多時候不適合我們自己的數據,可以根據需要進行修改。

1、類別不同

    # coding=UTF-8 
    import torchvision.models as models #調用模型 
    model = models.resnet50(pretrained=True) #提取fc層中固定的參數 
    fc_features = model.fc.in_features #修改類別為9 
    model.fc = nn.Linear(fc_features, 9)  

2、添加層后,加載部分參數

model = ... model_dict = model.state_dict() # 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict) # 3. load the new state dict
model.load_state_dict(model_dict)

參考:https://blog.csdn.net/u012494820/article/details/79068625

          https://blog.csdn.net/whut_ldz/article/details/78845947


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM