View source on GitHub |
Decorator to define a function with a custom gradient.
tf.keras.ops.custom_gradient(
f
)
This decorator allows fine grained control over the gradients of a sequence for operations. This may be useful for multiple reasons, including providing a more efficient or numerically stable gradient for a sequence of operations.
Returns | |
---|---|
A function h(*args) which returns the same value as
f(*args)[0] and whose gradient is determined by
f(*args)[1] .
|
Examples:
- Backend-agnostic example.
@ops.custom_gradient
def log1pexp(x):
e = ops.exp(x)
def grad(*args, upstream=None):
if upstream is None:
(upstream,) = args
return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))
return ops.log(1 + e), grad
Note that the grad function that returns gradient computation
requires args
as well as an upstream
keyword argument, depending
on the backend being set. With the JAX and TensorFlow backends,
it requires only one argument, whereas it might use the upstream
argument in the case of the PyTorch backend.
When working with TensorFlow/JAX backend, grad(upstream)
is sufficient. With PyTorch, the grad
function requires
*args
as well as upstream
, e.g. def grad(*args, upstream)
.
Follow the previous example to use @ops.custom_gradient
in
a way that is compatible with all backends.
- Here's JAX & TensorFlow-specific example:
@ops.custom_gradient
def log1pexp(x):
e = ops.exp(x)
def grad(upstream):
return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))
return ops.log(1 + e), grad
- Lastly, here's a PyTorch-specific example,
using
*args
&upstream
:
@ops.custom_gradient
def log1pexp(x):
e = ops.exp(x)
def grad(*args, upstream):
return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))
return ops.log(1 + e), grad