pytorch單機多卡並行計算示例


一個簡單的例子。

注意:

os.environ['MASTER_ADDR'] = 'xxx.xx.xx.xxx' # 這里填寫電腦的IP地址
os.environ['MASTER_PORT'] = '29555' # 空閑端口

這兩個參數似乎必須提前給出,選擇的初始化方法為init_method="env://"(默認的環境變量方法)

 

# 單機多卡並行計算示例

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP


# https://pytorch.org/docs/stable/notes/ddp.html


def example(local_rank, world_size): # local_rank由mp.spawn自動給出
    # create default process group
    dist.init_process_group(backend="gloo", init_method="env://", rank=local_rank, world_size=world_size)
    # create local model
    model = nn.Linear(10, 10).cuda(local_rank)
    # construct DDP model
    ddp_model = DDP(model, device_ids=[local_rank], output_device=local_rank)
    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    # forward pass
    for i in range(100):
        if local_rank == 0: # 這里開幾個進程就會打印幾次
            print(i)
        outputs = ddp_model(torch.randn(20, 10).cuda(local_rank))
        labels = torch.randn(20, 10).cuda(local_rank)
        # backward pass
        loss_fn(outputs, labels).backward()
        # update parameters
        optimizer.step()


def main():
    os.environ['MASTER_ADDR'] = 'xxx.xx.xx.xxx' # 這里填寫電腦的IP地址
    os.environ['MASTER_PORT'] = '29555' # 空閑端口
    world_size = torch.cuda.device_count()
    mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)



if __name__=="__main__":
    main()
    print('Done!')

  

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM