每次機器模型訓練完成后,都直接退出了。
沒有仔細的研究模型中各個參數到底是怎么樣的
直到前幾天看到大神將10層CNN每一步都展示出來的Github, 驚為天人那https://poloclub.github.io/cnn-explainer/
於是我也想看看,首先就是將模型中的參數保存下來
官網推薦了兩種方法
1. 只保存模型參數
保存:
torch.save(the_model.state_dict(), PATH)
重新加載:由於只保存了參數,重新加載時,需要創造一個新的模型框架來裝參數
restore_model = TheModelClass(*args, **kwargs)
restore_model.load_state_dict(torch.load(PATH))
2. 保存整個模型
保存:
torch.save(the_model, PATH)
重新加載:保存了整個模型,不需要創造新模型
restore_model = torch.load(PATH)
最后,查看模型參數
restore_model.state_dict()