pytorch加载模型


1.加载全部模型:

net.load_state_dict(torch.load(net_para_pth))

2.加载部分模型

net_para_pth = './result/5826.pth'
pretrained_dict = torch.load(net_para_pth)
model_dict = net.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
net.load_state_dict(model_dict)

3.改变某一层参数后加载

将该层名称改一下,然后用2中方法加载,比如,要将conv5的out_channels由256改为512。

将conv_5改为conv_5_chg,就可以顺利加载了,不改会报错哟

 

算是权宜之计了,还有什么好方法,希望多多指教




					


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM