Oryx là một thư viện lập trình xác suất và học sâu được xây dựng dựa trên JAX.
import oryx import jax.numpy as jnp ppl = oryx.core.ppl tfd = oryx.distributions # Define sampling function def sample(key): x = ppl.random_variable(tfd.Normal(0., 1.))(key) return jnp.exp(x / 2.) + 2. # Transform sampling function into a log-density function ppl.log_prob(sample)(1.) # ==> -0.9189
Cách tiếp cận của Oryx là đưa ra một tập hợp các phép biến đổi hàm tạo nên và tích hợp với các phép biến đổi hiện có của JAX. Để cài đặt Oryx, bạn có thể chạy:
pip install --upgrade oryx