使用Pytorch進行單機多卡分布式訓練


一. torch.nn.DataParallel ?

pytorch單機多卡最簡單的實現方法就是使用nn.DataParallel類,其幾乎僅使用一行代碼net = torch.nn.DataParallel(net)就可讓模型同時在多張GPU上訓練,它大致的工作過程如下圖所示:
DataParallel前傳與反傳工作過程

每一個IterationForward過程中,nn.DataParallel都自動將輸入按照gpu_batch進行split,然后復制模型參數到各個GPU上,分別進行前傳后將得到網絡輸出,最后將結果concat到一起送往0號卡中。

Backward過程中,先由0號卡計算loss函數,通過loss.backward()得到損失函數相於各個gpu輸出結果的梯度grad_l1 ... gradln,接下來0號卡將所有的grad_l送回對應的GPU中,然后GPU們分別進行backward得到各個GPU上面的模型參數梯度值gradm1 ... gradmn,最后所有參數的梯度匯總到GPU0卡進行update。

注:DataParallel的整個並行訓練過程利用python多線程實現

由以上工作過程分析可知,nn.DataParallel有着這樣幾個無法避免的問題:

  1. 負載不均衡問題。gpu0所承擔的任務明顯要重於其他gpu
  2. 速度問題。每個iteration都需要復制模型且均從GPU0卡向其他GPU復制,通訊任務重且效率低;python多線程GIL鎖導致的線程顛簸(thrashing)問題。
  3. 只能單機運行。由於單進程的約束導致。
  4. 只能切分batch到多GPU,而無法讓一個model分布在多個GPU上。當一個模型過大,設置batchsize=1時其顯存占用仍然大於單張顯卡顯存,此時就無法使用DataParallel類進行訓練。

因此官方推薦使用torch.nn.DistributedDataParallel替代nn.DataParallel.


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM