Ver en TensorFlow.org | Ejecutar en Google Colab | Ver fuente en GitHub | Descargar cuaderno |
Introducción
TensorFlow Probabilidad (PTF) ofrece una serie de JointDistribution
abstracciones que hacen que la inferencia probabilística más fácil al permitir que un usuario de manera fácil expresar un modelo gráfico probabilístico en una forma casi matemática; la abstracción genera métodos para tomar muestras del modelo y evaluar la probabilidad logarítmica de las muestras del modelo. En este tutorial, se revisan variantes "autobatched", que fueron desarrollados después de los originales JointDistribution
abstracciones. En relación con las abstracciones originales, no autobatched, las versiones autobatched son más sencillas de usar y más ergonómicas, lo que permite que muchos modelos se expresen con menos repetición. En esta colab, exploramos un modelo simple con (quizás tedioso) detalle, dejando en claro los problemas que resuelve el autobatching y (con suerte) enseñándole al lector más sobre los conceptos de formas de TFP en el camino.
Antes de la introducción de autobatching, había unas pocas variantes diferentes de JointDistribution
, correspondientes a diferentes estilos sintácticas para expresar modelos probabilísticos: JointDistributionSequential
, JointDistributionNamed
, y JointDistributionCoroutine
. Auobatching existe como un mixin, así que ahora tenemos AutoBatched
variantes de todos ellos. En este tutorial, exploramos las diferencias entre JointDistributionSequential
y JointDistributionSequentialAutoBatched
; sin embargo, todo lo que hacemos aquí es aplicable a las otras variantes sin esencialmente cambios.
Dependencias y requisitos previos
Importación y configuraciones
import functools
import numpy as np
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import tensorflow_probability as tfp
tfd = tfp.distributions
Requisito previo: un problema de regresión bayesiana
Consideraremos un escenario de regresión bayesiana muy simple:
\[ \begin{align*} m & \sim \text{Normal}(0, 1) \\ b & \sim \text{Normal}(0, 1) \\ Y & \sim \text{Normal}(mX + b, 1) \end{align*} \]
En este modelo, m
y b
se han extraído de las normales estándar, y de las observaciones Y
se dibujan de una distribución normal cuya media depende de las variables aleatorias m
y b
, y algunos (no aleatoria, conocido) covariables X
. (Para simplificar, en este ejemplo, asumimos que se conoce la escala de todas las variables aleatorias).
Para llevar a cabo la inferencia en este modelo, tendríamos que conocer tanto las covariables X
y las observaciones Y
, pero para los propósitos de este tutorial, sólo necesitaremos X
, por lo que definimos un simple maniquí X
:
X = np.arange(7)
X
array([0, 1, 2, 3, 4, 5, 6])
Desiderata
En la inferencia probabilística, a menudo queremos realizar dos operaciones básicas:
-
sample
: Dibujo muestras a partir del modelo. -
log_prob
: Cálculo de la probabilidad de registro de una muestra a partir del modelo.
La contribución clave de la PTF de JointDistribution
abstracciones (así como de muchos otros enfoques de programación probabilística) es permitir a los usuarios escribir un modelo de una vez y tener acceso a ambas sample
y log_prob
cálculos.
Tomando nota de que tenemos 7 puntos en nuestro conjunto de datos ( X.shape = (7,)
), ahora podemos afirmar los desiderata para una excelente JointDistribution
:
-
sample()
debe producir una lista deTensors
que tienen forma[(), (), (7,)
], que corresponde a la pendiente escalar, el sesgo de escalar, y las observaciones del vector, respectivamente. -
log_prob(sample())
debe producir un escalar: la probabilidad de registro de una determinada pendiente, sesgo, y observaciones. -
sample([5, 3])
debe producir una lista deTensors
que tienen forma[(5, 3), (5, 3), (5, 3, 7)]
, lo que representa un(5, 3)
- lote de muestras de el modelo. -
log_prob(sample([5, 3]))
debe producir unaTensor
con forma (5, 3).
Ahora vamos a ver una sucesión de JointDistribution
modelos, vemos cómo lograr los desiderata anteriormente, y es de esperar aprender un poco más sobre la PTF da forma a lo largo del camino.
Alerta de spoiler: El enfoque que satisface los deseos anteriormente sin repetitivo añadido se autobatching .
Primer intento; JointDistributionSequential
jds = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
Esta es más o menos una traducción directa del modelo en código. La pendiente m
y el sesgo b
son sencillos. Y
se define utilizando un lambda
-Función: el patrón general es que un lambda
-Función de \(k\) argumentos en una JointDistributionSequential
(JDS) utiliza los anteriores \(k\) distribuciones en el modelo. Tenga en cuenta el orden "inverso".
Llamaremos sample_distributions
, que devuelve la vez una muestra y los subyacentes "sub-distribuciones" que se utilizaron para generar la muestra. (Podríamos haber producido solo la muestra llamando a sample
, y más tarde en el tutorial que será conveniente tener las distribuciones también.) La muestra que producimos está bien:
dists, sample = jds.sample_distributions()
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>, <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>, <tf.Tensor: shape=(7,), dtype=float32, numpy= array([ 0.18573815, -1.79962 , -1.8106272 , -3.5971394 , -6.6625295 , -7.308844 , -9.832693 ], dtype=float32)>]
Pero log_prob
produce un resultado con una forma no deseada:
jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy= array([-4.4777603, -4.6775575, -4.7430477, -4.647725 , -4.5746684, -4.4368567, -4.480562 ], dtype=float32)>
Y el muestreo múltiple no funciona:
try:
jds.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]
Tratemos de entender qué va mal.
Una breve reseña: forma de lote y evento
En la PTF, una (no una ordinaria JointDistribution
) distribución de probabilidad tiene una forma evento y una forma por lotes, y la comprensión de la diferencia es crucial para el uso eficaz de la PTF:
- La forma del evento describe la forma de un solo dibujo de la distribución; el dibujo puede depender de las dimensiones. Para distribuciones escalares, la forma del evento es []. Para un MultivariateNormal de 5 dimensiones, la forma del evento es [5].
- La forma de lote describe los sorteos independientes, no distribuidos de forma idéntica, también conocido como un "lote" de distribuciones. Representar un lote de distribuciones en un solo objeto de Python es una de las formas clave en que TFP logra eficiencia a escala.
Para nuestros propósitos, un hecho importante a tener en cuenta es que si llamamos log_prob
en una sola muestra de una distribución, el resultado siempre tendrá una forma que partidos (es decir, tiene como dimensiones más a la derecha) la forma de lotes.
Para una discusión más a fondo de las formas, vea el tutorial "Descripción de TensorFlow Distribuciones formas" .
¿Por qué no log_prob(sample())
Producir un escalar?
Vamos a utilizar nuestro conocimiento de lotes y el evento forma de explorar lo que está pasando con log_prob(sample())
. Aquí está nuestra muestra nuevamente:
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>, <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>, <tf.Tensor: shape=(7,), dtype=float32, numpy= array([ 0.18573815, -1.79962 , -1.8106272 , -3.5971394 , -6.6625295 , -7.308844 , -9.832693 ], dtype=float32)>]
Y aquí están nuestras distribuciones:
dists
[<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>, <tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>, <tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>]
La probabilidad logarítmica se calcula sumando las probabilidades logarítmicas de las subdistribuciones en los elementos (emparejados) de las partes:
log_prob_parts = [dist.log_prob(s) for (dist, s) in zip(dists, sample)]
log_prob_parts
[<tf.Tensor: shape=(), dtype=float32, numpy=-2.3113134>, <tf.Tensor: shape=(), dtype=float32, numpy=-1.1357536>, <tf.Tensor: shape=(7,), dtype=float32, numpy= array([-1.0306933, -1.2304904, -1.2959809, -1.200658 , -1.1276014, -0.9897899, -1.0334952], dtype=float32)>]
np.sum(log_prob_parts) - jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)>
Por lo tanto, un nivel de explicación es que el cálculo de probabilidades de registro está regresando a 7 Tensor porque el tercer subcomponente de log_prob_parts
está a 7 tensor. ¿Pero por qué?
Así, vemos que el último elemento de dists
, que corresponde a nuestra distribución en Y
en la formulación mathematial, tiene una batch_shape
de [7]
. En otras palabras, nuestra distribución en Y
es un lote de 7 normales independientes (con diferentes medios y, en este caso, la misma escala).
Ahora entendemos lo que está mal: en JDS, la distribución en Y
tiene batch_shape=[7]
, una muestra de la JDS representa escalares para m
y b
y un "lote" de 7 normales independientes. y log_prob
calcula 7 log-probabilidades separadas, cada una de las cuales representa la probabilidad de registro de dibujo m
y b
y una sola observación Y[i]
en algún X[i]
.
La fijación de log_prob(sample())
con Independent
Recordemos que dists[2]
tiene event_shape=[]
y batch_shape=[7]
:
dists[2]
<tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>
Mediante el uso de la PTF Independent
metadistribución, que convierte dimensiones de lotes a las dimensiones de eventos, podemos convertir esto en una distribución con event_shape=[7]
y batch_shape=[]
(vamos a renombramos y_dist_i
porque es una distribución en Y
, con el _i
de pie en nuestro Independent
envoltura):
y_dist_i = tfd.Independent(dists[2], reinterpreted_batch_ndims=1)
y_dist_i
<tfp.distributions.Independent 'IndependentJointDistributionSequential_sample_distributions_Normal' batch_shape=[] event_shape=[7] dtype=float32>
Ahora, el log_prob
de un 7-vector es un escalar:
y_dist_i.log_prob(sample[2])
<tf.Tensor: shape=(), dtype=float32, numpy=-7.9087086>
Bajo las sábanas, Independent
sumas sobre el lote:
y_dist_i.log_prob(sample[2]) - tf.reduce_sum(dists[2].log_prob(sample[2]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>
Y, de hecho, podemos usar esto para construir un nuevo jds_i
(la i
de nuevo es sinónimo de Independent
), donde log_prob
devuelve un escalar:
jds_i = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Independent( # Y
tfd.Normal(loc=m*X + b, scale=1.),
reinterpreted_batch_ndims=1)
])
jds_i.log_prob(sample)
<tf.Tensor: shape=(), dtype=float32, numpy=-11.355776>
Un par de notas:
-
jds_i.log_prob(s)
no es el mismo quetf.reduce_sum(jds.log_prob(s))
. El primero produce la probabilidad logarítmica "correcta" de la distribución conjunta. Estas últimas sumas más de un 7-Tensor, cada elemento de los cuales es la suma de la probabilidad de registro dem
,b
, y un único elemento de la probabilidad de registro deY
, por lo que overcountsm
yb
. (log_prob(m) + log_prob(b) + log_prob(Y)
devuelve un resultado en lugar de lanzar una excepción porque TFP sigue TF y las normas de radiodifusión de NumPy;. La adición de un escalar a un vector produce un resultado-vector de tamaño) - En este caso particular, podríamos haber resuelto el problema y ha logrado el mismo resultado utilizando
MultivariateNormalDiag
en lugar deIndependent(Normal(...))
.MultivariateNormalDiag
es una distribución de valor vectorial (es decir, que ya tiene vector evento de forma). IndeeedMultivariateNormalDiag
podría ser (pero no se), implementado como una composición deIndependent
yNormal
. Vale la pena recordar que dado un vectorV
, las muestras den1 = Normal(loc=V)
, yn2 = MultivariateNormalDiag(loc=V)
son indistinguibles; la diferencia beween estas distribuciones es quen1.log_prob(n1.sample())
es un vector yn2.log_prob(n2.sample())
es un escalar.
¿Varias muestras?
Dibujar varias muestras todavía no funciona:
try:
jds_i.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]
Pensemos por qué. Cuando llamamos jds_i.sample([5, 3])
, nos primer sorteo muestras para m
y b
, cada una con forma (5, 3)
. A continuación, vamos a tratar de construir una Normal
distribución a través de:
tfd.Normal(loc=m*X + b, scale=1.)
Pero si m
tiene la forma (5, 3)
y X
tiene forma de 7
, no se pueden multiplicar entre sí, y de hecho este es el error que está golpeando:
m = tfd.Normal(0., 1.).sample([5, 3])
try:
m * X
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]
Para resolver este problema, vamos a pensar en lo que las propiedades de la distribución a través de Y
tiene que tener. Si hemos llamado jds_i.sample([5, 3])
, entonces sabemos m
y b
ambos tienen la forma (5, 3)
. ¿Qué forma debe una llamada a sample
en la Y
productos de distribución? La respuesta obvia es (5, 3, 7)
: para cada punto de lote, queremos una muestra con el mismo tamaño que X
. Podemos lograr esto utilizando las capacidades de transmisión de TensorFlow, agregando dimensiones adicionales:
m[..., tf.newaxis].shape
TensorShape([5, 3, 1])
(m[..., tf.newaxis] * X).shape
TensorShape([5, 3, 7])
Adición de un eje a ambos m
y b
, podemos definir un nuevo JDS que soporta múltiples muestras:
jds_ia = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Independent( # Y
tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
reinterpreted_batch_ndims=1)
])
shaped_sample = jds_ia.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-1.1133379 , 0.16390413, -0.24177533], [-1.1312429 , -0.6224666 , -1.8182136 ], [-0.31343174, -0.32932565, 0.5164407 ], [-0.0119963 , -0.9079621 , 2.3655841 ], [-0.26293617, 0.8229698 , 0.31098196]], dtype=float32)>, <tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-0.02876974, 1.0872147 , 1.0138507 ], [ 0.27367726, -1.331534 , -0.09084719], [ 1.3349475 , -0.68765205, 1.680652 ], [ 0.75436825, 1.3050154 , -0.9415123 ], [-1.2502679 , -0.25730947, 0.74611956]], dtype=float32)>, <tf.Tensor: shape=(5, 3, 7), dtype=float32, numpy= array([[[-1.8258233e+00, -3.0641669e-01, -2.7595463e+00, -1.6952467e+00, -4.8197951e+00, -5.2986512e+00, -6.6931367e+00], [ 3.6438566e-01, 1.0067395e+00, 1.4542470e+00, 8.1155670e-01, 1.8868095e+00, 2.3877139e+00, 1.0195159e+00], [-8.3624744e-01, 1.2518480e+00, 1.0943471e+00, 1.3052304e+00, -4.5756745e-01, -1.0668410e-01, -7.0669651e-02]], [[-3.1788960e-01, 9.2615485e-03, -3.0963073e+00, -2.2846246e+00, -3.2269263e+00, -6.0213070e+00, -7.4806519e+00], [-3.9149747e+00, -3.5155020e+00, -1.5669601e+00, -5.0759468e+00, -4.5065498e+00, -5.6719379e+00, -4.8012795e+00], [ 1.3053948e-01, -8.0493152e-01, -4.7845001e+00, -4.9721808e+00, -7.1365709e+00, -9.6198196e+00, -9.7951422e+00]], [[ 2.0621397e+00, 3.4639853e-01, 7.0252883e-01, -1.4311566e+00, 3.3790007e+00, 1.1619035e+00, -8.9105040e-01], [-7.8956139e-01, -8.5023916e-01, -9.7148323e-01, -2.6229355e+00, -2.7150445e+00, -2.4633870e+00, -2.1841538e+00], [ 7.7627432e-01, 2.2401071e+00, 3.7601702e+00, 2.4245868e+00, 4.0690269e+00, 4.0605016e+00, 5.1753912e+00]], [[ 1.4275590e+00, 3.3346462e+00, 1.5374103e+00, -2.2849756e-01, 9.1219616e-01, -3.1220305e-01, -3.2643962e-01], [-3.1910419e-02, -3.8848895e-01, 9.9946201e-02, -2.3619974e+00, -1.8507402e+00, -3.6830821e+00, -5.4907336e+00], [-7.1941972e-02, 2.1602919e+00, 4.9575748e+00, 4.2317696e+00, 9.3528280e+00, 1.0526063e+01, 1.5262107e+01]], [[-2.3257759e+00, -2.5343289e+00, -3.5342445e+00, -4.0423255e+00, -3.2361765e+00, -3.3434000e+00, -2.6849220e+00], [ 1.5006512e-02, -1.9866472e-01, 7.6781356e-01, 1.6228745e+00, 1.4191239e+00, 2.6655579e+00, 4.4663467e+00], [ 2.6599693e+00, 1.2663836e+00, 1.7162113e+00, 1.4839669e+00, 2.0559487e+00, 2.5976877e+00, 2.5977583e+00]]], dtype=float32)>]
jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-12.483114 , -10.139662 , -11.514159 ], [-11.656767 , -17.201958 , -12.132455 ], [-17.838818 , -9.474525 , -11.24898 ], [-13.95219 , -12.490049 , -17.123957 ], [-14.487818 , -11.3755455, -10.576363 ]], dtype=float32)>
Como verificación adicional, verificaremos que la probabilidad de registro para un solo punto de lote coincida con lo que teníamos antes:
(jds_ia.log_prob(shaped_sample)[3, 1] -
jds_i.log_prob([shaped_sample[0][3, 1],
shaped_sample[1][3, 1],
shaped_sample[2][3, 1, :]]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>
AutoBatching para la victoria
¡Excelente! Ahora tenemos una versión de JointDistribution que maneja todos nuestros desiderata: log_prob
vuelve un escalar gracias a la utilización de tfd.Independent
, y múltiples muestras de trabajo ahora que hemos fijado la difusión mediante la adición de ejes adicionales.
¿Y si te dijera que hay una manera mejor y más fácil? No existe, y se llama JointDistributionSequentialAutoBatched
(JDSAB):
jds_ab = tfd.JointDistributionSequentialAutoBatched([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
jds_ab.log_prob(jds.sample())
<tf.Tensor: shape=(), dtype=float32, numpy=-12.954952>
shaped_sample = jds_ab.sample([5, 3])
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-12.191533 , -10.43885 , -16.371655 ], [-13.292994 , -11.97949 , -16.788685 ], [-15.987699 , -13.435732 , -10.6029 ], [-10.184758 , -11.969714 , -14.275676 ], [-12.740775 , -11.5654125, -12.990162 ]], dtype=float32)>
jds_ab.log_prob(shaped_sample) - jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], dtype=float32)>
¿Como funciona esto? Mientras que usted podría tratar de leer el código para una comprensión más profunda, vamos a dar una visión general breve que es suficiente para la mayoría de los casos de uso:
- Recordemos que nuestro primer problema era que nuestra distribución para
Y
teníabatch_shape=[7]
yevent_shape=[]
, y hemos utilizadoIndependent
para convertir la dimensión lote a una dimensión evento. JDSAB ignora las formas de lote de distribuciones de componentes; en lugar en que trata la forma de proceso por lotes como una propiedad general del modelo, que se supone que es[]
(a menos que se especifique lo contrario mediante el establecimiento debatch_ndims > 0
). El efecto es equivalente a usar tfd.Independent para convertir todas las dimensiones de los lotes de las distribuciones de componentes en dimensiones de eventos, como lo hicimos manualmente anteriormente. - Nuestro segundo problema era una necesidad para dar masajes a las formas de
m
yb
para que pudieran transmitir adecuadamente conX
cuando la creación de múltiples muestras. Con JDSAB, se escribe un modelo para generar una sola muestra, y "levantar" la totalidad del modelo para generar múltiples muestras usando de TensorFlow vectorized_map . (Esta función es análoga a la de JAX vmap ).
Explorar la cuestión de forma de lotes con más detalle, podemos comparar las formas de lote de nuestra original "mala" distribución conjunta jds
, nuestras distribuciones de lote fijo jds_i
y jds_ia
, y nuestra autobatched jds_ab
:
jds.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([7])]
jds_i.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ia.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ab.batch_shape
TensorShape([])
Vemos que los originales jds
tiene subdistributions con formas diferentes lotes. jds_i
y jds_ia
solucionar este problema mediante la creación de subdistributions con el mismo (vacío) forma de lotes. jds_ab
tiene sólo un único (vacío) forma de lotes.
Vale la pena señalar que JointDistributionSequentialAutoBatched
ofrece cierta generalidad adicional gratis. Supongamos que hacemos la covariables X
(e, implícitamente, las observaciones Y
) de dos dimensiones:
X = np.arange(14).reshape((2, 7))
X
array([[ 0, 1, 2, 3, 4, 5, 6], [ 7, 8, 9, 10, 11, 12, 13]])
Nuestra JointDistributionSequentialAutoBatched
funciona sin cambios (es necesario redefinir el modelo porque la forma de X
se almacena en caché por jds_ab.log_prob
):
jds_ab = tfd.JointDistributionSequentialAutoBatched([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
shaped_sample = jds_ab.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[ 0.1813647 , -0.85994506, 0.27593774], [-0.73323774, 1.1153806 , 0.8841938 ], [ 0.5127983 , -0.29271227, 0.63733214], [ 0.2362284 , -0.919168 , 1.6648189 ], [ 0.26317367, 0.73077047, 2.5395133 ]], dtype=float32)>, <tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[ 0.09636458, 2.0138032 , -0.5054413 ], [ 0.63941646, -1.0785882 , -0.6442188 ], [ 1.2310615 , -0.3293852 , 0.77637213], [ 1.2115169 , -0.98906034, -0.07816773], [-1.1318136 , 0.510014 , 1.036522 ]], dtype=float32)>, <tf.Tensor: shape=(5, 3, 2, 7), dtype=float32, numpy= array([[[[-1.9685398e+00, -1.6832136e+00, -6.9127172e-01, 8.5992378e-01, -5.3123581e-01, 3.1584005e+00, 2.9044402e+00], [-2.5645006e-01, 3.1554163e-01, 3.1186538e+00, 1.4272424e+00, 1.2843871e+00, 1.2266440e+00, 1.2798605e+00]], [[ 1.5973477e+00, -5.3631151e-01, 6.8143606e-03, -1.4910895e+00, -2.1568544e+00, -2.0513713e+00, -3.1663666e+00], [-4.9448099e+00, -2.8385928e+00, -6.9027486e+00, -5.6543546e+00, -7.2378774e+00, -8.1577444e+00, -9.3582869e+00]], [[-2.1233239e+00, 5.8853775e-02, 1.2024102e+00, 1.6622503e+00, -1.9197327e-01, 1.8647723e+00, 6.4322817e-01], [ 3.7549341e-01, 1.5853541e+00, 2.4594500e+00, 2.1952972e+00, 1.7517658e+00, 2.9666045e+00, 2.5468128e+00]]], [[[ 8.9906776e-01, 6.7375046e-01, 7.3354661e-01, -9.9894643e-01, -3.4606690e+00, -3.4810467e+00, -4.4315586e+00], [-3.0670738e+00, -6.3628020e+00, -6.2538433e+00, -6.8091092e+00, -7.7134805e+00, -8.6319380e+00, -8.6904278e+00]], [[-2.2462025e+00, -3.3060855e-01, 1.8974400e-01, 3.1422038e+00, 4.1483402e+00, 3.5642972e+00, 4.8709240e+00], [ 4.7880130e+00, 5.8790064e+00, 9.6695948e+00, 7.8112822e+00, 1.2022618e+01, 1.2411858e+01, 1.4323385e+01]], [[-1.0189297e+00, -7.8115642e-01, 1.6466728e+00, 8.2378983e-01, 3.0765080e+00, 3.0170646e+00, 5.1899948e+00], [ 6.5285158e+00, 7.8038850e+00, 6.4155884e+00, 9.0899811e+00, 1.0040427e+01, 9.1404457e+00, 1.0411951e+01]]], [[[ 4.5557004e-01, 1.4905317e+00, 1.4904103e+00, 2.9777462e+00, 2.8620450e+00, 3.4745665e+00, 3.8295493e+00], [ 3.9977460e+00, 5.7173767e+00, 7.8421035e+00, 6.3180594e+00, 6.0838981e+00, 8.2257290e+00, 9.6548376e+00]], [[-7.0750320e-01, -3.5972297e-01, 4.3136525e-01, -2.3301599e+00, -5.0374687e-01, -2.8338656e+00, -3.4453444e+00], [-3.1258626e+00, -3.4687450e+00, -1.2045374e+00, -4.0196013e+00, -5.8831010e+00, -4.2965469e+00, -4.1388311e+00]], [[ 2.1969774e+00, 2.4614549e+00, 2.2314475e+00, 1.8392437e+00, 2.8367062e+00, 4.8600502e+00, 4.2273531e+00], [ 6.1879644e+00, 5.1792760e+00, 6.1141996e+00, 5.6517797e+00, 8.9979610e+00, 7.5938139e+00, 9.7918644e+00]]], [[[ 1.5249090e+00, 1.1388919e+00, 8.6903995e-01, 3.0762129e+00, 1.5128503e+00, 3.5204377e+00, 2.4760864e+00], [ 3.4166217e+00, 3.5930209e+00, 3.1694956e+00, 4.5797420e+00, 4.5271711e+00, 2.8774328e+00, 4.7288942e+00]], [[-2.3095846e+00, -2.0595703e+00, -3.0093951e+00, -3.8594103e+00, -4.9681158e+00, -6.4256043e+00, -5.5345035e+00], [-6.4306297e+00, -7.0924540e+00, -8.4075985e+00, -1.0417805e+01, -1.1727266e+01, -1.1196255e+01, -1.1333830e+01]], [[-7.0419472e-01, 1.4568675e+00, 3.7946482e+00, 4.8489718e+00, 6.6498446e+00, 9.0224218e+00, 1.1153137e+01], [ 1.0060651e+01, 1.1998097e+01, 1.5326431e+01, 1.7957514e+01, 1.8323889e+01, 2.0160881e+01, 2.1269085e+01]]], [[[-2.2360647e-01, -1.3632748e+00, -7.2704530e-01, 2.3558271e-01, -1.0381399e+00, 1.9387857e+00, -3.3694571e-01], [ 1.6015106e-01, 1.5284677e+00, -4.8567140e-01, -1.7770648e-01, 2.1919653e+00, 1.3015286e+00, 1.3877077e+00]], [[ 1.3688663e+00, 2.6602898e+00, 6.6657305e-01, 4.6554832e+00, 5.7781887e+00, 4.9115267e+00, 4.8446012e+00], [ 5.1983776e+00, 6.2297459e+00, 6.3848300e+00, 8.4291229e+00, 7.1309576e+00, 1.0395646e+01, 8.5736713e+00]], [[ 1.2675294e+00, 5.2844582e+00, 5.1331611e+00, 8.9993315e+00, 1.0794343e+01, 1.4039831e+01, 1.5731170e+01], [ 1.9084715e+01, 2.2191265e+01, 2.3481146e+01, 2.5803375e+01, 2.8632090e+01, 3.0234968e+01, 3.1886738e+01]]]], dtype=float32)>]
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-28.90071 , -23.052422, -19.851362], [-19.775568, -25.894997, -20.302256], [-21.10754 , -23.667885, -20.973007], [-19.249458, -20.87892 , -20.573763], [-22.351208, -25.457762, -24.648403]], dtype=float32)>
Por otro lado, nuestra mano cuidadosamente JointDistributionSequential
ya no funciona:
jds_ia = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Independent( # Y
tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
reinterpreted_batch_ndims=1)
])
try:
jds_ia.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3,1] vs. [2,7] [Op:Mul]
Para solucionar esto, tendríamos que añadir un segundo tf.newaxis
tanto m
y b
adaptarse a la forma, y aumentar las reinterpreted_batch_ndims
a 2 en la llamada a Independent
. En este caso, dejar que la maquinaria de procesamiento por lotes automático maneje los problemas de forma es más corto, más fácil y más ergonómico.
Una vez más, observamos que mientras que este cuaderno exploró JointDistributionSequentialAutoBatched
, las otras variantes de JointDistribution
tienen equivalente AutoBatched
. (Para los usuarios de JointDistributionCoroutine
, JointDistributionCoroutineAutoBatched
tiene la ventaja adicional de que ya no hay necesidad de especificar Root
nodos, si usted nunca ha usado JointDistributionCoroutine
. Puede ignorar esta declaración)
Pensamientos concluyentes
En este cuaderno, introdujimos JointDistributionSequentialAutoBatched
y trabajamos a través de un ejemplo sencillo en detalle. ¡Ojalá hayas aprendido algo sobre las formas de TFP y sobre el lote automático!