PyTorch之Checkpoint機制解析


PyTorch之Checkpoint機制解析

本文已授權極市平台, 並首發於極市平台公眾號. 未經允許不得二次轉載.

原文鏈接:https://www.yuque.com/lart/ugkv9f/azvnyg

PyTorch 提供了一種非常方便的節省顯存的方式,就是 Checkpoint 機制。這篇文章的目的在於更透徹的了解其內在的機制。

Checkpoint 機制

該技術的核心是一種使用時間換空間的策略。在現有的許多方法中被大量使用,例如 DenseNet、Swin Transformer 源碼中都可以看到它的身影。

為了了解它的工作原理,我們先得弄明白的一個問題是,PyTorch 模型在訓練過程中顯存占用主要是用來存儲什么?

關於這一點,Connolly 的文章 《PyTorch 顯存機制分析》 介紹的非常詳細:

開門見山的說,PyTorch 在進行深度學習訓練的時候,有 4 大部分的顯存開銷,分別是模型參數(parameters)模型參數的梯度(gradients)優化器狀態(optimizer states) 以及 中間激活值(intermediate activations) 或者叫中間結果(intermediate results)。

而通過 Checkpoint 技術,我們可以通過一種取巧的方式,使用 PyTorch 提供的 “no-grad”no_grad())模式來避免將這部分運算被autograd記錄到反向圖“backward graph”中,從而避免了對於中間激活值的存儲需求。

個人理解(歡迎指出錯誤):

前向傳播時 autograd 記錄各個操作反向傳播需要的一些信息和中間變量。反向傳播之后,用於計算梯度的中間結果會被釋放。也就是說,模型參數、優化器狀態和參數梯度是始終在占用着存儲空間的,中間激活值在反向傳播之后就自動被清空了。具體顯存占用變化可見 PyTorch 顯存占用分析 ,這里我簡單修改了 《PyTorch 顯存機制分析》 中給出的例子 進行了一下驗證

這里實際上會引申出另一個問題,為什么自定義 Function 一般情況下會減少顯存占用?(在 Vision Longformer 中各種實現的對比里可以明顯看到這一現象)

我覺得主要是因為自定義 Function 的時候,我們可以從一整個模塊的角度來更有針對性的在 ctx 中存儲中間變量,而自動求導引擎可能關注的太細了,導致存儲許多不必要的中間變量。關於這一點暫時不知道如何驗證。

這可以避免存儲模型特定層中間運算結果,從而有效降低了前向傳播中顯存的占用。 這些中間結果會在反向傳播的時候被即時重新計算一次。要注意,被 checkpoint 包裹的層反向傳播時仍然會在第一次反向傳播的時候開辟存儲梯度的空間。

因為 checkpoint 是在 torch.no_grad() 模式下計算的目標操作的前向函數,這並不會修改原本的葉子結點的狀態,有梯度的還會保持。只是關聯這些葉子結點的臨時生成的中間變量會被設置為不需要梯度,因此梯度鏈式關系會被斷開。

通過這樣的方式,雖然延長了反向傳播的時間,但是卻也在一定程度上緩解了存儲大量中間變量帶來的顯存占用。

源碼解析

以下代碼來自 PyTorch v1.10.1 版本:https://github.com/pytorch/pytorch/blob/v1.10.1/torch/utils/checkpoint.py。最新的版本中補充了一些新的內容,待其最終發布后再說吧,下面的內容本身已經將 checkpoint 的核心介紹了。

輔助函數

這部分代碼中首先構造了數個輔助函數,主要是用來做一些針對輸入的檢查和處理,同時也要處理好隨機種子的問題。

def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
    if isinstance(inputs, tuple):
        out = []
        for inp in inputs:
            if not isinstance(inp, torch.Tensor):
                out.append(inp)
                continue
            
            # 直接detach(),從inp所在的計算圖中剝離,默認會自動將requires_grad置為False
            x = inp.detach()
            # 但是這里的實際需求中,仍需要保持其自身的需要記錄梯度的屬性,且其梯度變為None
            x.requires_grad = inp.requires_grad
            # 因為只有需要保存梯度的參數才能夠構建梯度的傳播路徑
            out.append(x)
        return tuple(out)
    else:
        raise RuntimeError(
            "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)

def check_backward_validity(inputs: Iterable[Any]) -> None:
    """檢查輸入參數是否至少有一個需要記錄梯度的Tensor,這樣才能確保輸出也有梯度。"""
    if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)):
        warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")

