Prawdopodobieństwo TensorFlow na JAX

Zobacz na TensorFlow.org Uruchom w Google Colab Wyświetl źródło na GitHub Pobierz notatnik

TensorFlow Prawdopodobieństwo (TFP) jest biblioteką do probabilistycznego rozumowania i analizy statystycznej, która teraz działa również na JAX ! Dla tych, którzy nie są zaznajomieni, JAX jest biblioteką do przyspieszonych obliczeń numerycznych opartych na transformacjach funkcji komponowalnych.

TFP w JAX obsługuje wiele najbardziej przydatnych funkcji zwykłego TFP, zachowując jednocześnie abstrakcje i interfejsy API, z którymi wielu użytkowników TFP jest teraz zadowolonych.

Ustawiać

TFP na JAX nie zależy TensorFlow; odinstalujmy całkowicie TensorFlow z tego Colab.

pip uninstall tensorflow -y -q

Możemy zainstalować TFP na JAX z najnowszymi nocnymi kompilacjami TFP.

pip install -Uq tfp-nightly[jax] > /dev/null

Zaimportujmy kilka przydatnych bibliotek Pythona.

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn import datasets
sns.set(style='white')
/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.
  import pandas.util.testing as tm

Zaimportujmy również kilka podstawowych funkcji JAX.

import jax.numpy as jnp
from jax import grad
from jax import jit
from jax import random
from jax import value_and_grad
from jax import vmap

Importowanie TFP na JAX

Aby korzystać TFP na JAX, wystarczy zaimportować jax „podłoża” i używać go jak zwykle będzie tfp :

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
tfpk = tfp.math.psd_kernels

Demo: Bayesowska regresja logistyczna

Aby zademonstrować, co możemy zrobić z backendem JAX, zaimplementujemy bayesowską regresję logistyczną stosowaną do klasycznego zestawu danych Iris.

Najpierw zaimportujmy zestaw danych Iris i wyodrębnijmy kilka metadanych.

iris = datasets.load_iris()
features, labels = iris['data'], iris['target']

num_features = features.shape[-1]
num_classes = len(iris.target_names)

Możemy zdefiniować model używając tfd.JointDistributionCoroutine . Umieścimy standardowe normalne prawdopodobieństwa a priori na obu wag i określenia polaryzacji następnie napisać target_log_prob funkcję szpilki objętych próbą etykiet do danych.

Root = tfd.JointDistributionCoroutine.Root
def model():
  w = yield Root(tfd.Sample(tfd.Normal(0., 1.),
                            sample_shape=(num_features, num_classes)))
  b = yield Root(
      tfd.Sample(tfd.Normal(0., 1.), sample_shape=(num_classes,)))
  logits = jnp.dot(features, w) + b
  yield tfd.Independent(tfd.Categorical(logits=logits),
                        reinterpreted_batch_ndims=1)


dist = tfd.JointDistributionCoroutine(model)
def target_log_prob(*params):
  return dist.log_prob(params + (labels,))

Mamy próbki z dist do wytworzenia stanu początkowego dla MCMC. Następnie możemy zdefiniować funkcję, która przyjmuje losowy klucz i stan początkowy, i tworzy 500 próbek z próbnika bez zawracania (NUTS). Należy pamiętać, że możemy użyć transformacje JAX jak jit skompilować nasz NUTS próbnika za pomocą XLA.

init_key, sample_key = random.split(random.PRNGKey(0))
init_params = tuple(dist.sample(seed=init_key)[:-1])

@jit
def run_chain(key, state):
  kernel = tfp.mcmc.NoUTurnSampler(target_log_prob, 1e-3)
  return tfp.mcmc.sample_chain(500,
      current_state=state,
      kernel=kernel,
      trace_fn=lambda _, results: results.target_log_prob,
      num_burnin_steps=500,
      seed=key)

states, log_probs = run_chain(sample_key, init_params)
plt.figure()
plt.plot(log_probs)
plt.ylabel('Target Log Prob')
plt.xlabel('Iterations of NUTS')
plt.show()

png

Użyjmy naszych próbek, aby wykonać uśrednianie modelu bayesowskiego (BMA) przez uśrednienie przewidywanych prawdopodobieństw każdego zestawu wag.

Najpierw napiszmy funkcję, która dla danego zestawu parametrów wygeneruje prawdopodobieństwa dla każdej klasy. Możemy użyć dist.sample_distributions do uzyskania ostatecznego rozkładu w modelu.

def classifier_probs(params):
  dists, _ = dist.sample_distributions(seed=random.PRNGKey(0),
                                       value=params + (None,))
  return dists[-1].distribution.probs_parameter()

Możemy vmap(classifier_probs) na zestaw próbek, aby uzyskać przewidywane prawdopodobieństwo klasy dla każdego z naszych próbek. Następnie obliczamy średnią dokładność w każdej próbce oraz dokładność z uśredniania modelu bayesowskiego.

all_probs = jit(vmap(classifier_probs))(states)
print('Average accuracy:', jnp.mean(all_probs.argmax(axis=-1) == labels))
print('BMA accuracy:', jnp.mean(all_probs.mean(axis=0).argmax(axis=-1) == labels))
Average accuracy: 0.96952
BMA accuracy: 0.97999996

Wygląda na to, że BMA zmniejsza nasz wskaźnik błędów o prawie jedną trzecią!

Podstawy

TFP na JAX ma identyczną API TF gdzie zamiast przyjmowania przedmiotów TF jak tf.Tensor Ś przyjmuje analog JAX. Na przykład, gdziekolwiek tf.Tensor był używany jako wejście API się oczekuje JAX DeviceArray . Zamiast zwracania tf.Tensor metody TFP powróci DeviceArray s. TFP na JAX współpracuje również z zagnieżdżonych struktur obiektów JAX, takich jak listy lub słownika DeviceArray s.

Dystrybucje

Większość dystrybucji TFP jest obsługiwana w JAX z bardzo podobną semantyką do ich odpowiedników TF. Są one również zarejestrowany jako JAX Pytrees , więc mogą być wejścia i wyjścia z JAX-przekształconych funkcji.

Podstawowe dystrybucje

log_prob metoda dystrybucji działa tak samo.

dist = tfd.Normal(0., 1.)
print(dist.log_prob(0.))
-0.9189385

Pobieranie próbek z rozkładu wymaga wyraźnie przechodzącą w PRNGKey (lub listę liczb całkowitych) jako seed argumentu słowa kluczowego. Niepowodzenie jawnego przekazania nasion spowoduje błąd.

tfd.Normal(0., 1.).sample(seed=random.PRNGKey(0))
DeviceArray(-0.20584226, dtype=float32)

Semantyka kształt rozkładów pozostają takie same w JAX, gdzie każdy będzie dystrybucje mieć event_shape i batch_shape i rysowania wielu próbek doda dodatkowe sample_shape wymiary.

Na przykład, tfd.MultivariateNormalDiag parametry wektora będzie mieć kształt zdarzeń wektora i pusty kształt wsadu.

dist = tfd.MultivariateNormalDiag(
    loc=jnp.zeros(5),
    scale_diag=jnp.ones(5)
)
print('Event shape:', dist.event_shape)
print('Batch shape:', dist.batch_shape)
Event shape: (5,)
Batch shape: ()

Z drugiej strony, tfd.Normal parametryzowane wektorami będą miały kształt skalarne i wektorowe wydarzenie partii kształt.

dist = tfd.Normal(
    loc=jnp.ones(5),
    scale=jnp.ones(5),
)
print('Event shape:', dist.event_shape)
print('Batch shape:', dist.batch_shape)
Event shape: ()
Batch shape: (5,)

Semantyka biorąc log_prob próbek działa tak samo w JAX też.

dist =  tfd.Normal(jnp.zeros(5), jnp.ones(5))
s = dist.sample(sample_shape=(10, 2), seed=random.PRNGKey(0))
print(dist.log_prob(s).shape)

dist =  tfd.Independent(tfd.Normal(jnp.zeros(5), jnp.ones(5)), 1)
s = dist.sample(sample_shape=(10, 2), seed=random.PRNGKey(0))
print(dist.log_prob(s).shape)
(10, 2, 5)
(10, 2)

Ponieważ JAX DeviceArray s są kompatybilne z bibliotekami jak NumPy i matplotlib, możemy nakarmić próbek bezpośrednio do funkcji kreślenia.

sns.distplot(tfd.Normal(0., 1.).sample(1000, seed=random.PRNGKey(0)))
plt.show()

png

Distribution metody są zgodne z przekształceń JAX.

sns.distplot(jit(vmap(lambda key: tfd.Normal(0., 1.).sample(seed=key)))(
    random.split(random.PRNGKey(0), 2000)))
plt.show()

png

x = jnp.linspace(-5., 5., 100)
plt.plot(x, jit(vmap(grad(tfd.Normal(0., 1.).prob)))(x))
plt.show()

png

Ponieważ dystrybucje TFP są zarejestrowane jako JAX węzłów pytree, możemy napisać funkcje z rozkładami jak wejść lub wyjść i przekształcić je za pomocą jit , ale nie są jeszcze obsługiwane jako argumenty vmap funkcje -ed.

@jit
def random_distribution(key):
  loc_key, scale_key = random.split(key)
  loc, log_scale = random.normal(loc_key), random.normal(scale_key)
  return tfd.Normal(loc, jnp.exp(log_scale))
random_dist = random_distribution(random.PRNGKey(0))
print(random_dist.mean(), random_dist.variance())
0.14389051 0.081832744

Przekształcone dystrybucje

Przekształcone dystrybucje tj dystrybucje, których próbki są przepuszczane przez Bijector również pracować z pudełka (bijectors działa też! Patrz poniżej).

dist = tfd.TransformedDistribution(
    tfd.Normal(0., 1.),
    tfb.Sigmoid()
)
sns.distplot(dist.sample(1000, seed=random.PRNGKey(0)))
plt.show()

png

Wspólne dystrybucje

TFP oferuje JointDistribution s, aby umożliwić łączenie rozkładu składników w pojedynczym rozkładzie na wielu zmiennych losowych. Obecnie oferuje trzy podstawowe warianty TFP ( JointDistributionSequential , JointDistributionNamed i JointDistributionCoroutine ), z których wszystkie są obsługiwane w JAX. W AutoBatched warianty są również wszystkie obsługiwane.

dist = tfd.JointDistributionSequential([
  tfd.Normal(0., 1.),
  lambda x: tfd.Normal(x, 1e-1)
])
plt.scatter(*dist.sample(1000, seed=random.PRNGKey(0)), alpha=0.5)
plt.show()

png

joint = tfd.JointDistributionNamed(dict(
    e=             tfd.Exponential(rate=1.),
    n=             tfd.Normal(loc=0., scale=2.),
    m=lambda n, e: tfd.Normal(loc=n, scale=e),
    x=lambda    m: tfd.Sample(tfd.Bernoulli(logits=m), 12),
))
joint.sample(seed=random.PRNGKey(0))
{'e': DeviceArray(3.376818, dtype=float32),
 'm': DeviceArray(2.5449684, dtype=float32),
 'n': DeviceArray(-0.6027825, dtype=float32),
 'x': DeviceArray([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32)}
Root = tfd.JointDistributionCoroutine.Root
def model():
  e = yield Root(tfd.Exponential(rate=1.))
  n = yield Root(tfd.Normal(loc=0, scale=2.))
  m = yield tfd.Normal(loc=n, scale=e)
  x = yield tfd.Sample(tfd.Bernoulli(logits=m), 12)

joint = tfd.JointDistributionCoroutine(model)

joint.sample(seed=random.PRNGKey(0))
StructTuple(var0=DeviceArray(0.17315261, dtype=float32), var1=DeviceArray(-3.290489, dtype=float32), var2=DeviceArray(-3.1949058, dtype=float32), var3=DeviceArray([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32))

Inne dystrybucje

Procesy Gaussa działają również w trybie JAX!

k1, k2, k3 = random.split(random.PRNGKey(0), 3)
observation_noise_variance = 0.01
f = lambda x: jnp.sin(10*x[..., 0]) * jnp.exp(-x[..., 0]**2)
observation_index_points = random.uniform(
    k1, [50], minval=-1.,maxval= 1.)[..., jnp.newaxis]
observations = f(observation_index_points) + tfd.Normal(
    loc=0., scale=jnp.sqrt(observation_noise_variance)).sample(seed=k2)

index_points = jnp.linspace(-1., 1., 100)[..., jnp.newaxis]

kernel = tfpk.ExponentiatedQuadratic(length_scale=0.1)

gprm = tfd.GaussianProcessRegressionModel(
    kernel=kernel,
    index_points=index_points,
    observation_index_points=observation_index_points,
    observations=observations,
    observation_noise_variance=observation_noise_variance)

samples = gprm.sample(10, seed=k3)
for i in range(10):
  plt.plot(index_points, samples[i], alpha=0.5)
plt.plot(observation_index_points, observations, marker='o', linestyle='')
plt.show()

png

Obsługiwane są również ukryte modele Markowa.

initial_distribution = tfd.Categorical(probs=[0.8, 0.2])
transition_distribution = tfd.Categorical(probs=[[0.7, 0.3],
                                                 [0.2, 0.8]])

observation_distribution = tfd.Normal(loc=[0., 15.], scale=[5., 10.])

model = tfd.HiddenMarkovModel(
    initial_distribution=initial_distribution,
    transition_distribution=transition_distribution,
    observation_distribution=observation_distribution,
    num_steps=7)

print(model.mean())
print(model.log_prob(jnp.zeros(7)))
print(model.sample(seed=random.PRNGKey(0)))
[3.       6.       7.5      8.249999 8.625001 8.812501 8.90625 ]
/usr/local/lib/python3.6/dist-packages/tensorflow_probability/substrates/jax/distributions/hidden_markov_model.py:483: UserWarning: HiddenMarkovModel.log_prob in TFP versions < 0.12.0 had a bug in which the transition model was applied prior to the initial step. This bug has been fixed. You may observe a slight change in behavior.
  'HiddenMarkovModel.log_prob in TFP versions < 0.12.0 had a bug '
-19.855635
[ 1.3641367  0.505798   1.3626463  3.6541772  2.272286  15.10309
 22.794212 ]

Kilka dystrybucje takie jak PixelCNN nie są jeszcze obsługiwane z powodu ścisłych zależnościach na TensorFlow lub XLA niezgodności.

Bijektory

Większość bijektorów TFP jest już obsługiwana w JAX!

tfb.Exp().inverse(1.)
DeviceArray(0., dtype=float32)
bij = tfb.Shift(1.)(tfb.Scale(3.))
print(bij.forward(jnp.ones(5)))
print(bij.inverse(jnp.ones(5)))
[4. 4. 4. 4. 4.]
[0. 0. 0. 0. 0.]
b = tfb.FillScaleTriL(diag_bijector=tfb.Exp(), diag_shift=None)
print(b.forward(x=[0., 0., 0.]))
print(b.inverse(y=[[1., 0], [.5, 2]]))
[[1. 0.]
 [0. 1.]]
[0.6931472 0.5       0.       ]
b = tfb.Chain([tfb.Exp(), tfb.Softplus()])
# or:
# b = tfb.Exp()(tfb.Softplus())
print(b.forward(-jnp.ones(5)))
[1.3678794 1.3678794 1.3678794 1.3678794 1.3678794]

Bijectors są kompatybilne z przekształceń JAX jak jit , grad i vmap .

jit(vmap(tfb.Exp().inverse))(jnp.arange(4.))
DeviceArray([     -inf, 0.       , 0.6931472, 1.0986123], dtype=float32)
x = jnp.linspace(0., 1., 100)
plt.plot(x, jit(grad(lambda x: vmap(tfb.Sigmoid().inverse)(x).sum()))(x))
plt.show()

png

Niektóre bijectors, jak RealNVP i FFJORD nie są jeszcze obsługiwane.

MCMC

Mamy przeniesiony tfp.mcmc Jax, tak więc możemy uruchomić algorytmy jak Hamiltona Monte Carlo (HMC) i No-U-Turn-Sampler (NUTS) w JAX.

target_log_prob = tfd.MultivariateNormalDiag(jnp.zeros(2), jnp.ones(2)).log_prob

W przeciwieństwie TFP na TF, jesteśmy zobowiązani przekazać PRNGKey do sample_chain pomocą seed argumentu słowa kluczowego.

def run_chain(key, state):
  kernel = tfp.mcmc.NoUTurnSampler(target_log_prob, 1e-1)
  return tfp.mcmc.sample_chain(1000,
      current_state=state,
      kernel=kernel,
      trace_fn=lambda _, results: results.target_log_prob,
      seed=key)
states, log_probs = jit(run_chain)(random.PRNGKey(0), jnp.zeros(2))
plt.figure()
plt.scatter(*states.T, alpha=0.5)
plt.figure()
plt.plot(log_probs)
plt.show()

png

png

Aby uruchomić wiele łańcuchów, możemy albo przekazać partii państw do sample_chain lub wykorzystanie vmap (choć jeszcze nie zbadane różnice w wynikach pomiędzy dwoma podejściami).

states, log_probs = jit(run_chain)(random.PRNGKey(0), jnp.zeros([10, 2]))
plt.figure()
for i in range(10):
  plt.scatter(*states[:, i].T, alpha=0.5)
plt.figure()
for i in range(10):
  plt.plot(log_probs[:, i], alpha=0.5)
plt.show()

png

png

Optymalizatory

TFP w JAX obsługuje kilka ważnych optymalizatorów, takich jak BFGS i L-BFGS. Skonfigurujmy prostą, skalowaną kwadratową funkcję straty.

minimum = jnp.array([1.0, 1.0])  # The center of the quadratic bowl.
scales = jnp.array([2.0, 3.0])  # The scales along the two axes.

# The objective function and the gradient.
def quadratic_loss(x):
  return jnp.sum(scales * jnp.square(x - minimum))

start = jnp.array([0.6, 0.8])  # Starting point for the search.

BFGS może znaleźć minimum tej straty.

optim_results = tfp.optimizer.bfgs_minimize(
    value_and_grad(quadratic_loss), initial_position=start, tolerance=1e-8)

# Check that the search converged
assert(optim_results.converged)
# Check that the argmin is close to the actual value.
np.testing.assert_allclose(optim_results.position, minimum)
# Print out the total number of function evaluations it took. Should be 5.
print("Function evaluations: %d" % optim_results.num_objective_evaluations)
Function evaluations: 5

Tak samo L-BFGS.

optim_results = tfp.optimizer.lbfgs_minimize(
    value_and_grad(quadratic_loss), initial_position=start, tolerance=1e-8)

# Check that the search converged
assert(optim_results.converged)
# Check that the argmin is close to the actual value.
np.testing.assert_allclose(optim_results.position, minimum)
# Print out the total number of function evaluations it took. Should be 5.
print("Function evaluations: %d" % optim_results.num_objective_evaluations)
Function evaluations: 5

Aby vmap L-BFGS, Ustawmy funkcji, która optymalizuje stratę jednego punktu początkowego.

def optimize_single(start):
  return tfp.optimizer.lbfgs_minimize(
      value_and_grad(quadratic_loss), initial_position=start, tolerance=1e-8)

all_results = jit(vmap(optimize_single))(
    random.normal(random.PRNGKey(0), (10, 2)))
assert all(all_results.converged)
for i in range(10):
  np.testing.assert_allclose(optim_results.position[i], minimum)
print("Function evaluations: %s" % all_results.num_objective_evaluations)
Function evaluations: [6 6 9 6 6 8 6 8 5 9]

Zastrzeżenia

Istnieją pewne podstawowe różnice między TF i JAX, niektóre zachowania TFP będą się różnić między dwoma podłożami i nie wszystkie funkcje są obsługiwane. Na przykład,

  • TFP na JAX nie obsługuje niczego podobnego tf.Variable ponieważ nic tak jak to występuje w JAX. Oznacza to także narzędzia takie jak tfp.util.TransformedVariable nie są obsługiwane albo.
  • tfp.layers nie jest obsługiwana w backend jeszcze, ze względu na uzależnienie od Keras i tf.Variable s.
  • tfp.math.minimize nie działa w TFP na JAX powodu jego uzależnienia od tf.Variable .
  • W TFP na JAX, kształty tensorów są zawsze konkretnymi wartościami całkowitymi i nigdy nie są nieznane/dynamiczne, jak w TFP na TF.
  • Pseudolosowość jest inaczej traktowana w TF i JAX (patrz załącznik).
  • Biblioteki w tfp.experimental nie są gwarantowane istnieć w podłożu JAX.
  • Zasady promocji Dtype są różne w TF i JAX. TFP na JAX próbuje wewnętrznie respektować semantykę dtype TF, aby zapewnić spójność.
  • Bijektory nie zostały jeszcze zarejestrowane jako pytrees JAX.

Aby zobaczyć pełną listę tego, co jest obsługiwana w TFP na JAX, należy zapoznać się z dokumentacją API .

Wniosek

Przenieśliśmy wiele funkcji TFP do JAX i nie możemy się doczekać, co zbudują wszyscy. Niektóre funkcje nie są jeszcze obsługiwane; jeśli straciłeś coś ważnego do ciebie (lub jeśli znajdziesz błąd!), skontaktuj się z nami - można wysłać tfprobability@tensorflow.org lub złożyć sprawę na naszej repo GitHub .

Dodatek: pseudolosowość w JAX

Model Jax generowania liczb pseudolosowych jest (PRNG) jest bezpaństwowcem. W przeciwieństwie do modelu stanowego, nie ma zmiennego stanu globalnego, który ewoluuje po każdym losowym losowaniu. W modelu Jax jest, zaczynamy z kluczem PRNG, który działa jak para 32-bitowych liczb całkowitych. Możemy skonstruować za pomocą tych klawiszy jax.random.PRNGKey .

key = random.PRNGKey(0)  # Creates a key with value [0, 0]
print(key)
[0 0]

Losowo wybrane funkcje w JAX spożywać klucza do deterministycznie produkować losową variate, co oznacza, że nie powinny być używane ponownie. Na przykład, możemy użyć key do próbki o rozkładzie normalnym wartość, ale nie powinno się używać key ponownie w innym miejscu. Ponadto, podając tę samą wartość na random.normal spowoduje taką samą wartość.

print(random.normal(key))
-0.20584226

Jak więc pobrać wiele próbek z jednego klucza? Odpowiedź jest kluczem łupania. Podstawowym założeniem jest to, że możemy podzielić PRNGKey na wiele, a każdy z nowych kluczy mogą być traktowane jako niezależne źródło losowości.

key1, key2 = random.split(key, num=2)
print(key1, key2)
[4146024105  967050713] [2718843009 1272950319]

Podział kluczy jest deterministyczny, ale chaotyczny, więc każdy nowy klucz może być teraz używany do losowania odrębnej próby.

print(random.normal(key1), random.normal(key2))
0.14389051 -1.2515389

Więcej informacji na temat modelu deterministycznego łupania klucz Jax, patrz tej instrukcji .