[源碼解析] PyTorch 分布式(11) ----- DistributedDataParallel 之 構建Reducer和Join操作


[源碼解析] PyTorch 分布式(11) ----- DistributedDataParallel 之 構建Reducer和Join操作

0x00 摘要

因為前文已經圍繞Reducer相關的各種成員變量做了相關分析,所以本文開始做動態邏輯分析,目的是:把前面幾篇文章串聯起來,為后面分析前向傳播和反向傳播設定基礎。

本系列其他文章如下:

深度學習利器之自動微分(1)

深度學習利器之自動微分(2)

[源碼解析]深度學習利器之自動微分(3) --- 示例解讀

[源碼解析]PyTorch如何實現前向傳播(1) --- 基礎類(上)

[源碼解析]PyTorch如何實現前向傳播(2) --- 基礎類(下)

[源碼解析] PyTorch如何實現前向傳播(3) --- 具體實現

[源碼解析] Pytorch 如何實現后向傳播 (1)---- 調用引擎

[源碼解析] Pytorch 如何實現后向傳播 (2)---- 引擎靜態結構

[源碼解析] Pytorch 如何實現后向傳播 (3)---- 引擎動態邏輯

[源碼解析] PyTorch 如何實現后向傳播 (4)---- 具體算法

[源碼解析] PyTorch 分布式(1)------歷史和概述

[源碼解析] PyTorch 分布式(2) ----- DataParallel(上)

[源碼解析] PyTorch 分布式(3) ----- DataParallel(下)

[源碼解析] PyTorch 分布式(4)------分布式應用基礎概念

[源碼解析] PyTorch分布式(5) ------ DistributedDataParallel 總述&如何使用

[源碼解析] PyTorch分布式(6) ---DistributedDataParallel -- 初始化&store

[源碼解析] PyTorch 分布式(7) ----- DistributedDataParallel 之進程組

[源碼解析] PyTorch 分布式(8) -------- DistributedDataParallel之論文篇

[源碼解析] PyTorch 分布式(9) ----- DistributedDataParallel 之初始化

[源碼解析] PyTorch 分布式(10)------DistributedDataParallel 之 Reducer靜態架構

0x01 引論

為了更好的分析,我們還是需要看看如何調用。

1.1 調用

Reducer 的創建代碼如下,是在_ddp_init_helper 之中。

        # Note: reverse list of buckets because we want to approximate the
        # order in which their gradients are produced, and assume they
        # are used in the forward pass in the order they are defined.
        self.reducer = dist.Reducer(
            parameters, # parameters[0]是張量列表
            list(reversed(bucket_indices)), # 桶信息
            self.process_group,
            expect_sparse_gradient,
            self.bucket_bytes_cap,
            self.find_unused_parameters,
            self.gradient_as_bucket_view,
            param_to_name_mapping,
        )

1.2 參數說明

調用的 parameters 舉例如下, parameters[0] 就是 rank 0 上模型的 parameters,可以看到其只有 [0] 元素有意義,這個 [0] 原始本身包括 20 個元素:

parameters = {list: 1} 
0 = {list: 4}           
 0 = {Parameter: 10} Parameter containing:\ntensor([[-4.0381e-02,  3.8828e-02, 1  )   
 1 = {Parameter: 10} Parameter containing:\ntensor([-0.0438, -0.2033,  0.2771,  0.0721,  ) 
 2 = {Parameter: 5} Parameter containing:\ntensor([[-0.0094, -0.1319,  0.0713,  0.3155,  )
 3 = {Parameter: 5} Parameter containing:\ntensor([-0.0008,  0.0582, -0.1245, -0.2538, )
 ...
 20 = {Parameter: 5} Parameter containing:\ntensor([-0.0008,  0.0582, -0.1245, -0.2538, )                                                   
 __len__ = {int} 20
__len__ = {int} 1

bucket_indices 舉例如下:

關於 tensor indices,就是給所有的tensor一個index,從0開始遞增,一直到 tensors.size()。假如模型的 parameters 一共有20個張量,則 tensor index 從 0 到 19,分成 6 個buckets,則在這6個buckets之中,每個 tensor index 都是唯一不重復的。

+-----------------------------------------------------------------------+
|                                                                       |
|  <tensor index 0, tensor index 1, tensor index 2, tensor index 3>     |
|                                                                       |
|                                                                       |
|  <tensor index 4, tensor index 5, tensor 6>                           |
|                                                                       |
|                                                                       |
|  ......                                                               |
|                                                                       |
|                                                                       |
|  <tensor index 16, tensor index 17, tensor index 18, tensor index 19> |
|                                                                       |
+-----------------------------------------------------------------------+

接下來,我們就看看如何進行初始化 Reducer。

0x02 Reducer 初始化

代碼位於:torch/lib/c10d/reducer.h 和 torch/lib/c10d/reducer.cpp

2.1 構造函數

具體邏輯如下:

  • 看看本模塊是不是多設備模塊,具體是: 遍歷張量,得到張量的設備,把設備插入到一個set結構之中,如果set內的設備多於一個,是多設備
  • 如果 expect_sparse_gradients沒有設置,就把expect_sparse_gradients_初始化為false。
  • 調用 initialize_buckets 初始化 buckets 並盡可能按照逆序將 parameters 分配到 buckets 之中,這樣按桶通信就可以提高效率。后續在運行時候也可能再次重新初始化桶。
  • 為每個 parameter 加上 grad_accumulator,它們在 backward 時負責梯度同步。
    • 因為這些variables是autograd圖的葉子張量,所以它們的grad_fn都被設置為 gradient accumulation function。
    • Reducer保存了指向這些functions的指針,這樣Reducer就可以知道它們在autograd傳播之中是否被使用,如果沒有使用,那么就把這些functions的梯度張量(grad tensors)設置為規約就緒狀態。
    • 遍歷張量,為每個張量生成一個類型為VariableIndex的變量index。
    • 得到Variable::AutogradMeta的grad_accumulator_,即用於累加葉子 Variable 的梯度累加器。
    • 把reducer的autograd_hook函數添加進去每個grad_accumulator_之中,變量index是hook的參數。這個 hook 掛在 autograd graph 之上,在 backward 時負責梯度同步。grad_accumulator 執行完后,autograd_hook 就會運行。
  • gradAccToVariableMap_ 存了grad_accumulator & index 的對應關系(函數指針和參數張量的對應關系),這樣以后在 autograd graph 遍歷尋找 unused parameters 就方便了。
  • 初始化 backward_stats_。
  • 調用 initialize_local_used_map 初始化各種 unused map。
// The constructor takes a list of variables for every model replica.
// The bucket assignment for this reducer is specified as a list of
// buckets, each of which is specified as a list of indices into the
// variables list for **a single replica** (i.e. `variables[0]`).
Reducer::Reducer(
    std::vector<std::vector<at::Tensor>> replicas, // 張量
    std::vector<std::vector<size_t>> bucket_indices, // 桶信息
    c10::intrusive_ptr<c10d::ProcessGroup> process_group,
    std::vector<std::vector<bool>> expect_sparse_gradients,
    int64_t bucket_bytes_cap,
    bool find_unused_parameters,
    bool gradient_as_bucket_view,
    std::unordered_map<size_t, std::string> paramNames)
    : replicas_(std::move(replicas)),
      process_group_(std::move(process_group)),
      expect_sparse_gradients_(std::move(expect_sparse_gradients)),
      expect_autograd_hooks_(false),
      require_finalize_(false),
      next_bucket_(0),
      has_marked_unused_parameters_(false),
      find_unused_parameters_(find_unused_parameters),
      gradient_as_bucket_view_(gradient_as_bucket_view),
      local_used_maps_reduced_(false),
      num_iterations_(0),
      num_buckets_ready_(0),
      has_rebuilt_bucket_(false),
      bucket_bytes_cap_(bucket_bytes_cap),
      divFactor_(kUnsetDivFactor),
      static_graph_(false),
      comm_hook_(nullptr),
      thread_local_state_(at::ThreadLocalState()),
      ddp_debug_level_(parseDistDebugLevel()),
      param_names_(std::move(paramNames)) {

  // Check whether the module is multi_device_module
  // 看看本模塊是不是多設備模塊
  {
    std::set<int> unique_devices;
    for (const auto& v : replicas_[0]) { // 遍歷張量
      auto device_idx = int(v.device().index()); // 得到張量的設備
      if (unique_devices.find(device_idx) == unique_devices.end()) {
        unique_devices.insert(device_idx); // 把設備插入到一個set結構之中
        if (unique_devices.size() > 1) { // 如果set內的設備多於一個,是多設備
          is_multi_device_module_ = true; 
          break;
        }
      }
    }
  }

  // If `expect_sparse_gradients` is not specified, initialize it such that
  // we do not expect sparse gradients for any parameter.
  if (expect_sparse_gradients_.empty()) {
    expect_sparse_gradients_ = std::vector<std::vector<bool>>(
        replicas_.size(), std::vector<bool>(replicas_[0].size(), false));
  }

  // Initialize variable bucketing.
  // This can be reinitialized later after capturing runtime information.
  {
    std::lock_guard<std::mutex> lock(mutex_);
    initialize_buckets(std::move(bucket_indices)); //初始化桶
  }

  // All variables are expected to have their `grad_fn` set to the gradient
  // accumulation function (since they are leafs in the autograd graph).
  // We store pointers to these functions such that we can check if they are
  // used in an autograd pass. If they are not, we know their grad tensors
  // can be marked as ready for reduction.
  {
    const auto replica_count = replicas_.size();
    grad_accumulators_.resize(replica_count);
    for (size_t replica_index = 0; replica_index < replica_count; // 只有replicas_[0]有意義
         replica_index++) {
      const auto variable_count = replicas_[replica_index].size(); //張量數目
      grad_accumulators_[replica_index].resize(variable_count); // 給grad_accumulators_分配內存
        
      for (size_t variable_index = 0; variable_index < variable_count;
           variable_index++) { // 遍歷張量,variable_index 就是張量的index
        auto& variable = replicas_[replica_index][variable_index]; //得到具體的張量
        const auto index = VariableIndex(replica_index, variable_index); //每個張量生成一個VariableIndex

        // The gradient accumulator function is lazily initialized once.
        // Therefore we can use its presence in the autograd graph as
        // evidence that the parameter has participated in an iteration.
        auto grad_accumulator =
            torch::autograd::impl::grad_accumulator(variable); // 得到Variable::AutogradMeta的grad_accumulator_,即,用於累加葉子 Variable 的梯度累加器

#ifndef _WIN32
        using torch::distributed::autograd::ThreadLocalDistAutogradContext;
#endif
        // Hook to execute after the gradient accumulator has executed.
        hooks_.emplace_back(
            // 累加器添加hook,這個 hook 掛在 autograd graph 之上,在 backward 時負責梯度同步。
            // grad_accumulator 執行完后,autograd_hook 就會運行
            grad_accumulator->add_post_hook(
                torch::make_unique<torch::autograd::utils::LambdaPostHook>(
                    [=](const torch::autograd::variable_list& outputs,
                        const torch::autograd::variable_list& /* unused */) {
#ifndef _WIN32
                      this->rpc_context_.set(
                          ThreadLocalDistAutogradContext::getContextPtr());
#endif
                      this->autograd_hook(index); // 把reducer的autograd_hook函數添加進去
                      return outputs;
                    })),
            grad_accumulator);

        // Map raw function pointer to replica index and parameter index.
        // This is used later on when the autograd graph is traversed
        // to check for parameters for which no gradient is computed, if
        // find_unused_parameters=True.
        // Note that the mapping of gradient accumulator to variable should be
        // one to one as we deduplicate shared parameters before constructing
        // Reducer.
          
        // gradAccToVariableMap_ 存了grad_accumulator & index 的對應關系(函數指針和參數張量的對應關系),這樣以后在 autograd graph 遍歷尋找 unused parameters 就方便了
        if (find_unused_parameters_) {
          gradAccToVariableMap_[grad_accumulator.get()] = index;
        }

        numGradHooksTriggeredMap_[index] = 0;

        // The gradient accumulator is stored as weak_ptr in the autograd
        // metadata of the variable, so we have to keep it alive here for
        // the raw pointer to be valid.
        TORCH_CHECK(
            grad_accumulators_[replica_index][variable_index] == nullptr,
            c10::str(
                "Reducer tried to register duplicate grad accumulator for replica ",
                replica_index,
                " variable ",
                variable_index));
        grad_accumulators_[replica_index][variable_index] =
            std::move(grad_accumulator);
      }
    }
  }

  // Initialize backward stats vector.
  {
    const auto replica_count = replicas_.size();
    backward_stats_.resize(replica_count);
    const auto variable_count = replicas_[0].size();
    std::for_each(
        backward_stats_.begin(),
        backward_stats_.end(),
        [=](std::vector<int64_t>& v) { v.resize(variable_count); });
  }

  // See Note [Skip allreducing local_used_maps_dev]
  if (find_unused_parameters_) {
    initialize_local_used_map();
  }
}

我們接下來具體分析每一個部分。

2.2 初始化桶

initialize_buckets方法用來初始化桶,具體邏輯是對於每一個桶,添加其模型副本,對於每一個模型副本,添加張量列表:

  • 用分布式上下文設置 rpc_context_。

    • 如果在DDP構造函數內調用initialize_bucket,則 rpc上下文指針(rpc context ptr)是否為null 無關緊要,因為grad不會發生變化。
    • 如果在訓練循環期間調用initialize_bucket,例如在rebuild_bucket 內部,因為grad可能會發生改變並指向bucket_view,那么它需要檢查rpc context ptr是否為null。
    • 如果rpc context ptr是null,則改變 variable.grad(),否則,在rpc上下文中改變梯度。
  • 清空buckets_ 和 variable_locators_。

  • 重置variable_locators_的尺寸,這樣每個variable都有一個bucket index。

  • 利用如下得到所有桶的個數和每個桶中副本個數:bucket_count = bucket_indices.size(); replica_count = replicas_.size();

  • 從0開始遞增到 bucket_count,逐一初始化 Bucket。

    • 生成一個 Bucket bucket
    • 如果bucket_indices[bucket_index].size() == 1,說明這個桶期待一個single sparse gradient,則設置 bucket.expect_sparse_gradient = true。
    • 從0開始遞增到replica_count,逐一初始化 BucketReplica。
      • 生成一個 BucketReplica replica
      • 如果這個桶期待一個single sparse gradient,則
        • 利用bucket_indices[bucket_index].front()取出向量第一個元素,設置為 variable_index。
        • 利用 variable_index 得到副本之中對應的variable。
        • 設置副本replica的變量列表,代碼為replica.variables = {variable},這個副本只包括一個variable。
      • 否則說明是dense gradient,則
        • 遍歷桶的variable,即利用 replicas_[replica_index][variable_index] 得到variable。
        • 設置variable的設備和數據類型
        • 給副本設置其variables,代碼為:replica.variables.push_back(variable)。
        • 設置replica 的一些關於variable的元信息,這些元信息是flat contents相關的,比如offsets存儲了各個張量在flat bucket contents中的offset。
        • 給relica.contents分配內存
        • 利用 initialize_bucket_views(replica, replica.contents) 初始化 cotnents 和 views。
        • 利用 bucket.replicas.push_back(std::move(replica)) 把這個 replica 加入到 bucket。
    • 遍歷桶中的variable,代碼為 bucket_indices[bucket_index]。
      • 設置 Reducer.variable_locators_,這樣 Reducer 就知道如何在 bucket 之中確定一個varaible。bucket_indexbuckets_列表的位置,表示 buckets_ 之上的一個bucket。intra_bucket_index 是在 bucket replica 之中 vector 域的 variable index。
    • 設置桶的變量,bucket.variable_indices = std::move(bucket_indices[bucket_index]);
    • 利用 buckets_.push_back(std::move(bucket)) 把bucket這個桶加入到 Reducer之中。

具體代碼是:

void Reducer::initialize_buckets(
    std::vector<std::vector<size_t>> bucket_indices) {
  // If initialize_buckets is called inside DDP constructor, then
  // it does not matter rpc context ptr is nullptr or not, as grad
  // will not be mutated.
  // If initialize_buckets is called during training loop, e.g, inside
  // rebuild_buckets(), since grad could be mutated and be pointed to
  // bucket_view, then it needs to check rpc context ptr is nullptr or not,
  // If rpc context ptr is nullptr, mutate variable.grad(); otherwise,
  // mutate grad in rpc context.
#ifndef _WIN32
  using torch::distributed::autograd::ThreadLocalDistAutogradContext;
  this->rpc_context_.set(ThreadLocalDistAutogradContext::getContextPtr());
#endif

  // This shouldn't be called if we're expecting autograd hooks to fire.
  TORCH_CHECK(
      !expect_autograd_hooks_,
      "`initialize_buckets` must NOT be called during autograd execution.");

  // Clear current bucket assignment.
  buckets_.clear();
  variable_locators_.clear();

  // Ensure we have a bucket index for every variable.
  variable_locators_.resize(replicas_[0].size());

  // Iterate over buckets.
  const auto bucket_count = bucket_indices.size();
  const auto replica_count = replicas_.size();
  buckets_.reserve(bucket_count);
  // 從0開始遞增到bucket_count
  for (size_t bucket_index = 0; bucket_index < bucket_count; bucket_index++) {
    Bucket bucket; // 生成一個桶

    // TODO(@pietern): Validate indices.
    // Must be non-empty, unique, and unique across buckets.
    TORCH_CHECK(
        bucket_indices[bucket_index].size() > 0, "Empty bucket specified.");

    // Variables that expect sparse gradients must have their own bucket.
    if (bucket_indices[bucket_index].size() == 1) {
      // 說明這個桶期待一個single sparse gradient
      const auto variable_index = bucket_indices[bucket_index].front();
      bucket.expect_sparse_gradient =
          expect_sparse_gradients_[0][variable_index];
    } else {
      for (const auto variable_index : bucket_indices[bucket_index]) {
        TORCH_CHECK(
            !expect_sparse_gradients_[0][variable_index],
            "Buckets with more than one variable cannot include variables ",
            "that expect a sparse gradient.");
      }
    }

    // Iterate over model replicas. 從0開始遞增到replica_count,遍歷模型副本數目,為每一個模型副本都要做同樣設置
    for (size_t replica_index = 0; replica_index < replica_count;
         replica_index++) {
      BucketReplica replica; // 生成一個副本

      if (bucket.expect_sparse_gradient) {
        // 說明這個桶期待一個single sparse gradient
        const auto variable_index = bucket_indices[bucket_index].front(); // 得到張量的index
        const auto& variable = replicas_[replica_index][variable_index]; // 得到張量
        TORCH_INTERNAL_ASSERT(bucket_indices[bucket_index].size() == 1);
        replica.variables = {variable}; // 這個副本只包括一個variable
      } else {
        at::TensorOptions options;
        // The start index of the variable in the flattened tensor.
        size_t offset = 0;

        // Reserve enough space for the per-variable fields stored in bucket
        // replica for efficiency.
        const size_t num_variables = bucket_indices[bucket_index].size();
        replica.variables.reserve(num_variables); 
        replica.offsets.reserve(num_variables);
        replica.lengths.reserve(num_variables);
        replica.sizes_vec.reserve(num_variables);

        // Iterate over bucket variables.
        for (const auto variable_index : bucket_indices[bucket_index]) { //遍歷桶中的variable
          TORCH_CHECK(
              variable_index < replicas_[replica_index].size(),
              "Out of range variable index specified.");
          const auto& variable = replicas_[replica_index][variable_index];
          if (!options.has_device()) {
            options = options.device(variable.device());
          } else {
            TORCH_CHECK(
                variable.device() == options.device(),
                "All parameters in a bucket must be ",
                "placed on the same device.");
          }
          if (!options.has_dtype()) {
            options = options.dtype(variable.dtype());
          } else {
            TORCH_CHECK(
                variable.dtype() == options.dtype(),
                "All parameters in a bucket must have the same dtype.");
          }
          
          const auto length = variable.numel();
          // 給副本設置其variables
          replica.variables.push_back(variable); // 這里添加了一個新變量,所以最終能知道該桶中的變量數目
          // 設置replica 的一些關於variable的元信息
          replica.offsets.push_back(offset);
          replica.lengths.push_back(length);
          replica.sizes_vec.push_back(variable.sizes());
          offset += length;
        }

        // Allocate bucket contents tensor.
        replica.contents = at::empty({static_cast<long>(offset)}, options);

        initialize_bucket_views(replica, replica.contents); // 初始化cotents和views
      }

      // Add bucket replica to enclosing bucket.
      bucket.replicas.push_back(std::move(replica)); // 桶的副本列表中添加一個新副本
    }

    // Map participating variables to this bucket.
    // This is identical across replicas so we only need to do this once.
    size_t intra_bucket_index = 0;
    for (const auto variable_index : bucket_indices[bucket_index]) { // 遍歷桶中的variable
      TORCH_CHECK(
          variable_index < variable_locators_.size(),
          "Out of range variable index specified.");
      variable_locators_[variable_index] = // 這樣 Reducer 就知道如何在 bucket 之中確定一個varaible
          VariableLocator(bucket_index, intra_bucket_index++);
    }
    bucket.variable_indices = std::move(bucket_indices[bucket_index]);

    buckets_.push_back(std::move(bucket)); // 把桶插入Reducer
  }
}

2.3 初始化視圖

initialize_bucket_views 這里是設置 Replica 的contents 和 views。

// (see Note:  "Gradient Layout Contract" in initialize_buckets).
void Reducer::initialize_bucket_views(
    Reducer::BucketReplica& replica,
    at::Tensor& contents) {
  for (size_t i = 0; i < replica.variables.size(); i++) {
    auto& v = replica.variables[i];
    const auto offset = replica.offsets[i];
    const auto length = replica.lengths[i];
    if (v.is_non_overlapping_and_dense()) { // Dense類型的張量
      // If the param's memory is dense, match its layout, anticipating
      // the autograd engine (AccumulateGrad) will also create gradients
      // matching its layout.
      replica.bucket_views_in.push_back( // replica.bucket_views_in里面都是視圖
          contents.as_strided(v.sizes(), v.strides(), offset));
    } else { // Sparse類型的張量
      // Fall back to a C-style contiguous view, again anticipating
      // AccumulateGrad will do the same when stashing grads for non-dense
      // params.
      replica.bucket_views_in.push_back( // replica.bucket_views_in里面都是視圖
          contents.narrow(0, offset, length).view(v.sizes()));
    }
    // By default `bucket_views_out` and `bucket_views_in` are
    // essentially the same thing.
    replica.bucket_views_out = replica.bucket_views_in; // out也是視圖

    // If gradient_as_bucket_view_ is set as true, then there are two cases to
    // handle: initialize_bucket_views could be called inside initialize_buckets
    // when rebuild_buckets, if grad has already been defined/calculated in
    // previous iteration, old grad needs to be copied into new bucket_view and
    // let grad point to the new bucket_view, initialize_bucket_views could also
    // be called inside initialize_buckets during construction. Grads are not
    // defined during construction time, in this case, do not let grad point to
    // bucket_view, because grads should be kept as being undefined for globally
    // unused parameters.
    if (gradient_as_bucket_view_) {
      auto& bucket_view = replica.bucket_views_in.back();
      runGradCallbackForVariable(v, [&](auto& grad) {
        if (grad.defined() && !grad.is_alias_of(bucket_view)) {
          bucket_view.copy_(grad);
          grad = bucket_view; // 梯度被修改了,需要回寫
          // The grad is modefied and needs to be written back.
          return true;
        }
        // The grad is not modified and does not need to be written back.
        return false; // 不需要回寫,因為沒有被修改
      });
    }
  }
}

2.3.1 BucketReplica成員變量

我們先回憶一下BucketReplica的幾個成員變量。

  • at::Tensor contents :把桶的內容展平的結果,即Flattened (1 dimensional) 之后的結果。
  • std::vector<at::Tensor> bucket_views_in :提供了從輸入角度在 contents 之中查看具體梯度的方法。
  • std::vector<at::Tensor> bucket_views_out :提供了從輸入角度在 contents 之中查看具體梯度的方法。

關於 std::vector<at::Tensor> bucket_views_instd::vector<at::Tensor> bucket_views_out 的進一步說明:

  • 這兩個變量提供在 contents 之中操作具體梯度的方法,或者說,它們提供了視圖(views),該視圖可以操作contents 之中每個張量的梯度。用戶把這兩個變量作為入口點來把每個梯度的數據從 content 之中移入和移出。
  • 在 PyTorch 之中,視圖是指創建一個方便查看的東西,視圖與原數據共享內存,它只是將原有的數據進行整理,直接顯示其中部分內容或者進行重排序后再顯示出來。

也需要對幾個 PyTorch 函數進行說明。

  • as_strided :依據現有tensor以及給定的步長來創建一個視圖(類型仍然為tensor),需要注意,這里的結果是視圖,所以這個張量依然和原始張量共享內存。
  • narrow :返回一個新的張量,其是原來張量的縮小版,但是這個張量依然和原始張量共享內存。

BucketReplica 邏輯具體如下圖:

+------------------------------------------+
| BucketReplica                            |
|                                          |
|       vector<Tensor> bucket_views_in +--------------------+
|                                          |                |
|                                          |                |
|       vector<Tensor> bucket_views_out +--------------+    |
|                                          |           |    |
|                                          |           |    |
|                                          |           v    v
|                                          |     +-----+----+--------------------------+
|       Tensor contents  +---------------------> |Flattened (Tensor1, Tensor2, Tensor3)|
|                                          |     +-------------------------------------+
|                                          |
|                                          |
|       vector<Tensor> variables  +------------>  [Tensor1,Tensor2,Tensor3]
|                                          |
|                                          |
|                                          |
+------------------------------------------+

2.3.2 調用

如何調用?如果gradient_as_bucket_view_設置為true,則有兩種情況需要處理:

  • rebuild_buckets 之中可以在initialize_bucket內調用initialize_bucket_view,如果grad在上一次迭代中已經定義/計算過,則需要將舊的grad復制到新的bucket_view中,並讓grad指向新的bucket_view,
  • 在構造過程中,也可以在initialize_bucket中調用initialize_bucket_views。在構造期間不會定義梯度,在這種情況下,不要讓梯度指向bucket_view,因為對於全局未使用的參數,梯度應保持為未定義。

2.4 初始化本地使用變量

initialize_local_used_map此處是初始化 local_used_maps_,我們回憶一下論文內容,local_used_maps_ 就是用來查找全局未使用參數(Globally Unused Parameters):

全局未使用參數(Globally Unused Parameters)的梯度在向前和向后過程中應保持不變。檢測未使用的參數需要全局信息,因為在一個DDP過程中,一個參數可能在一次操作中不存在,但可能在另一個過程的同一次迭代中參與訓練。因此DDP在位圖中維護本地未使用的參數信息,並啟動額外的AllReduce以收集全局位圖。由於位圖比張量尺寸小得多,因此模型中的所有參數共享同一位圖,而不是創建每桶位圖(per-bucket bitmaps)。位圖位於CPU上,以避免為每次更新啟動專用CUDA內核。但是,某些ProcessGroup后端可能無法在CPU 張量上運行AllReduce。例如,ProcessGroupNCCL僅支持CUDA張量。此外,由於DDP應該與任何定制的ProcessGroup后端一起工作,它不能假設所有后端都支持CPU張量。為了解決這個問題,DDP在同一設備上維護另一個位圖作為第一個模型參數,並調用非阻塞拷貝操作(non-blocking copy)將CPU位圖移動到設備位圖以進行集合通信

具體代碼如下:

void Reducer::initialize_local_used_map() {
  const auto replica_count = replicas_.size();
  const auto variable_count = replicas_[0].size();
  local_used_maps_.resize(replica_count);
  local_used_maps_dev_.resize(replica_count);

  for (size_t i = 0; i < replica_count; i++) {
    at::TensorOptions options;
    options = options.dtype(at::kInt);

    // Deliberately don't pin the memory even if local_used_maps_dev_ will
    // be cuda. See Note [local_used_maps_ -> local_used_maps_dev copying]
    local_used_maps_[i] =
        at::zeros({static_cast<long>(variable_count)}, options);

    // This tensor needs to be on the same device as replica because backend
    // such as NCCL may not support CPU tensors, and hence it might not work
    // if we always put it on CPU.
    options = options.device(replicas_[i][0].device());
    local_used_maps_dev_[i] =
        at::empty({static_cast<long>(variable_count)}, options);
  }
}

初始化流程大致如下:

                                    +
                                    |
                                    |
                                    v
                  rpc_context_ = ThreadLocalDistAutogradContext
                                    +
                                    |
                                    |
                                    v
                  buckets_ & variable_locators_ (clear & resize)
                                    +
                                    |
                                    |
                                    v
+----------------------->  from 0 ~ bucket_count :  +--------------------------->
|                                                                                +
|                                                                                |
|      +-------------------------------------------------------------------+     |
|      | init Bucket          set bucket_indices                           |     |
|      |                            +                                      |     |
|      |                            |                                      |     |
|      |                            |                                      |     |
|      |                            v                                      |     |
|      |   ^ +------------> from 0 ~ replica_count : +----------------->   |     |
|      |   |                                                           |   |     |
|      |   |  +---------------------------------------------------+    |   |     |
|      |   |  | init BucketReplica                                |    |   |     |
|      |   |  |                                                   |    |   |     |
<----+ |   +--+                                                   | <--+   | <---+
       |      |    bucket.replicas.push_back(std::move(replica))  |        |
       |      |                                                   |        |
       |      +----------------------+----------------------------+        |
       |                             |                                     |
       |                             |                                     |
       |                             v                                     |
       |             buckets_.push_back(std::move(bucket))                 |
       |                             +                                     |
       +-------------------------------------------------------------------+
                                     |
                                     v

得到的 Reducer 大致如下,這里需要注意的是 ,BucketReplica 每個桶只有一個:

            +----------------------------------------+                 +------------------+
            |tensor index 4, tensor index 5, tensor 6| <------+        | index 2, index 3 |
            +----------------------------------------+        |        +--------------+---+
                                                              |                       ^
                                                              |                       |
+---------------------------+   +---------------------------------------------------------+
| Reducer                   |   | +----------------------------------+     +------------+ |
|                           |   | |Bucket                     |      |     |Bucket    | | |
|                           |   | |                           +      |     |          | | |
| vector<Bucket> buckets_ +---> | | vector<size_t> variable_indices  |     | indices ++ | |
|                           |   | |                                  |     |            | |
|                           |   | |  vector<BucketReplica> replicas  | ... | replicas   | |
|                           |   | |                         +        |     |   +        | |
|                           |   | |                         |        |     |   |        | |
|                           |   | +----------------------------------+     +------------+ |
|                           |   |                           |                  |          |
+---------------------------+   +---------------------------------------------------------+
                                                            |                  |
                                                            |                  |
                                                            v                  v
                          +---------------------------------------+   +-------------------+
                          |  +----------------------------------+ |   | +---------------+ |
                          |  | BucketReplica                    | |   | | BucketReplica | |
                          |  |                                  | |   | |               | |
                          |  |                                  | |   | |               | |
                          |  |  vector<Tensor> bucket_views_in  | |   | |   views_in    | |
                          |  |                                  | |   | |               | |
                          |  |  vector<Tensor> bucket_views_out | |   | |   views_out   | |
                          |  |                                  | |   | |               | |
                          |  |  Tensor contents                 | |   | |   contents    | |
                          |  |                                  | |   | |               | |
                          |  |  vector<Tensor> variables        | |   | |   variables   | |
                          |  |                     +            | |   | |      +        | |
                          |  +----------------------------------+ |   | +---------------+ |
                          +---------------------------------------+   +-------------------+
                                                   |                           |
                                                   |                           |
                                                   v                           v
                                   +---------------+------------+    +---------+----------+
                                   |Tensor 4, Tensor 5, Tensor 6|    | Tensor 2, Tensor 3 |
                                   +----------------------------+    +--------------------+

0x03 靜態圖

3.1 緣由

雖然 PyTorch 是動態圖,但是用戶可以明確地讓DDP知道訓練圖是靜態的,有如下情況時候可以設定:

  1. 已使用和未使用的參數集在整個訓練循環中不變,在這種情況下,用戶是否將find_unsued_parameters設置為true並不重要。

  2. 圖形的訓練方式在整個訓練循環過程中不會改變(意味着不存在依賴於迭代的控制流)。當圖被設置為靜態時,DDP將支持以前不支持的case,比如:

    1. 可重入的反向傳播。
    2. 多次activation checkpointing。
    3. activation checkpointing 並且find_unused_parameters = true。
    4. 並不是所有的輸出張量都用於損失計算。。
    5. 在前向函數之外有一個模型參數。
    6. 當find_unsued_parameters=true時或者存在未使用的參數,可能會提高性能,因為DDP在每個迭代之內不會搜索網絡來檢查未使用的參數。

3.2 使用

_set_static_graph 可以配置靜態圖,此API應在DistributedDataParallel構造之后,並且在訓練循環開始之前調用。並且,也應該以同樣的方式對所有的rank 進行調用。例如:

ddp_model = DistributedDataParallel(model)
ddp_model._set_static_graph()
for i in range(n):

_set_static_graph 代碼為:

def _set_static_graph(self):
    """
    Users can explicitly let DDP know the trained graph is static,
    when 1) the set of used and unused parameters will not change
    during the whole training loop; in this case, it does not matter
    whether users set find_unsued_parameters = true or not.
    2) how the graph is trained will not change during the whole training
    loop (meaning there is no control flow depending on iterations).
    When graph is set to be static, DDP will support cases that can not
    be supported in the past: 1) reentrant backwards
    2) activation checkpointing multiple times 3)
    activation checkpointing with find_unused_parameters = true.
    4) not all output tensors are used in loss calculation.
    5) there is model parameter that is outside of forward function.
    6) potentially improve performance when find_unsued_parameters = true
    or there are unused parameters, as DDP will not search graph in each
    iteraton to detect unused parameters when static_graph is set to be True.

    This API should be called after DistributedDataParallel construction, and
    before training loops starts. Also it should be called in the same way for
    all ranks. For example:
        ddp_model = DistributedDataParallel(model)
        ddp_model._set_static_graph()
        for i in range(n):
            .....
    """
    self.static_graph = True
    self.reducer._set_static_graph() # 調用 Reducer 進行配置
    self.logger._set_static_graph()
    if self.find_unused_parameters:
        warnings.warn(
            "You passed find_unused_parameters=true to DistributedDataParallel, "
            "`_set_static_graph` will detect unused parameters automatically, so "
            "you do not need to set find_unused_parameters=true, just be sure these "
            "unused parameters will not change during training loop while calling "
            "`_set_static_graph`."
        )

3.2 Reducer

Reducer 只有在第一次迭代之后才能生成靜態圖,因為畢竟PyTorch還是動態的,無論如何也得走一步動態生成。

void Reducer::set_static_graph() {
  std::lock_guard<std::mutex> lock(mutex_);
  TORCH_CHECK(
      num_iterations_ == 0,
      "set_static_graph() should be called before training loop starts "
      "and after DistributedDataParallel is constructed.");
  static_graph_ = true;
  // when static_graph_ is set as true, always initialize_local_used_map
  // and detect the global unused parameters in the first iteration.
  initialize_local_used_map();
}

0x04 重建桶

4.1 為何要重建

因為 PyTorch 是動態生成計算圖,所以需要相應重建桶。但是只有設置了靜態圖 並且 第一次迭代之后才會重建,如果設置 find_unused_parameters_,就不重建。

  // Returns true if we should rebuild buckets, else false. We only rebuild
  // buckets once after the first iteration and never rebuild them if
  // find_unused_parameters_.
  inline bool should_rebuild_buckets() const {
    return (static_graph_ || !find_unused_parameters_) && !has_rebuilt_bucket_;
  }

4.2 准備重建

我們首先看看重建之前的一些准備。

push_rebuilt_params 就是插入一個重建參數列表。

void Reducer::push_rebuilt_params(const VariableIndex& index) {
  rebuilt_params_.push_back(
      replicas_[index.replica_index][index.variable_index]);
  rebuilt_param_indices_.push_back(index.variable_index);
}

其次,push_rebuilt_params_for_all_indices 會遍歷每個 replica,針對 replica 之中的每個 variable 進行設置。

void Reducer::push_rebuilt_params_for_all_indices() {
  std::lock_guard<std::mutex> lock(mutex_);
  if (!should_rebuild_buckets() || !rebuilt_param_indices_.empty()) {
    return;
  }
  const auto replica_count = replicas_.size();
  for (size_t replica_index = 0; replica_index < replica_count;
       ++replica_index) {
    const auto variable_count = replicas_[replica_index].size();
    for (size_t variable_index = 0; variable_index < variable_count;
         ++variable_index) {
      const auto index = VariableIndex(replica_index, variable_index);
      push_rebuilt_params(index);
    }
  }
}

4.3 重建

我們接下來看看重建機制。

DDP 根據張量在后向傳播中接收梯度的時間,使用 rebuilt_params_ 和 rebuilt_param_indices_ 來重建存儲桶。

rebuild_buckets 函數進行廣播通信調用,並且可以與下一個forward()調用重疊,因此它可以是異步的。

  • 在find_unused_parameters=true情況下重建bucket 就是異步操作,因為我們可以多次重建bucket,其中子圖經過訓練,參數索引順序可能會更頻繁地更改。
  • 對於find_unused_parameters=false的情況,bucket只重建一次,性能成本可以忽略不計。如果已重建存儲桶, rebuild_buckets 則返回true。
bool Reducer::rebuild_buckets() {
  // Ensure reduction for previous backwards pass is finished. If user's model
  // has unused parameters for example, this will raise an error recommending to
  // run with find_unused_parameters=True, instead of the size mismatch
  // exception below.
  std::lock_guard<std::mutex> lock(mutex_);
  ensure_prior_reduction_finished();
  if (!should_rebuild_buckets() || rebuilt_params_.empty()) {
    return false;
  }

  std::vector<std::vector<size_t>> rebuilt_bucket_indices;
  std::vector<size_t> bucket_size_limits;
  bucket_size_limits.push_back(kDefaultFirstBucketBytes);
  bucket_size_limits.push_back(bucket_bytes_cap_);
  rebuilt_bucket_indices = compute_bucket_assignment_by_size(
      rebuilt_params_,
      bucket_size_limits,
      expect_sparse_gradients_[0],
      rebuilt_param_indices_);

  // For rebuilt bucket indices, it needs to be synced across all ranks.
  // Broadcast the newly rebuilt bucket indices from rank 0 in default.
  // After syncing up rebuilt bucket indices, initialize buckets for reducer.
  sync_bucket_indices(rebuilt_bucket_indices);

  has_rebuilt_bucket_ = true; // 只重建一次
  rebuilt_params_.clear();
  rebuilt_param_indices_.clear();

  initialize_buckets(std::move(rebuilt_bucket_indices));
  return true;
}

4.4 何時設定重建

重建僅在以下情況進行設定:

  1. 第一次重建存儲桶

  2. static_graph_ is true 或 find_unused_parameters_ is false

  3. 此反向傳播過程需要運行allreduce。

在這里,我們只需基於梯度到達順序將張量及其參數索引轉儲到rebuilt_params_rebuilt_param_indices_。然后在finalize_backward() 結束時,將基於rebuilt_params_rebuilt_param_indices_重建存儲桶,然后廣播和初始化存儲桶。

此外,我們只需要轉儲一個副本的張量和參數索引。

以 mark_variable_ready 為例,其中就會調用 push_rebuilt_params(index) 來插入列表。

void Reducer::mark_variable_ready(VariableIndex index) {
  // Rebuild bucket only if 1) it is the first time to rebuild bucket 2)
  // static_graph_ is true or find_unused_parameters_ is false,
  // 3) this backward pass needs to run allreduce.
  // Here, we just dump tensors and their parameter indices into
  // rebuilt_params_ and rebuilt_param_indices_ based on gradient arriving
  // order, and then at the end of finalize_backward(), buckets will be
  // rebuilt based on rebuilt_params_ and rebuilt_param_indices_, and then
  // will be broadcasted and initialized. Also we only need to dump tensors
  // and parameter indices of one replica.
  if (should_rebuild_buckets()) {
    push_rebuilt_params(index); // 插入列表
  }

  const auto replica_index = index.replica_index;
  const auto variable_index = index.variable_index;

  if (replica_index == 0) {
    checkAndRaiseMarkedTwiceError(variable_index);
    perIterationReadyParams_.insert(variable_index);
  }
  backward_stats_[replica_index][variable_index] =
      current_time_in_nanos() - cpu_timer_.backward_compute_start_time;

  // Any time we mark a variable ready (be it in line due to unused parameters,
  // or via an autograd hook), we require a call to the finalize function. If
  // this doesn't happen before the next iteration (or call to
  // `prepare_for_backwards`), we know something is wrong.
  require_finalize_ = true;

  const auto& bucket_index = variable_locators_[variable_index];
  auto& bucket = buckets_[bucket_index.bucket_index];
  auto& replica = bucket.replicas[replica_index];

  set_divide_factor();

  if (bucket.expect_sparse_gradient) {
    mark_variable_ready_sparse(index);
  } else {
    mark_variable_ready_dense(index);
  }

  // TODO(@pietern): Make this work for both CPU/CUDA tensors.
  // When using CPU tensors we don't need to do this.
  // // Record event so that we can wait for all of them.
  // auto& event = replica.events[bucket_index.intra_bucket_index];
  // event.record();

  // Check if this was the final gradient for this bucket.
  if (--replica.pending == 0) {
    // Kick off reduction if all replicas for this bucket are ready.
    if (--bucket.pending == 0) {
      mark_bucket_ready(bucket_index.bucket_index);
    }
  }

  // Run finalizer function and kick off reduction for local_used_maps once the
  // final bucket was marked ready.
  if (next_bucket_ == buckets_.size()) {

    if (dynamic_graph_find_unused()) {
      all_reduce_local_used_map();
    }

    // The autograd engine uses the default stream when running callbacks, so we
    // pass in the current CUDA stream in case it is not the default.
    const c10::Stream currentStream = get_current_stream();
    torch::autograd::Engine::get_default_engine().queue_callback([=] {
      std::lock_guard<std::mutex> lock(this->mutex_);
      // Run callback with the current stream
      c10::OptionalStreamGuard currentStreamGuard{currentStream};
      if (should_collect_runtime_stats()) {
        record_backward_compute_end_time();
      }
      // Check that all buckets were completed and had their work kicked off.
      TORCH_INTERNAL_ASSERT(next_bucket_ == buckets_.size());
      this->finalize_backward();
    });
  }
}

4.5 直接調用

_rebuild_buckets 函數也可以直接調用,比如如下情況,就是在整個訓練期間內在 forward 調用了一次。

def forward(self, *inputs, **kwargs):
    with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
        self.reducer.save_thread_local_state()
        if torch.is_grad_enabled() and self.require_backward_grad_sync:
            self.num_iterations += 1
            self.reducer.prepare_for_forward()
        if self.ddp_uneven_inputs_config.ddp_join_enabled:
            ones = torch.ones(1, device=self.device)
            work = dist.all_reduce(ones, group=self.process_group, async_op=True)
            if self.ddp_uneven_inputs_config.ddp_join_throw_on_early_termination:
                # Active ranks schedule an allreduce with zeros, inactive
                # ranks schedule them with 1. If the result != 0 it
                # indicates at least one rank has terminated and we should
                # throw.
                zeros = torch.zeros(1, device=self.device)
                dist.all_reduce(zeros, group=self.process_group)
                should_throw_stop_iteration = zeros.item()
                if should_throw_stop_iteration:
                    raise RuntimeError(
                        "Detected at least one rank that exhausted inputs. Throwing across all ranks."
                    )
            else:
                self.reducer._set_forward_pass_work_handle(
                    work,
                    self.ddp_uneven_inputs_config.ddp_join_divide_by_initial_world_size,
                )

        # Calling _rebuild_buckets before forward compuation,
        # It may allocate new buckets before deallocating old buckets
        # inside _rebuild_buckets. To save peak memory usage,
        # call _rebuild_buckets before the peak memory usage increases
        # during forward computation.
        # This should be called only once during whole training period.
        
        # 在這里進行直接調用
        if torch.is_grad_enabled() and self.reducer._rebuild_buckets(): # 設定
            logging.info("Reducer buckets have been rebuilt in this iteration.")

再比如 Join 方法也可以直接調用進行重建。

@contextmanager
def join(
    self,
    divide_by_initial_world_size=True,
    enable=True,
    throw_on_early_termination=False,
):
  
  									# 忽略其他代碼
    
                    else:
                        # Some DDP process still needs to be joined.
                        if self.ddp_uneven_inputs_config.ddp_join_throw_on_early_termination:
                            # Schedule allreduce telling active ranks to terminate
                            ones = torch.ones(1, device=self.device)
                            dist.all_reduce(ones, group=self.process_group)
                            # Raising StopIteration doesn't throw error in python 3.6
                            # and throws RuntimeError in 3.7+ (PEP 479), so just
                            # raise RuntimeError here.
                            raise RuntimeError(
                                f"Rank {self._distributed_rank} exhausted all inputs."
                            )
                        if is_last_joiner:
                            is_last_joiner = False
                        # It will rebuild buckets only once during training period
                        
                        # 這里進行調用。
                        self.reducer._rebuild_buckets()
                        # Schedule a corresponding broadcast if we are syncing module
                        # buffers in the forward pass.
                        self._check_and_sync_module_buffers()   

既然提到了 Join,我們接下來就看看這個概念。

0x05 Join

Join 是為了解決訓練數據不均勻的問題,就是允許某些輸入較少的worker(其已經完成Join操作)可以繼續和那些尚未結束的worker繼續執行集合通信,就是一個欺騙操作(Shadow)。

5.1 緣起

支撐DDP背后的是幾個集合通信庫的all-reduce操作,其完成了各個worker之間的梯度同步。而當訓練數據在 ranks 之間的輸入是不均勻(uneven)的,就會導致DDP會掛起。因為集合通信要求在進程組中的所有rank都參與,因此如果一個rank的輸入少,其他ranks會hang或者報錯(取決於后端),而且任何類在執行同步集合通信時,在每次迭代都會遇到這個問題。

因此,DDP 給出了一個 "Join" API,Join是一個上下文管理器,在每個rank的訓練循環之中使用。數據量少的 rank 會提前耗盡輸入,這時它將給集合通信一個假象,從而會構建一個虛擬(dummy)的 all-reduce,以便在數據不足時候與其他 ranks 匹配。具體如何制造這個假象是由注冊hook指定。

其大致思路如下:

                +----------------------------+
                |             Data           |
                |   +--------+   +--------+  |
                |   |        |   | Empty  |  |
                |   |        |   |        |  |
                |   +-----+--+   +--------+  |
                |         |                  |
                |         |                  |
                +----------------------------+
                          |
                          |
        +------------+    |               +------------+
        |            |    |               |            |
+---->  |    Model   |    |               |   Model    | <-----+
|       |            |    |               |            |       |
|       +------+-----+    |               +------+-----+       |
|              |          |                      |             |
|              |          |                      |             |
|              v          |                      v             |
|       +------+-----+    |             +--------+----------+  |
|       |  Forward   +<---+             | _JoinHook         |  |
|       |  (local)   |                  |                   |  |
|       +------+-----+                  |                   |  |
|              |                        |                   |  |
|              |                        |                   |  |
|              v                        | +---------------+ |  |
|       +------+-----+                  | | main_hook     | |  |
|       |  Backward  |                  | |               | |  |
|       |  (local)   |                  | |               | |  |
|       +------+-----+                  | |               | |  |
|              |                        | |               | |  |
|              |                        | |               | |  |
|              v                        | |               | |  |
|       +------+-----+                  | |               | |  |
|       | All-Reduce |     Sync grads   | |   All-Reduce  | |  |
|       |            | <--------------> | |   (Dummy)     | |  |
|       +------+-----+                  | |               | |  |
|              |                        | +---------------+ |  |
|              |                        +-------------------+  |
|              v                                 |             |
|     +--------+-------+                         |             |
|     | Update Weights |                         |             |
|     |                |                         |             |
|     +--------+-------+                         |             |
|              |                                 |             |
|              |                                 |             |
+--------------+                                 +-------------+

5.2 使用

5.2.1 DistributedDataParallel

Join 可以和 DistributedDataParallel 一起使用,比如下面的例子之中,會啟動兩個worker,分別是 rank 0 和 rank 1,rank 0 會得到5個輸入,rank 1會得到6個輸入,這就是輸入不均衡。

如果沒有使用 Join,則 rank 1 會在處理第6個輸入時候死掉掛起,因為rank 0沒有相關輸入,所以rank 1只能等待。如果使用了 Join,則不會出現這種問題,可以順利結束。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel as DDP

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    with Join([model]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

這將產生以下輸出(其中print來自 0 級和 1 級的 ranks,可以任意排序):

Rank 0 has exhausted all 5 of its inputs!
Rank 1 has exhausted all 6 of its inputs!

5.2.2 ZeroRedundancyOptimizer

Join上下文不僅是和一個類合作,也可以和多個類一起,比如PyTorch 的ZeroRedundancyOptimizer

from torch.distributed.optim import ZeroRedundancyOptimizer as ZeRO
from torch.optim import Adam

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    optim = ZeRO(model.parameters(), Adam, lr=0.01)
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    # Pass both `model` and `optim` into `Join()`
    with Join([model, optim]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()
            optim.step()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

這將產生與以前相同的輸出。顯着的變化是需要另外將ZeroRedundancyOptimizer實例傳入 Join()

后續會對ZeroRedundancyOptimizer等機制也進行分析。

5.3 原理

在最新文檔 https://pytorch.org/tutorials/advanced/generic_join.html 之中,PyTorch 給出了一定解釋,我們翻譯如下。

為了更好的使用,我們將介紹Join類以及支持類JoinableJoinHook

備注:這部分在 v1.10.0 版本代碼之中。

5.3.1 Joinable

首先,與Join上下文管理器兼容的類必須繼承抽象基類Joinable。特別的,Joinable必須實現:

  • join_hook(self, **kwargs) -> JoinHook

這將返回 的JoinHook實例Joinable,用來確定加入的進程應如何影響由Joinable 執行的每次迭代集體通信。

  • join_device(self) -> torch.device

這將返回Join上下文管理器用來執行集體通信的設備,例如torch.device("cuda:0")torch.device("cpu")

  • join_process_group(self) -> ProcessGroup

這將返回Join上下文管理器用於執行集體通信的進程組。

概括一下,JoinHook負責具體行為,join_device 和 join_process_group 負責具體集合通信

需要注意的是,join_devicejoin_process_group是必需的屬性,他們可以確保上下文管理器能夠安排"加入"和"未加入"進程之間的集體通信。一種用法是使用 all-reduce 計算每次迭代中"未加入"進程的數量。另一種用法是實現 throw_on_early_termination=True所需的機制,我們將在下面解釋。

DistributedDataParallelZeroRedundancyOptimizer已經繼承Joinable並實現了上面的方法,這就是為什么我們可以在前面的例子中直接使用它們。

class DistributedDataParallel(Module, Joinable):

class ZeroRedundancyOptimizer(Optimizer, Joinable):

DDP 涉及到提供數據,所以繼承Joinable可以理解,ZeroRedundancyOptimizer 為何也需要繼承?這是因為 ZeroRedundancyOptimizer 可以和 DDP 一起合作,並且 ZeroRedundancyOptimizer 內部也有集合操作,所以需要被 Join 一起管理。

Joinable類應該確保調用Joinable構造函數,因為它初始化了一個JoinConfig實例,上下文管理器在內部使用JoinConfig來確保正確性。JoinConfig將在每個Joinable _join_config字段中保存。

5.3.2JoinHook

接下來,讓我們分解一下JoinHook類。JoinHook提供了兩個進入上下文管理器的入口點:

  • main_hook(self) -> None

當存在尚未加入(Join)的 rank 時,每個加入(Join)的 rank 都會重復調用此鈎子。它目的是在每次訓練迭代(例如,在一次前向傳遞,反向傳遞和優化器步驟)之中,隱藏由Joinable所執行的集體通信,即已經Join的rank 如何與未Join的rank執行集合通信

  • post_hook(self, is_last_joiner: bool) -> None

一旦所有 ranks 都加入,這個鈎子就會被調用。它傳遞了一個額外的 bool參數is_last_joiner,其表明此 rank 是否是最后加入的 rank 之一。該參數可能對同步有用。

5.3.2.1 ZeroRedundancyOptimizer

我們以 內置的 ZeroRedundancyOptimizer main hook 來給出一個鈎子的具體例子:因為加入的 rank 仍然負責更新和同步其參數分片,所以 main hook 依然執行優化器步驟。

class _ZeROJoinHook(_JoinHook):
    def __init__(self, zero):
        assert isinstance(zero, ZeroRedundancyOptimizer), \
            "ZeRO join hook requires passing in a ZeroRedundancyOptimizer " \
            "instance as the state"
        self.zero = zero
        super().__init__()

    def main_hook(self):
        """
        Performs an optimizer step, which updates the joined process's shard of
        the parameters and broadcasts those parameters.
        """
        self.zero.step()

step函數簡略如下:

def step(
    self,
    closure: Optional[Callable[[], float]] = None,
    **kwargs: Any,
) -> Optional[float]:
    _Join.notify_join_context(self) # 這里會通知
    # Check if the model trainability has changed
    is_trainable_mask = self._get_is_trainable_mask()
    if is_trainable_mask != self._is_trainable_mask:
        self._build_param_buckets()
        self._is_trainable_mask = is_trainable_mask

    # Sync the exposed `param_groups` attributes to the local optimizer in
    # case they have been updated
    self._sync_param_groups(self.param_groups, self.optim.param_groups)

    # Run the optimizer step on this shard only
    if closure is not None:
        loss = self.optim.step(closure=closure, **kwargs)  # type: ignore[call-arg]
    else:
        loss = self.optim.step(**kwargs)

    # Sync all of the updated parameter shards across the ranks
    self._sync_parameters()

    # Sync any updated attributes in the local optimizer to the exposed
    # `param_groups`
    self._sync_param_groups(self.optim.param_groups, self.param_groups)

    return loss

再來看看DistributedDataParallel

  • main_hook 依然會做相關的一系列操作來欺騙其他rank。
  • post-hook 會從最后加入的rank之一來廣播最終更新的模型,以確保模型在所有rank中都是相同的。
class _DDPJoinHook(_JoinHook):
    def __init__(self, ddp, divide_by_initial_world_size):
        """
        Sets config variables for internal usage.
        """
        ddp.logger._set_uneven_input_join()
        self.ddp = ddp
        self.ddp._divide_by_initial_world_size = divide_by_initial_world_size
        super().__init__()

    def main_hook(self):
        """
        Shadows the DDP collective communication operations in the forward and
        backward passes.
        """
        ddp = self.ddp
        # Buckets are rebuilt only once during a training period
        ddp.reducer._rebuild_buckets()

        # Schedule a broadcast if we are syncing module buffers in the
        # forward pass
        ddp._check_and_sync_module_buffers()

        # Check if need to sync in the backward pass
        work = ddp._check_global_requires_backward_grad_sync(is_joined_rank=True)
        work.wait()
        should_sync_backwards = work.result()[0].item() != 0
        # Forward parameter sync is disabled in the next iteration if we
        # are skipping gradient sync this iteration, so set
        # `require_forward_param_sync` accordingly
        ddp.require_forward_param_sync = should_sync_backwards
        if not should_sync_backwards:
            return

        # Schedule one allreduce per gradient bucket to match the backward
        # pass allreduce
        ddp._match_all_reduce_for_bwd_pass()

        # Check if we need to allreduce locally unused parameters
        if ddp.find_unused_parameters:
            ddp._match_unused_params_allreduce()

        # Rebuilt parameters are pushed only once during a training period
        ddp.reducer._push_all_rebuilt_params()

    def post_hook(self, is_last_joiner: bool):
        """
        Syncs the final model to ensure that the model is the same across all
        processes.
        """
        self.ddp._sync_final_model(is_last_joiner)

_sync_final_model 這里會廣播最新的模型。

# When running in join model, agrees upon a common rank and broadcast model
# parameters to all other ranks.
def _sync_final_model(self, is_last_joiner):
    # Agree upon the process that will be the authoritative model copy.
    # The current rank is a candidate for being the authoritative copy if
    # is_last_joiner=True. We break ties via picking the larger rank.
    self._authoritative_rank = self._find_common_rank(
        self._distributed_rank, is_last_joiner
    )
    self._sync_params_and_buffers(authoritative_rank=self._authoritative_rank)

5.3.3 Join

最后,讓我們看看這些基礎類是如何適應Join類本身的。

  • __init__(self, joinables: List[Joinable], enable: bool = True, throw_on_early_termination: bool = False)

正如我們在前面的例子中看到的,構造函數接收一個參與訓練循環的Joinable列表 。這些應該是在每次迭代中執行集體通信的類。

enablebool類型,如果您知道不會有不均勻的輸入,則可以設置為 False,在這種情況下,上下文管理器變得類似於contextlib.nullcontext(). 這也可能會在參與Joinable列表之中禁用join-related計算。

throw_on_early_terminationbool類型,其可以設置為True,以便讓每個等級在檢測到不均勻輸入時引發異常。這對於不符合上下文管理器要求的情況很有用,這通常是當來自不同類的集體通信可以任意交錯(interleaved)時,例如DistributedDataParallel與具有SyncBatchNorm層的模型一起使用時 。在這種情況下,應將此參數設置為 True以便應用程序邏輯可以捕獲異常並確定如何繼續。

  • 核心邏輯出現在該__exit__()方法中,該方法在存在未加入的 rank 時會進行循環調用每個 Joinable的主鈎子,然后一旦所有rank加入,就調用它們的 post 鈎子。主鈎子和后鈎子都按照Joinables 傳入的順序進行迭代。
  • 上下文管理器需要來自未加入進程的心跳。因此,每個Joinable類都應該在每次迭代的集體通信之前調用Join.notify_join_context() 。上下文管理器將確保只有第一個傳入的Joinable實際發送心跳。

5.4 例子

我們通過一個例子來具體看看。下面代碼之中,每個rank會打印(1)在Join之前看到的所有rank的輸入數量,以及(2)所有rank的輸入總數。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join, Joinable, JoinHook

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

class CounterJoinHook(JoinHook):
    r"""
    Join hook for :class:`Counter`.

    Arguments:
        counter (Counter): the :class:`Counter` object using this hook.
        sync_max_count (bool): whether to sync the max count once all ranks
            join.
    """
    def __init__(
        self,
        counter,
        sync_max_count
    ):
        self.counter = counter
        self.sync_max_count = sync_max_count

    def main_hook(self):
        r"""
        Shadows the counter's all-reduce by all-reducing a dim-1 zero tensor.
        """
        t = torch.zeros(1, device=self.counter.device)
        dist.all_reduce(t)

    def post_hook(self, is_last_joiner: bool):
        r"""
        Synchronizes the max count across all :class:`Counter` s if
        ``sync_max_count=True``.
        """
        if not self.sync_max_count:
            return
        rank = dist.get_rank(self.counter.process_group)
        common_rank = self.counter.find_common_rank(rank, is_last_joiner)
        if rank == common_rank:
            self.counter.max_count = self.counter.count.detach().clone()
        dist.broadcast(self.counter.max_count, src=common_rank)

class Counter(Joinable):
    r"""
    Example :class:`Joinable` that counts the number of training iterations
    that it participates in.
    """
    def __init__(self, device, process_group):
        super(Counter, self).__init__()
        self.device = device
        self.process_group = process_group
        self.count = torch.tensor([0], device=device).float()
        self.max_count = torch.tensor([0], device=device).float()

    def __call__(self):
        r"""
        Counts the number of inputs processed on this iteration by all ranks
        by all-reducing a dim-1 one tensor; increments its own internal count.
        """
        Join.notify_join_context(self)
        t = torch.ones(1, device=self.device).float()
        dist.all_reduce(t)
        self.count += t

    def join_hook(self, **kwargs) -> JoinHook:
        r"""
        Return a join hook that shadows the all-reduce in :meth:`__call__`.

        This join hook supports the following keyword arguments:
            sync_max_count (bool, optional): whether to synchronize the maximum
                count across all ranks once all ranks join; default is ``False``.
        """
        sync_max_count = kwargs.get("sync_max_count", False)
        return CounterJoinHook(self, sync_max_count)

    @property
    def join_device(self) -> torch.device:
        return self.device

    @property
    def join_process_group(self):
        return self.process_group

    # 確定最后join的rank,由於后加入的rank可能不止一個,所以選擇rank最大的rank來同步  
    def find_common_rank(self, rank, to_consider):
        r"""
        Returns the max rank of the ones to consider over the process group.
        """
        common_rank = torch.tensor([rank if to_consider else -1], device=self.device)
        dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
        common_rank = common_rank.item()
        return common_rank

def worker(rank):
    assert torch.cuda.device_count() >= WORLD_SIZE
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    counter = Counter(torch.device(f"cuda:{rank}"), dist.group.WORLD)
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    with Join([counter], sync_max_count=True):
        for _ in inputs:
            counter()

    print(f"{int(counter.count.item())} inputs processed before rank {rank} joined!")
    print(f"{int(counter.max_count.item())} inputs processed across all ranks!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

由於rank 0看到5個輸入,rank 1看到6個,因此產生輸出:

10 inputs processed before rank 0 joined!
11 inputs processed across all ranks!
11 inputs processed before rank 1 joined!
11 inputs processed across all ranks!

需要強調的一些要點:

  • Counter實例在每次迭代中執行一個all reduce操作,因此:
    • 對於已經Join的rank,其 main hook 也執行單個all reduce來對整體通信進行蒙騙操作( shadow it),注意這個 all-reduce是調用一個為0的tensor,所以對整體結果不影響。
    • 其他未 Join 的 rank 會以為這依然是一個正確的滿員的集合操作。
    • 這樣就處理了不均勻輸入。
  • Counter類在其 __call__()方法的開頭調用 Join.notify_join_context() ,因為這是每次集合操作(all-reduce)的地方,需要在這里通知上下文管理器,本示例還沒有Join(已經結束的rank不會調用到這里)。
  • 'is_last_joiner'參數用於確定post-hooks中的廣播源。
  • 我們將 sync_max_count 關鍵字參數傳遞給上下文管理器,上下文管理器會將其轉發給'Counter'的join hook。
  • post-hooks之中,會對 self.counter.max_count 進行廣播。

0xFF 參考

pytorch分布式系列3——分布式訓練時,torch.utils.data.distributed.DistributedSampler做了什么?

pytorch分布式系列1——搞清torch.distributed.launch相關的環境變量

pytorch分布式系列2——DistributedDataParallel是如何做同步的?

pytorch(分布式)數據並行個人實踐總結——DataParallel/DistributedDataParallel

Pytorch的nn.DataParallel

https://discuss.pytorch.org/t/dataparallel-imbalanced-memory-usage/22551/20

https://pytorch.org/docs/stable/distributed.html

PyTorch 源碼解讀之分布式訓練了解一下?

實操教程|PyTorch AutoGrad C++層實現

PYTORCH 自動微分(一)

PyTorch如何加速數據並行訓練?分布式秘籍大揭秘

pytorch分布式訓練(二init_process_group)

https://pytorch.org/tutorials/intermediate/ddp_tutorial.html

https://pytorch.org/docs/master/notes/ddp.html

https://pytorch.org/tutorials/intermediate/dist_tuto.html

PyTorch 源碼解讀之 DP & DDP:模型並行和分布式訓練解析

Pytorch模型中的parameter與buffer

【PyTorch開發者日 2020】PyTorch分布式數據並行(DDP)

[中文字幕] 深入理解 PyTorch 中的 Hook 機制

[中文字幕] 深入解讀 Pytorch AutoGrad

DISTRIBUTED TRAINING WITH UNEVEN INPUTS USING THE JOIN CONTEXT MANAGER

談談torch1.10中的ZeroRedundancyOptimizer和Join


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM