一般來說PyTorch有兩種保存和讀取模型參數的方法。但這篇文章我記錄了一種最佳實踐,可以在加載模型時避免掉一些問題。
傳統方案:
第一種方案是保存整個模型:
torch.save(model_object, 'model.pth')
第二種方法是保存模型網絡參數:
torch.save(model_object.state_dict(), 'params.pth')
加載的時候分別這樣加載:
model = torch.load('model.pth')
以及:
model_object.load_state_dict(torch.load('params.pth'))
改進的方案
注意到這個方案是因為模型在加載之后,loss會飆升之后再慢慢降回來。查閱有關分析之后,判定是優化器optimizer的問題。
如果模型的保存是為了恢復訓練狀態,那么可以考慮同時保存優化器optimizer的參數:
state = { 'epoch': epoch, 'net': model.state_dict(), 'optimizer': optimizer.state_dict(), ... } torch.save(state, filepath)
然后這樣加載:
checkpoint = torch.load(model_path) model.load_state_dict(checkpoint['net']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] + 1
如果模型的保存是為了方便以后進行validation和test,可以在加載完之后制定model.eval()固定dropout和BN層。