Distribuições de junta em lote automático: um tutorial suave

Ver no TensorFlow.org Executar no Google Colab Ver fonte no GitHub Baixar caderno

Introdução

TensorFlow Probabilidade (TFP) oferece uma série de JointDistribution abstrações que fazem inferência probabilística mais fácil, permitindo que um usuário facilmente expressar um modelo gráfico probabilística de uma forma quase matemática; a abstração gera métodos para amostragem do modelo e avaliação da probabilidade de log de amostras do modelo. Neste tutorial, revisamos variantes "autobatched", que foram desenvolvidos após os originais JointDistribution abstrações. Em relação às abstrações originais não autobatched, as versões autobatched são mais simples de usar e mais ergonômicas, permitindo que muitos modelos sejam expressos com menos clichês. Neste colab, exploramos um modelo simples em detalhes (talvez tediosos), deixando claro os problemas que o autobatching resolve e (espero) ensinando o leitor mais sobre os conceitos de forma da TFP ao longo do caminho.

Antes da introdução de autobatching, houve algumas variantes diferentes de JointDistribution , correspondendo a diferentes estilos sintáticas para expressar modelos probabilísticos: JointDistributionSequential , JointDistributionNamed e JointDistributionCoroutine . Auobatching existe como um mixin, então agora temos AutoBatched variantes de todos estes. Neste tutorial, vamos explorar as diferenças entre JointDistributionSequential e JointDistributionSequentialAutoBatched ; no entanto, tudo o que fazemos aqui é aplicável às outras variantes, essencialmente sem alterações.

Dependências e pré-requisitos

Importar e configurar

Pré-requisito: Um Problema de Regressão Bayesiana

Vamos considerar um cenário de regressão bayesiana muito simples:

\[ \begin{align*} m & \sim \text{Normal}(0, 1) \\ b & \sim \text{Normal}(0, 1) \\ Y & \sim \text{Normal}(mX + b, 1) \end{align*} \]

Neste modelo, m e b são retirados de indivíduos normais padrão, e as observações Y são desenhados a partir de uma distribuição normal com média depende das variáveis aleatórias m e b , e alguns (não aleatória, conhecido) covariáveis X . (Para simplificar, neste exemplo, assumimos que a escala de todas as variáveis ​​aleatórias é conhecida.)

Para realizar inferência neste modelo, precisaríamos saber ambas as covariáveis X e as observações Y , mas para os fins deste tutorial, nós só precisa de X , então definimos um simples manequim X :

X = np.arange(7)
X
array([0, 1, 2, 3, 4, 5, 6])

Desiderata

Na inferência probabilística, geralmente queremos realizar duas operações básicas:

  • sample : a tiragem de amostras a partir do modelo.
  • log_prob : Calculando a probabilidade log de uma amostra a partir do modelo.

A contribuição fundamental de da TFP JointDistribution abstrações (bem como de muitas outras abordagens para a programação probabilística) é permitir aos usuários escrever um modelo de uma vez e ter acesso a ambas as sample e log_prob cálculos.

Notando que temos 7 pontos em nosso conjunto de dados ( X.shape = (7,) ), podemos agora afirmar os desideratos para um excelente JointDistribution :

  • sample() deve produzir uma lista dos Tensors que têm forma [(), (), (7,) ], correspondente à inclinação escalar, viés escalar, e observações vector, respectivamente.
  • log_prob(sample()) deverá produzir um escalar: a probabilidade de log de uma determinada inclinação, de polarização, e observações.
  • sample([5, 3]) deve produzir uma lista dos Tensors que têm forma [(5, 3), (5, 3), (5, 3, 7)] , que representa um (5, 3) - em lotes de amostras a partir de o modelo.
  • log_prob(sample([5, 3])) deve produzir uma Tensor com forma (5, 3).

Vamos agora olhar para uma sucessão de JointDistribution modelos, ver como alcançar os desideratos acima, e espero aprender um pouco mais sobre TFP molda ao longo do caminho.

Alerta de spoiler: A abordagem que satisfaz os desideratos acima sem clichê adicionado é autobatching .

Primeira tentativa; JointDistributionSequential

jds = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

Esta é mais ou menos uma tradução direta do modelo em código. O declive m e viés b são simples. Y é definida usando um lambda -função: o padrão geral é que um lambda -função de \(k\) argumentos em um JointDistributionSequential (JDS) utiliza os anteriores \(k\) distribuições no modelo. Observe a ordem "reversa".

Vamos chamar sample_distributions , que retorna tanto uma amostra e as subjacentes "sub-distribuições" que foram usados para gerar a amostra. (Poderíamos ter produzido apenas a amostra chamando sample , mais tarde no tutorial, será conveniente ter as distribuições também.) A amostra que produzimos é muito bem:

dists, sample = jds.sample_distributions()
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

Mas log_prob produz um resultado com uma forma indesejada:

jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=
array([-4.4777603, -4.6775575, -4.7430477, -4.647725 , -4.5746684,
       -4.4368567, -4.480562 ], dtype=float32)>

E a amostragem múltipla não funciona:

try:
  jds.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

Vamos tentar entender o que está errado.

Uma breve revisão: forma de lote e evento

Em TFP, uma (não um comum JointDistribution ) distribuição de probabilidade tem uma forma evento e uma forma de lote, e compreender a diferença é crucial para o uso eficaz da TFP:

  • A forma do evento descreve a forma de um único desenho da distribuição; o desenho pode ser dependente das dimensões. Para distribuições escalares, a forma do evento é []. Para um MultivariateNormal 5-dimensional, a forma do evento é [5].
  • A forma de lote descreve desenhos independentes, não distribuídos de forma idêntica, também conhecido como "lote" de distribuições. Representar um lote de distribuições em um único objeto Python é uma das principais maneiras pelas quais o TFP alcança eficiência em escala.

Para nossos propósitos, um fato fundamental para manter em mente é que se nós chamamos log_prob em uma única amostra de uma distribuição, o resultado terá sempre uma forma que jogos (ou seja, tem como dimensões mais à direita) a forma de lote.

Para uma discussão mais aprofundada das formas, consulte "Compreender TensorFlow Distribuições Shapes" tutorial .

Por que não log_prob(sample()) Produzir um escalar?

Vamos usar o nosso conhecimento da forma de lote e evento para explorar o que está acontecendo com log_prob(sample()) . Aqui está nosso exemplo novamente:

sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

E aqui estão nossas distribuições:

dists
[<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>]

A probabilidade de log é calculada somando as probabilidades de log das sub-distribuições nos elementos (combinados) das partes:

log_prob_parts = [dist.log_prob(s) for (dist, s) in zip(dists, sample)]
log_prob_parts
[<tf.Tensor: shape=(), dtype=float32, numpy=-2.3113134>,
 <tf.Tensor: shape=(), dtype=float32, numpy=-1.1357536>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([-1.0306933, -1.2304904, -1.2959809, -1.200658 , -1.1276014,
        -0.9897899, -1.0334952], dtype=float32)>]
np.sum(log_prob_parts) - jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)>

Assim, um nível de explicação é que o cálculo de probabilidade log está retornando a 7 Tensor porque a terceira subcomponente do log_prob_parts fica a 7 Tensor. Mas por que?

Bem, vemos que o último elemento de dists , o que corresponde a nossa distribuição de mais de Y na formulação mathematial, tem uma batch_shape de [7] . Em outras palavras, a distribuição ao longo Y é um lote de 7 normais independentes (com diferentes meios e, neste caso, a mesma escala).

Compreendemos agora o que há de errado: na JDS, a distribuição ao longo do Y tem batch_shape=[7] , uma amostra do JDS representa escalares para m e b e um "lote" de 7 normais independentes. e log_prob calcula 7 log-probabilidades separadas, cada uma das quais representa a probabilidade de registo de tiragem m e b e uma única observação Y[i] em algum X[i] .

Fixação log_prob(sample()) com Independent

Recorde-se que dists[2] tem event_shape=[] e batch_shape=[7] :

dists[2]
<tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>

Através da utilização de PTF Independent metadistribuição, que converte as dimensões de lote para dimensões de eventos, que pode converter este em uma distribuição com event_shape=[7] e batch_shape=[] (que vai mudar o nome y_dist_i porque é uma distribuição em Y , com o _i permanente in para o nosso Independent embalagem):

y_dist_i = tfd.Independent(dists[2], reinterpreted_batch_ndims=1)
y_dist_i
<tfp.distributions.Independent 'IndependentJointDistributionSequential_sample_distributions_Normal' batch_shape=[] event_shape=[7] dtype=float32>

Agora, a log_prob de um 7-vector é um escalar:

y_dist_i.log_prob(sample[2])
<tf.Tensor: shape=(), dtype=float32, numpy=-7.9087086>

Debaixo das cobertas, Independent somas sobre o lote:

y_dist_i.log_prob(sample[2]) - tf.reduce_sum(dists[2].log_prob(sample[2]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

E, de fato, podemos usar isso para construir um novo jds_i (o i novamente significa Independent ), onde log_prob retorna um escalar:

jds_i = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m*X + b, scale=1.),
        reinterpreted_batch_ndims=1)
])

jds_i.log_prob(sample)
<tf.Tensor: shape=(), dtype=float32, numpy=-11.355776>

Algumas notas:

  • jds_i.log_prob(s) não é o mesmo que tf.reduce_sum(jds.log_prob(s)) . O primeiro produz o log de probabilidade "correto" da distribuição conjunta. Os últimos somas sobre um 7-Tensor, cada elemento de que é a soma da probabilidade de log de m , b , e um único elemento da probabilidade log de Y , de modo que overcounts m e b . ( log_prob(m) + log_prob(b) + log_prob(Y) retorna um resultado em vez de gerar uma excepção porque TFP segue TF e regras de radiodifusão de Numpy;. Adicionando um escalar para um vector produz um resultado do tamanho do vetor-)
  • Neste caso particular, poderíamos ter resolvido o problema e obteve o mesmo resultado usando MultivariateNormalDiag vez de Independent(Normal(...)) . MultivariateNormalDiag é uma distribuição vectorial (isto é, que já tem vector de evento-forma). Indeeed MultivariateNormalDiag poderia ser (mas não é) implementado como uma composição Independent e Normal . Vale a pena lembrar que dado um vector de V , a partir de amostras n1 = Normal(loc=V) , e n2 = MultivariateNormalDiag(loc=V) são indistinguíveis; a diferença beween estas distribuições é que n1.log_prob(n1.sample()) é um vector e n2.log_prob(n2.sample()) é um escalar.

Amostras múltiplas?

Desenhar várias amostras ainda não funciona:

try:
  jds_i.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

Vamos pensar no porquê. Quando chamamos jds_i.sample([5, 3]) , que vai primeiro extrair amostras para m e b , cada um com forma (5, 3) . Em seguida, vamos tentar construir uma Normal distribuição via:

tfd.Normal(loc=m*X + b, scale=1.)

Mas se m tem forma (5, 3) e X tem forma 7 , não podemos multiplicá-los juntos, e na verdade este é o erro que está batendo:

m = tfd.Normal(0., 1.).sample([5, 3])
try:
  m * X
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

Para resolver esse problema, vamos pensar sobre o que propriedades a distribuição ao longo Y tem que ter. Se nós chamamos jds_i.sample([5, 3]) , então sabemos m e b vai ambos têm forma (5, 3) . Que forma deve uma chamada para sample na Y produtos distribuição? A resposta óbvia é (5, 3, 7) : para cada ponto de lote, queremos uma amostra com o mesmo tamanho que X . Podemos conseguir isso usando os recursos de transmissão do TensorFlow, adicionando dimensões extras:

m[..., tf.newaxis].shape
TensorShape([5, 3, 1])
(m[..., tf.newaxis] * X).shape
TensorShape([5, 3, 7])

Adicionando um eixo de ambos m e b , pode-se definir uma nova JDS que suporta múltiplas amostras:

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

shaped_sample = jds_ia.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-1.1133379 ,  0.16390413, -0.24177533],
        [-1.1312429 , -0.6224666 , -1.8182136 ],
        [-0.31343174, -0.32932565,  0.5164407 ],
        [-0.0119963 , -0.9079621 ,  2.3655841 ],
        [-0.26293617,  0.8229698 ,  0.31098196]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-0.02876974,  1.0872147 ,  1.0138507 ],
        [ 0.27367726, -1.331534  , -0.09084719],
        [ 1.3349475 , -0.68765205,  1.680652  ],
        [ 0.75436825,  1.3050154 , -0.9415123 ],
        [-1.2502679 , -0.25730947,  0.74611956]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 7), dtype=float32, numpy=
 array([[[-1.8258233e+00, -3.0641669e-01, -2.7595463e+00, -1.6952467e+00,
          -4.8197951e+00, -5.2986512e+00, -6.6931367e+00],
         [ 3.6438566e-01,  1.0067395e+00,  1.4542470e+00,  8.1155670e-01,
           1.8868095e+00,  2.3877139e+00,  1.0195159e+00],
         [-8.3624744e-01,  1.2518480e+00,  1.0943471e+00,  1.3052304e+00,
          -4.5756745e-01, -1.0668410e-01, -7.0669651e-02]],
 
        [[-3.1788960e-01,  9.2615485e-03, -3.0963073e+00, -2.2846246e+00,
          -3.2269263e+00, -6.0213070e+00, -7.4806519e+00],
         [-3.9149747e+00, -3.5155020e+00, -1.5669601e+00, -5.0759468e+00,
          -4.5065498e+00, -5.6719379e+00, -4.8012795e+00],
         [ 1.3053948e-01, -8.0493152e-01, -4.7845001e+00, -4.9721808e+00,
          -7.1365709e+00, -9.6198196e+00, -9.7951422e+00]],
 
        [[ 2.0621397e+00,  3.4639853e-01,  7.0252883e-01, -1.4311566e+00,
           3.3790007e+00,  1.1619035e+00, -8.9105040e-01],
         [-7.8956139e-01, -8.5023916e-01, -9.7148323e-01, -2.6229355e+00,
          -2.7150445e+00, -2.4633870e+00, -2.1841538e+00],
         [ 7.7627432e-01,  2.2401071e+00,  3.7601702e+00,  2.4245868e+00,
           4.0690269e+00,  4.0605016e+00,  5.1753912e+00]],
 
        [[ 1.4275590e+00,  3.3346462e+00,  1.5374103e+00, -2.2849756e-01,
           9.1219616e-01, -3.1220305e-01, -3.2643962e-01],
         [-3.1910419e-02, -3.8848895e-01,  9.9946201e-02, -2.3619974e+00,
          -1.8507402e+00, -3.6830821e+00, -5.4907336e+00],
         [-7.1941972e-02,  2.1602919e+00,  4.9575748e+00,  4.2317696e+00,
           9.3528280e+00,  1.0526063e+01,  1.5262107e+01]],
 
        [[-2.3257759e+00, -2.5343289e+00, -3.5342445e+00, -4.0423255e+00,
          -3.2361765e+00, -3.3434000e+00, -2.6849220e+00],
         [ 1.5006512e-02, -1.9866472e-01,  7.6781356e-01,  1.6228745e+00,
           1.4191239e+00,  2.6655579e+00,  4.4663467e+00],
         [ 2.6599693e+00,  1.2663836e+00,  1.7162113e+00,  1.4839669e+00,
           2.0559487e+00,  2.5976877e+00,  2.5977583e+00]]], dtype=float32)>]
jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.483114 , -10.139662 , -11.514159 ],
       [-11.656767 , -17.201958 , -12.132455 ],
       [-17.838818 ,  -9.474525 , -11.24898  ],
       [-13.95219  , -12.490049 , -17.123957 ],
       [-14.487818 , -11.3755455, -10.576363 ]], dtype=float32)>

Como uma verificação extra, verificaremos se a probabilidade de log para um único ponto de lote corresponde ao que tínhamos antes:

(jds_ia.log_prob(shaped_sample)[3, 1] -
 jds_i.log_prob([shaped_sample[0][3, 1],
                 shaped_sample[1][3, 1],
                 shaped_sample[2][3, 1, :]]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

AutoBatching For The Win

Excelente! Temos agora uma versão do JointDistribution que lida com toda a nossa desideratos: log_prob retorna um escalares graças ao uso de tfd.Independent e várias amostras de trabalhar agora que fixa transmitindo adicionando eixos extras.

E se eu dissesse que existe uma maneira mais fácil e melhor? Há, e é chamado JointDistributionSequentialAutoBatched (JDSAB):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
jds_ab.log_prob(jds.sample())
<tf.Tensor: shape=(), dtype=float32, numpy=-12.954952>
shaped_sample = jds_ab.sample([5, 3])
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.191533 , -10.43885  , -16.371655 ],
       [-13.292994 , -11.97949  , -16.788685 ],
       [-15.987699 , -13.435732 , -10.6029   ],
       [-10.184758 , -11.969714 , -14.275676 ],
       [-12.740775 , -11.5654125, -12.990162 ]], dtype=float32)>
jds_ab.log_prob(shaped_sample) - jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)>

Como é que isso funciona? Enquanto você poderia tentar ler o código para uma compreensão profunda, vamos dar uma visão breve o que é suficiente para a maioria dos casos de uso:

  • Recorde-se que o nosso primeiro problema é que a nossa distribuição para Y tinha batch_shape=[7] e event_shape=[] , e utilizou-se Independent para converter a dimensão do lote para um evento de dimensão. O JDSAB ignora as formas de lote das distribuições de componentes; em vez disso, trata forma de lote como uma propriedade geral do modelo, o qual é assumido como sendo [] (a menos que especificado de outra forma pela definição batch_ndims > 0 ). O efeito é equivalente a usar tfd.Independent para converter todas as dimensões de lote de distribuições de componentes em dimensões de eventos, como fizemos manualmente acima.
  • A segunda foi um problema a necessidade de massagem as formas de m e b , para que pudessem difundir adequadamente com X quando a criação de várias amostras. Com JDSAB, você escreve um modelo para gerar uma única amostra, e nós "levantar" todo o modelo para gerar várias amostras usando de TensorFlow vectorized_map . (Este recurso é análogo ao do JAX VMAP .)

Explorando a questão de forma lote com mais detalhes, podemos comparar as formas de lote do nosso "mau" original de distribuição conjuntas jds , nossas distribuições fixo-lote jds_i e jds_ia , e nossa autobatched jds_ab :

jds.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([7])]
jds_i.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ia.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ab.batch_shape
TensorShape([])

Vemos que os originais jds tem subdistributions com diferentes formas de lote. jds_i e jds_ia corrigir isto criando subdistributions com a mesma forma de lote (vazio). jds_ab tem apenas uma única forma de lote (vazio).

É importante notar que JointDistributionSequentialAutoBatched oferece alguma generalidade adicional de graça. Suponha-se que fazer o co-variáveis X (e, implicitamente, as observações Y ) bidimensional:

X = np.arange(14).reshape((2, 7))
X
array([[ 0,  1,  2,  3,  4,  5,  6],
       [ 7,  8,  9, 10, 11, 12, 13]])

Nossa JointDistributionSequentialAutoBatched funciona sem alterações (precisamos redefinir o modelo porque a forma de X é armazenada em cache pelo jds_ab.log_prob ):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

shaped_sample = jds_ab.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.1813647 , -0.85994506,  0.27593774],
        [-0.73323774,  1.1153806 ,  0.8841938 ],
        [ 0.5127983 , -0.29271227,  0.63733214],
        [ 0.2362284 , -0.919168  ,  1.6648189 ],
        [ 0.26317367,  0.73077047,  2.5395133 ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.09636458,  2.0138032 , -0.5054413 ],
        [ 0.63941646, -1.0785882 , -0.6442188 ],
        [ 1.2310615 , -0.3293852 ,  0.77637213],
        [ 1.2115169 , -0.98906034, -0.07816773],
        [-1.1318136 ,  0.510014  ,  1.036522  ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 2, 7), dtype=float32, numpy=
 array([[[[-1.9685398e+00, -1.6832136e+00, -6.9127172e-01,
            8.5992378e-01, -5.3123581e-01,  3.1584005e+00,
            2.9044402e+00],
          [-2.5645006e-01,  3.1554163e-01,  3.1186538e+00,
            1.4272424e+00,  1.2843871e+00,  1.2266440e+00,
            1.2798605e+00]],
 
         [[ 1.5973477e+00, -5.3631151e-01,  6.8143606e-03,
           -1.4910895e+00, -2.1568544e+00, -2.0513713e+00,
           -3.1663666e+00],
          [-4.9448099e+00, -2.8385928e+00, -6.9027486e+00,
           -5.6543546e+00, -7.2378774e+00, -8.1577444e+00,
           -9.3582869e+00]],
 
         [[-2.1233239e+00,  5.8853775e-02,  1.2024102e+00,
            1.6622503e+00, -1.9197327e-01,  1.8647723e+00,
            6.4322817e-01],
          [ 3.7549341e-01,  1.5853541e+00,  2.4594500e+00,
            2.1952972e+00,  1.7517658e+00,  2.9666045e+00,
            2.5468128e+00]]],
 
 
        [[[ 8.9906776e-01,  6.7375046e-01,  7.3354661e-01,
           -9.9894643e-01, -3.4606690e+00, -3.4810467e+00,
           -4.4315586e+00],
          [-3.0670738e+00, -6.3628020e+00, -6.2538433e+00,
           -6.8091092e+00, -7.7134805e+00, -8.6319380e+00,
           -8.6904278e+00]],
 
         [[-2.2462025e+00, -3.3060855e-01,  1.8974400e-01,
            3.1422038e+00,  4.1483402e+00,  3.5642972e+00,
            4.8709240e+00],
          [ 4.7880130e+00,  5.8790064e+00,  9.6695948e+00,
            7.8112822e+00,  1.2022618e+01,  1.2411858e+01,
            1.4323385e+01]],
 
         [[-1.0189297e+00, -7.8115642e-01,  1.6466728e+00,
            8.2378983e-01,  3.0765080e+00,  3.0170646e+00,
            5.1899948e+00],
          [ 6.5285158e+00,  7.8038850e+00,  6.4155884e+00,
            9.0899811e+00,  1.0040427e+01,  9.1404457e+00,
            1.0411951e+01]]],
 
 
        [[[ 4.5557004e-01,  1.4905317e+00,  1.4904103e+00,
            2.9777462e+00,  2.8620450e+00,  3.4745665e+00,
            3.8295493e+00],
          [ 3.9977460e+00,  5.7173767e+00,  7.8421035e+00,
            6.3180594e+00,  6.0838981e+00,  8.2257290e+00,
            9.6548376e+00]],
 
         [[-7.0750320e-01, -3.5972297e-01,  4.3136525e-01,
           -2.3301599e+00, -5.0374687e-01, -2.8338656e+00,
           -3.4453444e+00],
          [-3.1258626e+00, -3.4687450e+00, -1.2045374e+00,
           -4.0196013e+00, -5.8831010e+00, -4.2965469e+00,
           -4.1388311e+00]],
 
         [[ 2.1969774e+00,  2.4614549e+00,  2.2314475e+00,
            1.8392437e+00,  2.8367062e+00,  4.8600502e+00,
            4.2273531e+00],
          [ 6.1879644e+00,  5.1792760e+00,  6.1141996e+00,
            5.6517797e+00,  8.9979610e+00,  7.5938139e+00,
            9.7918644e+00]]],
 
 
        [[[ 1.5249090e+00,  1.1388919e+00,  8.6903995e-01,
            3.0762129e+00,  1.5128503e+00,  3.5204377e+00,
            2.4760864e+00],
          [ 3.4166217e+00,  3.5930209e+00,  3.1694956e+00,
            4.5797420e+00,  4.5271711e+00,  2.8774328e+00,
            4.7288942e+00]],
 
         [[-2.3095846e+00, -2.0595703e+00, -3.0093951e+00,
           -3.8594103e+00, -4.9681158e+00, -6.4256043e+00,
           -5.5345035e+00],
          [-6.4306297e+00, -7.0924540e+00, -8.4075985e+00,
           -1.0417805e+01, -1.1727266e+01, -1.1196255e+01,
           -1.1333830e+01]],
 
         [[-7.0419472e-01,  1.4568675e+00,  3.7946482e+00,
            4.8489718e+00,  6.6498446e+00,  9.0224218e+00,
            1.1153137e+01],
          [ 1.0060651e+01,  1.1998097e+01,  1.5326431e+01,
            1.7957514e+01,  1.8323889e+01,  2.0160881e+01,
            2.1269085e+01]]],
 
 
        [[[-2.2360647e-01, -1.3632748e+00, -7.2704530e-01,
            2.3558271e-01, -1.0381399e+00,  1.9387857e+00,
           -3.3694571e-01],
          [ 1.6015106e-01,  1.5284677e+00, -4.8567140e-01,
           -1.7770648e-01,  2.1919653e+00,  1.3015286e+00,
            1.3877077e+00]],
 
         [[ 1.3688663e+00,  2.6602898e+00,  6.6657305e-01,
            4.6554832e+00,  5.7781887e+00,  4.9115267e+00,
            4.8446012e+00],
          [ 5.1983776e+00,  6.2297459e+00,  6.3848300e+00,
            8.4291229e+00,  7.1309576e+00,  1.0395646e+01,
            8.5736713e+00]],
 
         [[ 1.2675294e+00,  5.2844582e+00,  5.1331611e+00,
            8.9993315e+00,  1.0794343e+01,  1.4039831e+01,
            1.5731170e+01],
          [ 1.9084715e+01,  2.2191265e+01,  2.3481146e+01,
            2.5803375e+01,  2.8632090e+01,  3.0234968e+01,
            3.1886738e+01]]]], dtype=float32)>]
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-28.90071 , -23.052422, -19.851362],
       [-19.775568, -25.894997, -20.302256],
       [-21.10754 , -23.667885, -20.973007],
       [-19.249458, -20.87892 , -20.573763],
       [-22.351208, -25.457762, -24.648403]], dtype=float32)>

Por outro lado, a nossa cuidadosamente elaborado JointDistributionSequential não funciona mais:

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

try:
  jds_ia.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3,1] vs. [2,7] [Op:Mul]

Para corrigir isso, teríamos que adicionar um segundo tf.newaxis a ambos m e b combinar a forma e aumentar reinterpreted_batch_ndims a 2 na chamada para Independent . Nesse caso, deixar a máquina de lote automático lidar com os problemas de formato é mais curto, mais fácil e mais ergonômico.

Mais uma vez, notamos que, enquanto este notebook explorado JointDistributionSequentialAutoBatched , as outras variantes de JointDistribution tem equivalente AutoBatched . (Para usuários de JointDistributionCoroutine , JointDistributionCoroutineAutoBatched tem a vantagem adicional de que você não precisa mais especificar Root nós, se você nunca usou JointDistributionCoroutine . Você pode seguramente ignorar esta declaração)

Pensamentos Finais

Neste caderno, introduzimos JointDistributionSequentialAutoBatched e trabalhou através de um exemplo simples em detalhe. Espero que você tenha aprendido algo sobre formas TFP e autobatching!