需求
- 對基於pytorch的深度學習模型進行多卡訓練以加速訓練過程
- 由於顯卡版本過於老舊,安裝配置NCCL工程量過於龐大,希望使用簡單的pytorch代碼實現單機多卡訓練,不考慮多機多卡的顯卡通信
- 訓練完成后保存的checkpoint需要能夠在任何設備上進行加載、推理
實現
訓練
- pytorch提供了簡單的單機多卡訓練api,只需要在初始化模型之后執行下列語句將模型復制到多卡上
# initiate multi-gpu training
model = nn.DataParallel(model, device_ids=<ids of the gpus you want to use>)
- 其他操作與單卡訓練完全一致
加載checkpoint
- 上述操作后保存的checkpoint如果按照常規方法直接進行加載會報錯
RuntimeError: Error(s) in loading state_dict for <ModelName>:
Missing key(s) in state_dict:...
- debug遍歷后發現其實其狀態字典是完全一致的,只是因為我們在訓練過程中將模型定義為了多卡並行模型。這里只需要按照訓練過程中轉換為多卡模型的代碼初始化當前模型結構即可,即執行:
# initiate multi-gpu training
model = nn.DataParallel(model, device_ids=<ids of the gpus you want to use>)
- 其他操作與征程推理完全一致,若不想使用多卡/只想使用cpu,只需要按照常規將
device = torch.device("<cpu/cuda:id>")即可
Note:查閱資料過程中發現有解答建議使用參數強行忽略模型加載的錯誤torch.load(<checkpoint>, strict=False),經測試,這樣加載的模型啥也不是...不知道為什么pytorch官方要提供這個接口
