pytorch使用DataParallel並行化負載不均衡問題


使用DataParallel進行並行化時的結構如下:

 

 

在上圖第一行第四個步驟中,GPU-1 其實匯集了所有 GPU 的運算結果。這個對於多分類問題還好,但如果是自然語言處理模型就會出現問題,導致 GPU-1 匯集的梯度過大,直接爆掉。

那么就要想辦法實現多 GPU 的負載均衡,方法就是讓 GPU-1 不匯集梯度,而是保存在各個 GPU 上。這個方法的關鍵就是要分布化我們的損失函數,讓梯度在各個 GPU 上單獨計算和反向傳播。這里又一個開源的實現:https://github.com/zhanghang1989/PyTorch-Encoding。這里是一個修改版,可以直接在我們的代碼里調用:地址。實例:

from parallel import DataParallelModel, DataParallelCriterion
 
parallel_model = DataParallelModel(model)             # 並行化model
parallel_loss  = DataParallelCriterion(loss_function) # 並行化損失函數
 
predictions = parallel_model(inputs)      # 並行前向計算
                                          # "predictions"是多個gpu的結果的元組
loss = parallel_loss(predictions, labels) # 並行計算損失函數
loss.backward()                           # 計算梯度
optimizer.step()                          # 反向傳播
predictions = parallel_model(inputs)

如果你的網絡輸出是多個,可以這樣分解:

output_1, output_2 = zip(*predictions)

 

如果有時候不想進行分布式損失函數計算,可以這樣手動匯集所有結果:

gathered_predictions = parallel.gather(predictions)

下圖展示了負載均衡以后的原理:

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM