pytorch 如何手动修改模型文件model.pth和model.weights


Pytroch网络模型:修改参数值,修改参数名,添加参数层,删除参数层

修改参数值

方法1

dict的类型是collecitons.OrderedDict,是一个有序字典,直接将新参数名称和初始值作为键值对插入,然后保存即可。

#修改前 dict = torch.load('./ckpt_dir//model_0.pth') net.load_state_dict(dict) for name,param in net.named_parameters(): print(name,param) #按参数名修改权重 dict["forward1.0.weight"] = torch.ones((1,1,3,3,3)) dict["forward1.0.bias"] = torch.ones(1) torch.save(dict, './ckpt_dir//model_0_.pth') #验证修改是否成功 net.load_state_dict(torch.load('./ckpt_dir//model_0_.pth')) for param_tensor in net.state_dict(): print(net.state_dict()[param_tensor]) 

方法2(按条件修改)

net.load_state_dict(torch.load('./ckpt_dir//model_0.pth')) for param_tensor in net.state_dict(): print(net.state_dict()[param_tensor]) #按条件修改权重 for param in net.parameters(): new = torch.zeros_like(param.data) param.data = torch.where(0, param.data, new) #验证是否真的修改了权重值。 for param_tensor in net.state_dict(): print(net.state_dict()[param_tensor]) 

修改参数名

dict = torch.load(model_dir) older_val = dict['旧名'] # 修改参数名 dict['新名'] = dict.pop('旧名') torch.save(dict, './model_changed.pth') #验证修改是否成功 changed_dict = torch.load('./model_changed.pth') print(old_val) print(changed_dict['新名']) 

添加参数层

dict = torch.load('./ckpt_dir//model_0.pth') print(dict) dict['forward1.0.weight1'] = None #把OrderedDict类型的dict当作普通字典使用即可 print(dict) 

删除参数层

pre_model = "./results/model_2-9.pth" dict = torch.load(pre_model) for key in list(dict.keys()): if key.startswith('decoder1'): del dict[key] torch.save(dict, './model_deleted.pth') # # #验证修改是否成功 changed_dict = torch.load('./model_deleted.pth') for key in dict.keys(): print(key) 
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接: https://blog.csdn.net/weixin_44058333/article/details/99682848


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM