[python][pytorch]多GPU下的模型保存與加載


說明

在模型訓練的時候,往往使用的是多GPU的環境;但是在模型驗證或者推理階段,往往使用單GPU甚至CPU進行運算。那么中間有個保存和加載的過程。下面來總結一下。

多GPU進行訓練

首先設置可見的GPU數量,有兩種方式可以聲明:

  1. 在shell腳本中聲明:
export CUDA_VISIBLE_DEVICES=0,1,2,3
  1. 在py文件中聲明
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda

推薦使用前者進行聲明,因為后者可能會出現失效的情況。

多GPU模型加載

其次,要將模型分發到不同的GPU。

model = Model(args)
if torch.cuda.is_available() and args.use_gpu:
    model= model.cuda()
    model = torch.nn.DataParallel(model)

當然這里只是涉及到一個簡單的模型並行加載,里面還埋着其他的坑,如果是小數據集且顯存夠用,完全不用優化,但是如果不夠用,我們后面會詳細深挖並行中出現坑。

模型保存

等到訓練完成之后,需要將模型保存起來。需要注意的是,模型此時保存的是計算圖+參數是並行的,但是參數是單GPU的。

state = {
    'epoch': epoch,
    'model': args.model,
    'dataset': args.dataset,
    'state_dict': net.module.state_dict() if isinstance(net, nn.DataParallel) else net.state_dict(),
    'acc': top1.avg,
    'optimizer': optimizer.state_dict(),
}
torch.save(state, filename)

如果服務器環境變化不大,或者和訓練時候是同一個GPU環境,直接加載model就不會出現問題,否則建議直接使用參數加載。

模型加載

由於模型訓練和部署情況的多樣性,大致可以分為以下幾種情況:

  1. 單卡訓練,單卡加載部署,單CPU和GPU統一放到這一類。舉例:在GPU上訓練,在CPU上加載。或者在GPU上訓練,在GPU上加載。
    這類情況最簡單,簡單粗暴直接寫就行。
model = Model(args)
ckpt = torch.load(args.pretrained_model, map_location='cpu')
state = ckpt['state_dict']
net.load_state_dict(state)

注意map_location的參數,如果在gpu上進行加載,則聲明map_location='cuda:0'。如果不聲明,可能會報錯,input和weight的類型不一致。

  1. 多卡訓練,單卡加載部署。舉例:在多GPU上並行訓練,在單GPU或CPU上加載。
    這種情況要防止參數保存的時候沒有加module,那么保存的參數名稱是module.conv1.weight,而單卡的參數名稱是conv1.weight,這時就會報錯,找不到相應的字典的錯誤。
    此時可以通過手動的方式刪減掉模型中前幾位的名稱,然后重新加載。
kwargs={'map_location':lambda storage, loc: storage.cuda(gpu_id)}
def load_GPUS(model,model_path,kwargs):
    state_dict = torch.load(model_path,**kwargs)
    # create new OrderedDict that does not contain `module.`
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:] # remove `module.`
        new_state_dict[name] = v
    # load params
    model.load_state_dict(new_state_dict)
    return model
  1. 單卡訓練,多卡加載部署。舉例:多見於暴發戶的情況,一開始只能單卡跑,后來有了多卡,但是單卡的參數有不想浪費。
    此時唯有記住一點,因為參數是沒有module的,而加載后的參數是有module的,因此需要保證參數加載在模型分發之前。
    即保證:
    net.load_state_dict(state)model = torch.nn.DataParallel(model)之前。

  2. 多卡訓練,多卡加載部署。環境如果沒有變化,則可以直接加載,如果環境有變化,則可以拆解成第2種情況,然后再分發模型。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM