最近使用Pytorch在學習一個深度學習項目,在模型保存和加載過程中遇到了問題,最終通過在網卡查找資料得已解決,故以此記之,以備忘卻。
首先,是在使用多GPU進行模型訓練的過程中,在保存模型參數時,應該使用類似如下代碼進行保存:
對應的在加載模型參數時,使用如下代碼進行加載是沒有問題的:

請注意紅圈的地方缺了“module”關鍵字,導致在保存模型參數時,參數保存成了這樣(模型參數是以key-value的形式保存的),即stat_dict(key),對應的value每個值都多了一個module:
接下來在加載模型參數時,如果直接使用代碼 model.load_state_dict(torch.load('模型參數文件存放路徑')['state_dict'])就會出現問題。報錯如下:
好了,既然知道了出問題的原因在哪里,那就來考慮下如何處理了,兩種方案:
第一,修改保存模型的代碼(加上"module")后,把模型重新訓練一次,重新加載即可。但我們大家都知道,這樣的深度模型訓練,時間一般都是以小時或者天計的,我們等不了那么久。(如果時間允許,可以這么干。哈哈!)
第二,在加載模型參數之前,寫代碼將模型參數里的"module"關鍵字給去掉。比如可以這么寫:

實話實說,這個代碼並不是我的原創,網上給出這個解決方案的地方很多。但我這里有一點不同的時,我加了個“[state_dict]”,我看到的很多地方是沒有這個的,直接就是ckpt.items()。因為我並不知道他們保存模型參數的代碼是怎么寫的,所以也並不好評論對錯。但總之一句話,我們是要通過這段代碼,去掉狀態字典里的"module"關鍵字的所以大家可以通過debug,查看這里的k取到的是什么值,應該要是取到下圖所示紅色框里的值,然后通過“name=k[7:]”去掉前面的"module",然后再加載就可以了。
文中提到一個詞“[state_dict]”,大家不用太在意,有的人在保存模型參數時,用的是“model”,只要在保存和讀取的時候,保持一致就可以了。
歡迎大家對描述不清楚或者不准確的地方提出批評意見和建議!