توزیع استنباط با JAX

مشاهده در TensorFlow.org در Google Colab اجرا شود مشاهده منبع در GitHubدانلود دفترچه یادداشت

TensorFlow Probability (TFP) در JAX اکنون ابزارهایی برای محاسبات عددی توزیع شده دارد. برای مقیاس‌بندی به تعداد زیادی شتاب‌دهنده، ابزارها حول نوشتن کد با استفاده از پارادایم «تک برنامه‌ای چندگانه» یا به اختصار SPMD ساخته شده‌اند.

در این نوت بوک، به نحوه "فکر کردن در SPMD" و معرفی انتزاعات TFP جدید برای مقیاس بندی به پیکربندی هایی مانند TPU pods یا خوشه های GPU خواهیم پرداخت. اگر خودتان این کد را اجرا می کنید، مطمئن شوید که زمان اجرای TPU را انتخاب کنید.

ابتدا آخرین نسخه‌های TFP، JAX و TF را نصب می‌کنیم.

نصب می کند

ما تعدادی کتابخانه عمومی را به همراه برخی از ابزارهای JAX وارد خواهیم کرد.

راه اندازی و واردات

INFO:tensorflow:Enabling eager execution
INFO:tensorflow:Enabling v2 tensorshape
INFO:tensorflow:Enabling resource variables
INFO:tensorflow:Enabling tensor equality
INFO:tensorflow:Enabling control flow v2

ما همچنین چند نام مستعار TFP مفید را تنظیم خواهیم کرد. انتزاعی جدید در حال حاضر در ارائه tfp.experimental.distribute و tfp.experimental.mcmc .

tfd = tfp.distributions
tfb = tfp.bijectors
tfm = tfp.mcmc
tfed = tfp.experimental.distribute
tfde = tfp.experimental.distributions
tfem = tfp.experimental.mcmc

Root = tfed.JointDistributionCoroutine.Root

برای اتصال نوت بوک به TPU از کمکی زیر از JAX استفاده می کنیم. برای تأیید اتصال ما، تعداد دستگاه‌ها را چاپ می‌کنیم که باید هشت عدد باشد.

from jax.tools import colab_tpu
colab_tpu.setup_tpu()
print(f'Found {jax.device_count()} devices')
Found 8 devices

معرفی سریع به jax.pmap

پس از اتصال به یک TPU، ما دسترسی به هشت دستگاه داشته باشد. با این حال، زمانی که کد JAX را مشتاقانه اجرا می کنیم، JAX به طور پیش فرض محاسبات را روی یک مورد اجرا می کند.

ساده ترین راه برای اجرای یک محاسبات در بسیاری از دستگاه ها، ترسیم یک تابع است که هر دستگاه یک شاخص از نقشه را اجرا می کند. JAX فراهم می کند jax.pmap ( "نقشه موازی") تحول که تبدیل یک تابع را به یک است که نقشه های عملکرد را در چندین دستگاه.

در مثال زیر، آرایه‌ای به اندازه 8 ایجاد می‌کنیم (برای مطابقت با تعداد دستگاه‌های موجود) و تابعی را ترسیم می‌کنیم که 5 عدد به آن اضافه می‌کند.

xs = jnp.arange(8.)
out = jax.pmap(lambda x: x + 5.)(xs)
print(type(out), out)
<class 'jax.interpreters.pxla.ShardedDeviceArray'> [ 5.  6.  7.  8.  9. 10. 11. 12.]

توجه داشته باشید که ما دریافت ShardedDeviceArray نوع تماس، نشان می دهد که آرایه خروجی به صورت فیزیکی در سراسر دستگاه تقسیم می شود.

jax.pmap عمل می کند معنایی مانند یک نقشه، اما دارای چند گزینه مهم است که رفتار خود را تغییر دهید. به طور پیش فرض، pmap فرض تمام ورودی به تابع در حال نقشه برداری بیش از، اما ما می توانیم این رفتار را با تغییر in_axes استدلال است.

xs = jnp.arange(8.)
y = 5.
# Map over the 0-axis of `xs` and don't map over `y`
out = jax.pmap(lambda x, y: x + y, in_axes=(0, None))(xs, y)
print(out)
[ 5.  6.  7.  8.  9. 10. 11. 12.]

شبیه از out_axes آرگومان به pmap تعیین اینکه آیا یا نه برای بازگشت به ارزش در هر دستگاه. تنظیم out_axes به None طور خودکار ارزش بر روی دستگاه 1 را برمی گرداند و تنها باید مورد استفاده قرار گیرد اگر ما با اعتماد به نفس ارزش همان در هر دستگاه هستند.

xs = jnp.ones(8) # Value is the same on each device
out = jax.pmap(lambda x: x + 1, out_axes=None)(xs)
print(out)
2.0

چه اتفاقی می‌افتد وقتی کاری که می‌خواهیم انجام دهیم به‌راحتی به‌عنوان یک تابع خالص نگاشت‌شده قابل بیان نباشد؟ به عنوان مثال، اگر بخواهیم در محوری که روی آن نقشه برداری می کنیم، جمع آوری کنیم، چطور؟ JAX "مجموعه ها" را ارائه می دهد، عملکردهایی که در بین دستگاه ها ارتباط برقرار می کنند تا نوشتن برنامه های توزیع شده جالب تر و پیچیده تر را امکان پذیر کند. برای درک اینکه دقیقا چگونه کار می کنند، SPMD را معرفی می کنیم.

SPMD چیست؟

داده های چندگانه تک برنامه ای (SPMD) یک مدل برنامه نویسی همزمان است که در آن یک برنامه واحد (یعنی همان کد) به طور همزمان در دستگاه ها اجرا می شود، اما ورودی های هر یک از برنامه های در حال اجرا می تواند متفاوت باشد.

اگر برنامه ما یک تابع ساده از ورودی های آن (یعنی چیزی شبیه به این است x + 5 )، اجرای یک برنامه در SPMD فقط نقشه برداری آن داده ها در طول های مختلف، مانند ما با انجام jax.pmap پیش از آن. با این حال، ما می‌توانیم کارهایی بیش از یک تابع انجام دهیم. JAX "مجموعه ها" را ارائه می دهد، که عملکردهایی هستند که بین دستگاه ها ارتباط برقرار می کنند.

به عنوان مثال، شاید ما بخواهیم مجموع یک مقدار را در تمام دستگاه های خود در نظر بگیریم. قبل از انجام این کار، ما نیاز به اختصاص یک نام به ما محور نقشه برداری هستید بیش از در pmap . پس از آن ما با استفاده از lax.psum ( "مجموع موازی") تابع برای انجام یک جمع در سراسر دستگاه، حصول اطمینان از ما شناسایی به نام محور ما در حال جمع.

def f(x):
  out = lax.psum(x, axis_name='i')
  return out
xs = jnp.arange(8.) # Length of array matches number of devices
jax.pmap(f, axis_name='i')(xs)
ShardedDeviceArray([28., 28., 28., 28., 28., 28., 28., 28.], dtype=float32)

psum مصالح جمعی ارزش x در هر دستگاه و هماهنگ ارزش خود را در سراسر نقشه یعنی out است 28. در هر دستگاه. ما دیگر یک «نقشه» ساده را انجام نمی‌دهیم، اما در حال اجرای یک برنامه SPMD هستیم که در آن محاسبات هر دستگاه اکنون می‌تواند با همان محاسبات روی دستگاه‌های دیگر تعامل داشته باشد، البته به روشی محدود با استفاده از مجموعه‌ها. در این سناریو، ما می توانیم استفاده out_axes = None ، چون psum خواهد شد که ارزش همگام سازی کنید.

def f(x):
  out = lax.psum(x, axis_name='i')
  return out
jax.pmap(f, axis_name='i', out_axes=None)(jnp.arange(8.))
ShardedDeviceArray(28., dtype=float32)

SPMD ما را قادر می سازد یک برنامه بنویسیم که بر روی هر دستگاهی در هر پیکربندی TPU به طور همزمان اجرا شود. همان کدی که برای انجام یادگیری ماشینی روی 8 هسته TPU استفاده می شود، می تواند روی یک پاد TPU که ممکن است صدها تا هزاران هسته داشته باشد، استفاده شود! برای یک آموزش مفصل تر در مورد jax.pmap و SPMD، شما می توانید به مراجعه JAX 101 آموزش .

MCMC در مقیاس

در این دفترچه، ما بر روی استفاده از روش‌های مونت کارلو زنجیره مارکوف (MCMC) برای استنتاج بیزی تمرکز می‌کنیم. ممکن است راه‌هایی وجود داشته باشد که از دستگاه‌های زیادی برای MCMC استفاده کنیم، اما در این نوت‌بوک، روی دو مورد تمرکز می‌کنیم:

  1. اجرای زنجیره های مارکوف مستقل در دستگاه های مختلف. این مورد نسبتاً ساده است و انجام آن با وانیل TFP امکان پذیر است.
  2. به اشتراک گذاری یک مجموعه داده در بین دستگاه ها. این مورد کمی پیچیده تر است و به ماشین آلات TFP اخیراً اضافه شده نیاز دارد.

زنجیر مستقل

فرض کنید می‌خواهیم با استفاده از MCMC استنتاج بیزی روی یک مشکل انجام دهیم و می‌خواهیم چندین زنجیره را به صورت موازی در چندین دستگاه اجرا کنیم (مثلاً 2 در هر دستگاه). معلوم می‌شود که این برنامه‌ای است که ما می‌توانیم فقط در دستگاه‌ها «نقشه‌برداری» کنیم، یعنی برنامه‌ای که نیازی به جمع ندارد. برای اطمینان از اینکه هر برنامه یک زنجیره مارکوف متفاوت را اجرا می کند (برخلاف اجرای یک زنجیره)، مقدار متفاوتی را برای دانه تصادفی به هر دستگاه ارسال می کنیم.

بیایید آن را روی یک مشکل اسباب بازی نمونه برداری از توزیع گاوسی دو بعدی امتحان کنیم. ما می توانیم از عملکرد موجود MCMC TFP خارج از جعبه استفاده کنیم. به طور کلی، ما سعی می کنیم بیشتر منطق را در عملکرد نقشه برداری خود قرار دهیم تا به وضوح بین آنچه در همه دستگاه ها در حال اجرا است در مقابل دستگاه اول تمایز قائل شویم.

def run(seed):
  target_log_prob = tfd.Sample(tfd.Normal(0., 1.), 2).log_prob

  initial_state = jnp.zeros([2, 2]) # 2 chains
  kernel = tfm.HamiltonianMonteCarlo(target_log_prob, 1e-1, 10)
  def trace_fn(state, pkr):
    return target_log_prob(state)

  states, log_prob = tfm.sample_chain(
    num_results=1000,
    num_burnin_steps=1000,
    kernel=kernel,
    current_state=initial_state,
    trace_fn=trace_fn,
    seed=seed
  )
  return states, log_prob

به خودی خود، run تابع طول می کشد در یک دانه تصادفی بدون تابعیت (تا ببینید که چگونه بدون تابعیت کار اتفاقی، شما می توانید به عنوان خوانده شده TFP در JAX نوت بوک یا دیدن آموزش JAX 101 ). نقشه برداری run بیش از دانه های مختلف در حال اجرا چند زنجیره مارکوف مستقل خواهد شد.

states, log_probs = jax.pmap(run)(random.split(random.PRNGKey(0), 8))
print(states.shape, log_probs.shape)
# states is (8 devices, 1000 samples, 2 chains, 2 dimensions)
# log_prob is (8 devices, 1000 samples, 2 chains)
(8, 1000, 2, 2) (8, 1000, 2)

توجه داشته باشید که چگونه ما اکنون یک محور اضافی مربوط به هر دستگاه داریم. می توانیم ابعاد را دوباره مرتب کنیم و آنها را صاف کنیم تا یک محور برای 16 زنجیره بدست آوریم.

states = states.transpose([0, 2, 1, 3]).reshape([-1, 1000, 2])
log_probs = log_probs.transpose([0, 2, 1]).reshape([-1, 1000])
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].plot(log_probs.T, alpha=0.4)
ax[1].scatter(*states.reshape([-1, 2]).T, alpha=0.1)
plt.show()

png

هنگامی که در حال اجرا زنجیره مستقل در بسیاری از دستگاه، آن را به عنوان آسان به عنوان pmap -ing بیش از یک تابع که با استفاده از tfp.mcmc ، حصول اطمینان از ما عبور مقادیر مختلف برای دانه تصادفی به هر دستگاه.

اشتراک گذاری داده ها

وقتی MCMC را انجام می‌دهیم، توزیع هدف اغلب یک توزیع پسین است که با شرطی کردن یک مجموعه داده به دست می‌آید، و محاسبه چگالی ورود به سیستم غیرعادی شامل جمع کردن احتمالات برای هر داده مشاهده‌شده است.

با مجموعه داده های بسیار بزرگ، حتی اجرای یک زنجیره روی یک دستگاه می تواند بسیار گران باشد. با این حال، وقتی به چندین دستگاه دسترسی داریم، می‌توانیم مجموعه داده را در بین دستگاه‌ها تقسیم کنیم تا از محاسباتی که در دسترس داریم بهتر استفاده کنیم.

اگر ما می خواهم به انجام MCMC با یک مجموعه داده sharded، ما نیاز به اطمینان از unnormalized ورود به سیستم با چگالی ما در هر دستگاه محاسبه نشان دهنده مجموع، یعنی چگالی بیش از همه داده ها، در غیر این صورت هر یک از دستگاه انجام خواهد MCMC با هدف نادرست خود را توزیع برای این منظور، بهره وری کل عوامل در حال حاضر ابزارهای جدید (یعنی tfp.experimental.distribute و tfp.experimental.mcmc ) است که قادر به محاسبه "sharded" احتمال ورود به سیستم و انجام MCMC با آنها.

توزیع های خرد شده

بر TFP انتزاع هسته در حال حاضر برای محاسبه probabiliities ورود به سیستم sharded است فراهم می کند Sharded متا توزیع، که طول می کشد یک توزیع به عنوان ورودی و یک توزیع جدید که دارای خواص خاص زمانی که در یک زمینه SPMD اعدام می گرداند. Sharded زندگی در tfp.experimental.distribute .

به طور مستقیم، یک Sharded مربوط توزیع به مجموعه ای از متغیرهای تصادفی که "تقسیم" در سراسر دستگاه بوده است. در هر دستگاه، آنها نمونه های مختلفی تولید می کنند و می توانند به صورت جداگانه دارای دانسیته ورود به سیستم متفاوتی باشند. روش دیگر، یک Sharded مربوط توزیع به "بشقاب" در مدل اصطلاح گرافیکی، که در آن اندازه صفحه تعداد دستگاه است.

نمونه برداری یک Sharded توزیع

اگر ما از یک نمونه Normal توزیع در یک برنامه در حال pmap اد با استفاده از دانه های مشابه بر روی هر دستگاه، ما را به نمونه های مشابه در هر دستگاه را دریافت کنید. می‌توانیم تابع زیر را به عنوان نمونه‌برداری از یک متغیر تصادفی که در بین دستگاه‌ها همگام‌سازی شده است، در نظر بگیریم.

# `pmap` expects at least one value to be mapped over, so we provide a dummy one
def f(seed, _):
  return tfd.Normal(0., 1.).sample(seed=seed)
jax.pmap(f, in_axes=(None, 0))(random.PRNGKey(0), jnp.arange(8.))
ShardedDeviceArray([-0.20584236, -0.20584236, -0.20584236, -0.20584236,
                    -0.20584236, -0.20584236, -0.20584236, -0.20584236],                   dtype=float32)

اگر ما بسته بندی tfd.Normal(0., 1.) با tfed.Sharded ، ما در حال حاضر منطقی هشت متغیر تصادفی مختلف (یکی در هر دستگاه) و در نتیجه تولید یک نمونه مختلف برای هر یک، با وجود عبور در دانه همان .

def f(seed, _):
  return tfed.Sharded(tfd.Normal(0., 1.), shard_axis_name='i').sample(seed=seed)
jax.pmap(f, in_axes=(None, 0), axis_name='i')(random.PRNGKey(0), jnp.arange(8.))
ShardedDeviceArray([ 1.2152631 ,  0.7818249 ,  0.32549605,  0.6828047 ,
                     1.3973192 , -0.57830244,  0.37862757,  2.7706041 ],                   dtype=float32)

یک نمایش معادل از این توزیع در یک دستگاه تنها 8 نمونه عادی مستقل است. هر چند ارزش از نمونه متفاوت خواهد بود ( tfed.Sharded کند شبه تصادفی تولید اعداد کمی متفاوت)، هر دو آنها نشان دهنده توزیع یکسان.

dist = tfd.Sample(tfd.Normal(0., 1.), jax.device_count())
dist.sample(seed=random.PRNGKey(0))
DeviceArray([ 0.08086783, -0.38624594, -0.3756545 ,  1.668957  ,
             -1.2758069 ,  2.1192007 , -0.85821325,  1.1305912 ],            dtype=float32)

با توجه به ورود چگالی یک Sharded توزیع

بیایید ببینیم چه اتفاقی می‌افتد وقتی که چگالی لگ یک نمونه را از یک توزیع معمولی در یک زمینه SPMD محاسبه می‌کنیم.

def f(seed, _):
  dist = tfd.Normal(0., 1.)
  x = dist.sample(seed=seed)
  return x, dist.log_prob(x)
jax.pmap(f, in_axes=(None, 0))(random.PRNGKey(0), jnp.arange(8.))
(ShardedDeviceArray([-0.20584236, -0.20584236, -0.20584236, -0.20584236,
                     -0.20584236, -0.20584236, -0.20584236, -0.20584236],                   dtype=float32),
 ShardedDeviceArray([-0.94012403, -0.94012403, -0.94012403, -0.94012403,
                     -0.94012403, -0.94012403, -0.94012403, -0.94012403],                   dtype=float32))

هر نمونه در هر دستگاه یکسان است، بنابراین ما چگالی یکسان را در هر دستگاه نیز محاسبه می کنیم. به طور شهودی، در اینجا ما فقط یک توزیع بر روی یک متغیر معمولی توزیع شده داریم.

با Sharded توزیع، ما باید یک توزیع بیش از 8 متغیرهای تصادفی، بنابراین زمانی که ما محاسبه log_prob از یک نمونه، ما خلاصه، در سراسر دستگاه، بیش از هر یک از تراکم ورود به سیستم منحصر به فرد. (شاید متوجه شوید که این مقدار کل log_prob بزرگتر از log_prob singleton محاسبه شده در بالا است.)

def f(seed, _):
  dist = tfed.Sharded(tfd.Normal(0., 1.), shard_axis_name='i')
  x = dist.sample(seed=seed)
  return x, dist.log_prob(x)
sample, log_prob = jax.pmap(f, in_axes=(None, 0), axis_name='i')(
    random.PRNGKey(0), jnp.arange(8.))
print('Sample:', sample)
print('Log Prob:', log_prob)
Sample: [ 1.2152631   0.7818249   0.32549605  0.6828047   1.3973192  -0.57830244
  0.37862757  2.7706041 ]
Log Prob: [-13.7349205 -13.7349205 -13.7349205 -13.7349205 -13.7349205 -13.7349205
 -13.7349205 -13.7349205]

توزیع معادل، "نشقه نشده" همان چگالی ورود را تولید می کند.

dist = tfd.Sample(tfd.Normal(0., 1.), jax.device_count())
dist.log_prob(sample)
DeviceArray(-13.7349205, dtype=float32)

Sharded توزیع مقادیر مختلف تولید از sample در هر دستگاه است، اما همان مقدار برای log_prob در هر دستگاه. اینجا چه خبره؟ Sharded توزیع می کند psum داخلی برای اطمینان از log_prob ارزش ها در هماهنگی در سراسر دستگاه می باشد. چرا این رفتار را می خواهیم؟ اگر ما در حال اجرا با همان سند MCMC در هر دستگاه، ما می خواهم target_log_prob به همان در سراسر هر دستگاه، حتی اگر برخی از متغیرهای تصادفی در محاسبه در سراسر دستگاه sharded.

علاوه بر این، Sharded تضمین می کند توزیع که شیب در سراسر دستگاه درست باشد، به اطمینان حاصل شود که الگوریتم های مانند شورای عالی رسانه ها، که به شیب تابع ورود به سیستم با چگالی به عنوان بخشی از تابع انتقال، تولید نمونه مناسب.

Sharded JointDistribution بازدید کنندگان

ما می توانیم مدل با متعدد ایجاد Sharded متغیرهای تصادفی با استفاده از JointDistribution بازدید کنندگان (JDS). متاسفانه، Sharded توزیع نمی تواند با خیال راحت با وانیل استفاده tfd.JointDistribution ، اما tfp.experimental.distribute صادرات "وصله" JDS که مانند رفتار خواهد شد Sharded توزیع.

def f(seed, _):
  dist = tfed.JointDistributionSequential([
    tfd.Normal(0., 1.),
    tfed.Sharded(tfd.Normal(0., 1.), shard_axis_name='i'),
  ])
  x = dist.sample(seed=seed)
  return x, dist.log_prob(x)
jax.pmap(f, in_axes=(None, 0), axis_name='i')(random.PRNGKey(0), jnp.arange(8.))
([ShardedDeviceArray([1.6121525, 1.6121525, 1.6121525, 1.6121525, 1.6121525,
                      1.6121525, 1.6121525, 1.6121525], dtype=float32),
  ShardedDeviceArray([ 0.8690128 , -0.83167845,  1.2209264 ,  0.88412696,
                       0.76478404, -0.66208494, -0.0129658 ,  0.7391483 ],                   dtype=float32)],
 ShardedDeviceArray([-12.214451, -12.214451, -12.214451, -12.214451,
                     -12.214451, -12.214451, -12.214451, -12.214451],                   dtype=float32))

این JDS sharded می توانید هر دو Sharded توزیع و وانیل بهره وری کل عوامل به عنوان اجزای. برای توزیع‌های تکه‌ای نشده، نمونه‌های مشابهی را در هر دستگاه به‌دست می‌آوریم و برای توزیع‌های خرد شده، نمونه‌های متفاوتی دریافت می‌کنیم. log_prob در هر دستگاه است و همچنین هماهنگ شده است.

MCMC با Sharded توزیع

چگونه ما در مورد فکر می کنم Sharded توزیع در زمینه MCMC؟ اگر ما یک مدل مولد است که می تواند به عنوان یک ابراز کرده اند JointDistribution ، ما می توانیم برخی محور که مدل به "سفال" در سراسر انتخاب کنید. به طور معمول، یک متغیر تصادفی در مدل با داده‌های مشاهده‌شده مطابقت دارد، و اگر مجموعه داده‌ای بزرگ داشته باشیم که می‌خواهیم آن را در دستگاه‌ها تقسیم کنیم، می‌خواهیم متغیرهایی که به نقاط داده مرتبط هستند نیز تقسیم شوند. همچنین ممکن است متغیرهای تصادفی «محلی» داشته باشیم که با مشاهداتی که به اشتراک می‌گذاریم، یک به یک هستند، بنابراین باید آن متغیرهای تصادفی را نیز برش دهیم.

ما بیش از نمونه هایی از استفاده از رفتن Sharded توزیع با بهره وری کل عوامل MCMC در این بخش. ما با یک بیزی مثال رگرسیون لجستیک ساده شروع و نتیجه گیری با یک مثال فاکتور ماتریس، با هدف نشان دادن برخی استفاده از موارد برای distribute کتابخانه.

مثال: رگرسیون لجستیک بیزی برای MNIST

ما می خواهیم رگرسیون لجستیک بیزی را روی یک مجموعه داده بزرگ انجام دهیم. این مدل یک قبل \(p(\theta)\) بیش از وزن های رگرسیون، و احتمال \(p(y_i | \theta, x_i)\) است که بیش از همه داده ها خلاصه \(\{x_i, y_i\}_{i = 1}^N\) برای به دست آوردن کل چگالی ورود به سیستم مشترک. اگر ما داده های ما سفال، ما می خواهم متغیرهای تصادفی مشاهده سفال \(x_i\) و \(y_i\) در مدل ما.

ما از مدل رگرسیون لجستیک بیزی زیر برای طبقه بندی MNIST استفاده می کنیم:

\[ \begin{align*} w &\sim \mathcal{N}(0, 1) \\ b &\sim \mathcal{N}(0, 1) \\ y_i | w, b, x_i &\sim \textrm{Categorical}(w^T x_i + b) \end{align*} \]

بیایید MNIST را با استفاده از مجموعه داده های TensorFlow بارگیری کنیم.

mnist = tfds.as_numpy(tfds.load('mnist', batch_size=-1))
raw_train_images, train_labels = mnist['train']['image'], mnist['train']['label']
train_images = raw_train_images.reshape([raw_train_images.shape[0], -1]) / 255.

raw_test_images, test_labels = mnist['test']['image'], mnist['test']['label']
test_images = raw_test_images.reshape([raw_test_images.shape[0], -1]) / 255.
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...
WARNING:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your
local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.
HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=4.0, style=ProgressStyle(descriptio…
Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.

ما 60000 تصویر آموزشی داریم اما بیایید از 8 هسته موجود خود استفاده کنیم و آن را به 8 روش تقسیم کنیم. ما این دستی را استفاده می کنید shard تابع مطلوبیت.

def shard_value(x):
  x = x.reshape((jax.device_count(), -1, *x.shape[1:]))
  return jax.pmap(lambda x: x)(x) # pmap will physically place values on devices

shard = functools.partial(jax.tree_map, shard_value)
sharded_train_images, sharded_train_labels = shard((train_images, train_labels))
print(sharded_train_images.shape, sharded_train_labels.shape)
(8, 7500, 784) (8, 7500)

قبل از ادامه، بیایید به سرعت در مورد دقت TPU و تأثیر آن بر HMC بحث کنیم. TPU ها اجرا ضرب ماتریس با استفاده از کم bfloat16 دقت برای سرعت. bfloat16 ضرب ماتریس اغلب برای بسیاری از برنامه های کاربردی یادگیری عمیق کافی است، اما زمانی که با شورای عالی رسانه ها استفاده می شود، ما به طور تجربی در بر داشت با دقت کمتر می توانید به واگرا مدار، باعث رد شود. ما می‌توانیم از ضرب‌های ماتریس با دقت بالاتر، به قیمت محاسبات اضافی استفاده کنیم.

برای افزایش دقت matmul ما، ما می توانید استفاده jax.default_matmul_precision دکوراتور با "tensorfloat32" دقت (برای دقت و حتی بالاتر ما می تواند استفاده "float32" دقت).

اکنون بیایید ما را تعریف run تابع، که در یک دانه تصادفی خواهد (که همان خواهد شد در هر دستگاه) و تکه های شکسته MNIST. تابع مدل فوق را پیاده سازی می کند و سپس از عملکرد وانیلی MCMC TFP برای اجرای یک زنجیره استفاده می کنیم. ما مطمئن شوید که برای تزئین run با jax.default_matmul_precision دکوراتور مطمئن شوید ضرب ماتریس است که با دقت بالاتر اجرا شود، هر چند در مثال خاص زیر، ما فقط به عنوان به خوبی می تواند با استفاده از jnp.dot(images, w, precision=lax.Precision.HIGH) .

# We can use `out_axes=None` in the `pmap` because the results will be the same
# on every device. 
@functools.partial(jax.pmap, axis_name='data', in_axes=(None, 0), out_axes=None)
@jax.default_matmul_precision('tensorfloat32')
def run(seed, data):
  images, labels = data # a sharded dataset
  num_examples, dim = images.shape
  num_classes = 10

  def model_fn():
    w = yield Root(tfd.Sample(tfd.Normal(0., 1.), [dim, num_classes]))
    b = yield Root(tfd.Sample(tfd.Normal(0., 1.), [num_classes]))
    logits = jnp.dot(images, w) + b
    yield tfed.Sharded(tfd.Independent(tfd.Categorical(logits=logits), 1),
                       shard_axis_name='data')
  model = tfed.JointDistributionCoroutine(model_fn)

  init_seed, sample_seed = random.split(seed)

  initial_state = model.sample(seed=init_seed)[:-1] # throw away `y`

  def target_log_prob(*state):
    return model.log_prob((*state, labels))

  def accuracy(w, b):
    logits = images.dot(w) + b
    preds = logits.argmax(axis=-1)
    # We take the average accuracy across devices by using `lax.pmean`
    return lax.pmean((preds == labels).mean(), 'data')

  kernel = tfm.HamiltonianMonteCarlo(target_log_prob, 1e-2, 100)
  kernel = tfm.DualAveragingStepSizeAdaptation(kernel, 500)
  def trace_fn(state, pkr):
    return (
        target_log_prob(*state),
        accuracy(*state),
        pkr.new_step_size)
  states, trace = tfm.sample_chain(
    num_results=1000,
    num_burnin_steps=1000,
    current_state=initial_state,
    kernel=kernel,
    trace_fn=trace_fn,
    seed=sample_seed
  )
  return states, trace

jax.pmap شامل یک کامپایل JIT اما تابع وارد شده است بعد از اولین تماس ذخیره سازی. ما پاسخ run و چشم پوشی از خروجی به کش تلفیقی.

%%time
output = run(random.PRNGKey(0), (sharded_train_images, sharded_train_labels))
jax.tree_map(lambda x: x.block_until_ready(), output)
CPU times: user 24.5 s, sys: 48.2 s, total: 1min 12s
Wall time: 1min 54s

حال به سراغ پاسخ run دوباره به ببینید که چه مدت طول می کشد اجرای واقعی.

%%time
states, trace = run(random.PRNGKey(0), (sharded_train_images, sharded_train_labels))
jax.tree_map(lambda x: x.block_until_ready(), trace)
CPU times: user 13.1 s, sys: 45.2 s, total: 58.3 s
Wall time: 1min 43s

ما در حال اجرای 200000 مرحله جهشی هستیم که هر یک از آنها یک گرادیان را در کل مجموعه داده محاسبه می کند. تقسیم محاسبات بر روی 8 هسته ما را قادر می سازد معادل 200000 دوره آموزشی را در حدود 95 ثانیه محاسبه کنیم، یعنی حدود 2100 دوره در ثانیه!

بیایید log-density هر نمونه و دقت هر نمونه را رسم کنیم:

fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].plot(trace[0])
ax[0].set_title('Log Prob')
ax[1].plot(trace[1])
ax[1].set_title('Accuracy')
ax[2].plot(trace[2])
ax[2].set_title('Step Size')
plt.show()

png

اگر نمونه ها را جمع آوری کنیم، می توانیم میانگین مدل بیزی را برای بهبود عملکرد خود محاسبه کنیم.

@functools.partial(jax.pmap, axis_name='data', in_axes=(0, None), out_axes=None)
def bayesian_model_average(data, states):
  images, labels = data
  logits = jax.vmap(lambda w, b: images.dot(w) + b)(*states)
  probs = jax.nn.softmax(logits, axis=-1)
  bma_accuracy = (probs.mean(axis=0).argmax(axis=-1) == labels).mean()
  avg_accuracy = (probs.argmax(axis=-1) == labels).mean()
  return lax.pmean(bma_accuracy, axis_name='data'), lax.pmean(avg_accuracy, axis_name='data')

sharded_test_images, sharded_test_labels = shard((test_images, test_labels))
bma_acc, avg_acc = bayesian_model_average((sharded_test_images, sharded_test_labels), states)
print(f'Average Accuracy: {avg_acc}')
print(f'BMA Accuracy: {bma_acc}')
print(f'Accuracy Improvement: {bma_acc - avg_acc}')
Average Accuracy: 0.9188529253005981
BMA Accuracy: 0.9264000058174133
Accuracy Improvement: 0.0075470805168151855

میانگین مدل بیزی دقت ما را تقریباً 1٪ افزایش می دهد!

مثال: سیستم توصیه MovieLens

اکنون بیایید با مجموعه داده توصیه های MovieLens که مجموعه ای از کاربران و رتبه بندی آنها از فیلم های مختلف است، استنتاج کنیم. به طور خاص، ما می توانیم MovieLens به عنوان یک نماینده \(N \times M\) ماتریس ساعت \(W\) که در آن \(N\) است که تعداد کاربران و \(M\) تعداد فیلم است. ما انتظار داریم \(N > M\). ورودی های \(W_{ij}\) یک بولی که نشان میدهد آیا کاربران \(i\) تماشا فیلم \(j\). توجه داشته باشید که MovieLens رتبه‌بندی‌های کاربران را ارائه می‌کند، اما برای ساده‌تر کردن مشکل، آن‌ها را نادیده می‌گیریم.

ابتدا مجموعه داده را بارگذاری می کنیم. ما از نسخه با 1 میلیون رتبه استفاده خواهیم کرد.

movielens = tfds.as_numpy(tfds.load('movielens/1m-ratings', batch_size=-1))
GENRES = ['Action', 'Adventure', 'Animation', 'Children', 'Comedy',
          'Crime', 'Documentary', 'Drama', 'Fantasy', 'Film-Noir',
          'Horror', 'IMAX', 'Musical', 'Mystery', 'Romance', 'Sci-Fi',
          'Thriller', 'Unknown', 'War', 'Western', '(no genres listed)']
