Oryx est une bibliothèque basée sur JAX, conçue pour la programmation probabiliste et le deep learning.
import oryx import jax.numpy as jnp ppl = oryx.core.ppl tfd = oryx.distributions # Define sampling function def sample(key): x = ppl.random_variable(tfd.Normal(0., 1.))(key) return jnp.exp(x / 2.) + 2. # Transform sampling function into a log-density function ppl.log_prob(sample)(1.) # ==> -0.9189
L'approche d'Oryx consiste à exposer un ensemble de transformations de fonction qui composent les transformations actuelles de JAX et s'intègrent dans celles-ci. Pour installer Oryx, vous pouvez exécuter la commande suivante :
pip install --upgrade oryx