הצג באתר TensorFlow.org | הפעל בגוגל קולאב | צפה במקור ב-GitHub | הורד מחברת |
מבוא
TensorFlow הסתברות (פריון כולל) מציעה מספר JointDistribution
ההפשטות שהופכות היסק הסתברותי קל בכך שהוא מאפשר למשתמש בקלות להביע מודל גרפי הסתברותית בצורה כמעט מתמטית; ההפשטה מייצרת שיטות לדגימה מהמודל ולהערכת ההסתברות ביומן של דגימות מהמודל. במדריך זה, אנו בודקים "autobatched" גרסאות, אשר פותחו לאחר המקורי JointDistribution
הפשטות. ביחס להפשטות המקוריות, הלא-אוטומטיות, הגרסאות האוטומטיות פשוטות יותר לשימוש וארגונומיות יותר, מה שמאפשר לדגמים רבים לבוא לידי ביטוי עם פחות לוח. בקולאב זה, אנו חוקרים מודל פשוט בפרטי פרטים (אולי מייגעים), מבהירים את הבעיות שהאצט האוטומטי פותר, ו(בתקווה) מלמדים את הקורא יותר על מושגי צורת TFP לאורך הדרך.
לקראת כניסתה של autobatching, היו כמה גרסאות שונות של JointDistribution
, המתאימים לסגנונות תחבירי שונה להבעת מודלים הסתברותיים: JointDistributionSequential
, JointDistributionNamed
, ו JointDistributionCoroutine
. Auobatching קיים בתור Mixin, כך שיש לנו עכשיו AutoBatched
וריאנטים של כל אלה. במדריך זה, אנו לחקור את ההבדלים בין JointDistributionSequential
ו JointDistributionSequentialAutoBatched
; עם זאת, כל מה שאנו עושים כאן ישים על הווריאציות האחרות ללא שינויים למעשה.
תלות ותנאים מוקדמים
ייבוא והגדרות
import functools
import numpy as np
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import tensorflow_probability as tfp
tfd = tfp.distributions
תנאי מוקדם: בעיית רגרסיה בייסיאנית
נשקול תרחיש רגרסיה בייסיאנית פשוטה מאוד:
\[ \begin{align*} m & \sim \text{Normal}(0, 1) \\ b & \sim \text{Normal}(0, 1) \\ Y & \sim \text{Normal}(mX + b, 1) \end{align*} \]
במודל זה, m
ו b
שאובים הנורמלים סטנדרטי, ואת התצפיות Y
שאובים התפלגות נורמלית אשר מתכוון תלוי במשתנים אקראיים m
ו b
, וחלקם (אקראי, הידוע) למשתנים X
. (למען הפשטות, בדוגמה זו, אנו מניחים שקנה המידה של כל המשתנים האקראיים ידוע).
כדי לבצע היקש במודל זה, היינו צריכים לדעת הן את המשתנים: X
ואת התצפיות Y
, אבל לצורך מדריך זה, נצטרך רק X
, אז אנחנו מגדירים דמה פשוט X
:
X = np.arange(7)
X
array([0, 1, 2, 3, 4, 5, 6])
Desiderata
בהסקה הסתברותית, לעתים קרובות אנו רוצים לבצע שתי פעולות בסיסיות:
-
sample
: ציור דגימות מהמודל. -
log_prob
: חישוב הסתברות היומן של מדגם מהמודל.
תרומת המפתח של של הפריון הכולל JointDistribution
פשטות (כמו גם גישות אחרות רבות לתכנות הסתברותית) היא לאפשר למשתמשים לכתוב מודל אחת לגשת לשני sample
ו log_prob
חישובים.
וציין כי יש לנו 7 נקודות בערכת הנתונים שלנו ( X.shape = (7,)
), אנו יכולים כעת לקבוע הדבר המבוקש עבור מצוין JointDistribution
:
-
sample()
צריך לייצר רשימה שלTensors
שיש צורה[(), (), (7,)
], המקביל למדרון סקלר, הטיה סקלר, ותצפיות וקטור, בהתאמה. -
log_prob(sample())
צריך לייצר סקלר: הסתברות היומן של מדרון, הטיה מסוימת, ותצפיות. -
sample([5, 3])
צריכים לייצר רשימה שלTensors
שיש צורה[(5, 3), (5, 3), (5, 3, 7)]
, מייצג(5, 3)
- קבוצה של דגימות המודל. -
log_prob(sample([5, 3]))
צריך לייצרTensor
עם צורה (5, 3).
כעת אנו נבחנו שורה של JointDistribution
מודלים, לראות כיצד להשיג דבר המבוקש לעיל, ואנו מקווים ללמוד קצת יותר על הפריון הכולל צורות לאורך הדרך.
התראה ספוילר: הגישה המספק דבר המבוקש הנ"ל בלא מוכן מראש מוסף autobatching .
נסיון ראשון; JointDistributionSequential
jds = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
זה פחות או יותר תרגום ישיר של המודל לקוד. השיפוע m
והטיית b
הוא פשוט. Y
מוגדר באמצעות lambda
-function: הדפוס הכללי הוא כי lambda
-function של \(k\) ויכוחים בתוך JointDistributionSequential
(JDS) משתמשת הקודם \(k\) הפצות במודל. שימו לב לסדר ה"הפוך".
אנחנו נתקשר sample_distributions
, אשר מחזיר הן מדגם ואת "תת-הפצות" שבבסיס ששימשו להפקת המדגם. (אנחנו יכולים הניבו רק המדגם על ידי התקשרות sample
; בהמשך הדרכה זה יהיה נוח לקבל הפצות כמו גם.) המדגם שאנו מייצרים הוא בסדר:
dists, sample = jds.sample_distributions()
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>, <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>, <tf.Tensor: shape=(7,), dtype=float32, numpy= array([ 0.18573815, -1.79962 , -1.8106272 , -3.5971394 , -6.6625295 , -7.308844 , -9.832693 ], dtype=float32)>]
אבל log_prob
מייצר תוצאה עם צורה רצויה:
jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy= array([-4.4777603, -4.6775575, -4.7430477, -4.647725 , -4.5746684, -4.4368567, -4.480562 ], dtype=float32)>
ודגימה מרובה לא עובדת:
try:
jds.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]
בואו ננסה להבין מה לא בסדר.
סקירה קצרה: אצווה וצורת אירוע
בשנת הפריון כוללת, סתם אזרח (לא JointDistribution
) התפלגות הסתברות בעל צורת אירוע צורה אצווה, והבנת ההבדל היא קריטית כדי שימוש יעיל של פריון כולל:
- צורת אירוע מתארת את הצורה של ציור בודד מההפצה; ההגרלה עשויה להיות תלויה במידות שונות. עבור התפלגויות סקלריות, צורת האירוע היא []. עבור MultivariateNormal 5 ממדי, צורת האירוע היא [5].
- צורת אצווה מתארת שרטוטים עצמאיים, לא מבוזרים זהים, הלא היא "אצווה" של הפצות. ייצוג אצווה של הפצות באובייקט Python יחיד היא אחת הדרכים המרכזיות שבהן TFP משיג יעילות בקנה מידה.
לענייננו, עובדה קריטית לזכור היא שאם אנחנו קוראים log_prob
על מדגם יחיד מחלוק, התוצאה תמיד תהיה צורה כי גפרורים (כלומר, יש גם ממדים ימניים ביותר) בצורה תצווה.
לדיון מעמיק יותר של צורות, לראות את המדריך "TensorFlow הפצות צורות הבנה" .
מדוע לא log_prob(sample())
לייצר סקלר?
בואו להשתמש בידע שלנו בכושר אצווה ואירוע לחקור מה קורה עם log_prob(sample())
. הנה שוב המדגם שלנו:
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>, <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>, <tf.Tensor: shape=(7,), dtype=float32, numpy= array([ 0.18573815, -1.79962 , -1.8106272 , -3.5971394 , -6.6625295 , -7.308844 , -9.832693 ], dtype=float32)>]
והנה ההפצות שלנו:
dists
[<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>, <tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>, <tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>]
ההסתברות ביומן מחושבת על ידי סיכום הסתברויות היומן של תת-התפלגויות ברכיבים (המתואמים) של החלקים:
log_prob_parts = [dist.log_prob(s) for (dist, s) in zip(dists, sample)]
log_prob_parts
[<tf.Tensor: shape=(), dtype=float32, numpy=-2.3113134>, <tf.Tensor: shape=(), dtype=float32, numpy=-1.1357536>, <tf.Tensor: shape=(7,), dtype=float32, numpy= array([-1.0306933, -1.2304904, -1.2959809, -1.200658 , -1.1276014, -0.9897899, -1.0334952], dtype=float32)>]
np.sum(log_prob_parts) - jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)>
אז, אחד ברמה של הסבר היא כי חישוב הסתברות יומן מחזיר 7-טנסור כי רכיב המשנה השלישי של log_prob_parts
הוא 7-מותח. אבל למה?
ובכן, אנו רואים כי האלמנט האחרון של dists
, אשר תואמת את ההפצה שלנו מעל Y
בניסוח mathematial, יש batch_shape
של [7]
. במילות אחרות, חלוק שלנו מעל Y
היא קבוצה של 7 הנורמלים עצמאיים (עם אמצעים שונים, במקרה זה, באותו הסולם).
כעת אנו מבינים מה לא בסדר: ב JDS, חלוק מעל Y
יש batch_shape=[7]
, מדגם מן JDS מייצג scalars עבור m
ו b
ו "תצווה" של 7 הנורמלים עצמאיים. ו log_prob
מחשב 7 יומן-הסתברויות נפרדת, שכל אחת מהן מייצגת את יומן ההסתברות לשליפת m
ו b
ו תצפית אחת Y[i]
על כמה X[i]
.
תיקון log_prob(sample())
עם Independent
נזכיר כי dists[2]
יש event_shape=[]
ו batch_shape=[7]
:
dists[2]
<tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>
באמצעות של הפריון הכולל Independent
metadistribution, אשר ממיר ממדים יצוו לממדי אירוע, אנחנו יכולים להמיר את זה לתוך הפצה עם event_shape=[7]
ו batch_shape=[]
(נצטרך לשנות את זה y_dist_i
בגלל שזה הפצה על Y
, עם _i
העמיד עבור שלנו Independent
הגלישה):
y_dist_i = tfd.Independent(dists[2], reinterpreted_batch_ndims=1)
y_dist_i
<tfp.distributions.Independent 'IndependentJointDistributionSequential_sample_distributions_Normal' batch_shape=[] event_shape=[7] dtype=float32>
עכשיו, log_prob
של וקטור 7 הוא סקלר:
y_dist_i.log_prob(sample[2])
<tf.Tensor: shape=(), dtype=float32, numpy=-7.9087086>
מתחת לשמיכות, Independent
סכומים מעל אצווה:
y_dist_i.log_prob(sample[2]) - tf.reduce_sum(dists[2].log_prob(sample[2]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>
ואכן, אנו יכולים להשתמש בזה כדי לבנות חדש jds_i
(את i
שוב מייצג Independent
) שבו log_prob
חוזר סקלר:
jds_i = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Independent( # Y
tfd.Normal(loc=m*X + b, scale=1.),
reinterpreted_batch_ndims=1)
])
jds_i.log_prob(sample)
<tf.Tensor: shape=(), dtype=float32, numpy=-11.355776>
כמה הערות:
-
jds_i.log_prob(s)
הוא לא אותו הדבר כמוtf.reduce_sum(jds.log_prob(s))
. הראשון מייצר את הסתברות היומן ה"נכונה" של ההתפלגות המשותפת. הסכומים האחרונים מעל 7-מותח, כל אלמנט של אשר הוא סכום ההסתברות יומן שלm
,b
, ו אלמנט בודד של הסתברות יומן שלY
, אז זה overcountsm
וb
. (log_prob(m) + log_prob(b) + log_prob(Y)
מחזירה תוצאה ולא לזרוק חריג כי הפריון הכולל כדלקמן TF ותקנון שידור של numpy;. הוספת סקלר כדי וקטור מייצר תוצאה וקטור בגודל) - במקרה הספציפי הזה, היינו יכולים לפתור את הבעיה והשיג אותה תוצאה באמצעות
MultivariateNormalDiag
במקוםIndependent(Normal(...))
.MultivariateNormalDiag
היא הפצה מוערכת וקטור (כלומר, זה כבר יש וקטור צורת אירוע). IndeeedMultivariateNormalDiag
יכול להיות (אך לא) מיושם בתור הרכבIndependent
וNormal
. כדאי לזכור כי נתון וקטורV
, דגימותn1 = Normal(loc=V)
, וn2 = MultivariateNormalDiag(loc=V)
הם נבדלים; ההבדל beween הפצות אלה היא כיn1.log_prob(n1.sample())
הוא וקטורn2.log_prob(n2.sample())
הוא סקלר.
מספר דוגמאות?
ציור של מספר דוגמאות עדיין לא עובד:
try:
jds_i.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]
בואו נחשוב למה. כאשר אנו קוראים jds_i.sample([5, 3])
, ראשית ניצור לצייר דגימות m
ו b
, כל אחד עם צורה (5, 3)
. בשלב הבא, אנחנו הולכים לנסות לבנות Normal
הפצה באמצעות:
tfd.Normal(loc=m*X + b, scale=1.)
אבל אם m
יש צורה (5, 3)
ו X
יש צורה 7
, אנחנו לא יכולים להכפיל אותם יחד, ואכן זו היא השגיאה אנו להכות במקרים הבאים:
m = tfd.Normal(0., 1.).sample([5, 3])
try:
m * X
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]
כדי לפתור בעיה זו, בואו נחשוב על מה המאפיינים חלוק מעל Y
יש לקיים. אם אנחנו כבר נקרא jds_i.sample([5, 3])
, אז אנחנו יודעים m
ו b
יהיה הן צורה (5, 3)
. מה צורה צריכה קריאה sample
על Y
תוצרת ההפצה? התשובה המתבקשת היא (5, 3, 7)
: עבור כל נקודה אצווה, אנחנו רוצים מדגם עם בגודל זהה X
. אנו יכולים להשיג זאת על ידי שימוש ביכולות השידור של TensorFlow, הוספת ממדים נוספים:
m[..., tf.newaxis].shape
TensorShape([5, 3, 1])
(m[..., tf.newaxis] * X).shape
TensorShape([5, 3, 7])
הוספת ציר לשני m
ו b
, אנו יכולים להגדיר JDS חדש תומך במספר דוגמאות:
jds_ia = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Independent( # Y
tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
reinterpreted_batch_ndims=1)
])
shaped_sample = jds_ia.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-1.1133379 , 0.16390413, -0.24177533], [-1.1312429 , -0.6224666 , -1.8182136 ], [-0.31343174, -0.32932565, 0.5164407 ], [-0.0119963 , -0.9079621 , 2.3655841 ], [-0.26293617, 0.8229698 , 0.31098196]], dtype=float32)>, <tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-0.02876974, 1.0872147 , 1.0138507 ], [ 0.27367726, -1.331534 , -0.09084719], [ 1.3349475 , -0.68765205, 1.680652 ], [ 0.75436825, 1.3050154 , -0.9415123 ], [-1.2502679 , -0.25730947, 0.74611956]], dtype=float32)>, <tf.Tensor: shape=(5, 3, 7), dtype=float32, numpy= array([[[-1.8258233e+00, -3.0641669e-01, -2.7595463e+00, -1.6952467e+00, -4.8197951e+00, -5.2986512e+00, -6.6931367e+00], [ 3.6438566e-01, 1.0067395e+00, 1.4542470e+00, 8.1155670e-01, 1.8868095e+00, 2.3877139e+00, 1.0195159e+00], [-8.3624744e-01, 1.2518480e+00, 1.0943471e+00, 1.3052304e+00, -4.5756745e-01, -1.0668410e-01, -7.0669651e-02]], [[-3.1788960e-01, 9.2615485e-03, -3.0963073e+00, -2.2846246e+00, -3.2269263e+00, -6.0213070e+00, -7.4806519e+00], [-3.9149747e+00, -3.5155020e+00, -1.5669601e+00, -5.0759468e+00, -4.5065498e+00, -5.6719379e+00, -4.8012795e+00], [ 1.3053948e-01, -8.0493152e-01, -4.7845001e+00, -4.9721808e+00, -7.1365709e+00, -9.6198196e+00, -9.7951422e+00]], [[ 2.0621397e+00, 3.4639853e-01, 7.0252883e-01, -1.4311566e+00, 3.3790007e+00, 1.1619035e+00, -8.9105040e-01], [-7.8956139e-01, -8.5023916e-01, -9.7148323e-01, -2.6229355e+00, -2.7150445e+00, -2.4633870e+00, -2.1841538e+00], [ 7.7627432e-01, 2.2401071e+00, 3.7601702e+00, 2.4245868e+00, 4.0690269e+00, 4.0605016e+00, 5.1753912e+00]], [[ 1.4275590e+00, 3.3346462e+00, 1.5374103e+00, -2.2849756e-01, 9.1219616e-01, -3.1220305e-01, -3.2643962e-01], [-3.1910419e-02, -3.8848895e-01, 9.9946201e-02, -2.3619974e+00, -1.8507402e+00, -3.6830821e+00, -5.4907336e+00], [-7.1941972e-02, 2.1602919e+00, 4.9575748e+00, 4.2317696e+00, 9.3528280e+00, 1.0526063e+01, 1.5262107e+01]], [[-2.3257759e+00, -2.5343289e+00, -3.5342445e+00, -4.0423255e+00, -3.2361765e+00, -3.3434000e+00, -2.6849220e+00], [ 1.5006512e-02, -1.9866472e-01, 7.6781356e-01, 1.6228745e+00, 1.4191239e+00, 2.6655579e+00, 4.4663467e+00], [ 2.6599693e+00, 1.2663836e+00, 1.7162113e+00, 1.4839669e+00, 2.0559487e+00, 2.5976877e+00, 2.5977583e+00]]], dtype=float32)>]
jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-12.483114 , -10.139662 , -11.514159 ], [-11.656767 , -17.201958 , -12.132455 ], [-17.838818 , -9.474525 , -11.24898 ], [-13.95219 , -12.490049 , -17.123957 ], [-14.487818 , -11.3755455, -10.576363 ]], dtype=float32)>
כבדיקה נוספת, נוודא שההסתברות ביומן עבור נקודת אצווה בודדת תואמת למה שהיה לנו קודם:
(jds_ia.log_prob(shaped_sample)[3, 1] -
jds_i.log_prob([shaped_sample[0][3, 1],
shaped_sample[1][3, 1],
shaped_sample[2][3, 1, :]]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>
אצווה אוטומטית בשביל הניצחון
מְעוּלֶה! עכשיו יש לנו גרסה של JointDistribution שמטפל בכל דברים רצויים שלנו: log_prob
מחזיר תודה סקלר לשימוש tfd.Independent
, ודוגמאות רבות לעבוד עכשיו כי אנחנו קבוע לשדר ידי הוספת צירים מיותרים.
מה אם הייתי אומר לך שיש דרך קלה וטובה יותר? יש, וזה נקרא JointDistributionSequentialAutoBatched
(JDSAB):
jds_ab = tfd.JointDistributionSequentialAutoBatched([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
jds_ab.log_prob(jds.sample())
<tf.Tensor: shape=(), dtype=float32, numpy=-12.954952>
shaped_sample = jds_ab.sample([5, 3])
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-12.191533 , -10.43885 , -16.371655 ], [-13.292994 , -11.97949 , -16.788685 ], [-15.987699 , -13.435732 , -10.6029 ], [-10.184758 , -11.969714 , -14.275676 ], [-12.740775 , -11.5654125, -12.990162 ]], dtype=float32)>
jds_ab.log_prob(shaped_sample) - jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], dtype=float32)>
איך זה עובד? בעוד אתה יכול לנסות לקרוא את הקוד עבור הבנה עמוקה, אנו נספק סקירה קצרה אשר מספקים ברוב מקרי שימוש:
- נזכיר כי הבעיה הראשונה שלנו היתה כי ההפצה שלנו עבור
Y
היוbatch_shape=[7]
וevent_shape=[]
, והיינוIndependent
כדי להמיר את הממד אצווה כדי מימד אירוע. JDSAB מתעלם מצורות האצווה של הפצות רכיבים; במקום זה מתייחס לצורה אצווה כנכס הכוללת של המודל, אשר ההנחה היא להיות[]
(אלא אם צוין אחרת על ידי הגדרתbatch_ndims > 0
). ההשפעה כמוהו כשימוש tfd.Independent להמיר את כל הממדים אצווה של הפצות רכיב לממדים האירוע, כפי שעשינו למעלה באופן ידני. - שני הבעיה שלנו הייתה צורך לעסות את הצורות של
m
וb
כדי שיוכלו לשדר כראוי עםX
בעת יצירת דגימות מרובות. עם JDSAB, אתה כותב מודל ליצירת מדגם יחיד, ואנחנו "להרים" את המודל כולו כדי ליצור דגימות מרובות באמצעות של TensorFlow vectorized_map . (תכונה זו אנלוגית של JAX vmap .)
היכרות עם הנושא בצורה אצווה ביתר פירוט, אנחנו יכולים להשוות את הצורות אצווה של "רע" המקורי שלנו הפצה משותפת jds
, הפצות אצווה קבוע שלנו jds_i
ו jds_ia
, ו autobatched שלנו jds_ab
:
jds.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([7])]
jds_i.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ia.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ab.batch_shape
TensorShape([])
אנו רואים כי המקורית jds
יש subdistributions עם צורות שונות יצווה. jds_i
ו jds_ia
לתקן זאת על ידי יצירת subdistributions עם אותה צורה אצווה (ריק). jds_ab
יש רק צורה אצווה אחת (ריק).
שווה זה וציין כי JointDistributionSequentialAutoBatched
מציעה הכללה נוספת בחינם. נניח שאנחנו עושים את למשתנים X
(ו, במשתמע, התצפיות Y
) דו-ממדית:
X = np.arange(14).reshape((2, 7))
X
array([[ 0, 1, 2, 3, 4, 5, 6], [ 7, 8, 9, 10, 11, 12, 13]])
שלנו JointDistributionSequentialAutoBatched
עובד ללא שינויים (אנחנו צריכים להגדיר מחדש את המודל בגלל בצורת X
במטמון, על ידי jds_ab.log_prob
):
jds_ab = tfd.JointDistributionSequentialAutoBatched([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
shaped_sample = jds_ab.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[ 0.1813647 , -0.85994506, 0.27593774], [-0.73323774, 1.1153806 , 0.8841938 ], [ 0.5127983 , -0.29271227, 0.63733214], [ 0.2362284 , -0.919168 , 1.6648189 ], [ 0.26317367, 0.73077047, 2.5395133 ]], dtype=float32)>, <tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[ 0.09636458, 2.0138032 , -0.5054413 ], [ 0.63941646, -1.0785882 , -0.6442188 ], [ 1.2310615 , -0.3293852 , 0.77637213], [ 1.2115169 , -0.98906034, -0.07816773], [-1.1318136 , 0.510014 , 1.036522 ]], dtype=float32)>, <tf.Tensor: shape=(5, 3, 2, 7), dtype=float32, numpy= array([[[[-1.9685398e+00, -1.6832136e+00, -6.9127172e-01, 8.5992378e-01, -5.3123581e-01, 3.1584005e+00, 2.9044402e+00], [-2.5645006e-01, 3.1554163e-01, 3.1186538e+00, 1.4272424e+00, 1.2843871e+00, 1.2266440e+00, 1.2798605e+00]], [[ 1.5973477e+00, -5.3631151e-01, 6.8143606e-03, -1.4910895e+00, -2.1568544e+00, -2.0513713e+00, -3.1663666e+00], [-4.9448099e+00, -2.8385928e+00, -6.9027486e+00, -5.6543546e+00, -7.2378774e+00, -8.1577444e+00, -9.3582869e+00]], [[-2.1233239e+00, 5.8853775e-02, 1.2024102e+00, 1.6622503e+00, -1.9197327e-01, 1.8647723e+00, 6.4322817e-01], [ 3.7549341e-01, 1.5853541e+00, 2.4594500e+00, 2.1952972e+00, 1.7517658e+00, 2.9666045e+00, 2.5468128e+00]]], [[[ 8.9906776e-01, 6.7375046e-01, 7.3354661e-01, -9.9894643e-01, -3.4606690e+00, -3.4810467e+00, -4.4315586e+00], [-3.0670738e+00, -6.3628020e+00, -6.2538433e+00, -6.8091092e+00, -7.7134805e+00, -8.6319380e+00, -8.6904278e+00]], [[-2.2462025e+00, -3.3060855e-01, 1.8974400e-01, 3.1422038e+00, 4.1483402e+00, 3.5642972e+00, 4.8709240e+00], [ 4.7880130e+00, 5.8790064e+00, 9.6695948e+00, 7.8112822e+00, 1.2022618e+01, 1.2411858e+01, 1.4323385e+01]], [[-1.0189297e+00, -7.8115642e-01, 1.6466728e+00, 8.2378983e-01, 3.0765080e+00, 3.0170646e+00, 5.1899948e+00], [ 6.5285158e+00, 7.8038850e+00, 6.4155884e+00, 9.0899811e+00, 1.0040427e+01, 9.1404457e+00, 1.0411951e+01]]], [[[ 4.5557004e-01, 1.4905317e+00, 1.4904103e+00, 2.9777462e+00, 2.8620450e+00, 3.4745665e+00, 3.8295493e+00], [ 3.9977460e+00, 5.7173767e+00, 7.8421035e+00, 6.3180594e+00, 6.0838981e+00, 8.2257290e+00, 9.6548376e+00]], [[-7.0750320e-01, -3.5972297e-01, 4.3136525e-01, -2.3301599e+00, -5.0374687e-01, -2.8338656e+00, -3.4453444e+00], [-3.1258626e+00, -3.4687450e+00, -1.2045374e+00, -4.0196013e+00, -5.8831010e+00, -4.2965469e+00, -4.1388311e+00]], [[ 2.1969774e+00, 2.4614549e+00, 2.2314475e+00, 1.8392437e+00, 2.8367062e+00, 4.8600502e+00, 4.2273531e+00], [ 6.1879644e+00, 5.1792760e+00, 6.1141996e+00, 5.6517797e+00, 8.9979610e+00, 7.5938139e+00, 9.7918644e+00]]], [[[ 1.5249090e+00, 1.1388919e+00, 8.6903995e-01, 3.0762129e+00, 1.5128503e+00, 3.5204377e+00, 2.4760864e+00], [ 3.4166217e+00, 3.5930209e+00, 3.1694956e+00, 4.5797420e+00, 4.5271711e+00, 2.8774328e+00, 4.7288942e+00]], [[-2.3095846e+00, -2.0595703e+00, -3.0093951e+00, -3.8594103e+00, -4.9681158e+00, -6.4256043e+00, -5.5345035e+00], [-6.4306297e+00, -7.0924540e+00, -8.4075985e+00, -1.0417805e+01, -1.1727266e+01, -1.1196255e+01, -1.1333830e+01]], [[-7.0419472e-01, 1.4568675e+00, 3.7946482e+00, 4.8489718e+00, 6.6498446e+00, 9.0224218e+00, 1.1153137e+01], [ 1.0060651e+01, 1.1998097e+01, 1.5326431e+01, 1.7957514e+01, 1.8323889e+01, 2.0160881e+01, 2.1269085e+01]]], [[[-2.2360647e-01, -1.3632748e+00, -7.2704530e-01, 2.3558271e-01, -1.0381399e+00, 1.9387857e+00, -3.3694571e-01], [ 1.6015106e-01, 1.5284677e+00, -4.8567140e-01, -1.7770648e-01, 2.1919653e+00, 1.3015286e+00, 1.3877077e+00]], [[ 1.3688663e+00, 2.6602898e+00, 6.6657305e-01, 4.6554832e+00, 5.7781887e+00, 4.9115267e+00, 4.8446012e+00], [ 5.1983776e+00, 6.2297459e+00, 6.3848300e+00, 8.4291229e+00, 7.1309576e+00, 1.0395646e+01, 8.5736713e+00]], [[ 1.2675294e+00, 5.2844582e+00, 5.1331611e+00, 8.9993315e+00, 1.0794343e+01, 1.4039831e+01, 1.5731170e+01], [ 1.9084715e+01, 2.2191265e+01, 2.3481146e+01, 2.5803375e+01, 2.8632090e+01, 3.0234968e+01, 3.1886738e+01]]]], dtype=float32)>]
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-28.90071 , -23.052422, -19.851362], [-19.775568, -25.894997, -20.302256], [-21.10754 , -23.667885, -20.973007], [-19.249458, -20.87892 , -20.573763], [-22.351208, -25.457762, -24.648403]], dtype=float32)>
מצד השני, שלנו מעוצב בקפידה JointDistributionSequential
כבר לא עובד:
jds_ia = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Independent( # Y
tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
reinterpreted_batch_ndims=1)
])
try:
jds_ia.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3,1] vs. [2,7] [Op:Mul]
כדי לתקן זאת, היינו צריכים להוסיף שנייה tf.newaxis
לשני m
ו b
להתאים את הצורה, ואת הגידול reinterpreted_batch_ndims
כדי 2 ב הקריאה Independent
. במקרה זה, מתן אפשרות למכונות האצווה האוטומטית לטפל בבעיות הצורה היא קצרה יותר, קלה יותר וארגונומית יותר.
שוב, נציין כי בעוד המחברת הזאת בחנו JointDistributionSequentialAutoBatched
, גרסאות אחרות של JointDistribution
יש מקבילה AutoBatched
. (עבור משתמשי JointDistributionCoroutine
, JointDistributionCoroutineAutoBatched
יתרון נוסף כי אתה כבר לא צריך לציין Root
צמתים; אם מעולם לא השתמשת JointDistributionCoroutine
. אתה רשאי להתעלם הצהרה זו)
מחשבות מסכמות
במחברת זו, הצגנו JointDistributionSequentialAutoBatched
ועבד באמצעות דוגמה פשוטה בפירוט. אני מקווה שלמדת משהו על צורות TFP ועל אצווה אוטומטית!