使用pytorch的DistributedParallel進行單機多卡訓練


1. 導入庫:

import torch.distributed as dist

from torch.utils.data.distributed import DistributedSampler

2. 進程初始化:

parser = argparse.ArgumentParser()

parser.add_argument('--local_rank', type=int, default=-1)

添加必要參數
local_rank:系統自動賦予的進程編號,可以利用該編號控制打印輸出以及設置device

torch.distributed.init_process_group(backend="nccl", init_method='file://shared/sharedfile', rank=local_rank, world_size=world_size)

world_size:所創建的進程數,也就是所使用的GPU數量

(初始化設置詳見參考文檔)

3. 數據分發:

dataset = datasets.ImageFolder(dataPath)

data_sampler = DistributedSampler(dataset, rank=local_rank, num_replicas=world_size)

使用DistributedSampler來為各個進程分發數據,其中num_replicas與world_size保持一致,用於將數據集等分成不重疊的數個子集

dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=1, drop_last=True, pin_memory=True, sampler=data_sampler)

在Dataloader中指定sampler時,其中的shuffle必須為False,而DistributedSampler中的shuffle項默認為True,因此訓練過程默認執行shuffle

4. 網絡模型:

torch.cuda.set_device(local_rank)

device = torch.device('cuda:'+f'{local_rank}')

設置每個進程對應的GPU設備

D = Model()

D = torch.nn.SyncBatchNorm.convert_sync_batchnorm(D).to(device)

由於在訓練過程中各卡的前向后向傳播均獨立進行,因此無法進行統一的批歸一化,如果想要將各卡的輸出統一進行批歸一化,需要將模型中的BN轉換成SyncBN

D = torch.nn.parallel.DistributedDataParallel(D, find_unused_parameters=True, device_ids=[local_rank], output_device=local_rank)

如果有forward的返回值如果不在計算loss的計算圖里,那么需要find_unused_parameters=True,即返回值不進入backward去算grad,也不需要在不同進程之間進行通信。

5. 迭代:

data_sampler.set_epoch(epoch)

每個epoch需要為sampler設置當前epoch

6. 加載:

dist.barrier()

D.load_state_dict(torch.load('D.pth'), map_location=torch.device('cpu'))

dist.barrier()

加載模型前后用dist.barrier()來同步不同進程間的快慢

7. 啟動:

CUDA_VISIBLE_DEVICES=1,3 python -m torch.distributed.launch --nproc_per_node=2 train.py --epochs 15000 --batchsize 10 --world_size 2

用-m torch.distributed.launch啟動,nproc_per_node為所使用的卡數,batchsize設置為每張卡各自的批大小

Reference:

https://github.com/GoldenRaven/Pytorch_DistributedParallel_GPU_test

https://www.cnblogs.com/yh-blog/p/12877922.html

https://zhuanlan.zhihu.com/p/86441879

https://zhuanlan.zhihu.com/p/98535650

https://blog.csdn.net/lgzlgz3102/article/details/107054314

https://blog.csdn.net/baidu_19518247/article/details/89635181

https://blog.csdn.net/m0_38008956/article/details/86559432?utm_source=blogxgwz4


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM