oryx.core.interpreters.propagate.propagate

Propagates cells in a Jaxpr using a set of rules.

cell_type used to instantiate literals into cells
rules maps JAX primitives to propagation rule functions
jaxpr used to construct the propagation graph
constcells used to populate the Jaxpr's constvars
incells used to populate the Jaxpr's invars
outcells used to populate the Jaxpr's outcells
reducer An optional callable used to reduce over the state at each equation in the Jaxpr. reducer takes in (env, eqn, state, new_state) as arguments and should return an updated state. The new_state value is provided by each equation.
initial_state The initial state value used in the reducer

The Jaxpr environment after propagation has terminated