Pytroch网络模型:修改参数值,修改参数名,添加参数层,删除参数层
修改参数值
方法1
dict的类型是collecitons.OrderedDict,是一个有序字典
,直接将新参数名称和初始值作为键值对插入,然后保存即可。
#修改前 dict = torch.load('./ckpt_dir//model_0.pth') net.load_state_dict(dict) for name,param in net.named_parameters(): print(name,param) #按参数名修改权重 dict["forward1.0.weight"] = torch.ones((1,1,3,3,3)) dict["forward1.0.bias"] = torch.ones(1) torch.save(dict, './ckpt_dir//model_0_.pth') #验证修改是否成功 net.load_state_dict(torch.load('./ckpt_dir//model_0_.pth')) for param_tensor in net.state_dict(): print(net.state_dict()[param_tensor])
方法2(按条件修改)
net.load_state_dict(torch.load('./ckpt_dir//model_0.pth')) for param_tensor in net.state_dict(): print(net.state_dict()[param_tensor]) #按条件修改权重 for param in net.parameters(): new = torch.zeros_like(param.data) param.data = torch.where(0, param.data, new) #验证是否真的修改了权重值。 for param_tensor in net.state_dict(): print(net.state_dict()[param_tensor])
修改参数名
dict = torch.load(model_dir) older_val = dict['旧名'] # 修改参数名 dict['新名'] = dict.pop('旧名') torch.save(dict, './model_changed.pth') #验证修改是否成功 changed_dict = torch.load('./model_changed.pth') print(old_val) print(changed_dict['新名'])
添加参数层
dict = torch.load('./ckpt_dir//model_0.pth') print(dict) dict['forward1.0.weight1'] = None #把OrderedDict类型的dict当作普通字典使用即可 print(dict)
删除参数层
pre_model = "./results/model_2-9.pth" dict = torch.load(pre_model) for key in list(dict.keys()): if key.startswith('decoder1'): del dict[key] torch.save(dict, './model_deleted.pth') # # #验证修改是否成功 changed_dict = torch.load('./model_deleted.pth') for key in dict.keys(): print(key)
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。