由於需要重復計算,所以隨機狀態的一致性是需要重視的。由於前向傳播的部分在反向過程中仍會計算一次,所以如果不使用原始的隨機狀態的話,會導致重新計算和原本正常計算過程中的隨機狀態不同,而影響模型的行為。

另外在這段代碼的注釋中提到了一點有趣的地方:

由於無法獲悉被 checkpoint 處理的操作是否在運算中間會將一些參數移動到不同的設備上,這可能需要手動保存這些設備對應的隨機狀態。當前的實現直接保存了所有可見設備上的隨機狀態,但是這樣有時可能是不必要的,但是目前尚沒有較好的解決策略。

所以按照文檔的意思,就是在說如果沒有這樣的移動,那就可以不用保存隨機狀態咯?這一點其實有些令人疑惑。

# We can't know if the run_fn will internally move some args to different devices,
# which would require logic to preserve rng states for those devices as well.
# We could paranoically stash and restore ALL the rng states for all visible devices,
# but that seems very wasteful for most cases.  Compromise:  Stash the RNG state for
# the device of all Tensor args.
#
# To consider:  maybe get_device_states and set_device_states should reside in torch/random.py?
def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
    """獲取不同輸入對應的GPU設備的隨機數生成器的狀態"""
    # This will not error out if "arg" is a CPU tensor or a non-tensor type because
    # the conditionals short-circuit.
    fwd_gpu_devices = list(set(arg.get_device() for arg in args
                               if isinstance(arg, torch.Tensor) and arg.is_cuda))

    fwd_gpu_states = []
    for device in fwd_gpu_devices:
        with torch.cuda.device(device):
            fwd_gpu_states.append(torch.cuda.get_rng_state())

    return fwd_gpu_devices, fwd_gpu_states

def set_device_states(devices, states) -> None:
    """針對不同的設備設置隨機數生成器的狀態"""
    for device, state in zip(devices, states):
        with torch.cuda.device(device):
            torch.cuda.set_rng_state(state)

核心 Function

可以看到,這里的 Checkpoint 本身就是基於 PyTorch 的 PyTorch 自定義算子之 Function 實現的一個擴展算子,所以該部分代碼也涉及到了Function的諸多功能。

閱讀它既可以幫助我們同時復習一下相關的知識,又能進一步了解更復雜的處理邏輯該如何搭建。

class CheckpointFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, run_function, preserve_rng_state, *args):
        check_backward_validity(args)
        # 暫存前向傳播函數
        ctx.run_function = run_function
        ctx.preserve_rng_state = preserve_rng_state
        # 用來保存當前模型的混合精度的狀態,以用在反向傳播中
        ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
        if preserve_rng_state:  # 保存目標模塊前向傳播之前,此時CPU和GPU的隨機數生成器的狀態
            ctx.fwd_cpu_state = torch.get_rng_state()
            # Don't eagerly initialize the cuda context by accident.
            # (If the user intends that the context is initialized later, within their
            # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
            # we have no way to anticipate this will happen before we run the function.)
            ctx.had_cuda_in_fwd = False
            if torch.cuda._initialized:  
                # PyTorch提供的一個內部變量,用於判定CUDA狀態是否已經被初始化了
                # torch.cuda.is_initialized中就用到了該變量
                ctx.had_cuda_in_fwd = True
                # 保存輸入變量涉及的各個GPU設備的隨機狀態
                ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)

        # Save non-tensor inputs in ctx, keep a placeholder None for tensors
        # to be filled out during the backward.
        ctx.inputs = []
        ctx.tensor_indices = []
        tensor_inputs = []
        for i, arg in enumerate(args):
            if torch.is_tensor(arg):
                tensor_inputs.append(arg)
                ctx.tensor_indices.append(i)
                ctx.inputs.append(None)
            else:
                ctx.inputs.append(arg)

        # save_for_backward()中保存反向傳播中需要用到的輸入和輸出tensor量。
        # 由於在反向傳播中需要重新計算記錄梯度的output,所以就不要保存output了。
        # 並且后面的計算也不需要在梯度模式下計算。
        ctx.save_for_backward(*tensor_inputs)

        with torch.no_grad():  
            # 不保存梯度的前向傳播操作,也就是說這里的output是不會記錄中間變量,無法直接計算梯度的。
            outputs = run_function(*args)
        return outputs

    @staticmethod
    def backward(ctx, *args):
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError(
                "Checkpointing is not compatible with .grad() or when an `inputs` parameter"
                " is passed to .backward(). Please use .backward() and do not pass its `inputs`"
                " argument.")
        # Copy the list to avoid modifying original list.
        inputs = list(ctx.inputs)
        tensor_indices = ctx.tensor_indices
        tensors = ctx.saved_tensors # 獲取前向傳播中保存的輸入tensor

        # Fill in inputs with appropriate saved tensors.
        for i, idx in enumerate(tensor_indices):
            inputs[idx] = tensors[i]

        # Stash the surrounding rng state, and mimic the state that was
        # present at this time during forward.  Restore the surrounding state
        # when we're done.
        rng_devices = []
        if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
            rng_devices = ctx.fwd_gpu_devices
        
        # 使用之前前向傳播開始之前保存的隨機數生成器的狀態來進行一次一模一樣的前向傳播過程
        with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
            # 使用上下文管理器保護原始的隨機數生成器的狀態,內部處理后在進行復原
            if ctx.preserve_rng_state:
                torch.set_rng_state(ctx.fwd_cpu_state)
                if ctx.had_cuda_in_fwd:
                    set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
            # 這里將inputs從計算圖中剝離開,但是其屬性requires_grad和原來是一樣的,這么做的目的是為了截斷反向傳播的路徑。
            # 從整個操作目的來看,由於我們需要重新計算輸出,並將梯度回傳到輸入上,所以輸入本身需要可以記錄梯度。
            # 但是這里的回傳不可以影響到checkpoint之外更靠前的那些操作,
            # backward之后會將之前保存的中間變量釋放掉,而我們僅僅是為了計算當前一小塊結構,所以梯度回傳需要截斷。
            detached_inputs = detach_variable(tuple(inputs))  # 會變成葉子結點,grad和grad_fn均重置為None
            # 處理完隨機狀態之后,就該准備着手重新前向傳播了。
            # 這次前向傳播是在梯度模式(torch.enable_grad())下執行的。此時會保存中間變量。
            with torch.enable_grad(), torch.cuda.amp.autocast(ctx.had_autocast_in_fwd):
                outputs = ctx.run_function(*detached_inputs)

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs,)

        # run backward() with only tensor that requires grad
        outputs_with_grad = []
        args_with_grad = []
        for i in range(len(outputs)):
            # 記錄需要計算梯度的輸出outputs[i]以及對應的回傳回來的有效梯度args[i]
            if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
                outputs_with_grad.append(outputs[i])
                args_with_grad.append(args[i])
        # 檢查需要計算梯度的輸出,如果沒有輸出需要計算梯度,那么實際上就說明這個模塊是不參與梯度計算的,
        # 也就是說,該模塊不需要使用checkpoint來調整。
        if len(outputs_with_grad) == 0:
            raise RuntimeError(
                "none of output has requires_grad=True,"
                " this checkpoint() is not necessary")
        # 該操作對被包裹的目標操作計算反向傳播,即計算回傳到輸入detached_inputs上的梯度。
        # 由於輸入的tensor已被從整體梯度圖中剝離,所以可以看做是一個葉子結點,可以在反向傳播之后獲得其梯度,並且中間變量也會隨之釋放。
        # 另外這里反傳計算梯度也不會導致將更靠前的結構中暫時保存來計算梯度的參數給釋放掉。
        torch.autograd.backward(outputs_with_grad, args_with_grad)
        # 如果前面不執行detach(),這里的inp.grad會被直接釋放並置為None,這並不符合預期
        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
                      for inp in detached_inputs)

        # 這里返回的梯度與當前類的forward輸入一一對應,
        # 由於這里的forward包含着本不需要梯度的兩個參數run_function、preserve_rng_state,故對應回傳None即可。
        return (None, None) + grads

這里實際上就是在原始的操作和整體的計算圖之間添加了一個中間層,用於信息的交互:

  1. 原始模型的數據傳輸到被包裹的目標層的時候,數據進入 checkpoint 的 forward() 中,被 checkpoint 進行檢查和記錄后,再送入目標層中;
  2. 目標層在非梯度模式下執行前向傳播。該模式下,新創建的 tensor 都是不會記錄梯度信息的;
  3. 目標層的結果通過 checkpoint 的前向傳播輸出,送入模型后續的其他結構中;
  4. 執行反向傳播,損失求導,鏈式回傳,計算梯度;
  5. 回傳回來的對應於 checkpoint 輸出的梯度被送入其對應的反向傳播函數,即 checkpoint 的 backward()
  6. 梯度送入 checkpoint 中后,需要進一步將梯度回傳到目標層的輸入上。由於在 checkpoint 的 forward 中目標層本身前向傳播是處於非梯度狀態下,所以回傳路徑上缺少目標層中操作的梯度子圖。於是為了獲取這部分信息,需要先梯度狀態下對目標層進行一次前向傳播,通過將回傳回來的梯度和目標層的輸出一起執行 torch.autograd.backward(outputs_with_grad, args_with_grad),從而獲得對應輸入的梯度信息。
  7. 將對應目標操作輸入的梯度信息按照 checkpoint 本身 Function 的 backward 的需求,使用 None 對其他輔助參數的梯度占位后進行返回。
  8. 返回的對應於其他模塊的輸出量的梯度,被沿着反向傳播的路徑送入對應操作的 backward 中,一層一層回傳累加到各個葉子節點上。

定義好操作后,進行一個簡單的包裝,同時處理一下默認參數,補充了更細致的文檔:

def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
    r"""Checkpoint a model or part of the model
    
    Checkpointing works by trading compute for memory. Rather than storing all
    intermediate activations of the entire computation graph for computing
    backward, the checkpointed part does **not** save intermediate activations,
    and instead recomputes them in backward pass. It can be applied on any part
    of a model.
    
    Specifically, in the forward pass, :attr:`function` will run in
    :func:`torch.no_grad` manner, i.e., not storing the intermediate
    activations. Instead, the forward pass saves the inputs tuple and the
    :attr:`function` parameter. In the backwards pass, the saved inputs and
    :attr:`function` is retrieved, and the forward pass is computed on
    :attr:`function` again, now tracking the intermediate activations, and then
    the gradients are calculated using these activation values.
    這一段詳細介紹了checkpoint的核心技術,也就是在非梯度模式下執行目標操作的前向傳播,只保留輸入和結構參數,省去了中間激活的保存。反向傳播時在梯度模式下重新計算這些激活,重建這部分反向圖,進而實現了梯度的正常回傳。
    
    The output of :attr:`function` can contain non-Tensor values and gradient
    recording is only performed for the Tensor values. Note that if the output
    consists of nested structures (ex: custom objects, lists, dicts etc.)
    consisting of Tensors, these Tensors nested in custom structures will not
    be considered as part of autograd.
    因為checkpoint的backward實現的邏輯中,直接遍歷目標操作的輸出(會被自定轉換成元組類型)並確定那些需要回流梯度的輸出。如果輸出中包含其他的非tensor結構,就會導致在遍歷過程中這些輸出被忽略掉。不過也確實,這樣直接簡化處理雖然使得靈活性下降,但是卻也避免了代碼過於復雜。
    
    .. warning::
        Checkpointing currently only supports :func:`torch.autograd.backward`
        and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
        is not supported.
    
    .. warning::
        If :attr:`function` invocation during backward does anything different
        than the one during forward, e.g., due to some global variable, the
        checkpointed version won't be equivalent, and unfortunately it can't be
        detected.
        盡量保證目標操作在反向計算期間和前向期間的操作的一致性。
        因為在checkpoint會在反向中重新計算一次前向,這可能會帶來一些由於無法檢測到的不確定因素而造成的與常規版本的差異。
        
    .. warning::
        If checkpointed segment contains tensors detached from the computational
        graph by `detach()` or `torch.no_grad()`, the backward pass will raise an
        error. This is because `checkpoint` makes all the outputs require
        gradients which causes issues when a tensor is defined to have no
        gradient in the model. To circumvent this, detach the tensors outside of
        the `checkpoint` function.
        不要在目標操作中包含detach或者非梯度模式的處理。
        **在我的實際測試中似乎並沒有這個問題?**或許這里應該看一下pytorch提供的測試案例。
        
    .. warning::
        At least one of the inputs needs to have :code:`requires_grad=True` if
        grads are needed for model inputs, otherwise the checkpointed part of the
        model won't have gradients. At least one of the outputs needs to have
        :code:`requires_grad=True` as well.
        要保證至少有一個輸入是requires_grad的,這樣才可以保證這部分操作可以被記錄梯度。
        也要保證輸出至少有一個需要計算梯度。

    Args:
        function: describes what to run in the forward pass of the model or
            part of the model. It should also know how to handle the inputs
            passed as the tuple. For example, in LSTM, if user passes
            ``(activation, hidden)``, :attr:`function` should correctly use the
            first input as ``activation`` and the second input as ``hidden``
        preserve_rng_state(bool, optional, default=True):  Omit stashing and restoring
            the RNG state during each checkpoint.
        args: tuple containing inputs to the :attr:`function`

    Returns:
        Output of running :attr:`function` on :attr:`*args`
    """
    # Hack to mix *args with **kwargs in a python 2.7-compliant way
    preserve = kwargs.pop('preserve_rng_state', True)
    if kwargs:
        raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))

    return CheckpointFunction.apply(function, preserve, *args)

