View source on GitHub
|
Module for transforming functions into FunctionModules.
In order to init functions, we need to define a Module subclass for them,
which is the FunctionModule. The FunctionModule encapsulates a Jaxpr that
is evaluated to execute the function, with special handling for keyword
arguments. This is useful for neural networks, where a keyword argument such as
training may change the semantics of a function. The next will be
adding a rule into the kwargs_rules dictionary, which is used in the
custom Jaxpr evaluator in FunctionModule. The kwargs_rules enables having
implementations for primitives that can change depending on the value of a
keyword argument. An example would be a neural network layer like dropout,
which has different behavior while training and not.
We also register functions with api.init. The init for functions first
inspects if the input function has a keyword argument init_key, and only if
that is the case does it harvest the function. This results in an opt-in
behavior for functions to be stateful.
To see documentation of init/spec/call_and_update and an example,
see api.py.
Classes
class FunctionModule: Encapsulates a staged function.
View source on GitHub