一個簡單的例子。
注意:
os.environ['MASTER_ADDR'] = 'xxx.xx.xx.xxx' # 這里填寫電腦的IP地址 os.environ['MASTER_PORT'] = '29555' # 空閑端口
這兩個參數似乎必須提前給出,選擇的初始化方法為init_method="env://"(默認的環境變量方法)
# 單機多卡並行計算示例
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
# https://pytorch.org/docs/stable/notes/ddp.html
def example(local_rank, world_size): # local_rank由mp.spawn自動給出
# create default process group
dist.init_process_group(backend="gloo", init_method="env://", rank=local_rank, world_size=world_size)
# create local model
model = nn.Linear(10, 10).cuda(local_rank)
# construct DDP model
ddp_model = DDP(model, device_ids=[local_rank], output_device=local_rank)
# define loss function and optimizer
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
# forward pass
for i in range(100):
if local_rank == 0: # 這里開幾個進程就會打印幾次
print(i)
outputs = ddp_model(torch.randn(20, 10).cuda(local_rank))
labels = torch.randn(20, 10).cuda(local_rank)
# backward pass
loss_fn(outputs, labels).backward()
# update parameters
optimizer.step()
def main():
os.environ['MASTER_ADDR'] = 'xxx.xx.xx.xxx' # 這里填寫電腦的IP地址
os.environ['MASTER_PORT'] = '29555' # 空閑端口
world_size = torch.cuda.device_count()
mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)
if __name__=="__main__":
main()
print('Done!')
