pytorch的state_dict()拷貝問題


先說結論,model.state_dict()是淺拷貝,返回的參數仍然會隨着網絡的訓練而變化。應該使用deepcopy(model.state_dict()),或將參數及時序列化到硬盤。

再講故事,前幾天在做一個模型的交叉驗證訓練時,通過model.state_dict()保存了每一組交叉驗證模型的參數,后根據效果選擇准確率最佳的模型load回去,結果每一次都是最后一個模型,從地址來看,每一個保存的state_dict()都具有不同的地址,但進一步發現state_dict()下的各個模型參數的地址是共享的,而我又使用了in-place的方式重置模型參數,進而導致了上述問題。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM