PyTorch模型加載與保存的最佳實踐


一般來說PyTorch有兩種保存和讀取模型參數的方法。但這篇文章我記錄了一種最佳實踐,可以在加載模型時避免掉一些問題。

傳統方案:

第一種方案是保存整個模型:

torch.save(model_object, 'model.pth')

第二種方法是保存模型網絡參數:

torch.save(model_object.state_dict(), 'params.pth')

加載的時候分別這樣加載:

model = torch.load('model.pth')

以及:

model_object.load_state_dict(torch.load('params.pth'))

改進的方案

注意到這個方案是因為模型在加載之后,loss會飆升之后再慢慢降回來。查閱有關分析之后,判定是優化器optimizer的問題。

如果模型的保存是為了恢復訓練狀態,那么可以考慮同時保存優化器optimizer的參數:

state = {
    'epoch': epoch,
    'net': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

然后這樣加載:

checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint['net'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch =  checkpoint['epoch'] + 1

如果模型的保存是為了方便以后進行validation和test,可以在加載完之后制定model.eval()固定dropout和BN層

 

https://ldzhangyx.github.io/2018/11/19/pytorch-1119/


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM