عرض على TensorFlow.org | تشغيل في Google Colab | عرض المصدر على جيثب | تحميل دفتر |
pip install -q -U jax jaxlib
pip install -q -Uq oryx -I
pip install -q tfp-nightly --upgrade
from functools import partial
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style='white')
import jax
import jax.numpy as jnp
from jax import jit, vmap, grad
from jax import random
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
import oryx
البرمجة الاحتمالية هي فكرة أنه يمكننا التعبير عن النماذج الاحتمالية باستخدام ميزات من لغة البرمجة. ثم يتم تقديم مهام مثل الاستدلال البايزي أو التهميش كميزات لغوية ويمكن أن تكون آلية.
يوفر Oryx نظام برمجة احتمالي يتم فيه التعبير عن البرامج الاحتمالية فقط كوظائف Python ؛ يتم تحويل هذه البرامج بعد ذلك عبر تحويلات دالة قابلة للتكوين مثل تلك الموجودة في JAX! تكمن الفكرة في البدء ببرامج بسيطة (مثل أخذ عينات من عشوائي عادي) وتكوينها معًا لتشكيل نماذج (مثل شبكة بايز العصبونية). نقطة مهمة من تصميم PPL المها في تمكين البرامج لتبدو وكأنها وظائف كنت إرسال بالفعل واستخدامها في JAX، ولكن توضع حواشي لجعل التحولات على علم بها.
لنقم أولاً باستيراد وظائف Oryx الأساسية PPL.
from oryx.core.ppl import random_variable
from oryx.core.ppl import log_prob
from oryx.core.ppl import joint_sample
from oryx.core.ppl import joint_log_prob
from oryx.core.ppl import block
from oryx.core.ppl import intervene
from oryx.core.ppl import conditional
from oryx.core.ppl import graph_replace
from oryx.core.ppl import nest
ما هي البرامج الاحتمالية في أوريكس؟
في Oryx ، تعد البرامج الاحتمالية مجرد وظائف Python نقية تعمل على قيم JAX ومفاتيح شبه عشوائية وتعيد عينة عشوائية. حسب التصميم، فهي متوافقة مع التحولات مثل jit
و vmap
. ومع ذلك، يوفر نظام البرمجة احتمالي المها الأدوات التي تمكنك من تعليم وظائف بطرق مفيدة.
وبعد الفلسفة JAX وظائف نقية، برنامج احتمالي أوريكس هي وظيفة بيثون أن يأخذ JAX PRNGKey
كأول حجتها وأي عدد من الحجج تكييف اللاحقة. ويسمى الناتج من وظيفة "عينة" ونفس القيود التي تنطبق على jit
افتتاحية و vmap
تنطبق ظائف افتتاحية لبرامج الاحتمالية (على سبيل المثال لا تعتمد على تدفق البيانات السيطرة، وليس له آثار جانبية، الخ). هذا يختلف عن العديد من أنظمة البرمجة الاحتمالية الحتمية التي تكون فيها "العينة" هي تتبع التنفيذ بالكامل ، بما في ذلك القيم الداخلية لتنفيذ البرنامج. سنرى لاحقا كيف يمكن الوصول المها القيم الداخلية باستخدام joint_sample
، ناقش أدناه.
Program :: PRNGKey -> ... -> Sample
هنا هو برنامج "مرحبا العالم" ان عينات من توزيع السجل العادي .
def log_normal(key):
return jnp.exp(random_variable(tfd.Normal(0., 1.))(key))
print(log_normal(random.PRNGKey(0)))
sns.distplot(jit(vmap(log_normal))(random.split(random.PRNGKey(0), 10000)))
plt.show()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) 0.8139614 /home/kbuilder/.local/lib/python3.6/site-packages/seaborn/distributions.py:2551: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning)
و log_normal
الوظيفة هي غلاف رقيق حول Tensorflow الاحتمالية (TFP) التوزيع، ولكن بدلا من استدعاء tfd.Normal(0., 1.).sample
، لقد استخدمت random_variable
بدلا من ذلك. كما سنرى لاحقا، random_variable
تمكننا من تحويل الأشياء إلى برامج الاحتمالية، جنبا إلى جنب مع وظيفة أخرى مفيدة.
يمكننا تحويل log_normal
إلى وظيفة سجل الكثافة باستخدام log_prob
التحول:
print(log_prob(log_normal)(1.))
x = jnp.linspace(0., 5., 1000)
plt.plot(x, jnp.exp(vmap(log_prob(log_normal))(x)))
plt.show()
-0.9189385
لأننا قد المشروح وظيفة مع random_variable
، log_prob
يدرك أن هناك دعوة ل tfd.Normal(0., 1.).sample
ويستخدم tfd.Normal(0., 1.).log_prob
لحساب توزيع قاعدة مشكلة السجل. للتعامل مع jnp.exp
، ppl.log_prob
يحسب تلقائيا الكثافة من خلال وظائف bijective، تتبع التغييرات في الحجم في حساب التغير من بين متغير.
في المها، ويمكن أن نأخذ البرامج وتحويلها باستخدام التحولات وظيفة - على سبيل المثال، jax.jit
أو log_prob
. ومع ذلك ، لا تستطيع Oryx القيام بذلك مع أي برنامج ؛ يتطلب وظائف أخذ العينات التي سجلت وظيفة كثافة اللوغاريتمات الخاصة بها مع Oryx. لحسن الحظ، المها يسجل تلقائيا TensorFlow احتمال توزيعات (TFP) في نظامها.
أدوات البرمجة الاحتمالية الخاصة بـ Oryx
لدى Oryx العديد من التحولات الوظيفية الموجهة نحو البرمجة الاحتمالية. سنستعرض معظمها ونقدم بعض الأمثلة. في النهاية ، سنجمع كل ذلك معًا في دراسة حالة MCMC. يمكنك أيضا الرجوع إلى وثائق core.ppl.transformations
لمزيد من التفاصيل.
random_variable
random_variable
فقد اثنين من القطع الرئيسية من وظائف، وتركز على كل من التأشير ظائف بيثون مع المعلومات التي يمكن استخدامها في التحولات.
random_variable
"تعمل بوصفها وظيفة هوية افتراضيا، ولكن يمكن استخدام التسجيلات نوع محددة إلى كائنات تحويل إلى programs.` احتماليلأنواع للاستدعاء (وظائف بيثون، lambdas،
functools.partial
الصورة، الخ) والتعسفيobject
الصورة (مثل JAXDeviceArray
ق) سيكون مجرد العودة مدخلاته.random_variable(x: object) == x random_variable(f: Callable[...]) == f
المها تلقائيا بتسجيل TensorFlow الاحتمالية (TFP) التوزيعات، والتي يتم تحويلها إلى برامج الاحتمالية التي تستدعي توزيع ل
sample
الأسلوب.random_variable(tfd.Normal(0., 1.))(random.PRNGKey(0)) # ==> -0.20584235
يقوم Oryx أيضًا بتضمين معلومات حول توزيع TFP في آثار JAX التي تتيح حساب كثافات السجل تلقائيًا.
random_variable
القيم يمكن العلامة مع أسماء، مما يجعلها مفيدة لتحولات المصب، من خلال توفير اختياريname
الحجة الكلمة لrandom_variable
. عندما كنا تمرير صفيف إلىrandom_variable
جنبا إلى جنب معname
(على سبيل المثالrandom_variable(x, name='x')
)، انها مجرد علامات قيمة وعوائد ذلك. إذا نحن نمر في استدعاء أو توزيع TFP،random_variable
عوائد برنامج أن علامات عينة انتاجها معname
.
هذه التعليقات لا تتغير دلالات البرنامج عند تنفيذه، ولكن فقط عندما تتحول (أي البرنامج سيعود نفس القيمة مع أو بدون استخدام random_variable
).
لنستعرض مثالاً نستخدم فيه كلتا الوظيفتين معًا.
def latent_normal(key):
z_key, x_key = random.split(key)
z = random_variable(tfd.Normal(0., 1.), name='z')(z_key)
return random_variable(tfd.Normal(z, 1e-1), name='x')(x_key)
في هذا البرنامج قمنا المعلمة وسيطة z
و x
، مما يجعل التحولات joint_sample
، intervene
، conditional
و graph_replace
على بينة من الأسماء 'z'
و 'x'
. سنستعرض بالضبط كيف يستخدم كل تحويل الأسماء لاحقًا.
log_prob
و log_prob
تحول وظيفة تحويل برنامج المها احتمالي إلى وظيفة سجل-كثافته. تأخذ دالة كثافة السجل هذه عينة محتملة من البرنامج كمدخلات وترجع كثافة اللوغاريتمات الخاصة بها ضمن توزيع أخذ العينات الأساسي.
log_prob :: Program -> (Sample -> LogDensity)
مثل random_variable
، وأنها تعمل من خلال التسجيل من أنواع حيث سجلت توزيعات TFP تلقائيا، لذلك log_prob(tfd.Normal(0., 1.))
يدعو tfd.Normal(0., 1.).log_prob
. لوظائف بيثون، ومع ذلك، log_prob
يتتبع البرنامج باستخدام JAX، ويتطلع لأخذ عينات البيانات. و log_prob
التحول يعمل على معظم البرامج التي ترجع المتغيرات العشوائية، مباشرة أو عن طريق التحولات للانعكاس ولكن ليس على البرامج التي القيم عينة داخليا التي لم يتم إرجاع. إذا كان لا يمكن عكس العمليات اللازمة في البرنامج، log_prob
سوف رمي خطأ.
وفيما يلي بعض الأمثلة على log_prob
تطبيقها على مختلف البرامج.
-
log_prob
يعمل على البرامج التي عينة مباشرة من توزيعات TFP (أو أنواع أخرى مسجلة) والعودة قيمهم.
def normal(key):
return random_variable(tfd.Normal(0., 1.))(key)
print(log_prob(normal)(0.))
-0.9189385
-
log_prob
غير قادرة على حساب سجل-كثافة عينات من البرامج التي تحول variates عشوائية باستخدام وظائف bijective (على سبيل المثالjnp.exp
،jnp.tanh
،jnp.split
).
def log_normal(key):
return 2 * jnp.exp(random_variable(tfd.Normal(0., 1.))(key))
print(log_prob(log_normal)(1.))
-1.159165
من أجل حساب عينة من log_normal
الصورة سجل الكثافة، نحتاج أولا إلى عكس exp
، مع الأخذ في log
من العينة، ثم قم بإضافة تصحيح حجم التغيير باستخدام السجل-ديت معكوس مصفوفه جاكوبي من exp
(انظر التغيير من متغير الصيغة من ويكيبيديا).
-
log_prob
يعمل مع البرامج التي هياكل الانتاج من عينات مثل والقواميس بيثون أو الصفوف.
def normal_2d(key):
x = random_variable(
tfd.MultivariateNormalDiag(jnp.zeros(2), jnp.ones(2)))(key)
x1, x2 = jnp.split(x, 2, 0)
return dict(x1=x1, x2=x2)
sample = normal_2d(random.PRNGKey(0))
print(sample)
print(log_prob(normal_2d)(sample))
{'x1': DeviceArray([-0.7847661], dtype=float32), 'x2': DeviceArray([0.8564447], dtype=float32)} -2.5125546
-
log_prob
يمشي الرسم البياني حساب تتبع وظيفة، حساب كل من القيم إلى الأمام والعكسية (وسجل-ديت لهم Jacobians) عند الضرورة في محاولة لتوصيل القيم التي يتم إرجاعها مع قاعدة قيمهم عينات عن طريق تغيير محددة جيدا من المتغيرات. خذ المثال التالي من البرنامج:
def complex_program(key):
k1, k2 = random.split(key)
z = random_variable(tfd.Normal(0., 1.))(k1)
x = random_variable(tfd.Normal(jax.nn.relu(z), 1.))(k2)
return jnp.exp(z), jax.nn.sigmoid(x)
sample = complex_program(random.PRNGKey(0))
print(sample)
print(log_prob(complex_program)(sample))
(DeviceArray(1.1547576, dtype=float32), DeviceArray(0.24830955, dtype=float32)) -1.0967848
في هذا البرنامج، ونحن عينة x
مشروط على z
، وهذا يعني أننا في حاجة إلى قيمة z
قبل أن نتمكن من حساب سجل الكثافة من x
. ومع ذلك، من أجل حساب z
، علينا أولا أن عكس jnp.exp
تطبيقها على z
. وهكذا، وذلك لحساب سجل-كثافة x
و z
، log_prob
احتياجات لأول قلب إخراج الأولى، ومن ثم تمريرها إلى الأمام من خلال jax.nn.relu
لحساب متوسط p(x | z)
.
لمزيد من المعلومات حول log_prob
، يمكنك الرجوع إلى core.interpreters.log_prob
. في التنفيذ، log_prob
يستند بشكل وثيق الخروج من inverse
تحول JAX. لمعرفة المزيد حول inverse
، انظر core.interpreters.inverse
.
joint_sample
لتحديد برامج أكثر تعقيدًا وإثارة للاهتمام ، سنستخدم بعض المتغيرات العشوائية الكامنة ، مثل المتغيرات العشوائية بقيم غير ملحوظة. دعونا تشير إلى latent_normal
البرنامج الذي عينات قيمة عشوائي z
أن يتم استخدام متوسط قيمة أخرى عشوائي x
.
def latent_normal(key):
z_key, x_key = random.split(key)
z = random_variable(tfd.Normal(0., 1.), name='z')(z_key)
return random_variable(tfd.Normal(z, 1e-1), name='x')(x_key)
في هذا البرنامج، z
هو ذلك الكامن إذا كان لنا أن مجرد دعوة latent_normal(random.PRNGKey(0))
ونحن لن نعرف قيمة الفعلية لل z
هي المسؤولة عن توليد x
.
joint_sample
هو التحول الذي يحول البرنامج إلى برنامج آخر أن عائدات القاموس أسماء سلسلة رسم الخرائط (العلامات) إلى قيمها. من أجل العمل ، نحتاج إلى التأكد من وضع علامة على المتغيرات الكامنة للتأكد من ظهورها في إخراج الوظيفة المحولة.
joint_sample(latent_normal)(random.PRNGKey(0))
{'x': DeviceArray(0.01873656, dtype=float32), 'z': DeviceArray(0.14389044, dtype=float32)}
علما بأن joint_sample
التحويلات برنامج إلى برنامج آخر أن عينات توزيع مشترك على القيم الكامنة، حتى نتمكن من مواصلة تحويلها. بالنسبة للخوارزميات مثل MCMC و VI ، من الشائع حساب احتمالية سجل التوزيع المشترك كجزء من إجراء الاستدلال. log_prob(latent_normal)
لا يعمل لأنه يتطلب تهميش من z
، ولكن يمكننا استخدام log_prob(joint_sample(latent_normal))
.
print(log_prob(joint_sample(latent_normal))(dict(x=0., z=1.)))
print(log_prob(joint_sample(latent_normal))(dict(x=0., z=-10.)))
-50.03529 -5049.535
لأن هذا هو مثل هذا النمط الشائع، وقد أوريكس أيضا joint_log_prob
التحول الذي هو مجرد تكوين log_prob
و joint_sample
.
print(joint_log_prob(latent_normal)(dict(x=0., z=1.)))
print(joint_log_prob(latent_normal)(dict(x=0., z=-10.)))
-50.03529 -5049.535
block
و block
يأخذ التحول في برنامج وسلسلة من أسماء وإرجاع البرنامج الذي يتصرف مماثل إلا أنه في التحولات المصب (مثل joint_sample
)، يتم تجاهل الأسماء المقدمة. مثال حيث block
هو مفيد وتحويل وتوزيع مشترك إلى قبل أكثر من المتغيرات الكامنة التي كتبها "عرقلة" القيم عينات في احتمال. على سبيل المثال، واتخاذ latent_normal
، الذي يستمد أولا z ~ N(0, 1)
بعد ذلك x | z ~ N(z, 1e-1)
. block(latent_normal, names=['x'])
هو برنامج يخفي x
اسم، لذلك اذا لم نفعل joint_sample(block(latent_normal, names=['x']))
، نحصل على قاموس مع عادل z
فيه .
blocked = block(latent_normal, names=['x'])
joint_sample(blocked)(random.PRNGKey(0))
{'z': DeviceArray(0.14389044, dtype=float32)}
intervene
و intervene
clobbers التحول العينات في برنامج احتمالي مع القيم من الخارج. وبالعودة الى لدينا latent_normal
البرنامج، دعونا تقول ونحن مهتمون في تشغيل نفس البرنامج ولكن أراد z
لتكون ثابتة إلى 4. وبدلا من كتابة برنامج جديد، يمكننا استخدام intervene
لتجاوز قيمة z
.
intervened = intervene(latent_normal, z=4.)
sns.distplot(vmap(intervened)(random.split(random.PRNGKey(0), 10000)))
plt.show();
/home/kbuilder/.local/lib/python3.6/site-packages/seaborn/distributions.py:2551: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning)
و intervened
عينات وظيفة من p(x | do(z = 4))
الذي هو مجرد التوزيع الطبيعي القياسي تركزت في 4. ونحن عندما intervene
على قيمة معينة، لم يعد يعتبر أن قيمة المتغير العشوائي. يعني ذلك أن z
لن الموسومة قيمة أثناء تنفيذ intervened
.
conditional
conditional
التحويلات البرنامج الذي عينات الكامنة القيم في واحدة أن الظروف على تلك القيم الكامنة. العودة إلى موقعنا latent_normal
البرنامج الذي عينات p(x)
مع الكامنة z
، يمكننا تحويله إلى برنامج مشروط p(x | z)
.
cond_program = conditional(latent_normal, 'z')
print(cond_program(random.PRNGKey(0), 100.))
print(cond_program(random.PRNGKey(0), 50.))
sns.distplot(vmap(lambda key: cond_program(key, 1.))(random.split(random.PRNGKey(0), 10000)))
sns.distplot(vmap(lambda key: cond_program(key, 2.))(random.split(random.PRNGKey(0), 10000)))
plt.show()
99.87485 49.874847 /home/kbuilder/.local/lib/python3.6/site-packages/seaborn/distributions.py:2551: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning) /home/kbuilder/.local/lib/python3.6/site-packages/seaborn/distributions.py:2551: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning)
nest
عندما نبدأ في إنشاء برامج احتمالية لبناء برامج أكثر تعقيدًا ، فمن الشائع إعادة استخدام الوظائف التي لها بعض المنطق المهم. على سبيل المثال، إذا كنا نرغب في بناء الشبكة العصبية النظرية الافتراضية، قد يكون هناك المهم dense
البرنامج الذي عينات الأوزان وتنفيذ تمريرة إلى الأمام.
إذا كان لنا أن إعادة وظائف، ومع ذلك، فإننا قد ينتهي مع القيم الموسومة مكررة في البرنامج النهائي، الذي ألغي من قبل التحولات مثل joint_sample
. يمكننا استخدام nest
لإنشاء سمة "نطاقات"، حيث سيتم إدراج أي عينات من داخل نطاق اسمه في القاموس متداخلة.
def f(key):
return random_variable(tfd.Normal(0., 1.), name='x')(key)
def g(key):
k1, k2 = random.split(key)
return nest(f, scope='x1')(k1) + nest(f, scope='x2')(k2)
joint_sample(g)(random.PRNGKey(0))
{'x1': {'x': DeviceArray(0.14389044, dtype=float32)}, 'x2': {'x': DeviceArray(-1.2515389, dtype=float32)} }
دراسة حالة: شبكة بايزي العصبية
دعونا نحاول يدنا في تدريب الشبكة العصبية النظرية الافتراضية لتصنيف الكلاسيكية فيشر ايريس البيانات. إنه صغير نسبيًا ومنخفض الأبعاد ، لذا يمكننا محاولة أخذ عينات مباشرة مع MCMC.
أولاً ، دعنا نستورد مجموعة البيانات وبعض الأدوات المساعدة الإضافية من Oryx.
from sklearn import datasets
iris = datasets.load_iris()
features, labels = iris['data'], iris['target']
num_features = features.shape[-1]
num_classes = len(iris.target_names)
from oryx.experimental import mcmc
from oryx.util import summary, get_summaries
نبدأ بتنفيذ طبقة كثيفة ، والتي سيكون لها مقدمات عادية على الأوزان والتحيز. للقيام بذلك، علينا أولا تحديد dense
أعلى وظيفة من أجل أن يأخذ في البعد الانتاج وتفعيل الوظيفة المطلوبة. و dense
ترجع برنامج الاحتمالي الذي يمثل توزيع مشروط p(h | x)
حيث h
هو إخراج طبقة كثيفة و x
هو مدخلاته. هذه أول عينات الوزن والتحيز ومن ثم تنطبق عليهم x
.
def dense(dim_out, activation=jax.nn.relu):
def forward(key, x):
dim_in = x.shape[-1]
w_key, b_key = random.split(key)
w = random_variable(
tfd.Sample(tfd.Normal(0., 1.), sample_shape=(dim_out, dim_in)),
name='w')(w_key)
b = random_variable(
tfd.Sample(tfd.Normal(0., 1.), sample_shape=(dim_out,)),
name='b')(b_key)
return activation(jnp.dot(w, x) + b)
return forward
أن يؤلف عدة dense
طبقات معا، ونحن سوف تنفيذ mlp
(متعدد الطبقات المستقبلات) أعلى وظيفة النظام الذي يأخذ في قائمة أحجام الخفية وعدد من الطبقات. تقوم بإرجاع البرنامج الذي يدعو مرارا وتكرارا dense
استخدام المناسب hidden_size
وأخيرا يعود logits لكل فئة في الطبقة النهائية. لاحظ استخدام nest
الذي يخلق نطاقات اسم كل طبقة.
def mlp(hidden_sizes, num_classes):
num_hidden = len(hidden_sizes)
def forward(key, x):
keys = random.split(key, num_hidden + 1)
for i, (subkey, hidden_size) in enumerate(zip(keys[:-1], hidden_sizes)):
x = nest(dense(hidden_size), scope=f'layer_{i + 1}')(subkey, x)
logits = nest(dense(num_classes, activation=lambda x: x),
scope=f'layer_{num_hidden + 1}')(keys[-1], x)
return logits
return forward
لتنفيذ النموذج الكامل ، سنحتاج إلى نمذجة التسميات كمتغيرات عشوائية فئوية. سنقوم بتعريف predict
وظيفة التي تأخذ في مجموعة بيانات من xs
(ملامح) التي مرت بعد ذلك إلى mlp
باستخدام vmap
. عندما نستخدم vmap(partial(mlp, mlp_key))
، ونحن تذوق مجموعة واحدة من الأوزان، ولكن الخريطة الممر إلى الأمام على جميع مدخلات xs
. هذا وتنتج مجموعة من logits
التي parameterizes توزيعات القاطع مستقلة.
def predict(mlp):
def forward(key, xs):
mlp_key, label_key = random.split(key)
logits = vmap(partial(mlp, mlp_key))(xs)
return random_variable(
tfd.Independent(tfd.Categorical(logits=logits), 1), name='y')(label_key)
return forward
هذا هو النموذج الكامل! دعنا نستخدم MCMC لأخذ عينات لاحقة من أوزان BNN المعطاة للبيانات ؛ أولا نحن بناء BNN "قالب" باستخدام mlp
.
bnn = mlp([200, 200], num_classes)
لإنشاء نقطة الانطلاق لدينا سلسلة ماركوف، يمكننا استخدام joint_sample
مع المدخلات وهمية.
weights = joint_sample(bnn)(random.PRNGKey(0), jnp.ones(num_features))
print(weights.keys())
dict_keys(['layer_1', 'layer_2', 'layer_3'])
يعد حساب احتمال سجل التوزيع المشترك كافياً للعديد من خوارزميات الاستدلال. دعونا يقول الآن نلاحظ x
وتريد لأخذ عينات من الخلفي p(z | x)
. لتوزيعات معقدة، ونحن لن تكون قادرة على تهميش من x
(على الرغم من ل latent_normal
نستطيع) ولكن يمكننا حساب لunnormalized كثافة سجل log p(z, x)
حيث x
هو ثابت لقيمة معينة. يمكننا استخدام احتمالية السجل غير الطبيعي مع MCMC لأخذ عينات لاحقة. دعونا نكتب هذه الوظيفة "المثبتة" السجل.
def target_log_prob(weights):
return joint_log_prob(predict(bnn))(dict(weights, y=labels), features)
الآن يمكننا استخدام tfp.mcmc
لأخذ عينات من الخلفي باستخدام لدينا unnormalized دالة الكثافة السجل. علما بأن علينا أن استخدام "المسطحة" نسخة من الأوزان المتداخلة لدينا قاموس لتكون متوافقة مع tfp.mcmc
، لذلك نستخدم المرافق شجرة JAX للشد وunflatten.
@jit
def run_chain(key, weights):
flat_state, sample_tree = jax.tree_flatten(weights)
def flat_log_prob(*states):
return target_log_prob(jax.tree_unflatten(sample_tree, states))
def trace_fn(_, results):
return results.inner_results.accepted_results.target_log_prob
flat_states, log_probs = tfp.mcmc.sample_chain(
1000,
num_burnin_steps=9000,
kernel=tfp.mcmc.DualAveragingStepSizeAdaptation(
tfp.mcmc.HamiltonianMonteCarlo(flat_log_prob, 1e-3, 100),
9000, target_accept_prob=0.7),
trace_fn=trace_fn,
current_state=flat_state,
seed=key)
samples = jax.tree_unflatten(sample_tree, flat_states)
return samples, log_probs
posterior_weights, log_probs = run_chain(random.PRNGKey(0), weights)
plt.plot(log_probs)
plt.show()
يمكننا استخدام عيناتنا لأخذ تقدير متوسط نموذج بايزي (BMA) لدقة التدريب. لحساب ذلك، يمكننا استخدام intervene
مع bnn
إلى "حقن" الخلفي الأوزان بدلا من تلك التي يتم أخذ عينات من المفتاح. لحساب logits لكل نقطة بيانات لكل عينة الخلفية، يمكننا مضاعفة vmap
على posterior_weights
و features
.
output_logits = vmap(lambda weights: vmap(lambda x: intervene(bnn, **weights)(
random.PRNGKey(0), x))(features))(posterior_weights)
output_probs = jax.nn.softmax(output_logits)
print('Average sample accuracy:', (
output_probs.argmax(axis=-1) == labels[None]).mean())
print('BMA accuracy:', (
output_probs.mean(axis=0).argmax(axis=-1) == labels[None]).mean())
Average sample accuracy: 0.9874067 BMA accuracy: 0.99333334
استنتاج
في Oryx ، تعد البرامج الاحتمالية مجرد وظائف JAX التي تأخذ العشوائية (الزائفة) كمدخل. بسبب تكامل Oryx الضيق مع نظام تحويل الوظائف في JAX ، يمكننا كتابة ومعالجة البرامج الاحتمالية كما نكتب كود JAX. ينتج عن هذا نظام بسيط ولكنه مرن لبناء نماذج معقدة والقيام بالاستدلال.