pytorch加載模型


1.加載全部模型:

net.load_state_dict(torch.load(net_para_pth))

2.加載部分模型

net_para_pth = './result/5826.pth'
pretrained_dict = torch.load(net_para_pth)
model_dict = net.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
net.load_state_dict(model_dict)

3.改變某一層參數后加載

將該層名稱改一下,然后用2中方法加載,比如,要將conv5的out_channels由256改為512。

將conv_5改為conv_5_chg,就可以順利加載了,不改會報錯喲

 

算是權宜之計了,還有什么好方法,希望多多指教

 
         
         
       


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM