Distributions conjointes par lots automatiques : un didacticiel en douceur

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHub Télécharger le cahier

introduction

Tensorflow Probabilité (TFP) offre un certain nombre d' JointDistribution abstractions qui rendent plus facile inférence probabiliste en permettant à un utilisateur d'exprimer facilement un modèle graphique probabiliste sous une forme quasi-mathématique; l'abstraction génère des méthodes d'échantillonnage à partir du modèle et d'évaluation de la probabilité logarithmique des échantillons du modèle. Dans ce tutoriel, nous passons en revue « autobatched » variantes qui ont été développées après l'original JointDistribution abstractions. Par rapport aux abstractions originales non mises en lots automatiques, les versions mises en lots automatiques sont plus simples à utiliser et plus ergonomiques, permettant à de nombreux modèles d'être exprimés avec moins de passe-partout. Dans cette collaboration, nous explorons un modèle simple dans les détails (peut-être fastidieux), clarifiant les problèmes que l'autobatching résout et (espérons-le) en enseignant davantage au lecteur sur les concepts de forme TFP en cours de route.

Avant l'introduction de autobatching, il y avait quelques variantes différentes de JointDistribution , correspondant à différents styles syntaxiques pour exprimer des modèles probabilistes: JointDistributionSequential , JointDistributionNamed et JointDistributionCoroutine . Auobatching existe en tant que mixin, donc nous avons maintenant AutoBatched variantes de tous ces aspects . Dans ce tutoriel, nous explorons les différences entre JointDistributionSequential et JointDistributionSequentialAutoBatched ; cependant, tout ce que nous faisons ici est applicable aux autres variantes sans pratiquement aucun changement.

Dépendances et prérequis

Importer et configurer

Prérequis : un problème de régression bayésienne

Considérons un scénario de régression bayésienne très simple :

\[ \begin{align*} m & \sim \text{Normal}(0, 1) \\ b & \sim \text{Normal}(0, 1) \\ Y & \sim \text{Normal}(mX + b, 1) \end{align*} \]

Dans ce modèle, m et b sont tirées à partir des normales standards, ainsi que des observations Y sont tirées d'une distribution normale dont la moyenne dépend des variables aléatoires m et b , et certains (non aléatoire, connue) covariables X . (Pour plus de simplicité, dans cet exemple, nous supposons que l'échelle de toutes les variables aléatoires est connue.)

Pour effectuer l' inférence dans ce modèle, nous aurions besoin de connaître les deux covariables X et les observations Y , mais pour les besoins de ce tutoriel, nous allons seulement besoin X , donc nous définissons simple mannequin X :

X = np.arange(7)
X
array([0, 1, 2, 3, 4, 5, 6])

Desiderata

Dans l'inférence probabiliste, nous voulons souvent effectuer deux opérations de base :

  • sample : Les échantillons de dessin à partir du modèle.
  • log_prob : calculer la probabilité de journal d'un échantillon du modèle.

La clé de contribution de la PGF JointDistribution des abstractions (ainsi que de nombreuses autres approches de la programmation probabiliste) est de permettre aux utilisateurs d'écrire un modèle une fois et avoir accès aux deux sample et log_prob calculs.

Notant que nous avons 7 points dans notre ensemble de données ( X.shape = (7,) ), nous pouvons maintenant affirmer les desiderata pour un excellent JointDistribution :

  • sample() doit produire une liste de Tensors ayant une forme [(), (), (7,) ], correspondant à la pente scalaire, observations polarisation scalaire et vecteur, respectivement.
  • log_prob(sample()) devrait produire un scalaire: la probabilité logarithmique de la pente, notamment polarisation, et des observations.
  • sample([5, 3]) doit produire une liste de Tensors ayant une forme [(5, 3), (5, 3), (5, 3, 7)] , ce qui représente un (5, 3) - lot d'échantillons à partir de le modèle.
  • log_prob(sample([5, 3])) devrait produire un Tensor de forme (5, 3).

Nous allons voir maintenant à une succession de JointDistribution modèles, voir comment atteindre les desiderata ci - dessus, et nous espérons apprendre un peu plus sur les formes TFP le long du chemin.

SPOILER ALERT: L'approche qui satisfait les desiderata ci - dessus sans passe- partout est ajouté autobatching .

Premier essai; JointDistributionSequential

jds = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

Il s'agit plus ou moins d'une traduction directe du modèle en code. La pente m et le biais b sont simples. Y est défini en utilisant un lambda -fonction: la tendance générale est que le lambda -fonction de \(k\) arguments dans un JointDistributionSequential (JDS) utilise les précédentes \(k\) distributions dans le modèle. Notez l'ordre "inverse".

Nous appellerons sample_distributions , qui retourne à la fois un échantillon et les sous - jacents « sous-distributions » qui ont été utilisées pour générer l'échantillon. (Nous aurions pu produire tout l'échantillon en appelant sample , plus tard dans le tutoriel , il sera pratique d'avoir les distributions.) L'échantillon que nous produisons est très bien:

dists, sample = jds.sample_distributions()
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

Mais log_prob produit un résultat avec une forme non souhaitée:

jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=
array([-4.4777603, -4.6775575, -4.7430477, -4.647725 , -4.5746684,
       -4.4368567, -4.480562 ], dtype=float32)>

Et l'échantillonnage multiple ne fonctionne pas :

try:
  jds.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

Essayons de comprendre ce qui ne va pas.

Un bref examen : forme de lot et d'événement

En TFP, un ordinaire (pas JointDistribution distribution de probabilité) a une forme d'événement et une forme de traitement par lots, et la compréhension de la différence est essentielle pour une utilisation efficace de la PGF:

  • La forme de l'événement décrit la forme d'un seul tirage de la distribution ; le tirage peut dépendre d'une dimension à l'autre. Pour les distributions scalaires, la forme de l'événement est []. Pour un MultivariateNormal à 5 ​​dimensions, la forme de l'événement est [5].
  • La forme du lot décrit des tirages indépendants et non distribués de manière identique, c'est-à-dire un « lot » de distributions. Représenter un lot de distributions dans un seul objet Python est l'un des principaux moyens par lesquels TFP atteint l'efficacité à grande échelle.

Pour nos besoins, un fait essentiel de garder à l' esprit est que si nous appelons log_prob sur un seul échantillon d'une distribution, le résultat sera toujours une forme adaptée ( par exemple, a les dimensions les plus à droite) la forme de lot.

Pour une discussion plus approfondie des formes, voir le tutoriel « Comprendre tensorflow Distributions formes » .

Pourquoi ne pas log_prob(sample()) Produire un Scalar?

Utilisons notre connaissance de la forme de lots et d' événements pour explorer ce qui se passe avec log_prob(sample()) . Voici à nouveau notre exemple :

sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

Et voici nos distributions :

dists
[<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>]

La probabilité de log est calculée en additionnant les probabilités de log des sous-distributions aux éléments (appariés) des parties :

log_prob_parts = [dist.log_prob(s) for (dist, s) in zip(dists, sample)]
log_prob_parts
[<tf.Tensor: shape=(), dtype=float32, numpy=-2.3113134>,
 <tf.Tensor: shape=(), dtype=float32, numpy=-1.1357536>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([-1.0306933, -1.2304904, -1.2959809, -1.200658 , -1.1276014,
        -0.9897899, -1.0334952], dtype=float32)>]
np.sum(log_prob_parts) - jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)>

Ainsi, un niveau d'explication est que le calcul de probabilité de journal retourne un 7-Tensor parce que le troisième sous - composant log_prob_parts est un 7-Tensor. Mais pourquoi?

Eh bien, nous voyons que le dernier élément de dists , ce qui correspond à notre distribution sur Y dans la formulation mathematial, a une batch_shape de [7] . En d' autres termes, notre distribution sur Y est un lot de 7 Normales indépendants (avec des moyens différents et, dans ce cas, la même échelle).

Nous comprenons maintenant ce qui est faux: dans JDS, la distribution sur Y a batch_shape=[7] , un échantillon de la JDS représente pour scalaires m et b et un « lot » de 7 Normales indépendants. et log_prob calcule 7 log-probabilités séparées, dont chacune représente la probabilité logarithmique de dessin m et b et une seule observation Y[i] à un certain X[i] .

Fixation log_prob(sample()) avec Independent

Rappelons que dists[2] a event_shape=[] et batch_shape=[7] :

dists[2]
<tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>

En utilisant de TFP Independent métadistribution, qui convertit les dimensions des lots aux dimensions de l' événement, nous pouvons convertir en une distribution avec event_shape=[7] et batch_shape=[] (nous allons le renommer y_dist_i parce qu'il est une distribution sur Y , avec le _i debout dans notre Independent emballage):

y_dist_i = tfd.Independent(dists[2], reinterpreted_batch_ndims=1)
y_dist_i
<tfp.distributions.Independent 'IndependentJointDistributionSequential_sample_distributions_Normal' batch_shape=[] event_shape=[7] dtype=float32>

Maintenant, le log_prob d'un 7-vecteur est un scalaire:

y_dist_i.log_prob(sample[2])
<tf.Tensor: shape=(), dtype=float32, numpy=-7.9087086>

Sous les couvertures, Independent sommes sur le lot:

y_dist_i.log_prob(sample[2]) - tf.reduce_sum(dists[2].log_prob(sample[2]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

Et en effet, nous pouvons l' utiliser pour construire une nouvelle jds_i (le i se distingue à nouveau pour Independent ) où log_prob renvoie un scalaire:

jds_i = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m*X + b, scale=1.),
        reinterpreted_batch_ndims=1)
])

jds_i.log_prob(sample)
<tf.Tensor: shape=(), dtype=float32, numpy=-11.355776>

Quelques remarques :

  • jds_i.log_prob(s) ne sont pas les mêmes que tf.reduce_sum(jds.log_prob(s)) . Le premier produit le log de probabilité « correct » de la distribution conjointe. Cette dernière somme sur un 7-Tensor, dont chaque élément correspond à la somme de la probabilité logarithmique de m , b , et un seul élément de la probabilité logarithmique de Y , de sorte qu'il overcounts m et b . ( log_prob(m) + log_prob(b) + log_prob(Y) renvoie un résultat plutôt que de lancer une exception parce TFP suit TF et les règles de diffusion de numpy;. En ajoutant un scalaire à un vecteur produit un résultat de vecteur de taille)
  • Dans ce cas particulier, nous aurions pu résoudre le problème et obtenir le même résultat en utilisant MultivariateNormalDiag au lieu de Independent(Normal(...)) . MultivariateNormalDiag est une distribution de valeurs vectorielles (elle a déjà vecteur événement en forme). Indeeed MultivariateNormalDiag pourrait être (mais pas) mis en œuvre en tant que composition Independent et Normal . Il vaut la peine de se rappeler que , étant donné un vecteur V , des échantillons de n1 = Normal(loc=V) , et n2 = MultivariateNormalDiag(loc=V) sont impossibles à distinguer; la différence beween ces distributions est que n1.log_prob(n1.sample()) est un vecteur et n2.log_prob(n2.sample()) est un scalaire.

Plusieurs échantillons ?

Le dessin de plusieurs échantillons ne fonctionne toujours pas :

try:
  jds_i.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

Réfléchissons au pourquoi. Lorsque nous appelons jds_i.sample([5, 3]) , on va d' abord dessiner des échantillons pour m et b , chacun avec la forme (5, 3) . Ensuite, nous allons essayer de construire une Normal de distribution via:

tfd.Normal(loc=m*X + b, scale=1.)

Mais si m a une forme (5, 3) et X a une forme 7 , nous ne pouvons pas les multiplier ensemble, et en effet c'est l'erreur que nous sommes percutant:

m = tfd.Normal(0., 1.).sample([5, 3])
try:
  m * X
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

Pour résoudre ce problème, nous allons réfléchir à ce que les propriétés de la distribution sur Y doit avoir. Si nous avons appelé jds_i.sample([5, 3]) , alors nous savons m et b seront tous les deux ont une forme (5, 3) . Quelle forme devrait un appel à sample sur le Y des produits de distribution? La réponse évidente est (5, 3, 7) : pour chaque point de traitement par lots, nous voulons un échantillon avec la même taille que X . Nous pouvons y parvenir en utilisant les capacités de diffusion de TensorFlow, en ajoutant des dimensions supplémentaires :

m[..., tf.newaxis].shape
TensorShape([5, 3, 1])
(m[..., tf.newaxis] * X).shape
TensorShape([5, 3, 7])

L' ajout d' un axe à la fois m et b , on peut définir un nouveau JDS qui prend en charge de multiples échantillons:

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

shaped_sample = jds_ia.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-1.1133379 ,  0.16390413, -0.24177533],
        [-1.1312429 , -0.6224666 , -1.8182136 ],
        [-0.31343174, -0.32932565,  0.5164407 ],
        [-0.0119963 , -0.9079621 ,  2.3655841 ],
        [-0.26293617,  0.8229698 ,  0.31098196]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-0.02876974,  1.0872147 ,  1.0138507 ],
        [ 0.27367726, -1.331534  , -0.09084719],
        [ 1.3349475 , -0.68765205,  1.680652  ],
        [ 0.75436825,  1.3050154 , -0.9415123 ],
        [-1.2502679 , -0.25730947,  0.74611956]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 7), dtype=float32, numpy=
 array([[[-1.8258233e+00, -3.0641669e-01, -2.7595463e+00, -1.6952467e+00,
          -4.8197951e+00, -5.2986512e+00, -6.6931367e+00],
         [ 3.6438566e-01,  1.0067395e+00,  1.4542470e+00,  8.1155670e-01,
           1.8868095e+00,  2.3877139e+00,  1.0195159e+00],
         [-8.3624744e-01,  1.2518480e+00,  1.0943471e+00,  1.3052304e+00,
          -4.5756745e-01, -1.0668410e-01, -7.0669651e-02]],
 
        [[-3.1788960e-01,  9.2615485e-03, -3.0963073e+00, -2.2846246e+00,
          -3.2269263e+00, -6.0213070e+00, -7.4806519e+00],
         [-3.9149747e+00, -3.5155020e+00, -1.5669601e+00, -5.0759468e+00,
          -4.5065498e+00, -5.6719379e+00, -4.8012795e+00],
         [ 1.3053948e-01, -8.0493152e-01, -4.7845001e+00, -4.9721808e+00,
          -7.1365709e+00, -9.6198196e+00, -9.7951422e+00]],
 
        [[ 2.0621397e+00,  3.4639853e-01,  7.0252883e-01, -1.4311566e+00,
           3.3790007e+00,  1.1619035e+00, -8.9105040e-01],
         [-7.8956139e-01, -8.5023916e-01, -9.7148323e-01, -2.6229355e+00,
          -2.7150445e+00, -2.4633870e+00, -2.1841538e+00],
         [ 7.7627432e-01,  2.2401071e+00,  3.7601702e+00,  2.4245868e+00,
           4.0690269e+00,  4.0605016e+00,  5.1753912e+00]],
 
        [[ 1.4275590e+00,  3.3346462e+00,  1.5374103e+00, -2.2849756e-01,
           9.1219616e-01, -3.1220305e-01, -3.2643962e-01],
         [-3.1910419e-02, -3.8848895e-01,  9.9946201e-02, -2.3619974e+00,
          -1.8507402e+00, -3.6830821e+00, -5.4907336e+00],
         [-7.1941972e-02,  2.1602919e+00,  4.9575748e+00,  4.2317696e+00,
           9.3528280e+00,  1.0526063e+01,  1.5262107e+01]],
 
        [[-2.3257759e+00, -2.5343289e+00, -3.5342445e+00, -4.0423255e+00,
          -3.2361765e+00, -3.3434000e+00, -2.6849220e+00],
         [ 1.5006512e-02, -1.9866472e-01,  7.6781356e-01,  1.6228745e+00,
           1.4191239e+00,  2.6655579e+00,  4.4663467e+00],
         [ 2.6599693e+00,  1.2663836e+00,  1.7162113e+00,  1.4839669e+00,
           2.0559487e+00,  2.5976877e+00,  2.5977583e+00]]], dtype=float32)>]
jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.483114 , -10.139662 , -11.514159 ],
       [-11.656767 , -17.201958 , -12.132455 ],
       [-17.838818 ,  -9.474525 , -11.24898  ],
       [-13.95219  , -12.490049 , -17.123957 ],
       [-14.487818 , -11.3755455, -10.576363 ]], dtype=float32)>

À titre de contrôle supplémentaire, nous vérifierons que la probabilité de journal pour un seul point de lot correspond à ce que nous avions auparavant :

(jds_ia.log_prob(shaped_sample)[3, 1] -
 jds_i.log_prob([shaped_sample[0][3, 1],
                 shaped_sample[1][3, 1],
                 shaped_sample[2][3, 1, :]]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

AutoBatching pour la victoire

Excellent! Nous avons maintenant une version de JointDistribution qui gère toutes nos desiderata: log_prob renvoie un scalaire de grâce à l'utilisation de tfd.Independent , et plusieurs échantillons travaillent maintenant que nous avons fixé la diffusion en ajoutant des axes supplémentaires.

Et si je vous disais qu'il existe un moyen plus simple et meilleur ? Il y a, et il est appelé JointDistributionSequentialAutoBatched (JDSAB):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
jds_ab.log_prob(jds.sample())
<tf.Tensor: shape=(), dtype=float32, numpy=-12.954952>
shaped_sample = jds_ab.sample([5, 3])
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.191533 , -10.43885  , -16.371655 ],
       [-13.292994 , -11.97949  , -16.788685 ],
       [-15.987699 , -13.435732 , -10.6029   ],
       [-10.184758 , -11.969714 , -14.275676 ],
       [-12.740775 , -11.5654125, -12.990162 ]], dtype=float32)>
jds_ab.log_prob(shaped_sample) - jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)>

Comment cela marche-t-il? Alors que vous pourriez essayer de lire le code pour une compréhension profonde, nous allons donner un bref aperçu qui est suffisant pour la plupart des cas d'utilisation:

  • Rappelons que notre premier problème était que notre distribution Y avait batch_shape=[7] et event_shape=[] , et nous avons utilisé Independent pour convertir la dimension du lot à une dimension de l' événement. JDSAB ignore les formes de lots des distributions de composants ; Au contraire , il traite la forme du lot comme une propriété globale du modèle, qui est supposée être [] (sauf indication contraire par la mise en batch_ndims > 0 ). L'effet est équivalent à utiliser tfd.Independent pour convertir toutes les dimensions des lots de distributions de composants dans des dimensions de l' événement, comme nous l'avons fait manuellement ci - dessus.
  • Notre deuxième problème était nécessaire de masser les formes de m et b afin qu'ils puissent diffuser de façon appropriée avec X lors de la création de plusieurs échantillons. Avec JDSAB, vous écrivez un modèle pour générer un échantillon unique, et nous « lever » l'ensemble du modèle pour générer plusieurs échantillons à l' aide de tensorflow de vectorized_map . (Cette fonction est analogue à Jax de vmap .)

Explorer la question de la forme de lots plus en détail, nous pouvons comparer les formes de traitement par lots de notre « mauvais » originale distribution conjointe jds , nos distributions de lots fixes jds_i et jds_ia , et notre autobatched jds_ab :

jds.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([7])]
jds_i.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ia.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ab.batch_shape
TensorShape([])

Nous voyons que l'original jds a différentes formes avec distribution secondaire de traitement par lots. jds_i et jds_ia résoudre ce problème en créant la forme avec distribution secondaire du lot même (vide). jds_ab ne dispose que d' une forme de lot unique (vide).

Il convient de noter que JointDistributionSequentialAutoBatched offre une certaine généralité supplémentaire gratuitement. Supposons que nous faisons covariables X (et, implicitement, les observations Y ) en deux dimensions:

X = np.arange(14).reshape((2, 7))
X
array([[ 0,  1,  2,  3,  4,  5,  6],
       [ 7,  8,  9, 10, 11, 12, 13]])

Notre JointDistributionSequentialAutoBatched fonctionne sans changement (nous devons redéfinir le modèle parce que la forme de X sont mises en cache par jds_ab.log_prob ):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

