1.加載全部模型:
net.load_state_dict(torch.load(net_para_pth))
2.加載部分模型
net_para_pth = './result/5826.pth'
pretrained_dict = torch.load(net_para_pth)
model_dict = net.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
net.load_state_dict(model_dict)
3.改變某一層參數后加載
將該層名稱改一下,然后用2中方法加載,比如,要將conv5的out_channels由256改為512。
將conv_5改為conv_5_chg,就可以順利加載了,不改會報錯喲
算是權宜之計了,還有什么好方法,希望多多指教