[源碼解析] PyTorch 分布式 Autograd (1) ---- 設計
0x00 摘要
本文以幾篇PyTorch官方文檔為基礎來了解分布式 autograd 的設計和內部結構,在翻譯時並沒有逐字翻譯,其中加入了自己的部分理解。分布式 autograd 后續文章的分析也會基於本文進行。
PyTorch分布式其他文章如下:
[源碼解析]PyTorch如何實現前向傳播(1) --- 基礎類(上)
[源碼解析]PyTorch如何實現前向傳播(2) --- 基礎類(下)
[源碼解析] PyTorch如何實現前向傳播(3) --- 具體實現
[源碼解析] Pytorch 如何實現后向傳播 (1)---- 調用引擎
[源碼解析] Pytorch 如何實現后向傳播 (2)---- 引擎靜態結構
[源碼解析] Pytorch 如何實現后向傳播 (3)---- 引擎動態邏輯
[源碼解析] PyTorch 如何實現后向傳播 (4)---- 具體算法
[源碼解析] PyTorch 分布式(1)------歷史和概述
[源碼解析] PyTorch 分布式(2) ----- DataParallel(上)
[源碼解析] PyTorch 分布式(3) ----- DataParallel(下)
[源碼解析] PyTorch 分布式(4)------分布式應用基礎概念
[源碼解析] PyTorch分布式(5) ------ DistributedDataParallel 總述&如何使用
[源碼解析] PyTorch分布式(6) ---DistributedDataParallel -- 初始化&store
[源碼解析] PyTorch 分布式(7) ----- DistributedDataParallel 之進程組
[源碼解析] PyTorch 分布式(8) -------- DistributedDataParallel之論文篇
[源碼解析] PyTorch 分布式(9) ----- DistributedDataParallel 之初始化
[源碼解析] PyTorch 分布式(10)------DistributedDataParallel 之 Reducer靜態架構
[源碼解析] PyTorch 分布式(11) ----- DistributedDataParallel 之 構建Reducer和Join操作
[源碼解析] PyTorch 分布式(12) ----- DistributedDataParallel 之 前向傳播
[源碼解析] PyTorch 分布式(13) ----- DistributedDataParallel 之 反向傳播
0x01 分布式RPC框架
本文主要以 https://pytorch.org/docs/master/rpc/distributed_autograd.html 為基准,但是原文檔要求用戶熟悉 Autograd 機制和分布式 RPC 框架,因為我們已經分析過 Autograd 機制,所以我們先研究一下 分布式 RPC 框架。
1.1 RPC 框架
RPC(Remote Procedure Call)是一種設計或者技術思想,而不是協議或者規范。
對於 RPC 最簡單的理解就是一個節點請求另外一個節點所提供的服務,但是對於用戶代碼來說需要維護一個"本地調用"的感覺,即,對於遠程函數調用需要像調用本地的函數一樣,遠程服務或者代碼看起來像運行在本地。
RPC 需要解決幾個問題:
- 如何通訊:即如何在調用者和服務提供者之間建立連接。
- 如何尋址:即調用者如何找到服務提供者,怎么知道其中有什么服務。
- 如何發送參數:調用者發起遠程調用時候,方法的參數需要通過 TCP 等協議傳輸到服務器,參數如何序列化?
- 如何接受參數:服務提供者收到參數之后如何反序列化,如何調用。
- 如何返回:服務提供者調用本地提供的服務之后,如何把返回值發送給調用者。
1.2 PyTorch RPC 四大支柱
以下翻譯自官方文檔 https://pytorch.org/docs/master/rpc.html。
分布式 RPC 框架通過一組原語提供了多機模型訓練機制以允許遠程通信,以及一個更高級別的 API 來自動區分拆分到多台機器上的模型。分布式 RPC 框架使遠程運行函數變得容易,支持引用遠程對象而無需復制真實數據,並提供 autograd 和優化器 API 以透明地向后運行和跨 RPC 邊界更新參數。這些功能可以分為四組 API。
- **遠程過程調用 (RPC) ** 支持使用給定的參數在指定的worker上運行函數並獲取返回值或創建對返回值的引用。有三個主要的 RPC API:
rpc_sync()(同步)、rpc_async()(異步)和remote()(異步並返回對遠程返回值的引用)。如果用戶代碼在沒有返回值的情況下無法繼續,請使用同步 API。否則,使用異步 API 獲取 Future,並在調用者需要返回值時等待 Future。remote()API 在需要遠程創建某些內容但從不需要將其獲取給調用者時很有用。想象一下driver進程設置參數服務器和訓練器的情況。Driver 可以在參數服務器上創建嵌入表,然后與訓練器共享嵌入表的引用,但其本身永遠不會在本地使用嵌入表。在這種情況下,rpc_sync()和rpc_async()已不再適用,因為他們總是意味着立即或在將來把返回值發給調用者。 - 遠程引用 (RRef)用作指向本地或遠程對象的分布式共享指針。它可以與其他 worker 共享,並且引用計數將被透明處理。每個 RRef 只有一個所有者,並且對象只存在於該所有者之中。持有 RRef 的非所有者worker 可以通過明確請求從所有者那里獲取對象的副本。當 worker 需要訪問某個數據對象,但它本身既不是對象的創建者
remote()函數的調用者也不是對象的所有者時,這很有用。分布式優化器就是此類用例的一個示例。 - Distributed Autograd將所有參與前向傳播 worker的本地 autograd 引擎縫合在一起,並在后向傳播期間自動聯系他們以計算梯度。在進行前向傳遞如果需要跨越多台機器時,這尤其有用,例如分布式模型並行訓練、參數服務器訓練等。 有了這個特性,用戶代碼不再需要擔心如何跨 RPC 邊界發送梯度和應該以什么順序啟動本地 autograd 引擎,如果前向傳遞中有嵌套和相互依賴的 RPC 調用,這可能會變得非常復雜。
- 分布優化器的構造需要一個
Optimizer()(例如,SGD(),Adagrad()等)和一個RRefs的參數列表。即,在每個不同的Ref所有者之上創建一個Optimizer()實例,然后運行step()相應更新參數。當用戶進行分布式前向和后向傳播時,參數和梯度將分散在多個 worker 中,因此需要對每個相關 worker 進行優化。Distributed Optimizer 將所有這些本地優化器合而為一,並提供了簡潔的構造函數和step()API。
1.3 RRef
下面我們以 https://pytorch.org/docs/master/rpc/rref.html 為基准來學習遠程引用協議的基本概念和部分設計細節。
RRef 是遠程參考(Remote REFerence)的縮寫。 它是位於本地或遠程工作worker上對象的引用,並且透明地在內部進行引用計數。 從概念上講,它可以被視為一個分布式共享指針。 應用程序可以調用 remote() 創建 一個RRef。 每個 RRef 都被 remote() 的調用者(即所有者)所擁有,並且可以由多個用戶使用。 所有者存儲實際數據,並跟蹤全局參考計數。 每個 RRef 可以由全局RRefId唯一標識,該全局RRefId在創建時由 remote() 調用者分配。
在所有者worker中,只有一個OwnerRRef實例包含真實數據,而在用戶worker之中,可以根據需要包含任意數量的UserRRefs,UserRRef不保存數據。當使用 RRP 時,所有者將使用全局唯一的RRefId來獲取唯一的OwnerRRef實例。 在 rpc_sync() , rpc_async() 或 remote() 調用中,所有者創建一個UserRRef,並將其用作參數或返回值。所有者將被通知並且相應更新參考計數。 如果全局沒有UserRRef實例,並且所有者上也沒有對OwnerRRef的引用,則OwnerRRef及其數據將被刪除。
1.3.1 假設條件
RRef 協議的設計基於以下假設。
- 瞬態網絡故障(Transient Network Failures):RRef 設計旨在通過重試消息來應對瞬態網絡故障。 RRef不能處理節點崩潰或永久性網絡分區,當這些事件發生時,應用程序應該關閉所有worker,還原到先前的checkpoint,然后恢復訓練。
- 非冪等 UDF (Non-idempotent UDFs):我們假設提供給
rpc_sync(),rpc_async()或remote()的用戶函數(UDF)不是冪等的,因此無法重試。 但是,內部 RRef 控制消息是冪等且消息失敗時可重試。 - 消息傳遞無序(Out of Order Message Delivery):我們不會對一對節點之間的消息傳遞順序做假設,因為發送者和接收者都使用多個線程,所以無法保證首先處理哪個消息。
接下來我們只是大致講解如何使用,具體大家可以參閱 https://pytorch.org/docs/master/rpc.html#distributed-rpc-framework。
1.3.2 同步調用
如下是同步調用API,該方法在 worker to 之上執行一個阻塞 RPC 調用來運行func。RPC 消息的發送和接收與 Python 代碼的執行並行。此方法是線程安全的。
torch.distributed.rpc.rpc_sync( to , func , args = None , kwargs = None , timeout = - 1.0 )
具體參數如下:
- to – 目標worker的name/rank/WorkerInfo。
- func (callable) – 一個可調用函數,例如 Python callables、內置運算符(例如add())和帶注釋的 TorchScript 函數。
- args –
func調用的參數元組。 - kwargs –
func調用關鍵字參數的字典。 - timeout – 用於此 RPC 的超時時間(以秒為單位)
返回值就是使用args and kwargs運行 func 的結果。
樣例:
確保 MASTER_ADDR and MASTER_PORT 已經在兩個worker之上設置。
export MASTER_ADDR=localhost
export MASTER_PORT=5678
然后在兩個不同的進程中運行以下代碼
>>> # On worker 0:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
>>> rpc.shutdown()
>>> # On worker 1:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.shutdown()
1.3.2 異步調用
如下是異步調用API,該方法在 worker to 之上執行一個非阻塞 RPC 調用來運行func。RPC 消息的發送和接收與 Python 代碼的執行並行。此方法是線程安全的。該方法立刻返回一個可以被等待的Future。
torch.distributed.rpc.rpc_async(to, func, args=None, kwargs=None, timeout=- 1.0)
具體參數如下:
- to – 目標worker的name/rank/
WorkerInfo。 - func (callable) – 一個可調用函數,例如 Python callables、內置運算符(例如add())和帶注釋的 TorchScript 函數。
- args –
func調用的參數元組。 - kwargs – 是
func調用關鍵字參數的字典。 - timeout – 用於此 RPC 的超時時間(以秒為單位)
返回一個可等待的Future對象。完成后,可以從 對象中檢索出func的返回值。
樣例:
確保 MASTER_ADDR and MASTER_PORT 已經在兩個worker之上設置。
>>> export MASTER_ADDR=localhost
>>> export MASTER_PORT=5678
然后在兩個不同的進程中運行以下代碼
>>> # On worker 0:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3))
>>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2))
>>> result = fut1.wait() + fut2.wait()
>>> rpc.shutdown()
>>> # On worker 1:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.shutdown()
0x02 示例
我們接下來以 https://pytorch.org/docs/master/rpc/distributed_autograd.html 為基礎進行學習。
假設您有兩個節點和一個跨兩個節點分區的非常簡單的模型。這可以使用torch.distributed.rpc如下實現。
分布式 autograd 背后的主要動機是在這種分布式模型上運行反向傳播loss,我們已經計算並記錄了所有需要梯度的張量的梯度。
import torch
import torch.distributed.rpc as rpc
def my_add(t1, t2):
return torch.add(t1, t2)
# On worker 0:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
# Perform some computation remotely.
t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))
# Perform some computation locally based on remote result.
t4 = torch.rand((3, 3), requires_grad=True)
t5 = torch.mul(t3, t4)
# Compute some loss.
loss = t5.sum()
0x03 前向傳播期間的 Autograd 記錄
PyTorch 在前向傳播期間構建 autograd 圖,該圖用於執行后向傳播。有關更多詳細信息,請參閱 autograd 如何編碼歷史記錄。
對於分布式 autograd,我們需要在前向傳播期間跟蹤所有 RPC,以確保正確執行后向傳播。為此,當執行 RPC 時候,我們把 send和recv functions 附加到autograd圖之上。
- 該
send函數附加到 RPC 的發起源節點之上,其輸出邊指向 RPC 輸入張量的 autograd 函數。在向后傳播期間,send函數的輸入是從目標接收的,是對應recv函數的輸出。 - 該
recv函數附加到 RPC 的接受目標節點之上,其輸入從某些運算符得到,這些運算符使用輸入張量在RPC接受目標上執行。在后向傳播期間,recv函數的輸出梯度將被發送到源節點之上,並且作為send方法的輸入。 - 每
send-recv對被分配一個全局唯一的autograd_message_id以唯一地標識該send-recv對。這對於在向后傳播期間查找遠程節點上的相應函數很有用。 - 對於RRef,每當我們調用
torch.distributed.rpc.RRef.to_here()時,我們都為涉及的張量添加了一個適當的send-recv對。
例如,這就是我們上面示例的 autograd 圖的樣子(為簡單起見,t5.sum() 被排除在外)。
我們可以看到,send方法在前向傳播中是發送者,但是在反向傳播之中就是接受者。

0x04 分布式 Autograd 上下文
每個使用分布式 autograd 的前向和后向傳播都被分配了一個唯一的torch.distributed.autograd.context,並且這個上下文具有一個全局唯一的autograd_context_id 。如果有需要,在每個節點上都會創建上下文。
上下文的作用如下:
- 運行分布式反向傳播的多個節點可能會在同一個張量上累積梯度並且存儲在張量的
.grad之上。在我們運行優化器之前,張量的.grad可能累積了來自各種分布式反向傳播的梯度。這類似於把torch.autograd.backward()在本地進行多次調用。為了提供一種把每個反向傳播梯度分離開的方法,在每個反向傳播過程里,梯度將被累積在torch.distributed.autograd.context之中。 - 在前向傳播期間,我們在上下文中存儲每個 autograd 傳播的
send和recv函數。這確保我們在 autograd 圖中保存對適當節點的引用以使其保持活動狀態。除此之外,這也使得在向后傳播期間很容易查找到對應的send和recv函數。 - 一般來說,我們也使用這個上下文來存儲每個分布式 autograd 傳播的一些元數據。
從用戶的角度來看,autograd 上下文設置如下:
import torch.distributed.autograd as dist_autograd
with dist_autograd.context() as context_id:
loss = model.forward()
dist_autograd.backward(context_id, loss)
需要注意的是,模型的前向傳播必須在分布式autograd上下文管理器中調用,因為需要一個有效的上下文來確保:所有的send和recv方法被存儲起來,並且在所有參與節點之上執行后向傳播。
0x05 分布式反向傳播
在本節中,我們將概述在分布式反向傳播期間准確計算依賴關系所遇到的挑戰,並且也講述幾種如何執行分布式反向傳播的算法(算法內部有權衡)。
5.1 計算依賴關系
首先,考慮在單台機器上運行以下代碼
import torch
a = torch.rand((3, 3), requires_grad=True)
b = torch.rand((3, 3), requires_grad=True)
c = torch.rand((3, 3), requires_grad=True)
d = a + b
e = b * c
d.sum.().backward()
下圖就是上面代碼對應的 autograd 圖。
作為反向傳播的一部分,autograd 引擎執行的第一步是計算 autograd 圖中每個節點的依賴項數量。這有助於 autograd 引擎知道圖中的節點何時准備好了可以執行。括號內為數字add(1)和mul(0)表示依賴關系的數量。如您所見,這意味着在向后傳播期間,add 節點需要 1 個輸入,mul節點不需要任何輸入(換句話說,不需要執行)。本地 autograd 引擎通過從根節點(在本例中是d)遍歷圖來計算這些依賴關系。
實際上,Autograd 圖中的某些節點可能不會在向后傳播中執行。這一事實對分布式 autograd 提出了挑戰。考慮這段使用 RPC 的代碼。
import torch
import torch.distributed.rpc as rpc
a = torch.rand((3, 3), requires_grad=True)
b = torch.rand((3, 3), requires_grad=True)
c = torch.rand((3, 3), requires_grad=True)
d = rpc.rpc_sync("worker1", torch.add, args=(a, b))
e = rpc.rpc_sync("worker1", torch.mul, args=(b, c))
loss = d.sum()
上面代碼的關聯 autograd 圖將是:

計算此分布式 autograd 圖的依賴項更具挑戰性,並且需要一些開銷(在計算或網絡通信方面)。
對於性能敏感的應用,我們可以通過假設每個send和recv函數都是反向傳播的有效成分來避免大量開銷(大多數應用不會執行未使用的 RPC)。這簡化了分布式 autograd 算法並且效率更高,但代價是應用程序需要了解這些限制。這種算法稱為FAST模式算法,下面詳細介紹。
在一般情況下, 作為向后傳播的一部分,可能不需要每個send和recv函數都是有效的。為了解決這個問題,我們提出了一種SMART 模式算法,此算法將在后面的部分中描述。請注意,目前僅實現了FAST模式算法。
5.2 FAST模式算法
該算法的關鍵假設是:當我們運行反向傳播時,每個send函數的依賴為 1。換句話說,我們假設我們會從另一個節點通過 RPC 接收梯度。
算法如下:
- 我們從具有反向傳播根的worker開始(所有根都必須是本地的)。
- 查找當前Distributed Autograd Context 的所有
send函數 。 - 從提供的根和我們檢索到的所有
send函數開始,我們在本地計算依賴項 。 - 計算依賴項后,使用提供的根來啟動本地 autograd 引擎。
- 當 autograd 引擎執行該
recv函數時,該recv函數通過 RPC 將輸入梯度發送到適當的worker。每個recv函數都知道目標 worker id,因為它被記錄為前向傳播的一部分。通過autograd_context_id和autograd_message_id該recv函數被發送到遠程主機。 - 當遠程主機收到這個請求時,我們使用
autograd_context_id和autograd_message_id來查找適當的send函數。 - 如果這是worker第一次收到對給定
autograd_context_id的請求,它將按照上面的第 1-3 點所述在本地計算依賴項。 - 然后將在第6點接受到的
send方法插入隊列,以便在該worker的本地 autograd 引擎上執行。 - 最后,我們不是在 Tensor的
.grad之上累積梯度,而是在每個Distributed Autograd Context之上分別累積梯度 。梯度存儲在Dict[Tensor, Tensor]之中 ,Dict[Tensor, Tensor]基本上是從 Tensor 到其關聯梯度的映射,並且可以使用 get_gradients() API檢索該映射 。
例如,分布式 autograd 的完整代碼如下:
import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
def my_add(t1, t2):
return torch.add(t1, t2)
# On worker 0:
# Setup the autograd context. Computations that take
# part in the distributed backward pass must be within
# the distributed autograd context manager.
with dist_autograd.context() as context_id:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
# Perform some computation remotely.
t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))
# Perform some computation locally based on remote result.
t4 = torch.rand((3, 3), requires_grad=True)
t5 = torch.mul(t3, t4)
# Compute some loss.
loss = t5.sum()
# Run the backward pass.
dist_autograd.backward(context_id, [loss])
# Retrieve the gradients from the context.
dist_autograd.get_gradients(context_id)
具有依賴關系的分布式 autograd 圖如下(為簡單起見,t5.sum() 被排除在外):

應用於上述示例的FAST 模式算法如下:
- 在
Worker 0上,我們從根loss和send1開始計算依賴關系。 結果,send1對Worker 0的依賴數為 1,mul對Worker 0的依賴數為 1。 - 現在,我們在
Worker 0上啟動本地 autograd 引擎。 我們首先執行mul函數,將其輸出作為t4的梯度,累積存儲在 autograd 上下文中。 然后,我們執行recv2,它將這些梯度發送到Worker 1。 - 由於這是
Worker 1第一次知道有關此反向傳播的信息,因此它將進行依賴關系計算,並且相應地標記send2,add和recv1的依賴性。 - 接下來,在
Worker 1的本地autograd引擎上將send2插入隊列,該引擎將依次執行add和recv1。 - 當執行
recv1時,它將梯度發送到Worker 0。 - 由於
Worker 0已經計算了此向后傳播的依賴性,因此它僅僅在本地將send1插入隊列並且執行。 - 最后,
t1,t2和t4的梯度會累積在分布式 Autograd 上下文中。
5.3 SMART模式算法
該算法的全部細節仍在研究中,但對於總體思路,您可以參考RFC中的分布式 Autograd 算法智能模式部分 。
0x06 分布式優化器
該DistributedOptimizer操作如下:
- 獲取要優化的遠程參數(
RRef)列表。這些參數也可以是包含在本地RRef的本地參數。 - 將一個
Optimizer類作為本地優化器,該優化器將在所有不同的RRef擁有者之上運行。 - 分布式優化器在每個工作節點上創建一個本地
Optimizer實例,並且對於每一個Optimizer保存一個RRef。 - 當調用
torch.distributed.optim.DistributedOptimizer.step()時,分布式優化器使用 RPC 在適當的遠程工作者上遠程執行所有本地優化器。必須為torch.distributed.optim.DistributedOptimizer.step()提供一個分布式autogradcontext_id。 本地優化器使用context_id在相應上下文中存儲梯度。 - 如果多個並發分布式優化器正在更新一個 worker 上的同一批參數,這些更新將通過鎖來進行序列操作。
0x07 簡單的端到端示例
綜上所述,以下是一個使用分布式 autograd 和分布式優化器的簡單端到端示例。如果將代碼放入名為“dist_autograd_simple.py”的文件中,則可以使用以下命令運行 :MASTER_ADDR="localhost" MASTER_PORT=29500 python dist_autograd_simple.py
import torch
import torch.multiprocessing as mp
import torch.distributed.autograd as dist_autograd
from torch.distributed import rpc
from torch import optim
from torch.distributed.optim import DistributedOptimizer
def random_tensor():
return torch.rand((3, 3), requires_grad=True)
def _run_process(rank, dst_rank, world_size):
name = "worker{}".format(rank)
dst_name = "worker{}".format(dst_rank)
# Initialize RPC.
rpc.init_rpc(
name=name,
rank=rank,
world_size=world_size
)
# Use a distributed autograd context.
with dist_autograd.context() as context_id:
# Forward pass (create references on remote nodes).
rref1 = rpc.remote(dst_name, random_tensor)
rref2 = rpc.remote(dst_name, random_tensor)
loss = rref1.to_here() + rref2.to_here()
# Backward pass (run distributed autograd).
dist_autograd.backward(context_id, [loss.sum()])
# Build DistributedOptimizer.
dist_optim = DistributedOptimizer(
optim.SGD,
[rref1, rref2],
lr=0.05,
)
# Run the distributed optimizer step.
dist_optim.step(context_id)
def run_process(rank, world_size):
dst_rank = (rank + 1) % world_size
_run_process(rank, dst_rank, world_size)
rpc.shutdown()
if __name__ == '__main__':
# Run world_size workers
world_size = 2
mp.spawn(run_process, args=(world_size,), nprocs=world_size)
0xFF 參考
https://pytorch.org/docs/master/rpc/distributed_autograd.html#distributed-autograd-design
https://pytorch.org/docs/master/rpc.html#distributed-autograd-framework
https://pytorch.org/docs/master/rpc/rref.html
https://pytorch.org/docs/master/rpc.html#distributed-rpc-framework

