1. 导入库:
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
2. 进程初始化:
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
添加必要参数
local_rank:系统自动赋予的进程编号,可以利用该编号控制打印输出以及设置device
torch.distributed.init_process_group(backend="nccl", init_method='file://shared/sharedfile', rank=local_rank, world_size=world_size)
world_size:所创建的进程数,也就是所使用的GPU数量
(初始化设置详见参考文档)
3. 数据分发:
dataset = datasets.ImageFolder(dataPath)
data_sampler = DistributedSampler(dataset, rank=local_rank, num_replicas=world_size)
使用DistributedSampler来为各个进程分发数据,其中num_replicas与world_size保持一致,用于将数据集等分成不重叠的数个子集
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=1, drop_last=True, pin_memory=True, sampler=data_sampler)
在Dataloader中指定sampler时,其中的shuffle必须为False,而DistributedSampler中的shuffle项默认为True,因此训练过程默认执行shuffle
4. 网络模型:
torch.cuda.set_device(local_rank)
device = torch.device('cuda:'+f'{local_rank}')
设置每个进程对应的GPU设备
D = Model()
D = torch.nn.SyncBatchNorm.convert_sync_batchnorm(D).to(device)
由于在训练过程中各卡的前向后向传播均独立进行,因此无法进行统一的批归一化,如果想要将各卡的输出统一进行批归一化,需要将模型中的BN转换成SyncBN
D = torch.nn.parallel.DistributedDataParallel(D, find_unused_parameters=True, device_ids=[local_rank], output_device=local_rank)
如果有forward的返回值如果不在计算loss的计算图里,那么需要find_unused_parameters=True,即返回值不进入backward去算grad,也不需要在不同进程之间进行通信。
5. 迭代:
data_sampler.set_epoch(epoch)
每个epoch需要为sampler设置当前epoch
6. 加载:
dist.barrier()
D.load_state_dict(torch.load('D.pth'), map_location=torch.device('cpu'))
dist.barrier()
加载模型前后用dist.barrier()来同步不同进程间的快慢
7. 启动:
CUDA_VISIBLE_DEVICES=1,3 python -m torch.distributed.launch --nproc_per_node=2 train.py --epochs 15000 --batchsize 10 --world_size 2
用-m torch.distributed.launch启动,nproc_per_node为所使用的卡数,batchsize设置为每张卡各自的批大小
Reference:
https://github.com/GoldenRaven/Pytorch_DistributedParallel_GPU_test
https://www.cnblogs.com/yh-blog/p/12877922.html
https://zhuanlan.zhihu.com/p/86441879
https://zhuanlan.zhihu.com/p/98535650
https://blog.csdn.net/lgzlgz3102/article/details/107054314
https://blog.csdn.net/baidu_19518247/article/details/89635181
https://blog.csdn.net/m0_38008956/article/details/86559432?utm_source=blogxgwz4