[源碼解析] TensorFlow 之 分布式變量
在 TensorFlow 之中,分布式變量是在多個設備上創建的變量。Mirrored variable 和 SyncOnRead variable 是兩個例子。本文就對分布式變量進行分析。我們通過一系列問題來引導分析:
- 創建如何調用到 Strategy 這里?
- 如何生成 Mirrored Variable?
- 如何把張量分發到各個設備上?
- 如果對外保持一個統一的視圖?
- 變量之間如何保持一致?
依然安利兩個大神:
[TensorFlow Internals] (https://github.com/horance-liu/tensorflow-internals),雖然其分析的不是最新代碼,但是建議對 TF 內部實現機制有興趣的朋友都去閱讀一下,絕對大有收獲。
https://home.cnblogs.com/u/deep-learning-stacks/ 西門宇少,不僅僅是 TensorFlow,其公共號還有更多其他領域,業界前沿。
本系列其他文章是:
[翻譯] TensorFlow 分布式之論文篇 "Implementation of Control Flow in TensorFlow"
[源碼解析] TensorFlow 分布式環境(1) --- 總體架構
[源碼解析] TensorFlow 分布式環境(2)---Master 靜態邏輯
[源碼解析] TensorFlow 分布式環境(3)--- Worker 靜態邏輯
[源碼解析] TensorFlow 分布式環境(4) --- WorkerCache
[源碼解析] TensorFlow 分布式環境(5) --- Session
[源碼解析] TensorFlow 分布式環境(7) --- Worker 動態邏輯
[源碼解析] TensorFlow 分布式環境(8) --- 通信機制
[源碼解析] TensorFlow 分布式 DistributedStrategy 之基礎篇
1. MirroredVariable
tf.distribute.MirroredStrategy 支持在一台機器的多個 GPU 上進行同步分布式訓練。該策略會為每個 GPU 設備創建一個副本。模型中的每個變量都會在所有副本之間進行鏡像。這些變量將共同形成一個名為 MirroredVariable 的單個概念上的變量。這些變量會通過應用相同的更新彼此保持同步。
圖 1 MirroredVariable
具體使用代碼示例如下:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
# Variable created inside scope:
with strategy.scope():
mirrored_variable = tf.Variable(1.)
# Variable created outside scope:
regular_variable = tf.Variable(1.)
打印結果如下:
>>> mirrored_variable
MirroredVariable:{
0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,
1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0>
}
>>> regular_variable
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
或者也可以參見 tensorflow/python/module/module_test.py 之中的示例.
def test_supports_distributed_variables(self):
mirrored = distributed_values.MirroredVariable(
None, [variables.Variable(1.)], variables.VariableAggregation.SUM)
tpu = tpu_values.TPUMirroredVariable(
strategy=None, values=[variables.Variable(42.)], aggregation=None)
aggregating = ps_values.AggregatingVariable(
strategy=None, v=variables.Variable(1.), aggregation=None)
m = module.Module()
m.a = mirrored
1.1 定義
MirroredVariable 注釋之中指出其作用是 :保存一個從副本到變量的映射,這些變量的值保持同步。具體沒有任何新增成員變量,只是實現了一些成員函數。
class MirroredVariable(DistributedVariable, Mirrored):
"""Holds a map from replica to variables whose values are kept in sync."""
def _update_replica(self, update_fn, value, **kwargs):
return _on_write_update_replica(self, update_fn, value, **kwargs)
def scatter_min(self, *args, **kwargs):
if values_util.is_saving_non_distributed():
return self._primary.scatter_min(*args, **kwargs)
return super(MirroredVariable, self).scatter_min(*args, **kwargs)
def scatter_max(self, *args, **kwargs):
if values_util.is_saving_non_distributed():
return self._primary.scatter_max(*args, **kwargs)
return super(MirroredVariable, self).scatter_max(*args, **kwargs)
def scatter_update(self, *args, **kwargs):
if values_util.is_saving_non_distributed(): # 非分布式情況
# 直接返回本地數值
return self._primary.scatter_update(*args, **kwargs)
# 否則進行分布式處理
return super(MirroredVariable, self).scatter_update(*args, **kwargs)
def _get_cross_replica(self):
# Return identity, to avoid directly exposing the variable to the user and
# allowing it to be modified by mistake.
return array_ops.identity(Mirrored._get_cross_replica(self))
我們以 scatter_update 為例看看,當不是分布式時候,其會直接調用 _primary 進行處理,否則會調用基類方法處理。另外,_update_replica 方法在更新時候會調用 _on_write_update_replica 進行副本同步,_on_write_update_replica 又會從使用上下文來進行更新,具體定義在 tensorflow/python/distribute/values.py 之中。
def _on_write_update_replica(var, update_fn, value, **kwargs):
"""Updates variables with ON_WRITE synchronization in replica context."""
if var.aggregation == vs.VariableAggregation.NONE:
return update_fn(var._get_on_device_or_primary(), value, **kwargs)
aggregated_value = apply_aggregation_replica_context(
value, var.aggregation, var)
values_util.mark_as_unsaveable()
return ds_context.get_replica_context()._update(
var,
update_fn,
args=(aggregated_value,),
kwargs=kwargs,
group=True)
else:
def merge_fn(strategy, value, **kwargs):
"""Aggregate values and update all variables in cross replica context."""
v = values_util.apply_aggregation(strategy, value, var.aggregation, var)
return var._update_cross_replica(update_fn, v, **kwargs)
return ds_context.get_replica_context().merge_call(
merge_fn, args=(value,), kwargs=kwargs)
只看這些成員方法,我們很難對 MirroredVariable 有一個清晰認識,我們還是需要從其類體系入手來分析。
1.2 相關類
1.2.1 類體系
MirroredVariable 類體系如下,我們會在逐一分析之后,再最終進行匯總。
圖 2 MirroredVariable 類體系
1.2.2 DistributedValues
我們首先看看 DistributedValues。
圖 3 DistributedValues
分布式變量(DistributedValues)由基類 tf.distribution.DistributedValues 表示。 tf.distributed.DistributedValues 概念適合表示多個設備上的值,它包含一個從副本ID到值的映射。
tf.distributed.DistributedValues 包含每個副本的一個值。根據子類的不同,這些值可以在更新時同步,也可以在需求時同步,或者從不同步。 tf.distributed.DistributedValues 可以規約(reduce)以獲得跨副本的單一值來作為 tf.distributed.Strategy.run 的輸入,或使用 tf.distributed.Strategy.experimental_local_results 檢查每個副本的值。
DistributedValues 作為基類不應該被直接實例化。而應該在 distribution strategy 之中創建其子類實例,具體可以通過在 tf.distribution.DistributedDataset 迭代或者通過 tf.distribution.Strategy.run 創建。
tf.distributed.DistributedValues 的兩種代表性類型是 "PerReplica" 和 "Mirrored" 值。
-
"PerReplica"值存在於 worker 設備上,每個副本有不同的值。它們是由 tf.distribution.Strategy.experimental_distribute_dataset 和 tf.distribution.Strategy.distribution_datasets_from_function 返回的分布式數據集的迭代產生。它們也是由 tf.distribution.Strategy.run 返回的典型結果。
-
"Mirrored"值與 "PerReplica"值類似,只是所有副本上的值都是一樣的。我們可以通過使用任何副本上的值,在跨副本上下文中安全地讀取 "Mirrored"值。
定義
DistributedValues 有 兩個成員變量比較重要,_values 和 _primary。初始化變量被設置到 _values 數組之中,數組第一個變量被復制為 _primary。
因為派生類會用到,所以我們分析 DistributedValues 的幾個成員函數。
- _get_on_device_or_primary 就是返回本副本對應的value,或者直接返回 _primary 對應的value。
- _get_cross_replica :返回跨副本value,這個留給派生類實現。
- _get :如果得到replica_id,就調用 _get_cross_replica 返回跨副本數值,或者返回本地數據。
概念圖如下:
圖 4 DistributedValues
DistributedValues 具體代碼如下:
@tf_export("distribute.DistributedValues", v1=[])
class DistributedValues(object):
"""Base class for representing distributed values.
A subclass instance of tf.distribute.DistributedValues is created when
creating variables within a distribution strategy, iterating a
tf.distribute.DistributedDataset or through tf.distribute.Strategy.run .
This base class should never be instantiated directly.
tf.distribute.DistributedValues contains a value per replica. Depending on
the subclass, the values could either be synced on update, synced on demand,
or never synced.
tf.distribute.DistributedValues can be reduced to obtain single value across
replicas, as input into tf.distribute.Strategy.run or the per-replica values
inspected using tf.distribute.Strategy.experimental_local_results .
"""
def __init__(self, values):
"""Should only be called by subclass __init__."""
self._values = tuple(values)
def _get(self):
"""Returns the value for the current device or raises a ValueError."""
replica_id = values_util.get_current_replica_id_as_int()
if replica_id is None:
return self._get_cross_replica() # 返回跨副本信息
else:
return self._values[replica_id] # 返回本地信息
def _get_cross_replica(self):
raise NotImplementedError(
"DistributedValues._get_cross_replica should be implemented by "
"sub-classes which support cross-replica accesses.")
def _get_on_device_or_primary(self):
"""Returns value in same replica or device if possible, else the _primary."""
# 獲取當前副本id
replica_id = values_util.get_current_replica_id_as_int()
if replica_id is None: # 如果沒有副本id,則看看本機上設備集合
# Try to find a value on the current device.
# 拿到當前設備名字,current_device 是一個string
current_device = device_util.canonicalize(device_util.current())
for value in self._values: # 遍歷
if device_util.canonicalize(value.device) == current_device:
return value # 返回
return self._primary # 返回 _primary
else:
# 返回本副本對應的value
return self._values[replica_id]
@property
def _primary(self):
"""Returns a representative component."""
return self._values[0]
@property
def _devices(self):
return tuple(v.device for v in self._values)
上面代碼之中大量用到了 get_current_replica_id_as_int,此函數定義在 tensorflow/python/distribute/values_util.py 之中,作用是獲取當前副本id。
def get_current_replica_id_as_int():
"""Returns the current replica ID as an integer, or None ."""
replica_context = ds_context.get_replica_context()
if replica_context:
replica_id = replica_context._replica_id
if not isinstance(replica_id, int):
replica_id = tensor_util.constant_value(replica_id)
else:
replica_id = distribute_lib.get_update_replica_id()
return replica_id
使用
我們從源碼之中找出一些使用例子如下,都是使用 MirroredStrategy 來獲取 DistributedValues。
# 1. Created from a tf.distribute.DistributedDataset :
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
distributed_values = next(dataset_iterator)
# 2. Returned by run :
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
@tf.function
def run():
ctx = tf.distribute.get_replica_context()
return ctx.replica_id_in_sync_group
distributed_values = strategy.run(run)
# 3. As input into run :
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
distributed_values = next(dataset_iterator)
@tf.function
def run(input):
return input + 1.0
updated_value = strategy.run(run, args=(distributed_values,))
# 4. Reduce value:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
distributed_values = next(dataset_iterator)
reduced_value = strategy.reduce(tf.distribute.ReduceOp.SUM,
distributed_values,
axis = 0)
# 5. Inspect local replica values:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
per_replica_values = strategy.experimental_local_results(distributed_values)
print(per_replica_values)
# 輸出結果
# (<tf.Tensor: shape=(1,), dtype=float32, numpy=array([5.], dtype=float32)>,
# <tf.Tensor: shape=(1,), dtype=float32, numpy=array([6.], dtype=float32)>)
1.2.3 DistributedDelegate
接下來我們看看 DistributedDelegate 。
圖 5 DistributedDelegate
DistributedDelegate 作用是在 DistributedValues 之上增加了計算功能。具體是通過 _get_as_operand 來調用基類 DistributedValues 的 _get 方法,得到value,然后進行計算。
圖 6 如何計算
DistributedDelegate 定義如下,省略部分代碼。
class DistributedDelegate(DistributedValues):
"""A map from device to values; acts as the same type as the values."""
def __getattr__(self, name):
# The '_use_resource_variables' and the attrs starts with '_self' are used
# for restoring the saved_model proto, and '_attribute_sentinel' is used for
# Layer tracking. At the point these attrs are queried, the variable has not
# been initialized. Thus it should not query those of the underlying
# components.
if name.startswith("_self_") or name in ("_use_resource_variables",
"_attribute_sentinel",
"_distributed_container"):
return super(DistributedDelegate, self).__getattr__(name)
# This allows copy.copy(DistributedDelegate). When copying an object,
# copy.copy doesn't invoke its __init__ method, instead it makes a new
# empty object, then copies the attributes over. copy.copy looks for
# attributes like "__getstate__" in case the object implements its custom
# copying. Since DistributedDelegate doesn't have those attributes defined,
# __getattr__ will be invoked, which tries to access "_values" attributes,
# but that doesn't exist either because this is an empty object, and again
# __getattr__ is invoked, leading to an infinite recursion.
if name == "_values":
raise AttributeError()
# TODO(priyag): This needs to be made robust against pitfalls from mix use
# __getattr__ and @property. See b/120402273.
return getattr(self._get(), name)
@property
def values(self):
"""Returns the per replica values."""
return self._values
def _get_as_operand(self):
"""Returns the value for operations for the current device.
Some implementations, e.g. TPUMirroredVariable , are not able to return the
value type within a replica context. They can, however, return a value that
can be used by the operations below.
"""
return self._get()
def __add__(self, o):
return self._get_as_operand() + o
def __radd__(self, o):
return o + self._get_as_operand()
def __sub__(self, o):
return self._get_as_operand() - o
def __rsub__(self, o):
return o - self._get_as_operand()
# 省略大部分代碼
1.2.4 PerReplica
PerReplica 的作用是:持有一個map,用來維持從副本到未同步value的映射。
class PerReplica(DistributedValues, composite_tensor.CompositeTensor):
"""Holds a map from replica to unsynchronized values."""
@property
def _type_spec(self):
return PerReplicaSpec(
*(type_spec.type_spec_from_value(v) for v in self._values))
@property
def values(self):
"""Returns the per replica values."""
return self._values
1.2.5 Mirrored
接着我們來到 Mirrored這里。
圖 7 Mirrored
Mirrored 代表了在多個設備上創建的變量,其通過對每個副本應用相同的更新來保持變量的同步。鏡像變量(Mirrored variables)是用 tf.Variable(...synchronization=tf.VariableSynchronization.ON_WRITE...) 創建的。 通常它們只用於同步訓練。
回憶一下 DistributedValues 的功能,其保存一個從副本到值的映射,這些值將保持同步,其 _get_cross_replica 方法沒有實現。而 Mirrored 的目的是在跨副本模式(cross-replica mode)下可以直接使用。所以 Mirrored 這里實現了 _get_cross_replica。_get_cross_replica 就是調用基類 DistributedValues 的 _get_on_device_or_primary 方法(具體請參見對應小節),作用是返回本副本對應的數值,或者直接返回 _primary 對應的數值。
概念圖如下:
圖 8 Mirrored 如何計算
Mirrored 定義如下:
# Note that unlike PerReplica, Mirrored values inherit from
# DistributedDelegate and so can be used directly in cross-replica mode.
class Mirrored(DistributedDelegate):
"""Holds a map from replica to values which are kept in sync."""
def _get_cross_replica(self):
return self._get_on_device_or_primary() # 調用基類 DistributedValues 的方法
def _as_graph_element(self):
obj = self._get() # 調用基類 DistributedValues 的方法
conv_fn = getattr(obj, "_as_graph_element", None)
if conv_fn and callable(conv_fn):
return conv_fn()
return obj
1.2.6 Policy
我們接下來看看分布式策略。
圖 9 分布式策略
VariablePolicy
VariablePolicy 是分布式策略的基類,其定義了分布式變量的同步和聚合的策略。在 tf.distribution 范圍內創建變量時,鑒於 tf.Variable 上設置了 synchronization 和 aggregation 參數, tf.distribution 會創建一個適當的策略對象並將其分配給分布式變量。所有的變量操作都被委托給相應的策略對象來完成。
class VariablePolicy(object):
"""Policy defining synchronization and aggregation of a distributed variable.
Given synchronization and aggregation parameters set on a tf.Variable
during variable creation within tf.distribute scope, tf.distribute creates
an appropriate policy object and assigns it to the distributed variable. All
variable operations are delegated to the respective policy object.
"""
def __init__(self, aggregation):
self._aggregation = aggregation
def value(self):
raise NotImplementedError(
"VariablePolicy.value should be overriden by sub-classes.")
def _is_mirrored(self):
raise NotImplementedError(
"VariablePolicy._is_mirrored should be overriden by sub-classes.")
def _as_graph_element(self, _):
raise NotImplementedError(
"VariablePolicy._as_graph_element should be overriden by sub-classes.")
def _get_cross_replica(self, var):
raise NotImplementedError(
"VariablePolicy._get_cross_replica should be overriden by sub-classes.")
def _update_replica(self, var, update_fn, value, **kwargs):
raise NotImplementedError(
"VariablePolicy._update_replica should be overriden by sub-classes.")
OnReadPolicy
OnReadPolicy 是讀取策略,比如其成員變量 _get_cross_replica 就會調用 var.distribute_strategy.reduce 來完成讀取。
class OnReadPolicy(VariablePolicy):
"""Policy defined for tf.VariableSynchronization.ON_READ synchronization.
This policy is created when synchronization is set to
tf.VariableSynchronization.ON_READ and aggregation is set to any of the
values allowed by the tf.VariableAggregation enum such as NONE , SUM ,
MEAN or ONLY_FIRST_REPLICA when creating a tf.Variable in tf.distribute
scope.
"""
def _is_mirrored(self):
return False
def value(self, var):
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
if (ds_context.in_cross_replica_context() and
not values_util.in_replica_update_context()):
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
return var._get_replica(0).value()
return var._get_cross_replica()
else:
return var._get_on_device_or_primary().value()
def _as_graph_element(self, var):
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
if ds_context.in_cross_replica_context():
return ops.convert_to_tensor(var._get_cross_replica())
return var._get()._as_graph_element()
def _get_cross_replica(self, var):
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
return var._get_replica(0) # 從第一個副本讀取
if self._aggregation == vs.VariableAggregation.SUM:
values_util.mark_as_unsaveable() # 不能更新
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
# 調用 distribute_strategy 完成規約
return var.distribute_strategy.reduce(
reduce_util.ReduceOp.from_variable_aggregation(self._aggregation),
var,
axis=None)
def _update_replica(self, var, update_fn, value, **kwargs):
return update_fn(var._get_on_device_or_primary(), value, **kwargs)
def assign_add(self,
var,
value,
use_locking=False,
name=None,
read_value=True):
"""Adds a value to this variable."""
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
if (ds_context.in_cross_replica_context() and
not values_util.in_replica_update_context()):
values_util.mark_as_unsaveable()
return values_util.on_read_assign_add_cross_replica(
var, value, read_value=read_value)
else:
return values_util.on_write_assign_add(
var,
value,
use_locking=use_locking,
name=name,
read_value=read_value)
def assign(self, var, value, use_locking=False, name=None, read_value=True):
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
if (ds_context.in_cross_replica_context() and
not values_util.in_replica_update_context()):
values_util.mark_as_unsaveable()
return values_util.on_read_assign_cross_replica(
var, value, read_value=read_value)
else:
return values_util.on_write_assign(
var,
value,
use_locking=use_locking,
name=name,
read_value=read_value)
# 省略大部分代碼
OnWritePolicy
OnWritePolicy 類用來實現寫策略。其主要是調用 var._get_on_device_or_primary() 來完成各種操作,比如 _get_cross_replica 就是調用 var._get_on_device_or_primary() 來完成操作。 而且也調用了 values_util 之中的各種基礎操作。
class OnWritePolicy(VariablePolicy):
"""Policy defined for tf.VariableSynchronization.ON_WRITE synchronization.
This policy is created when the following synchronization and aggregation
parameters are specified when creating a tf.Variable in tf.distribute
scope and synchronization is equal to tf.VariableSynchronization.ON_WRITE
or tf.VariableSynchronization.AUTO .
"""
def _is_mirrored(self):
return True
def value(self, var):
return var._get_on_device_or_primary().value()
def _as_graph_element(self, var):
return var._get_on_device_or_primary()._as_graph_element()
def _get_cross_replica(self, var):
# Return identity, to avoid directly exposing the variable to the user and
# allowing it to be modified by mistake.
return array_ops.identity(var._get_on_device_or_primary())
# 調用 update_fn 和 _on_write_update_replica 來完成對應操作
def _update_replica(self, var, update_fn, value, **kwargs):
if var.aggregation == variables_lib.VariableAggregation.NONE:
return update_fn(var._get_on_device_or_primary(), value, **kwargs)
return _on_write_update_replica(var, update_fn, value, **kwargs)
def assign(self, var, value, use_locking=False, name=None, read_value=True):
return values_util.on_write_assign(
var, value, use_locking=use_locking, name=name, read_value=read_value)
def assign_add(self,
var,
value,
use_locking=False,
name=None,
read_value=True):
# 調用 values_util 完成工作
return values_util.on_write_assign_add(
var, value, use_locking=use_locking, name=name, read_value=read_value)
# 這里后續會提到
def scatter_update(self, var, sparse_delta, use_locking=False, name=None):
return values_util.scatter_update(
var, sparse_delta, use_locking=use_locking, name=name)
def get_saveable(self, var, primary_var, name):
"""Saveable ops for AUTO variables."""
return values_util.get_on_write_saveable(var, primary_var, name)
def get_restore_ops(self, var, tensor):
return values_util.get_on_write_restore_ops(var, tensor)
# 省略大部分代碼
values_util
上面兩種策略都使用了 on_write_assign_add ,其定義在 ensorflow/python/distribute/values_util.py 之中。
def on_write_assign_add(var, value, use_locking=False, name=None,
read_value=True):
assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
return var._update(
update_fn=assign_add_fn,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
OnWritePolicy 也使用了 values_util 定義的 scatter_update,發現其還是調用回到了 var._update。
def scatter_update(var, sparse_delta, use_locking=False, name=None):
scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
return var._update(
update_fn=scatter_update_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
1.2.7 DistributedVariable
順着類關系,我們最后來到 DistributedVariable,這里其實是 MirroredVariable 的主要功能所在。
圖 10 DistributedVariable
DistributedVariable 持有從副本到變量的映射,對於 MirroredVariable 來說,self._policy 就是 OnWritePolicy,具體更新變量就是通過 _policy 完成。
class DistributedVariable(DistributedDelegate, variables_lib.Variable,
core.Tensor):
"""Holds a map from replica to variables."""
def __init__(self, strategy, values, aggregation, var_policy=None):
if (aggregation == variables_lib.VariableAggregation.MEAN and
not values[0].dtype.is_floating):
raise ValueError(
"creating distributed tf.Variable with aggregation=MEAN and a "
"non-floating dtype is not supported, please use a different "
"aggregation or dtype")
self._distribute_strategy = strategy
self._aggregation = aggregation
super(DistributedVariable, self).__init__(values)
self._common_name = self._primary.name.split(":")[0]
# Use a weakref to make it easy to map from the contained values
# to the container without introducing a reference cycle.
for v in values:
v._distributed_container = weakref.ref(self) # pylint: disable=protected-access
# Packed variable is used to reduce the overhead of function execution.
# For a DistributedVariable, only one variable handle is captured into a
# function graph. It's only supported in eager mode.
if ops.executing_eagerly_outside_functions() and getattr(
strategy, "_enable_packed_variable_in_eager_mode", False):
name = "%s/packed/" % self._common_name
self._packed_var = packed.PackedDistributedVariable(values, name=name)
else:
self._packed_var = None
# tf.keras keeps track of variables initialized using this attribute. When
# tf.keras gets the default session, it initializes all uninitialized vars.
# We need to make _keras_initialized a member of DistributedVariable because
# without this it will use __getattr__ which will delegate to a component
# variable.
self._keras_initialized = False
# Typically, a DistributedVariable 's initializer is composed of the
# initializers of the components variables. However, in some cases, such as
# when restoring from a checkpoint, we may set the _initializer_op
# property on the entire DistributedVariable .
self._initializer_op = None
# Set a VariablePolicy which decides how we replicate/aggregate the given
# variable.
self._policy = var_policy
具體如何處理,需要看實際情況,但是最終都是歸結到 strategy 或者 strategy.extended 之上。
讀取
讀取時候,會調用 _get_cross_replica,其內部調用 Policy。而 Policy 會調用 distribute_strategy 完成規約。
def _get_cross_replica(self):
if values_util.is_saving_non_distributed():
return self._primary # 如果是非分布式存儲,就直接返回
if self._policy:
# 返回跨樣本
return self._policy._get_cross_replica(self)
raise NotImplementedError(
"DistributedVariable._get_cross_replica requires a valid "
"VariablePolicy. Please set the policy via the var_policy argument "
"in the constructor, or override this method in sub-classes which "
"support cross-replica accesses.")
具體如下:
圖 11 DistributedVariable 讀取
scatter_update
比如 scatter_update 也會調用 _policy 完成更新操作。
def scatter_update(self, sparse_delta, use_locking=False, name=None):
if values_util.is_saving_non_distributed():
return self._primary.scatter_update(sparse_delta, use_locking, name)
if self._policy:
return self._policy.scatter_update(
self, sparse_delta, use_locking=use_locking, name=name)
return values_util.scatter_update(
self, sparse_delta, use_locking=use_locking, name=name)
前面在 OnWritePolicy 之中討論過,scatter_update 最后會調用回到 DistributedVariable 自己的 _update 方法。
def scatter_update(var, sparse_delta, use_locking=False, name=None):
scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
return var._update(
update_fn=scatter_update_fn,
value=sparse_delta,
use_locking=use_locking,
name=name)
var._update 里面有各種運行路徑,我們只選擇部分分析。
def _update(self, update_fn, value, **kwargs):
"""Applies updates depending on the context.
The method calls _update_replica in replica context,
_update_cross_replica in cross replica context, and update_fn in update
context.
If read_value is True, the method returns the updated Variable. If
read_value is False, the method returns the update tf.Operation .
Args:
update_fn: A callable to pass to strategy.extended.update to update the
variable. It should have the same signature as Variable.assign() .
value: value to be passed to update_fn .
**kwargs: keyword arguments to update_fn .
Returns:
Updated variable or tf.Operation .
"""
if values_util.is_saving_non_distributed():
return update_fn(self._primary, value, **kwargs) # 非分布式
with ds_context.enter_or_assert_strategy(self.distribute_strategy):
if ds_context.in_cross_replica_context():
update_replica_id = distribute_lib.get_update_replica_id()
if update_replica_id is not None:
replica_value = self._get_replica(update_replica_id)
return update_fn(replica_value, value, **kwargs)
return self._update_cross_replica(update_fn, value, **kwargs) # 跨副本更新
else:
values_util.assert_replica_context(self.distribute_strategy)
return self._update_replica(update_fn, value, **kwargs)
然后調用了 _update_cross_replica 進行跨副本更新。
def _update_cross_replica(self, update_fn, value, **kwargs):
"""Applies updates across replicas.
Args:
update_fn: A callable to pass to strategy.extended.update to update the
variable. It should has the same signature as Variable.assign() .
value: value to be passed to update_fn .
**kwargs: remaining arguments to update_fn .
Returns:
Updated variable or tf.Operation .
"""
values_util.mark_as_unsaveable()
return self.distribute_strategy.extended.update(
self, update_fn, args=(value,), kwargs=kwargs, group=True)
我們展示如下:
圖 12 DistributedVariable 更新
1.2.8 存儲
我們接下來看看 MirroredVariable 如何存儲,可以看到,在 _saveable_factory 之中使用 _MirroredSaveable 完成存儲功能。
class MirroredVariable(DistributedVariable, Mirrored):
def _gather_saveables_for_checkpoint(self):
"""Overrides Trackable method.
This allows both name-based and object-based save and restore of
MirroredVariables.
Returns:
A dictionary mapping attribute names to SaveableObject factories.
"""
def _saveable_factory(name=self._common_name):
return _MirroredSaveable(self, self._primary, name)
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
_MirroredSaveable 來定義如何存儲 MirroredVariable。
class _MirroredSaveable(saveable_object.SaveableObject):
"""Class for defining how to restore a MirroredVariable."""
def __init__(self, mirrored_variable, primary_variable, name):
self._mirrored_variable = mirrored_variable
# 這里調用到
tensor, spec = values_util.get_on_write_saveable(self._mirrored_variable,
primary_variable, name)
super(_MirroredSaveable, self).__init__(tensor, spec, name)
def restore(self, restored_tensors, restored_shapes):
"""Restore the same value into all variables."""
tensor, = restored_tensors
return values_util.get_on_write_restore_ops(self._mirrored_variable, tensor)
get_on_write_saveable 代碼如下:
def get_on_write_saveable(var, primary_var, name):
"""Return saveable spec for AUTO and ON_WRITE variables."""
# We use a callable so that we don't have to evaluate this expression
# in the case where we are trying to restore instead of save.
def tensor():
if context.executing_eagerly() and not primary_var.is_initialized():
# A SaveSpec tensor value of None indicates that the variable is
# uninitialized.
return None
strategy = var.distribute_strategy
return strategy.extended.read_var(var) # 獲取張量
spec = saveable_object.SaveSpec(
tensor=tensor,
slice_spec="",
name=name,
dtype=var.dtype,
device=primary_var.device)
return tensor, [spec]
tensorflow/python/distribute/mirrored_strategy.py 這里會跨副本進行取值。
def read_var(self, replica_local_var):
"""Read the aggregate value of a replica-local variable."""
if distribute_utils.is_sync_on_read(replica_local_var):
return replica_local_var._get_cross_replica()
return array_ops.identity(replica_local_var._get())
1.2.9 小結
經過上述分析,最終我們得到 MirroredVariable 繼承體系注解版如下,其很多功能最終落實在 tf.distribute.Strategy 之上。
圖 13 MirroredVariable 繼承體系注解版
1.3 構建變量
在 MirroredStrategy 下創建的變量是一個 MirroredVariable。如果在策略的構造參數中沒有指定設備,那么它將使用所有可用的 GPU。如果沒有找到 GPU,它將使用可用的 CPU。請注意,TensorFlow 將一台機器上的所有 CPU 視為一個單一的設備,並在內部使用線程進行並行化。我們接下來看看如何構建 MirroredVariable。
1.3.1 StrategyBase
首先,在 tensorflow/python/distribute/distribute_lib.py 之中有如下代碼,說明關於 scope 的使用,還是 _extended 起了作用。
def scope(self):
"""Returns a context manager selecting this Strategy as current.
Inside a with strategy.scope(): code block, this thread
will use a variable creator set by strategy , and will
enter its "cross-replica context".
Returns:
A context manager.
"""
return self._extended._scope(self)
1.3.2 StrategyExtendedV2
於是我們來到了 StrategyExtendedV2。StrategyExtendedV2 這里調用了 creator_with_resource_vars 來提供一種如何創建變量的機制,creator_with_resource_vars 內部則調用派生類的_create_variable 來建立變量。
def _scope(self, strategy):
"""Implementation of tf.distribute.Strategy.scope()."""
def creator_with_resource_vars(next_creator, **kwargs):
"""Variable creator to use in _CurrentDistributionContext ."""
_require_strategy_scope_extended(self)
kwargs["use_resource"] = True
kwargs["distribute_strategy"] = strategy
# Unwrap initial_value if it is a CheckpointInitialValue to avoid
# dereferencing a Tensor that is without a name . We still need to
# propagate the metadata it's holding.
if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue):
checkpoint_restore_uid = kwargs[
"initial_value"].checkpoint_position.restore_uid
kwargs["initial_value"] = kwargs["initial_value"].wrapped_value
elif isinstance(kwargs["initial_value"],
trackable.CheckpointInitialValueCallable):
checkpoint_restore_uid = kwargs[
"initial_value"].checkpoint_position.restore_uid
elif (isinstance(kwargs["initial_value"], functools.partial) and
isinstance(kwargs["initial_value"].func,
trackable.CheckpointInitialValueCallable)):
# Some libraries (e.g, Keras) create partial function out of initializer
# to bind shape/dtype, for example:
# initial_val = functools.partial(initializer, shape, dtype=dtype)
# Therefore to get the restore_uid we need to examine the "func" of
# the partial function.
checkpoint_restore_uid = kwargs[
"initial_value"].func.checkpoint_position.restore_uid
else:
checkpoint_restore_uid = None
created = self._create_variable(next_creator, **kwargs)
if checkpoint_restore_uid is not None:
# Let the checkpointing infrastructure know that the variable was
# already restored so it doesn't waste memory loading the value again.
# In this case of CheckpointInitialValueCallable this may already be
# done by the final variable creator, but it doesn't hurt to do it
# again.
created._maybe_initialize_trackable()
created._update_uid = checkpoint_restore_uid
return created
def distributed_getter(getter, *args, **kwargs):
return getter(*args, **kwargs)
# 這里使用了 creator_with_resource_vars
return _CurrentDistributionContext(
strategy,
variable_scope.variable_creator_scope(creator_with_resource_vars), # 配置如何建立變量
variable_scope.variable_scope(
variable_scope.get_variable_scope(),
custom_getter=distributed_getter), self._default_device)
邏輯如下,進入scope之后經過一系列操作之后,返回了_CurrentDistributionContext,其內部又會有一系列操作,我們繼續看看。
圖 14 如何創建變量
1.3.3 _CurrentDistributionContext
_CurrentDistributionContext 維護了策略相關的信息,設置各種作用域,返回策略。
class _CurrentDistributionContext(object):
"""Context manager setting the current tf.distribute.Strategy .
Also: overrides the variable creator and optionally the current device.
"""
def __init__(self,
strategy,
var_creator_scope,
var_scope=None,
resource_creator_scope=None,
default_device=None):
self._context = distribution_strategy_context._CrossReplicaThreadMode(
strategy)
self._var_creator_scope = var_creator_scope
self._var_scope = var_scope
self._resource_creator_scope = resource_creator_scope
if default_device:
self._device_scope = ops.device(default_device)
else:
self._device_scope = None
self._same_scope_again_count = 0
def __enter__(self):
# Allow this scope to be entered if this strategy is already in scope.
if distribution_strategy_context.has_strategy():
_require_cross_replica_or_default_context_extended(
self._context.strategy.extended)
self._same_scope_again_count += 1
else:
_push_per_thread_mode(self._context)
if self._var_scope:
self._var_scope.__enter__()
self._var_creator_scope.__enter__()
if self._resource_creator_scope:
nest.map_structure(lambda scope: scope.__enter__(),
self._resource_creator_scope)
if self._device_scope:
self._device_scope.__enter__()
return self._context.strategy
def __exit__(self, exception_type, exception_value, traceback):
if self._same_scope_again_count > 0:
self._same_scope_again_count -= 1
return
if self._device_scope:
try:
self._device_scope.__exit__(exception_type, exception_value, traceback)
except RuntimeError as e:
six.raise_from(
RuntimeError("Device scope nesting error: move call to "
"tf.distribute.set_strategy() out of with scope."),
e)
try:
self._var_creator_scope.__exit__(
exception_type, exception_value, traceback)
except RuntimeError as e:
six.raise_from(
RuntimeError("Variable creator scope nesting error: move call to "
"tf.distribute.set_strategy() out of with scope."),
e)
if self._resource_creator_scope:
try:
if isinstance(self._resource_creator_scope, list):
reversed_resource_creator_scope = self._resource_creator_scope[::-1]
nest.map_structure(
lambda scope: scope.__exit__(exception_type, exception_value,
traceback),
reversed_resource_creator_scope)
else:
self._resource_creator_scope.__exit__(exception_type, exception_value,
traceback)
except RuntimeError as e:
six.raise_from(
RuntimeError("Resource creator scope nesting error: move call "
"to tf.distribute.set_strategy() out of with "
"scope."), e)
if self._var_scope:
try:
self._var_scope.__exit__(exception_type, exception_value, traceback)
except RuntimeError as e:
six.raise_from(
RuntimeError("Variable scope nesting error: move call to "
"tf.distribute.set_strategy() out of with scope."),
e)
_pop_per_thread_mode()
1.3.4 MirroredStrategy
有了上面的分析,我們可以知道,當使用了 Strategy 時候,會使用 Strategy 的 _create_variable 最終生成變量。
create_variable 負責具體業務。里面會用到 self._devices,然后調用到了 distribute_utils.create_mirrored_variable,其會使用 real_mirrored_creator,VARIABLE_CLASS_MAPPING 和 create_mirrored_variable 來建立變量。real_mirrored_creator會配置具體的變量名稱,后續調用則會據此來設定變量應該放到哪個設備之上。對於第一個設備,這里依然采用原來的名字,而后續設備則在原變量名之后加上 /replica_設備號 ,這樣就可以和原始變量區別。接着會把原來變量的值賦值給這些對應的副本變量。
def _create_variable(self, next_creator, **kwargs):
"""Create a mirrored variable. See DistributionStrategy.scope ."""
colocate_with = kwargs.pop("colocate_with", None)
if colocate_with is None:
devices = self._devices
elif isinstance(colocate_with, numpy_dataset.SingleDevice):
with ops.device(colocate_with.device):
return next_creator(**kwargs)
else:
devices = colocate_with._devices
def _real_mirrored_creator(**kwargs):
value_list = []
for i, d in enumerate(devices):
with ops.device(d):
kwargs["initial_value"] = self._get_variable_creator_initial_value(
replica_id=i,
device=d,
primary_var=value_list[0] if value_list else None,
**kwargs)
if i > 0:
# Give replicas meaningful distinct names:
var0name = value_list[0].name.split(":")[0]
# We append a / to variable names created on replicas with id > 0 to
# ensure that we ignore the name scope and instead use the given
# name as the absolute name of the variable.
kwargs["name"] = "%s/replica_%d/" % (var0name, i)
with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
# Don't record operations (e.g. other variable reads) during
# variable creation.
with tape.stop_recording():
v = next_creator(**kwargs)
assert not isinstance(v, values.DistributedVariable)
value_list.append(v)
return value_list
return distribute_utils.create_mirrored_variable(
self._container_strategy(), _real_mirrored_creator,
distribute_utils.VARIABLE_CLASS_MAPPING,
distribute_utils.VARIABLE_POLICY_MAPPING, **kwargs)
VARIABLE_CLASS_MAPPING 用來設定生成哪種類型的變量。VARIABLE_POLICY_MAPPING 設定使用何種策略來應對讀寫同步。
# The following mapping indicates the policy that you must use for a given
# variable synchronization and aggregation pair.
# OnWritePolicy is used for:
# (synchronization=Auto, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
# (synchronization=ON_WRITE, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
# OnReadPolicy is used for:
# (synchronization=ON_READ, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
VARIABLE_POLICY_MAPPING = {
vs.VariableSynchronization.ON_WRITE: values_lib.OnWritePolicy,
vs.VariableSynchronization.ON_READ: values_lib.OnReadPolicy,
}
VARIABLE_CLASS_MAPPING = {
"VariableClass": values_lib.DistributedVariable,
vs.VariableSynchronization.ON_WRITE: values_lib.MirroredVariable, # 我們關注這里
vs.VariableSynchronization.ON_READ: values_lib.SyncOnReadVariable,
}
1.3.5 distribute_utils
tensorflow/python/distribute/distribute_utils.py 的 create_mirrored_variable 會具體建立變量。對於我們的例子,class_mapping 就是 values_lib.MirroredVariable。
def create_mirrored_variable(strategy, real_mirrored_creator, class_mapping,
policy_mapping, **kwargs):
"""Create distributed variables with given synchronization and aggregation."""
# Figure out what collections this variable should be added to.
# We'll add the MirroredVariable to those collections instead.
var_collections = kwargs.pop("collections", None)
if var_collections is None:
var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
kwargs["collections"] = []
synchronization = _validate_synchronization(kwargs)
# Update synchronization in kwargs in case it's AUTO, which is converted to
# ON_WRITE.
kwargs["synchronization"] = synchronization
aggregation = _validate_aggregation(kwargs)
use_var_policy = getattr(strategy.extended, "_use_var_policy", False)
# Ignore user-specified caching device, not needed for mirrored variables.
kwargs.pop("caching_device", None)
with tape.stop_recording():
# 構建鏡像變量列表
value_list = real_mirrored_creator(**kwargs)
# MirroredVariable is recreated during saved_model loading, and its
# component variables (value_list) will have None initializer. We
# set their initializers to no_op so that consumer like
# global_variables_initializer wouldn't complain, as it groups all
# variables' initializers thus all variables have to have initializers.
for v in value_list:
if hasattr(v, "_initializer_op") and v._initializer_op is None:
v._initializer_op = control_flow_ops.no_op()
if use_var_policy:
# 獲取策略,得到類,生成變量
var_policy_cls = policy_mapping.get(synchronization)
var_policy = var_policy_cls(aggregation=aggregation)
var_cls = class_mapping.get("VariableClass")
result = var_cls(strategy, value_list, aggregation, var_policy=var_policy)
else:
var_cls = class_mapping.get(synchronization)
result = var_cls(strategy, value_list, aggregation)
# Add the wrapped variable to the requested collections.
# The handling of eager mode and the global step matches
# ResourceVariable._init_from_args().
if not context.executing_eagerly():
g = ops.get_default_graph()
# If "trainable" is True, next_creator() will add the member variables
# to the TRAINABLE_VARIABLES collection, so we manually remove
# them and replace with the MirroredVariable. We can't set
# "trainable" to False for next_creator() since that causes functions
# like implicit_gradients to skip those variables.
if kwargs.get("trainable", True):
var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
for value in value_list:
for i, trainable_variable in enumerate(l):
if value is trainable_variable:
del l[i]
break
g.add_to_collections(var_collections, result)
elif ops.GraphKeys.GLOBAL_STEP in var_collections:
ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
return result
最終構建邏輯如下,_CurrentDistributionContext 成員函數 _var_creator_scope 會指向 creator_with_resource_vars。當生成變量時候,調用時候 creator_with_resource_vars 會逐層調用,最后生成 MirroredVariable。
圖 15 創建變量
1.4 總結
前面的問題我們目前為止回答如下:
- 如何調用到 Strategy 這里?
- 讀寫變量最終都會落到 strategy 或者 strategy.extended 之上。
- 如何生成 Mirrored Variable?
- 用戶在 scope 之中會獲得上下文,上下文提供了建立變量的方法,用戶在上下文之中建立的變量自然就是 Mirrored Variable。
- 如何把張量分發到各個設備上?
- 當使用了 Strategy 時候,會使用 Strategy 的 _create_variable 生成變量。 _create_variable 最終調用到 _real_mirrored_creator 。
- _real_mirrored_creator 會配置具體的變量名稱,后續調用則會據此來設定變量應該放到哪個設備之上。對於第一個設備,這里依然采用原來的名字,而后續設備則在原變量名之后加上 /replica _設備號 ,這樣就可以和原始變量區別。
- 后續在布局(placement)時候,會根據設備名字進行分配,把變量放置到對應設備之上。
- 如果對外保持一個統一的視圖?
- 在上下文之中,用戶得到的是 Mirrored Variable, Mirrored Variable 對外屏蔽了內部變量,提供了統一視圖。比如:讀取時候,會調用 _get_cross_replica,其內部調用 Policy。而 Policy 會調用 distribute_strategy 完成規約。
- 變量之間如何保持一致?
- 在前面 scatter_update 分析時候知道,更新變量時候,會調用到 strategy.extended 之上,在 strategy.extended 中,變量之間通過例如 All-Reduce 來保持一致,這個我們后文會詳細分析。
用示例圖來演示下,假設有一個 MirroredVariable A 變量,其內部是由 3 個張量組成。每個 Worker 都覺得自己在更新 MirroredVariable A,實際上是分別更新不同的變量,變量之間通過例如 All-Reduce 來保持一致。
圖 16 如何更新
2. ShardedVariable
在機器學習訓練之中,如果變量太大,無法放入單個設備上(例如大型embedding),則可能需要在多個設備上對這個變量進行分片。在 TensorFlow 中,與這個思想對應的概念就是 ShardedVariable 。
圖 17 ShardedVariable
變量分片(Variable sharding)是指將一個變量分割成多個較小的變量,這些變量被稱為分片(shards)。ShardedVariable 可以被看做是一個容器,容器中的 "變量 "應被視為分片。ShardedVariable 類維護一個可以獨立存儲在不同設備(例如,多個參數服務器)上的較小變量的列表,並負責保存和恢復這些變量,就像它們是一個較大的變量一樣。變量分片對於緩解分配訪問這些分片時的網絡負載很有用,它對於在多個參數服務器上分配一個普通變量的計算和存儲也很有用。
圖 18 ShardedVariable 容器
ShardedVariable 類的對象可以用給定數量的分片進行保存,然后從檢查點恢復到不同數量的分片。SavedModel可以被 TF serving API 等程序使用,但是不支持 tf.saved_model.load 。由於 ShardedVariable 可以被保存,然后根據恢復環境恢復到不同數量的分片,例如,TF serving API 會恢復到只有一個分片以提高服務效率,所以當在tf.function 中使用 ShardedVariable 時,一般不應假設它在保存和加載時具有相同數量的分片。
2.1 問題
對於 ShardedVariable,我們依然用幾個問題來引導分析。
- 如何實現參數存到參數服務器之上?
- 如何對參數實現分片存儲?
- 如何把計算(梯度更新參數的操作)放到參數服務器之上?(會在后續章節進行分析)
- Coordinator 是隨機分配計算的嗎?(會在后續章節進行分析)
2.2 定義
ShardedVariable 的定義其實沒有太多內容,主要精華都在基類 ShardedVariableMixin 之中,我們稍后就會進行分析。
圖 19 ShardedVariable 定義
具體定義代碼如下:
class ShardedVariable(ShardedVariableMixin, composite_tensor.CompositeTensor):
"""A container for Variables that should be treated as shards.
"""
@property
def _type_spec(self):
return ShardedVariableSpec(
*(resource_variable_ops.VariableSpec(v.shape, v.dtype)
for v in self._variables))
@classmethod
def _overload_all_operators(cls):
"""Register overloads for all operators."""
for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
if operator == '__getitem__':
continue
cls._overload_operator(operator)
@classmethod
def _overload_operator(cls, operator):
"""Delegate an operator overload to ops.Tensor ."""
tensor_operator = getattr(ops.Tensor, operator)
def _operator(v, *args, **kwargs):
return tensor_operator(_var_to_tensor(v), *args, **kwargs)
setattr(cls, operator, _operator)
2.3 如何分區
ShardedVariable 的精華之一就是分區,我們探究一下其機理。需要注意的是:ShardedVariable 只支持在第一個維度進行分區。
2.3.1 基類
基類 Partitioner 沒有太多東西,其派生類需要實現 call。
@tf_export('distribute.experimental.partitioners.Partitioner', v1=[])
class Partitioner(object):
"""Partitioner base class: all partitiners inherit from this class.
Partitioners should implement a __call__ method with the following
signature:
```python
def __call__(self, shape, dtype, axis=0):
# Partitions the given shape and returns the partition results.
# See docstring of __call__ method for the format of partition results.
```
"""
def __call__(self, shape, dtype, axis=0):
"""Partitions the given shape and returns the partition results.
Examples of a partitioner that allocates a fixed number of shards:
```python
partitioner = FixedShardsPartitioner(num_shards=2)
partitions = partitioner(tf.TensorShape([10, 3], tf.float32), axis=0)
print(partitions) # [2, 0]
```
Args:
shape: a tf.TensorShape , the shape to partition.
dtype: a tf.dtypes.Dtype indicating the type of the partition value.
axis: The axis to partition along. Default: outermost axis.
Returns:
A list of integers representing the number of partitions on each axis,
where i-th value correponds to i-th axis.
"""
raise NotImplementedError
2.2.4 固定分區
FixedShardsPartitioner 會把變量分成固定的分片。注釋之中有一個使用樣例,對於本例來說,axis = 0 時候,min(self._num_shards, shape.dims[axis].value) = min(2, 10),所以分成兩個 shard。
@tf_export('distribute.experimental.partitioners.FixedShardsPartitioner', v1=[])
class FixedShardsPartitioner(Partitioner):
"""Partitioner that allocates a fixed number of shards.
Examples:
>>> # standalone usage:
>>> partitioner = FixedShardsPartitioner(num_shards=2)
>>> partitions = partitioner(tf.TensorShape([10, 3]), tf.float32)
>>> [2, 1]
>>>
>>> # use in ParameterServerStrategy
>>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
>>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
"""
def __init__(self, num_shards):
"""Creates a new FixedShardsPartitioner .
Args:
num_shards: int , number of shards to partition.
"""
self._num_shards = num_shards
def __call__(self, shape, dtype, axis=0):
del dtype
result = [1] * len(shape)
result[axis] = min(self._num_shards, shape.dims[axis].value)
return result
2.2.5 最小分區
MinSizePartitioner 為每個分片分配最小尺寸的分區器。該分區器確保每個分片至少有"min_shard_字節",並嘗試分配盡可能多的分片,即保持分片大小盡可能小。此類分片的最大數量(上限)由"max_Shard"給出。
@tf_export('distribute.experimental.partitioners.MinSizePartitioner', v1=[])
class MinSizePartitioner(Partitioner):
"""Partitioner that allocates a minimum size per shard.
This partitioner ensures each shard has at least min_shard_bytes , and tries
to allocate as many shards as possible, i.e., keeping shard size as small as
possible. The maximum number of such shards (upper bound) is given by
max_shards .
Examples:
>>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=2)
>>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
>>> [2, 1]
>>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=10)
>>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
>>> [6, 1]
>>>
>>> # use in ParameterServerStrategy
>>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
>>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
"""
def __init__(self,
min_shard_bytes=256 << 10,
max_shards=1,
bytes_per_string=16):
"""Creates a new MinSizePartitioner .
Args:
min_shard_bytes: Minimum bytes of each shard. Defaults to 256K.
max_shards: Upper bound on the number of shards. Defaults to 1.
bytes_per_string: If the partition value is of type string, this provides
an estimate of how large each string is.
"""
self._min_shard_bytes = min_shard_bytes
self._max_shards = max_shards
self._bytes_per_string = bytes_per_string
def __call__(self, shape, dtype, axis=0):
return partitioned_variables.min_max_variable_partitioner(
max_partitions=self._max_shards,
axis=axis,
min_slice_size=self._min_shard_bytes,
bytes_per_string_element=self._bytes_per_string)(shape, dtype)
min_max_variable_partitioner 是具體業務實現。該方法返回一個分區器,該分區器對"給定形狀和數據類型"的變量進行分區,使每個分區有的最小值為 min_slice_size 大小的切片。此類分區的最大數量(上限)由 max_partitions 給出。
@tf_export(v1=["min_max_variable_partitioner"])
def min_max_variable_partitioner(max_partitions=1, axis=0,
min_slice_size=256 << 10,
bytes_per_string_element=16):
"""Partitioner to allocate minimum size per slice.
Returns a partitioner that partitions the variable of given shape and dtype
such that each partition has a minimum of min_slice_size slice of the
variable. The maximum number of such partitions (upper bound) is given by
max_partitions .
Args:
max_partitions: Upper bound on the number of partitions. Defaults to 1.
axis: Axis along which to partition the variable. Defaults to 0.
min_slice_size: Minimum size of the variable slice per partition. Defaults
to 256K.
bytes_per_string_element: If the Variable is of type string, this provides
an estimate of how large each scalar in the Variable is.
Returns:
A partition function usable as the partitioner argument to
variable_scope and get_variable .
"""
def _partitioner(shape, dtype):
"""Partitioner that partitions list for a variable of given shape and type.
Ex: Consider partitioning a variable of type float32 with
shape=[1024, 1024].
If max_partitions >= 16, this function would return
[(1024 * 1024 * 4) / (256 * 1024), 1] = [16, 1].
If max_partitions < 16, this function would return
[ max_partitions , 1].
Args:
shape: Shape of the variable.
dtype: Type of the variable.
Returns:
List of partitions for each axis (currently only one axis can be
partitioned).
Raises:
ValueError: If axis to partition along does not exist for the variable.
"""
if axis >= len(shape):
raise ValueError("Can not partition variable along axis %d when shape is "
"only %s" % (axis, shape))
if dtype.base_dtype == dtypes.string:
bytes_per_element = bytes_per_string_element
else:
bytes_per_element = dtype.size
total_size_bytes = shape.num_elements() * bytes_per_element
partitions = total_size_bytes / min_slice_size
partitions_list = [1] * len(shape)
# We can not partition the variable beyond what its shape or
# max_partitions allows.
partitions_list[axis] = max(1, min(shape.dims[axis].value,
max_partitions,
int(math.ceil(partitions))))
return partitions_list
return _partitioner
2.3.4 最大分區
此分區器確保每個碎片最多有 max_shard_bytes 大的尺寸,並嘗試分配盡可能少的分片,即保持分片盡可能大。如果分區程序達到了 max_shard 限制,那么每個 shard 可能最終都會大於 max_shard_bytes。默認情況下,max_shards..等於 None,就是不限制分片的數量。
@tf_export('distribute.experimental.partitioners.MaxSizePartitioner', v1=[])
class MaxSizePartitioner(Partitioner):
"""Partitioner that keeps shards below max_shard_bytes .
This partitioner ensures each shard has at most max_shard_bytes , and tries
to allocate as few shards as possible, i.e., keeping shard size as large
as possible.
If the partitioner hits the max_shards limit, then each shard may end up
larger than max_shard_bytes . By default max_shards equals None and no
limit on the number of shards is enforced.
Examples:
>>> partitioner = MaxSizePartitioner(max_shard_bytes=4)
>>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
>>> [6, 1]
>>> partitioner = MaxSizePartitioner(max_shard_bytes=4, max_shards=2)
>>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
>>> [2, 1]
>>> partitioner = MaxSizePartitioner(max_shard_bytes=1024)
>>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
>>> [1, 1]
>>>
>>> # use in ParameterServerStrategy
>>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
>>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
"""
def __init__(self, max_shard_bytes, max_shards=None, bytes_per_string=16):
"""Creates a new MaxSizePartitioner .
Args:
max_shard_bytes: The maximum size any given shard is allowed to be.
max_shards: The maximum number of shards in int created taking
precedence over max_shard_bytes .
bytes_per_string: If the partition value is of type string, this provides
an estimate of how large each string is.
"""
if max_shard_bytes < 1:
raise ValueError('max_shard_bytes must be positive, got: %r' %
max_shard_bytes)
if max_shards and max_shards < 1:
raise ValueError('max_shards must be positive, got: %r' % max_shards)
if bytes_per_string < 1:
raise ValueError('bytes_per_string must be positive, got: %r' %
bytes_per_string)
self._max_shard_bytes = max_shard_bytes
self._max_shards = max_shards
self._bytes_per_string = bytes_per_string
def __call__(self, shape, dtype, axis=0):
return partitioned_variables.variable_axis_size_partitioner(
max_shard_bytes=self._max_shard_bytes,
max_shards=self._max_shards,
bytes_per_string_element=self._bytes_per_string,
axis=axis)(shape, dtype)
variable_axis_size_partitioner 是具體業務功能。此分區程序將沿一個軸切分一個變量,試圖將最大分片的大小保持在 max_shard_bytes 以下。如果分區程序達到了 max_shard 限制,那么每個 shard 可能最終都會大於 max_shard_bytes。默認情況下,max_shards 等於 None,意思是不限制碎片的數量。
max_shard_bytes 的一個合理值是(64<<20)-1,或者在 64MB 左右,這樣可以保證低於 protobuf 字節的限制。
@tf_export(v1=["variable_axis_size_partitioner"])
def variable_axis_size_partitioner(
max_shard_bytes, axis=0, bytes_per_string_element=16, max_shards=None):
"""Get a partitioner for VariableScope to keep shards below max_shard_bytes .
This partitioner will shard a Variable along one axis, attempting to keep
the maximum shard size below max_shard_bytes . In practice, this is not
always possible when sharding along only one axis. When this happens,
this axis is sharded as much as possible (i.e., every dimension becomes
a separate shard).
If the partitioner hits the max_shards limit, then each shard may end up
larger than max_shard_bytes . By default max_shards equals None and no
limit on the number of shards is enforced.
One reasonable value for max_shard_bytes is (64 << 20) - 1 , or almost
64MB , to keep below the protobuf byte limit.
Args:
max_shard_bytes: The maximum size any given shard is allowed to be.
axis: The axis to partition along. Default: outermost axis.
bytes_per_string_element: If the Variable is of type string, this provides
an estimate of how large each scalar in the Variable is.
max_shards: The maximum number of shards in int created taking precedence
over max_shard_bytes .
Returns:
A partition function usable as the partitioner argument to
variable_scope and get_variable .
Raises:
ValueError: If any of the byte counts are non-positive.
"""
def _partitioner(shape, dtype):
"""Partitioner that partitions shards to have max_shard_bytes total size.
Args:
shape: A TensorShape .
dtype: A DType .
Returns:
A tuple representing how much to slice each axis in shape.
Raises:
ValueError: If shape is not a fully defined TensorShape or dtype is not
a DType .
"""
if dtype.base_dtype == dtypes.string:
element_size = bytes_per_string_element
else:
element_size = dtype.size
partitions = [1] * shape.ndims
bytes_per_slice = 1.0 * (
shape.num_elements() / shape.dims[axis].value) * element_size
# How many slices can we fit on one shard of size at most max_shard_bytes?
# At least one slice is required.
slices_per_shard = max(1, math.floor(max_shard_bytes / bytes_per_slice))
# How many shards do we need for axis given that each shard fits
# slices_per_shard slices from a total of shape[axis] slices?
axis_shards = int(math.ceil(
1.0 * shape.dims[axis].value / slices_per_shard))
if max_shards:
axis_shards = min(max_shards, axis_shards)
partitions[axis] = axis_shards
return partitions
return _partitioner
2.4 ShardedVariableMixin
前面提到了,ShardedVariableMixin 是核心所在,我們接下來就分析一下。ShardedVariableMixin 主要成員變量是:
-
_variables : 分區的變量。
-
_var_offsets : 分區變量在 ShardedVariableMixin 對應的偏移,就是把 _variables 看成是一個整體,然后用 offset 在其中查找對應的數據。
-
_shape : ShardedVariableMixin 的 shape。
-
_name : ShardedVariableMixin 的名字。
class ShardedVariableMixin(trackable.Trackable):
"""Mixin for ShardedVariable."""
def __init__(self,
variables: Sequence[variables_lib.Variable],
name='ShardedVariable'):
"""Treats variables as shards of a larger Variable.
Args:
variables: A list of ResourceVariable s that comprise this sharded
variable. Variables should not be shared between different
ShardedVariableMixin objects.
name: String. Name of this container. Defaults to "ShardedVariable".
"""
super(ShardedVariableMixin, self).__init__()
self._variables = variables
self._name = name
var_dtypes = {v.dtype for v in variables}
first_var = variables[0]
self._dtype = first_var.dtype
# All variables must have the same shape for axes > 0.
# 計算整體形狀
higher_dim_shapes = {tuple(v.shape.as_list()[1:]) for v in variables}
first_dim = sum(int(v.shape.as_list()[0]) for v in variables)
self._shape = tensor_shape.TensorShape([first_dim] +
first_var.shape.as_list()[1:])
# 計算每個分區在整體之中的偏移
self._var_offsets = [
[0 for _ in range(len(first_var.shape))] for _ in range(len(variables))
]
for i in range(1, len(variables)):
# Always partition on the first axis. Offsets on other axes are 0.
self._var_offsets[i][0] += (
self._var_offsets[i - 1][0] + variables[i - 1].shape.as_list()[0])
save_slice_info = [v._get_save_slice_info() for v in variables]
# We create an uninitialized saving_variable with the full shape, which can
# be later captured in signatures so that the signatures can treat this
# ShardedVariable as one single variable.
self._saving_variable = resource_variable_ops.UninitializedVariable(
shape=self._shape, dtype=self._dtype, name=self._name)
2.4.1 使用
我們用如下示例看看如何使用。
variables = [
tf.Variable(np.array([[3, 2]]), shape=(1, 2), dtype=tf.float32,),
tf.Variable(np.array([[3, 2], [0, 1]]), shape=(2, 2), dtype=tf.float32),
tf.Variable(np.array([[3, 2]]), shape=(1, 2), dtype=tf.float32)
]
sharded_variable = ShardedVariableMixin(variables)
sharded_variable 內部成員變量打印如下,可以看到,_var_offsets 就是把所有參數分區看為是一個整體,從中找到對應的分區。
_shape = {TensorShape: 2} (4, 2)
_var_offsets = {list: 3} [[0, 0], [1, 0], [3, 0]]
first_dim = {int} 4
比如上面例子之中,三個變量整體打包之后就是如下所示,用戶可以使用 offset 在這里查找數據。
[[3,2][3,2],[0,1],[3,2]]
我們再用另一個圖例看看。假設參數有4個分區,則具體如下:
圖 20 分區
如果變量都放在參數服務器上,則具體如下。
圖 21 分區與參數服務器
2.4.2 獲取分區
我們接下來看看如何獲取分區。就是從 sharded variable 之中把指定部分作為一個張量取出。具體邏輯是:分析傳入的 spec, 根據 spec 的內容對 sharded variable 進行處理,獲得一個參數分區。
def __getitem__(self, slice_spec):
"""Extracts the specified region as a Tensor from the sharded variable.
The API contract is identical to Tensor.__getitem__ . Assignment to the
sliced range is not yet supported.
Args:
slice_spec: The arguments to __getitem__, specifying the global slicing of
the sharded variable.
Returns:
The appropriate slice of tensor based on slice_spec .
Raises:
IndexError: If a slice index is out of bound.
TypeError: If spec_spec contains Tensor.
"""
# 拿到分區 spec
if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and
slice_spec.dtype == dtypes.bool) or
(isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool)):
tensor = _var_to_tensor(self)
return array_ops.boolean_mask(tensor=tensor, mask=slice_spec)
if not isinstance(slice_spec, (list, tuple)):
slice_spec = (slice_spec,)
s = slice_spec[0]
if isinstance(s, slice):
# 如果是 slice 類型,則解析分區
first_dim_slice_specs = self._decompose_slice_spec(s)
values = []
for i, var in enumerate(self._variables):
if first_dim_slice_specs[i] is not None:
all_dim_slice_spec = (first_dim_slice_specs[i],) + slice_spec[1:]
values.append(var[all_dim_slice_spec])
if s.step is not None and s.step < 0:
values.reverse()
if not values:
return constant_op.constant([],
dtype=self._dtype,
shape=((0,) + self._shape[1:]))
return array_ops.concat(values, axis=0)
elif s is Ellipsis:
return array_ops.concat([var[slice_spec] for var in self._variables],
axis=0)
elif s is array_ops.newaxis:
return array_ops.concat([var[slice_spec[1:]] for var in self._variables],
axis=0)[array_ops.newaxis]
else:
if isinstance(s, ops.Tensor):
raise TypeError(
'ShardedVariable: using Tensor for indexing is not allowed.')
if s < 0:
s += self._shape[0]
# 在參數分區之中遍歷,用offset來提取數據
for i in range(len(self._variables)):
if i == len(self._variables) - 1 or (s > self._var_offsets[i][0] and
s < self._var_offsets[i + 1][0]):
return self._variables[i][(s - self._var_offsets[i][0],) +
slice_spec[1:]]
Spec 一般來說是什么樣式?下面示例講述的比較清晰。
For example, given component variables:
v0 = [0, 1, 2]
v1 = [3, 4, 5]
v2 = [6, 7, 8, 9]
If slice_spec is slice(start=None, stop=None, step=None), we will have:
v0[returned[0]] = [0, 1, 2]
v1[returned[1]] = [3, 4, 5]
v2[returned[2]] = [6, 7, 8, 9]
If slice_spec is slice(start=2, stop=8, step=3), we will have:
v0[returned[0]] = [2]
v1[returned[1]] = [5]
returned[2] == None
If slice_spec is slice(start=9, stop=3, step=-2), we will have:
returned[0] == None
v1[returned[1]] = [5]
v2[returned[2]] = [9, 7]
獲取/解析 spec 的代碼具體如下:
def _decompose_slice_spec(self, slice_spec):
"""Decompose a global slice_spec into a list of per-variable slice_spec.
ShardedVariable only supports first dimension partitioning, thus
slice_spec must be for first dimension.
Args:
slice_spec: A python slice object that specifies the global slicing.
Returns:
A list of python slice objects or None specifying the local slicing for
each component variable. None means no slicing.
"""
result = []
# Normalize start, end and stop.
slice_step = slice_spec.step if slice_spec.step is not None else 1
if slice_step == 0:
raise ValueError('slice step cannot be zero')
slice_start = slice_spec.start
if slice_start is None:
slice_start = 0 if slice_step > 0 else self._shape[0] - 1
elif slice_start < 0:
slice_start += self._shape[0]
slice_end = slice_spec.stop
if slice_end is None:
# After the normalization, we no longer interpret negative index, thus
# "-1" conceptually refers to the element before the first one, which
# doesn't exist. This is to ease the decomposition code.
slice_end = self._shape[0] if slice_step > 0 else -1
elif slice_end < 0:
slice_end += self._shape[0]
# To find the local slice_spec of each component variable, we start from
# the start of the global slice, and iterate through each variable.
# When iterating on a variable, we move the cursor ( cur ) to the first
# index that falls into the variable's range, which becomes the start of
# the variable's local slice_spec. The end of the local_spec is determined
# by using whatever is smaller between global slice end and variable range
# end.
cur = slice_start
if slice_step > 0:
for i in range(len(self._var_offsets)):
var_start = self._var_offsets[i][0]
var_end = (
self._var_offsets[i + 1][0]
if i < len(self._var_offsets) - 1 else self._shape[0])
if cur < var_start:
cur += slice_step * int(math.ceil((var_start - cur) / slice_step))
if cur >= var_end or cur >= slice_end:
result.append(None)
else:
start = cur - var_start
end = min(slice_end, var_end) - var_start
result.append(slice(start, end, slice_step))
else: # slice_step < 0
for i in range(len(self._var_offsets) - 1, -1, -1):
var_start = self._var_offsets[i][0]
var_end = (
self._var_offsets[i + 1][0]
if i < len(self._var_offsets) - 1 else self._shape[0])
if cur >= var_end:
cur += slice_step * int(math.ceil((var_end - cur - 1) / slice_step))
if cur < var_start or cur <= slice_end:
result.append(None)
else:
start = cur - var_start
if slice_end >= var_start:
end = slice_end - var_start
else:
end = None # no explicit end: slice until hitting the boundary.
result.append(slice(start, end, slice_step))
result.reverse()
return result
2.4.3 Embedding
接下來我們看看嵌入的查找。可以發現這里就是調用時候添加了對應的 partition_strategy,name, validate_indices, max_norm 等信息,然后傳遞給embedding_ops.embedding_lookup。這里分區策略是 'mod'。
# Override the behavior of embedding_lookup(sharded_variable, ...)
@dispatch.dispatch_for_types(embedding_ops.embedding_lookup, ShardedVariable)
def embedding_lookup(params,
ids,
partition_strategy='mod',
name=None,
validate_indices=True,
max_norm=None):
if isinstance(params, list):
params = params[0]
return embedding_ops.embedding_lookup(params.variables, ids,
partition_strategy, name,
validate_indices, max_norm)
流程來到 embedding_lookup(tensorflow/python/ops/embedding_ops.py),我們需要繼續看 _embedding_lookup_and_transform。
@tf_export(v1=["nn.embedding_lookup"])
@dispatch.add_dispatch_support
def embedding_lookup(
params,
ids,
partition_strategy="mod",
name=None,
validate_indices=True, # pylint: disable=unused-argument
max_norm=None):
"""Looks up embeddings for the given ids from a list of tensors.
This function is used to perform parallel lookups on the list of tensors in
params . It is a generalization of tf.gather , where params is
interpreted as a partitioning of a large embedding tensor. params may be
a PartitionedVariable as returned by using tf.compat.v1.get_variable()
with a partitioner.
If len(params) > 1 , each element id of ids is partitioned between
the elements of params according to the partition_strategy .
In all strategies, if the id space does not evenly divide the number of
partitions, each of the first (max_id + 1) % len(params) partitions will
be assigned one more id.
If the input ids are ragged tensors, partition variables are not supported and
the partition strategy and the max_norm are ignored.
The results of the lookup are concatenated into a dense
tensor. The returned tensor has shape shape(ids) + shape(params)[1:] .
Args:
params: A single tensor representing the complete embedding tensor, or a
list of P tensors all of same shape except for the first dimension,
representing sharded embedding tensors. Alternatively, a
PartitionedVariable , created by partitioning along dimension 0. Each
element must be appropriately sized for the given partition_strategy .
ids: A Tensor or a 'RaggedTensor' with type int32 or int64 containing
the ids to be looked up in params .
partition_strategy: A string specifying the partitioning strategy, relevant
if len(params) > 1 . Currently "div" and "mod" are supported. Default
is "mod" .
name: A name for the operation (optional).
validate_indices: DEPRECATED. If this operation is assigned to CPU, values
in indices are always validated to be within range. If assigned to GPU,
out-of-bound indices result in safe but unspecified behavior, which may
include raising an error.
max_norm: If not None , each embedding is clipped if its l2-norm is larger
than this value.
Returns:
A Tensor or a 'RaggedTensor', depending on the input, with the same type
as the tensors in params .
Raises:
ValueError: If params is empty.
"""
if isinstance(ids, ragged_tensor.RaggedTensor):
return embedding_lookup_ragged(params, ids,
partition_strategy=partition_strategy,
max_norm=max_norm,
name=name)
return _embedding_lookup_and_transform(
params=params,
ids=ids,
partition_strategy=partition_strategy,
name=name,
max_norm=max_norm,
transform_fn=None)
_embedding_lookup_and_transform 這里是具體如何分區的代碼,我們先用實例演示一下。
- 如果 "partition_strategy "是 "mod",我們將每個id分配給分區 p = id % len(params) 。例如。
13個ID被分割到5個分區中,結果如下: [[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]] - 如果 "partition_strategy "是 "div",我們會以連續的方式將ID分配給分區。在這個例子中,13個ID被分成5個分區,結果如下: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]] 。
具體代碼如下:
def _embedding_lookup_and_transform(params,
ids,
partition_strategy="mod",
name=None,
max_norm=None,
transform_fn=None):
"""Helper function for embedding_lookup and _compute_sampled_logits.
This function is a generalization of embedding_lookup that optionally
applies a caller-specified transformation to each embedding. This is
done through the transform_fn argument. If provided, the function is
applied to each partitioned tensor of retrieved embeddings, colocated
with the embeddings. This function will be called with a single Tensor
argument of the same type as the params tensor and should return a
Tensor . The shape of the argument will be the same as params except
for the size of the first dimension. The first dimension of the result's
shape must be the same size as the argument's.
Args:
params: See embedding_lookup.
ids: See embedding_lookup.
partition_strategy: See embedding_lookup.
name: See embedding_lookup.
max_norm: See embedding_lookup.
transform_fn: An optional function to apply to each retrieved embedding. If
max_norm is provided, transform_fn is applied to the norm-limited
embeddings.
Returns:
See embedding_lookup for details.
Raises:
ValueError: If params is empty.
"""
with ops.name_scope(name, "embedding_lookup", params + [ids]) as name:
# 省略代碼
else:
# Flatten the ids. There are two cases where we need to do this.
# - There is more than one params tensor.
# - There is a transform_fn and ids is not statically known to be 1-D.
# We must flatten in this case because transform_fn expects a flat
# tensor of embeddings.
flat_ids = array_ops.reshape(ids, [-1])
original_indices = math_ops.range(array_ops.size(flat_ids))
# Create p_assignments and set new_ids depending on the strategy.
if partition_strategy == "mod":
p_assignments = flat_ids % np
new_ids = flat_ids // np
elif partition_strategy == "div":
# Compute num_total_ids as the sum of dim-0 of params, then assign to
# partitions based on a constant number of ids per partition. Optimize
# if we already know the full shape statically.
dim_0_size = tensor_shape.Dimension(
tensor_shape.dimension_value(params[0].get_shape()[0]))
for p in xrange(1, np):
dim_0_size += tensor_shape.Dimension(
tensor_shape.dimension_value(params[p].get_shape()[0]))
if dim_0_size.value:
num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype)
else:
dim_0_sizes = []
for p in xrange(np):
param_p_dim = tensor_shape.dimension_value(params[p].get_shape()[0])
if param_p_dim is not None:
dim_0_sizes.append(param_p_dim)
else:
with ops.colocate_with(params[p]):
dim_0_sizes.append(array_ops.shape(params[p])[0])
num_total_ids = math_ops.reduce_sum(
math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype))
ids_per_partition = num_total_ids // np
extras = num_total_ids % np
p_assignments = math_ops.maximum(flat_ids // (ids_per_partition + 1),
(flat_ids - extras) //
ids_per_partition)
# Emulate a conditional using a boolean indicator tensor
new_ids = array_ops.where(p_assignments < extras,
flat_ids % (ids_per_partition + 1),
(flat_ids - extras) % ids_per_partition)
else:
raise ValueError("Unrecognized partition strategy: " +
partition_strategy)
# 省略其他代碼
如何使用 embedding?我們從注釋之中提取使用方法如下,這里構建了一個 ShardedVariable,模型通過 embedding_lookup 來對此變量進行操作。
>>> class Model(tf.Module):
... def __init__(self):
... self.sharded_variable = ShardedVariable([
... tf.Variable([3.0], dtype=tf.float32),
... tf.Variable([2.0], dtype=tf.float32)
... ])
...
... @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)])
... def fn(self, x):
... return tf.nn.embedding_lookup(self.sharded_variable.variables, x)
...
... @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)])
... def serve_fn(self, x):
... return tf.nn.embedding_lookup(self.sharded_variable.variables, x)
>>>
>>> model = Model()
>>> model.fn(1).numpy()
2.0
>>> tf.saved_model.save(model, export_dir='/tmp/saved_model',
... signatures=model.serve_fn)
如果用圖例表示,則下面 worker 會在兩個參數服務器上並行操作來提取 embedding。
圖 22 處理 embedding
2.5 構建
關於 ShardedVariable 的構建,我們直接看 ParameterServerStrategyV2 之中的構建過程。
2.5.1 變量分片
要啟用變量分片,你可以在構建 ParameterServerStrategy 對象時傳入一個 variable_partitioner。每次創建變量時,variable_partitioner 都會被調用,並希望它能沿變量的每個維度返回分片的數量。系統提供了一些開箱即用的 variable_partitioner,比如 tf.distribution.experimental.partitioners.MinSizePartitioner 。建議使用基於大小(size-based)的分區器,如 tf.distribution.experimental.partitioners.MinSizePartitioner ,以避免對小變量進行分區,因為那樣可能對模型訓練速度產生負面影響。
當傳入 variable_partitioner 時候,如果你直接在 strategy.scope() 下創建一個變量,它將成為一個具有 variables 屬性(property)的容器類型,此屬性將提供對分片列表的訪問。在大多數情況下,這個容器將通過連接(concatenating)所有的分片自動轉換為一個張量。因此,它可以作為一個正常的變量使用。另一方面,一些TensorFlow方法,如 tf.nn.embedding_lookup 為這種容器類型提供了有效的實現,這些方法可以避免自動連接。
3.2.4 初始化
在 ParameterServerStrategyV2Extended 初始化時候,會把傳入的 variable_partitioner 設置到 _variable_partitioner 之中,也會配置參數服務器數目和 worker 數目。
class ParameterServerStrategyV2Extended(
parameter_server_strategy.ParameterServerStrategyExtended):
"""Extended class for ParameterServerStrategyV2.
Please see tf.distribute.StrategyExtended doc for more information.
"""
def __init__(self, container_strategy, cluster_resolver,
variable_partitioner):
"""Initialization of ParameterServerStrategyV2Extended."""
super(ParameterServerStrategyV2Extended, self).__init__(container_strategy)
self._num_ps = len(cluster_resolver.cluster_spec().as_dict().get("ps", []))
self._num_workers = len(cluster_resolver.cluster_spec().as_dict().get(
"worker", []))
self._variable_count = 0
self._variable_partitioner = variable_partitioner
2.5.3 構建
我們接下來看看創建過程,也就是如何把變量分片到不同參數服務器上。具體思路是:
- 沒有配置分區生成器的話,就用 RR 策略(_create_variable_round_robin)把變量分配到參數服務器之上。
- 如果配置了分區生成器,則做如下操作:
- 對 rank-0 不做分區。
- 通過 _variable_partitioner 得到分區數目。
- 分區數目需要大於第一維,否則用第一維。
- 計算張量 offset。
- 生成很多小張量。
- 使用 _create_variable_round_robin 構建小張量列表。
- 用小張量列表來生成 ShardedVariable。
def _create_variable(self, next_creator, **kwargs):
"""Implements StrategyExtendedV2._create_variable.
Creates a Variable or a ShardedVariable . A ShardedVariable will be
created if satisfying all the following criteria:
1. self._variable_partitioner results in more than one partition on the
first axis.
2. variable's rank is greater than 0.
3. variable is not colocated with another variable.
Otherwise a Variable will be created.
Args:
next_creator: See variable_scope.variable_creator_scope ; the next
creator in the chain.
**kwargs: Passed through to the next creator.
Returns:
A Variable or ShardedVariable .
"""
var_creator = self._create_var_creator(next_creator, **kwargs)
if "colocate_with" in kwargs: # Never partition colocated_with variables.
colocate_with = kwargs["colocate_with"]
# Clear the variable scope to avoid possible conflicts between device
# scope and colocation scope.
with ops.device(None):
with ops.colocate_with(colocate_with):
var = var_creator(**kwargs)
return var
# 沒有配置分區生成器的話,就用 RR 策略把變量分配到參數服務器之上
if self._variable_partitioner is None:
return self._create_variable_round_robin(var_creator, **kwargs)
# 下面是配置了分區生成器
name = kwargs.get("name", None)
initial_value = kwargs.get("initial_value", None)
# Two cases where initial_value can be a callable:
# 1. initial_value is passed as a callable, e.g, an initializer class.
# 2. restoring from checkpoint, initial_value is a
# "CheckpointInitialValueCallable".
init_from_fn = callable(initial_value)
dtype = kwargs.get("dtype", None)
shape = kwargs.get("shape", None)
if init_from_fn and (shape is None or dtype is None):
init_from_fn = False
initial_value = initial_value()
if not init_from_fn:
# The initial_value is created on coordinator, it will need to be sent to
# ps for variable initialization, which can be inefficient and can
# potentially hit the 2GB limit on protobuf serialization.
initial_value = ops.convert_to_tensor(initial_value, dtype=dtype)
dtype = initial_value.dtype
shape = initial_value.shape
else:
shape = tensor_shape.as_shape(shape)
# rank-0 不做分區
if shape.rank == 0: # Skip partitioning rank-0 variable.
return self._create_variable_round_robin(var_creator, **kwargs)
# 得到分區數目
num_partitions = self._variable_partitioner(shape=shape, dtype=dtype)
if num_partitions[0] == 1: # no partition
return self._create_variable_round_robin(var_creator, **kwargs)
# 分區數目需要大於第一維,否則用第一維
# Use "div" partition strategy to partition the variable.
num_partitions = min(num_partitions[0], shape[0])
base = shape[0] // num_partitions
# 計算 offset
extra = shape[0] % num_partitions
# An example: num_partitions=4, shape[0]=10, partitions: [3, 3, 2, 2]
# offsets: [0, 3, 6, 8, 10]
offsets = []
for i in range(num_partitions):
if i == 0:
offsets.append(0)
else:
prev_shard_size = base + (1 if i - 1 < extra else 0)
offsets.append(offsets[i - 1] + prev_shard_size)
offsets.append(shape[0])
def init_shard_fn(shard_index):
if not init_from_fn:
return initial_value[offsets[shard_index]:offsets[shard_index + 1]]
partition_shape = (offsets[shard_index + 1] -
offsets[shard_index],) + shape[1:]
partition_offset = (offsets[shard_index],) + (0,) * len(shape[1:])
arg_spec = tf_inspect.getfullargspec(initial_value)
if ("shard_info" not in arg_spec.args and
"shard_info" not in arg_spec.kwonlyargs):
try:
value = initial_value(
partition_shape=partition_shape,
partition_offset=partition_offset)
except (TypeError, ValueError):
# TypeError: Initializer doesn't accept kwargs
# ValueError: Initializer doesn't accept partition kwargs
# In both cases we go ahead creating the full value and then slice.
value = initial_value()
if value.shape == partition_shape:
# Initializer supports partition: value is the partition value.
return value
else:
# Initializer doesn't support partition: value is the full value
# and needs to be sliced to get the partition value.
return value[offsets[shard_index]:offsets[shard_index + 1]]
else:
# For compatibility with CheckpointInitialValueCallable .
return initial_value(
shard_info=trackable.ShardInfo(
shape=tensor_shape.as_shape(partition_shape),
offset=partition_offset))
# 生成很多小張量
var_list = []
for i in range(num_partitions):
kwargs["shape"] = (offsets[i + 1] - offsets[i],) + shape[1:]
kwargs["initial_value"] = lambda: init_shard_fn(i) # 初始化
if name is not None:
kwargs["name"] = "{}/part_{}".format(name, i)
# 使用 _create_variable_round_robin 得到張量如何分配
var_list.append(self._create_variable_round_robin(var_creator, **kwargs))
#用小張量列表來生成 ShardedVariable
result = sharded_variable.ShardedVariable(var_list)
return result
上面邏輯之中,兩個分支都使用了 _create_variable_round_robin,其使用 RR 策略決定具體 placement 如何做。其實,就是給張量配置了對應的設備名字,后續做布局操作時候,就按照設備名字進行操作。
def _create_variable_round_robin(self, next_creator, **kwargs):
# Clear the colocation scope to avoid possible conflicts between device
# scope and colocation scope.
with ops.colocate_with(None, ignore_existing=True):
# Explicitly set CPU:0 device for PS in case create variable is called
# inside replica_fn and worker has with GPU:0 scope.
with ops.device("/job:ps/task:%d/device:CPU:0" %
(self._variable_count % self._num_ps)):
var = next_creator(**kwargs)
logging.debug(
"Creating variable (name:%s, shape:%r) on "
"/job:ps/task:%d/device:CPU:0",
var.name, var.shape, (self._variable_count % self._num_ps))
self._variable_count += 1
return var
_create_variable_round_robin 的參數 next_creator 一般來說是如下方法,這里使用了 AggregatingVariable 和 CachingVariable 來構建變量列表 var_list,然后才是利用 var_list 構建 ShardedVariable。我們主要介紹 AggregatingVariable。
def _create_var_creator(self, next_creator, **kwargs):
aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
def var_creator(**kwargs):
"""Create an AggregatingVariable."""
# Create and wrap the variable.
v = next_creator(**kwargs)
wrapped_v = ps_values.CachingVariable(v)
wrapped = ps_values.AggregatingVariable(self._container_strategy(),
wrapped_v, aggregation)
return wrapped
if self._num_replicas_in_sync > 1:
if aggregation not in (
vs.VariableAggregation.NONE,
vs.VariableAggregation.SUM,
vs.VariableAggregation.MEAN,
vs.VariableAggregation.ONLY_FIRST_REPLICA
):
raise ValueError("Invalid variable aggregation mode: " + aggregation +
" for variable: " + kwargs["name"])
return var_creator
else:
def variable_creator_single_replica(**kwargs):
v = next_creator(**kwargs)
return ps_values.CachingVariable(v)
return variable_creator_single_replica
2.5.4 AggregatingVariable
AggregatingVariable 作用是對變量進行包裝,該變量可以進行跨副本匯集更改。以 _assign_func 為例,可以看到,其使用 _distribute_strategy.extended.update 對變量進行操作。
# Variable used in PSStrategy TF 1, TF2 and CentralStorageStrategy.
class AggregatingVariable(resource_variable_ops.BaseResourceVariable,
core.Tensor):
"""A wrapper around a variable that aggregates updates across replicas."""
def __init__(self, strategy, v, aggregation):
self._distribute_strategy = strategy
self._v = v
# NOTE: We don't use "_distributed_container" here because we don't want
# to trigger that code path in regroup().
v._aggregating_container = weakref.ref(self) # pylint: disable=protected-access
self._aggregation = aggregation
def __deepcopy__(self, memo):
"""Perform a deepcopy of the AggregatingVariable .
Unlike the deepcopy of a regular tf.Variable, this keeps the original
strategy and devices of the AggregatingVariable . To avoid confusion
with the behavior of deepcopy on a regular Variable (which does
copy into new devices), we only allow a deepcopy of a AggregatingVariable
within its originating strategy scope.
Args:
memo: The memoization object for deepcopy .
Returns:
A deep copy of the current AggregatingVariable .
Raises:
RuntimeError: If trying to deepcopy into a different strategy.
"""
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
v = copy.deepcopy(self._v, memo)
copied_variable = type(self)(
strategy=self._distribute_strategy,
v=v,
aggregation=self._aggregation)
memo[id(self)] = copied_variable
return copied_variable
def get(self):
return self._v
@property
def distribute_strategy(self):
return self._distribute_strategy
def __getattr__(self, name):
return getattr(self._v, name)
def _assign_func(self, *args, **kwargs):
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
f = kwargs.pop("f")
if ds_context.in_cross_replica_context():
if distribute_lib.get_update_replica_id() is not None:
# We are calling an assign function in an update context.
return f(self._v, *args, **kwargs)
# We are calling an assign function in cross replica context, wrap it in
# an update call.
return self._distribute_strategy.extended.update(
self, f, args=args, kwargs=kwargs)
else:
replica_context = ds_context.get_replica_context()
# We are calling an assign function in replica context.
# We reduce the value we want to assign/add/sub. More details about how
# we handle the different use cases can be found in the _reduce method.
# We call the function with the reduced value.
if self._aggregation == vs.VariableAggregation.NONE:
raise ValueError(
values_util.aggregation_error_msg.format(
variable_type="AggregatingVariable"))
def merge_fn(strategy,
value,
use_locking=False,
name=None,
read_value=True):
v = values_util.apply_aggregation(strategy, value, self._aggregation,
self)
if name and isinstance(name, values.PerReplica):
name = name.values[0]
return strategy.extended.update(
self,
f,
args=(v,),
kwargs={
"use_locking": use_locking,
"name": name,
"read_value": read_value
})
return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs)
2.6 使用
下面示例展示了 ShardedVariable 如何使用。在 Dense 之中構建了一個 ShardedVariable,就是 self.w,其 shape 是 [100, 10],分區之后的結果是兩個 (50, 10) 的張量。
class Dense(tf.Module):
def __init__(self, name=None):
super().__init__(name=name)
self.w = tf.Variable(tf.random.normal([100, 10]), name='w')
def __call__(self, x):
return x * self.w
# Partition the dense layer into 2 shards.
variable_partitioner = (
tf.distribute.experimental.partitioners.FixedShardsPartitioner(
num_shards = 2))
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver=...,
variable_partitioner = variable_partitioner)
with strategy.scope():
dense = Dense() # 位於 strategy 上下文之中,於是生成的變量被自動分成 2 個分區。
assert len(dense.variables) == 2
assert isinstance(dense.variables[0], tf.Variable)
assert isinstance(dense.variables[1], tf.Variable)
assert dense.variables[0].shape == (50, 10)
assert dense.variables[1].shape == (50, 10)
ShardedVariable 也是一種形式上的模型並行,比如把 AB 這個矩陣分解到兩個參數服務器之上,分別與 C 相乘,最后把相乘結果在 worker 上聚合起來, concatenation 成一個最終結果張量。
圖 23 合並張量
0xFF 參考
tensorflow源碼解析之distributed_runtime
TensorFlow 篇 | TensorFlow 2.x 分布式訓練概覽
《用TensorFlow 2.4 實現分布式訓練》周玥楓 https://www.bilibili.com/video/BV1MT4y1M7Ym
深入 TensorFlow:參數服務器訓練 https://www.bilibili.com/video/BV1u5411H798