
Decorator to define a function with a custom gradient.

This decorator allows fine grained control over the gradients of a sequence for operations. This may be useful for multiple reasons, including providing a more efficient or numerically stable gradient for a sequence of operations.

f Function f(*args) that returns a tuple (output, grad_fn), where:

  • args is a sequence of (nested structures of) tensor inputs to the function.
  • output is a (nested structure of) tensor outputs of applying operations in forward_fn to args.
  • grad_fn is a function with the signature grad_fn(*args, upstream) which returns a tuple of tensors the same size as (flattened) args: the derivatives of tensors in output with respect to the tensors in args. upstream is a tensor or sequence of tensors holding the initial value gradients for each tensor in output.

A function h(*args) which returns the same value as f(*args)[0] and whose gradient is determined by f(*args)[1].


  1. Backend-agnostic example.
def log1pexp(x):
    e = ops.exp(x)

    def grad(*args, upstream=None):
        if upstream is None:
            (upstream,) = args
        return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))

    return ops.log(1 + e), grad

Note that the grad function that returns gradient computation requires args as well as an upstream keyword argument, depending on the backend being set. With the JAX and TensorFlow backends, it requires only one argument, whereas it might use the upstream argument in the case of the PyTorch backend.

When working with TensorFlow/JAX backend, grad(upstream) is sufficient. With PyTorch, the grad function requires *args as well as upstream, e.g. def grad(*args, upstream). Follow the previous example to use @ops.custom_gradient in a way that is compatible with all backends.

  1. Here's JAX & TensorFlow-specific example:
def log1pexp(x):
    e = ops.exp(x)
    def grad(upstream):
        return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))
    return ops.log(1 + e), grad
  1. Lastly, here's a PyTorch-specific example, using *args & upstream:
def log1pexp(x):
    e = ops.exp(x)
    def grad(*args, upstream):
        return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))
    return ops.log(1 + e), grad