View source on GitHub |
Changes the layout of tensor
to the same as layout_tensor
.
tf.experimental.dtensor.relayout_like(
tensor: tf.Tensor
,
layout_tensor: tf.Tensor
,
name: Optional[str] = None
) -> tf.Tensor
relayout_like
is often used inside a tf.function
, to ensure a tensor is
placed to the same mesh and with the same layout as another tensor.
The backward gradient of a relayout
is a relayout_like
operation, to
ensure the backward tensor has the same layout as the forward input tensor:
@ops.RegisterGradient("Relayout")
def _relayout_gradient(op, grad):
return relayout_like(grad, layout_input=op.inputs[0])
Here is another illustrative example:
@tf.function
def func(x):
z = tf.ones(x.shape)
z = dtensor.relayout_like(z, x)
return x + z
with dtensor.default_mesh(cpu_mesh):
x = tf.ones((4, 4))
with dtensor.default_mesh(gpu_mesh):
y = func(x)
# y would be on the cpu mesh, following the mesh of x.
Returns | |
---|---|
A DTensor output from the RelayoutLike op. |