Decorates/wraps Python functions containing JAX code as TFF computations.
tff.jax.computation(
*args, **kwargs
)
This wrapper can be used in a similar manner to tff.tensorflow.computation
,
with exception of the following:
The code in the wrapped Python function must be JAX code that can be compiled to XLA (e.g., code that one would expect to be able to annotate with
@jax.jit
).The inputs and outputs must be tensors, or (possibly recursively) nested structures of tensors. Sequences are currently not supported.
Example:
@tff.jax.computation(np.int32)
def comp(x):
return jax.numpy.add(x, np.int32(10))