pytorch 中模型的保存與加載,增量訓練


 讓模型接着上次保存好的模型訓練,模型加載

#實例化模型、優化器、損失函數
model = MnistModel().to(config.device)
optimizer = optim.Adam(model.parameters(),lr=0.01)

if os.path.exists("./model/mnist_net.pt"):
    model.load_state_dict(torch.load("./model/mnist_net.pt"))
    optimizer.load_state_dict(torch.load("model/mnist_optimizer.pt"))

  模型保存

 

            torch.save(model.state_dict(),"model/mnist_net.pt")
            torch.save(optimizer.state_dict(),"model/mnist_optimizer.pt")

 

  

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM