torch.load()的作用


torch.load()的作用:從文件加載用torch.save()保存的對象。

api:

torch.load(f, map_location=None, pickle_module=<module 'pickle' from '/opt/conda/lib/python3.6/pickle.py'>, **pickle_load_args) 

參數:

  • f: 類似文件的對象(必須實現read(),:meth ' readline ',:meth ' tell '和:meth ' seek '),或者是包含文件的字符串。
  • map_location: 函數、torch.device或者字典指明如何重新映射存儲位置。
  • pickle_module : 用於unpickling元數據和對象的模塊(必須匹配用於序列化文件的pickle_module)。
  • pickle_load_args: 傳遞給pickle_module.load()和pickle_module.Unpickler()的可選關鍵字參數。

使用

默認加載方式,使用cpu加載cpu訓練得出的模型或者用gpu調用gpu訓練的模型:

torch.load('tensors.pt') 

將全部Tensor全部加載到cpu上:

torch.load('tensors.pt', map_location=torch.device('cpu')) 

使用函數將所有張量加載到CPU(適用在GPU訓練的模型在CPU上加載):

torch.load('tensors.pt', map_location=lambda storage, loc: storage) 

將所有張量加載到第一塊GPU(在CPU訓練在GPU加載):

torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1)) 

將張量從GPU 1映射到GPU 0(第一塊GPU訓練,第二塊GPU加載):

torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'}) 

根據你的設備,將張量加載到你當前設備上:

torch.load('modelparameters.pth', map_location = device)


作者:Blureyes
鏈接:https://www.jianshu.com/p/939de37f73e7
著作權歸作者所有。商業轉載請聯系作者獲得授權,非商業轉載請注明出處。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM