TensorFlow Distribution(分布式中的數據讀取和訓練)


本文目的

在介紹estimator分布式的時候,官方文檔由於版本更新導致與接口不一致。具體是:在estimator分布式當中,使用dataset作為數據輸入,在1.12版本中,數據訓練只是dataset的數據,就是所有設備加起來,跑一遍數據。

而在2.0版本中,訓練數據是dataset的數據乘以分
布式的設備數。也就是說,在每個設備當中都會完整地跑一遍dataset的所有數據。

1.12版本讀取

1. 在主線程當中創建圖

下面這段代碼中,在client中調用了input function,得到迭代器。這是屬於estimator distribute train調用的代碼

with ops.Graph().as_default() as g:
      # We want to create the iterations variable outside the distribution scope
      # as that is just stored on the host and mainly used to drive the loop
      # and doesn't need to be a Mirrored/Device variable.
      if is_tpu_strategy:
        steps_per_run_variable = training.get_or_create_steps_per_run_variable()
      with self._train_distribution.scope():
        random_seed.set_random_seed(self._config.tf_random_seed)
        iterator, input_hooks = self._get_iterator_from_input_fn(
            input_fn, model_fn_lib.ModeKeys.TRAIN, self._train_distribution)
  • _get_iterator_from_input_fn * 這個函數會生成迭代器供后續訓練讀取數據。
  def _get_iterator_from_input_fn(self, input_fn, mode, distribution=None):
    if distribution is not None:
      result = distribution.distribute_dataset(
          lambda: self._call_input_fn(input_fn, mode))
    else:
      result = self._call_input_fn(input_fn, mode)

    iterator = result.make_initializable_iterator()
    input_hooks = [estimator_util._DatasetInitializerHook(iterator)]  # pylint: disable=protected-access
    return iterator, input_hooks

這里會調用distribute_dataset生成dataset。
再點進去看以后可看到會創建這樣一個PerDeviceDataset

class PerDeviceDataset(object):
  """Like `tf.data.Dataset` split devices, producing `PerDevice` data."""

  def __init__(self, dataset, devices, prefetch_on_device=None):
    self._devices = devices

    # Default to using prefetching in graph mode, unless specified.
    # TODO(priyag): Enable prefetching in eager mode.
    self._prefetch_on_device = prefetch_on_device
    if self._prefetch_on_device is None:
      self._prefetch_on_device = not context.executing_eagerly()
    assert not (self._prefetch_on_device and context.executing_eagerly()), (
        "Prefetching is only supported in graph mode currently")

    if self._prefetch_on_device:
      self._dataset = dataset.apply(
          prefetching_ops_v2.prefetch_to_devices(self._devices))
    else:
      # TODO(priyag): If dropping remainder is not appropriate, find another
      # approach to distributing the dataset when not possible to divide evenly.
      # Possibly not an issue when we start using PartitionedDataset.
      self._dataset = dataset.batch(len(devices), drop_remainder=True)

最后一行代碼可以看到,在原dataset上又封裝了一層batch。將數據根據設備數切分。
后面創建迭代器也是封裝為PerDeviceDataIterator,形成一個字典映射,不同設備不同數據,根據batch 的index切分。

分布式訓練

在1.12版本中的訓練比較簡單。對於MirroredStrategy來說,會給每個一個device創建一個線程,
有一個缺點就是,每一次run都會創建線程,在todo里看到,后續會優化掉應該。

下面是在client中從迭代器獲取數據,傳遞給每個device去運算的代碼,
self._train_distribution.call_for_each_tower

features, labels = estimator_util.parse_iterator_result(
              iterator.get_next())
          grouped_estimator_spec = self._train_distribution.call_for_each_tower(
              self._call_model_fn,
              features,
              labels,  # although this will be None it seems
              model_fn_lib.ModeKeys.TRAIN,
              self.config)
          loss = self._train_distribution.unwrap(
              self._train_distribution.reduce(
                  distribute_lib.get_loss_reduction(),
                  grouped_estimator_spec.loss,
                  destinations='/device:CPU:0'))[0]
          distributed_train_op = grouped_estimator_spec.train_op

call_for_each_tower是每個設備訓練的接口

def _call_for_each_tower(distribution, fn, *args, **kwargs):
  """Run `fn` in separate threads, once per tower/worker device.
  run_concurrently = kwargs.pop("run_concurrently", True)
  if not context.executing_eagerly():
    # Lots of TF library code isn't thread-safe in graph mode, and
    # there is little to be gained by turning on multithreading when
    # constructing a graph.
    run_concurrently = False
    # Needed for per-thread device, etc. contexts in graph mode.
    ops.get_default_graph().switch_to_thread_local()
  elif run_concurrently is None:
    run_concurrently = True

  coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,))

  shared_variable_store = {}

  # TODO(isaprykin): Create these threads once instead of during every run()
  # call.
  threads = []
  for index, d in enumerate(distribution.worker_devices):
    variable_creator_fn = shared_variable_creator.make_fn(
        shared_variable_store, index)
    t = MirroredStrategy._MirroredTowerThread(  # pylint: disable=protected-access
        distribution, coord, d, variable_creator_fn, fn,
        *values.select_device(d, args), **values.select_device(d, kwargs))
    threads.append(t)

  for t in threads:
    t.start()

其中,select_device就是取對應設備key對應的值。完成整個分布式訓練。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM