Pytroch網絡模型:修改參數值,修改參數名,添加參數層,刪除參數層
修改參數值
方法1
dict的類型是collecitons.OrderedDict,是一個有序字典
,直接將新參數名稱和初始值作為鍵值對插入,然后保存即可。
#修改前 dict = torch.load('./ckpt_dir//model_0.pth') net.load_state_dict(dict) for name,param in net.named_parameters(): print(name,param) #按參數名修改權重 dict["forward1.0.weight"] = torch.ones((1,1,3,3,3)) dict["forward1.0.bias"] = torch.ones(1) torch.save(dict, './ckpt_dir//model_0_.pth') #驗證修改是否成功 net.load_state_dict(torch.load('./ckpt_dir//model_0_.pth')) for param_tensor in net.state_dict(): print(net.state_dict()[param_tensor])
方法2(按條件修改)
net.load_state_dict(torch.load('./ckpt_dir//model_0.pth')) for param_tensor in net.state_dict(): print(net.state_dict()[param_tensor]) #按條件修改權重 for param in net.parameters(): new = torch.zeros_like(param.data) param.data = torch.where(0, param.data, new) #驗證是否真的修改了權重值。 for param_tensor in net.state_dict(): print(net.state_dict()[param_tensor])
修改參數名
dict = torch.load(model_dir) older_val = dict['舊名'] # 修改參數名 dict['新名'] = dict.pop('舊名') torch.save(dict, './model_changed.pth') #驗證修改是否成功 changed_dict = torch.load('./model_changed.pth') print(old_val) print(changed_dict['新名'])
添加參數層
dict = torch.load('./ckpt_dir//model_0.pth') print(dict) dict['forward1.0.weight1'] = None #把OrderedDict類型的dict當作普通字典使用即可 print(dict)
刪除參數層
pre_model = "./results/model_2-9.pth" dict = torch.load(pre_model) for key in list(dict.keys()): if key.startswith('decoder1'): del dict[key] torch.save(dict, './model_deleted.pth') # # #驗證修改是否成功 changed_dict = torch.load('./model_deleted.pth') for key in dict.keys(): print(key)
版權聲明:本文為博主原創文章,遵循 CC 4.0 BY-SA 版權協議,轉載請附上原文出處鏈接和本聲明。