保存模型總體來說有兩種:
第一種:保存訓練的模型,之后我們可以繼續訓練
(1)保存模型
state = { 'model': model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch': epoch } torch.save(state, path)
model.state_dict():模型參數
optimizer.state_dict():優化器
epoch:保存epoch,為了可以接着訓練
(2)恢復模型
checkpoint = torch.load(path) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch']+1
第二種:保存測試的模型,一般保存准確率最高的
(1)保存模型
這時我們只需要保存模型參數就行了
torch.save(model.state_dict, path)
(2)恢復模型
model.load_state_dict(torch.load(path))