Represents a dataset distributed among devices and machines.
A tf.distribute.DistributedDataset
could be thought of as a "distributed"
dataset. When you use tf.distribute
API to scale training to multiple
devices or machines, you also need to distribute the input data, which leads
to a tf.distribute.DistributedDataset
instance, instead of a
tf.data.Dataset
instance in the non-distributed case. In TF 2.x,
tf.distribute.DistributedDataset
objects are Python iterables.
There are two APIs to create a tf.distribute.DistributedDataset
object:
tf.distribute.Strategy.experimental_distribute_dataset(dataset)
and
tf.distribute.Strategy.distribute_datasets_from_function(dataset_fn)
.
When to use which? When you have a tf.data.Dataset
instance, and the
regular batch splitting (i.e. re-batch the input tf.data.Dataset
instance
with a new batch size that is equal to the global batch size divided by the
number of replicas in sync) and autosharding (i.e. the
tf.data.experimental.AutoShardPolicy
options) work for you, use the former
API. Otherwise, if you are not using a canonical tf.data.Dataset
instance,
or you would like to customize the batch splitting or sharding, you can wrap
these logic in a dataset_fn
and use the latter API. Both API handles
prefetch to device for the user. For more details and examples, follow the
links to the APIs.
There are two main usages of a DistributedDataset
object:
Iterate over it to generate the input for a single device or multiple devices, which is a
tf.distribute.DistributedValues
instance. To do this, you can:- use a pythonic for-loop construct:
global_batch_size = 4
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(4).batch(global_batch_size)
dist_dataset = strategy.experimental_distribute_dataset(dataset)
@tf.function
def train_step(input):
features, labels = input
return labels - 0.3 * features
for x in dist_dataset:
# train_step trains the model using the dataset elements
loss = strategy.run(train_step, args=(x,))
print("Loss is", loss)
Loss is PerReplica:{
0: tf.Tensor(
[[0.7]
[0.7]], shape=(2, 1), dtype=float32),
1: tf.Tensor(
[[0.7]
[0.7]], shape=(2, 1), dtype=float32)
}
Placing the loop inside a
tf.function
will give a performance boost. Howeverbreak
andreturn
are currently not supported if the loop is placed inside atf.function
. We also don't support placing the loop inside atf.function
when usingtf.distribute.experimental.MultiWorkerMirroredStrategy
ortf.distribute.experimental.TPUStrategy
with multiple workers.- use
__iter__
to create an explicit iterator, which is of typetf.distribute.DistributedIterator
global_batch_size = 4
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
train_dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(50).batch(global_batch_size)
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
@tf.function
def distributed_train_step(dataset_inputs):
def train_step(input):
loss = tf.constant(0.1)
return loss
per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,axis=None)
EPOCHS = 2
STEPS = 3
for epoch in range(EPOCHS):
total_loss = 0.0
num_batches = 0
dist_dataset_iterator = iter(train_dist_dataset)
for _ in range(STEPS):
total_loss += distributed_train_step(next(dist_dataset_iterator))
num_batches += 1
average_train_loss = total_loss / num_batches
template = ("Epoch {}, Loss: {:.4f}")
print (template.format(epoch+1, average_train_loss))
Epoch 1, Loss: 0.2000
Epoch 2, Loss: 0.2000
To achieve a performance improvement, you can also wrap the
strategy.run
call with atf.range
inside atf.function
. This runs multiple steps in atf.function
. Autograph will convert it to atf.while_loop
on the worker. However, it is less flexible comparing with running a single step insidetf.function
. For example, you cannot run things eagerly or arbitrary python code within the steps.Inspect the
tf.TypeSpec
of the data generated byDistributedDataset
.tf.distribute.DistributedDataset
generatestf.distribute.DistributedValues
as input to the devices. If you pass the input to atf.function
and would like to specify the shape and type of each Tensor argument to the function, you can pass atf.TypeSpec
object to theinput_signature
argument of thetf.function
. To get thetf.TypeSpec
of the input, you can use theelement_spec
property of thetf.distribute.DistributedDataset
ortf.distribute.DistributedIterator
object.For example:
global_batch_size = 4
epochs = 1
steps_per_epoch = 1
mirrored_strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
dataset = tf.data.Dataset.from_tensors(([2.])).repeat(100).batch(global_batch_size)
dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)
@tf.function(input_signature=[dist_dataset.element_spec])
def train_step(per_replica_inputs):
def step_fn(inputs):
return tf.square(inputs)
return mirrored_strategy.run(step_fn, args=(per_replica_inputs,))
for _ in range(epochs):
iterator = iter(dist_dataset)
for _ in range(steps_per_epoch):
output = train_step(next(iterator))
print(output)
PerReplica:{
0: tf.Tensor(
[[4.]
[4.]], shape=(2, 1), dtype=float32),
1: tf.Tensor(
[[4.]
[4.]], shape=(2, 1), dtype=float32)
}
Visit the tutorial on distributed input for more examples and caveats.
Attributes | |
---|---|
element_spec
|
The type specification of an element of this tf.distribute.DistributedDataset .
|
Methods
__iter__
__iter__()
Creates an iterator for the tf.distribute.DistributedDataset
.
The returned iterator implements the Python Iterator protocol.
Example usage:
global_batch_size = 4
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4]).repeat().batch(global_batch_size)
distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset))
print(next(distributed_iterator))
PerReplica:{
0: tf.Tensor([1 2], shape=(2,), dtype=int32),
1: tf.Tensor([3 4], shape=(2,), dtype=int32)
}
Returns | |
---|---|
An tf.distribute.DistributedIterator instance for the given
tf.distribute.DistributedDataset object to enumerate over the
distributed data.
|