pytorch-模型保存和加載
加載模型參數和選擇是由保存的模型數據結構決定,故先要確定保存模型模型的方法和數據結構
保存模型
# 模型權重參數
model.state_dict()
'''首先說一下 model.state_dict()
pytorch 中的 model.state_dict 是一個簡單的python的字典對象,將每一層與它的對應參數建立映射關系.(如model的每一層的weights及偏置等等)
只有那些參數可以訓練的layer才會被保存到模型的state_dict中,如卷積層,線性層等
state_dict是在定義了model或optimizer之后pytorch自動生成的
'''
# model.state_dict() 其實返回的是一個OrderDict,存儲了網絡結構的名字和對應的參數
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.linear1 = nn.Linear(1, 2)
self.linear2 = nn.Linear(2, 1)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
mode = Net()
print(mode.state_dict())
"""
OrderedDict([('linear1.weight', tensor([[ 0.8108],[-0.7968]])), ('linear1.bias', tensor([ 0.2680, -0.4772])), ('linear2.weight', tensor([[-0.7066, -0.3334]])), ('linear2.bias', tensor([0.4819]))])
"""
print(mode.state_dict().keys())
"""
odict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias'])
"""
for param_tensor in model.state_dict():
#打印 key value字典
print(param_tensor,'\t',model.state_dict()[param_tensor].size())
"""
linear1.weight torch.Size([2, 1])
linear1.bias torch.Size([2])
linear2.weight torch.Size([1, 2])
linear2.bias torch.Size([1])
"""
# 保存模型
torch.save(obj, f, pickle_module,pickle_protocol )
"""輸入參數
obj 可以是單個值也可以字典、對象
f 要保存參數的文件路徑
pickle_module
pickle_protocol
"""
# 1、自定義保存-工程實踐中常常使用---推薦
state = {'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch }
torch.save(model_object, './model.pt')
# 2、僅僅是保存模型權重參數
torch.save(model.state_dict(), PATH)
# 3、直接保存整個模型和模型結構
torch.save(Net,PATH)
加載模型
參數的保存
torch.save(model_object.state_dict(), 'params.pth')
# 模型的加載有模型保存的數據結構決定
ckpt = torch.load(f, map_location=None)
"""輸入參數
f file模型文件
map_location torch.device, 動態地進行內存重映射,從不同的設備上讀取文件
pickle_module 用於unpickling元數據和對象的模塊
pickle_load_args 傳遞給pickle_module.load()
注釋: 如果多塊顯卡,map_location={'cuda:0':"cuda:1"},指定在2號顯卡,不使用1號顯卡
返回參數 字典d
由加載文件定義
默認情況,dict_keys(['epoch', 'state_dict', 'optimizer', 'best_pred'])
"""
# 1、針對第一種保存模型的加載方式
# 加載模型
model=Net()
# 加載模型參數
model_CKPT = torch.load(checkpoint_PATH)
# 參數各個屬性f
model.load_state_dict(model_CKPT['model'])
optimizer.load_state_dict(model_CKPT['optimizer'])
# 2、針對第二種保存模型的加載方式
model=Net() # 實例化網絡
model_CKPT = torch.load(checkpoint_PATH) # 加載模型參數
model.load_state_dict(model_CKPT)
# 針對第三種保存整個模型的加載方式
model = torch.load(mode_PATH)
部分權重的加載
# 關鍵自定義函數
def intersect_dicts(da, db, exclude=()):
"""輸入參數
da (state_dict) 加載權重的 state_dict
db (state_dict) 加載模型的 state_dict
exclude (list) 不想要的權重 keys()
返回參數
加載的部分權重 (state_dict)
"""
'''
print("exclude",exclude)
for k, v in da.items():
for x in exclude:
if x in k:
print('@ ',x ,k)
if v.shape != db[k].shape:
print('# ', x, k)
'''
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
案例
# 加載模型
model = Net()
# 加載權重
ckpt=torch.load(weights, map_location=device)
state_dict=ckpt.state_dict()
# state_dict 是一個字典
# state_dict.keys()
# odict_keys(['0.model.0.conv.conv.weight', '0.model.0.conv.conv.bias', '0.model.1.conv.weight', .....])
# 權重取舍處理
state_dict=intersect_dicts(state_dict, model.state_dict(), exclude=exclude)
# 模型加載權重
model.load_state_dict(state_dict, strict=False)
# 最后可以輸出加載了多少個
print('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights))
# output >>> Transferred 498/506 items from yolov5m.pt