使用DataParallel進行並行化時的結構如下:
在上圖第一行第四個步驟中,GPU-1 其實匯集了所有 GPU 的運算結果。這個對於多分類問題還好,但如果是自然語言處理模型就會出現問題,導致 GPU-1 匯集的梯度過大,直接爆掉。
那么就要想辦法實現多 GPU 的負載均衡,方法就是讓 GPU-1 不匯集梯度,而是保存在各個 GPU 上。這個方法的關鍵就是要分布化我們的損失函數,讓梯度在各個 GPU 上單獨計算和反向傳播。這里又一個開源的實現:https://github.com/zhanghang1989/PyTorch-Encoding。這里是一個修改版,可以直接在我們的代碼里調用:地址。實例:
from parallel import DataParallelModel, DataParallelCriterion parallel_model = DataParallelModel(model) # 並行化model parallel_loss = DataParallelCriterion(loss_function) # 並行化損失函數 predictions = parallel_model(inputs) # 並行前向計算 # "predictions"是多個gpu的結果的元組 loss = parallel_loss(predictions, labels) # 並行計算損失函數 loss.backward() # 計算梯度 optimizer.step() # 反向傳播 predictions = parallel_model(inputs)
如果你的網絡輸出是多個,可以這樣分解:
output_1, output_2 = zip(*predictions)
如果有時候不想進行分布式損失函數計算,可以這樣手動匯集所有結果:
gathered_predictions = parallel.gather(predictions)
下圖展示了負載均衡以后的原理: