[源碼解析] 深度學習流水線並行 GPipe(3) ----重計算
0x00 摘要
GPipe是一個基於 Lingvo (Lingvo 是 Google 基於 TensorFlow 二次開發的重點針對序列模型的框架)開發的,支持超大規模模型的神經網絡訓練並行庫,本文介紹其重計算功能,同時可以和其他實現一起印證。
本系列其他文章如下:
[源碼解析] 深度學習流水線並行Gpipe(1)---流水線基本實現
[源碼解析] 深度學習流水線並行GPipe (2) ----- 梯度累積
0x01 概述
1.1 前文回顧
前文提到,目前分布式模型訓練有幾個必要並行技術:
- 流水並行,尤其是如何自動設定流水;
- 梯度累加(Gradient Accumulation);
- 后向重計算;
- 1F1B 策略(我們將采用PipeDream分析);
在前文中,我們介紹了Gpipe如何實施流水線並行技術,以及梯度累積。
流水並行存在一個問題:顯存占用太大。如果每個 micro-batch 前向計算的中間結果(activation)被后向計算所消費,則需要在顯存中緩存 n 份(梯度累加的次數)完整的前向 activation。這時就不得不用另一項重要的技術:重計算(Checkpointing)。
本文以論文"Training deep nets with sublinear memory cost"為基礎,對於 pytorch 和 Gpipe 源碼 進行分析,期望可以對 “Gradient checkpointing”技術有一個具體的理解。
1.2 Gradient checkpointing
2016年,陳天奇團隊提出了亞線性內存優化相關的 "gradient/activation checkpointing(后向重計算)"等技術,旨在降低深度學習訓練過程中的中間激活(activation)帶來的顯存占用。Checkpointing技術屬於亞線性內存優化的一種,除此之外還有CPU offload等技術(CPU offload在微軟Deepspeed框架中被廣泛使用)。
梯度檢查點是一種減少深度神經網絡訓練時內存消耗的系統性方法,具體是在反向傳播中,針對每個設定為檢查點的段,通過重新運行前向傳播段來實現的:
- 梯度檢查點方法集中在減少降低存儲中間結果(特征圖)和梯度的內存開銷,因為在許多常見的深度網絡之中,與模型參數相比,中間結果要大得多。
- 梯度檢查點是一種以時間(算力)換空間(顯存)的方法,通過減少保存的激活值壓縮模型占用空間,但是在計算梯度時必須重新計算沒有存儲的激活值,即需要花兩倍的前向傳播計算時間。
- 具體來說,就是設置一些梯度檢查點,檢查點之外的中間結果先釋放掉,將來在反向傳播的過程中如果發現前向結果不在顯存中,就找到最近的梯度檢查點再進行前向計算,恢復出被釋放的張量。
0x02 背景知識
2.1 求導如何工作
此處借鑒了 訓練時顯存優化技術——OP合並與gradient checkpoint 的思路。
DNN模型由一系列不同類型的層組成(例如卷積層,全連接層,池化層)。
反向傳播的關鍵是“自動鏈式求導”,但實際上BP在這個基礎上也加入了一點動態規划機制。一般的BP包含以下兩個步驟:
- 前向傳導。以圖像分類為例,當前模型首先對一小部分訓練樣本(也稱為minibatch)進行預測。這個過程被稱為向前傳導。
- 為了進行預測,來自小批量的輸入數據被輸入到模型的第一層。
- 然后,每一層在其輸入上計算一個函數,為下一層生成輸出。前向傳導記錄以下兩個值:中間結點的輸出值,輸出值關於輸入值的梯度。
- 最后一層的輸出是類預測。基於模型的預測標簽和每個圖像的實際標簽,輸出層計算損失(或錯誤)。
- 反向傳播梯度計算。反向傳播就是一個計算網絡最終輸出值關於本層輸出的梯度的過程。即,從輸出開始,反向傳播梯度值,計算輸出值對於每一個中間變量的梯度,並保存。每層計算 前一層的誤差,和 所有相關層的權重更新(損失梯度),這將使模型的預測朝着所需的輸出移動。
在梯度回傳的過程中需要用到節點的輸出值,但是在反向傳播進行梯度計算的時候,BP不會進行重復計算,其原因就是在前向傳導的時候,進行了中間變量的存儲,也就是每個中間節點的輸出值。BP不斷地反向傳播梯度,並保存中間梯度,直到計算圖的所有中間值以及初始值的梯度被求解完畢。
我們看看反向傳播是如何工作的。
所謂自動求導框架實際上是“半自動”的:它並非直接求出一個復雜函數導數的解析形式,而是通過構建計算圖和預先寫好的基礎函數的求導規則,結合鏈式求導法則實現的自動求導。
我們假設一個函數為例進行說明,其表達式如下:
f(x) = x * (x + 1)
通過簡單的數學推導得到其梯度的解析式為f'(x) = x + 1 + x
;先把這個結果放一邊,看看自動求導框架是如何一步步求出這個結果的,畫出計算圖如下:
+---------+
| |
+------>+ x + 1 +----+
| | | | 3
2 | +---------+ |
| |
| v
+-----+--+ ++------+
| | | |
+------> | x +----------------> | + +---------->
| | 1 | |
+--------+ +-------+
在計算圖上,反向傳播先經過乘法運算,根據上面的求導規則:
- 路徑1上的梯度為
x + 1
; - 路徑3上的梯度為
x
; - 路徑3再反向傳播要經過路徑2,除了其梯度為
x + 1
之外,還要乘上 路徑2的梯度1
, - 路徑2和路徑1匯聚到一起,所以最終的梯度為
x + 1(路徑1)+ 1 * x(路徑2)= x + 1 + x
,剛好等於我們用數學公式計算出來的結果;
自動求導框架正是依靠這些基礎的規則和鏈式求導法則在高效准確的運作。
在絕大多數神經網絡的訓練過程中,在計算反向傳播時,前向傳播過程中得到的一些中間變量非常有用(為了方便求導)。在實際操作中,最好代碼實現對於這些中間變量的緩存,這樣在反向傳播的時候也能用上它們。於是顯存占用的大頭就是中間結果,也就是所謂的“特征圖”。對於本文,x 就是前一層輸出的中間結果(特征圖)。
在適用乘法的求導規則時,要求我們要事先保留下中間結果 x 和 x+1。注意框架定義的乘法及其求導規則是通用規則,乘法的左右兩邊完全可能是不相關的兩個值,所以必須同時保留下來。就是說,x + 1 在其他函數中,可能是 ( x + y ) + z ....,也可能包含其他輸入變量,所以無法通過 + 1 這樣簡單的算式由一個輸入 x 計算出來。
在不考慮框架自身優化的情況下,顯存占用就包括了一個 x 和 一個 x + 1,注意x可不是一個單獨的數值,而是類似32x32x128
這樣大小的特征圖。
2.2 梯度Checkpoint
如前一節所述,神經網絡的原始方式中:
- 在forward函數中,每層的激活函數值計算之后需要保存下來,因為它們需要在后向傳播的計算中被消費。
- 在backward時,根據損失函數值和該層對應的激活函數值計算梯度。
- 因此,我們需要在顯存中緩存 n 份(梯度累加的次數)完整的前向 activation。也就是說,這種情況下顯存占用與 層數成正比。
因此,目前流水並行存在一個問題:顯存占用太大。
是否可以不存儲激活值?比如在backward時,需要激活函數值的時候重新進行forward就可以了。
假如我們一個都不存儲,都通過forward重新計算?那么在大模型中這樣消耗的時間太大。所以我們可以選用折中的方式,比如只存部分層的激活函數值。當backward需要激活函數值的時候,取最近的激活值就行。所以就引入了一項重要的技術:重計算(Checkpointing)。
2.3 論文內容
2.3.1 主要論文
Gpipe 的 Checkpointing 主要思路來自以下兩篇論文:
- Andreas Griewank and Andrea Walther. Algorithm 799: revolve: an implementation of check- pointing for the reverse or adjoint mode of computational differentiation. ACM Transactions on Mathematical Software (TOMS), 26(1):19–45, 2000.
- Tianqi Chen, Bing Xu, Chiyuan Zhang, and Carlos Guestrin. Training deep nets with sublinear memory cost. arXiv preprint arXiv:1604.06174, 2016.
主要思路是用算力換內存(計算換顯存,反向求導時需要的中間結果從 checkpoint 重新計算),以及用帶寬換顯存。
2.3.2 論文 Training Deep Nets with Sublinear Memory Cost
2.3.2.1 主要思路
我們主要來看這篇論文。
Checkpointing 是陳天奇在2016年發表的論文 Training Deep Nets with Sublinear Memory Cost 中提到的,也稱之為亞線性內存優化。亞線性內存優化有兩種思路,Checkpointing 和 CPU offload:
- Checkpointing 的核心思想 是在前向網絡中標記少量的 Tensor (被 Checkpointing 的 Tensor ),前向計算就只會保留這些被標記的 Tensor, 其余的前向的 activation,會通過在反向傳播中根據 Checkpointing 的 Tensor 臨時重新計算一遍前向得到。這樣就使得大量的 activation 不需要一直保存到后向計算,有效減少了大量 Tensor 的生命周期,使得內存復用效率大幅提升。
- CPU offload 的思路類比於計算機操作系統中的“虛擬內存”技術(將不常用的內存臨時換入換出到磁盤上,從而增加內存總量),在深度學習中,GPU 顯存(Device Memory)的特點是昂貴、高速且容量小,而 CPU 主存(Host Memory)的特點是便宜、相對低速和大容量;那么將前向計算中的一些暫時用不到的 activation 臨時換出到 CPU 主存上,等到反向計算需要時再換入到 GPU 顯存里,通過這種方式也可以節省顯存。
兩種亞線性內存優化通過不同的方式達到了顯存優化:Checkpointing 是通過額外的計算開銷換顯存, CPU offload 通過額外的傳輸開銷換顯存。
2.3.2.2 Checkpointing 優化
上圖展示了做 Checkpointing 之前和之后的計算圖對比。
左面灰色的是網絡配置。
中間 Normal Gradient Graph 是普通網絡的前向后向傳播流程。
右面 Memory Optimized Gradient Graph 就是應用了 gradient-checkpoint 的結果。為了進一步減少內存,會刪除一些中間結果,並在需要時從額外的前向計算中恢復它們。
- 首先,神經網絡分為幾個部分(右面圖中就分成了三段),該算法只記住每一段的輸出,並在每一段中刪除所有中間結果。
- 其次,在反向傳播階段,我們可以通過從最近的記錄結果向前運行來重新計算丟棄的中間結果。
- 因此,我們只需支付存儲每個段的輸出的內存成本加上在每個段上進行反向傳播的最大內存成本。
所以gradient-checkpoint
就是並非是不需要中間結果,而是有辦法在求導過程中實時的計算出之前被舍棄掉的中間結果。
重計算並不是單獨為流水並行設計的,並且之前大多使用在單卡或者數據並行場景下。但這個優化在流水並行下就非常關鍵,因為它使得前向不需要緩存所有的 activation,而只需要緩存非常少個數的(比如一層 Transformer Layer 只會緩存一個 )、被 checkpoint 的特定 Tensor ,從而大大節省了流水並行下的顯存開銷。
0x03 OpenAI
在OpenAI 提出的gradient-checkpoint
就是論文Training Deep Nets with Sublinear Memory Cost思路的實現,因為其文檔比較齊全(https://github.com/openai/gradient-checkpointing),我們可以學習借鑒下。
總體思路是:在神經網絡中間設置若干個檢查點(checkpoint),對於中間結果feature map,每隔 sqrt(n)保留一個檢查點。檢查點以外的中間結果全部舍棄,反向傳播求導數的時間,需要某個中間結果時,從最近的檢查點開始計算,這樣既節省了顯存,又避免了從頭計算的繁瑣過程。
3.1 計算圖
對一個簡單的 n 層前饋神經網絡,獲取梯度的計算圖如下所示:
具體如下:
- 神經網絡的層級激活值對應於 f 標記的節點,且在正向傳播過程中,所有這些節點需要按順序計算。
- 損失函數對激活值和這些層級參數的梯度使用 b 節點標記,且在反向傳播過程中,所有這些節點需要按逆序計算。
- 計算 f 節點的激活值是進一步計算 b 節點梯度的前提要求,因此 f 節點在前向傳播后會保留在內存中。
- 只有當反向傳播執行地足夠遠以令計算對應的梯度不再需要使用后面層級的激活值或 f 的子節點時,這些激活值才能從內存中清除。這意味着簡單的反向傳播要求內存與神經網絡的層級數成線性增長關系。
3.2 重計算
簡單的反向傳播已經是計算最優的了,因為每個節點只需要計算一次。然而,如果我們願意重新計算節點,那么我們可以節省大量的內存。當我們需要節點的激活值時,我們可以簡單地重計算前向傳播的節點激活值。我們可以按順序執行計算,直到計算出需要使用激活值進行反向傳播的節點。
使用這一策略,需要令計算梯度的內存在神經網絡層的數量 n 上是穩定的,且 n 在內存方面是最優的。但是要注意,節點的計算數量現在擴展了 n^2,相比於之前的 n。n 個節點中的每一個被再計算 n 次。因此計算圖變得很慢以計算深度網絡,使得這一方法不適用於深度學習。
3.3 策略
為了在內存與計算之間取得平衡,我們需要一個策略允許節點被再計算,但是這種再計算不會發生很頻繁。這里我們使用的策略是把神經網絡激活的一個子集標記為一個節點。紫色的節點表示在給定的時間內需要儲存在內存中。
這些檢查點節點在前向傳播后保留在內存中,而其余節點最多只會重新計算一次。在重新計算后,非檢查點節點將保留在內存中,直到不再需要它們來執行反向傳播。對於簡單的前饋神經網絡,所有神經元的激活節點都是由正向傳播定義的連接點或圖的分離點。這意味着我們在反向傳播過程中只需要重計算 b 節點和最后檢查點之間的節點,當反向傳播達到了我們保存的檢查點節點,那么所有從該節點開始重計算的節點在內存中都能夠移除。
3.4 過程
首先,我們設定了兩個checkpoint,圖上第一行左面兩個紫色,注意,右面第一個紫色是輸入。
其次,正向傳播已經完成,開始反向傳播,就是從下面一行紫色1號開始反向傳播。
第三,來到了下面一行的紫色2號,它依賴於上面的紫色3號來計算(回憶一下,后向傳播計算需要前向計算的輸出),此紫色3號是checkpoint,在內存中存在,所以正常執行反向傳播
第四,來到了下面一行的白色 4 號,它依賴於上面的紫色 5 號來計算,5 號不是一個checkpoint,不在內存之中,需要重它前面的checkpoint開始計算,即從紫色 7 號開始計算。計算出來一個新的checkpoint,同時可以刪除上面一行原有紫色 5 號,因為不需要了。
第五,計算出下面的新紫色 4 號,從而繼續后向計算。
因為涉及到自動生成checkpoint,OpenAI這部分代碼比較晦澀鬼畜,所以這里不進行分析,如果有興趣的同學可以自行學習。
0x04 Pytorch 實現
我們接下來用Pyorch來看看。
4.1 基礎知識
4.1.1 Variable & Function
在PyTorch中,autograd是所有神經網絡的核心內容,為Tensor所有操作提供自動求導方法。它是一個按運行方式定義的框架,這意味着backprop是由代碼的運行方式定義的。
autograd.Variable 是autograd中最核心的類。 它包裝了一個Tensor,並且幾乎支持所有在其上定義的操作。一旦完成了你的運算,你可以調用 .backward()來自動計算出所有的梯度。
另一個對autograd的實現非常重要的類是Function,Function簡單說就是對Variable的運算,如加減乘除,relu,pool等。但它不僅僅是簡單的運算。與普通Python或者numpy的運算不同,Function是針對計算圖,需要計算反向傳播的梯度。因此他不僅需要進行該運算(forward過程),還需要利用cache保留前向傳播的輸入(為計算梯度),並支持反向傳播計算梯度。
Pytorch是利用Variable與Function來構建計算圖的。回顧下Variable,Variable就像是計算圖中的節點,保存計算結果(包括前向傳播的激活值,反向傳播的梯度),而Function就像計算圖中的邊,實現Variable的計算,並輸出新的Variable。
總結,Function與Variable構成了pytorch的自動求導機制,它定義的是各個Variable之間的計算關系。
備注:最新 PyTorch 代碼之中,已經用把 Function 修改為 Node 類,應該是為了更好的表示計算圖中節點的概念。
4.1.2 Function進一步理解
我們可以使用autograd.Function類來自定義一個模型、一個層、一個激活函數、一個損失函數,就更加好理解了,實際上本質上來說都是一個函數,只分這個函數是簡單還是復雜。
4.2 普通模式
這部分代碼位於torch/utils/checkpoint.py。pytorch是需要用戶指定checkpoint,因此實現相對簡單很多。
4.2.1 封裝
在 torch/utils/checkpoint.py 之中,對checkpoint有了一個封裝,該注釋非常值得我們閱讀,我們深入學習一下。
-
Checkpointing 本質就是用計算換內存。
-
Checkpointing 不存儲用於后向計算所需要的整個計算圖的全部中間激活值,而是在反向傳播中重新計算它們。
-
在前向傳播過程中,Checkpointing 參數 function 是運行在 torch.no_grad 模式,這樣就不會計算中間激活值了。相反,向前傳遞保存輸入元組和
function
參數。 -
在向后傳遞中,保存的輸入和
function
被取出,function
將再次被計算,這次會跟蹤中間激活值,然后使用這些激活值計算梯度。
def checkpoint(function, *args, **kwargs):
r"""Checkpoint a model or part of the model
Checkpointing works by trading compute for memory. Rather than storing all
intermediate activations of the entire computation graph for computing
backward, the checkpointed part does **not** save intermediate activations,
and instead recomputes them in backward pass. It can be applied on any part
of a model.
Specifically, in the forward pass, :attr:`function` will run in
:func:`torch.no_grad` manner, i.e., not storing the intermediate
activations. Instead, the forward pass saves the inputs tuple and the
:attr:`function` parameter. In the backwards pass, the saved inputs and
:attr:`function` is retrieved, and the forward pass is computed on
:attr:`function` again, now tracking the intermediate activations, and then
the gradients are calculated using these activation values.
The output of :attr:`function` can contain non-Tensor values and gradient
recording is only performed for the Tensor values. Note that if the output
consists of nested structures (ex: custom objects, lists, dicts etc.)
consisting of Tensors, these Tensors nested in custom structures will not
be considered as part of autograd.
Args:
function: describes what to run in the forward pass of the model or
part of the model. It should also know how to handle the inputs
passed as the tuple. For example, in LSTM, if user passes
``(activation, hidden)``, :attr:`function` should correctly use the
first input as ``activation`` and the second input as ``hidden``
preserve_rng_state(bool, optional, default=True): Omit stashing and restoring
the RNG state during each checkpoint.
args: tuple containing inputs to the :attr:`function`
Returns:
Output of running :attr:`function` on :attr:`*args`
"""
# Hack to mix *args with **kwargs in a python 2.7-compliant way
preserve = kwargs.pop('preserve_rng_state', True)
if kwargs:
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
return CheckpointFunction.apply(function, preserve, *args)
4.2.2 處理設備
因為pytorch無法知道向前傳播函數是否會把一些參數移動到不同的設備上,這就需要一些邏輯來保存為這些設備保存RNG狀態。雖然可以為所有可見設備保存/恢復所有的RNG狀態,但是這樣在大多數情況下是一種浪費,因此作為折中,pytorch只是針對所有的張量參數的設備進行保存RNG狀態。
def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
# This will not error out if "arg" is a CPU tensor or a non-tensor type because
# the conditionals short-circuit.
fwd_gpu_devices = list(set(arg.get_device() for arg in args
if isinstance(arg, torch.Tensor) and arg.is_cuda))
fwd_gpu_states = []
for device in fwd_gpu_devices:
with torch.cuda.device(device):
fwd_gpu_states.append(torch.cuda.get_rng_state())
return fwd_gpu_devices, fwd_gpu_states
def set_device_states(devices, states) -> None:
for device, state in zip(devices, states):
with torch.cuda.device(device):
torch.cuda.set_rng_state(state)
4.2.3 核心邏輯
CheckpointFunction 繼承了torch.autograd.Function。
我們可以對Function進行拓展,使其滿足我們自己的需要,而拓展就需要自定義Function的forward運算,以及對應的backward運算,同時在forward中需要通過保存輸入值用於backward。
-
forward函輸入tensor,計算輸出tensor。
- 在前向傳播過程中,Checkpointing 參數 function 是運行在 torch.no_grad 模式,這樣就不會計算中間激活值了。
- 向前傳遞保存輸入元組和
function
參數。 - 對於CheckpointFunction來說,還是需要在forward之中存儲一些另外的信息(就是上面說的 rng 信息),以供后向傳播時候計算使用。
- 進行前向傳播返回激活值。
-
backward函數接收相對於某個標量值的輸出張量的梯度,並且計算關於該相同標量值的輸入張量的梯度。
- 在向后傳遞中,保存的輸入和
function
被取出。 function
將再次被計算,這次會跟蹤中間激活值,然后使用這些激活值計算梯度。
- 在向后傳遞中,保存的輸入和
"""
我們可以通過建立torch.autograd的子類來實現我們自定義的autograd函數,
並完成張量的正向和反向傳播。
"""
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
"""
在forward函數中,接收包含輸入的Tensor並返回包含輸出的Tensor。
ctx是環境變量,用於提供反向傳播是需要的信息。我們可以使用上下文對象來緩存對象,以便在反向傳播中使用。可通過ctx.save_for_backward方法緩存數據,save_for_backward只能傳入Variable或是Tensor的變量。
"""
check_backward_validity(args)
# 保存前向傳播函數
ctx.run_function = run_function
ctx.preserve_rng_state = preserve_rng_state
ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
if preserve_rng_state:
ctx.fwd_cpu_state = torch.get_rng_state()
# Don't eagerly initialize the cuda context by accident.
# (If the user intends that the context is initialized later, within their
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
# we have no way to anticipate this will happen before we run the function.)
# 存儲前向傳播時候的設備狀態
ctx.had_cuda_in_fwd = False
if torch.cuda._initialized:
ctx.had_cuda_in_fwd = True
ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
# to be filled out during the backward.
ctx.inputs = []
ctx.tensor_indices = []
tensor_inputs = []
for i, arg in enumerate(args): # 存儲輸入數值
if torch.is_tensor(arg):
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)
# `saved_for_backward`是會保留此input的全部信息, 並避免in-place操作導致的input在backward被修改的情況. 它是將函數的輸入參數保存起來以便后面在求導時候再使用,起前向反向傳播中協調作用。
ctx.save_for_backward(*tensor_inputs)
with torch.no_grad():
outputs = run_function(*args) # 進行前向傳播
return outputs
"""
在反向傳播中,我們接收到上下文對象和一個張量,
其包含了相對於正向傳播過程中產生的輸出的損失的梯度。
我們可以從上下文對象中檢索緩存的數據,
並且必須計算並返回與正向傳播的輸入相關的損失的梯度。
"""
# 自動求導是根據每個op的backward創建的graph來進行的
@staticmethod
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
" is passed to .backward(). Please use .backward() and do not pass its `inputs`"
" argument.")
# Copy the list to avoid modifying original list.
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensors = ctx.saved_tensors # 獲取前面保存的參數,也可以使用self.saved_variables
# Fill in inputs with appropriate saved tensors.
for i, idx in enumerate(tensor_indices): # 利用存儲的張量重新設置input
inputs[idx] = tensors[i]
# Stash the surrounding rng state, and mimic the state that was
# present at this time during forward. Restore the surrounding state
# when we're done.
# 存儲目前rng狀態,模擬前向傳播狀態,最后恢復目前狀態
rng_devices = []
if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
rng_devices = ctx.fwd_gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
if ctx.preserve_rng_state:
torch.set_rng_state(ctx.fwd_cpu_state) # 恢復前向傳播時候的設備狀態
if ctx.had_cuda_in_fwd:
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
detached_inputs = detach_variable(tuple(inputs))
with torch.enable_grad(), torch.cuda.amp.autocast(ctx.had_autocast_in_fwd):
# 利用前向傳播函數再次計算
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
# run backward() with only tensor that requires grad
outputs_with_grad = [] # 激活值
args_with_grad = [] # 梯度
# 從前向傳播計算的結果中篩選需要傳播的張量
for i in range(len(outputs)):
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
outputs_with_grad.append(outputs[i])
args_with_grad.append(args[i])
if len(outputs_with_grad) == 0:
raise RuntimeError(
"none of output has requires_grad=True,"
" this checkpoint() is not necessary")
# 開始后向傳播
torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
for inp in detached_inputs)
return (None, None) + grads
4.3 Pipeline模式
我們接下來看看 流水線模式如何進行 Checkpoint。
Pytorch 流水型並行模式是受到了GPipe的啟發,在其注釋之中有提到。
通過CheckpointFunction
,pytorch可以做到把重計算和遞歸反向傳播合並到一個自動求導函數中,因此當梯度到達時,重計算就會開始。但是在流水線模式中,為了縮減GPU idle時間,重計算需要發生在梯度到達之前進行(因為重計算其實和梯度無關,重計算可以在梯度到來之前進行以獲得激活值,等后向傳播的梯度來了之后,再集合激活值進行自己的梯度計算)。
為了解決這個問題,pytorch引入了兩個自動求導函數:class:Recompute
and class:Checkpoint
,分別代表重計算和遞歸反向傳播,就是把普通模式下的 CheckpointFunction 分離成兩個階段,這樣用這兩個函數就可以控制自動求導引擎和CUDA。具體說就是在class:Recompute
and class:Checkpoint
之間插入CUDA同步,這樣把class:Checkpoint
推遲到梯度完全拷貝結束。
分開段,就可以多個流水線stage並行了。
4.3.1 樣例
我們可以先看看 test/distributed/pipeline/sync/test_checkpoint.py 這個代碼。
其通過log的巧妙打印,可以讓我們看出來運行時候,checkpoint在前向后向傳播之中的使用。
timeline 最后結果是 ["a:forward", "b:forward", "b:forward", "b:backward", "a:forward", "a:backward"],
其中兩兩一組,分別對應了 forward pass ,Checkpoint(Log[b]),Checkpoint(Log[a])。
@pytest.mark.parametrize("device", devices)
def test_serial_checkpoints(device):
# Copied from https://github.com/pytorch/pytorch/pull/18568.
timeline = []
class Log(torch.autograd.Function):
@staticmethod
def forward(ctx, name, x):
ctx.name = name
timeline.append(f"{name}:forward")
return x.detach()
@staticmethod
def backward(ctx, grad_output):
name = ctx.name
timeline.append(f"{name}:backward")
return None, grad_output
a = torch.rand(1, device=device, requires_grad=True)
b = torch.rand(1, device=device, requires_grad=True)
# Increase the next function sequence number.
_ = a + 1 + 2 + 3 + 4 + 5
# 這里意味着最后 backward 實際會運行"a:forward", "a:backward"
a = checkpoint(partial(Log.apply, "a"), a)
a, phony = fork(a)
b = join(b, phony)
# 這里意味着最后 backward 實際會運行"b:forward", "b:backward"
b = checkpoint(partial(Log.apply, "b"), b)
c = torch.cat((a, b))
out = c.sum()
# +--> {a} --Checkpoint(Log)--> {a}
# {out} --Sum--> {c} --Cat ^-----------------------------+
# +--> {b} --Checkpoint(Log)--> {b} --First--> {b}
out.backward()
assert timeline == ["a:forward", "b:forward", "b:forward", "b:backward", "a:forward", "a:backward"]
# |----------------------| |-----------------------| |-----------------------|
# forward pass Checkpoint(Log[b]) Checkpoint(Log[a])
4.3.2 共享變量
class:Recompute
and class:Checkpoint
之間具體是通過Context這個上下文來進行共享變量的保存。
# Types for shared memory between Checkpoint and Recompute.
Recomputed = Tuple[TensorOrTensors, Tensors] # (output, input_leaf)
RNGStates = Tuple[Tensor, Optional[Tensor]] # (cpu_rng_state, gpu_rng_state)
class Context:
"""The common interface between the :class:`Checkpoint` and
:class:`Recompute` context.
"""
recomputed: Deque[Recomputed]
rng_states: Deque[RNGStates]
function: Function
input_atomic: bool
saved_tensors: Tuple[Tensor, ...]
def save_for_backward(self, *tensors: Tensor) -> None: # pragma: no cover
pass
4.3.3 rng state
根據運行時的不同,RNG狀態可能會產生不同的性能影響,所以需要在每個檢查點期間存儲當前設備的RNG狀態,在重計算之前恢復當前設備的RNG狀態。
save_rng_states 和 restore_rng_states 兩個方法分別用來存取 RNG 狀態。
def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None:
""":meth:`Checkpoint.forward` captures the current PyTorch's random number
generator states at CPU and GPU to reuse in :meth:`Recompute.backward`.
.. seealso:: :ref:`Referential Transparency`
"""
cpu_rng_state = torch.get_rng_state()
gpu_rng_state: Optional[Tensor]
if device.type == "cuda":
gpu_rng_state = torch.cuda.get_rng_state(device)
else:
gpu_rng_state = None
rng_states.append((cpu_rng_state, gpu_rng_state))
@contextmanager
def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> Generator[None, None, None]:
""":meth:`Recompute.backward` restores the random number generator states
captured by :func:`save_rng_states` within its context.
.. seealso:: :ref:`Referential Transparency`
"""
cpu_rng_state, gpu_rng_state = rng_states.pop()
gpu_devices: List[torch.device] = []
if device.type == "cuda":
gpu_devices.append(device)
with torch.random.fork_rng(gpu_devices):
torch.set_rng_state(cpu_rng_state)
if gpu_rng_state is not None:
torch.cuda.set_rng_state(gpu_rng_state, device)
yield
4.3.4 Checkpoint
Checkpoint 和下面的 Recompute 就是把普通模式下的 checkpoint 代碼分離成兩個階段(forward函數被分成兩段,backward 函數也被分成兩段),從而可以更好的利用流水線。
class Checkpoint(torch.autograd.Function):
@staticmethod
# type: ignore[override]
def forward(
ctx: Context,
phony: Tensor,
recomputed: Deque[Recomputed],
rng_states: Deque[RNGStates],
function: Function,
input_atomic: bool,
*input: Tensor,
) -> TensorOrTensors:
ctx.recomputed = recomputed
ctx.rng_states = rng_states
# 存RNG狀態
save_rng_states(input[0].device, ctx.rng_states)
ctx.function = function
ctx.input_atomic = input_atomic
# 為BP做准備,其實目前沒有實現
ctx.save_for_backward(*input)
# 進行前向計算
with torch.no_grad(), enable_checkpointing():
output = function(input[0] if input_atomic else input)
return output
@staticmethod
def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: # pragma: no cover
# 從保存的重計算變量中彈出所需變量
output, input_leaf = ctx.recomputed.pop()
if isinstance(output, tuple):
tensors = output
else:
tensors = (output,)
if any(y.requires_grad for y in tensors):
tensors = tuple([x for x in tensors if x.requires_grad])
# 進行自動微分
torch.autograd.backward(tensors, grad_output)
grad_input: List[Optional[Tensor]] = [None, None, None, None, None]
grad_input.extend(x.grad for x in input_leaf)
return tuple(grad_input)
4.3.5 Recompute
Recompute 就是依據保存的信息,重新計算中間變量。
class Recompute(torch.autograd.Function):
@staticmethod
# type: ignore[override]
def forward(
ctx: Context,
phony: Tensor,
recomputed: Deque[Recomputed],
rng_states: Deque[RNGStates],
function: Function,
input_atomic: bool,
*input: Tensor,
) -> Tensor:
ctx.recomputed = recomputed
ctx.rng_states = rng_states
ctx.function = function
ctx.input_atomic = input_atomic
ctx.save_for_backward(*input)
return phony
@staticmethod
def backward(ctx: Context, *grad_output: Tensor) -> Tuple[None, ...]:
input = ctx.saved_tensors
input_leaf = tuple(x.detach().requires_grad_(x.requires_grad) for x in input)
# 取出保存的RNG狀態,進行前向計算,得到中間變量
with restore_rng_states(input[0].device, ctx.rng_states):
with torch.enable_grad(), enable_recomputing():
output = ctx.function(input_leaf[0] if ctx.input_atomic else input_leaf)
# 保存變量,為Checkpoint使用
ctx.recomputed.append((output, input_leaf))
grad_input: List[None] = [None, None, None, None, None]
grad_input.extend(None for _ in ctx.saved_tensors)
return tuple(grad_input)
4.3.6 Pipeline
4.3.6.1 Task
我們首先要看看 Task 類。代碼位於:torch/distributed/pipeline/sync/worker.py。
由注釋可知,Task 就是用來在一個分區上計算一個micro-batch。
compute
可以在worker線程內被並行執行。
finalize
應該在compute
結束之后被執行。
class Task:
"""A task represents how to compute a micro-batch on a partition.
It consists of two parts: :meth:`compute` and :meth:`finalize`.
:meth:`compute` should be executed in worker threads concurrently.
:meth:`finalize` should be executed after when worker threads complete to
execute :meth:`compute`.
:meth:`compute` might be boosted by worker threads. Because it produces
several CUDA API calls by user code. In PyTorch, parallel CUDA API calls
are not serialized through GIL. So more than one CUDA API call can be
produced at the same time.
"""
def __init__(
self, stream: AbstractStream, *, compute: Callable[[], Batch], finalize: Optional[Callable[[Batch], None]],
) -> None:
self.stream = stream
self._compute = compute
self._finalize = finalize
self._grad_enabled = torch.is_grad_enabled()
def compute(self) -> Batch:
with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled):
return self._compute()
def finalize(self, batch: Batch) -> None:
if self._finalize is None:
return
with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled):
self._finalize(batch)
4.3.6.2 compute
這里說的是 Pipeline 類的 compute 函數。
Pipeline 的邏輯如其注釋所示(PyTorch的注釋真的很翔實)。重點是 Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
這里設置了如何進行checkpoint。
可以看到,這里會將 recompute 方法設置為 Task 的 finalize 方法,然后會計划重計算。
class Pipeline:
"""The pipeline parallelism for Pipe."""
def compute(
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
) -> None:
"""Runs tasks with synchronization to copy streams."""
partitions = self.partitions
devices = self.devices
copy_streams = self.copy_streams
checkpoint_stop = self.checkpoint_stop
# Disable checkpointing if in eval mode.
if not self.partitions[0].training:
checkpoint_stop = 0
n = len(partitions)
streams = [current_stream(d) for d in devices]
exc_info: Optional[ExcInfo] = None
# With checkpointing, the autograd graph looks like this diagram:
# ┌─────┸──────┐
# │ Copy │
# └─────┰──────┘ (fence)
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┃ (compute)
# ┌─────┸──────┐
# │ Wait │ [1] Synchronize the current stream with the copy stream.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Checkpoint │ [2] Compute a partition within checkpointing.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Wait │ [3] Synchronize the copy stream with the current stream.
# └─────┰──────┘
# ┠ ─ ─ ─ ┐
# ┃ ┌─────┴─────┐
# ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
# ┃ └─────┬─────┘
# ┠ ─ ─ ─ ┘
# ┃
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┌─────┸──────┐ (fence)
# │ Copy │
# └─────┰──────┘
for i, j in schedule:
batch = batches[i]
partition = partitions[j]
# Synchronize with the copied input. ([1] in the diagram)
if j != 0:
_wait(batch, copy_streams[j][i], streams[j])
# Determine whether checkpointing or not.
checkpoint = i < checkpoint_stop
if checkpoint:
def function(
input: TensorOrTensors,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> TensorOrTensors:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
return partition(input)
# 這里進行處理
chk = Checkpointing(function, batch)
# 分別設置了chk.checkpoint 和 chk.recompute
task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
del function, chk
else:
def compute(
batch: Batch = batch,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> Batch:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
return batch.call(partition)
task = Task(streams[j], compute=compute, finalize=None)
del compute
# Compute tasks in parallel. ([2] in the diagram)
self.in_queues[j].put(task) # 將task插入到 pipeline的queue,這樣可以並行。
for i, j in schedule:
ok, payload = self.out_queues[j].get()
# Hold the first exception.
if exc_info is not None:
continue
elif not ok:
exc_info = cast(ExcInfo, payload)
continue
# 取出 task
task, batch = cast(Tuple[Task, Batch], payload)
# The copy stream synchronizes to copy the output. ([3] in the
# diagram)
if j != n - 1:
_wait(batch, streams[j], copy_streams[j][i])
# Finalize tasks. If checkpointing is enabled, here the
# recomputation is scheduled at backpropagation. ([4] in the
# diagram)
with use_device(devices[j]):
task.finalize(batch) # 計划進行重計算
batches[i] = batch
# Fail at the first exception.
if exc_info is not None:
raise exc_info[0].with_traceback(exc_info[1], exc_info[2])
關於 PyTorch 的 Pipeline,后續會有專門系列進行分析。
0x05 Gpipe實現
Gpipe 在反向傳播的時候,可以在第 k-th 個 accelerator 上重新計算前向傳播函數 F_k。
5.1 API函數 _Rematerialize
首先,我們看看API方法。
在 builder.py 之中有 _Rematerialize 函數,可以用來包裝一個需要重新計算的層。
def _Rematerialize(self, name, body):
"""Forces rematerialization on FProp of the body layer."""
return builder_layers.RematerializationLayer.Params().Set(
name=name, body=body)
5.2 包裝層 RematerializationLayer
RematerializationLayer 是包裝層,其中有:
FProp 就是把被封裝層 包裝為一個函數 Fn,然后調用 py_utils.RematerializeFn 把 Fn 與 輸入變量一起傳入。
class RematerializationLayer(base_layer.BaseLayer):
"""A wrapper layer with rematerialization."""
@classmethod
def Params(cls):
p = super().Params()
p.Define('body', None,
'The main layer whose FProp will be wrapped by RematerializeFn.')
return p
def __init__(self, params):
super().__init__(params)
self.CreateChild('body', self.params.body)
def FProp(self, theta, *xs):
input_list = theta.body.Flatten() # 得到theta
theta_len = len(input_list)
input_list += list(xs) # 得到輸入參數
input_len = len(input_list)
def Fn(*args): # 包裝函數,會調用被封裝層的 FProp
body_theta = theta.body.Pack(args[:theta_len])
return self.body.FProp(body_theta, *args[theta_len:input_len])
return py_utils.RematerializeFn(Fn, *input_list) # 調用,執行FProp,並且做Gradient checking
@classmethod
def FPropMeta(cls, p, *args): # 就是傳播被封裝層的信息
py_utils.CheckShapes(args)
return p.body.cls.FPropMeta(p.body, *args)
3.2.3 tensorflow gradients 函數
RematerializeFn 調用了 tensorflow gradients 函數 來計算梯度,所以我們需要解釋下。
在tensorflow中,gradients 函數可以自動計算函數的梯度。我們只需要設計我們的函數,然后去調用 tf.gradients
函數就可以了。
tf.gradients()的參數如下,其中
tf.gradients()
實現ys
對xs
求導grad_ys
也是一個list,其長度等於len(ys)
。這個參數的意義在於對xs
中的每個元素的求導權重。
tf.gradients(ys, xs,
grad_ys=None,
name='gradients',
colocate_gradients_with_ops=False,
gate_gradients=False,
aggregation_method=None,
stop_gradients=None)
5.4 功能函數 RematerializeFn
RematerializeFn 是最終功能函數,就是調用 fn,並且在反向傳播過程中進行rematerializes fn。
def RematerializeFn(fn, *xs):
"""Calls fn and rematerializes fn in the backward pass.
`fn(*xs) -> ys`, where xs and ys can be a single tensor or a tuple of tensors.
Args:
fn: A python function to be rematerialized in the backprop pass.
*xs: A single tensor or a list/tuple of tensors. `xs` are input args to the
fn function.
Returns:
`fn(*xs)`
"""
initial_step_seed = GetStepSeed()
final_step_seed = MaybeGenerateSeedFromScope()
def Backward(fwd_xs, fwd_ys, d_fwd_ys):
"""The backward function that rematerializes forward outputs."""
del fwd_ys # 去掉傳入的參數,因為在內部需要用備份的Checkpoint來處理
always_true = tf.random.uniform([]) < 2.0
# Alternatively, can do this:
# tf.where(tf.math.is_nan(x),
# tf.constant(float('nan'), dtype=x.dtype) * tf.ones_like(x),
# x)
bak_xs = [tf.where(always_true, x, tf.zeros_like(x)) for x in fwd_xs.xs] # 依據Checkpoint來生成 bak_xs
for dst, src in zip(bak_xs, xs):
dst.set_shape(src.shape)
ResetStepSeed(initial_step_seed)
ys = fn(*bak_xs) # 依據Checkpoint來重新生成ys
MaybeResetStepSeed(final_step_seed)
dxs = tf.gradients(ys, bak_xs, grad_ys=d_fwd_ys) # ys 對 bak_xs 求導
dxs_final = [] # 聚合
for dx, x in zip(dxs, bak_xs):
if dx is None:
dxs_final.append(tf.zeros_like(x))
else:
dxs_final.append(dx)
assert len(dxs_final) == len(bak_xs)
return NestedMap(
initial_step_seed=tf.zeros_like(initial_step_seed), xs=dxs_final)
ys_shapes = []
# TODO(huangyp, yonghui): Check Forward doesn't use any stateful random ops.
def Forward(fwd_xs):
"""Forward function plus sanity checks."""
for dst, src in zip(fwd_xs.xs, xs):
dst.set_shape(src.shape)
ResetStepSeed(fwd_xs.initial_step_seed)
ys = fn(*fwd_xs.xs) # 正常計算
# Some sanity check.
assert not GetExtraInputs()
assert not GetExtraArgs()
assert not GetExtraVars()
if isinstance(ys, tuple):
for y in ys:
assert isinstance(y, tf.Tensor)
ys_shapes.append(y.shape)
else:
assert isinstance(ys, tf.Tensor)
ys_shapes.append(ys.shape)
return ys
ys = CallDefun(
Forward,
NestedMap(initial_step_seed=initial_step_seed, xs=xs),
bak=Backward)
if isinstance(ys, tuple):
for y, s in zip(ys, ys_shapes):
y.set_shape(s)
else:
ys.set_shape(ys_shapes[0])
# TODO(b/129159299): The ResetStepSeed below is needed to work around this
# bug, which is a problem with global tensors being shared by different
# inference graphs. It should be replaced with the new step seed value
# returned from the Forward function when the bug is fixed.
MaybeResetStepSeed(final_step_seed)
return ys
CallDefun定義如下,就是把fwd, back封裝起來進行調用。其中,Function 的作用是依據一個callable 構建一個TensorFlow graph function
def CallDefun(fwd, args=None, bak=None, bak_as_function=False, device=None):
"""Wraps fwd in a defun with custom gradient bak and calls it with args.
Args:
fwd: A callable xs: Nested Structure -> ys: Nested Structure.
args: A Nested Structure of tf.Tensor or None.
bak: A callable xs, ys, dys: Nested Structure -> dxs[, dcapture]: Nested
Structure. The custom backprop function for fwd. bak needs to return
dcapture if fwd uses any implicitly captured tensors, whose gradients are
dcapture.
bak_as_function: Whether to create a TF graph function for bak.
device: the device on which to run fwd and bak.
Returns:
A Nested Structure equivalent to what fwd(args) computes.
"""
if args is not None:
args = Transform(tf.convert_to_tensor, args)
sigs = Function(
fwd_sig=TensorSpecs(args),
bak=bak,
bak_as_function=bak_as_function,
device=device)(
fwd=fwd)
if args is None:
return sigs()
else:
return sigs(args)
至此,GPipe 分析完畢,下一篇開始分析 PipeDream,敬請期待。
0xFF 參考
Tensorflow實現先累加多個minibatch計算的梯度,再反向傳播
PipeDream: Fast and Efficient Pipeline Parallel DNN Training
論文解讀系列第五篇:微軟斯坦福等PipeDream快速訓練大規模神經網絡
https://cs231n.github.io/neural-networks-3/#gradcheck
https://www.cnblogs.com/geekfx/p/14182048.html
訓練時顯存優化技術——OP合並與gradient checkpoint
Pytorch筆記04-自定義torch.autograd.Function
pytorch的自定義拓展之(三)——torch.autograd.Function的簡單定義與案例
pytorch的自定義拓展之(二)——torch.autograd.Function完成自定義層