pytorch 如何手動修改模型文件model.pth和model.weights


Pytroch網絡模型:修改參數值,修改參數名,添加參數層,刪除參數層

修改參數值

方法1

dict的類型是collecitons.OrderedDict,是一個有序字典,直接將新參數名稱和初始值作為鍵值對插入,然后保存即可。

#修改前 dict = torch.load('./ckpt_dir//model_0.pth') net.load_state_dict(dict) for name,param in net.named_parameters(): print(name,param) #按參數名修改權重 dict["forward1.0.weight"] = torch.ones((1,1,3,3,3)) dict["forward1.0.bias"] = torch.ones(1) torch.save(dict, './ckpt_dir//model_0_.pth') #驗證修改是否成功 net.load_state_dict(torch.load('./ckpt_dir//model_0_.pth')) for param_tensor in net.state_dict(): print(net.state_dict()[param_tensor]) 

方法2(按條件修改)

net.load_state_dict(torch.load('./ckpt_dir//model_0.pth')) for param_tensor in net.state_dict(): print(net.state_dict()[param_tensor]) #按條件修改權重 for param in net.parameters(): new = torch.zeros_like(param.data) param.data = torch.where(0, param.data, new) #驗證是否真的修改了權重值。 for param_tensor in net.state_dict(): print(net.state_dict()[param_tensor]) 

修改參數名

dict = torch.load(model_dir) older_val = dict['舊名'] # 修改參數名 dict['新名'] = dict.pop('舊名') torch.save(dict, './model_changed.pth') #驗證修改是否成功 changed_dict = torch.load('./model_changed.pth') print(old_val) print(changed_dict['新名']) 

添加參數層

dict = torch.load('./ckpt_dir//model_0.pth') print(dict) dict['forward1.0.weight1'] = None #把OrderedDict類型的dict當作普通字典使用即可 print(dict) 

刪除參數層

pre_model = "./results/model_2-9.pth" dict = torch.load(pre_model) for key in list(dict.keys()): if key.startswith('decoder1'): del dict[key] torch.save(dict, './model_deleted.pth') # # #驗證修改是否成功 changed_dict = torch.load('./model_deleted.pth') for key in dict.keys(): print(key) 
版權聲明:本文為博主原創文章,遵循 CC 4.0 BY-SA 版權協議,轉載請附上原文出處鏈接和本聲明。
本文鏈接: https://blog.csdn.net/weixin_44058333/article/details/99682848


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM