Pytorch Distributed 初始化


Pytorch Distributed 初始化方法

參考文獻

https://pytorch.org/docs/master/distributed.html

代碼
https://github.com/overfitover/pytorch-distributed
歡迎來star me.

初始化

torch.distributed.init_process_group(backend, init_method='env://', **kwargs)

參數說明

  • backend(str): 后端選擇,包括 tcp mpi gloo
  • init_method(str, optional): 用來初始化包的URL, 用來做並發控制的共享方式
  • world_size(int, optional): 參與工作的進程數
  • rank(int, optional): 當前進程的rank
  • group_name(str, optional): 用來標記這組進程。

init_method()

有三種方法:

  • file:// 共享文件系統
  • tcp:// IP組播
  • env:// 環境變量 (默認是這個)

env

#!/usr/bin/env python
import os
import torch
import torch.distributed as dist
from torch.multiprocessing import Process
import time

def run(rank, size):
    pass


def init_processes(rank, size, fn, backend='gloo'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '162.128.0.22'
    os.environ['MASTER_PORT'] = '29555'
    dist.init_process_group(backend, rank=rank, world_size=size)
    torch.cuda.manual_seed(1)
    fn(rank, size)
    print("MM")
    print(dist.get_rank())
    print(dist.get_world_size())
    print(dist.is_available())


def main():

    size = 2
    processes=[]
    for i in range(size):
        p = Process(target=init_processes, args=(i, size, run))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

if __name__ == "__main__":
    start_time = time.time()
    main()

    end_time = time.time()
    print("耗時:", end_time-start_time)

注意
將162.128.0.22換成自己的IP地址。

tcp

import torch
import torch.distributed as dist
import argparse
from time import sleep
from random import randint
from torch.multiprocessing import Process


def initialize(rank, world_size, ip, port):
    dist.init_process_group(backend='tcp', init_method='tcp://{}:{}'.format(ip, port), rank=rank, world_size=world_size)
    print("MM")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--ip', type=str, default='162.128.0.22')
    parser.add_argument('--port', type=str, default='20000')
    parser.add_argument('--rank', '-r', type=int)
    parser.add_argument('--world-size', '-s', type=int)
    args = parser.parse_args()
    print(args)
    # initialize(args.rank, args.world_size, args.ip, args.port)

    size = 2
    processes = []
    for i in range(size):
        p = Process(target=initialize, args=(i, size, args.ip, args.port))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()


if __name__ == '__main__':
    main()

注意
將162.128.0.22換成自己的IP地址。

共享文件

import argparse
from time import sleep
from random import randint
from torch.multiprocessing import Process


def initialize(rank, world_size):
    dist.init_process_group(backend='gloo', init_method='file:///home/yxk/Documents/Deeplearningoflidar139/overfitover/share', rank=rank, world_size=world_size)
    print("MM")

def main():

    size = 2
    processes = []
    for i in range(size):
        p = Process(target=initialize, args=(i, size))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()


if __name__ == '__main__':
    main()

注意
init_method: 需要以file://開頭,包含共享文件系統上不存在的文件(在現有目錄中)的路徑。如果文件不存在, 文件系統初始化將自動創建該文件,但不會刪除該文件。你要在下一個init_process_group調用之前清楚該文件。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM