Pytorch之Dataparallel源碼解析


之前對Pytorch 1.0 的Dataparallel的使用方法一直似懂非懂,總是會碰到各種莫名其妙的問題,今天就好好從源頭梳理一下,更好地理解它的原理或者說說下步驟。

源碼地址: https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/data_parallel.py

初始化

首先我們一行一行地來看一下Dataparallel是如何初始化的。

  • super就是繼承torch.nn.Module父類,這里不做解釋
  • 第一個if判斷語句:檢查是否有可用GPU
  • 第二個if判斷語句:如果沒有指定GPU,則默認使用所有可用的GPU
  • 第三個if判斷語句:output_device表示輸出到哪一個GPU上,默認是第一個GPU,注意這個第一個device_ids列表上的第一個,所以如果你有三個GPU,而你在將model復制到cuda上時寫的代碼是model.cuda(1)或者model.cuda(2),則會報錯,因為device_ids是[0,1,2].其第一個元素是0。這一點可以在后面的forward函數中看到。
  • emm,后面每行代碼的作用很清楚,就不再一一解釋了。
def __init__(self, module, device_ids=None, output_device=None, dim=0):
	super(DataParallel, self).__init__()

	if not torch.cuda.is_available():
		self.module = module
		self.device_ids = []
		return

	if device_ids is None:
		device_ids = list(range(torch.cuda.device_count()))
	if output_device is None:
		output_device = device_ids[0]

	self.dim = dim
	self.module = module
	self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
	self.output_device = _get_device_index(output_device, True)
	self.src_device_obj = torch.device("cuda:{}".format(self.device_ids[0]))

	_check_balance(self.device_ids)

	if len(self.device_ids) == 1:
		self.module.cuda(device_ids[0])

前向傳播

下面進入到重頭戲:Dataparallel的forward函數。

def forward(self, *inputs, **kwargs):
	if not self.device_ids:
		return self.module(*inputs, **kwargs)

	for t in chain(self.module.parameters(), self.module.buffers()):
		if t.device != self.src_device_obj:
			raise RuntimeError("module must have its parameters and buffers "
							   "on device {} (device_ids[0]) but found one of "
							   "them on device: {}".format(self.src_device_obj, t.device))

	inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
	if len(self.device_ids) == 1:
		return self.module(*inputs[0], **kwargs[0])
	replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
	outputs = self.parallel_apply(replicas, inputs, kwargs)
	return self.gather(outputs, self.output_device)
  • 第一個if判斷語句:如果沒有可用的GPU設備,則使用原來的module進行計算。
  • for循環就是對應了前面提到的問題,用於檢查model和input是不是放在第一個GPU上
  • 之后下一步就是將將input平均划分到每個GPU上,用到的是下面的scatter函數
def scatter(inputs, target_gpus, dim=0):
    r"""
    Slices tensors into approximately equal chunks and
    distributes them across given GPUs. Duplicates
    references to objects that are not tensors.
    """
    def scatter_map(obj):
        if isinstance(obj, torch.Tensor):
            return Scatter.apply(target_gpus, None, dim, obj)
        if isinstance(obj, tuple) and len(obj) > 0:
            return list(zip(*map(scatter_map, obj)))
        if isinstance(obj, list) and len(obj) > 0:
            return list(map(list, zip(*map(scatter_map, obj))))
        if isinstance(obj, dict) and len(obj) > 0:
            return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
        return [obj for targets in target_gpus]

    # After scatter_map is called, a scatter_map cell will exist. This cell
    # has a reference to the actual function scatter_map, which has references
    # to a closure that has a reference to the scatter_map cell (because the
    # fn is recursive). To avoid this reference cycle, we set the function to
    # None, clearing the cell
    try:
        res = scatter_map(inputs)
    finally:
        scatter_map = None
    return res
  • 數據划分之后呢,再判斷一下有幾個可用的GPU(前面是判斷有沒有,這里是判斷有幾個),如果只有一個GPU,那就不用進入到下一步了。
  • 如果有多個GPU,那么就需要用到replica函數,這個函數比較復雜,就不解釋了,感興趣的可以閱讀一下源碼:https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/replicate.py 。不過它的主要作用就是將模型復制到多個GPU上。
  • 下一步中的parallel_apply作用就是並行地在多個GPU上計算模型,每個模型是一樣的,只不過輸入數據是不一樣的,因為前面將數據平均划分了。例如你有兩個GPU,一個batch大小是64,那么兩個GPU分別處理batch大小為32的數據。
  • 最后就是將輸出值gather到一起,傳送到output_device,即第一個GPU設備上。


微信公眾號:AutoML機器學習
MARSGGBO原創
如有意合作或學術討論歡迎私戳聯系~
郵箱:marsggbo@foxmail.com





2019-6-2




免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM