[源碼解析] 快手八卦 --- 機器學習分布式訓練新思路(3)


[源碼解析] 快手八卦 --- 機器學習分布式訓練新思路(3)

0x00 摘要

“Bagua“ 是快手和蘇黎世理工(ETH Zürich)聯合開發的分布式訓練框架。其專門針對分布式的場景設計特定的優化算法,實現算法和系統層面的聯合優化,力圖極致化分布式訓練的效率。其特點是:

  • 並行性能顯著提高;

  • 對網絡環境更魯棒;

  • “一鍵式”使用;

  • 分布式通訊算法易拓展性;

  • 可用於工業級場景大規模使用;

  • 安全、故障易排查;

本文以:

為基礎來分析學習。本文介紹去中心化和異步通信。

本系列前兩篇鏈接為:

[源碼解析] 快手八卦 --- 機器學習分布式訓練新思路(1)

[源碼解析] 快手八卦 --- 機器學習分布式訓練新思路(2)

0x02 去中心化

官方文章中是這樣介紹其設計思路的:

  • 中心化或是去中心化(Centralized or Decentralized):在中心化的通訊模式中,梯度或模型的同步過程需要所有的工作節點進行參與,因此,較高的網絡延時往往會導致訓練效率的降低。去中心化的通信模式 往往可以有效的解決這一問題:在該模式下,工作節點可以被連接成特定的拓撲結構(例如環),在通信過程中,每一個工作節點只與和它相鄰的節點進行通信。

以下結合 https://tutorials.baguasys.com/algorithms/decentralized 來學習。

2.1 示例用法

用戶可以在源碼之中找到運行去中心化 SGD 的完整示例,這里只是簡單介紹。

您需要初始化八卦算法:

from bagua.torch_api.algorithms import decentralized
algorithm = decentralized.DecentralizedAlgorithm()

然后用以下方法裝飾您的模型:

model = model.with_bagua([optimizer], algorithm)

2.2 去中心化培訓概述

Decentralized SGD 是一種數據並行的分布式學習算法,它消除了所有 worker 之間必有存在一個集中式全局模型的需求,這使得它在通信模式上與基於 Allreduce 或基於參數服務器的算法有很大不同。使用去中心化 SGD,每個 worker 只需要與一個或幾個特定的 worker 交換數據,而不是全局聚合數據。因此,去中心化通信的通信連接數比 Allreduce 少得多,通信開銷比 Parameter Server 更均衡。盡管去中心化 SGD 可能會導致每個 worker 的模型不同,但理論上已經證明,去中心化 SGD 算法的收斂速度與其對應中心化版本相同。

2.3 去中心化訓練算法

目前,不時有許多去中心化訓練算法被提出。這些令人驚嘆的工作集中在去中心化訓練的不同方面,如對等選擇(peer selection)、數據壓縮、異步等,並提供了許多遠見。到目前為止,八卦已經結合了兩種基本的去中心化算法,即去中心化 SGD和 低精度去中心化 SGD。憑借八卦對去中心化的自動系統支持,我們預計在不久的將來會實現越來越多的去中心化算法。

2.4 Decentralized SGD

現在我們將描述在八卦中實現的 Decentralized SGD 算法。讓我們假設worker 的數量是 n,worker上的模型參數 是:

\[x^{(i)} ,i∈ \{0,...,n−1\} \]

每個工作人員都能夠直接從任何其他工作人員發送或接收數據。在每次迭代 t 中,算法重復以下步驟:

  1. 迭代t 之中,每個worker 計算本地梯度 \(g^{(t)}_t\)

  2. 將本地模型與其選定的對等模型做平均:

    \[x_{t+\frac{1}{2}}^{(i)} = \frac{x^{(i)}_{t} + x_t^{(j)}}{2} \]

  3. 用局部梯度更新平均模型

    \[X^{(i)}_{t+1} = X^{(i)}_{t+\frac{1}{2}} - γg_t^{(i)} \]

在第 2 步中,我們采用一種策略為每次迭代中的每個 worker 選擇一個 peer,這樣所有 worker 都正確配對並且數據交換是有效的,因為每個 worker 可以在迭代之間與不同的 peer 交換數據。簡而言之,我們的策略將工作人員平均分成兩組,並在兩組之間動態配對 worker,每次迭代都不同。

2.5 通信開銷

去中心化 SGD 的通信開銷與網絡程度(degree of network)高度相關,即一個 worker 與其他 worker 的連接數。不同的拓撲或策略會導致不同程度的網絡。很明顯,我們之前描述的Decentralized SGD算法的網絡度為1。因此,在每次迭代中,一個worker只需要與一個worker建立一個連接來交換模型大小1倍的數據。我們比較了不同通信模式在最繁忙節點延遲和帶寬方面的通信復雜性。

算法 延遲復雜度 帶寬復雜度
Allreduce(環) O(n) O(1)
參數服務器 O(1) O(n)
八卦的Decentralized SGD O(1) O(1)

2.6 分析

前面官方教程之中,這部分是關鍵:

在第 2 步中,我們采用一種策略為每次迭代中的每個 worker 選擇一個 peer,這樣所有 worker 都正確配對並且數據交換是有效的,因為每個 worker 可以在迭代之間與不同的 peer 交換數據。簡而言之,我們的策略將工作人員平均分成兩組,並在兩組之間動態配對 worker,每次迭代都不同。

我們就以此出發來進行分析學習。

2.6.1 DecentralizedAlgorithmImpl

2.6.1.1 定義

參數 peer_selection_mode 可以有兩種選擇:

  • all表示在每個通信步驟中平均所有worker的權重。
  • shift_one 是指每個 worker 在每個通信步驟中選擇一個不同的對等點進行權重平均。
class DecentralizedAlgorithmImpl(AlgorithmImpl):
    def __init__(
        self,
        process_group: BaguaProcessGroup,
        hierarchical: bool = True,
        peer_selection_mode: str = "all",
        communication_interval: int = 1,
    ):
        """
        Implementation of the
        `Decentralized SGD <https://tutorials.baguasys.com/algorithms/decentralized>`_
        algorithm.

        Args:
            process_group (BaguaProcessGroup): The process group to work on.
            hierarchical (bool): Enable hierarchical communication.
            peer_selection_mode (str): Can be ``"all"`` or ``"shift_one"``. ``"all"`` means all workers'
                weights are averaged in each communication step. ``"shift_one"`` means each worker
                selects a different peer to do weights average in each communication step.
            communication_interval (int): Number of iterations between two communication steps.

        """
        super(DecentralizedAlgorithmImpl, self).__init__(process_group)
        self.hierarchical = hierarchical
        self.peer_selection_mode = peer_selection_mode
        self.communication_interval = communication_interval
        self.cuda_event = torch.cuda.Event()
2.6.1.2 初始化狀態

_init_states 方法把權重張量初始化到 bucket._peer_weight。

提一下,LowPrecisionDecentralizedAlgorithmImpl 是初始化了左右兩個 peer_weight,因為精力所限,本文不對其進行分析,有興趣的讀者可以自行深入。

def _init_states(self, bucket: BaguaBucket):
    weight_tensor = bucket.flattened_tensor()
    bucket._peer_weight = weight_tensor.to_bagua_tensor("peer_weight")
2.6.1.3 初始化操作

init_operations 使用 append_decentralized_synchronous_op 配置了 bucket 的 _decentralized_op 成員變量。

def init_operations(
    self,
    bagua_module: BaguaModule,
    bucket: BaguaBucket,
):
    self._init_states(bucket)
    torch.cuda.synchronize()
    bucket.clear_ops()
    decentralized_op = bucket.append_decentralized_synchronous_op( # 配置成員變量
        peer_weight=bucket._peer_weight,
        hierarchical=self.hierarchical,
        peer_selection_mode=self.peer_selection_mode,
        group=self.process_group,
    )
    bucket._decentralized_op = decentralized_op
2.6.1.4 Post操作

init_post_backward_hook 注冊了 post hook 操作,會把去中心化平均的結果拷貝回來,后面會在進行細化分析。

def init_post_backward_hook(self, bagua_module: BaguaModule):
    def hook():
        if self._should_communicate(bagua_module):
            bagua_module._bagua_backend.wait_pending_comm_ops()

            torch.cuda.current_stream().record_event(self.cuda_event)
            self.cuda_event.synchronize()
            for bucket in bagua_module.bagua_buckets:
                bucket._decentralized_op.copy_back_peer_weight( # 拷貝回來
                    bucket.backend_bucket
                )

    return hook

算法如下,append_decentralized_synchronous_op 用來通信,init_post_backward_hook 把去中心化平均的結果拷貝回來。

+--------------------------------------------------------------------+
|DecentralizedAlgorithmImpl                                          |
|                                                                    |
|     process_group                                                  |
|                                                                    |
|     decentralized_op = bucket.append_decentralized_synchronous_op  |
|                                                                    |
|     peer_selection_mode                                            |
|                                                                    |
|     init_post_backward_hook                                        |
|                                                                    |
+--------------------------------------------------------------------+

2.6.2 BaguaBucket

我們接下來進入 BaguaBucket,其是聚集了一系列 Bagua 張量,其最終調用 backend_bucket 進行處理,就是 rust 的 BaguaBucketPy。

class BaguaBucket:
    def __init__(
        self, tensors: List[BaguaTensor], name: str, flatten: bool, alignment: int = 1
    ) -> None:
        """
        Create a Bagua bucket with a list of Bagua tensors.
        """
        self.tensors = tensors
        """
        The tensors contained within the bucket.
        """
        self.bagua_module_name = tensors[0].bagua_module_name
        self._bagua_backend = get_backend(self.bagua_module_name)
        self.name = name
        """
        The bucket's name.
        """
        self.padding_tensor = None

        if alignment > 1:
            padding = sum(tensor.numel() for tensor in self.tensors) % alignment
            if padding > 0:
                padding = alignment - padding
                self.padding_tensor = torch.zeros(
                    padding, dtype=self.tensors[0].dtype, device=self.tensors[0].device
                ).to_bagua_tensor("bagua_padding_tensor_bucket_" + name)

        self._all_tensors = (
            self.tensors + [self.padding_tensor]
            if self.padding_tensor is not None
            else self.tensors
        )

        self.backend_tensor = None
        self.flatten = flatten
        if self.flatten:
            self._flatten_()
            torch.cuda.empty_cache()

        self.backend_bucket = B.BaguaBucketPy( # 底層實現
            name, [tensor._bagua_backend_tensor for tensor in self._all_tensors]
        )

        for tensor in self._all_tensors:
            tensor._bagua_bucket = self
2.6.2.1 append_decentralized_synchronous_op

append_decentralized_synchronous_op 是往桶添加了操作,當bucket中的所有張量都標記為ready時,該操作將由Bagua后端按照附加順序執行。參數 peer_weight 的意義是用於與對等模型求平均值的張量,應與桶張量的總大小相同。

append_decentralized_synchronous_op 不是 inplace 操作,這意味着桶權重首先復制到peer_weight,去中心化平均的結果放置在 peer_weight,然后使用op.copy_back_peer_weight(self) 將結果再拷貝回來。具體在前面 init_post_backward_hook 之中有拷貝回來的操作。

我們還可以注意到,如果采取了 hierarchical 模式,則傳入了 inter, intra 兩種communicator。

def append_decentralized_synchronous_op(
    self,
    peer_weight: BaguaTensor,
    hierarchical: bool = True,
    peer_selection_mode: str = "all",
    group: Optional[BaguaProcessGroup] = None,
):
    """
    Append a decentralized synchronous operation to a bucket. It will do gossipy style model averaging among workers.
    """
    if group is None:
        group = _get_default_group()

    if hierarchical:
        return self.backend_bucket.append_decentralized_synchronous_op(
            _bagua_backend_comm(group.get_inter_node_communicator()),
            _bagua_backend_comm(group.get_intra_node_communicator()),
            hierarchical=hierarchical,
            peer_selection_mode=peer_selection_mode,
            peer_weight=peer_weight._bagua_backend_tensor,
        )
    else:
        return self.backend_bucket.append_decentralized_synchronous_op(
            _bagua_backend_comm(group.get_global_communicator()),
            None,
            hierarchical=hierarchical,
            peer_selection_mode=peer_selection_mode,
            peer_weight=peer_weight._bagua_backend_tensor,
        )
2.6.2.2 BaguaBucket

我們來到了 Rust 世界,BaguaBucket 的 append_decentralized_synchronous_op 操作之中,如果是 "all" 或者 "shift_one",則會調用 DecentralizedFullPrecisionSynchronous。

pub fn append_decentralized_synchronous_op(
    &mut self,
    communicator_internode: Option<&BaguaSingleCommunicator>,
    communicator_intranode: Option<&BaguaSingleCommunicator>,
    hierarchical: bool,
    peer_selection_mode: String,
    peer_weight: BaguaTensor,
) -> Arc<DecentralizedFullPrecisionSynchronous> {
    let communicator =
        BaguaCommunicator::new(communicator_internode, communicator_intranode, hierarchical)
            .expect("cannot create communicator");
    let comm_op = Arc::new(DecentralizedFullPrecisionSynchronous {
        communicator,
        peer_selection_mode: match peer_selection_mode.as_str() {
            "all" => PeerSelectionMode::All,
            "shift_one" => PeerSelectionMode::ShiftOne,
            &_ => {
                unimplemented!("unsupported peer_selection_mode for decentralized algorithm (should be `all` or `shift_one`)")
            }
        },
        step: Default::default(),
        peer_weight,
    });

    self.inner
        .lock()
        .comm_ops
        .push(comm_op.clone() as Arc<dyn CommOpTrait + Send + Sync>);
    comm_op
}
2.6.2.3 DecentralizedFullPrecisionSynchronous

DecentralizedFullPrecisionSynchronous 位於 rust/bagua-core/bagua-core-internal/src/comm_ops/decentralized_full_precision_synchronous.rs 之中。

其定義如下:

pub struct DecentralizedFullPrecisionSynchronous {
    pub communicator: BaguaCommunicator,
    pub peer_selection_mode: PeerSelectionMode,
    pub step: Mutex<usize>,
    pub peer_weight: BaguaTensor,
}
2.6.2.3.1 發送

再回憶一下官方思路。

在第 2 步中,我們采用一種策略為每次迭代中的每個 worker 選擇一個 peer,這樣所有 worker 都正確配對並且數據交換是有效的,因為每個 worker 可以在迭代之間與不同的 peer 交換數據。簡而言之,我們的策略將工作人員平均分成兩組,並在兩組之間動態配對 worker,每次迭代都不同。

具體就是通過下面代碼實現的。關鍵點在函數的最后一句,通過調整step, 計算出下一個peer,這樣每次peer都不同

                    // 計算出下一個peer,關鍵點在函數的最后一句,通過調整step,每次peer都不同
                    let peer_rank = if c.rank < c.nranks / 2 {
                        ((step + rank) % ((nranks + 1) / 2)) + (nranks / 2)
                    } else {
                        (rank - (nranks / 2) - step).rem_euclid(nranks / 2)
                    } 
                    
										......
                            c.send(&t.raw, peer_rank); // 發送
                            c.recv(peer_tensor, peer_rank); // 接受
                    ......
                    
                    *self.step.lock() += 1; // 這里是關鍵點!遞增到下一個peer

全部代碼如下:

impl CommOpTrait for DecentralizedFullPrecisionSynchronous {
    fn execute_background_communication(
        &self,
        bucket: Arc<BaguaBucket>,
        comm_op_channels: &BaguaCommOpChannels,
    ) {
        let bucket_guard = bucket.inner.lock();
        let stream_ptr = self.communicator.stream_ptr();

        // 獲取不同的communicator
        let mut communication_tensor = match &self.communicator {
            BaguaCommunicator::SingleCommunicator(_) => {
                bucket_guard.get_communication_tensor(stream_ptr, false, false)
            }
            BaguaCommunicator::HierarchicalCommunicator(x) => match x {
                BaguaHierarchicalCommunicator::Leader(_) => {
                    bucket_guard.get_communication_tensor(stream_ptr, true, true)
                }
                BaguaHierarchicalCommunicator::Worker(_) => {
                    bucket_guard.get_communication_tensor(stream_ptr, false, false)
                }
            },
        };

        let peer_mode = &self.peer_selection_mode;
        let mut peer_guard = self.peer_weight.inner.write();
        let mut peer_tensor = peer_guard.raw.as_mut();
        let step = { *self.step.lock() } as i64;

        self.communicator.execute_communication( // 執行通信
            &mut communication_tensor,
            true,
            true,
            false,
            &mut |c, t| {
                match peer_mode {
                    PeerSelectionMode::All => {
                        // 做普通 allreduce
                        {
                            peer_tensor.clone_from(&t.raw, c.stream_ptr);
                            let _guard = NCCLGroupGuard::new();
                            c.allreduce_inplace(peer_tensor, BaguaReductionOp::AVG);
                        }
                    }
                    PeerSelectionMode::ShiftOne => { // shift_one 
                        let rank = c.rank as i64;
                        let nranks = c.nranks as i64;
                        // 計算出下一個peer,關鍵點在函數的最后一句,通過調整step,每次peer都不同
                        let peer_rank = if c.rank < c.nranks / 2 {
                            ((step + rank) % ((nranks + 1) / 2)) + (nranks / 2)
                        } else {
                            (rank - (nranks / 2) - step).rem_euclid(nranks / 2)
                        } as i32;
                        {
                            let _guard = NCCLGroupGuard::new();
                            c.send(&t.raw, peer_rank); // 發送
                            c.recv(peer_tensor, peer_rank); // 接受
                        }
                        peer_tensor.average_inplace(&t.raw, c.stream_ptr);
                    },
                    PeerSelectionMode::Ring => {
                        unimplemented!() // 沒有實現
                    },
                }
            },
        );

        *self.step.lock() += 1; // 這里是關鍵點!遞增到下一個pee
    }
}

沒有精力去研究rust,所以使用源碼中的測試代碼 tests/torch_api/test_decentralized.py 來看看,八卦在這方面真心做的不錯。

def get_peer_rank(peer_selection_mode, rank, nranks, step, communication_interval):
    comm_step = step // communication_interval
    if peer_selection_mode == "shift_one":
        if rank < nranks // 2:
            return ((comm_step + rank) % ((nranks + 1) // 2)) + (nranks // 2)
        else:
            return (rank - (nranks // 2) - comm_step) % (nranks // 2)
    else:
        ValueError("Unsupported `peer_selection_mode`")

step = 1
for i in range(6):
    print("iteration : ", i)
    print("peer is : ", get_peer_rank("shift_one", 1, 5, step, 1))
    step += 1
    
"""
iteration :  0
peer is :  4
iteration :  1
peer is :  2
iteration :  2
peer is :  3
iteration :  3
peer is :  4
iteration :  4
peer is :  2
iteration :  5
peer is :  3
"""

整理出圖如下,worker 1 每次分別和 worker 4, worker 2,worker 3 進行交換。

                              +--------------+
                              |              |
                              |   Worker 0   |
                              |              |
                              |              |
                              +--------------+

                              +--------------+
                              |              |
                   +------->  |   Worker 2   |
+--------------+   | peer 2   |              |
|              |   |          |              |
|   Worker 1   |   |          +--------------+
|              +---+
|              |   |          +--------------+
+--------------+   |          |              |
                   |          |   Worker 3   |
                   +------->  |              |
                   | peer 3   |              |
                   |          +--------------+
                   |
                   |          +--------------+
                   |          |              |
                   +--------> |   Worker 4   |
                     peer 1   |              |
                              |              |
                              +--------------+
2.6.2.3.2 拷貝回來

copy_back_peer_weight 就是前面提到的回拷貝操作。

impl DecentralizedFullPrecisionSynchronous {
  
    pub fn copy_back_peer_weight(&self, bucket: Arc<BaguaBucket>) { // 拷貝回去
        let bucket_guard = bucket.inner.lock();
        let stream_ptr = self.communicator.stream_ptr();

        let mut communication_tensor =
            bucket_guard.get_communication_tensor(stream_ptr, false, false);

        self.communicator.execute_communication(
            &mut communication_tensor,
            false,
            false,
            true,
            &mut |c, t| {
                t.raw
                    .clone_from(self.peer_weight.inner.read().raw.as_ref(), c.stream_ptr);
            },
        );
    }
}

我們再給出一個示意圖。

+---------------------------------------------------------------------+
|DecentralizedAlgorithmImpl                                           |
|                                                                     |
|     process_group                                                   |
|                                                                     |
|     decentralized_op = bucket.append_decentralized_synchronous_op   |
|                                                 +                   |
|     peer_selection_mode                         |                   |
|                                                 |                   |
|     init_post_backward_hook                     |                   |
|              ^                                  |                   |
|              |                                  |                   |
|              |                                  |                   |
+---------------------------------------------------------------------+
               |                                  |
               |                                  |
+-----------------------------------------------------------+         +----------+
| BaguaBucket  |                                  |         |         | Worker 0 |
|              |                                  |         |         +----------+
|              |                                  v         |
|              |                                            |         +----------+
|              |    DecentralizedFullPrecisionSynchronous { |         | Worker 1 |
|              |                                            |         +----------+
|              |         PeerSelectionMode::ShiftOne {      |
|              |                                            |   peer2 +----------+
|              |            c.send(&t.raw, peer_rank);+--------+----> | Worker 2 |
|              |            c.recv(peer_tensor, peer_rank)  |  |      +----------+
|              |         }                                  |  |
|              |    }                                       |  |peer3 +----------+
|              |                                            |  +----> | Worker 3 |
|              |                                            |  |      +----------+
|              |                                            |  |
|              +--+ copy_back_peer_weight                   |  |peer4 +----------+
|                                                           |  +----> | Worker 4 |
+-----------------------------------------------------------+         +----------+

0x03 異步

關於異步通信,官方文檔思路如下:

  • 同步或是異步(Synchronous or Asynchronous):同步模式中,在每一次迭代過程中,所有工作節點都需要進行通信,並且下一步迭代必須等待當前迭代的通信完成才能開始。反之,異步式分布算法 [2] 則不需要等待時間:當某個節點完成計算后就可直接傳遞本地梯度,進行模型更新。

我們接下來用 https://tutorials.baguasys.com/algorithms/async-model-average 結合代碼來分析學習。

3.1 示例用法

首先初始化八卦算法:

from bagua.torch_api.algorithms import async_model_average
algorithm = async_model_average.AsyncModelAverageAlgorithm()

然后對模型使用算法

model = model.with_bagua([optimizer], algorithm)

與運行同步算法不同,您需要在訓練過程完成時(例如,當您要運行測試時)明確停止通信線程:

model.bagua_algorithm.abort(model)

要在再次開始訓練時恢復通信線程,請執行以下操作:

model.bagua_algorithm.resume(model)

3.2 異步模型平均

在Gradient AllReduce 等同步通信算法中,同一迭代中每個 worker 都需要以鎖步(lock-step)方式運作。當系統中沒有落后者(straggler)時,這種同步算法相當有效,並可以提供更容易推理的確定性訓練結果。然而,當系統中存在落后者時,使用同步算法時,更快的 worker 必須在每次迭代中等待最慢的 worker,這會極大地損害整個系統的性能。為了處理掉隊者,我們可以使用異步算法,其中 worker 不需要同步。八卦提供的異步模型平均算法就是這樣的異步算法。

3.3 算法

異步模式平均算法可以被描述為如下:

每個 worker 都維護一個本地模型 X. 第 i 個 worker 維護 $ x^{(i)}$ ,每個 worker 並行運行兩個線程。第一個線程進行梯度計算(稱為計算線程),另一個線程進行通信(稱為通信線程)。對於每個 worker i, 有一個鎖 \(m_i\),控制對其模型的訪問。

第 i 個 worker 上的計算線程重復以下步驟:

  1. 獲取鎖 \(m_i\)
  2. 在一批輸入數據上計算局部梯度 $∇ F (x^{(i)}) $。
  3. 釋放鎖 \(m_i\).
  4. 用局部梯度更新模型,$x^{(i)} = x^{(i)} - γ∇ F (x^{(i)}) $。

第 i 個 worker 上的通信線程重復以下步驟::

  1. 獲取鎖 \(m_i\)
  2. 與所有其他 worker 的模型通信以平均本地模型\(X^{(i)}\)\(X^{(i)} = \frac{1}{n} \sum^n_{j=1}X^{(j)}\)
  3. 釋放鎖 \(m_i\).

每個 worker 獨立並發地運行這兩個線程。

3.4 分析

大家可以看到,本質上就是計算線程和通信線程都是自己操作,但是依賴鎖進行彼此協調,達到了異步的目的。

3.4.1 異步通信實現

AsyncModelAverageAlgorithmImpl 是異步通信的實現。

class AsyncModelAverageAlgorithmImpl(AlgorithmImpl):
    def __init__(
        self,
        process_group: BaguaProcessGroup,
        peer_selection_mode: str = "all",
        sync_interval_ms: int = 500,
        warmup_steps: int = 0,
    ):
        """
        Implementation of the
        `AsyncModelAverage <https://tutorials.baguasys.com/algorithms/async-model-average.html>`_
        algorithm.

        The asynchronous implementation is experimental, and imposes some restrictions.
        With such asynchronous algorithm, the number of iterations on each worker are different. Therefore
        the current implementation assumes that the dataset is an endless stream, and all workers continuously
        synchronize between each other.

        Users should call :meth:`abort` to manually stop the algorithm's continuous synchronization process.
        For example, for a model wrapped with `.with_bagua(...)`, you can abort with `model.bagua_algorithm.abort(model)`,
        and resume with `model.bagua_algorithm.resume(model)`.

        Args:
            process_group (BaguaProcessGroup): The process group to work on.
            peer_selection_mode (str): The way how workers communicate with each other. Currently ``"all"`` is supported.
                ``"all"`` means all workers' weights are synchronized during each communication.
            sync_interval_ms (int): Number of milliseconds between model synchronizations.
            warmup_steps (int): Number of steps to warm up by doing gradient allreduce before doing asynchronous
                model averaging. Use 0 to disable.
        """

        super(AsyncModelAverageAlgorithmImpl, self).__init__(process_group)
        self.peer_selection_mode = peer_selection_mode
        self.sync_interval_ms = sync_interval_ms
        self.step_id = 0
        self.warmup_steps = warmup_steps
        self.cuda_event = torch.cuda.Event()
        self.abort_event = threading.Event()
        self.dummy_tensor = torch.Tensor([0]).byte().cuda()

        # 線程池
        self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
        self.scheduled = False

        process_ranks = list(_pg_group_ranks[self.process_group])
        self.thread_group = new_group(
            process_ranks, stream=torch.cuda.Stream(priority=-1)
        )

3.4.2 初始化操作

init_operations 的 這部分調用是在 _bagua_reset_algorithm_buckets 之中,每個 BaguaModule 都會做設置,主要是設置:熱身時期是同步操作/其他時間是異步操作,這里忽略了大部分代碼。

def _bagua_reset_algorithm_buckets(self):
    self._bagua_cleanup_algorithm()
    raw_buckets = self._bagua_autotune_get_buckets()
    self.bagua_buckets.extend(self.bagua_algorithm.tensors_to_buckets(raw_buckets))

    for name, param in self.named_parameters():
        # 忽略 real_hook_factory 定義
        if param.requires_grad:
            param_tmp = param.expand_as(param)
            grad_acc = param_tmp.grad_fn.next_functions[0][0]
            hook = grad_acc.register_hook(real_hook_factory(name, param))
            hook.grad_acc = grad_acc
            self._bagua_algorithm_hooks.append(hook)

    optimizer_hook = self.bagua_algorithm.init_post_optimizer_step_hook(self)

    for optimizer in self.bagua_optimizers:
        if not hasattr(optimizer, "_bagua_original_step"):
            optimizer._bagua_original_step = optimizer.step
        # 忽略 new_step_factory 定義
        optimizer.step = new_step_factory(optimizer)

    for bucket in self.bagua_buckets:
        self.bagua_algorithm.init_operations( # 這里調用對算法的初始化操作
            self,
            bucket,
        )
    self._bagua_backend.register_ordered_buckets(
        [bucket.backend_bucket for bucket in self.bagua_buckets]
    )

就是對於除了熱身期間之外,每個桶都設定了異步通信

def init_operations(
    self,
    bagua_module: BaguaModule,
    bucket: BaguaBucket,
):
    bagua_module._bagua_backend.wait_pending_comm_ops()
    bucket.clear_ops()

    if self.step_id < self.warmup_steps:
        bucket.append_centralized_synchronous_op( # 熱身時期是同步操作
            hierarchical=False,
            average=True,
            group=self.process_group,
        )
    else:
        # 其他時間是異步操作
        async_op = bucket.append_asynchronous_model_average_op(
            peer_selection_mode=self.peer_selection_mode, group=self.thread_group
        )
        bucket._async_op = async_op

3.4.3 加鎖解鎖

我們接下來看看加鎖釋放鎖的基礎操作。bagua/torch_api/algorithms/async_model_average.py 之中有:

def _lock_model(self, bagua_module: BaguaModule):
    torch.cuda.current_stream().record_event(self.cuda_event)
    self.cuda_event.synchronize() # CUDA同步操作

    for bucket in bagua_module.bagua_buckets:
        bucket._async_op.lock_weight() # 加鎖操作

def _unlock_model(self, bagua_module: BaguaModule):
    torch.cuda.current_stream().record_event(self.cuda_event)
    self.cuda_event.synchronize() # CUDA同步操作

    for bucket in bagua_module.bagua_buckets:
        bucket._async_op.unlock_weight() # 釋放鎖

lock_weight 和 unlock_weight 的實現在 rust 代碼之中。

impl DecentralizedFullPrecisionAsynchronous {
    pub fn lock_weight(&self) {
        let raw_mutex = unsafe { self.weight_mutex.raw() };
        raw_mutex.lock();
    }

    pub fn unlock_weight(&self) {
        unsafe {
            let raw_mutex = self.weight_mutex.raw();
            raw_mutex.unlock();
        };
    }
}

3.4.4 計算線程

計算線程之中,和加鎖解鎖關鍵步驟如下:

3.4.4.1 前向傳播

前向傳播時候,先進行加鎖,如果異步循環通信線程沒有啟動,則會進行啟動。

def init_forward_pre_hook(self, bagua_module: BaguaModule):
    def hook(input):
        if (
            self.step_id > self.warmup_steps
            and self.sync_interval_ms > 0  # noqa: W503
        ):
            self._lock_model(bagua_module) # 枷鎖

            if not hasattr(self, "future"):
                self.future = self.executor.submit(
                    self._run_async_loop, bagua_module # 啟動異步循環通信線程
                )
                self.scheduled = True

    return hook
3.4.4.2 后向傳播

后向傳播結束之后,會對鎖進行釋放,就是說,前向傳播時候加鎖啟動線程,后向傳播時候解鎖,這期間進行計算

def init_backward_hook(self, bagua_module: BaguaModule):
    def hook(parameter_name, parameter):
        if self.step_id <= self.warmup_steps:
            parameter._bagua_grad.bagua_mark_communication_ready() # 通知后端可以通信

    return hook

def init_post_backward_hook(self, bagua_module: BaguaModule):
    def hook():
        if self.step_id <= self.warmup_steps:
            bagua_module._bagua_backend.wait_pending_comm_ops() # 等待
        else:
            self._unlock_model(bagua_module) # 解鎖

    return hook

此時邏輯如下:

+---------------------------------------------------------------------------+
| AsyncModelAverageAlgorithmImpl                                            |
|                                                                           |
|  +-----------------------------+                 +----------------------+ |
|  | Computation thread          |                 | BaguaBucket          | |
|  |                             | set async_op    |  +----------------+  | |
|  |    init_operations   +----------------------> |  | _async_op      |  | |
|  |                             |                 |  |                |  | |
|  |                             | lock_weight()   |  |                |  | |
|  |    init_forward_pre_hook +------------------> |  |                |  | |
|  |                             | unlock_weight() |  |                |  | |
|  |    init_post_backward_hook+-----------------> |  |                |  | |
|  |                             |                 |  |                |  | |
|  |                             |                 |  +----------------+  | |
|  +-----------------------------+                 +----------------------+ |
|                                                                           |
|  +-----------------------------+                                          |
|  | Communation thread          |                                          |
|  |                             |                                          |
|  | _run_async_loop             |                                          |
|  |                             |                                          |
|  |                             |                                          |
|  +-----------------------------+                                          |
|                                                                           |
+---------------------------------------------------------------------------+

3.4.5 通信線程

通信線程主循環如下,主要是通知后端,進行通信

def _run_async_loop(self, bagua_module: BaguaModule):
    comm_step = 0
    while True:
        state = self._negotiate()
        if state == _AsyncInternalState.ABORT:
            break

        start_time = time.time()
        for bucket in bagua_module.bagua_buckets: # 遍歷桶
            for tensor in bucket.tensors: # 遍歷張量
                # 通知后端,進行通信
                tensor.bagua_mark_communication_ready_without_synchronization() 

        bagua_module._bagua_backend.wait_pending_comm_ops()
        duration = (time.time() - start_time) * 1000

        comm_step += 1
        time.sleep(self.sync_interval_ms / 1000)
3.4.5.1通知后端
Python

bagua_mark_communication_ready_without_synchronization 的實現如下,調用后端的 mark_communication_ready。

def bagua_mark_communication_ready_without_synchronization(self):
    """
    Mark a Bagua tensor ready immediately, without `CUDA event <https://pytorch.org/docs/stable/generated/torch.cuda.Event.html?highlight=event#torch.cuda.Event>`_ synchronization.
    """
    self.bagua_backend.mark_communication_ready(
        self._bagua_backend_tensor,
        0,
    )
Rust

mark_communication_ready 的實現在 rust 之中。位置是 rust/bagua-core/bagua-core-py/src/lib.rs。

pub fn mark_communication_ready(
    &mut self,
    tensor: PyRef<BaguaTensorPy>,
    ready_cuda_event_ptr: u64,
    py: Python,
) -> PyResult<()> {
    let inner = &tensor.inner;
    py.allow_threads(|| {
        self.inner
            .mark_communication_ready(inner, ready_cuda_event_ptr)
    })
    .map_err(|e| PyRuntimeError::new_err(format!("{:?}", e)))
}

rust/bagua-core/bagua-core-internal/src/lib.rs 之中有:

pub fn mark_communication_ready(
    &mut self,
    tensor: &BaguaTensor,
    ready_cuda_event_ptr: u64,
) -> Result<(), BaguaCoreError> {
    let tracer = global::tracer("bagua-core");
    let mut span = tracer.start("tensor_ready");
    span.set_attribute(KeyValue::new("tensor_name", tensor.name()));

    tensor.mark_comm_ready(ready_cuda_event_ptr);
    while self.should_schedule()? {
        let bucket = self.ordered_buckets.pop_front().unwrap();
        bucket.reset_comm_ready();
        let bucket_clone = bucket.clone();
        self.ordered_buckets.push_back(bucket);
        self.schedule_comm(bucket_clone)?;
    }
    Ok(())
}

schedule_comm 在 rust/bagua-core/bagua-core-internal/src/lib.rs 之中。

pub fn schedule_comm(&self, bucket: Arc<BaguaBucket>) -> Result<(), BaguaCoreError> {
    let event_channel = BaguaEventChannel::new("comm_op");
    self.channels
        .schedule_channel_sender
        .send(BaguaScheduledCommOp {
            name: format!("comm op for bucket {}", bucket.name),
            ops: {
                let guard = bucket.inner.lock();
                guard.comm_ops.clone() // 獲取bucket的op,進行調用
            },
            bucket,
            event_channel: event_channel.clone(),
        })
        .map_err(|e| BaguaCoreError::InternalChannelError(format!("{:?}", e)))?;
    Ok(self
        .channels
        .not_waited_events_sender
        .send(event_channel)
        .map_err(|e| BaguaCoreError::InternalChannelError(format!("{:?}", e)))?)
}

發送了一個 BaguaScheduledCommOp。

pub struct BaguaScheduledCommOp {
    pub name: String,
    pub bucket: Arc<BaguaBucket>,
    pub ops: Vec<Arc<dyn CommOpTrait + Send + Sync>>,
    pub event_channel: BaguaEventChannel,
}

邏輯如下:

+---------------------------------------------------+    +----------------------------+
| AsyncModelAverageAlgorithmImpl                    |    | BaguaBucket                |
|                                                   |    | +------------------------+ |
|  +-----------------------------+                  |    | | _async_op              | |
|  | Computation thread          |                  |    | |                        | |
|  |                             |    set async_op  |    | |                        | |
|  |    init_operations   +----------------------------> | |                        | |
|  |                             |                  |    | |                        | |
|  |                             |    lock_weight() |    | |                        | |
|  |    init_forward_pre_hook +------------------------> | |                        | |
|  |                             |   unlock_weight()|    | |                        | |
|  |    init_post_backward_hook+-----------------------> | |                        | |
|  |                             |                  |    | +------------------------+ |
|  |                             |                  |    +----------------------------+
|  +-----------------------------+                  |
|  +---------------------------------+              |
|  | Communation thread              |              |    +----------------------------+
|  | +-----------------------------+ |              |    | BaguaCommBackendPy         |
|  | |                             | |              |    |                            |
|  | | _run_async_loop    +----------------------------> |   mark_communication_ready |
|  | |                             | |              |    |            +               |
|  | +-----------------------------+ |              |    |            |               |
|  +---------------------------------+              |    |            v               |
+---------------------------------------------------+    |      schedule_comm         |
                                                         |                            |
                                                         +----------------------------+
3.4.5.2 歸並

schedule_comm 最終會調用到 bucket.comm_ops,該變量在初始化時候被配置為 DecentralizedFullPrecisionAsynchronous,所以我們需要回頭來一步一步看看如何歸並。

前面初始化操作時候有使用 bucket.append_asynchronous_model_average_op 進行配置。

def init_operations(
    self,
    bagua_module: BaguaModule,
    bucket: BaguaBucket,
):
    bagua_module._bagua_backend.wait_pending_comm_ops()
    bucket.clear_ops()

    if self.step_id < self.warmup_steps:
        bucket.append_centralized_synchronous_op( # 熱身時期是同步操作
            hierarchical=False,
            average=True,
            group=self.process_group,
        )
    else:
        # 其他時間是異步操作
        async_op = bucket.append_asynchronous_model_average_op( # 進行歸並配置
            peer_selection_mode=self.peer_selection_mode, group=self.thread_group
        )
        bucket._async_op = async_op
Python

append_asynchronous_model_average_op 代碼在 bagua/torch_api/bucket.py。其作用是:

  • 將異步模型歸並操作附加到bucket。此操作將在訓練模型時啟用 worker 之間的連續模型平均。當bucket中的所有張量都標記為ready時,操作將由Bagua后端按照追加的順序執行。

  • 此操作旨在與計算過程並行運行。它返回對op的引用。op具有獨占訪問模型的鎖。調用op.lock_weight()獲取鎖,調用op.unlock_weight()釋放鎖。

  • 重點在於,張量 ready 之后進行操作。

def append_asynchronous_model_average_op(
    self, peer_selection_mode: str, group: Optional[BaguaProcessGroup] = None
):

    """
    Append an asynchronous model average operation to a bucket. This operation will enable continuous
    model averaging between workers while training a model.

    The operations will be executed by the Bagua backend in the order they are appended
    when all the tensors within the bucket are marked ready.

    This operation is intended to run in parallel with the computation process. It returns a reference
    to the op. The op features a lock to exclusively access the model. Call ``op.lock_weight()`` to
    acquire the lock and ``op.unlock_weight()`` to release it.

    Args:
        peer_selection_mode (str): The way how workers communicate with each otehr. Currently ``"all"`` is supported.
            ``"all"`` means all workers' weights are averaged during each communication.
        group: The process group to work on. If ``None``, the default process group will be used.
    Returns:
        The asynchronous model average operation itself.
    """
    if group is None:
        group = _get_default_group()

    return self.backend_bucket.append_decentralized_asynchronous_op(
        _bagua_backend_comm(group.get_global_communicator()),
        None,
        peer_selection_mode=peer_selection_mode,
        torch_stream=torch.cuda.current_stream().cuda_stream,
    )
Rust

append_decentralized_asynchronous_op 函數在 rust 之中,其調用了 DecentralizedFullPrecisionAsynchronous,就是往 bucket.comm_ops 之上添加了一個 DecentralizedFullPrecisionAsynchronous。

    pub fn append_decentralized_asynchronous_op(
        &mut self,
        communicator_internode: Option<&BaguaSingleCommunicator>,
        communicator_intranode: Option<&BaguaSingleCommunicator>,
        peer_selection_mode: String,
        torch_stream: u64,
    ) -> Arc<DecentralizedFullPrecisionAsynchronous> {
        let communicator =
            BaguaCommunicator::new(communicator_internode, communicator_intranode, false)
                .expect("cannot create communicator");

        let comm_op = Arc::new(DecentralizedFullPrecisionAsynchronous {
            communicator,
            peer_selection_mode: match peer_selection_mode.as_str() {
                "all" => PeerSelectionMode::All,
                &_ => {
                    unimplemented!("unsupported peer_selection_mode for decentralized asynchronous algorithm (should be `all`)")
                }
            },
            torch_stream,
            weight_mutex: Arc::new(Mutex::new(true)),
        });

        self.inner
            .lock()
            .comm_ops // 插入到 bucket 的 comm_ops
            .push(comm_op.clone() as Arc<dyn CommOpTrait + Send + Sync>);

        comm_op
    }

DecentralizedFullPrecisionAsynchronous 里面有加鎖,釋放鎖,CUDA 同步操作等等,恰好與前面提到的前向傳播/后向傳播對應。

impl CommOpTrait for DecentralizedFullPrecisionAsynchronous {
    fn execute_background_communication(
        &self,
        bucket: Arc<BaguaBucket>,
        comm_op_channels: &BaguaCommOpChannels,
    ) {
        let bucket_guard = bucket.inner.lock();

        let comm_stream = self.communicator.stream_ptr();

        let mut communication_tensor = match &self.communicator {
            BaguaCommunicator::SingleCommunicator(_) => {
                bucket_guard.get_communication_tensor(comm_stream, false, false)
            }
            BaguaCommunicator::HierarchicalCommunicator(x) => {
                panic!("asynchronous op only accepts non-hierarchical communicator");
            }
        };

        let peer_mode = &self.peer_selection_mode;

        let torch_stream = self.torch_stream;

        self.communicator.execute_communication(
            &mut communication_tensor,
            false,
            false,
            false,
            &mut |c, t| {
                let start_time = std::time::Instant::now();
   
                let temp_buf = CUDA_DEVICE_MEMORY_POOL[t.raw.device_id()]
                    .try_pull(t.raw.num_elements_allocated() * t.raw.dtype().bytes())
                    .expect("cannot allocate cuda memory");

                let mut temp_tensor = BaguaTensorRaw {
                    ptr: temp_buf.ptr,
                    num_elem_allocated: t.raw.num_elements_allocated(),
                    dtype: t.raw.dtype().clone(),
                    num_elem: t.raw.num_elements(),
                    device_id: t.raw.device_id(),
                    pool_allocations: vec![Arc::new(temp_buf)],
                };

                let reduced_buf = CUDA_DEVICE_MEMORY_POOL[t.raw.device_id()]
                    .try_pull(t.raw.num_elements_allocated() * t.raw.dtype().bytes())
                    .expect("cannot allocate cuda memory");

                let mut reduced_tensor = BaguaTensorRaw {
                    ptr: reduced_buf.ptr,
                    num_elem_allocated: t.raw.num_elements_allocated(),
                    dtype: t.raw.dtype().clone(),
                    num_elem: t.raw.num_elements(),
                    device_id: t.raw.device_id(),
                    pool_allocations: vec![Arc::new(reduced_buf)],
                };

                let src_ready_event = CUDA_EVENT_POOL.take().event;

                // use default stream to copy weights
                temp_tensor.clone_from(&t.raw, torch_stream as u64);

                unsafe {
                    cpp::cpp!([
                        src_ready_event as "cudaEvent_t",
                        comm_stream as "cudaStream_t",
                        torch_stream as "cudaStream_t"]
                    {
                        CUDACHECK(cudaEventRecord(src_ready_event, torch_stream));
                        CUDACHECK(cudaStreamWaitEvent(comm_stream, src_ready_event , 0));
                    });
                }

                match peer_mode {
                    PeerSelectionMode::All => {
                        c.allreduce(&temp_tensor, &mut reduced_tensor, BaguaReductionOp::SUM);
                    }
                    PeerSelectionMode::Ring => {
                        unimplemented!()
                    }
                    PeerSelectionMode::ShiftOne => {
                        unimplemented!()
                    }
                };

                {
                    // 獲取 ready event
                    let ready_event = CUDA_EVENT_POOL.take().event;
                    unsafe {
                        cpp::cpp!([
                            ready_event as "cudaEvent_t",
                            comm_stream as "cudaStream_t"]
                        {
                            // CUDA 同步操作
                            CUDACHECK(cudaEventRecord(ready_event, comm_stream));
                            CUDACHECK(cudaEventSynchronize(ready_event));
                        });
                    }

                    self.lock_weight(); // 加鎖
                  
                    t.raw.async_model_average(
                        &reduced_tensor,
                        &temp_tensor,
                        c.nranks as f32,
                        comm_stream,
                    );

                    unsafe {
                        cpp::cpp!([
                            ready_event as "cudaEvent_t",
                            comm_stream as "cudaStream_t"]
                        {
                            // 對CUDA進行操作
                            CUDACHECK(cudaEventRecord(ready_event, comm_stream));
                            CUDACHECK(cudaEventSynchronize(ready_event));
                        });
                    }
                    self.unlock_weight(); // 解鎖
                }

                tracing::debug!(
                    "#{} async model average update cost: {:?}",
                    c.rank,
                    start_time.elapsed()
                );
            },
        );
    }
}

在 rust/bagua-core/bagua-core-internal/kernels/bagua_kernels.cu 之中有最終操作。

__global__ void async_model_average(float *tensor, const float *reduced_tensor_copy, 
      const float *tensor_copy, const float nranks, const int N) {
    for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) {  
   tensor[i] += reduced_tensor_copy[i] / nranks - tensor_copy[i];
    }
}

我們總結邏輯如下:

  • (1)init_operations 會進行一系列調用,生成了一個DecentralizedFullPrecisionAsynchronous,賦值在bucket 的 comm_ops 和 aysnc_op 之上。

計算線程之中做如下操作:

  • (2)計算線程之中,在前向傳播之前設置了hook,其中會 lock weight。
  • (3)計算線程之中,在后向傳播之前設置了hook,其中會 unlock weight。

通訊線程之中做如下操作:

  • (4)會調用 mark_communication_ready 進行通信設置。
  • (5)mark_communication_ready 最終調用到 schedule_comm,其會啟動 bucket.comm_ops,bucket.comm_ops 就是 DecentralizedFullPrecisionAsynchronous。
  • DecentralizedFullPrecisionAsynchronous 之中會:
    • (6)lock weight。
    • (7)會進行異步模型歸並。
    • (8)會 unlock weight。
  +---------------------------------------------------+   +----------------------+    +----------------------------------------+
  | AsyncModelAverageAlgorithmImpl                    |   |  BaguaBucket         |    | DecentralizedFullPrecisionAsynchronous |
  |                                                   |   |                 1    |    |                                        |
  |  +-----------------------------+                  |   |       comm_ops +--------> |  6   self.lock_weight()                |
  |  | Computation thread          |  1 set async_op  |   |                      |    |                                        |
  |  |                             |                  |   |    +--------------+  |    |                                        |
  |  |    init_operations   +---------------------------->+    | _async_op  1 |  |    |  7   t.raw.async_model_average(        |
  |  |                             |                  |   |    |           +--------> |                &reduced_tensor,        |
  |  |                             |                  |   |    |              |  |    |                &temp_tensor,           |
  |  |                             |                  |   |    |              |  |    |                c.nranks as f32,        |
  |  |                             |                  |   |    |              |  |    |                comm_stream,            |
  |  |                             |  2 lock_weight() |   |    |              |  |    |            );                          |
  |  |    init_forward_pre_hook +----------------------------> |              |  |    |                                        |
  |  |                             | 3 unlock_weight()|   |    |              |  |    |                                        |
  |  |    init_post_backward_hook+---------------------------> |              |  |    |  8   self.unlock_weight()              |
  |  |                             |                  |   |    +--------------+  |    |                                        |
  |  |                             |                  |   |                      |    +--------+-------------------------------+
  |  +-----------------------------+                  |   +----------------------+             ^
  |                                                   |                                        |
+--------------------------------------------------------------------------------------------------------------------------------+
  |                                                   |                                        |
  |  +---------------------------------+              |                                        |
  |  | Communation thread              |              |   +-----------------------------+      |
  |  | +-----------------------------+ |              |   |  BaguaCommBackendPy         |      |
  |  | |                             | |     4        |   |                             |      |
  |  | | _run_async_loop    +--------------------------------> mark_communication_ready |      |
  |  | |                             | |              |   |             +               |      | 5
  |  | +-----------------------------+ |              |   |             |               |      |
  |  +---------------------------------+              |   |             v               |      |
  +---------------------------------------------------+   |       schedule_comm         |      |
                                                          |             +               |      |
                                                          |             |               |      |
                                                          |             v               |      |
                                                          |       bucket.comm_ops  +-----------+
                                                          |                             |
                                                          +-----------------------------+

手機如下:

或者我們換一個角度來看,就是左右兩個線程都操作桶,通過鎖來協調競爭,特色除了鎖之外,就在DecentralizedFullPrecisionAsynchronous 之中。這里需要注意的是,數值 1 的意義是設置,就是 bucket 的 _async_op 和 comm_ops 都配置成 DecentralizedFullPrecisionAsynchronous,最后通訊線程之中(4)會調用 mark_communication_ready 進行通信設置。

                                                                                                                             +-------------------------+
                                                 +----------------------+                                                    | Communation thread      |
                                                 |  BaguaBucket         |                                                    | +---------------------+ |
                                                 |                      | 1                                                  | |                     | |
+---------------------------+                    |       comm_ops +--------------------------------+                         | | _run_async_loop     | |
| Computation thread        |  1 set async_op    |                      |                          |                         | |          +          | |
|                           |                    |    +--------------+  |                          |                         | |          |          | |
|  init_operations   +-------------------------->+    | _async_op    |  | 1                        |                         | +---------------------+ |
|                           |                    |    |           +------------------+             |                         +-------------------------+
|                           |                    |    |              |  |            |             |                                      |
|                           |                    |    |              |  |            |             |                                      |
|                           |                    |    |              |  |            v             v                                      v
|                           |  2 lock_weight()   |    |              |  |     +------+-------------+-------------------+    +-------------+---------------+
|  init_forward_pre_hook +--------------------------> |              |  |     | DecentralizedFullPrecisionAsynchronous |    |  BaguaCommBackendPy         |
|                           |                    |    |              |  | 6   |                                        |    |                             |
|                           |                    |    |              +<------------+ self.lock_weight()                |    |    mark_communication_ready |
|                           |                    |    |              |  |     |                                        |    |             +               |
|                           |                    |    |              |  |     |  7   t.raw.async_model_average(        |    |             |               |
|                           |                    |    |              |  |     |                &reduced_tensor,        |    |             v               |
|                           |                    |    |              |  |     |                &temp_tensor,           |    |       schedule_comm         |
|                           |                    |    |              |  |     |                c.nranks as f32,        |    |             +               |
|                           |                    |    |              |  |     |                comm_stream,            |    |             |               |
|                           |                    |    |              |  |     |            );                          |  4 |             v               |
|                           |                    |    |              |  | 8   |                                        +<--------+  bucket.comm_ops       |
|                           | 3 unlock_weight()  |    |              +<-----------+  self.unlock_weight()              |    |                             |
|  init_post_backward_hook+-------------------------> |              |  |     |                                        |    +-----------------------------+
|                           |                    |    |              |  |     +----------------------------------------+
|                           |                    |    +--------------+  |
|                           |                    |                      |
+---------------------------+                    +----------------------+

手機如下:

至此,八卦框架分析完畢,這個框架無論是論文,代碼,文檔,介紹網站,PPT都非常給力,推薦有興趣的朋友繼續深入研究。

0xFF 參考

PyTorch internals

快手八卦!突破 TensorFlow、PyTorch 並行瓶頸的開源分布式訓練框架來了!

https://arxiv.org/pdf/2107.01499.pdf

https://tutorials.baguasys.com/algorithms/decentralized

[1] Dean, Jeffrey, Greg S. Corrado, Rajat Monga, Kai Chen, Matthieu Devin, Quoc V. Le, Mark Z. Mao et al. “Large scale distributed deep networks.” (2012).

[2] Zhengyuan Zhou, Panayotis Mertikopoulos, Nicholas Bambos, Peter Glynn, Yinyu Ye, Li-Jia Li, and Li Fei-Fei. 2018. Distributed asynchronous optimization with unbounded delays: How slow can you go?. In International Conference on Machine Learning. PMLR, 5970–5979.

[3] DanAlistarh, DemjanGrubic, JerryLi, RyotaTomioka, and MilanVojnovic. 2016. QSGD: Communication-efficient SGD via gradient quantization and encoding. arXiv preprint arXiv:1610.02132 (2016).

[4] Dan Alistarh, Torsten Hoefler, Mikael Johansson, Sarit Khirirat, Nikola Konstanti- nov, and Cédric Renggli. 2018. The convergence of sparsified gradient methods. In Proceedings of the 32nd International Conference on Neural Information Processing Systems. 5977–5987.

[5] Anastasia Koloskova, Sebastian Stich, and Martin Jaggi. 2019. Decentralized stochastic optimization and gossip algorithms with compressed communication. In International Conference on Machine Learning. PMLR, 3478–3487.

[6] Xiangru Lian, Ce Zhang, Huan Zhang, Cho-Jui Hsieh, Wei Zhang, and Ji Liu. 2017. Can decentralized algorithms outperform centralized algorithms? a case study for decentralized parallel stochastic gradient descent. In Proceedings of the 31st International Conference on Neural Information Processing Systems. 5336–5346.

[7] Christopher De Sa, Matthew Feldman, Christopher Ré, and Kunle Olukotun. 2017. Understanding and optimizing asynchronous low-precision stochastic gradient descent. In Proceedings of the 44th Annual International Symposium on Computer Architecture. 561–574.

[8] Xiangru Lian, Wei Zhang, Ce Zhang, and Ji Liu. 2018. Asynchronous decentral- ized parallel stochastic gradient descent. In International Conference on Machine Learning. PMLR, 3043–3052.

[9] Hanlin Tang, Shaoduo Gan, Ce Zhang, Tong Zhang, and Ji Liu. 2018. Com- munication compression for decentralized training. In Proceedings of the 32nd International Conference on Neural Information Processing Systems. 7663–7673.

[10] Ji Liu, Ce Zhang, et al. 2020. Distributed Learning Systems with First-Order Methods. Foundations and Trends® in Databases 9, 1 (2020), 1–100.


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM