tf.distribute.ReplicaContext

A class with a collection of APIs that can be called in a replica context.

You can use tf.distribute.get_replica_context to get an instance of ReplicaContext, which can only be called inside the function passed to tf.distribute.Strategy.run.

strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1'])
def func():
  replica_context = tf.distribute.get_replica_context()
  return replica_context.replica_id_in_sync_group
strategy.run(func)
PerReplica:{
  0: <tf.Tensor: shape=(), dtype=int32, numpy=0>,
  1: <tf.Tensor: shape=(), dtype=int32, numpy=1>
}

strategy A tf.distribute.Strategy.
replica_id_in_sync_group An integer, a Tensor or None. Prefer an integer whenever possible to avoid issues with nested tf.function. It accepts a Tensor only to be compatible with tpu.replicate.

devices Returns the devices this replica is to be executed on, as a tuple of strings. (deprecated)

num_replicas_in_sync Returns number of replicas that are kept in sync.
replica_id_in_sync_group Returns the id of the replica.

This identifies the replica among all replicas that are kept in sync. The value of the replica id can range from 0 to tf.distribute.ReplicaContext.num_replicas_in_sync - 1.

strategy The current tf.distribute.Strategy object.

Methods

all_gather

View source

All-gathers value across all replicas along axis.

For all strategies except tf.distribute.TPUStrategy, the input value on different replicas must have the same rank, and their shapes must be the same in all dimensions except the axis-th dimension. In other words, their shapes cannot be different in a dimension d where d does not equal to the axis argument. For example, given a tf.distribute.DistributedValues with component tensors of shape (1, 2, 3) and (1, 3, 3) on two replicas, you can call all_gather(..., axis=1, ...) on it, but not all_gather(..., axis=0, ...) or all_gather(..., axis=2, ...). However, with tf.distribute.TPUStrategy, all tensors must have exactly the same rank and same shape.

You can pass in a single tensor to all-gather:

strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
@tf.function
def gather_value():
  ctx = tf.distribute.get_replica_context()
  local_value = tf.constant([1, 2, 3])
  return ctx.all_gather(local_value, axis=0)
result = strategy.run(gather_value)
result
PerReplica:{
  0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
  1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>
}
strategy.experimental_local_results(result)
(<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
dtype=int32)>,
<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
dtype=int32)>)

You can also pass in a nested structure of tensors to all-gather, say, a list:

strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
@tf.function
def gather_nest():
  ctx = tf.distribute.get_replica_context()
  value_1 = tf.constant([1, 2, 3])
  value_2 = tf.constant([[1, 2], [3, 4]])
  # all_gather a nest of `tf.distribute.DistributedValues`
  return ctx.all_gather([value_1, value_2], axis=0)
result = strategy.run(gather_nest)
result
[PerReplica:{
  0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
  1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>
}, PerReplica:{
  0: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
array([[1, 2],
       [3, 4],
       [1, 2],
       [3, 4]], dtype=int32)>,
  1: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
array([[1, 2],
       [3, 4],
       [1, 2],
       [3, 4]], dtype=int32)>
}]
strategy.experimental_local_results(result)
([<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
<tf.Tensor: shape=(4, 2), dtype=int32, numpy=
array([[1, 2],
       [3, 4],
       [1, 2],
       [3, 4]], dtype=int32)>],
       [<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
       <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
array([[1, 2],
       [3, 4],
       [1, 2],
       [3, 4]], dtype=int32)>])

What if you are all-gathering tensors with different shapes on different replicas? Consider the following example with two replicas, where you have value as a nested structure consisting of two items to all-gather, a and b.

On Replica 0, value is {'a': [0], 'b': [[0, 1]]}.

On Replica 1, value is {'a': [1], 'b': [[2, 3], [4, 5]]}.

Result for all_gather with axis=0 (on each of the replicas) is:

{'a': [1, 2], 'b': [[0, 1], [2, 3], [4, 5]]}

Args
value a nested structure of tf.Tensor which tf.nest.flatten accepts, or a tf.distribute.DistributedValues instance. The structure of the tf.Tensor need to be same on all replicas. The underlying tensor constructs can only be dense tensors with non-zero rank, NOT tf.IndexedSlices.
axis 0-D int32 Tensor. Dimension along which to gather.
options a tf.distribute.experimental.CommunicationOptions. Options to perform collective operations. This overrides the default options if the tf.distribute.Strategy takes one in the constructor. See tf.distribute.experimental.CommunicationOptions for details of the options.

Returns
A nested structure of tf.Tensor with the gathered values. The structure is the same as value.

all_reduce

View source

All-reduces value across all replicas.

strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
def step_fn():
  ctx = tf.distribute.get_replica_context()
  value = tf.identity(1.)
  return ctx.all_reduce(tf.distribute.ReduceOp.SUM, value)
strategy.experimental_local_results(strategy.run(step_fn))
(<tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
 <tf.Tensor: shape=(), dtype=float32, numpy=2.0>)

It supports batched operations. You can pass a list of values and it attempts to batch them when possible. You can also specify options to indicate the desired batching behavior, e.g. batch the values into multiple packs so that they can better overlap with computations.

strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
def step_fn():
  ctx = tf.distribute.get_replica_context()
  value1 = tf.identity(1.)
  value2 = tf.identity(2.)
  return ctx.all_reduce(tf.distribute.ReduceOp.SUM, [value1, value2])
strategy.experimental_local_results(strategy.run(step_fn))
([<tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
<tf.Tensor: shape=(), dtype=float32, numpy=4.0>],
[<tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
<tf.Tensor: shape=(), dtype=float32, numpy=4.0>])

Note that all replicas need to participate in the all-reduce, otherwise this operation hangs. Note that if there're multiple all-reduces, they need to execute in the same order on all replicas. Dispatching all-reduce based on conditions is usually error-prone.

This API currently can only be called in the replica context. Other variants to reduce values across replicas are:

Args
reduce_op a tf.distribute.ReduceOp value specifying how values should be combined. Allows using string representation of the enum such as "SUM", "MEAN".
value a nested structure of tf.Tensor which tf.nest.flatten accepts. The structure and the shapes of the tf.Tensor need to be same on all replicas.
options a tf.distribute.experimental.CommunicationOptions. Options to perform collective operations. This overrides the default options if the tf.distribute.Strategy takes one in the constructor. See tf.distribute.experimental.CommunicationOptions for details of the options.

Returns
A nested structure of tf.Tensor with the reduced values. The structure is the same as value.

merge_call

View source

Merge args across replicas and run merge_fn in a cross-replica context.

This allows communication and coordination when there are multiple calls to the step_fn triggered by a call to strategy.run(step_fn, ...).

See tf.distribute.Strategy.run for an explanation.

If not inside a distributed scope, this is equivalent to:

strategy = tf.distribute.get_strategy()
with cross-replica-context(strategy):
  return merge_fn(strategy, *args, **kwargs)

Args
merge_fn Function that joins arguments from threads that are given as PerReplica. It accepts tf.distribute.Strategy object as the first argument.
args List or tuple with positional per-thread arguments for merge_fn.
kwargs Dict with keyword per-thread arguments for merge_fn.

Returns
The return value of merge_fn, except for PerReplica values which are unpacked.