عرض على TensorFlow.org | تشغيل في Google Colab | عرض المصدر على جيثب | تحميل دفتر |
مقدمة
TensorFlow الاحتمالية (TFP) يوفر عددا من JointDistribution
التجريدية التي تجعل الاستدلال احتمالي أسهل عن طريق السماح للمستخدم بسهولة التعبير عن نموذج رسومية احتمالي في شكل شبه الرياضية. يولد الاستخراج طرقًا لأخذ العينات من النموذج وتقييم احتمالية تسجيل العينات من النموذج. في هذا البرنامج التعليمي، نستعرض "autobatched" المتغيرات، والتي تم تطويرها بعد الأصلي JointDistribution
التجريد. بالنسبة إلى التجريدات الأصلية غير المطابقة تلقائيًا ، فإن الإصدارات التلقائية هي أبسط في الاستخدام وأكثر راحة ، مما يسمح بالتعبير عن العديد من النماذج باستخدام نموذج معياري أقل. في هذا الكولاب ، نستكشف نموذجًا بسيطًا بالتفصيل (ربما يكون مملاً) ، مع توضيح المشكلات التي تحلها تلقائيًا ، و (نأمل) تعليم القارئ المزيد حول مفاهيم شكل TFP على طول الطريق.
قبل إدخال autobatching، كان هناك عدد قليل من أنواع مختلفة من JointDistribution
الموافق الأنماط النحوية المختلفة للتعبير عن النماذج الاحتمالية: JointDistributionSequential
، JointDistributionNamed
، و JointDistributionCoroutine
. موجود Auobatching باعتباره mixin، لذلك لدينا الآن AutoBatched
المتغيرات من كل هذه. في هذا البرنامج التعليمي، واستكشاف الفروق بين JointDistributionSequential
و JointDistributionSequentialAutoBatched
. ومع ذلك ، فإن كل ما نقوم به هنا ينطبق على المتغيرات الأخرى مع عدم وجود تغييرات في الأساس.
التبعيات والمتطلبات
الاستيراد والإعداد
import functools
import numpy as np
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import tensorflow_probability as tfp
tfd = tfp.distributions
المتطلب السابق: مشكلة انحدار بايزي
سننظر في سيناريو انحدار بايزي بسيط للغاية:
\[ \begin{align*} m & \sim \text{Normal}(0, 1) \\ b & \sim \text{Normal}(0, 1) \\ Y & \sim \text{Normal}(mX + b, 1) \end{align*} \]
في هذا النموذج، m
و b
مستمدة من المعدلات القياسية، والملاحظات Y
مستمدة من التوزيع الطبيعي الذي يعني يعتمد على المتغيرات العشوائية m
و b
، وبعض (اعشوائي، المعروف) المتغيرات المشاركة X
. (من أجل التبسيط ، في هذا المثال ، نفترض أن مقياس جميع المتغيرات العشوائية معروف.)
لأداء الاستدلال في هذا النموذج، وكنا بحاجة إلى معرفة كل من المتغيرات المشاركة X
والملاحظات Y
، ولكن لأغراض هذا البرنامج التعليمي، سوف نحتاج فقط X
، لذلك نحدد دمية بسيطة X
:
X = np.arange(7)
X
array([0, 1, 2, 3, 4, 5, 6])
Desiderata
في الاستدلال الاحتمالي ، غالبًا ما نريد إجراء عمليتين أساسيتين:
-
sample
: رسم عينات من الطراز. -
log_prob
: الحوسبة احتمال سجل لعينة من النموذج.
المساهمة الرئيسية TFP في JointDistribution
تجريدات (وكذلك العديد من الطرق الأخرى لبرمجة الاحتمالية) هو السماح للمستخدمين الكتابة نموذجا مرة واحدة والحصول على كل sample
و log_prob
الحسابية.
مشيرا إلى أن لدينا 7 نقاط في مجموعة البيانات لدينا ( X.shape = (7,)
)، يمكننا الآن القول والأمنيات للممتازة JointDistribution
:
-
sample()
يجب أن ينتج قائمةTensors
وجود شكل[(), (), (7,)
]، الموافق المنحدر العددية، والتحيز العددية، والملاحظات متجه، على التوالي. -
log_prob(sample())
ينبغي أن تنتج العددية: احتمال سجل من منحدر معينة، والتحيز، والملاحظات. -
sample([5, 3])
ينبغي أن تصدر قائمةTensors
وجود شكل[(5, 3), (5, 3), (5, 3, 7)]
، وهو ما يمثل(5, 3)
- مجموعة من عينات من الموديل. -
log_prob(sample([5, 3]))
يجب أن تنتجTensor
مع شكل (5، 3).
سنقوم الآن ننظر في خلافة JointDistribution
النماذج، معرفة كيفية تحقيق الأمنيات المذكورة أعلاه، ونأمل أن تتعلم أكثر من ذلك بقليل عن TFP الأشكال على طول الطريق.
تنبيه المفسد: إن النهج الذي يرضي الأمنيات المذكورة أعلاه دون النمطي المضافة و autobatching .
المحاولة الأولى؛ JointDistributionSequential
jds = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
هذه ترجمة مباشرة للنموذج إلى حد ما إلى رمز. المنحدر m
والتحيز b
هي واضحة. Y
يعرف باستخدام lambda
الوظائف: النمط العام هو أن lambda
-function من \(k\) الحجج في JointDistributionSequential
(JDS) يستخدم السابقة \(k\) التوزيعات في النموذج. لاحظ الترتيب "العكسي".
وسوف ندعو sample_distributions
، الذي يعود على حد سواء العينة والكامنة "، توزيعات الفرعية" التي تم استخدامها لتوليد العينة. (ونحن يمكن أن تنتج فقط العينة عن طريق الاتصال sample
، في وقت لاحق في البرنامج التعليمي سوف تكون مريحة لديها توزيعات كذلك.) العينة التي ننتجها هي على ما يرام:
dists, sample = jds.sample_distributions()
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>, <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>, <tf.Tensor: shape=(7,), dtype=float32, numpy= array([ 0.18573815, -1.79962 , -1.8106272 , -3.5971394 , -6.6625295 , -7.308844 , -9.832693 ], dtype=float32)>]
لكن log_prob
تنتج نتيجة لذلك مع شكل غير مرغوب فيها:
jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy= array([-4.4777603, -4.6775575, -4.7430477, -4.647725 , -4.5746684, -4.4368567, -4.480562 ], dtype=float32)>
ولا يعمل أخذ العينات المتعددة:
try:
jds.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]
دعنا نحاول فهم الخطأ الذي يحدث.
مراجعة موجزة: شكل الدفعة والحدث
في TFP، عادية (وليس JointDistribution
) التوزيع الاحتمالي لديه شكل الحدث وشكل دفعة، وفهم الفرق أمر حاسم لفعالية استخدام TFP:
- يصف شكل الحدث شكل السحب الفردي من التوزيع ؛ قد يعتمد السحب على الأبعاد. بالنسبة للتوزيعات العددية ، يكون شكل الحدث هو []. بالنسبة إلى متعدد المتغيرات خماسي الأبعاد عادي ، يكون شكل الحدث هو [5].
- يصف شكل الدُفعات عمليات سحب مستقلة وغير موزعة بشكل متماثل ، ويعرف أيضًا باسم "دفعة" من التوزيعات. يمثل تمثيل مجموعة من التوزيعات في كائن Python واحد إحدى الطرق الرئيسية التي يحقق بها TFP الكفاءة على نطاق واسع.
لأغراضنا، وهذه حقيقة حرج أن نأخذ في الاعتبار هو أنه إذا كان نسميه log_prob
على عينة واحدة من توزيع، والنتيجة سوف يكون دائما على الشكل الذي مباريات (أي له كما أبعاد أقصى اليمين) على شكل دفعة واحدة.
للاطلاع على مناقشة أكثر تعمقا من الأشكال، انظر إلى "فهم TensorFlow التوزيعات الأشكال" البرنامج التعليمي .
لماذا لا log_prob(sample())
إنتاج مجموعة عددي؟
دعونا نستخدم معرفتنا دفعة وحدث شكل لاستكشاف ما يحدث مع log_prob(sample())
. ها هي عينتنا مرة أخرى:
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>, <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>, <tf.Tensor: shape=(7,), dtype=float32, numpy= array([ 0.18573815, -1.79962 , -1.8106272 , -3.5971394 , -6.6625295 , -7.308844 , -9.832693 ], dtype=float32)>]
وهنا توزيعاتنا:
dists
[<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>, <tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>, <tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>]
يتم حساب احتمالية السجل عن طريق جمع احتمالات سجل التوزيعات الفرعية في العناصر (المتطابقة) للأجزاء:
log_prob_parts = [dist.log_prob(s) for (dist, s) in zip(dists, sample)]
log_prob_parts
[<tf.Tensor: shape=(), dtype=float32, numpy=-2.3113134>, <tf.Tensor: shape=(), dtype=float32, numpy=-1.1357536>, <tf.Tensor: shape=(7,), dtype=float32, numpy= array([-1.0306933, -1.2304904, -1.2959809, -1.200658 , -1.1276014, -0.9897899, -1.0334952], dtype=float32)>]
np.sum(log_prob_parts) - jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)>
لذلك، مستوى واحد من التفسير هو أن حساب سجل احتمال بإرجاع 7 التنسور لأن مكون الثالث من log_prob_parts
هو 7 التنسور. لكن لماذا؟
حسنا، ونحن نرى أن العنصر الأخير من dists
، والتي تتطابق مع التوزيع لدينا أكثر من Y
في صياغة mathematial، لديه batch_shape
من [7]
. وبعبارة أخرى، توزيعنا على Y
هو مجموعة من 7 المعدلات مستقلة (مع وسائل مختلفة، وفي هذه الحالة، نفس النطاق).
ونحن نفهم الآن ما هو الخطأ: في JDS، وتوزيع أكثر من Y
ديه batch_shape=[7]
، وشمل عينة من JDS تمثل سكالارس ل m
و b
و "دفعة" من 7 المعدلات مستقلة. و log_prob
يحسب 7 منفصلة سجل-الاحتمالات، يمثل كل منها احتمال سجل من رسم m
و b
والمراقبة واحدة Y[i]
في بعض X[i]
.
تحديد log_prob(sample())
مع Independent
يذكر أن dists[2]
له event_shape=[]
و batch_shape=[7]
:
dists[2]
<tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>
باستخدام TFP و Independent
metadistribution، والذي يحول أبعاد دفعة لأبعاد الحدث، يمكننا تحويل هذا إلى توزيع مع event_shape=[7]
و batch_shape=[]
(وسوف تسميته y_dist_i
لانها التوزيع على Y
، مع _i
مكانة في لدينا Independent
التفاف):
y_dist_i = tfd.Independent(dists[2], reinterpreted_batch_ndims=1)
y_dist_i
<tfp.distributions.Independent 'IndependentJointDistributionSequential_sample_distributions_Normal' batch_shape=[] event_shape=[7] dtype=float32>
الآن، و log_prob
من ناقلات 7 هو العددية:
y_dist_i.log_prob(sample[2])
<tf.Tensor: shape=(), dtype=float32, numpy=-7.9087086>
تحت الأغطية، Independent
مبالغ تزيد على المبلغ دفعة:
y_dist_i.log_prob(sample[2]) - tf.reduce_sum(dists[2].log_prob(sample[2]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>
وفي الواقع، يمكننا استخدام هذه لبناء جديد jds_i
(و i
تقف مرة أخرى ل Independent
) حيث log_prob
إرجاع العددية:
jds_i = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Independent( # Y
tfd.Normal(loc=m*X + b, scale=1.),
reinterpreted_batch_ndims=1)
])
jds_i.log_prob(sample)
<tf.Tensor: shape=(), dtype=float32, numpy=-11.355776>
بضع ملاحظات:
-
jds_i.log_prob(s)
ليست هي نفسها كماtf.reduce_sum(jds.log_prob(s))
. الأول ينتج احتمالية السجل "الصحيح" للتوزيع المشترك. المبالغ الأخيرة أكثر من 7 التنسور، كل عنصر منها هو مجموع احتمال سجل منm
،b
، وعنصر واحد من احتمال سجلY
، لذلك overcountsm
وb
. (log_prob(m) + log_prob(b) + log_prob(Y)
بإرجاع نتيجة لذلك بدلا من رمي استثناء لأن TFP يلي TF وقواعد البث نمباي؛ و. إضافة العددية لناقلات تنتج نتيجة ناقلات الحجم) - في هذه الحالة بالذات، فإننا يمكن أن تحل المشكلة وحقق نفس النتيجة باستخدام
MultivariateNormalDiag
بدلا منIndependent(Normal(...))
.MultivariateNormalDiag
هو توزيع قيمتها ناقلات (أي أنها لديها بالفعل ناقلات الحدث شكل). IndeeedMultivariateNormalDiag
يمكن أن يكون (ولكن لا) تنفيذها باعتبارها تكوينIndependent
وNormal
. أنه من المفيد أن نتذكر أنه نظرا متجهV
، وعينات منn1 = Normal(loc=V)
، وn2 = MultivariateNormalDiag(loc=V)
لا يمكن تمييزها. الفرق beween هذه التوزيعات هو أنn1.log_prob(n1.sample())
هو متجه وn2.log_prob(n2.sample())
هو العددية.
عينات متعددة؟
لا يزال رسم عينات متعددة لا يعمل:
try:
jds_i.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]
دعونا نفكر لماذا. عندما ندعو jds_i.sample([5, 3])
، سنقوم أولا أخذ عينات ل m
و b
، ولكل منها شكل (5, 3)
. المقبل، ونحن ذاهبون لمحاولة بناء Normal
التوزيع عن طريق:
tfd.Normal(loc=m*X + b, scale=1.)
ولكن إذا m
له شكل (5, 3)
و X
لديها شكل 7
، لا نستطيع ضرب بعضهم البعض، بل وهذا هو الخطأ نحن ضرب:
m = tfd.Normal(0., 1.).sample([5, 3])
try:
m * X
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]
لحل هذه المشكلة، دعونا نفكر في ما خصائص التوزيع على Y
أن يكون. وإذا كنا قد دعا jds_i.sample([5, 3])
، ثم نحن نعرف m
و b
سيكون شكل على حد سواء (5, 3)
. ماهية طبيعة دعوة لل sample
على Y
المنتجات التوزيع؟ الجواب الواضح هو (5, 3, 7)
: لكل نقطة دفعة، ونحن نريد عينة مع نفس حجم X
. يمكننا تحقيق ذلك باستخدام إمكانيات البث في TensorFlow ، مع إضافة أبعاد إضافية:
m[..., tf.newaxis].shape
TensorShape([5, 3, 1])
(m[..., tf.newaxis] * X).shape
TensorShape([5, 3, 7])
إضافة محور لكلا m
و b
، يمكننا تحديد JDS الجديد الذي يدعم عينات متعددة:
jds_ia = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Independent( # Y
tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
reinterpreted_batch_ndims=1)
])
shaped_sample = jds_ia.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-1.1133379 , 0.16390413, -0.24177533], [-1.1312429 , -0.6224666 , -1.8182136 ], [-0.31343174, -0.32932565, 0.5164407 ], [-0.0119963 , -0.9079621 , 2.3655841 ], [-0.26293617, 0.8229698 , 0.31098196]], dtype=float32)>, <tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-0.02876974, 1.0872147 , 1.0138507 ], [ 0.27367726, -1.331534 , -0.09084719], [ 1.3349475 , -0.68765205, 1.680652 ], [ 0.75436825, 1.3050154 , -0.9415123 ], [-1.2502679 , -0.25730947, 0.74611956]], dtype=float32)>, <tf.Tensor: shape=(5, 3, 7), dtype=float32, numpy= array([[[-1.8258233e+00, -3.0641669e-01, -2.7595463e+00, -1.6952467e+00, -4.8197951e+00, -5.2986512e+00, -6.6931367e+00], [ 3.6438566e-01, 1.0067395e+00, 1.4542470e+00, 8.1155670e-01, 1.8868095e+00, 2.3877139e+00, 1.0195159e+00], [-8.3624744e-01, 1.2518480e+00, 1.0943471e+00, 1.3052304e+00, -4.5756745e-01, -1.0668410e-01, -7.0669651e-02]], [[-3.1788960e-01, 9.2615485e-03, -3.0963073e+00, -2.2846246e+00, -3.2269263e+00, -6.0213070e+00, -7.4806519e+00], [-3.9149747e+00, -3.5155020e+00, -1.5669601e+00, -5.0759468e+00, -4.5065498e+00, -5.6719379e+00, -4.8012795e+00], [ 1.3053948e-01, -8.0493152e-01, -4.7845001e+00, -4.9721808e+00, -7.1365709e+00, -9.6198196e+00, -9.7951422e+00]], [[ 2.0621397e+00, 3.4639853e-01, 7.0252883e-01, -1.4311566e+00, 3.3790007e+00, 1.1619035e+00, -8.9105040e-01], [-7.8956139e-01, -8.5023916e-01, -9.7148323e-01, -2.6229355e+00, -2.7150445e+00, -2.4633870e+00, -2.1841538e+00], [ 7.7627432e-01, 2.2401071e+00, 3.7601702e+00, 2.4245868e+00, 4.0690269e+00, 4.0605016e+00, 5.1753912e+00]], [[ 1.4275590e+00, 3.3346462e+00, 1.5374103e+00, -2.2849756e-01, 9.1219616e-01, -3.1220305e-01, -3.2643962e-01], [-3.1910419e-02, -3.8848895e-01, 9.9946201e-02, -2.3619974e+00, -1.8507402e+00, -3.6830821e+00, -5.4907336e+00], [-7.1941972e-02, 2.1602919e+00, 4.9575748e+00, 4.2317696e+00, 9.3528280e+00, 1.0526063e+01, 1.5262107e+01]], [[-2.3257759e+00, -2.5343289e+00, -3.5342445e+00, -4.0423255e+00, -3.2361765e+00, -3.3434000e+00, -2.6849220e+00], [ 1.5006512e-02, -1.9866472e-01, 7.6781356e-01, 1.6228745e+00, 1.4191239e+00, 2.6655579e+00, 4.4663467e+00], [ 2.6599693e+00, 1.2663836e+00, 1.7162113e+00, 1.4839669e+00, 2.0559487e+00, 2.5976877e+00, 2.5977583e+00]]], dtype=float32)>]
jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-12.483114 , -10.139662 , -11.514159 ], [-11.656767 , -17.201958 , -12.132455 ], [-17.838818 , -9.474525 , -11.24898 ], [-13.95219 , -12.490049 , -17.123957 ], [-14.487818 , -11.3755455, -10.576363 ]], dtype=float32)>
كتحقق إضافي ، سنتحقق من أن احتمالية السجل لنقطة دفعة واحدة تتطابق مع ما كان لدينا من قبل:
(jds_ia.log_prob(shaped_sample)[3, 1] -
jds_i.log_prob([shaped_sample[0][3, 1],
shaped_sample[1][3, 1],
shaped_sample[2][3, 1, :]]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>
الدفعة التلقائية من أجل الفوز
ممتاز! لدينا الآن نسخة من JointDistribution الذي يعالج جميع لدينا الأمنيات: log_prob
عائدات بفضل العددية لاستخدام tfd.Independent
، وعينات متعددة تعمل الآن أننا الثابتة البث بإضافة محاور إضافية.
ماذا لو أخبرتك أن هناك طريقة أسهل وأفضل؟ هناك، وانه دعا JointDistributionSequentialAutoBatched
(JDSAB):
jds_ab = tfd.JointDistributionSequentialAutoBatched([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
jds_ab.log_prob(jds.sample())
<tf.Tensor: shape=(), dtype=float32, numpy=-12.954952>
shaped_sample = jds_ab.sample([5, 3])
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-12.191533 , -10.43885 , -16.371655 ], [-13.292994 , -11.97949 , -16.788685 ], [-15.987699 , -13.435732 , -10.6029 ], [-10.184758 , -11.969714 , -14.275676 ], [-12.740775 , -11.5654125, -12.990162 ]], dtype=float32)>
jds_ab.log_prob(shaped_sample) - jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], dtype=float32)>
كيف يعمل هذا؟ في الوقت الذي يمكن أن محاولة قراءة رمز لفهم عميق، وسوف نقدم لمحة موجزة وهو ما يكفي لمعظم حالات الاستخدام:
- يذكر أن مشكلتنا الأولى هي أن لدينا توزيع
Y
زيارتهاbatch_shape=[7]
وevent_shape=[]
، وكناIndependent
لتحويل البعد دفعة إلى البعد الحدث. يتجاهل JDSAB الأشكال الدفعية لتوزيعات المكونات ؛ بدلا من ذلك فإنه يعامل شكل دفعة كخاصية العامة للنموذج، الذي يفترض أن يكون[]
(ما لم ينص على خلاف ذلك عن طريق وضعbatch_ndims > 0
). تأثير ما يعادل باستخدام tfd.Independent لتحويل جميع أبعاد دفعة من توزيعات المكونة في أبعاد الحدث، كما فعلنا يدويا أعلاه. - وكان لدينا مشكلة ثانية حاجة لتدليك أشكال
m
وb
حتى يتمكنوا من بث مناسب معX
عند إنشاء عينات متعددة. مع JDSAB، أن تكتب نموذج لتوليد عينة واحدة، ونحن "رفع" نموذج كامل لتوليد نماذج متعددة باستخدام TensorFlow في vectorized_map . (هذه الميزة analagous إلى JAX في vmap ).
استكشاف قضية شكل دفعة في مزيد من التفاصيل، يمكن لنا أن نقارن الأشكال دفعة من لدينا "سيئة" الأصلي التوزيع المشتركة jds
، لدينا ثابتة دفعة توزيعات jds_i
و jds_ia
، ولدينا autobatched jds_ab
:
jds.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([7])]
jds_i.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ia.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ab.batch_shape
TensorShape([])
ونحن نرى أن أصلي jds
ديها subdistributions مع الأشكال دفعة مختلفة. jds_i
و jds_ia
إصلاح ذلك عن طريق إنشاء subdistributions مع نفسه (فارغة) شكل دفعة واحدة. jds_ab
ليس لديها سوى واحد (فارغ) شكل دفعة واحدة.
ومن الجدير بالذكر أن JointDistributionSequentialAutoBatched
يقدم بعض عمومية إضافية مجانا. لنفترض يمكننا أن نجعل من المتغيرات المشاركة X
(وضمنيا، الملاحظات Y
) ثنائي الأبعاد:
X = np.arange(14).reshape((2, 7))
X
array([[ 0, 1, 2, 3, 4, 5, 6], [ 7, 8, 9, 10, 11, 12, 13]])
لدينا JointDistributionSequentialAutoBatched
يعمل مع أي تغييرات (نحن بحاجة إلى إعادة تعريف نموذج لشكل X
هو مؤقتا بواسطة jds_ab.log_prob
):
jds_ab = tfd.JointDistributionSequentialAutoBatched([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
shaped_sample = jds_ab.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[ 0.1813647 , -0.85994506, 0.27593774], [-0.73323774, 1.1153806 , 0.8841938 ], [ 0.5127983 , -0.29271227, 0.63733214], [ 0.2362284 , -0.919168 , 1.6648189 ], [ 0.26317367, 0.73077047, 2.5395133 ]], dtype=float32)>, <tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[ 0.09636458, 2.0138032 , -0.5054413 ], [ 0.63941646, -1.0785882 , -0.6442188 ], [ 1.2310615 , -0.3293852 , 0.77637213], [ 1.2115169 , -0.98906034, -0.07816773], [-1.1318136 , 0.510014 , 1.036522 ]], dtype=float32)>, <tf.Tensor: shape=(5, 3, 2, 7), dtype=float32, numpy= array([[[[-1.9685398e+00, -1.6832136e+00, -6.9127172e-01, 8.5992378e-01, -5.3123581e-01, 3.1584005e+00, 2.9044402e+00], [-2.5645006e-01, 3.1554163e-01, 3.1186538e+00, 1.4272424e+00, 1.2843871e+00, 1.2266440e+00, 1.2798605e+00]], [[ 1.5973477e+00, -5.3631151e-01, 6.8143606e-03, -1.4910895e+00, -2.1568544e+00, -2.0513713e+00, -3.1663666e+00], [-4.9448099e+00, -2.8385928e+00, -6.9027486e+00, -5.6543546e+00, -7.2378774e+00, -8.1577444e+00, -9.3582869e+00]], [[-2.1233239e+00, 5.8853775e-02, 1.2024102e+00, 1.6622503e+00, -1.9197327e-01, 1.8647723e+00, 6.4322817e-01], [ 3.7549341e-01, 1.5853541e+00, 2.4594500e+00, 2.1952972e+00, 1.7517658e+00, 2.9666045e+00, 2.5468128e+00]]], [[[ 8.9906776e-01, 6.7375046e-01, 7.3354661e-01, -9.9894643e-01, -3.4606690e+00, -3.4810467e+00, -4.4315586e+00], [-3.0670738e+00, -6.3628020e+00, -6.2538433e+00, -6.8091092e+00, -7.7134805e+00, -8.6319380e+00, -8.6904278e+00]], [[-2.2462025e+00, -3.3060855e-01, 1.8974400e-01, 3.1422038e+00, 4.1483402e+00, 3.5642972e+00, 4.8709240e+00], [ 4.7880130e+00, 5.8790064e+00, 9.6695948e+00, 7.8112822e+00, 1.2022618e+01, 1.2411858e+01, 1.4323385e+01]], [[-1.0189297e+00, -7.8115642e-01, 1.6466728e+00, 8.2378983e-01, 3.0765080e+00, 3.0170646e+00, 5.1899948e+00], [ 6.5285158e+00, 7.8038850e+00, 6.4155884e+00, 9.0899811e+00, 1.0040427e+01, 9.1404457e+00, 1.0411951e+01]]], [[[ 4.5557004e-01, 1.4905317e+00, 1.4904103e+00, 2.9777462e+00, 2.8620450e+00, 3.4745665e+00, 3.8295493e+00], [ 3.9977460e+00, 5.7173767e+00, 7.8421035e+00, 6.3180594e+00, 6.0838981e+00, 8.2257290e+00, 9.6548376e+00]], [[-7.0750320e-01, -3.5972297e-01, 4.3136525e-01, -2.3301599e+00, -5.0374687e-01, -2.8338656e+00, -3.4453444e+00], [-3.1258626e+00, -3.4687450e+00, -1.2045374e+00, -4.0196013e+00, -5.8831010e+00, -4.2965469e+00, -4.1388311e+00]], [[ 2.1969774e+00, 2.4614549e+00, 2.2314475e+00, 1.8392437e+00, 2.8367062e+00, 4.8600502e+00, 4.2273531e+00], [ 6.1879644e+00, 5.1792760e+00, 6.1141996e+00, 5.6517797e+00, 8.9979610e+00, 7.5938139e+00, 9.7918644e+00]]], [[[ 1.5249090e+00, 1.1388919e+00, 8.6903995e-01, 3.0762129e+00, 1.5128503e+00, 3.5204377e+00, 2.4760864e+00], [ 3.4166217e+00, 3.5930209e+00, 3.1694956e+00, 4.5797420e+00, 4.5271711e+00, 2.8774328e+00, 4.7288942e+00]], [[-2.3095846e+00, -2.0595703e+00, -3.0093951e+00, -3.8594103e+00, -4.9681158e+00, -6.4256043e+00, -5.5345035e+00], [-6.4306297e+00, -7.0924540e+00, -8.4075985e+00, -1.0417805e+01, -1.1727266e+01, -1.1196255e+01, -1.1333830e+01]], [[-7.0419472e-01, 1.4568675e+00, 3.7946482e+00, 4.8489718e+00, 6.6498446e+00, 9.0224218e+00, 1.1153137e+01], [ 1.0060651e+01, 1.1998097e+01, 1.5326431e+01, 1.7957514e+01, 1.8323889e+01, 2.0160881e+01, 2.1269085e+01]]], [[[-2.2360647e-01, -1.3632748e+00, -7.2704530e-01, 2.3558271e-01, -1.0381399e+00, 1.9387857e+00, -3.3694571e-01], [ 1.6015106e-01, 1.5284677e+00, -4.8567140e-01, -1.7770648e-01, 2.1919653e+00, 1.3015286e+00, 1.3877077e+00]], [[ 1.3688663e+00, 2.6602898e+00, 6.6657305e-01, 4.6554832e+00, 5.7781887e+00, 4.9115267e+00, 4.8446012e+00], [ 5.1983776e+00, 6.2297459e+00, 6.3848300e+00, 8.4291229e+00, 7.1309576e+00, 1.0395646e+01, 8.5736713e+00]], [[ 1.2675294e+00, 5.2844582e+00, 5.1331611e+00, 8.9993315e+00, 1.0794343e+01, 1.4039831e+01, 1.5731170e+01], [ 1.9084715e+01, 2.2191265e+01, 2.3481146e+01, 2.5803375e+01, 2.8632090e+01, 3.0234968e+01, 3.1886738e+01]]]], dtype=float32)>]
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-28.90071 , -23.052422, -19.851362], [-19.775568, -25.894997, -20.302256], [-21.10754 , -23.667885, -20.973007], [-19.249458, -20.87892 , -20.573763], [-22.351208, -25.457762, -24.648403]], dtype=float32)>
من ناحية أخرى، لدينا بعناية JointDistributionSequential
لم يعد يعمل:
jds_ia = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Independent( # Y
tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
reinterpreted_batch_ndims=1)
])
try:
jds_ia.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3,1] vs. [2,7] [Op:Mul]
لحل هذه المشكلة، سيكون لدينا لإضافة ثاني tf.newaxis
لكلا m
و b
تتناسب مع الشكل، وزيادة reinterpreted_batch_ndims
إلى 2 في استدعاء Independent
. في هذه الحالة ، يكون السماح لآلة الخلط التلقائي بالتعامل مع مشكلات الشكل أقصر وأسهل وأكثر راحة.
مرة أخرى، نلاحظ أنه في حين أن هذا الكمبيوتر الدفتري استكشاف JointDistributionSequentialAutoBatched
، والمتغيرات الأخرى من JointDistribution
لها ما يعادل AutoBatched
. (لمستخدمي JointDistributionCoroutine
، JointDistributionCoroutineAutoBatched
له فائدة إضافية التي لم تعد بحاجة لتحديد Root
العقد، وإذا كنت تستخدم أبدا JointDistributionCoroutine
. يمكنك بشكل آمن تجاهل هذا البيان)
أفكار ختامية
في هذه المفكرة، قدمنا JointDistributionSequentialAutoBatched
وعملت من خلال مثال بسيط في التفاصيل. نأمل أن تكون قد تعلمت شيئًا عن أشكال TFP وحول المزج التلقائي!