Distribuciones de juntas en lotes automáticos: un tutorial sencillo

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno

Introducción

TensorFlow Probabilidad (PTF) ofrece una serie de JointDistribution abstracciones que hacen que la inferencia probabilística más fácil al permitir que un usuario de manera fácil expresar un modelo gráfico probabilístico en una forma casi matemática; la abstracción genera métodos para tomar muestras del modelo y evaluar la probabilidad logarítmica de las muestras del modelo. En este tutorial, se revisan variantes "autobatched", que fueron desarrollados después de los originales JointDistribution abstracciones. En relación con las abstracciones originales, no autobatched, las versiones autobatched son más sencillas de usar y más ergonómicas, lo que permite que muchos modelos se expresen con menos repetición. En esta colab, exploramos un modelo simple con (quizás tedioso) detalle, dejando en claro los problemas que resuelve el autobatching y (con suerte) enseñándole al lector más sobre los conceptos de formas de TFP en el camino.

Antes de la introducción de autobatching, había unas pocas variantes diferentes de JointDistribution , correspondientes a diferentes estilos sintácticas para expresar modelos probabilísticos: JointDistributionSequential , JointDistributionNamed , y JointDistributionCoroutine . Auobatching existe como un mixin, así que ahora tenemos AutoBatched variantes de todos ellos. En este tutorial, exploramos las diferencias entre JointDistributionSequential y JointDistributionSequentialAutoBatched ; sin embargo, todo lo que hacemos aquí es aplicable a las otras variantes sin esencialmente cambios.

Dependencias y requisitos previos

Importación y configuraciones

Requisito previo: un problema de regresión bayesiana

Consideraremos un escenario de regresión bayesiana muy simple:

\[ \begin{align*} m & \sim \text{Normal}(0, 1) \\ b & \sim \text{Normal}(0, 1) \\ Y & \sim \text{Normal}(mX + b, 1) \end{align*} \]

En este modelo, m y b se han extraído de las normales estándar, y de las observaciones Y se dibujan de una distribución normal cuya media depende de las variables aleatorias m y b , y algunos (no aleatoria, conocido) covariables X . (Para simplificar, en este ejemplo, asumimos que se conoce la escala de todas las variables aleatorias).

Para llevar a cabo la inferencia en este modelo, tendríamos que conocer tanto las covariables X y las observaciones Y , pero para los propósitos de este tutorial, sólo necesitaremos X , por lo que definimos un simple maniquí X :

X = np.arange(7)
X
array([0, 1, 2, 3, 4, 5, 6])

Desiderata

En la inferencia probabilística, a menudo queremos realizar dos operaciones básicas:

  • sample : Dibujo muestras a partir del modelo.
  • log_prob : Cálculo de la probabilidad de registro de una muestra a partir del modelo.

La contribución clave de la PTF de JointDistribution abstracciones (así como de muchos otros enfoques de programación probabilística) es permitir a los usuarios escribir un modelo de una vez y tener acceso a ambas sample y log_prob cálculos.

Tomando nota de que tenemos 7 puntos en nuestro conjunto de datos ( X.shape = (7,) ), ahora podemos afirmar los desiderata para una excelente JointDistribution :

  • sample() debe producir una lista de Tensors que tienen forma [(), (), (7,) ], que corresponde a la pendiente escalar, el sesgo de escalar, y las observaciones del vector, respectivamente.
  • log_prob(sample()) debe producir un escalar: la probabilidad de registro de una determinada pendiente, sesgo, y observaciones.
  • sample([5, 3]) debe producir una lista de Tensors que tienen forma [(5, 3), (5, 3), (5, 3, 7)] , lo que representa un (5, 3) - lote de muestras de el modelo.
  • log_prob(sample([5, 3])) debe producir una Tensor con forma (5, 3).

Ahora vamos a ver una sucesión de JointDistribution modelos, vemos cómo lograr los desiderata anteriormente, y es de esperar aprender un poco más sobre la PTF da forma a lo largo del camino.

Alerta de spoiler: El enfoque que satisface los deseos anteriormente sin repetitivo añadido se autobatching .

Primer intento; JointDistributionSequential

jds = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

Esta es más o menos una traducción directa del modelo en código. La pendiente m y el sesgo b son sencillos. Y se define utilizando un lambda -Función: el patrón general es que un lambda -Función de \(k\) argumentos en una JointDistributionSequential (JDS) utiliza los anteriores \(k\) distribuciones en el modelo. Tenga en cuenta el orden "inverso".

Llamaremos sample_distributions , que devuelve la vez una muestra y los subyacentes "sub-distribuciones" que se utilizaron para generar la muestra. (Podríamos haber producido solo la muestra llamando a sample , y más tarde en el tutorial que será conveniente tener las distribuciones también.) La muestra que producimos está bien:

dists, sample = jds.sample_distributions()
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

Pero log_prob produce un resultado con una forma no deseada:

jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=
array([-4.4777603, -4.6775575, -4.7430477, -4.647725 , -4.5746684,
       -4.4368567, -4.480562 ], dtype=float32)>

Y el muestreo múltiple no funciona:

try:
  jds.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

Tratemos de entender qué va mal.

Una breve reseña: forma de lote y evento

En la PTF, una (no una ordinaria JointDistribution ) distribución de probabilidad tiene una forma evento y una forma por lotes, y la comprensión de la diferencia es crucial para el uso eficaz de la PTF:

  • La forma del evento describe la forma de un solo dibujo de la distribución; el dibujo puede depender de las dimensiones. Para distribuciones escalares, la forma del evento es []. Para un MultivariateNormal de 5 dimensiones, la forma del evento es [5].
  • La forma de lote describe los sorteos independientes, no distribuidos de forma idéntica, también conocido como un "lote" de distribuciones. Representar un lote de distribuciones en un solo objeto de Python es una de las formas clave en que TFP logra eficiencia a escala.

Para nuestros propósitos, un hecho importante a tener en cuenta es que si llamamos log_prob en una sola muestra de una distribución, el resultado siempre tendrá una forma que partidos (es decir, tiene como dimensiones más a la derecha) la forma de lotes.

Para una discusión más a fondo de las formas, vea el tutorial "Descripción de TensorFlow Distribuciones formas" .

¿Por qué no log_prob(sample()) Producir un escalar?

Vamos a utilizar nuestro conocimiento de lotes y el evento forma de explorar lo que está pasando con log_prob(sample()) . Aquí está nuestra muestra nuevamente:

sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

Y aquí están nuestras distribuciones:

dists
[<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>]

La probabilidad logarítmica se calcula sumando las probabilidades logarítmicas de las subdistribuciones en los elementos (emparejados) de las partes:

log_prob_parts = [dist.log_prob(s) for (dist, s) in zip(dists, sample)]
log_prob_parts
[<tf.Tensor: shape=(), dtype=float32, numpy=-2.3113134>,
 <tf.Tensor: shape=(), dtype=float32, numpy=-1.1357536>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([-1.0306933, -1.2304904, -1.2959809, -1.200658 , -1.1276014,
        -0.9897899, -1.0334952], dtype=float32)>]
np.sum(log_prob_parts) - jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)>

Por lo tanto, un nivel de explicación es que el cálculo de probabilidades de registro está regresando a 7 Tensor porque el tercer subcomponente de log_prob_parts está a 7 tensor. ¿Pero por qué?

Así, vemos que el último elemento de dists , que corresponde a nuestra distribución en Y en la formulación mathematial, tiene una batch_shape de [7] . En otras palabras, nuestra distribución en Y es un lote de 7 normales independientes (con diferentes medios y, en este caso, la misma escala).

Ahora entendemos lo que está mal: en JDS, la distribución en Y tiene batch_shape=[7] , una muestra de la JDS representa escalares para m y b y un "lote" de 7 normales independientes. y log_prob calcula 7 log-probabilidades separadas, cada una de las cuales representa la probabilidad de registro de dibujo m y b y una sola observación Y[i] en algún X[i] .

La fijación de log_prob(sample()) con Independent

Recordemos que dists[2] tiene event_shape=[] y batch_shape=[7] :

dists[2]
<tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>

Mediante el uso de la PTF Independent metadistribución, que convierte dimensiones de lotes a las dimensiones de eventos, podemos convertir esto en una distribución con event_shape=[7] y batch_shape=[] (vamos a renombramos y_dist_i porque es una distribución en Y , con el _i de pie en nuestro Independent envoltura):

y_dist_i = tfd.Independent(dists[2], reinterpreted_batch_ndims=1)
y_dist_i
<tfp.distributions.Independent 'IndependentJointDistributionSequential_sample_distributions_Normal' batch_shape=[] event_shape=[7] dtype=float32>

Ahora, el log_prob de un 7-vector es un escalar:

y_dist_i.log_prob(sample[2])
<tf.Tensor: shape=(), dtype=float32, numpy=-7.9087086>

Bajo las sábanas, Independent sumas sobre el lote:

y_dist_i.log_prob(sample[2]) - tf.reduce_sum(dists[2].log_prob(sample[2]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

Y, de hecho, podemos usar esto para construir un nuevo jds_i (la i de nuevo es sinónimo de Independent ), donde log_prob devuelve un escalar:

jds_i = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m*X + b, scale=1.),
        reinterpreted_batch_ndims=1)
])

jds_i.log_prob(sample)
<tf.Tensor: shape=(), dtype=float32, numpy=-11.355776>

Un par de notas:

  • jds_i.log_prob(s) no es el mismo que tf.reduce_sum(jds.log_prob(s)) . El primero produce la probabilidad logarítmica "correcta" de la distribución conjunta. Estas últimas sumas más de un 7-Tensor, cada elemento de los cuales es la suma de la probabilidad de registro de m , b , y un único elemento de la probabilidad de registro de Y , por lo que overcounts m y b . ( log_prob(m) + log_prob(b) + log_prob(Y) devuelve un resultado en lugar de lanzar una excepción porque TFP sigue TF y las normas de radiodifusión de NumPy;. La adición de un escalar a un vector produce un resultado-vector de tamaño)
  • En este caso particular, podríamos haber resuelto el problema y ha logrado el mismo resultado utilizando MultivariateNormalDiag en lugar de Independent(Normal(...)) . MultivariateNormalDiag es una distribución de valor vectorial (es decir, que ya tiene vector evento de forma). Indeeed MultivariateNormalDiag podría ser (pero no se), implementado como una composición de Independent y Normal . Vale la pena recordar que dado un vector V , las muestras de n1 = Normal(loc=V) , y n2 = MultivariateNormalDiag(loc=V) son indistinguibles; la diferencia beween estas distribuciones es que n1.log_prob(n1.sample()) es un vector y n2.log_prob(n2.sample()) es un escalar.

¿Varias muestras?

Dibujar varias muestras todavía no funciona:

try:
  jds_i.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

Pensemos por qué. Cuando llamamos jds_i.sample([5, 3]) , nos primer sorteo muestras para m y b , cada una con forma (5, 3) . A continuación, vamos a tratar de construir una Normal distribución a través de:

tfd.Normal(loc=m*X + b, scale=1.)

Pero si m tiene la forma (5, 3) y X tiene forma de 7 , no se pueden multiplicar entre sí, y de hecho este es el error que está golpeando:

m = tfd.Normal(0., 1.).sample([5, 3])
try:
  m * X
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

Para resolver este problema, vamos a pensar en lo que las propiedades de la distribución a través de Y tiene que tener. Si hemos llamado jds_i.sample([5, 3]) , entonces sabemos m y b ambos tienen la forma (5, 3) . ¿Qué forma debe una llamada a sample en la Y productos de distribución? La respuesta obvia es (5, 3, 7) : para cada punto de lote, queremos una muestra con el mismo tamaño que X . Podemos lograr esto utilizando las capacidades de transmisión de TensorFlow, agregando dimensiones adicionales:

m[..., tf.newaxis].shape
TensorShape([5, 3, 1])
(m[..., tf.newaxis] * X).shape
TensorShape([5, 3, 7])

Adición de un eje a ambos m y b , podemos definir un nuevo JDS que soporta múltiples muestras:

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

shaped_sample = jds_ia.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-1.1133379 ,  0.16390413, -0.24177533],
        [-1.1312429 , -0.6224666 , -1.8182136 ],
        [-0.31343174, -0.32932565,  0.5164407 ],
        [-0.0119963 , -0.9079621 ,  2.3655841 ],
        [-0.26293617,  0.8229698 ,  0.31098196]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-0.02876974,  1.0872147 ,  1.0138507 ],
        [ 0.27367726, -1.331534  , -0.09084719],
        [ 1.3349475 , -0.68765205,  1.680652  ],
        [ 0.75436825,  1.3050154 , -0.9415123 ],
        [-1.2502679 , -0.25730947,  0.74611956]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 7), dtype=float32, numpy=
 array([[[-1.8258233e+00, -3.0641669e-01, -2.7595463e+00, -1.6952467e+00,
          -4.8197951e+00, -5.2986512e+00, -6.6931367e+00],
         [ 3.6438566e-01,  1.0067395e+00,  1.4542470e+00,  8.1155670e-01,
           1.8868095e+00,  2.3877139e+00,  1.0195159e+00],
         [-8.3624744e-01,  1.2518480e+00,  1.0943471e+00,  1.3052304e+00,
          -4.5756745e-01, -1.0668410e-01, -7.0669651e-02]],
 
        [[-3.1788960e-01,  9.2615485e-03, -3.0963073e+00, -2.2846246e+00,
          -3.2269263e+00, -6.0213070e+00, -7.4806519e+00],
         [-3.9149747e+00, -3.5155020e+00, -1.5669601e+00, -5.0759468e+00,
          -4.5065498e+00, -5.6719379e+00, -4.8012795e+00],
         [ 1.3053948e-01, -8.0493152e-01, -4.7845001e+00, -4.9721808e+00,
          -7.1365709e+00, -9.6198196e+00, -9.7951422e+00]],
 
        [[ 2.0621397e+00,  3.4639853e-01,  7.0252883e-01, -1.4311566e+00,
           3.3790007e+00,  1.1619035e+00, -8.9105040e-01],
         [-7.8956139e-01, -8.5023916e-01, -9.7148323e-01, -2.6229355e+00,
          -2.7150445e+00, -2.4633870e+00, -2.1841538e+00],
         [ 7.7627432e-01,  2.2401071e+00,  3.7601702e+00,  2.4245868e+00,
           4.0690269e+00,  4.0605016e+00,  5.1753912e+00]],
 
        [[ 1.4275590e+00,  3.3346462e+00,  1.5374103e+00, -2.2849756e-01,
           9.1219616e-01, -3.1220305e-01, -3.2643962e-01],
         [-3.1910419e-02, -3.8848895e-01,  9.9946201e-02, -2.3619974e+00,
          -1.8507402e+00, -3.6830821e+00, -5.4907336e+00],
         [-7.1941972e-02,  2.1602919e+00,  4.9575748e+00,  4.2317696e+00,
           9.3528280e+00,  1.0526063e+01,  1.5262107e+01]],
 
        [[-2.3257759e+00, -2.5343289e+00, -3.5342445e+00, -4.0423255e+00,
          -3.2361765e+00, -3.3434000e+00, -2.6849220e+00],
         [ 1.5006512e-02, -1.9866472e-01,  7.6781356e-01,  1.6228745e+00,
           1.4191239e+00,  2.6655579e+00,  4.4663467e+00],
         [ 2.6599693e+00,  1.2663836e+00,  1.7162113e+00,  1.4839669e+00,
           2.0559487e+00,  2.5976877e+00,  2.5977583e+00]]], dtype=float32)>]
jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.483114 , -10.139662 , -11.514159 ],
       [-11.656767 , -17.201958 , -12.132455 ],
       [-17.838818 ,  -9.474525 , -11.24898  ],
       [-13.95219  , -12.490049 , -17.123957 ],
       [-14.487818 , -11.3755455, -10.576363 ]], dtype=float32)>

Como verificación adicional, verificaremos que la probabilidad de registro para un solo punto de lote coincida con lo que teníamos antes:

(jds_ia.log_prob(shaped_sample)[3, 1] -
 jds_i.log_prob([shaped_sample[0][3, 1],
                 shaped_sample[1][3, 1],
                 shaped_sample[2][3, 1, :]]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

AutoBatching para la victoria

¡Excelente! Ahora tenemos una versión de JointDistribution que maneja todos nuestros desiderata: log_prob vuelve un escalar gracias a la utilización de tfd.Independent , y múltiples muestras de trabajo ahora que hemos fijado la difusión mediante la adición de ejes adicionales.

¿Y si te dijera que hay una manera mejor y más fácil? No existe, y se llama JointDistributionSequentialAutoBatched (JDSAB):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
jds_ab.log_prob(jds.sample())
<tf.Tensor: shape=(), dtype=float32, numpy=-12.954952>
shaped_sample = jds_ab.sample([5, 3])
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.191533 , -10.43885  , -16.371655 ],
       [-13.292994 , -11.97949  , -16.788685 ],
       [-15.987699 , -13.435732 , -10.6029   ],
       [-10.184758 , -11.969714 , -14.275676 ],
       [-12.740775 , -11.5654125, -12.990162 ]], dtype=float32)>
jds_ab.log_prob(shaped_sample) - jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)>

¿Como funciona esto? Mientras que usted podría tratar de leer el código para una comprensión más profunda, vamos a dar una visión general breve que es suficiente para la mayoría de los casos de uso:

  • Recordemos que nuestro primer problema era que nuestra distribución para Y tenía batch_shape=[7] y event_shape=[] , y hemos utilizado Independent para convertir la dimensión lote a una dimensión evento. JDSAB ignora las formas de lote de distribuciones de componentes; en lugar en que trata la forma de proceso por lotes como una propiedad general del modelo, que se supone que es [] (a menos que se especifique lo contrario mediante el establecimiento de batch_ndims > 0 ). El efecto es equivalente a usar tfd.Independent para convertir todas las dimensiones de los lotes de las distribuciones de componentes en dimensiones de eventos, como lo hicimos manualmente anteriormente.
  • Nuestro segundo problema era una necesidad para dar masajes a las formas de m y b para que pudieran transmitir adecuadamente con X cuando la creación de múltiples muestras. Con JDSAB, se escribe un modelo para generar una sola muestra, y "levantar" la totalidad del modelo para generar múltiples muestras usando de TensorFlow vectorized_map . (Esta función es análoga a la de JAX vmap ).

Explorar la cuestión de forma de lotes con más detalle, podemos comparar las formas de lote de nuestra original "mala" distribución conjunta jds , nuestras distribuciones de lote fijo jds_i y jds_ia , y nuestra autobatched jds_ab :

jds.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([7])]
jds_i.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ia.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ab.batch_shape
TensorShape([])

Vemos que los originales jds tiene subdistributions con formas diferentes lotes. jds_i y jds_ia solucionar este problema mediante la creación de subdistributions con el mismo (vacío) forma de lotes. jds_ab tiene sólo un único (vacío) forma de lotes.

Vale la pena señalar que JointDistributionSequentialAutoBatched ofrece cierta generalidad adicional gratis. Supongamos que hacemos la covariables X (e, implícitamente, las observaciones Y ) de dos dimensiones:

X = np.arange(14).reshape((2, 7))
X
array([[ 0,  1,  2,  3,  4,  5,  6],
       [ 7,  8,  9, 10, 11, 12, 13]])

Nuestra JointDistributionSequentialAutoBatched funciona sin cambios (es necesario redefinir el modelo porque la forma de X se almacena en caché por jds_ab.log_prob ):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

shaped_sample = jds_ab.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.1813647 , -0.85994506,  0.27593774],
        [-0.73323774,  1.1153806 ,  0.8841938 ],
        [ 0.5127983 , -0.29271227,  0.63733214],
        [ 0.2362284 , -0.919168  ,  1.6648189 ],
        [ 0.26317367,  0.73077047,  2.5395133 ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.09636458,  2.0138032 , -0.5054413 ],
        [ 0.63941646, -1.0785882 , -0.6442188 ],
        [ 1.2310615 , -0.3293852 ,  0.77637213],
        [ 1.2115169 , -0.98906034, -0.07816773],
        [-1.1318136 ,  0.510014  ,  1.036522  ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 2, 7), dtype=float32, numpy=
 array([[[[-1.9685398e+00, -1.6832136e+00, -6.9127172e-01,
            8.5992378e-01, -5.3123581e-01,  3.1584005e+00,
            2.9044402e+00],
          [-2.5645006e-01,  3.1554163e-01,  3.1186538e+00,
            1.4272424e+00,  1.2843871e+00,  1.2266440e+00,
            1.2798605e+00]],
 
         [[ 1.5973477e+00, -5.3631151e-01,  6.8143606e-03,
           -1.4910895e+00, -2.1568544e+00, -2.0513713e+00,
           -3.1663666e+00],
          [-4.9448099e+00, -2.8385928e+00, -6.9027486e+00,
           -5.6543546e+00, -7.2378774e+00, -8.1577444e+00,
           -9.3582869e+00]],
 
         [[-2.1233239e+00,  5.8853775e-02,  1.2024102e+00,
            1.6622503e+00, -1.9197327e-01,  1.8647723e+00,
            6.4322817e-01],
          [ 3.7549341e-01,  1.5853541e+00,  2.4594500e+00,
            2.1952972e+00,  1.7517658e+00,  2.9666045e+00,
            2.5468128e+00]]],
 
 
        [[[ 8.9906776e-01,  6.7375046e-01,  7.3354661e-01,
           -9.9894643e-01, -3.4606690e+00, -3.4810467e+00,
           -4.4315586e+00],
          [-3.0670738e+00, -6.3628020e+00, -6.2538433e+00,
           -6.8091092e+00, -7.7134805e+00, -8.6319380e+00,
           -8.6904278e+00]],
 
         [[-2.2462025e+00, -3.3060855e-01,  1.8974400e-01,
            3.1422038e+00,  4.1483402e+00,  3.5642972e+00,
            4.8709240e+00],
          [ 4.7880130e+00,  5.8790064e+00,  9.6695948e+00,
            7.8112822e+00,  1.2022618e+01,  1.2411858e+01,
            1.4323385e+01]],
 
         [[-1.0189297e+00, -7.8115642e-01,  1.6466728e+00,
            8.2378983e-01,  3.0765080e+00,  3.0170646e+00,
            5.1899948e+00],
          [ 6.5285158e+00,  7.8038850e+00,  6.4155884e+00,
            9.0899811e+00,  1.0040427e+01,  9.1404457e+00,
            1.0411951e+01]]],
 
 
        [[[ 4.5557004e-01,  1.4905317e+00,  1.4904103e+00,
            2.9777462e+00,  2.8620450e+00,  3.4745665e+00,
            3.8295493e+00],
          [ 3.9977460e+00,  5.7173767e+00,  7.8421035e+00,
            6.3180594e+00,  6.0838981e+00,  8.2257290e+00,
            9.6548376e+00]],
 
         [[-7.0750320e-01, -3.5972297e-01,  4.3136525e-01,
           -2.3301599e+00, -5.0374687e-01, -2.8338656e+00,
           -3.4453444e+00],
          [-3.1258626e+00, -3.4687450e+00, -1.2045374e+00,
           -4.0196013e+00, -5.8831010e+00, -4.2965469e+00,
           -4.1388311e+00]],
 
         [[ 2.1969774e+00,  2.4614549e+00,  2.2314475e+00,
            1.8392437e+00,  2.8367062e+00,  4.8600502e+00,
            4.2273531e+00],
          [ 6.1879644e+00,  5.1792760e+00,  6.1141996e+00,
            5.6517797e+00,  8.9979610e+00,  7.5938139e+00,
            9.7918644e+00]]],
 
 
        [[[ 1.5249090e+00,  1.1388919e+00,  8.6903995e-01,
            3.0762129e+00,  1.5128503e+00,  3.5204377e+00,
            2.4760864e+00],
          [ 3.4166217e+00,  3.5930209e+00,  3.1694956e+00,
            4.5797420e+00,  4.5271711e+00,  2.8774328e+00,
            4.7288942e+00]],
 
         [[-2.3095846e+00, -2.0595703e+00, -3.0093951e+00,
           -3.8594103e+00, -4.9681158e+00, -6.4256043e+00,
           -5.5345035e+00],
          [-6.4306297e+00, -7.0924540e+00, -8.4075985e+00,
           -1.0417805e+01, -1.1727266e+01, -1.1196255e+01,
           -1.1333830e+01]],
 
         [[-7.0419472e-01,  1.4568675e+00,  3.7946482e+00,
            4.8489718e+00,  6.6498446e+00,  9.0224218e+00,
            1.1153137e+01],
          [ 1.0060651e+01,  1.1998097e+01,  1.5326431e+01,
            1.7957514e+01,  1.8323889e+01,  2.0160881e+01,
            2.1269085e+01]]],
 
 
        [[[-2.2360647e-01, -1.3632748e+00, -7.2704530e-01,
            2.3558271e-01, -1.0381399e+00,  1.9387857e+00,
           -3.3694571e-01],
          [ 1.6015106e-01,  1.5284677e+00, -4.8567140e-01,
           -1.7770648e-01,  2.1919653e+00,  1.3015286e+00,
            1.3877077e+00]],
 
         [[ 1.3688663e+00,  2.6602898e+00,  6.6657305e-01,
            4.6554832e+00,  5.7781887e+00,  4.9115267e+00,
            4.8446012e+00],
          [ 5.1983776e+00,  6.2297459e+00,  6.3848300e+00,
            8.4291229e+00,  7.1309576e+00,  1.0395646e+01,
            8.5736713e+00]],
 
         [[ 1.2675294e+00,  5.2844582e+00,  5.1331611e+00,
            8.9993315e+00,  1.0794343e+01,  1.4039831e+01,
            1.5731170e+01],
          [ 1.9084715e+01,  2.2191265e+01,  2.3481146e+01,
            2.5803375e+01,  2.8632090e+01,  3.0234968e+01,
            3.1886738e+01]]]], dtype=float32)>]
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-28.90071 , -23.052422, -19.851362],
       [-19.775568, -25.894997, -20.302256],
       [-21.10754 , -23.667885, -20.973007],
       [-19.249458, -20.87892 , -20.573763],
       [-22.351208, -25.457762, -24.648403]], dtype=float32)>

Por otro lado, nuestra mano cuidadosamente JointDistributionSequential ya no funciona:

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

try:
  jds_ia.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3,1] vs. [2,7] [Op:Mul]

Para solucionar esto, tendríamos que añadir un segundo tf.newaxis tanto m y b adaptarse a la forma, y aumentar las reinterpreted_batch_ndims a 2 en la llamada a Independent . En este caso, dejar que la maquinaria de procesamiento por lotes automático maneje los problemas de forma es más corto, más fácil y más ergonómico.

Una vez más, observamos que mientras que este cuaderno exploró JointDistributionSequentialAutoBatched , las otras variantes de JointDistribution tienen equivalente AutoBatched . (Para los usuarios de JointDistributionCoroutine , JointDistributionCoroutineAutoBatched tiene la ventaja adicional de que ya no hay necesidad de especificar Root nodos, si usted nunca ha usado JointDistributionCoroutine . Puede ignorar esta declaración)

Pensamientos concluyentes

En este cuaderno, introdujimos JointDistributionSequentialAutoBatched y trabajamos a través de un ejemplo sencillo en detalle. ¡Ojalá hayas aprendido algo sobre las formas de TFP y sobre el lote automático!