一. torch.nn.DataParallel ?
pytorch單機多卡最簡單的實現方法就是使用nn.DataParallel類,其幾乎僅使用一行代碼net = torch.nn.DataParallel(net)
就可讓模型同時在多張GPU上訓練,它大致的工作過程如下圖所示:
在每一個Iteration的Forward過程中,nn.DataParallel都自動將輸入按照gpu_batch進行split,然后復制模型參數到各個GPU上,分別進行前傳后將得到網絡輸出,最后將結果concat到一起送往0號卡中。
在Backward過程中,先由0號卡計算loss函數,通過loss.backward()
得到損失函數相於各個gpu輸出結果的梯度grad_l1 ... gradln,接下來0號卡將所有的grad_l送回對應的GPU中,然后GPU們分別進行backward得到各個GPU上面的模型參數梯度值gradm1 ... gradmn,最后所有參數的梯度匯總到GPU0卡進行update。
注:DataParallel的整個並行訓練過程利用python多線程實現
由以上工作過程分析可知,nn.DataParallel有着這樣幾個無法避免的問題:
- 負載不均衡問題。gpu0所承擔的任務明顯要重於其他gpu
- 速度問題。每個iteration都需要復制模型且均從GPU0卡向其他GPU復制,通訊任務重且效率低;python多線程GIL鎖導致的線程顛簸(thrashing)問題。
- 只能單機運行。由於單進程的約束導致。
- 只能切分batch到多GPU,而無法讓一個model分布在多個GPU上。當一個模型過大,設置batchsize=1時其顯存占用仍然大於單張顯卡顯存,此時就無法使用DataParallel類進行訓練。
因此官方推薦使用torch.nn.DistributedDataParallel替代nn.DataParallel.