توزيعات مشتركة مجمعة تلقائيًا: برنامج تعليمي لطيف

عرض على TensorFlow.org تشغيل في Google Colab عرض المصدر على جيثب تحميل دفتر

مقدمة

TensorFlow الاحتمالية (TFP) يوفر عددا من JointDistribution التجريدية التي تجعل الاستدلال احتمالي أسهل عن طريق السماح للمستخدم بسهولة التعبير عن نموذج رسومية احتمالي في شكل شبه الرياضية. يولد الاستخراج طرقًا لأخذ العينات من النموذج وتقييم احتمالية تسجيل العينات من النموذج. في هذا البرنامج التعليمي، نستعرض "autobatched" المتغيرات، والتي تم تطويرها بعد الأصلي JointDistribution التجريد. بالنسبة إلى التجريدات الأصلية غير المطابقة تلقائيًا ، فإن الإصدارات التلقائية هي أبسط في الاستخدام وأكثر راحة ، مما يسمح بالتعبير عن العديد من النماذج باستخدام نموذج معياري أقل. في هذا الكولاب ، نستكشف نموذجًا بسيطًا بالتفصيل (ربما يكون مملاً) ، مع توضيح المشكلات التي تحلها تلقائيًا ، و (نأمل) تعليم القارئ المزيد حول مفاهيم شكل TFP على طول الطريق.

قبل إدخال autobatching، كان هناك عدد قليل من أنواع مختلفة من JointDistribution الموافق الأنماط النحوية المختلفة للتعبير عن النماذج الاحتمالية: JointDistributionSequential ، JointDistributionNamed ، و JointDistributionCoroutine . موجود Auobatching باعتباره mixin، لذلك لدينا الآن AutoBatched المتغيرات من كل هذه. في هذا البرنامج التعليمي، واستكشاف الفروق بين JointDistributionSequential و JointDistributionSequentialAutoBatched . ومع ذلك ، فإن كل ما نقوم به هنا ينطبق على المتغيرات الأخرى مع عدم وجود تغييرات في الأساس.

التبعيات والمتطلبات

الاستيراد والإعداد

المتطلب السابق: مشكلة انحدار بايزي

سننظر في سيناريو انحدار بايزي بسيط للغاية:

\[ \begin{align*} m & \sim \text{Normal}(0, 1) \\ b & \sim \text{Normal}(0, 1) \\ Y & \sim \text{Normal}(mX + b, 1) \end{align*} \]

في هذا النموذج، m و b مستمدة من المعدلات القياسية، والملاحظات Y مستمدة من التوزيع الطبيعي الذي يعني يعتمد على المتغيرات العشوائية m و b ، وبعض (اعشوائي، المعروف) المتغيرات المشاركة X . (من أجل التبسيط ، في هذا المثال ، نفترض أن مقياس جميع المتغيرات العشوائية معروف.)

لأداء الاستدلال في هذا النموذج، وكنا بحاجة إلى معرفة كل من المتغيرات المشاركة X والملاحظات Y ، ولكن لأغراض هذا البرنامج التعليمي، سوف نحتاج فقط X ، لذلك نحدد دمية بسيطة X :

X = np.arange(7)
X
array([0, 1, 2, 3, 4, 5, 6])

Desiderata

في الاستدلال الاحتمالي ، غالبًا ما نريد إجراء عمليتين أساسيتين:

  • sample : رسم عينات من الطراز.
  • log_prob : الحوسبة احتمال سجل لعينة من النموذج.

المساهمة الرئيسية TFP في JointDistribution تجريدات (وكذلك العديد من الطرق الأخرى لبرمجة الاحتمالية) هو السماح للمستخدمين الكتابة نموذجا مرة واحدة والحصول على كل sample و log_prob الحسابية.

مشيرا إلى أن لدينا 7 نقاط في مجموعة البيانات لدينا ( X.shape = (7,) )، يمكننا الآن القول والأمنيات للممتازة JointDistribution :

  • sample() يجب أن ينتج قائمة Tensors وجود شكل [(), (), (7,) ]، الموافق المنحدر العددية، والتحيز العددية، والملاحظات متجه، على التوالي.
  • log_prob(sample()) ينبغي أن تنتج العددية: احتمال سجل من منحدر معينة، والتحيز، والملاحظات.
  • sample([5, 3]) ينبغي أن تصدر قائمة Tensors وجود شكل [(5, 3), (5, 3), (5, 3, 7)] ، وهو ما يمثل (5, 3) - مجموعة من عينات من الموديل.
  • log_prob(sample([5, 3])) يجب أن تنتج Tensor مع شكل (5، 3).

سنقوم الآن ننظر في خلافة JointDistribution النماذج، معرفة كيفية تحقيق الأمنيات المذكورة أعلاه، ونأمل أن تتعلم أكثر من ذلك بقليل عن TFP الأشكال على طول الطريق.

تنبيه المفسد: إن النهج الذي يرضي الأمنيات المذكورة أعلاه دون النمطي المضافة و autobatching .

المحاولة الأولى؛ JointDistributionSequential

jds = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

هذه ترجمة مباشرة للنموذج إلى حد ما إلى رمز. المنحدر m والتحيز b هي واضحة. Y يعرف باستخدام lambda الوظائف: النمط العام هو أن lambda -function من \(k\) الحجج في JointDistributionSequential (JDS) يستخدم السابقة \(k\) التوزيعات في النموذج. لاحظ الترتيب "العكسي".

وسوف ندعو sample_distributions ، الذي يعود على حد سواء العينة والكامنة "، توزيعات الفرعية" التي تم استخدامها لتوليد العينة. (ونحن يمكن أن تنتج فقط العينة عن طريق الاتصال sample ، في وقت لاحق في البرنامج التعليمي سوف تكون مريحة لديها توزيعات كذلك.) العينة التي ننتجها هي على ما يرام:

dists, sample = jds.sample_distributions()
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

لكن log_prob تنتج نتيجة لذلك مع شكل غير مرغوب فيها:

jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=
array([-4.4777603, -4.6775575, -4.7430477, -4.647725 , -4.5746684,
       -4.4368567, -4.480562 ], dtype=float32)>

ولا يعمل أخذ العينات المتعددة:

try:
  jds.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

دعنا نحاول فهم الخطأ الذي يحدث.

مراجعة موجزة: شكل الدفعة والحدث

في TFP، عادية (وليس JointDistribution ) التوزيع الاحتمالي لديه شكل الحدث وشكل دفعة، وفهم الفرق أمر حاسم لفعالية استخدام TFP:

  • يصف شكل الحدث شكل السحب الفردي من التوزيع ؛ قد يعتمد السحب على الأبعاد. بالنسبة للتوزيعات العددية ، يكون شكل الحدث هو []. بالنسبة إلى متعدد المتغيرات خماسي الأبعاد عادي ، يكون شكل الحدث هو [5].
  • يصف شكل الدُفعات عمليات سحب مستقلة وغير موزعة بشكل متماثل ، ويعرف أيضًا باسم "دفعة" من التوزيعات. يمثل تمثيل مجموعة من التوزيعات في كائن Python واحد إحدى الطرق الرئيسية التي يحقق بها TFP الكفاءة على نطاق واسع.

لأغراضنا، وهذه حقيقة حرج أن نأخذ في الاعتبار هو أنه إذا كان نسميه log_prob على عينة واحدة من توزيع، والنتيجة سوف يكون دائما على الشكل الذي مباريات (أي له كما أبعاد أقصى اليمين) على شكل دفعة واحدة.

للاطلاع على مناقشة أكثر تعمقا من الأشكال، انظر إلى "فهم TensorFlow التوزيعات الأشكال" البرنامج التعليمي .

لماذا لا log_prob(sample()) إنتاج مجموعة عددي؟

دعونا نستخدم معرفتنا دفعة وحدث شكل لاستكشاف ما يحدث مع log_prob(sample()) . ها هي عينتنا مرة أخرى:

sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

وهنا توزيعاتنا:

dists
[<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>]

يتم حساب احتمالية السجل عن طريق جمع احتمالات سجل التوزيعات الفرعية في العناصر (المتطابقة) للأجزاء:

log_prob_parts = [dist.log_prob(s) for (dist, s) in zip(dists, sample)]
log_prob_parts
[<tf.Tensor: shape=(), dtype=float32, numpy=-2.3113134>,
 <tf.Tensor: shape=(), dtype=float32, numpy=-1.1357536>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([-1.0306933, -1.2304904, -1.2959809, -1.200658 , -1.1276014,
        -0.9897899, -1.0334952], dtype=float32)>]
np.sum(log_prob_parts) - jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)>

لذلك، مستوى واحد من التفسير هو أن حساب سجل احتمال بإرجاع 7 التنسور لأن مكون الثالث من log_prob_parts هو 7 التنسور. لكن لماذا؟

حسنا، ونحن نرى أن العنصر الأخير من dists ، والتي تتطابق مع التوزيع لدينا أكثر من Y في صياغة mathematial، لديه batch_shape من [7] . وبعبارة أخرى، توزيعنا على Y هو مجموعة من 7 المعدلات مستقلة (مع وسائل مختلفة، وفي هذه الحالة، نفس النطاق).

ونحن نفهم الآن ما هو الخطأ: في JDS، وتوزيع أكثر من Y ديه batch_shape=[7] ، وشمل عينة من JDS تمثل سكالارس ل m و b و "دفعة" من 7 المعدلات مستقلة. و log_prob يحسب 7 منفصلة سجل-الاحتمالات، يمثل كل منها احتمال سجل من رسم m و b والمراقبة واحدة Y[i] في بعض X[i] .

تحديد log_prob(sample()) مع Independent

يذكر أن dists[2] له event_shape=[] و batch_shape=[7] :

dists[2]
<tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>

باستخدام TFP و Independent metadistribution، والذي يحول أبعاد دفعة لأبعاد الحدث، يمكننا تحويل هذا إلى توزيع مع event_shape=[7] و batch_shape=[] (وسوف تسميته y_dist_i لانها التوزيع على Y ، مع _i مكانة في لدينا Independent التفاف):

y_dist_i = tfd.Independent(dists[2], reinterpreted_batch_ndims=1)
y_dist_i
<tfp.distributions.Independent 'IndependentJointDistributionSequential_sample_distributions_Normal' batch_shape=[] event_shape=[7] dtype=float32>

الآن، و log_prob من ناقلات 7 هو العددية:

y_dist_i.log_prob(sample[2])
<tf.Tensor: shape=(), dtype=float32, numpy=-7.9087086>

تحت الأغطية، Independent مبالغ تزيد على المبلغ دفعة:

y_dist_i.log_prob(sample[2]) - tf.reduce_sum(dists[2].log_prob(sample[2]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

وفي الواقع، يمكننا استخدام هذه لبناء جديد jds_ii تقف مرة أخرى ل Independent ) حيث log_prob إرجاع العددية:

jds_i = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m*X + b, scale=1.),
        reinterpreted_batch_ndims=1)
])

jds_i.log_prob(sample)
<tf.Tensor: shape=(), dtype=float32, numpy=-11.355776>

بضع ملاحظات:

  • jds_i.log_prob(s) ليست هي نفسها كما tf.reduce_sum(jds.log_prob(s)) . الأول ينتج احتمالية السجل "الصحيح" للتوزيع المشترك. المبالغ الأخيرة أكثر من 7 التنسور، كل عنصر منها هو مجموع احتمال سجل من m ، b ، وعنصر واحد من احتمال سجل Y ، لذلك overcounts m و b . ( log_prob(m) + log_prob(b) + log_prob(Y) بإرجاع نتيجة لذلك بدلا من رمي استثناء لأن TFP يلي TF وقواعد البث نمباي؛ و. إضافة العددية لناقلات تنتج نتيجة ناقلات الحجم)
  • في هذه الحالة بالذات، فإننا يمكن أن تحل المشكلة وحقق نفس النتيجة باستخدام MultivariateNormalDiag بدلا من Independent(Normal(...)) . MultivariateNormalDiag هو توزيع قيمتها ناقلات (أي أنها لديها بالفعل ناقلات الحدث شكل). Indeeed MultivariateNormalDiag يمكن أن يكون (ولكن لا) تنفيذها باعتبارها تكوين Independent و Normal . أنه من المفيد أن نتذكر أنه نظرا متجه V ، وعينات من n1 = Normal(loc=V) ، و n2 = MultivariateNormalDiag(loc=V) لا يمكن تمييزها. الفرق beween هذه التوزيعات هو أن n1.log_prob(n1.sample()) هو متجه و n2.log_prob(n2.sample()) هو العددية.

عينات متعددة؟

لا يزال رسم عينات متعددة لا يعمل:

try:
  jds_i.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

دعونا نفكر لماذا. عندما ندعو jds_i.sample([5, 3]) ، سنقوم أولا أخذ عينات ل m و b ، ولكل منها شكل (5, 3) . المقبل، ونحن ذاهبون لمحاولة بناء Normal التوزيع عن طريق:

tfd.Normal(loc=m*X + b, scale=1.)

ولكن إذا m له شكل (5, 3) و X لديها شكل 7 ، لا نستطيع ضرب بعضهم البعض، بل وهذا هو الخطأ نحن ضرب:

m = tfd.Normal(0., 1.).sample([5, 3])
try:
  m * X
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

لحل هذه المشكلة، دعونا نفكر في ما خصائص التوزيع على Y أن يكون. وإذا كنا قد دعا jds_i.sample([5, 3]) ، ثم نحن نعرف m و b سيكون شكل على حد سواء (5, 3) . ماهية طبيعة دعوة لل sample على Y المنتجات التوزيع؟ الجواب الواضح هو (5, 3, 7) : لكل نقطة دفعة، ونحن نريد عينة مع نفس حجم X . يمكننا تحقيق ذلك باستخدام إمكانيات البث في TensorFlow ، مع إضافة أبعاد إضافية:

m[..., tf.newaxis].shape
TensorShape([5, 3, 1])
(m[..., tf.newaxis] * X).shape
TensorShape([5, 3, 7])

إضافة محور لكلا m و b ، يمكننا تحديد JDS الجديد الذي يدعم عينات متعددة:

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

shaped_sample = jds_ia.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-1.1133379 ,  0.16390413, -0.24177533],
        [-1.1312429 , -0.6224666 , -1.8182136 ],
        [-0.31343174, -0.32932565,  0.5164407 ],
        [-0.0119963 , -0.9079621 ,  2.3655841 ],
        [-0.26293617,  0.8229698 ,  0.31098196]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-0.02876974,  1.0872147 ,  1.0138507 ],
        [ 0.27367726, -1.331534  , -0.09084719],
        [ 1.3349475 , -0.68765205,  1.680652  ],
        [ 0.75436825,  1.3050154 , -0.9415123 ],
        [-1.2502679 , -0.25730947,  0.74611956]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 7), dtype=float32, numpy=
 array([[[-1.8258233e+00, -3.0641669e-01, -2.7595463e+00, -1.6952467e+00,
          -4.8197951e+00, -5.2986512e+00, -6.6931367e+00],
         [ 3.6438566e-01,  1.0067395e+00,  1.4542470e+00,  8.1155670e-01,
           1.8868095e+00,  2.3877139e+00,  1.0195159e+00],
         [-8.3624744e-01,  1.2518480e+00,  1.0943471e+00,  1.3052304e+00,
          -4.5756745e-01, -1.0668410e-01, -7.0669651e-02]],
 
        [[-3.1788960e-01,  9.2615485e-03, -3.0963073e+00, -2.2846246e+00,
          -3.2269263e+00, -6.0213070e+00, -7.4806519e+00],
         [-3.9149747e+00, -3.5155020e+00, -1.5669601e+00, -5.0759468e+00,
          -4.5065498e+00, -5.6719379e+00, -4.8012795e+00],
         [ 1.3053948e-01, -8.0493152e-01, -4.7845001e+00, -4.9721808e+00,
          -7.1365709e+00, -9.6198196e+00, -9.7951422e+00]],
 
        [[ 2.0621397e+00,  3.4639853e-01,  7.0252883e-01, -1.4311566e+00,
           3.3790007e+00,  1.1619035e+00, -8.9105040e-01],
         [-7.8956139e-01, -8.5023916e-01, -9.7148323e-01, -2.6229355e+00,
          -2.7150445e+00, -2.4633870e+00, -2.1841538e+00],
         [ 7.7627432e-01,  2.2401071e+00,  3.7601702e+00,  2.4245868e+00,
           4.0690269e+00,  4.0605016e+00,  5.1753912e+00]],
 
        [[ 1.4275590e+00,  3.3346462e+00,  1.5374103e+00, -2.2849756e-01,
           9.1219616e-01, -3.1220305e-01, -3.2643962e-01],
         [-3.1910419e-02, -3.8848895e-01,  9.9946201e-02, -2.3619974e+00,
          -1.8507402e+00, -3.6830821e+00, -5.4907336e+00],
         [-7.1941972e-02,  2.1602919e+00,  4.9575748e+00,  4.2317696e+00,
           9.3528280e+00,  1.0526063e+01,  1.5262107e+01]],
 
        [[-2.3257759e+00, -2.5343289e+00, -3.5342445e+00, -4.0423255e+00,
          -3.2361765e+00, -3.3434000e+00, -2.6849220e+00],
         [ 1.5006512e-02, -1.9866472e-01,  7.6781356e-01,  1.6228745e+00,
           1.4191239e+00,  2.6655579e+00,  4.4663467e+00],
         [ 2.6599693e+00,  1.2663836e+00,  1.7162113e+00,  1.4839669e+00,
           2.0559487e+00,  2.5976877e+00,  2.5977583e+00]]], dtype=float32)>]
jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.483114 , -10.139662 , -11.514159 ],
       [-11.656767 , -17.201958 , -12.132455 ],
       [-17.838818 ,  -9.474525 , -11.24898  ],
       [-13.95219  , -12.490049 , -17.123957 ],
       [-14.487818 , -11.3755455, -10.576363 ]], dtype=float32)>

كتحقق إضافي ، سنتحقق من أن احتمالية السجل لنقطة دفعة واحدة تتطابق مع ما كان لدينا من قبل:

(jds_ia.log_prob(shaped_sample)[3, 1] -
 jds_i.log_prob([shaped_sample[0][3, 1],
                 shaped_sample[1][3, 1],
                 shaped_sample[2][3, 1, :]]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

الدفعة التلقائية من أجل الفوز

ممتاز! لدينا الآن نسخة من JointDistribution الذي يعالج جميع لدينا الأمنيات: log_prob عائدات بفضل العددية لاستخدام tfd.Independent ، وعينات متعددة تعمل الآن أننا الثابتة البث بإضافة محاور إضافية.

ماذا لو أخبرتك أن هناك طريقة أسهل وأفضل؟ هناك، وانه دعا JointDistributionSequentialAutoBatched (JDSAB):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
jds_ab.log_prob(jds.sample())
<tf.Tensor: shape=(), dtype=float32, numpy=-12.954952>
shaped_sample = jds_ab.sample([5, 3])
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.191533 , -10.43885  , -16.371655 ],
       [-13.292994 , -11.97949  , -16.788685 ],
       [-15.987699 , -13.435732 , -10.6029   ],
       [-10.184758 , -11.969714 , -14.275676 ],
       [-12.740775 , -11.5654125, -12.990162 ]], dtype=float32)>
jds_ab.log_prob(shaped_sample) - jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)>

كيف يعمل هذا؟ في الوقت الذي يمكن أن محاولة قراءة رمز لفهم عميق، وسوف نقدم لمحة موجزة وهو ما يكفي لمعظم حالات الاستخدام:

  • يذكر أن مشكلتنا الأولى هي أن لدينا توزيع Y زيارتها batch_shape=[7] و event_shape=[] ، وكنا Independent لتحويل البعد دفعة إلى البعد الحدث. يتجاهل JDSAB الأشكال الدفعية لتوزيعات المكونات ؛ بدلا من ذلك فإنه يعامل شكل دفعة كخاصية العامة للنموذج، الذي يفترض أن يكون [] (ما لم ينص على خلاف ذلك عن طريق وضع batch_ndims > 0 ). تأثير ما يعادل باستخدام tfd.Independent لتحويل جميع أبعاد دفعة من توزيعات المكونة في أبعاد الحدث، كما فعلنا يدويا أعلاه.
  • وكان لدينا مشكلة ثانية حاجة لتدليك أشكال m و b حتى يتمكنوا من بث مناسب مع X عند إنشاء عينات متعددة. مع JDSAB، أن تكتب نموذج لتوليد عينة واحدة، ونحن "رفع" نموذج كامل لتوليد نماذج متعددة باستخدام TensorFlow في vectorized_map . (هذه الميزة analagous إلى JAX في vmap ).

استكشاف قضية شكل دفعة في مزيد من التفاصيل، يمكن لنا أن نقارن الأشكال دفعة من لدينا "سيئة" الأصلي التوزيع المشتركة jds ، لدينا ثابتة دفعة توزيعات jds_i و jds_ia ، ولدينا autobatched jds_ab :

jds.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([7])]
jds_i.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ia.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ab.batch_shape
TensorShape([])

ونحن نرى أن أصلي jds ديها subdistributions مع الأشكال دفعة مختلفة. jds_i و jds_ia إصلاح ذلك عن طريق إنشاء subdistributions مع نفسه (فارغة) شكل دفعة واحدة. jds_ab ليس لديها سوى واحد (فارغ) شكل دفعة واحدة.

ومن الجدير بالذكر أن JointDistributionSequentialAutoBatched يقدم بعض عمومية إضافية مجانا. لنفترض يمكننا أن نجعل من المتغيرات المشاركة X (وضمنيا، الملاحظات Y ) ثنائي الأبعاد:

X = np.arange(14).reshape((2, 7))
X
array([[ 0,  1,  2,  3,  4,  5,  6],
       [ 7,  8,  9, 10, 11, 12, 13]])

لدينا JointDistributionSequentialAutoBatched يعمل مع أي تغييرات (نحن بحاجة إلى إعادة تعريف نموذج لشكل X هو مؤقتا بواسطة jds_ab.log_prob ):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

shaped_sample = jds_ab.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.1813647 , -0.85994506,  0.27593774],
        [-0.73323774,  1.1153806 ,  0.8841938 ],
        [ 0.5127983 , -0.29271227,  0.63733214],
        [ 0.2362284 , -0.919168  ,  1.6648189 ],
        [ 0.26317367,  0.73077047,  2.5395133 ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.09636458,  2.0138032 , -0.5054413 ],
        [ 0.63941646, -1.0785882 , -0.6442188 ],
        [ 1.2310615 , -0.3293852 ,  0.77637213],
        [ 1.2115169 , -0.98906034, -0.07816773],
        [-1.1318136 ,  0.510014  ,  1.036522  ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 2, 7), dtype=float32, numpy=
 array([[[[-1.9685398e+00, -1.6832136e+00, -6.9127172e-01,
            8.5992378e-01, -5.3123581e-01,  3.1584005e+00,
            2.9044402e+00],
          [-2.5645006e-01,  3.1554163e-01,  3.1186538e+00,
            1.4272424e+00,  1.2843871e+00,  1.2266440e+00,
            1.2798605e+00]],
 
         [[ 1.5973477e+00, -5.3631151e-01,  6.8143606e-03,
           -1.4910895e+00, -2.1568544e+00, -2.0513713e+00,
           -3.1663666e+00],
          [-4.9448099e+00, -2.8385928e+00, -6.9027486e+00,
           -5.6543546e+00, -7.2378774e+00, -8.1577444e+00,
           -9.3582869e+00]],
 
         [[-2.1233239e+00,  5.8853775e-02,  1.2024102e+00,
            1.6622503e+00, -1.9197327e-01,  1.8647723e+00,
            6.4322817e-01],
          [ 3.7549341e-01,  1.5853541e+00,  2.4594500e+00,
            2.1952972e+00,  1.7517658e+00,  2.9666045e+00,
            2.5468128e+00]]],
 
 
        [[[ 8.9906776e-01,  6.7375046e-01,  7.3354661e-01,
           -9.9894643e-01, -3.4606690e+00, -3.4810467e+00,
           -4.4315586e+00],
          [-3.0670738e+00, -6.3628020e+00, -6.2538433e+00,
           -6.8091092e+00, -7.7134805e+00, -8.6319380e+00,
           -8.6904278e+00]],
 
         [[-2.2462025e+00, -3.3060855e-01,  1.8974400e-01,
            3.1422038e+00,  4.1483402e+00,  3.5642972e+00,
            4.8709240e+00],
          [ 4.7880130e+00,  5.8790064e+00,  9.6695948e+00,
            7.8112822e+00,  1.2022618e+01,  1.2411858e+01,
            1.4323385e+01]],
 
         [[-1.0189297e+00, -7.8115642e-01,  1.6466728e+00,
            8.2378983e-01,  3.0765080e+00,  3.0170646e+00,
            5.1899948e+00],
          [ 6.5285158e+00,  7.8038850e+00,  6.4155884e+00,
            9.0899811e+00,  1.0040427e+01,  9.1404457e+00,
            1.0411951e+01]]],
 
 
        [[[ 4.5557004e-01,  1.4905317e+00,  1.4904103e+00,
            2.9777462e+00,  2.8620450e+00,  3.4745665e+00,
            3.8295493e+00],
          [ 3.9977460e+00,  5.7173767e+00,  7.8421035e+00,
            6.3180594e+00,  6.0838981e+00,  8.2257290e+00,
            9.6548376e+00]],
 
         [[-7.0750320e-01, -3.5972297e-01,  4.3136525e-01,
           -2.3301599e+00, -5.0374687e-01, -2.8338656e+00,
           -3.4453444e+00],
          [-3.1258626e+00, -3.4687450e+00, -1.2045374e+00,
           -4.0196013e+00, -5.8831010e+00, -4.2965469e+00,
           -4.1388311e+00]],
 
         [[ 2.1969774e+00,  2.4614549e+00,  2.2314475e+00,
            1.8392437e+00,  2.8367062e+00,  4.8600502e+00,
            4.2273531e+00],
          [ 6.1879644e+00,  5.1792760e+00,  6.1141996e+00,
            5.6517797e+00,  8.9979610e+00,  7.5938139e+00,
            9.7918644e+00]]],
 
 
        [[[ 1.5249090e+00,  1.1388919e+00,  8.6903995e-01,
            3.0762129e+00,  1.5128503e+00,  3.5204377e+00,
            2.4760864e+00],
          [ 3.4166217e+00,  3.5930209e+00,  3.1694956e+00,
            4.5797420e+00,  4.5271711e+00,  2.8774328e+00,
            4.7288942e+00]],
 
         [[-2.3095846e+00, -2.0595703e+00, -3.0093951e+00,
           -3.8594103e+00, -4.9681158e+00, -6.4256043e+00,
           -5.5345035e+00],
          [-6.4306297e+00, -7.0924540e+00, -8.4075985e+00,
           -1.0417805e+01, -1.1727266e+01, -1.1196255e+01,
           -1.1333830e+01]],
 
         [[-7.0419472e-01,  1.4568675e+00,  3.7946482e+00,
            4.8489718e+00,  6.6498446e+00,  9.0224218e+00,
            1.1153137e+01],
          [ 1.0060651e+01,  1.1998097e+01,  1.5326431e+01,
            1.7957514e+01,  1.8323889e+01,  2.0160881e+01,
            2.1269085e+01]]],
 
 
        [[[-2.2360647e-01, -1.3632748e+00, -7.2704530e-01,
            2.3558271e-01, -1.0381399e+00,  1.9387857e+00,
           -3.3694571e-01],
          [ 1.6015106e-01,  1.5284677e+00, -4.8567140e-01,
           -1.7770648e-01,  2.1919653e+00,  1.3015286e+00,
            1.3877077e+00]],
 
         [[ 1.3688663e+00,  2.6602898e+00,  6.6657305e-01,
            4.6554832e+00,  5.7781887e+00,  4.9115267e+00,
            4.8446012e+00],
          [ 5.1983776e+00,  6.2297459e+00,  6.3848300e+00,
            8.4291229e+00,  7.1309576e+00,  1.0395646e+01,
            8.5736713e+00]],
 
         [[ 1.2675294e+00,  5.2844582e+00,  5.1331611e+00,
            8.9993315e+00,  1.0794343e+01,  1.4039831e+01,
            1.5731170e+01],
          [ 1.9084715e+01,  2.2191265e+01,  2.3481146e+01,
            2.5803375e+01,  2.8632090e+01,  3.0234968e+01,
            3.1886738e+01]]]], dtype=float32)>]
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-28.90071 , -23.052422, -19.851362],
       [-19.775568, -25.894997, -20.302256],
       [-21.10754 , -23.667885, -20.973007],
       [-19.249458, -20.87892 , -20.573763],
       [-22.351208, -25.457762, -24.648403]], dtype=float32)>

من ناحية أخرى، لدينا بعناية JointDistributionSequential لم يعد يعمل:

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

try:
  jds_ia.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3,1] vs. [2,7] [Op:Mul]

لحل هذه المشكلة، سيكون لدينا لإضافة ثاني tf.newaxis لكلا m و b تتناسب مع الشكل، وزيادة reinterpreted_batch_ndims إلى 2 في استدعاء Independent . في هذه الحالة ، يكون السماح لآلة الخلط التلقائي بالتعامل مع مشكلات الشكل أقصر وأسهل وأكثر راحة.

مرة أخرى، نلاحظ أنه في حين أن هذا الكمبيوتر الدفتري استكشاف JointDistributionSequentialAutoBatched ، والمتغيرات الأخرى من JointDistribution لها ما يعادل AutoBatched . (لمستخدمي JointDistributionCoroutine ، JointDistributionCoroutineAutoBatched له فائدة إضافية التي لم تعد بحاجة لتحديد Root العقد، وإذا كنت تستخدم أبدا JointDistributionCoroutine . يمكنك بشكل آمن تجاهل هذا البيان)

أفكار ختامية

في هذه المفكرة، قدمنا JointDistributionSequentialAutoBatched وعملت من خلال مثال بسيط في التفاصيل. نأمل أن تكون قد تعلمت شيئًا عن أشكال TFP وحول المزج التلقائي!