Pytorch 模型參數保存 + 如何查看模型參數


每次機器模型訓練完成后,都直接退出了。

沒有仔細的研究模型中各個參數到底是怎么樣的

直到前幾天看到大神將10層CNN每一步都展示出來的Github, 驚為天人那https://poloclub.github.io/cnn-explainer/

於是我也想看看,首先就是將模型中的參數保存下來

pytorch模型參數保存

官網推薦了兩種方法

1. 只保存模型參數

 保存:  

torch.save(the_model.state_dict(), PATH) 

 重新加載:由於只保存了參數,重新加載時,需要創造一個新的模型框架來裝參數

  

restore_model = TheModelClass(*args, **kwargs)

restore_model.load_state_dict(torch.load(PATH))

 

2. 保存整個模型

  保存:  

torch.save(the_model, PATH)

  重新加載:保存了整個模型,不需要創造新模型

restore_model = torch.load(PATH)

 

最后,查看模型參數  

restore_model.state_dict()

 

 

 

 
 

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM