Lihat di TensorFlow.org | Jalankan di Google Colab | Lihat sumber di GitHub | Unduh buku catatan |
pengantar
TensorFlow Probabilitas (TFP) menawarkan sejumlah JointDistribution
abstraksi yang membuat inferensi probabilistik lebih mudah dengan memungkinkan pengguna untuk dengan mudah mengekspresikan model grafis probabilistik dalam bentuk matematika dekat-; abstraksi menghasilkan metode untuk pengambilan sampel dari model dan mengevaluasi probabilitas log sampel dari model. Dalam tutorial ini, kita meninjau "autobatched" varian, yang dikembangkan setelah asli JointDistribution
abstraksi. Dibandingkan dengan abstraksi asli, non-autobatched, versi autobatched lebih mudah digunakan dan lebih ergonomis, memungkinkan banyak model diekspresikan dengan lebih sedikit boilerplate. Dalam colab ini, kami mengeksplorasi model sederhana dalam detail (mungkin membosankan), memperjelas masalah yang diselesaikan autobatching, dan (semoga) mengajari pembaca lebih banyak tentang konsep bentuk TFP di sepanjang jalan.
Sebelum pengenalan autobatching, ada varian yang berbeda dari JointDistribution
, sesuai dengan gaya sintaksis yang berbeda untuk mengekspresikan model probabilistik: JointDistributionSequential
, JointDistributionNamed
, dan JointDistributionCoroutine
. Auobatching ada sebagai mixin, jadi kita sekarang memiliki AutoBatched
varian dari semua ini. Dalam tutorial ini, kita mengeksplorasi perbedaan antara JointDistributionSequential
dan JointDistributionSequentialAutoBatched
; namun, semua yang kami lakukan di sini berlaku untuk varian lain tanpa perubahan pada dasarnya.
Dependensi & Prasyarat
Impor dan set up
import functools
import numpy as np
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import tensorflow_probability as tfp
tfd = tfp.distributions
Prasyarat: Masalah Regresi Bayesian
Kami akan mempertimbangkan skenario regresi Bayesian yang sangat sederhana:
\[ \begin{align*} m & \sim \text{Normal}(0, 1) \\ b & \sim \text{Normal}(0, 1) \\ Y & \sim \text{Normal}(mX + b, 1) \end{align*} \]
Dalam model ini, m
dan b
diambil dari normals standar, dan pengamatan Y
diambil dari distribusi normal yang rata-rata tergantung pada variabel-variabel acak m
dan b
, dan beberapa (nonrandom, dikenal) kovariat X
. (Untuk kesederhanaan, dalam contoh ini, kami menganggap skala semua variabel acak diketahui.)
Untuk melakukan inferensi dalam model ini, kami perlu tahu kedua kovariat X
dan pengamatan Y
, namun untuk tujuan tutorial ini, kita hanya perlu X
, jadi kita mendefinisikan dummy sederhana X
:
X = np.arange(7)
X
array([0, 1, 2, 3, 4, 5, 6])
Desiderata
Dalam inferensi probabilistik, kita sering ingin melakukan dua operasi dasar:
-
sample
: Menggambar sampel dari model. -
log_prob
: Komputasi probabilitas log dari sampel dari model.
Kontribusi kunci dari TFP ini JointDistribution
abstraksi (serta dari banyak pendekatan lain untuk pemrograman probabilistik) adalah untuk memungkinkan pengguna untuk menulis sebuah model sekali dan memiliki akses ke kedua sample
dan log_prob
perhitungan.
Mencatat bahwa kita memiliki 7 poin di set data kami ( X.shape = (7,)
), kita sekarang dapat menyatakan desiderata untuk sangat baik JointDistribution
:
-
sample()
harus menghasilkan daftarTensors
memiliki bentuk[(), (), (7,)
], sesuai dengan kemiringan skalar, Bias skalar, dan pengamatan vektor, masing-masing. -
log_prob(sample())
harus menghasilkan skalar: probabilitas log tertentu lereng, bias, dan pengamatan. -
sample([5, 3])
harus menghasilkan daftarTensors
memiliki bentuk[(5, 3), (5, 3), (5, 3, 7)]
, mewakili(5, 3)
- batch sampel dari model. -
log_prob(sample([5, 3]))
harus menghasilkanTensor
dengan bentuk (5, 3).
Kita sekarang akan melihat suksesi JointDistribution
model, melihat bagaimana untuk mencapai desiderata di atas, dan mudah-mudahan belajar sedikit lebih banyak tentang TFP membentuk sepanjang jalan.
Spoiler alert: Pendekatan yang memenuhi yang desiderata atas tanpa menambahkan boilerplate adalah autobatching .
Percobaan pertama; JointDistributionSequential
jds = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
Ini kurang lebih merupakan terjemahan langsung dari model ke dalam kode. Kemiringan m
dan bias b
adalah mudah. Y
didefinisikan menggunakan sebuah lambda
-fungsi: pola umum adalah bahwa lambda
-fungsi dari \(k\) argumen dalam JointDistributionSequential
(JDS) menggunakan sebelumnya \(k\) distribusi dalam model. Perhatikan urutan "terbalik".
Kami akan memanggil sample_distributions
, yang kembali baik sampel dan mendasari "sub-distribusi" yang digunakan untuk menghasilkan sampel. (Kita bisa diproduksi hanya sampel dengan menelepon sample
; nanti di tutorial itu akan mudah untuk memiliki distribusi juga.) Sampel kami memproduksi baik-baik saja:
dists, sample = jds.sample_distributions()
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>, <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>, <tf.Tensor: shape=(7,), dtype=float32, numpy= array([ 0.18573815, -1.79962 , -1.8106272 , -3.5971394 , -6.6625295 , -7.308844 , -9.832693 ], dtype=float32)>]
Tapi log_prob
menghasilkan hasil dengan bentuk yang tidak diinginkan:
jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy= array([-4.4777603, -4.6775575, -4.7430477, -4.647725 , -4.5746684, -4.4368567, -4.480562 ], dtype=float32)>
Dan beberapa pengambilan sampel tidak berfungsi:
try:
jds.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]
Mari kita coba memahami apa yang salah.
Ulasan Singkat: Bentuk Batch dan Acara
Dalam TFP, yang biasa (bukan JointDistribution
) distribusi probabilitas memiliki bentuk acara dan bentuk batch, dan memahami perbedaan adalah penting untuk penggunaan efektif TFP:
- Bentuk acara menggambarkan bentuk undian tunggal dari distribusi; pengundian mungkin tergantung lintas dimensi. Untuk distribusi skalar, bentuk kejadiannya adalah []. Untuk MultivariatNormal 5 dimensi, bentuk kejadiannya adalah [5].
- Bentuk batch menggambarkan undian yang independen dan tidak terdistribusi secara identik, alias "kumpulan" distribusi. Mewakili sekumpulan distribusi dalam satu objek Python adalah salah satu cara utama TFP mencapai efisiensi dalam skala besar.
Untuk tujuan kita, fakta penting untuk diingat adalah bahwa jika kita sebut log_prob
pada sampel tunggal dari distribusi, hasilnya selalu akan memiliki bentuk yang cocok (yaitu, memiliki sebagai dimensi paling kanan) bentuk batch.
Untuk pembahasan lebih mendalam tentang bentuk, lihat yang "Memahami TensorFlow Distribusi Shapes" tutorial .
Mengapa Apakah tidak log_prob(sample())
Menghasilkan skalar?
Mari kita gunakan pengetahuan kita tentang batch dan acara bentuk untuk mengeksplorasi apa yang terjadi dengan log_prob(sample())
. Ini contoh kami lagi:
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>, <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>, <tf.Tensor: shape=(7,), dtype=float32, numpy= array([ 0.18573815, -1.79962 , -1.8106272 , -3.5971394 , -6.6625295 , -7.308844 , -9.832693 ], dtype=float32)>]
Dan berikut adalah distribusi kami:
dists
[<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>, <tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>, <tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>]
Probabilitas log dihitung dengan menjumlahkan probabilitas log dari sub-distribusi pada elemen (yang cocok) dari bagian:
log_prob_parts = [dist.log_prob(s) for (dist, s) in zip(dists, sample)]
log_prob_parts
[<tf.Tensor: shape=(), dtype=float32, numpy=-2.3113134>, <tf.Tensor: shape=(), dtype=float32, numpy=-1.1357536>, <tf.Tensor: shape=(7,), dtype=float32, numpy= array([-1.0306933, -1.2304904, -1.2959809, -1.200658 , -1.1276014, -0.9897899, -1.0334952], dtype=float32)>]
np.sum(log_prob_parts) - jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)>
Jadi, satu tingkat dari penjelasan adalah bahwa perhitungan log probabilitas adalah mengembalikan 7-Tensor karena subkomponen ketiga log_prob_parts
adalah 7-Tensor. Tapi kenapa?
Nah, kita melihat bahwa elemen terakhir dari dists
, yang sesuai dengan distribusi kami lebih Y
dalam perumusan mathematial, memiliki batch_shape
dari [7]
. Dengan kata lain, distribusi kami lebih Y
adalah batch 7 normals independen (dengan cara yang berbeda dan, dalam hal ini, skala yang sama).
Kami sekarang mengerti apa yang salah: di JDS, distribusi lebih Y
memiliki batch_shape=[7]
, sampel dari JDS merupakan skalar untuk m
dan b
dan "batch" dari 7 normals independen. dan log_prob
menghitung 7 terpisah log-probabilitas, yang masing-masing mewakili probabilitas log menggambar m
dan b
dan pengamatan tunggal Y[i]
di beberapa X[i]
.
Memperbaiki log_prob(sample())
dengan Independent
Ingat bahwa dists[2]
memiliki event_shape=[]
dan batch_shape=[7]
:
dists[2]
<tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>
Dengan menggunakan TFP ini Independent
metadistribusi, yang mengubah dimensi batch untuk dimensi acara, kita dapat mengkonversi ini menjadi distribusi dengan event_shape=[7]
dan batch_shape=[]
(kami akan mengganti nama y_dist_i
karena distribusi pada Y
, dengan _i
berdiri untuk kami Independent
pembungkus):
y_dist_i = tfd.Independent(dists[2], reinterpreted_batch_ndims=1)
y_dist_i
<tfp.distributions.Independent 'IndependentJointDistributionSequential_sample_distributions_Normal' batch_shape=[] event_shape=[7] dtype=float32>
Sekarang, log_prob
dari 7-vektor adalah skalar:
y_dist_i.log_prob(sample[2])
<tf.Tensor: shape=(), dtype=float32, numpy=-7.9087086>
Di bawah selimut, Independent
jumlah lebih batch:
y_dist_i.log_prob(sample[2]) - tf.reduce_sum(dists[2].log_prob(sample[2]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>
Dan memang, kita dapat menggunakan ini untuk membangun baru jds_i
(yang i
lagi singkatan Independent
) di mana log_prob
mengembalikan skalar:
jds_i = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Independent( # Y
tfd.Normal(loc=m*X + b, scale=1.),
reinterpreted_batch_ndims=1)
])
jds_i.log_prob(sample)
<tf.Tensor: shape=(), dtype=float32, numpy=-11.355776>
Beberapa catatan:
-
jds_i.log_prob(s)
tidak sama dengantf.reduce_sum(jds.log_prob(s))
. Yang pertama menghasilkan probabilitas log yang "benar" dari distribusi gabungan. Jumlah yang terakhir selama 7-Tensor, setiap elemen yang merupakan jumlah dari probabilitas logm
,b
, dan satu elemen dari probabilitas logY
, sehingga overcountsm
danb
. (log_prob(m) + log_prob(b) + log_prob(Y)
mengembalikan hasilnya daripada membuang pengecualian karena TFP berikut TF dan aturan penyiaran NumPy ini;. Menambahkan skalar untuk vektor menghasilkan hasil vektor berukuran) - Dalam kasus ini, kita bisa memecahkan masalah dan mencapai hasil yang sama menggunakan
MultivariateNormalDiag
bukanIndependent(Normal(...))
.MultivariateNormalDiag
adalah distribusi vektor-dihargai (yaitu, sudah memiliki vektor-bentuk acara). Indeeed suatu berkatMultivariateNormalDiag
bisa (tetapi tidak) dilaksanakan sebagai komposisiIndependent
danNormal
. Hal ini bermanfaat untuk diingat bahwa diberi vektorV
, sampel darin1 = Normal(loc=V)
, dann2 = MultivariateNormalDiag(loc=V)
tidak dapat dibedakan; perbedaan beween distribusi ini adalah bahwan1.log_prob(n1.sample())
adalah vektor dann2.log_prob(n2.sample())
adalah skalar.
Beberapa Sampel?
Menggambar banyak sampel masih tidak berfungsi:
try:
jds_i.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]
Mari kita pikirkan alasannya. Ketika kita sebut jds_i.sample([5, 3])
, kami akan pertama mengambil contoh untuk m
dan b
, masing-masing dengan bentuk (5, 3)
. Berikutnya, kita akan mencoba untuk membangun sebuah Normal
distribusi melalui:
tfd.Normal(loc=m*X + b, scale=1.)
Tetapi jika m
memiliki bentuk (5, 3)
dan X
memiliki bentuk 7
, kita tidak bisa berkembang biak mereka bersama-sama, dan memang ini adalah kesalahan kita memukul sedang:
m = tfd.Normal(0., 1.).sample([5, 3])
try:
m * X
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]
Untuk mengatasi masalah ini, mari kita berpikir tentang sifat-sifat apa distribusi lebih Y
harus memiliki. Jika kita sudah menelepon jds_i.sample([5, 3])
, maka kita tahu m
dan b
keduanya akan memiliki bentuk (5, 3)
. Apa bentuk harus panggilan untuk sample
pada Y
menghasilkan distribusi? Jawaban yang jelas adalah (5, 3, 7)
: untuk setiap titik batch, kita ingin sampel dengan ukuran yang sama dengan X
. Kita dapat mencapainya dengan menggunakan kemampuan penyiaran TensorFlow, menambahkan dimensi ekstra:
m[..., tf.newaxis].shape
TensorShape([5, 3, 1])
(m[..., tf.newaxis] * X).shape
TensorShape([5, 3, 7])
Menambahkan sumbu kedua m
dan b
, kita dapat mendefinisikan JDS baru yang mendukung beberapa sampel:
jds_ia = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Independent( # Y
tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
reinterpreted_batch_ndims=1)
])
shaped_sample = jds_ia.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-1.1133379 , 0.16390413, -0.24177533], [-1.1312429 , -0.6224666 , -1.8182136 ], [-0.31343174, -0.32932565, 0.5164407 ], [-0.0119963 , -0.9079621 , 2.3655841 ], [-0.26293617, 0.8229698 , 0.31098196]], dtype=float32)>, <tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-0.02876974, 1.0872147 , 1.0138507 ], [ 0.27367726, -1.331534 , -0.09084719], [ 1.3349475 , -0.68765205, 1.680652 ], [ 0.75436825, 1.3050154 , -0.9415123 ], [-1.2502679 , -0.25730947, 0.74611956]], dtype=float32)>, <tf.Tensor: shape=(5, 3, 7), dtype=float32, numpy= array([[[-1.8258233e+00, -3.0641669e-01, -2.7595463e+00, -1.6952467e+00, -4.8197951e+00, -5.2986512e+00, -6.6931367e+00], [ 3.6438566e-01, 1.0067395e+00, 1.4542470e+00, 8.1155670e-01, 1.8868095e+00, 2.3877139e+00, 1.0195159e+00], [-8.3624744e-01, 1.2518480e+00, 1.0943471e+00, 1.3052304e+00, -4.5756745e-01, -1.0668410e-01, -7.0669651e-02]], [[-3.1788960e-01, 9.2615485e-03, -3.0963073e+00, -2.2846246e+00, -3.2269263e+00, -6.0213070e+00, -7.4806519e+00], [-3.9149747e+00, -3.5155020e+00, -1.5669601e+00, -5.0759468e+00, -4.5065498e+00, -5.6719379e+00, -4.8012795e+00], [ 1.3053948e-01, -8.0493152e-01, -4.7845001e+00, -4.9721808e+00, -7.1365709e+00, -9.6198196e+00, -9.7951422e+00]], [[ 2.0621397e+00, 3.4639853e-01, 7.0252883e-01, -1.4311566e+00, 3.3790007e+00, 1.1619035e+00, -8.9105040e-01], [-7.8956139e-01, -8.5023916e-01, -9.7148323e-01, -2.6229355e+00, -2.7150445e+00, -2.4633870e+00, -2.1841538e+00], [ 7.7627432e-01, 2.2401071e+00, 3.7601702e+00, 2.4245868e+00, 4.0690269e+00, 4.0605016e+00, 5.1753912e+00]], [[ 1.4275590e+00, 3.3346462e+00, 1.5374103e+00, -2.2849756e-01, 9.1219616e-01, -3.1220305e-01, -3.2643962e-01], [-3.1910419e-02, -3.8848895e-01, 9.9946201e-02, -2.3619974e+00, -1.8507402e+00, -3.6830821e+00, -5.4907336e+00], [-7.1941972e-02, 2.1602919e+00, 4.9575748e+00, 4.2317696e+00, 9.3528280e+00, 1.0526063e+01, 1.5262107e+01]], [[-2.3257759e+00, -2.5343289e+00, -3.5342445e+00, -4.0423255e+00, -3.2361765e+00, -3.3434000e+00, -2.6849220e+00], [ 1.5006512e-02, -1.9866472e-01, 7.6781356e-01, 1.6228745e+00, 1.4191239e+00, 2.6655579e+00, 4.4663467e+00], [ 2.6599693e+00, 1.2663836e+00, 1.7162113e+00, 1.4839669e+00, 2.0559487e+00, 2.5976877e+00, 2.5977583e+00]]], dtype=float32)>]
jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-12.483114 , -10.139662 , -11.514159 ], [-11.656767 , -17.201958 , -12.132455 ], [-17.838818 , -9.474525 , -11.24898 ], [-13.95219 , -12.490049 , -17.123957 ], [-14.487818 , -11.3755455, -10.576363 ]], dtype=float32)>
Sebagai pemeriksaan tambahan, kami akan memverifikasi bahwa probabilitas log untuk satu titik batch cocok dengan yang kami miliki sebelumnya:
(jds_ia.log_prob(shaped_sample)[3, 1] -
jds_i.log_prob([shaped_sample[0][3, 1],
shaped_sample[1][3, 1],
shaped_sample[2][3, 1, :]]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>
AutoBatching Untuk Kemenangan
Bagus sekali! Kami sekarang memiliki versi JointDistribution yang menangani semua kami desiderata: log_prob
kembali berkat skalar untuk penggunaan tfd.Independent
, dan beberapa sampel bekerja sekarang bahwa kita tetap penyiaran dengan menambahkan sumbu ekstra.
Bagaimana jika saya memberi tahu Anda bahwa ada cara yang lebih mudah dan lebih baik? Ada, dan itu disebut JointDistributionSequentialAutoBatched
(JDSAB):
jds_ab = tfd.JointDistributionSequentialAutoBatched([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
jds_ab.log_prob(jds.sample())
<tf.Tensor: shape=(), dtype=float32, numpy=-12.954952>
shaped_sample = jds_ab.sample([5, 3])
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-12.191533 , -10.43885 , -16.371655 ], [-13.292994 , -11.97949 , -16.788685 ], [-15.987699 , -13.435732 , -10.6029 ], [-10.184758 , -11.969714 , -14.275676 ], [-12.740775 , -11.5654125, -12.990162 ]], dtype=float32)>
jds_ab.log_prob(shaped_sample) - jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], dtype=float32)>
Bagaimana cara kerjanya? Meskipun Anda bisa mencoba untuk membaca kode untuk pemahaman yang mendalam, kami akan memberikan gambaran singkat yang cukup untuk sebagian besar kasus penggunaan:
- Ingat bahwa masalah pertama kami adalah bahwa distribusi kami untuk
Y
memilikibatch_shape=[7]
danevent_shape=[]
, dan kami digunakanIndependent
untuk mengkonversi dimensi batch untuk dimensi acara. JDSAB mengabaikan bentuk batch dari distribusi komponen; bukannya memperlakukan bentuk batch properti keseluruhan model, yang diasumsikan[]
(kecuali ditentukan lain dengan menetapkanbatch_ndims > 0
). Efeknya adalah setara dengan menggunakan tfd.Independent untuk mengkonversi semua dimensi batch distribusi komponen ke dimensi acara, seperti yang kita lakukan secara manual di atas. - Masalah kedua kami adalah kebutuhan untuk memijat bentuk
m
danb
sehingga mereka bisa menyiarkan secara tepat denganX
saat membuat beberapa sampel. Dengan JDSAB, Anda menulis model untuk menghasilkan sampel tunggal, dan kami "mengangkat" seluruh model untuk menghasilkan beberapa sampel menggunakan TensorFlow ini vectorized_map . (Fitur ini analog dengan JAX ini VMAP .)
Menjelajahi masalah bentuk batch yang lebih detail, kita bisa membandingkan bentuk batch kami asli "buruk" distribusi gabungan jds
, kami distribusi batch-tetap jds_i
dan jds_ia
, dan autobatched kami jds_ab
:
jds.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([7])]
jds_i.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ia.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ab.batch_shape
TensorShape([])
Kita melihat bahwa asli jds
memiliki subdistributions dengan bentuk batch yang berbeda. jds_i
dan jds_ia
memperbaiki ini dengan menciptakan subdistributions dengan (kosong) bentuk batch yang sama. jds_ab
hanya memiliki satu (kosong) bentuk batch.
Itu perlu dicatat bahwa JointDistributionSequentialAutoBatched
menawarkan beberapa umum tambahan gratis. Misalkan kita membuat kovariat X
(dan, secara implisit, pengamatan Y
) dua dimensi:
X = np.arange(14).reshape((2, 7))
X
array([[ 0, 1, 2, 3, 4, 5, 6], [ 7, 8, 9, 10, 11, 12, 13]])
Kami JointDistributionSequentialAutoBatched
bekerja dengan tidak ada perubahan (kita perlu mendefinisikan kembali model karena bentuk X
-cache oleh jds_ab.log_prob
):
jds_ab = tfd.JointDistributionSequentialAutoBatched([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
shaped_sample = jds_ab.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[ 0.1813647 , -0.85994506, 0.27593774], [-0.73323774, 1.1153806 , 0.8841938 ], [ 0.5127983 , -0.29271227, 0.63733214], [ 0.2362284 , -0.919168 , 1.6648189 ], [ 0.26317367, 0.73077047, 2.5395133 ]], dtype=float32)>, <tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[ 0.09636458, 2.0138032 , -0.5054413 ], [ 0.63941646, -1.0785882 , -0.6442188 ], [ 1.2310615 , -0.3293852 , 0.77637213], [ 1.2115169 , -0.98906034, -0.07816773], [-1.1318136 , 0.510014 , 1.036522 ]], dtype=float32)>, <tf.Tensor: shape=(5, 3, 2, 7), dtype=float32, numpy= array([[[[-1.9685398e+00, -1.6832136e+00, -6.9127172e-01, 8.5992378e-01, -5.3123581e-01, 3.1584005e+00, 2.9044402e+00], [-2.5645006e-01, 3.1554163e-01, 3.1186538e+00, 1.4272424e+00, 1.2843871e+00, 1.2266440e+00, 1.2798605e+00]], [[ 1.5973477e+00, -5.3631151e-01, 6.8143606e-03, -1.4910895e+00, -2.1568544e+00, -2.0513713e+00, -3.1663666e+00], [-4.9448099e+00, -2.8385928e+00, -6.9027486e+00, -5.6543546e+00, -7.2378774e+00, -8.1577444e+00, -9.3582869e+00]], [[-2.1233239e+00, 5.8853775e-02, 1.2024102e+00, 1.6622503e+00, -1.9197327e-01, 1.8647723e+00, 6.4322817e-01], [ 3.7549341e-01, 1.5853541e+00, 2.4594500e+00, 2.1952972e+00, 1.7517658e+00, 2.9666045e+00, 2.5468128e+00]]], [[[ 8.9906776e-01, 6.7375046e-01, 7.3354661e-01, -9.9894643e-01, -3.4606690e+00, -3.4810467e+00, -4.4315586e+00], [-3.0670738e+00, -6.3628020e+00, -6.2538433e+00, -6.8091092e+00, -7.7134805e+00, -8.6319380e+00, -8.6904278e+00]], [[-2.2462025e+00, -3.3060855e-01, 1.8974400e-01, 3.1422038e+00, 4.1483402e+00, 3.5642972e+00, 4.8709240e+00], [ 4.7880130e+00, 5.8790064e+00, 9.6695948e+00, 7.8112822e+00, 1.2022618e+01, 1.2411858e+01, 1.4323385e+01]], [[-1.0189297e+00, -7.8115642e-01, 1.6466728e+00, 8.2378983e-01, 3.0765080e+00, 3.0170646e+00, 5.1899948e+00], [ 6.5285158e+00, 7.8038850e+00, 6.4155884e+00, 9.0899811e+00, 1.0040427e+01, 9.1404457e+00, 1.0411951e+01]]], [[[ 4.5557004e-01, 1.4905317e+00, 1.4904103e+00, 2.9777462e+00, 2.8620450e+00, 3.4745665e+00, 3.8295493e+00], [ 3.9977460e+00, 5.7173767e+00, 7.8421035e+00, 6.3180594e+00, 6.0838981e+00, 8.2257290e+00, 9.6548376e+00]], [[-7.0750320e-01, -3.5972297e-01, 4.3136525e-01, -2.3301599e+00, -5.0374687e-01, -2.8338656e+00, -3.4453444e+00], [-3.1258626e+00, -3.4687450e+00, -1.2045374e+00, -4.0196013e+00, -5.8831010e+00, -4.2965469e+00, -4.1388311e+00]], [[ 2.1969774e+00, 2.4614549e+00, 2.2314475e+00, 1.8392437e+00, 2.8367062e+00, 4.8600502e+00, 4.2273531e+00], [ 6.1879644e+00, 5.1792760e+00, 6.1141996e+00, 5.6517797e+00, 8.9979610e+00, 7.5938139e+00, 9.7918644e+00]]], [[[ 1.5249090e+00, 1.1388919e+00, 8.6903995e-01, 3.0762129e+00, 1.5128503e+00, 3.5204377e+00, 2.4760864e+00], [ 3.4166217e+00, 3.5930209e+00, 3.1694956e+00, 4.5797420e+00, 4.5271711e+00, 2.8774328e+00, 4.7288942e+00]], [[-2.3095846e+00, -2.0595703e+00, -3.0093951e+00, -3.8594103e+00, -4.9681158e+00, -6.4256043e+00, -5.5345035e+00], [-6.4306297e+00, -7.0924540e+00, -8.4075985e+00, -1.0417805e+01, -1.1727266e+01, -1.1196255e+01, -1.1333830e+01]], [[-7.0419472e-01, 1.4568675e+00, 3.7946482e+00, 4.8489718e+00, 6.6498446e+00, 9.0224218e+00, 1.1153137e+01], [ 1.0060651e+01, 1.1998097e+01, 1.5326431e+01, 1.7957514e+01, 1.8323889e+01, 2.0160881e+01, 2.1269085e+01]]], [[[-2.2360647e-01, -1.3632748e+00, -7.2704530e-01, 2.3558271e-01, -1.0381399e+00, 1.9387857e+00, -3.3694571e-01], [ 1.6015106e-01, 1.5284677e+00, -4.8567140e-01, -1.7770648e-01, 2.1919653e+00, 1.3015286e+00, 1.3877077e+00]], [[ 1.3688663e+00, 2.6602898e+00, 6.6657305e-01, 4.6554832e+00, 5.7781887e+00, 4.9115267e+00, 4.8446012e+00], [ 5.1983776e+00, 6.2297459e+00, 6.3848300e+00, 8.4291229e+00, 7.1309576e+00, 1.0395646e+01, 8.5736713e+00]], [[ 1.2675294e+00, 5.2844582e+00, 5.1331611e+00, 8.9993315e+00, 1.0794343e+01, 1.4039831e+01, 1.5731170e+01], [ 1.9084715e+01, 2.2191265e+01, 2.3481146e+01, 2.5803375e+01, 2.8632090e+01, 3.0234968e+01, 3.1886738e+01]]]], dtype=float32)>]
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-28.90071 , -23.052422, -19.851362], [-19.775568, -25.894997, -20.302256], [-21.10754 , -23.667885, -20.973007], [-19.249458, -20.87892 , -20.573763], [-22.351208, -25.457762, -24.648403]], dtype=float32)>
Di sisi lain, kami hati-hati dibuat JointDistributionSequential
tidak lagi bekerja:
jds_ia = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Independent( # Y
tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
reinterpreted_batch_ndims=1)
])
try:
jds_ia.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3,1] vs. [2,7] [Op:Mul]
Untuk memperbaiki hal ini, kita harus menambahkan kedua tf.newaxis
untuk kedua m
dan b
sesuai dengan bentuk, dan peningkatan reinterpreted_batch_ndims
ke 2 dalam panggilan untuk Independent
. Dalam hal ini, membiarkan mesin batch otomatis menangani masalah bentuk lebih pendek, lebih mudah, dan lebih ergonomis.
Sekali lagi, kami mencatat bahwa sementara notebook ini dieksplorasi JointDistributionSequentialAutoBatched
, varian lain dari JointDistribution
memiliki setara AutoBatched
. (Untuk pengguna JointDistributionCoroutine
, JointDistributionCoroutineAutoBatched
memiliki manfaat tambahan yang Anda tidak perlu lagi untuk menentukan Root
node, jika Anda belum pernah menggunakan JointDistributionCoroutine
. Anda dapat dengan aman mengabaikan pernyataan ini)
Kesimpulan
Dalam notebook ini, kami memperkenalkan JointDistributionSequentialAutoBatched
dan bekerja melalui contoh sederhana secara rinci. Semoga Anda belajar sesuatu tentang bentuk TFP dan tentang autobatching!