Pytorch加载预训练模型的坑


保存模型:

def save(model, model_path):
  torch.save(model.state_dict(), model_path)

加载模型:

def load(model, model_path):
  model.load_state_dict(torch.load(model_path))

这样会出现一个问题,即明明指定了某张卡,但总有一个模型的显存多出来,占到另一张卡上,很烦人,看到知乎有个方法可以解决

https://www.zhihu.com/question/67209417/answer/355059967

说是把模型的数据放在CPU上就可以解决,等试一下效果

def load(model, model_path):
  model.load_state_dict(torch.load(model_path,map_location=torch.device('cpu')))



					


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM