pytorch模型存儲的兩種方式


1.保存整個網絡結構信息和模型參數信息:

torch.save(model_object, './model.pth')

直接加載即可使用:

model = torch.load('./model.pth')

 

2.只保存網絡的模型參數-推薦使用

torch.save(model_object.state_dict(), './params.pth')

加載則要先從本地網絡模塊導入網絡,然后再加載參數:

from models import AgeModel
model = AgeModel()
model.load_state_dict(torch.load('./params.pth'))

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM