torch保存加載模型


三個核心函數

torch.save() 
torch.load()
torch.nn.Module.load_state_dict()

狀態字典定義

狀態字典本質上就是普通的python字典。

  • 對於具有可學習參數的網絡層來說,狀態字典的鍵就是網絡層,值就是對應的參數張量。
    大概如下圖所示,網絡層的可學習參數包括權重和偏置等。

    當然batchnorm層也有需要保存的參數,比如running_mean。
  • 對於優化器對象也有自己的狀態字典。其中包含了優化器狀態信息和超參數。優化器的狀態字典一般只有斷點訓練的時候才使用,畢竟推理也用不到優化器。

只保存/加載模型參數(推薦做法)

# 保存模型參數
torch.save(model.state_dict(), PATH)  
# 加載模型參數並用於推理
model = MyModel()
model.load_static_dict(torch.load(PATH))
model.eval()
  • torch.save()保存的文件后綴通常是 .pt 或 .pth
  • 保存模型參數的對象model和加載模型參數的對象model應該是同一個類的實例。
  • load_static_dict()方法的參數是一個字典,必須先用torch.load()把保存的參數轉化成python字典。
  • 進行推理之前,必須先用model.eval()把dropout和BN層置為驗證模式。

保存/加載整個模型

# 保存整個模型
torch.save(model, PATH)
# 加載整個模型
model = torch.load(PATH)
model.eval()

斷點訓練checkpoint使用

# 保存斷點狀態,保存的文件后綴一般是.tar。
torch.save({
  'epoch': epoch,
  'model_state_dict': model.state_dict(),
  'loss': loss,
  ...
}, PATH)

# 加載斷點
model = MyModel()
optimizer = MyOptimizer()

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.train()
# model.eval() # 恢復斷點之后直接推理也是可以的

同一個文件中保存多個模型

# 其實本質上跟checkpoint的使用是一樣的
torch.save({
            'modelA_state_dict': modelA.state_dict(),
            'modelB_state_dict': modelB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            ...
            }, PATH)
# 加載多個模型,本質上跟checkpoint也是一樣的,保存文件后綴名也是.tar
modelA = MyModel()
modelB = MyModel()
optimizerA = MyOptimizer()
optimizerB = MyOptimizer()

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()

用一個模型的參數來初始化另一個不同模型

# 保存模型參數
torch.save(modelA.state_dict(), PATH)
# 加載模型參數
modelB = MyModel()
modelB.load_state_dict(torch.load(PATH), strict=False)
  • load_state_dict()方法中strict=False表示忽略不匹配的網絡層,畢竟兩個網絡不一樣

不同設備保存/加載模型

  • 保存時候沒區別,反正都是保存到磁盤上
    torch.save(model.state_dict(), PATH)
  • 加載模型到cpu上
    device = torch.device('cpu')
    model = MyModel()
    model.load_state_dict(torch.load(PATH, map_location=device))
    
  • 加載模型到GPU上
    # 有點奇怪,為啥不用map_location參數,而要先加載再轉移到GPU上
    device = torch.device('cuda')
    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH)
    model.to(device)
    


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM