Pytorch加載預訓練模型的坑


保存模型:

def save(model, model_path):
  torch.save(model.state_dict(), model_path)

加載模型:

def load(model, model_path):
  model.load_state_dict(torch.load(model_path))

這樣會出現一個問題,即明明指定了某張卡,但總有一個模型的顯存多出來,占到另一張卡上,很煩人,看到知乎有個方法可以解決

https://www.zhihu.com/question/67209417/answer/355059967

說是把模型的數據放在CPU上就可以解決,等試一下效果

def load(model, model_path):
  model.load_state_dict(torch.load(model_path,map_location=torch.device('cpu')))
 
         
         
       


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM