本文分為兩部分,第一部分講如何保存模型參數,優化器參數等等,第二部分則講如何讀取。
假設網絡為model = Net(), optimizer = optim.Adam(model.parameters(), lr=args.lr), 假設在某個epoch,我們要保存模型參數,優化器參數以及epoch
一、
1. 先建立一個字典,保存三個參數:
state = {‘net':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
2.調用torch.save():
torch.save(state, dir)
其中dir表示保存文件的絕對路徑+保存文件名,如'/home/qinying/Desktop/modelpara.pth'
二、
當你想恢復某一階段的訓練(或者進行測試)時,那么就可以讀取之前保存的網絡模型參數等。
checkpoint = torch.load(dir)
model.load_state_dict(checkpoint['net'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch'] + 1