pytorch保存模型等相關參數,利用torch.save(),以及讀取保存之后的文件


本文分為兩部分,第一部分講如何保存模型參數,優化器參數等等,第二部分則講如何讀取。

假設網絡為model = Net(), optimizer = optim.Adam(model.parameters(), lr=args.lr), 假設在某個epoch,我們要保存模型參數,優化器參數以及epoch

一、

1. 先建立一個字典,保存三個參數:

state = {‘net':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}

2.調用torch.save():

torch.save(state, dir)

其中dir表示保存文件的絕對路徑+保存文件名,如'/home/qinying/Desktop/modelpara.pth'

二、

當你想恢復某一階段的訓練(或者進行測試)時,那么就可以讀取之前保存的網絡模型參數等。

checkpoint = torch.load(dir)

model.load_state_dict(checkpoint['net'])

optimizer.load_state_dict(checkpoint['optimizer'])

start_epoch = checkpoint['epoch'] + 1

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM