torch.distributed.barrier()


1、背景介紹

  在pytorch的多卡訓練中,通常有兩種方式,一種是單機多卡模式(存在一個節點,通過torch.nn.DataParallel(model)實現),一種是多機多卡模式(存在一個節點或者多個節點,通過torch.nn.parallel.DistributedDataParallel(model),在單機多卡環境下使用第二種分布式訓練模式具有更快的速度。pytorch在分布式訓練過程中,對於數據的讀取是采用主進程預讀取並緩存,然后其它進程從緩存中讀取,不同進程之間的數據同步具體通過torch.distributed.barrier()實現。

2、通俗理解torch.distributed.barrier()

def create_dataloader():
    #使用上下文管理器中實現的barrier函數確保分布式中的主進程首先處理數據,然后其它進程直接從緩存中讀取
    with torch_distributed_zero_first(rank):
        dataset = LoadImagesAndLabels()
 
 
from contextlib import contextmanager
 
#定義的用於同步不同進程對數據讀取的上下文管理器
@contextmanager
def torch_distributed_zero_first(local_rank: int):
    """
    Decorator to make all processes in distributed training wait for each local_master to do something.
    """
    if local_rank not in [-1, 0]:
        torch.distributed.barrier()
    yield   #中斷后執行上下文代碼,然后返回到此處繼續往下執行
    if local_rank == 0:
        torch.distributed.barrier()

  

(1)進程號rank理解
在多進程上下文中,我們通常假定rank 0是第一個進程或者主進程,其它進程分別具有0,1,2不同rank號,這樣總共具有4個進程。

(2)單一進程數據處理
通常有一些操作是沒有必要以並行的方式進行處理的,如數據讀取與處理操作,只需要一個進程進行處理並緩存,然后與其它進程共享緩存處理數據,但是由於不同進程是同步執行的,單一進程處理數據必然會導致進程之間出現不同步的現象,為此,torch中采用了barrier()函數對其它非主進程進行阻塞,來達到同步的目的。

(3)barrier()具體原理
在上面的代碼示例中,如果執行create_dataloader()函數的進程不是主進程,即rank不等於0或者-1,上下文管理器會執行相應的torch.distributed.barrier(),設置一個阻塞柵欄,讓此進程處於等待狀態,等待所有進程到達柵欄處(包括主進程數據處理完畢);如果執行create_dataloader()函數的進程是主進程,其會直接去讀取數據並處理,然后其處理結束之后會接着遇到torch.distributed.barrier(),此時,所有進程都到達了當前的柵欄處,這樣所有進程就達到了同步,並同時得到釋放。

原文鏈接:https://blog.csdn.net/weixin_41041772/article/details/109820870

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM