Oryx는 JAX를 기반으로 빌드된 확률적 프로그래밍과 딥 러닝을 위한 라이브러리입니다.
import oryx import jax.numpy as jnp ppl = oryx.core.ppl tfd = oryx.distributions # Define sampling function def sample(key): x = ppl.random_variable(tfd.Normal(0., 1.))(key) return jnp.exp(x / 2.) + 2. # Transform sampling function into a log-density function ppl.log_prob(sample)(1.) # ==> -0.9189
Oryx의 접근 방식은 JAX의 기존 변환을 작성하고 통합하는 함수 변환 집합을 노출하는 것입니다. Oryx를 설치하려면 다음을 실행합니다.
pip install --upgrade oryx