torch.load()
的作用:從文件加載用torch.save()
保存的對象。
api:
torch.load(f, map_location=None, pickle_module=<module 'pickle' from '/opt/conda/lib/python3.6/pickle.py'>, **pickle_load_args)
參數:
- f: 類似文件的對象(必須實現read(),:meth ' readline ',:meth ' tell '和:meth ' seek '),或者是包含文件的字符串。
- map_location: 函數、torch.device或者字典指明如何重新映射存儲位置。
- pickle_module : 用於unpickling元數據和對象的模塊(必須匹配用於序列化文件的pickle_module)。
- pickle_load_args: 傳遞給pickle_module.load()和pickle_module.Unpickler()的可選關鍵字參數。
使用
默認加載方式,使用cpu加載cpu訓練得出的模型或者用gpu調用gpu訓練的模型:
torch.load('tensors.pt')
將全部Tensor
全部加載到cpu
上:
torch.load('tensors.pt', map_location=torch.device('cpu'))
使用函數將所有張量加載到CPU
(適用在GPU
訓練的模型在CPU
上加載):
torch.load('tensors.pt', map_location=lambda storage, loc: storage)
將所有張量加載到第一塊GPU
(在CPU
訓練在GPU
加載):
torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
將張量從GPU 1
映射到GPU 0
(第一塊GPU
訓練,第二塊GPU
加載):
torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
根據你的設備,將張量加載到你當前設備上:
torch.load('modelparameters.pth', map_location = device)
作者:Blureyes
鏈接:https://www.jianshu.com/p/939de37f73e7
著作權歸作者所有。商業轉載請聯系作者獲得授權,非商業轉載請注明出處。