shaped_sample = jds_ab.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.1813647 , -0.85994506,  0.27593774],
        [-0.73323774,  1.1153806 ,  0.8841938 ],
        [ 0.5127983 , -0.29271227,  0.63733214],
        [ 0.2362284 , -0.919168  ,  1.6648189 ],
        [ 0.26317367,  0.73077047,  2.5395133 ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.09636458,  2.0138032 , -0.5054413 ],
        [ 0.63941646, -1.0785882 , -0.6442188 ],
        [ 1.2310615 , -0.3293852 ,  0.77637213],
        [ 1.2115169 , -0.98906034, -0.07816773],
        [-1.1318136 ,  0.510014  ,  1.036522  ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 2, 7), dtype=float32, numpy=
 array([[[[-1.9685398e+00, -1.6832136e+00, -6.9127172e-01,
            8.5992378e-01, -5.3123581e-01,  3.1584005e+00,
            2.9044402e+00],
          [-2.5645006e-01,  3.1554163e-01,  3.1186538e+00,
            1.4272424e+00,  1.2843871e+00,  1.2266440e+00,
            1.2798605e+00]],
 
         [[ 1.5973477e+00, -5.3631151e-01,  6.8143606e-03,
           -1.4910895e+00, -2.1568544e+00, -2.0513713e+00,
           -3.1663666e+00],
          [-4.9448099e+00, -2.8385928e+00, -6.9027486e+00,
           -5.6543546e+00, -7.2378774e+00, -8.1577444e+00,
           -9.3582869e+00]],
 
         [[-2.1233239e+00,  5.8853775e-02,  1.2024102e+00,
            1.6622503e+00, -1.9197327e-01,  1.8647723e+00,
            6.4322817e-01],
          [ 3.7549341e-01,  1.5853541e+00,  2.4594500e+00,
            2.1952972e+00,  1.7517658e+00,  2.9666045e+00,
            2.5468128e+00]]],
 
 
        [[[ 8.9906776e-01,  6.7375046e-01,  7.3354661e-01,
           -9.9894643e-01, -3.4606690e+00, -3.4810467e+00,
           -4.4315586e+00],
          [-3.0670738e+00, -6.3628020e+00, -6.2538433e+00,
           -6.8091092e+00, -7.7134805e+00, -8.6319380e+00,
           -8.6904278e+00]],
 
         [[-2.2462025e+00, -3.3060855e-01,  1.8974400e-01,
            3.1422038e+00,  4.1483402e+00,  3.5642972e+00,
            4.8709240e+00],
          [ 4.7880130e+00,  5.8790064e+00,  9.6695948e+00,
            7.8112822e+00,  1.2022618e+01,  1.2411858e+01,
            1.4323385e+01]],
 
         [[-1.0189297e+00, -7.8115642e-01,  1.6466728e+00,
            8.2378983e-01,  3.0765080e+00,  3.0170646e+00,
            5.1899948e+00],
          [ 6.5285158e+00,  7.8038850e+00,  6.4155884e+00,
            9.0899811e+00,  1.0040427e+01,  9.1404457e+00,
            1.0411951e+01]]],
 
 
        [[[ 4.5557004e-01,  1.4905317e+00,  1.4904103e+00,
            2.9777462e+00,  2.8620450e+00,  3.4745665e+00,
            3.8295493e+00],
          [ 3.9977460e+00,  5.7173767e+00,  7.8421035e+00,
            6.3180594e+00,  6.0838981e+00,  8.2257290e+00,
            9.6548376e+00]],
 
         [[-7.0750320e-01, -3.5972297e-01,  4.3136525e-01,
           -2.3301599e+00, -5.0374687e-01, -2.8338656e+00,
           -3.4453444e+00],
          [-3.1258626e+00, -3.4687450e+00, -1.2045374e+00,
           -4.0196013e+00, -5.8831010e+00, -4.2965469e+00,
           -4.1388311e+00]],
 
         [[ 2.1969774e+00,  2.4614549e+00,  2.2314475e+00,
            1.8392437e+00,  2.8367062e+00,  4.8600502e+00,
            4.2273531e+00],
          [ 6.1879644e+00,  5.1792760e+00,  6.1141996e+00,
            5.6517797e+00,  8.9979610e+00,  7.5938139e+00,
            9.7918644e+00]]],
 
 
        [[[ 1.5249090e+00,  1.1388919e+00,  8.6903995e-01,
            3.0762129e+00,  1.5128503e+00,  3.5204377e+00,
            2.4760864e+00],
          [ 3.4166217e+00,  3.5930209e+00,  3.1694956e+00,
            4.5797420e+00,  4.5271711e+00,  2.8774328e+00,
            4.7288942e+00]],
 
         [[-2.3095846e+00, -2.0595703e+00, -3.0093951e+00,
           -3.8594103e+00, -4.9681158e+00, -6.4256043e+00,
           -5.5345035e+00],
          [-6.4306297e+00, -7.0924540e+00, -8.4075985e+00,
           -1.0417805e+01, -1.1727266e+01, -1.1196255e+01,
           -1.1333830e+01]],
 
         [[-7.0419472e-01,  1.4568675e+00,  3.7946482e+00,
            4.8489718e+00,  6.6498446e+00,  9.0224218e+00,
            1.1153137e+01],
          [ 1.0060651e+01,  1.1998097e+01,  1.5326431e+01,
            1.7957514e+01,  1.8323889e+01,  2.0160881e+01,
            2.1269085e+01]]],
 
 
        [[[-2.2360647e-01, -1.3632748e+00, -7.2704530e-01,
            2.3558271e-01, -1.0381399e+00,  1.9387857e+00,
           -3.3694571e-01],
          [ 1.6015106e-01,  1.5284677e+00, -4.8567140e-01,
           -1.7770648e-01,  2.1919653e+00,  1.3015286e+00,
            1.3877077e+00]],
 
         [[ 1.3688663e+00,  2.6602898e+00,  6.6657305e-01,
            4.6554832e+00,  5.7781887e+00,  4.9115267e+00,
            4.8446012e+00],
          [ 5.1983776e+00,  6.2297459e+00,  6.3848300e+00,
            8.4291229e+00,  7.1309576e+00,  1.0395646e+01,
            8.5736713e+00]],
 
         [[ 1.2675294e+00,  5.2844582e+00,  5.1331611e+00,
            8.9993315e+00,  1.0794343e+01,  1.4039831e+01,
            1.5731170e+01],
          [ 1.9084715e+01,  2.2191265e+01,  2.3481146e+01,
            2.5803375e+01,  2.8632090e+01,  3.0234968e+01,
            3.1886738e+01]]]], dtype=float32)>]
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-28.90071 , -23.052422, -19.851362],
       [-19.775568, -25.894997, -20.302256],
       [-21.10754 , -23.667885, -20.973007],
       [-19.249458, -20.87892 , -20.573763],
       [-22.351208, -25.457762, -24.648403]], dtype=float32)>

D'autre part, notre conçu avec soin JointDistributionSequential ne fonctionne plus:

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

try:
  jds_ia.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3,1] vs. [2,7] [Op:Mul]

Pour résoudre ce problème, il faudrait ajouter un deuxième tf.newaxis à la fois m et b correspondre à la forme, et d' augmenter reinterpreted_batch_ndims à 2 dans l'appel à Independent . Dans ce cas, laisser les machines de dosage automatique gérer les problèmes de forme est plus court, plus facile et plus ergonomique.

Encore une fois, nous constatons que si ce bloc - notes exploré JointDistributionSequentialAutoBatched , les autres variantes de JointDistribution ont équivalent AutoBatched . (Pour les utilisateurs de JointDistributionCoroutine , JointDistributionCoroutineAutoBatched a l'avantage supplémentaire que vous ne avez plus besoin de spécifier Root nœuds, si vous ne l' avez jamais utilisé JointDistributionCoroutine . Vous pouvez ignorer en toute sécurité cette déclaration)

Pensées de conclusion

Dans ce cahier, nous avons introduit JointDistributionSequentialAutoBatched et travaillé par un exemple simple en détail. J'espère que vous avez appris quelque chose sur les formes TFP et sur l'autobatching !