View source on GitHub |
Represents the layout information of a DTensor.
tf.experimental.dtensor.Layout(
sharding_specs: List[str],
mesh: tf.experimental.dtensor.Mesh
)
A layout describes how a distributed tensor is partitioned across a mesh (and
thus across devices). For each axis of the tensor, the corresponding
sharding spec indicates which dimension of the mesh it is sharded over. A
special sharding spec UNSHARDED
indicates that axis is replicated on
all the devices of that mesh.
For example, let's consider a 1-D mesh:
Mesh(["TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5"], [("x", 6)])
This mesh arranges 6 TPU devices into a 1-D array. Layout([UNSHARDED], mesh)
is a layout for rank-1 tensor which is replicated on the 6 devices.
For another example, let's consider a 2-D mesh:
Mesh(["TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5"],
[("x", 3), ("y", 2)])
This mesh arranges 6 TPU devices into a 3x2
2-D array.
Layout(["x", UNSHARDED], mesh)
is a layout for rank-2 tensor whose first
axis is sharded on mesh dimension "x" and the second axis is replicated. If we
place np.arange(6).reshape((3, 2))
using this layout, the individual
components tensors would look like:
Device | Component
TPU:0 [[0, 1]]
TPU:1 [[0, 1]]
TPU:2 [[2, 3]]
TPU:3 [[2, 3]]
TPU:4 [[4, 5]]
TPU:5 [[4, 5]]
Attributes | |
---|---|
mesh
|
|
rank
|
|
shape
|
|
sharding_specs
|
Methods
as_proto
as_proto()
as_proto(self: tensorflow.python._pywrap_dtensor_device.Layout) -> tensorflow::dtensor::LayoutProto
Returns the LayoutProto protobuf message.
batch_sharded
@classmethod
batch_sharded( mesh:
tf.experimental.dtensor.Mesh
, batch_dim: str, rank: int, axis: int = 0 ) -> 'Layout'
Returns a layout sharded on batch dimension.
delete
delete(
dims: List[int]
) -> 'Layout'
Returns the layout with the give dimensions deleted.
from_device
@classmethod
from_device( device: str ) -> 'Layout'
Constructs a single device layout from a single device mesh.
from_proto
@classmethod
from_proto( layout_proto: layout_pb2.LayoutProto ) -> 'Layout'
Creates an instance from a LayoutProto.
from_single_device_mesh
@classmethod
from_single_device_mesh( mesh:
tf.experimental.dtensor.Mesh
) -> 'Layout'
Constructs a single device layout from a single device mesh.
from_string
@classmethod
from_string( layout_str: str ) -> 'Layout'
Creates an instance from a human-readable string.
inner_sharded
@classmethod
inner_sharded( mesh:
tf.experimental.dtensor.Mesh
, inner_dim: str, rank: int ) -> 'Layout'
Returns a layout sharded on inner dimension.
is_batch_parallel
is_batch_parallel()
is_batch_parallel(self: tensorflow.python._pywrap_dtensor_device.Layout) -> bool
is_fully_replicated
is_fully_replicated()
is_fully_replicated(self: tensorflow.python._pywrap_dtensor_device.Layout) -> bool
Returns True if all tensor axes are replicated.
is_single_device
is_single_device()
is_single_device(self: tensorflow.python._pywrap_dtensor_device.Layout) -> bool
Returns True if the Layout represents a non-distributed device.
num_shards
num_shards()
num_shards(self: tensorflow.python._pywrap_dtensor_device.Layout, idx: int) -> int
Returns the number of shards for tensor dimension idx
.
offset_to_shard
offset_to_shard()
Mapping from offset in a flattened list to shard index.
offset_tuple_to_global_index
offset_tuple_to_global_index(
offset_tuple
)
Mapping from offset to index in global tensor.
replicated
@classmethod
replicated( mesh:
tf.experimental.dtensor.Mesh
, rank: int ) -> 'Layout'
Returns a replicated layout of rank rank
.
to_string
to_string()
to_string(self: tensorflow.python._pywrap_dtensor_device.Layout) -> str
__eq__
__eq__()
eq(self: tensorflow.python._pywrap_dtensor_device.Layout, arg0: tensorflow.python._pywrap_dtensor_device.Layout) -> bool