1.保存整個網絡結構信息和模型參數信息:
torch.save(model_object, './model.pth')
直接加載即可使用:
model = torch.load('./model.pth')
2.只保存網絡的模型參數-推薦使用
torch.save(model_object.state_dict(), './params.pth')
加載則要先從本地網絡模塊導入網絡,然后再加載參數:
from models import AgeModel model = AgeModel() model.load_state_dict(torch.load('./params.pth'))
