保存模型:
def save(model, model_path): torch.save(model.state_dict(), model_path)
加載模型:
def load(model, model_path): model.load_state_dict(torch.load(model_path))
這樣會出現一個問題,即明明指定了某張卡,但總有一個模型的顯存多出來,占到另一張卡上,很煩人,看到知乎有個方法可以解決
https://www.zhihu.com/question/67209417/answer/355059967
說是把模型的數據放在CPU上就可以解決,等試一下效果
def load(model, model_path):
model.load_state_dict(torch.load(model_path,map_location=torch.device('cpu')))