先說結論,model.state_dict()是淺拷貝,返回的參數仍然會隨着網絡的訓練而變化。應該使用deepcopy(model.state_dict()),或將參數及時序列化到硬盤。
再講故事,前幾天在做一個模型的交叉驗證訓練時,通過model.state_dict()保存了每一組交叉驗證模型的參數,后根據效果選擇准確率最佳的模型load回去,結果每一次都是最后一個模型,從地址來看,每一個保存的state_dict()都具有不同的地址,但進一步發現state_dict()下的各個模型參數的地址是共享的,而我又使用了in-place的方式重置模型參數,進而導致了上述問題。