tf.train.experimental.ShardByTaskPolicy

Policy that splits tensors into shards based on their device spec task.

Inherits From: ShardingCallback

description

Methods

__call__

View source

Callback to split tensors into shards based on their device spec task.

Args
shardable_tensors A list of ShardableTensors.

Returns
List of shard dicts containing tensors. [ {checkpoint key: {slice_spec: tensor} } ]