Oryx, JAX के शीर्ष पर निर्मित संभाव्य प्रोग्रामिंग और गहन शिक्षण के लिए एक लाइब्रेरी है।
import oryx import jax.numpy as jnp ppl = oryx.core.ppl tfd = oryx.distributions # Define sampling function def sample(key): x = ppl.random_variable(tfd.Normal(0., 1.))(key) return jnp.exp(x / 2.) + 2. # Transform sampling function into a log-density function ppl.log_prob(sample)(1.) # ==> -0.9189
ओरिक्स का दृष्टिकोण फ़ंक्शन परिवर्तनों के एक सेट को उजागर करना है जो JAX के मौजूदा परिवर्तनों के साथ बनता और एकीकृत होता है। Oryx को स्थापित करने के लिए, आप चला सकते हैं:
pip install --upgrade oryx