tf.keras.distribution.TensorLayout

A layout to apply to a tensor.

This API is aligned with jax.sharding.NamedSharding and tf.dtensor.Layout.

See more details in jax.sharding.NamedSharding and tf.dtensor.Layout.

axes tuple of strings that should map to the axis_names in a DeviceMesh. For any dimensions that doesn't need any sharding, A None can be used a placeholder.
device_mesh Optional DeviceMesh that will be used to create the layout. The actual mapping of tensor to physical device is not known until the mesh is specified.

axes

device_mesh