pytorch保存和加載cpu,GPU,以及多GPU模型


https://www.jianshu.com/p/4905bf8e06e5

上面這個鏈接主要給出了PyTorch如何保存和加載模型

今天遇到了單GPU保存模型,然后多GPU加載模型出現錯誤的情況。在此記錄。

from collections import OrderedDict
def load_pretrainedmodel(modelname, model_):
    pre_model = torch.load(modelname, map_location=lambda storage, loc: storage)["model"]
    #print(pre_model)
    if cuda:
        state_dict = OrderedDict()
        for k in pre_model.state_dict():
            name = k
            if name[:7] != 'module' and torch.cuda.device_count() > 1: # loaded model is single GPU but we will train it in multiple GPUS!
                name = 'module.' + name #add 'module'
            elif name[:7] == 'module' and torch.cuda.device_count() == 1: # loaded model is multiple GPUs but we will train it in single GPU!
                name = k[7:]# remove `module.` 
            state_dict[name] = pre_model.state_dict()[k]
            #print(name)
        model_.load_state_dict(state_dict)
        #model_.load_state_dict(torch.load(modelname)['model'].state_dict())
    else:
        model_ = torch.load(modelname, map_location=lambda storage, loc: storage)["model"]
    return model_


由於多GPU的模型參數會多出‘module.’這個前綴,所以有時要加上有時要去掉。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM