A layout to apply to a tensor.
tf.keras.distribution.TensorLayout(
axes, device_mesh=None
)
This API is aligned with jax.sharding.NamedSharding
and tf.dtensor.Layout
.
See more details in jax.sharding.NamedSharding
and tf.dtensor.Layout.
Args |
axes
|
tuple of strings that should map to the axis_names in
a DeviceMesh . For any dimensions that doesn't need any sharding,
A None can be used a placeholder.
|
device_mesh
|
Optional DeviceMesh that will be used to create
the layout. The actual mapping of tensor to physical device
is not known until the mesh is specified.
|
Attributes |
axes
|
|
device_mesh
|
|