Oryx, JAX के शीर्ष पर निर्मित संभाव्य प्रोग्रामिंग और गहन शिक्षण के लिए एक लाइब्रेरी है।

import oryx
import jax.numpy as jnp
ppl = oryx.core.ppl
tfd = oryx.distributions

# Define sampling function
def sample(key):
  x = ppl.random_variable(tfd.Normal(0., 1.))(key)
  return jnp.exp(x / 2.) + 2.

# Transform sampling function into a log-density function
ppl.log_prob(sample)(1.)  # ==> -0.9189
ओरिक्स का दृष्टिकोण फ़ंक्शन परिवर्तनों के एक सेट को उजागर करना है जो JAX के मौजूदा परिवर्तनों के साथ बनता और एकीकृत होता है। Oryx को स्थापित करने के लिए, आप चला सकते हैं:
 pip install --upgrade oryx