應用案例

Checkpoint for Sequential

PyTorch 源碼中給了一個很直接的應用案例,就是將 checkpoint 應用於 Sequential 搭建起來的模型。按照分段數 segments 指定的,將模型划分為多段。

def checkpoint_sequential(functions, segments, input, **kwargs):
    r"""A helper function for checkpointing sequential models.

    Sequential models execute a list of modules/functions in order
    (sequentially). Therefore, we can divide such a model in various segments
    and checkpoint each segment. All segments except the last will run in
    :func:`torch.no_grad` manner, i.e., not storing the intermediate
    activations. The inputs of each checkpointed segment will be saved for
    re-running the segment in the backward pass.

    See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.

    .. warning::
        Checkpointing currently only supports :func:`torch.autograd.backward`
        and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
        is not supported.

    .. warning:
        At least one of the inputs needs to have :code:`requires_grad=True` if
        grads are needed for model inputs, otherwise the checkpointed part of the
        model won't have gradients.

    .. warning:
        Since PyTorch 1.4, it allows only one Tensor as the input and
        intermediate outputs, just like :class:`torch.nn.Sequential`.

    Args:
        functions: A :class:`torch.nn.Sequential` or the list of modules or
            functions (comprising the model) to run sequentially.
        segments: Number of chunks to create in the model
        input: A Tensor that is input to :attr:`functions`
        preserve_rng_state(bool, optional, default=True):  Omit stashing and restoring
            the RNG state during each checkpoint.

    Returns:
        Output of running :attr:`functions` sequentially on :attr:`*inputs`

    Example:
        >>> model = nn.Sequential(...)
        >>> input_var = checkpoint_sequential(model, chunks, input_var)
    """
    # Hack for keyword-only parameter in a python 2.7-compliant way
    preserve = kwargs.pop('preserve_rng_state', True)
    if kwargs:
        raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))

    def run_function(start, end, functions):
        def forward(input):
            for j in range(start, end + 1):
                input = functions[j](input)
            return input
        return forward

    if isinstance(functions, torch.nn.Sequential):
        functions = list(functions.children()) 
        # 獲取Sequential的子模塊,這里使用children方法,僅獲取最外層

    segment_size = len(functions) // segments
    # the last chunk has to be non-volatile (為什么?似乎加上也是可以的)
    end = -1
    for start in range(0, segment_size * (segments - 1), segment_size):
        end = start + segment_size - 1
        # 迭代式的將各個子模塊集合使用checkpoint包裝並前向傳播。
        input = checkpoint(run_function(start, end, functions), input,
                           preserve_rng_state=preserve)
    # 剩余的結構不再使用checkpoint
    return run_function(end + 1, len(functions) - 1, functions)(input)

參考鏈接


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM