使用Pytorch在多GPU下保存和加載訓練模型參數遇到的問題


  最近使用Pytorch在學習一個深度學習項目,在模型保存和加載過程中遇到了問題,最終通過在網卡查找資料得已解決,故以此記之,以備忘卻。

  首先,是在使用多GPU進行模型訓練的過程中,在保存模型參數時,應該使用類似如下代碼進行保存:

  torch.save({
                'epoch': epoch,
                'state_dict': model.module.state_dict(),
                'optimizer': optimizer.state_dict()
            }, 'results/checkpoint_net.pth')

  對應的在加載模型參數時,使用如下代碼進行加載是沒有問題的:

checkpoint = torch.load('./results/checkpoint_net.pth')
model.load_state_dict(checkpoint['model'])
  一般情況下,在保存模型時我們不會發現會有什么不對,而是在需要加載模型參數時,才發現加載報錯了。比如:
 
  這時我們需要回頭檢查我們在保存模型參數時,是否有哪里不對。比如我這次就是這樣的,寫代碼的時候並沒有考慮到多GPU的情況,所以保存代碼如下:
  
  torch.save({
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }, 'results/checkpoint_net.pth')
  

 

    請注意紅圈的地方缺了“module”關鍵字,導致在保存模型參數時,參數保存成了這樣(模型參數是以key-value的形式保存的),即stat_dict(key),對應的value每個值都多了一個module:

 

  接下來在加載模型參數時,如果直接使用代碼 model.load_state_dict(torch.load('模型參數文件存放路徑')['state_dict'])就會出現問題。報錯如下:

 

 

  好了,既然知道了出問題的原因在哪里,那就來考慮下如何處理了,兩種方案:

  第一,修改保存模型的代碼(加上"module")后,把模型重新訓練一次,重新加載即可。但我們大家都知道,這樣的深度模型訓練,時間一般都是以小時或者天計的,我們等不了那么久。(如果時間允許,可以這么干。哈哈!)

  第二,在加載模型參數之前,寫代碼將模型參數里的"module"關鍵字給去掉。比如可以這么寫:

  

 實話實說,這個代碼並不是我的原創,網上給出這個解決方案的地方很多。但我這里有一點不同的時,我加了個“[state_dict]”,我看到的很多地方是沒有這個的,直接就是ckpt.items()。因為我並不知道他們保存模型參數的代碼是怎么寫的,所以也並不好評論對錯。但總之一句話,我們是要通過這段代碼,去掉狀態字典里的"module"關鍵字的所以大家可以通過debug,查看這里的k取到的是什么值,應該要是取到下圖所示紅色框里的值,然后通過“name=k[7:]”去掉前面的"module",然后再加載就可以了。

 

  文中提到一個詞“[state_dict]”,大家不用太在意,有的人在保存模型參數時,用的是“model”,只要在保存和讀取的時候,保持一致就可以了。

 歡迎大家對描述不清楚或者不准確的地方提出批評意見和建議!


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM