一個簡單的例子。
注意:
os.environ['MASTER_ADDR'] = 'xxx.xx.xx.xxx' # 這里填寫電腦的IP地址 os.environ['MASTER_PORT'] = '29555' # 空閑端口
這兩個參數似乎必須提前給出,選擇的初始化方法為init_method="env://"(默認的環境變量方法)
# 單機多卡並行計算示例 import os import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn import torch.optim as optim from torch.nn.parallel import DistributedDataParallel as DDP # https://pytorch.org/docs/stable/notes/ddp.html def example(local_rank, world_size): # local_rank由mp.spawn自動給出 # create default process group dist.init_process_group(backend="gloo", init_method="env://", rank=local_rank, world_size=world_size) # create local model model = nn.Linear(10, 10).cuda(local_rank) # construct DDP model ddp_model = DDP(model, device_ids=[local_rank], output_device=local_rank) # define loss function and optimizer loss_fn = nn.MSELoss() optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) # forward pass for i in range(100): if local_rank == 0: # 這里開幾個進程就會打印幾次 print(i) outputs = ddp_model(torch.randn(20, 10).cuda(local_rank)) labels = torch.randn(20, 10).cuda(local_rank) # backward pass loss_fn(outputs, labels).backward() # update parameters optimizer.step() def main(): os.environ['MASTER_ADDR'] = 'xxx.xx.xx.xxx' # 這里填寫電腦的IP地址 os.environ['MASTER_PORT'] = '29555' # 空閑端口 world_size = torch.cuda.device_count() mp.spawn(example, args=(world_size,), nprocs=world_size, join=True) if __name__=="__main__": main() print('Done!')