Downloading and preparing dataset movielens/1m-ratings/0.1.0 (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0...
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Shuffling and writing examples to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0.incompleteYKA3TG/movielens-train.tfrecord
HBox(children=(FloatProgress(value=0.0, max=1000209.0), HTML(value='')))
Dataset movielens downloaded and prepared to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0. Subsequent calls will reuse this data.

ما برخی از پیش پردازش از مجموعه داده را انجام دهد به دست آوردن ماتریس ساعت \(W\).

raw_movie_ids = movielens['train']['movie_id']
raw_user_ids = movielens['train']['user_id']
genres = movielens['train']['movie_genres']

movie_ids, movie_labels = pd.factorize(movielens['train']['movie_id'])
user_ids, user_labels = pd.factorize(movielens['train']['user_id'])

num_movies = movie_ids.max() + 1
num_users = user_ids.max() + 1

movie_titles = dict(zip(movielens['train']['movie_id'],
                        movielens['train']['movie_title']))
movie_genres = dict(zip(movielens['train']['movie_id'],
                        genres))
movie_id_to_title = [movie_titles[movie_labels[id]].decode('utf-8')
                     for id in range(num_movies)]
movie_id_to_genre = [GENRES[movie_genres[movie_labels[id]][0]] for id in range(num_movies)]

watch_matrix = np.zeros((num_users, num_movies), bool)
watch_matrix[user_ids, movie_ids] = True
print(watch_matrix.shape)
(6040, 3706)

ما می توانیم یک مدل مولد برای تعریف \(W\)، با استفاده از یک مدل فاکتور ماتریس احتمالی ساده است. فرض کنیم یک نهفته \(N \times D\) ماتریس کاربران \(U\) و نهان \(M \times D\) ماتریس فیلم \(V\)، که هنگامی که ضرب ماتریس برای تولید ساعت مچی logits یک برنولی \(W\). ما همچنین شامل یک بردار تعصب برای کاربران و فیلم، \(u\) و \(v\).

\[ \begin{align*} U &\sim \mathcal{N}(0, 1) \quad u \sim \mathcal{N}(0, 1)\\ V &\sim \mathcal{N}(0, 1) \quad v \sim \mathcal{N}(0, 1)\\ W_{ij} &\sim \textrm{Bernoulli}\left(\sigma\left(\left(UV^T\right)_{ij} + u_i + v_j\right)\right) \end{align*} \]

این یک ماتریس بسیار بزرگ است. 6040 کاربر و 3706 فیلم منجر به ماتریسی با بیش از 22 میلیون ورودی در آن می شود. چگونه به اشتراک گذاری این مدل نزدیک شویم؟ خب، اگر فرض کنیم که \(N > M\) (یعنی کاربران بیش از فیلم وجود دارد)، و سپس آن را حس به سفال ماتریس ساعت در سراسر محور کاربران، بنابراین هر دستگاه را یک تکه از ماتریس ساعت مربوط به یک زیر مجموعه از کاربران . بر خلاف مثال قبلی، با این حال، ما نیز به سفال تا دارند \(U\) ماتریس، از آن است که تعبیه برای هر کاربر، بنابراین هر یک از دستگاه مسئول تکه های شکسته خواهد \(U\) و تکه های شکسته \(W\). از سوی دیگر، \(V\) unsharded خواهد بود و توان در سراسر دستگاه هماهنگ شده است.

sharded_watch_matrix = shard(watch_matrix)

قبل از اینکه ما ارسال ما run ، اجازه دهید به سرعت بحث در مورد چالش های اضافی با sharding محلی متغیر تصادفی \(U\). هنگامی که در حال اجرا HMC، وانیل tfp.mcmc.HamiltonianMonteCarlo هسته نمونه خواهد تکانه برای هر عنصر از دولت زنجیره ای است. پیش از این، تنها متغیرهای تصادفی تقسیم نشده بخشی از آن حالت بودند و لحظه لحظه در هر دستگاه یکسان بود. هنگامی که ما در حال حاضر sharded دارند \(U\)، ما نیاز به نمونه تکانه های مختلف در هر دستگاه برای \(U\)، در حالی که نمونه برداری از تکانه برای \(V\). برای انجام این کار، ما می توانیم با استفاده از tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo با Sharded توزیع حرکت. همانطور که ما به محاسبات موازی درجه یک ادامه می دهیم، ممکن است این را ساده کنیم، به عنوان مثال با گرفتن یک نشانگر تیرگی به هسته HMC.

def make_run(*,
             axis_name,
             dim=20,
             num_chains=2,
             prior_variance=1.,
             step_size=1e-2,
             num_leapfrog_steps=100,
             num_burnin_steps=1000,
             num_results=500,
             ):
  @functools.partial(jax.pmap, in_axes=(None, 0), axis_name=axis_name)
  @jax.default_matmul_precision('tensorfloat32')
  def run(key, watch_matrix):
    num_users, num_movies = watch_matrix.shape

    Sharded = functools.partial(tfed.Sharded, shard_axis_name=axis_name)

    def prior_fn():
      user_embeddings = yield Root(Sharded(tfd.Sample(tfd.Normal(0., 1.), [num_users, dim]), name='user_embeddings'))
      user_bias = yield Root(Sharded(tfd.Sample(tfd.Normal(0., 1.), [num_users]), name='user_bias'))
      movie_embeddings = yield Root(tfd.Sample(tfd.Normal(0., 1.), [num_movies, dim], name='movie_embeddings'))
      movie_bias = yield Root(tfd.Sample(tfd.Normal(0., 1.), [num_movies], name='movie_bias'))
      return (user_embeddings, user_bias, movie_embeddings, movie_bias)
    prior = tfed.JointDistributionCoroutine(prior_fn)

    def model_fn():
      user_embeddings, user_bias, movie_embeddings, movie_bias = yield from prior_fn()
      logits = (jnp.einsum('...nd,...md->...nm', user_embeddings, movie_embeddings)
                + user_bias[..., :, None] + movie_bias[..., None, :])
      yield Sharded(tfd.Independent(tfd.Bernoulli(logits=logits), 2), name='watch')
    model = tfed.JointDistributionCoroutine(model_fn)

    init_key, sample_key = random.split(key)
    initial_state = prior.sample(seed=init_key, sample_shape=num_chains)

    def target_log_prob(*state):
      return model.log_prob((*state, watch_matrix))

    momentum_distribution = tfed.JointDistributionSequential([
      Sharded(tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_users, dim]), 1.), 2)),
      Sharded(tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_users]), 1.), 1)),
      tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_movies, dim]), 1.), 2),
      tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_movies]), 1.), 1),
    ])

    # We pass in momentum_distribution here to ensure that the momenta for 
    # user_embeddings and user_bias are also sharded
    kernel = tfem.PreconditionedHamiltonianMonteCarlo(target_log_prob, step_size,
                                                      num_leapfrog_steps,
                                                      momentum_distribution=momentum_distribution)

    num_adaptation_steps = int(0.8 * num_burnin_steps)
    kernel = tfm.DualAveragingStepSizeAdaptation(kernel, num_adaptation_steps)

    def trace_fn(state, pkr):
      return {
        'log_prob': target_log_prob(*state),
        'log_accept_ratio': pkr.inner_results.log_accept_ratio,
      }
    return tfm.sample_chain(
        num_results, initial_state,
        kernel=kernel,
        num_burnin_steps=num_burnin_steps,
        trace_fn=trace_fn,
        seed=sample_key)
  return run

ما دوباره آن را اجرا کنید یک بار به کش وارد run .

%%time
run = make_run(axis_name='data')
output = run(random.PRNGKey(0), sharded_watch_matrix)
jax.tree_map(lambda x: x.block_until_ready(), output)
CPU times: user 56 s, sys: 1min 24s, total: 2min 20s
Wall time: 3min 35s

اکنون آن را دوباره بدون سربار کامپایل اجرا می کنیم.

%%time
states, trace = run(random.PRNGKey(0), sharded_watch_matrix)
jax.tree_map(lambda x: x.block_until_ready(), trace)
CPU times: user 28.8 s, sys: 1min 16s, total: 1min 44s
Wall time: 3min 1s

به نظر می رسد که ما حدود 150000 مرحله جهشی را در حدود 3 دقیقه تکمیل کرده ایم، بنابراین حدود 83 مرحله جهشی در ثانیه! بیایید نسبت پذیرش و چگالی لگ نمونه های خود را رسم کنیم.

fig, axs = plt.subplots(1, len(trace), figsize=(5 * len(trace), 5))
for ax, (key, val) in zip(axs, trace.items()):
  ax.plot(val[0]) # Indexing into a sharded array, each element is the same
  ax.set_title(key);

png

اکنون که چند نمونه از زنجیره مارکوف خود داریم، بیایید از آنها برای پیش‌بینی استفاده کنیم. ابتدا اجازه دهید هر یک از اجزا را استخراج کنیم. به یاد داشته باشید که user_embeddings و user_bias هستند تقسیم در دستگاه، بنابراین ما نیاز به الحاق ما ShardedArray به همه آنها را به دست آورد. از سوی دیگر، movie_embeddings و movie_bias همان در هر دستگاه است، بنابراین ما فقط می توانید مقدار از سفال اولین انتخاب کنید. ما به طور منظم استفاده numpy برای کپی از مقادیر تماس TPU ها را به پردازنده.

user_embeddings = np.concatenate(np.array(states.user_embeddings, np.float32), axis=2)
user_bias = np.concatenate(np.array(states.user_bias, np.float32), axis=2)
movie_embeddings = np.array(states.movie_embeddings[0], dtype=np.float32)
movie_bias = np.array(states.movie_bias[0], dtype=np.float32)
samples = (user_embeddings, user_bias, movie_embeddings, movie_bias)
print(f'User embeddings: {user_embeddings.shape}')
print(f'User bias: {user_bias.shape}')
print(f'Movie embeddings: {movie_embeddings.shape}')
print(f'Movie bias: {movie_bias.shape}')
User embeddings: (500, 2, 6040, 20)
User bias: (500, 2, 6040)
Movie embeddings: (500, 2, 3706, 20)
Movie bias: (500, 2, 3706)

بیایید سعی کنیم یک سیستم توصیه گر ساده بسازیم که از عدم قطعیت ثبت شده در این نمونه ها استفاده کند. بیایید ابتدا تابعی بنویسیم که فیلم ها را بر اساس احتمال تماشا رتبه بندی می کند.

@jax.jit
def recommend(sample, user_id):
  user_embeddings, user_bias, movie_embeddings, movie_bias = sample
  movie_logits = (
      jnp.einsum('d,md->m', user_embeddings[user_id], movie_embeddings)
      + user_bias[user_id] + movie_bias)
  return movie_logits.argsort()[::-1]

اکنون می‌توانیم تابعی بنویسیم که روی همه نمونه‌ها حلقه بزند و برای هر یک، فیلم رتبه‌بندی شده برتر را که کاربر قبلاً تماشا نکرده است، انتخاب کند. سپس می‌توانیم تعداد فیلم‌های پیشنهادی را در بین نمونه‌ها ببینیم.

def get_recommendations(user_id): 
  movie_ids = []
  already_watched = set(jnp.arange(num_movies)[watch_matrix[user_id] == 1])
  for i in range(500):
    for j in range(2):
      sample = jax.tree_map(lambda x: x[i, j], samples)
      ranking = recommend(sample, user_id)
      for movie_id in ranking:
        if int(movie_id) not in already_watched:
          movie_ids.append(movie_id)
          break
  return movie_ids

def plot_recommendations(movie_ids, ax=None):
  titles = collections.Counter([movie_id_to_title[i] for i in movie_ids])
  ax = ax or plt.gca()
  names, counts = zip(*sorted(titles.items(), key=lambda x: -x[1]))
  ax.bar(names, counts)
  ax.set_xticklabels(names, rotation=90)

بیایید کاربری را که بیشترین فیلم را دیده در مقابل شخصی که کمترین فیلم را دیده است، در نظر بگیریم.

user_watch_counts = watch_matrix.sum(axis=1)
user_most = user_watch_counts.argmax()
user_least = user_watch_counts.argmin()
print(user_watch_counts[user_most], user_watch_counts[user_least])
2314 20

ما امیدواریم که سیستم ما است اطمینان بیشتری در مورد user_most از user_least ، با توجه به اینکه ما باید اطلاعات بیشتری در مورد اینکه چه نوع از فیلم user_most بیشتر احتمال دارد به تماشا.

fig, ax = plt.subplots(1, 2, figsize=(20, 10))
most_recommendations = get_recommendations(user_most)
plot_recommendations(most_recommendations, ax=ax[0])
ax[0].set_title('Recommendation for user_most')
least_recommendations = get_recommendations(user_least)
plot_recommendations(least_recommendations, ax=ax[1])
ax[1].set_title('Recommendation for user_least');

png

ما می بینیم این است که واریانس در توصیه های ما برای وجود دارد user_least منعکس عدم قطعیت اضافی ما در تنظیمات سازمان دیده بان خود.

همچنین می‌توانیم نگاهی به ژانرهای فیلم‌های پیشنهادی داشته باشیم.

most_genres = collections.Counter([movie_id_to_genre[i] for i in most_recommendations])
least_genres = collections.Counter([movie_id_to_genre[i] for i in least_recommendations])
fig, ax = plt.subplots(1, 2, figsize=(20, 10))
ax[0].bar(most_genres.keys(), most_genres.values())
ax[0].set_title('Genres recommended for user_most')
ax[1].bar(least_genres.keys(), least_genres.values())
ax[1].set_title('Genres recommended for user_least');

png

user_most دیده است بسیاری از فیلم ها و ژانرهای است طاقچه مانند رمز و راز و جرم و جنایت توصیه شده است در حالی که user_least است بسیاری از فیلم را تماشا نیست و فیلم بیشتر جریان اصلی، که کمدی و اکشن چوله توصیه شده است.