JAX ile Dağıtılmış Çıkarım

TensorFlow.org'da görüntüleyin Google Colab'da çalıştırın Kaynağı GitHub'da görüntüleyinNot defterini indir

JAX üzerindeki TensorFlow Probability (TFP) artık dağıtılmış sayısal hesaplama için araçlara sahiptir. Çok sayıda hızlandırıcıya ölçeklendirmek için araçlar, "tek programlı çoklu veri" paradigmasını veya kısaca SPMD'yi kullanarak kod yazma etrafında oluşturulmuştur.

Bu not defterinde, "SPMD'de nasıl düşünüleceğini" gözden geçireceğiz ve TPU bölmeleri veya GPU kümeleri gibi yapılandırmalara ölçeklendirme için yeni TFP soyutlamalarını tanıtacağız. Bu kodu kendiniz çalıştırıyorsanız, bir TPU çalışma zamanı seçtiğinizden emin olun.

Önce en son TFP, JAX ve TF sürümlerini yükleyeceğiz.

Yüklemeler

Bazı JAX yardımcı programları ile birlikte bazı genel kitaplıkları içe aktaracağız.

Kurulum ve İçe Aktarmalar

INFO:tensorflow:Enabling eager execution
INFO:tensorflow:Enabling v2 tensorshape
INFO:tensorflow:Enabling resource variables
INFO:tensorflow:Enabling tensor equality
INFO:tensorflow:Enabling control flow v2

Ayrıca bazı kullanışlı TFP takma adları ayarlayacağız. Yeni soyutlamalar şu anda verilmektedir tfp.experimental.distribute ve tfp.experimental.mcmc .

tfd = tfp.distributions
tfb = tfp.bijectors
tfm = tfp.mcmc
tfed = tfp.experimental.distribute
tfde = tfp.experimental.distributions
tfem = tfp.experimental.mcmc

Root = tfed.JointDistributionCoroutine.Root

Notebook'u bir TPU'ya bağlamak için JAX'ın aşağıdaki yardımcısını kullanıyoruz. Bağlandığımızı doğrulamak için, sekiz olması gereken cihaz sayısını yazdırıyoruz.

from jax.tools import colab_tpu
colab_tpu.setup_tpu()
print(f'Found {jax.device_count()} devices')
Found 8 devices

Hızlı bir giriş jax.pmap

TPU bağlandıktan sonra, sekiz cihazlara erişimi vardır. Ancak, JAX kodunu hevesle çalıştırdığımızda, JAX varsayılan olarak yalnızca bir tanesinde hesaplamaları çalıştırmaya başlar.

Birçok cihaz arasında bir hesaplama yürütmenin en basit yolu, her cihazın haritanın bir dizinini yürütmesini sağlayarak bir işlevi eşlemektir. JAX sağlar jax.pmap çeşitli cihazlar üzerinde işlevini eşleyen birine bir fonksiyon döner ( "paralel haritası") dönüşümü.

Aşağıdaki örnekte, 8 boyutunda bir dizi oluşturuyoruz (mevcut cihazların sayısıyla eşleşmesi için) ve buna 5 ekleyen bir işlevi eşliyoruz.

xs = jnp.arange(8.)
out = jax.pmap(lambda x: x + 5.)(xs)
print(type(out), out)
<class 'jax.interpreters.pxla.ShardedDeviceArray'> [ 5.  6.  7.  8.  9. 10. 11. 12.]

Not biz almalarını ShardedDeviceArray çıktı dizisi fiziksel cihazlar arasında bölünmüş belirten tip geri.

jax.pmap semantik bir harita gibi davranır, ama onun davranışını değiştirmek birkaç önemli seçenek vardır. Varsayılan olarak, pmap işlevine tüm girişler üzerinde haritası çıkarılan varsayar, ancak bu davranışı değiştirebilirsiniz in_axes argüman.

xs = jnp.arange(8.)
y = 5.
# Map over the 0-axis of `xs` and don't map over `y`
out = jax.pmap(lambda x, y: x + y, in_axes=(0, None))(xs, y)
print(out)
[ 5.  6.  7.  8.  9. 10. 11. 12.]

Benzer şekilde, out_axes argüman pmap her cihazda değerler döndürmek için olup olmadığını belirler. Ayar out_axes için None otomatik 1 cihazda değer döndürür ve biz değerleri her cihazda aynı eminiz yalnızca kullanılmalıdır.

xs = jnp.ones(8) # Value is the same on each device
out = jax.pmap(lambda x: x + 1, out_axes=None)(xs)
print(out)
2.0

Yapmak istediğimiz şey, eşlenmiş bir saf işlev olarak kolayca ifade edilemezse ne olur? Örneğin, haritasını çıkardığımız eksen boyunca bir toplam yapmak istersek ne olur? JAX, daha ilginç ve karmaşık dağıtılmış programların yazılmasını sağlamak için cihazlar arasında iletişim kuran işlevler olan "kolektifler" sunar. Tam olarak nasıl çalıştıklarını anlamak için SPMD'yi tanıtacağız.

SPMD nedir?

Tek programlı çoklu veri (SPMD), cihazlar arasında aynı anda tek bir programın (yani aynı kodun) yürütüldüğü, ancak çalışan programların her birinin girdilerinin farklı olabileceği eşzamanlı bir programlama modelidir.

Program girişlerinden basit fonksiyonu (yani gibi bir şey ise x + 5 ) örneğinde yaptığımız gibi, SPMD bir program çalıştırarak sadece o farklı veri haritasını çıkarıyor jax.pmap erken. Ancak, bir işlevi "haritalamaktan" fazlasını yapabiliriz. JAX, cihazlar arasında iletişim kuran işlevler olan "kolektifler" sunar.

Örneğin, tüm cihazlarımızda bir miktarın toplamını almak isteyebiliriz. Bunu yapmadan önce, biz üzerinde Bizler eşleme ekseni bir ad atamanız gerekir pmap . Sonra kullanmak lax.psum biz üzerinde toplanmasıyla ediyoruz eksen biz adlandırılmış tespit sağlanması, cihazlar arasında bir miktar gerçekleştirmek için ( "paralel toplamı") fonksiyonu.

def f(x):
  out = lax.psum(x, axis_name='i')
  return out
xs = jnp.arange(8.) # Length of array matches number of devices
jax.pmap(f, axis_name='i')(xs)
ShardedDeviceArray([28., 28., 28., 28., 28., 28., 28., 28.], dtype=float32)

psum toplu agrega değerini x her aygıt ve harita üzerinde değeri yani senkronize out olduğu 28. Her cihazda. Artık basit bir "harita" gerçekleştirmiyoruz, ancak her bir cihazın hesaplamasının artık diğer cihazlarda aynı hesaplama ile kolektifler kullanarak sınırlı da olsa etkileşime girebildiği bir SPMD programı yürütüyoruz. Bu senaryoda, kullanabileceğimiz out_axes = None , çünkü psum değerini senkronize olacak.

def f(x):
  out = lax.psum(x, axis_name='i')
  return out
jax.pmap(f, axis_name='i', out_axes=None)(jnp.arange(8.))
ShardedDeviceArray(28., dtype=float32)

SPMD, aynı anda herhangi bir TPU konfigürasyonunda her cihazda çalışan bir program yazmamızı sağlar. 8 TPU çekirdeğinde makine öğrenimi yapmak için kullanılan kodun aynısı, yüzlerce ila binlerce çekirdeğe sahip olabilecek bir TPU bölmesinde kullanılabilir! Hakkında daha detaylı bir eğitim için jax.pmap ve SPMD, sen başvurabilirsiniz JAX 101 öğretici .

geniş ölçekte MCMC

Bu not defterinde, Bayes çıkarımı için Markov Zinciri Monte Carlo (MCMC) yöntemlerini kullanmaya odaklanıyoruz. MCMC için birçok cihazı kullanmamızın bir yolu olabilir, ancak bu not defterinde iki tanesine odaklanacağız:

  1. Farklı cihazlarda bağımsız Markov zincirleri çalıştırma. Bu durum oldukça basittir ve vanilya TFP ile yapmak mümkündür.
  2. Bir veri kümesini cihazlar arasında paylaşma. Bu durum biraz daha karmaşıktır ve yakın zamanda eklenen TFP makineleri gerektirir.

Bağımsız Zincirler

Diyelim ki MCMC kullanarak bir problem üzerinde Bayes çıkarımı yapmak istiyoruz ve birkaç zinciri paralel olarak birkaç cihazda çalıştırmak istiyoruz (her cihazda 2 diyelim). Bu, aygıtlar arasında yalnızca "eşleyebileceğimiz" bir program olduğu ortaya çıkıyor, yani kollektiflere ihtiyaç duymayan bir program. Her programın farklı bir Markov zinciri çalıştırdığından emin olmak için (aynı olanı çalıştırmak yerine), her cihaza rastgele tohum için farklı bir değer iletiyoruz.

Bunu bir 2-D Gauss dağılımından örneklemenin bir oyuncak problemi üzerinde deneyelim. TFP'nin mevcut MCMC işlevselliğini kutudan çıktığı gibi kullanabiliriz. Genel olarak, tüm cihazlarda çalışanla ilkini daha açık bir şekilde ayırt etmek için mantığın çoğunu haritalanmış işlevimizin içine koymaya çalışıyoruz.

def run(seed):
  target_log_prob = tfd.Sample(tfd.Normal(0., 1.), 2).log_prob

  initial_state = jnp.zeros([2, 2]) # 2 chains
  kernel = tfm.HamiltonianMonteCarlo(target_log_prob, 1e-1, 10)
  def trace_fn(state, pkr):
    return target_log_prob(state)

  states, log_prob = tfm.sample_chain(
    num_results=1000,
    num_burnin_steps=1000,
    kernel=kernel,
    current_state=initial_state,
    trace_fn=trace_fn,
    seed=seed
  )
  return states, log_prob

Kendi başına, run işlevi vatansız rasgele tohum alır (ne kadar vatansız rasgelelik iş görmek için, okuyabilir JAX üzerinde PFP'yi dizüstü ya bakınız JAX 101 öğretici ). Haritalama run farklı tohumlar üzerinde birçok bağımsız Markov zincirleri çalıştıran sonuçlanacaktır.

states, log_probs = jax.pmap(run)(random.split(random.PRNGKey(0), 8))
print(states.shape, log_probs.shape)
# states is (8 devices, 1000 samples, 2 chains, 2 dimensions)
# log_prob is (8 devices, 1000 samples, 2 chains)
(8, 1000, 2, 2) (8, 1000, 2)

Artık her bir cihaza karşılık gelen ekstra bir eksenimiz olduğuna dikkat edin. 16 zincir için bir eksen elde etmek için boyutları yeniden düzenleyebilir ve düzleştirebiliriz.

states = states.transpose([0, 2, 1, 3]).reshape([-1, 1000, 2])
log_probs = log_probs.transpose([0, 2, 1]).reshape([-1, 1000])
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].plot(log_probs.T, alpha=0.4)
ax[1].scatter(*states.reshape([-1, 2]).T, alpha=0.1)
plt.show()

png

Birçok cihazda bağımsız zincirleri çalıştırırken olarak, bu kolay olarak var pmap kullandığı bir işlevi üzerinde -ing tfp.mcmc , sağlanması her cihaza rastgele tohum için farklı değerler geçmektedir.

Verileri parçalama

MCMC yaptığımızda, hedef dağılım genellikle bir veri kümesi üzerinde koşullandırma yoluyla elde edilen sonsal bir dağılımdır ve normalleştirilmemiş bir günlük yoğunluğunun hesaplanması, gözlemlenen her veri için olasılıkların toplanmasını içerir.

Çok büyük veri kümeleriyle, tek bir aygıtta bir zinciri çalıştırmak bile aşırı derecede pahalı olabilir. Bununla birlikte, birden fazla cihaza erişimimiz olduğunda, elimizdeki bilgi işlemden daha iyi yararlanmak için veri kümesini cihazlar arasında bölebiliriz.

Biz kanatlı bir veri kümesi ile MCMC yapmak isterseniz, biz başka türlü her cihaz kendi yanlış hedefle MCMC yapıyor olacak, biz her cihazda hesaplamak normalize edilmemiş günlük yoğunluklu yani tüm veriler üzerinde yoğunluk toplam temsil sağlamak için gereken dağıtım. Bu amaçla, TFP şimdi (yani yeni araçlara sahiptir tfp.experimental.distribute ve tfp.experimental.mcmc ) o bilgisayar "kanatlı bir" günlük olasılıklarını etkinleştirip onlarla MCMC yapıyor.

Parçalı dağıtımlar

Çekirdek soyutlama TFP hemen kanatlı bir günlük probabiliities olan işlem sağlar Sharded giriş olarak bir dağıtım alır ve bir SPMD bağlamında yürütüldüğünde spesifik özelliklere sahip olan yeni bir dağılımını verir meta dağılımı. Sharded yaşıyor tfp.experimental.distribute .

Sezgisel bir Sharded cihaz üzerinden "bölünmüş" olmuştur rasgele değişkenlerin bir dizi dağıtım karşılık gelir. Her cihazda farklı örnekler üretecekler ve ayrı ayrı farklı log yoğunluklarına sahip olabilirler. Seçenek olarak ise, bir Sharded plaka boyutu cihazların sayısı olan grafik modeli dilinde bir "levhası", dağıtım karşılık gelir.

Bir örnekleme Sharded dağılım

Biz dan örnek ise Normal bir program varlık dağıtım pmap her cihazda aynı tohum kullanılarak Ed, her cihazda aynı numuneyi alacak. Aşağıdaki işlevi, cihazlar arasında senkronize edilen tek bir rastgele değişkeni örneklemek olarak düşünebiliriz.

# `pmap` expects at least one value to be mapped over, so we provide a dummy one
def f(seed, _):
  return tfd.Normal(0., 1.).sample(seed=seed)
jax.pmap(f, in_axes=(None, 0))(random.PRNGKey(0), jnp.arange(8.))
ShardedDeviceArray([-0.20584236, -0.20584236, -0.20584236, -0.20584236,
                    -0.20584236, -0.20584236, -0.20584236, -0.20584236],                   dtype=float32)

Biz sarın Eğer tfd.Normal(0., 1.) bir ile tfed.Sharded biz mantıksal olarak şimdi (her cihazda bir) sekiz farklı rasgele değişkenler var ve bu nedenle aynı tohumdan geçen rağmen her biri için farklı bir örnek üretecek .

def f(seed, _):
  return tfed.Sharded(tfd.Normal(0., 1.), shard_axis_name='i').sample(seed=seed)
jax.pmap(f, in_axes=(None, 0), axis_name='i')(random.PRNGKey(0), jnp.arange(8.))
ShardedDeviceArray([ 1.2152631 ,  0.7818249 ,  0.32549605,  0.6828047 ,
                     1.3973192 , -0.57830244,  0.37862757,  2.7706041 ],                   dtype=float32)

Bu dağılımın tek bir cihazda eşdeğer bir temsili sadece 8 bağımsız normal örnektir. Örneğin değeri (farklı olacaktır olsa tfed.Sharded biraz farklı bir yalancı rasgele sayı oluşturma yapar), aynı dağıtım temsil hem.

dist = tfd.Sample(tfd.Normal(0., 1.), jax.device_count())
dist.sample(seed=random.PRNGKey(0))
DeviceArray([ 0.08086783, -0.38624594, -0.3756545 ,  1.668957  ,
             -1.2758069 ,  2.1192007 , -0.85821325,  1.1305912 ],            dtype=float32)

Bir log-yoğunluk alınması Sharded dağıtım

SPMD bağlamında normal bir dağıtımdan bir örneğin günlük yoğunluğunu hesapladığımızda ne olduğunu görelim.

def f(seed, _):
  dist = tfd.Normal(0., 1.)
  x = dist.sample(seed=seed)
  return x, dist.log_prob(x)
jax.pmap(f, in_axes=(None, 0))(random.PRNGKey(0), jnp.arange(8.))
(ShardedDeviceArray([-0.20584236, -0.20584236, -0.20584236, -0.20584236,
                     -0.20584236, -0.20584236, -0.20584236, -0.20584236],                   dtype=float32),
 ShardedDeviceArray([-0.94012403, -0.94012403, -0.94012403, -0.94012403,
                     -0.94012403, -0.94012403, -0.94012403, -0.94012403],                   dtype=float32))

Her örnek her cihazda aynıdır, bu nedenle her cihazda aynı yoğunluğu hesaplarız. Sezgisel olarak, burada yalnızca tek bir normal dağılımlı değişken üzerinde bir dağılıma sahibiz.

Bir ile Sharded dağılımı, biz hesaplamak nedenle zaman, 8 rastgele değişkenler üzerindeki dağılımı o log_prob bir numunenin, bireysel giriş yoğunluğu üzerinden her cihazlar arasında, toplar. (Bu toplam log_prob değerinin yukarıda hesaplanan singleton log_prob değerinden daha büyük olduğunu fark edebilirsiniz.)

def f(seed, _):
  dist = tfed.Sharded(tfd.Normal(0., 1.), shard_axis_name='i')
  x = dist.sample(seed=seed)
  return x, dist.log_prob(x)
sample, log_prob = jax.pmap(f, in_axes=(None, 0), axis_name='i')(
    random.PRNGKey(0), jnp.arange(8.))
print('Sample:', sample)
print('Log Prob:', log_prob)
Sample: [ 1.2152631   0.7818249   0.32549605  0.6828047   1.3973192  -0.57830244
  0.37862757  2.7706041 ]
Log Prob: [-13.7349205 -13.7349205 -13.7349205 -13.7349205 -13.7349205 -13.7349205
 -13.7349205 -13.7349205]

Eşdeğer, "paylaşılmamış" dağıtım, aynı günlük yoğunluğunu üretir.

dist = tfd.Sample(tfd.Normal(0., 1.), jax.device_count())
dist.log_prob(sample)
DeviceArray(-13.7349205, dtype=float32)

Bir Sharded dağılımı farklı değerler üreten sample her bir cihaz, aynı için aynı değeri elde log_prob her cihazda. Burada neler oluyor? Bir Sharded dağılımı yapar psum sağlamak için içten log_prob değerleri cihazlar arasında senkronize bulunmaktadır. Bu davranışı neden isteyelim? Her cihazda aynı MCMC zincirini çalıştırıyorsanız, istediğimiz target_log_prob hesaplamasında bazı rasgele değişkenler cihazlar arasında kanatlı bir dahi, her cihaz için aynı olması.

Buna ek olarak, bir Sharded cihazlarda gradyanlar doğru olduğunu dağılımı sağlayan, geçiş işlevinin bir parçası olarak log-yoğunluk fonksiyonunun gradyanlar almak uygun örnekleri üretmek HMC gibi bu algoritmaları sağlamak.

Kanatlı bir JointDistribution s

Birden fazla olan modeller oluşturabilirsiniz Sharded kullanarak rastgele değişkenlerin JointDistribution ler (JDs). Ne yazık ki, Sharded dağılımlar güvenle vanilya ile kullanılamaz tfd.JointDistribution ler, ancak tfp.experimental.distribute ihracat gibi davranacaktır Jülyen "yamalı" Sharded dağılımları.

def f(seed, _):
  dist = tfed.JointDistributionSequential([
    tfd.Normal(0., 1.),
    tfed.Sharded(tfd.Normal(0., 1.), shard_axis_name='i'),
  ])
  x = dist.sample(seed=seed)
  return x, dist.log_prob(x)
jax.pmap(f, in_axes=(None, 0), axis_name='i')(random.PRNGKey(0), jnp.arange(8.))
([ShardedDeviceArray([1.6121525, 1.6121525, 1.6121525, 1.6121525, 1.6121525,
                      1.6121525, 1.6121525, 1.6121525], dtype=float32),
  ShardedDeviceArray([ 0.8690128 , -0.83167845,  1.2209264 ,  0.88412696,
                       0.76478404, -0.66208494, -0.0129658 ,  0.7391483 ],                   dtype=float32)],
 ShardedDeviceArray([-12.214451, -12.214451, -12.214451, -12.214451,
                     -12.214451, -12.214451, -12.214451, -12.214451],                   dtype=float32))

Bunlar kanatlı bir JDs ikisine de sahip olabilir Sharded bileşen olarak ve vanilyalı TFP dağılımları. Parçalanmamış dağılımlar için her cihazda aynı örneği, parçalı dağıtımlar için farklı örnekler elde ederiz. log_prob her cihazda de senkronize edilir.

İle MCMC Sharded dağılımları

Nasıl hakkında düşünüyorsunuz Sharded MCMC bağlamında dağılımları? Biz olarak ifade edilebilir bir üretken bir model varsa JointDistribution , karşımıza "kırığa" o modelin bazı eksenini alabilirsiniz. Tipik olarak, modeldeki bir rastgele değişken, gözlemlenen verilere karşılık gelir ve cihazlar arasında parçalamak istediğimiz büyük bir veri kümemiz varsa, veri noktalarıyla ilişkili değişkenlerin de paylaşılmasını isteriz. Ayrıca, parçaladığımız gözlemlerle bire bir olan "yerel" rastgele değişkenlere sahip olabiliriz, bu nedenle bu rastgele değişkenleri ek olarak parçalamamız gerekecek.

Biz kullanımının örneklerini ele alacağız Sharded bu bölümde TFP MCMC ile dağılımlar. Biz daha basit Bayes lojistik regresyon örnekle başlayalım ve bazı kullanım-örnekler de sergilemektedir hedefiyle, bir matris ayrıştırma örnekle bitireceğim distribute kütüphanesi.

Örnek: MNIST için Bayes lojistik regresyon

Büyük bir veri setinde Bayesian lojistik regresyon yapmak istiyoruz; Model, önceki sahip \(p(\theta)\) regresyon ağırlıkları üzerinde ve bir olabilirlik \(p(y_i | \theta, x_i)\) tüm veriler üzerinde toplanır \(\{x_i, y_i\}_{i = 1}^N\) total eklem günlük yoğunluğu elde edildi. Bizim veri shard, biz gözlenen rastgele değişkenler shard ediyorum \(x_i\) ve \(y_i\) bizim modelinde.

MNIST sınıflandırması için aşağıdaki Bayes lojistik regresyon modelini kullanıyoruz:

\[ \begin{align*} w &\sim \mathcal{N}(0, 1) \\ b &\sim \mathcal{N}(0, 1) \\ y_i | w, b, x_i &\sim \textrm{Categorical}(w^T x_i + b) \end{align*} \]

TensorFlow Veri Kümelerini kullanarak MNIST'i yükleyelim.

mnist = tfds.as_numpy(tfds.load('mnist', batch_size=-1))
raw_train_images, train_labels = mnist['train']['image'], mnist['train']['label']
train_images = raw_train_images.reshape([raw_train_images.shape[0], -1]) / 255.

raw_test_images, test_labels = mnist['test']['image'], mnist['test']['label']
test_images = raw_test_images.reshape([raw_test_images.shape[0], -1]) / 255.
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...
WARNING:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your
local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.
HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=4.0, style=ProgressStyle(descriptio…
Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.

60000 eğitim görüntümüz var, ancak mevcut 8 çekirdeğimizden yararlanalım ve onu 8 yola ayıralım. Biz bu kullanışlı kullanacağız shard fayda fonksiyonu.

def shard_value(x):
  x = x.reshape((jax.device_count(), -1, *x.shape[1:]))
  return jax.pmap(lambda x: x)(x) # pmap will physically place values on devices

shard = functools.partial(jax.tree_map, shard_value)
sharded_train_images, sharded_train_labels = shard((train_images, train_labels))
print(sharded_train_images.shape, sharded_train_labels.shape)
(8, 7500, 784) (8, 7500)

Devam etmeden önce, TPU'lardaki hassasiyeti ve bunun HMC üzerindeki etkisini hızlıca tartışalım. TPU düşük kullanılarak matris çarpımı çalıştırmak bfloat16 hızı hassas. bfloat16 matris çarpım genellikle çok derin öğrenme uygulamalar için yeterlidir, ancak HMC kullanıldığında, ampirik doğruluğu daha düşüktür rejeksiyonları neden yörüngeleri ayrılan yol açabilir bulduk. Bazı ek hesaplamalar pahasına daha yüksek hassasiyetli matris çarpımlarını kullanabiliriz.

Bizim matmul hassasiyetini artırmak için, kullanabileceğimiz jax.default_matmul_precision ile dekoratör "tensorfloat32" (hatta daha yüksek hassasiyet için yarar bir hassasiyetle "float32" kesinlik).

Şimdi bizim tanımlayalım run (her cihazda aynı olacaktır) rastgele tohumda alacak işlevi ve MNIST bir shard. İşlev, yukarıda belirtilen modeli uygulayacak ve ardından tek bir zinciri çalıştırmak için TFP'nin vanilya MCMC işlevselliğini kullanacağız. Emin süslemeleri için yapacağız run ile jax.default_matmul_precision aşağıda özel örnekte, sadece yanı kullanabilirsiniz olsa emin matris çarpım yüksek hassasiyetle çalışır hale getirmek için dekoratör jnp.dot(images, w, precision=lax.Precision.HIGH) .

# We can use `out_axes=None` in the `pmap` because the results will be the same
# on every device. 
@functools.partial(jax.pmap, axis_name='data', in_axes=(None, 0), out_axes=None)
@jax.default_matmul_precision('tensorfloat32')
def run(seed, data):
  images, labels = data # a sharded dataset
  num_examples, dim = images.shape
  num_classes = 10

  def model_fn():
    w = yield Root(tfd.Sample(tfd.Normal(0., 1.), [dim, num_classes]))
    b = yield Root(tfd.Sample(tfd.Normal(0., 1.), [num_classes]))
    logits = jnp.dot(images, w) + b
    yield tfed.Sharded(tfd.Independent(tfd.Categorical(logits=logits), 1),
                       shard_axis_name='data')
  model = tfed.JointDistributionCoroutine(model_fn)

  init_seed, sample_seed = random.split(seed)

  initial_state = model.sample(seed=init_seed)[:-1] # throw away `y`

  def target_log_prob(*state):
    return model.log_prob((*state, labels))

  def accuracy(w, b):
    logits = images.dot(w) + b
    preds = logits.argmax(axis=-1)
    # We take the average accuracy across devices by using `lax.pmean`
    return lax.pmean((preds == labels).mean(), 'data')

  kernel = tfm.HamiltonianMonteCarlo(target_log_prob, 1e-2, 100)
  kernel = tfm.DualAveragingStepSizeAdaptation(kernel, 500)
  def trace_fn(state, pkr):
    return (
        target_log_prob(*state),
        accuracy(*state),
        pkr.new_step_size)
  states, trace = tfm.sample_chain(
    num_results=1000,
    num_burnin_steps=1000,
    current_state=initial_state,
    kernel=kernel,
    trace_fn=trace_fn,
    seed=sample_seed
  )
  return states, trace

jax.pmap bir JIT derleme kapsar ancak derlenmiş işlev ilk çağrıdan sonra önbelleğe alınır. Biz arayacağım run ve derleme önbelleğe çıkışını görmezden.

%%time
output = run(random.PRNGKey(0), (sharded_train_images, sharded_train_labels))
jax.tree_map(lambda x: x.block_until_ready(), output)
CPU times: user 24.5 s, sys: 48.2 s, total: 1min 12s
Wall time: 1min 54s

Şimdi çağrı edeceğiz run yine gerçek yürütme ne kadar sürdüğünü görmek için.

%%time
states, trace = run(random.PRNGKey(0), (sharded_train_images, sharded_train_labels))
jax.tree_map(lambda x: x.block_until_ready(), trace)
CPU times: user 13.1 s, sys: 45.2 s, total: 58.3 s
Wall time: 1min 43s

Her biri tüm veri kümesi üzerinde bir gradyan hesaplayan 200.000 sıçrama adımı yürütüyoruz. Hesaplamayı 8 çekirdeğe bölmek, 200.000 eğitim dönemi eşdeğerini yaklaşık 95 saniyede, saniyede yaklaşık 2.100 dönem hesaplamamızı sağlar!

Her örneğin log yoğunluğunu ve her örneğin doğruluğunu çizelim:

fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].plot(trace[0])
ax[0].set_title('Log Prob')
ax[1].plot(trace[1])
ax[1].set_title('Accuracy')
ax[2].plot(trace[2])
ax[2].set_title('Step Size')
plt.show()

png

Örnekleri birleştirirsek, performansımızı iyileştirmek için bir Bayes modeli ortalamasını hesaplayabiliriz.

@functools.partial(jax.pmap, axis_name='data', in_axes=(0, None), out_axes=None)
def bayesian_model_average(data, states):
  images, labels = data
  logits = jax.vmap(lambda w, b: images.dot(w) + b)(*states)
  probs = jax.nn.softmax(logits, axis=-1)
  bma_accuracy = (probs.mean(axis=0).argmax(axis=-1) == labels).mean()
  avg_accuracy = (probs.argmax(axis=-1) == labels).mean()
  return lax.pmean(bma_accuracy, axis_name='data'), lax.pmean(avg_accuracy, axis_name='data')

sharded_test_images, sharded_test_labels = shard((test_images, test_labels))
bma_acc, avg_acc = bayesian_model_average((sharded_test_images, sharded_test_labels), states)
print(f'Average Accuracy: {avg_acc}')
print(f'BMA Accuracy: {bma_acc}')
print(f'Accuracy Improvement: {bma_acc - avg_acc}')
Average Accuracy: 0.9188529253005981
BMA Accuracy: 0.9264000058174133
Accuracy Improvement: 0.0075470805168151855

Bir Bayes modeli ortalaması, doğruluğumuzu neredeyse %1 oranında artırır!

Örnek: MovieLens öneri sistemi

Şimdi, kullanıcıların ve çeşitli filmlerin derecelendirmelerinin bir koleksiyonu olan MovieLens önerileri veri kümesiyle çıkarım yapmayı deneyelim. Özellikle, bir şekilde MovieLens temsil edebilir \(N \times M\) izle matris \(W\) \(N\) kullanıcıları ve sayısıdır \(M\) film sayısıdır; beklediğimiz \(N > M\). Girdileri \(W_{ij}\) kullanıcı olup olmadığını belirten bir boolean olan \(i\) film izlenmeye \(j\). MovieLens'in kullanıcı derecelendirmeleri sağladığını, ancak sorunu basitleştirmek için bunları göz ardı ettiğimizi unutmayın.

İlk önce veri setini yükleyeceğiz. 1 milyon reytingli versiyonu kullanacağız.

movielens = tfds.as_numpy(tfds.load('movielens/1m-ratings', batch_size=-1))
GENRES = ['Action', 'Adventure', 'Animation', 'Children', 'Comedy',
          'Crime', 'Documentary', 'Drama', 'Fantasy', 'Film-Noir',
          'Horror', 'IMAX', 'Musical', 'Mystery', 'Romance', 'Sci-Fi',
          'Thriller', 'Unknown', 'War', 'Western', '(no genres listed)']
Downloading and preparing dataset movielens/1m-ratings/0.1.0 (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0...
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Shuffling and writing examples to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0.incompleteYKA3TG/movielens-train.tfrecord
HBox(children=(FloatProgress(value=0.0, max=1000209.0), HTML(value='')))
Dataset movielens downloaded and prepared to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0. Subsequent calls will reuse this data.

Biz seyretmek matrisi elde etmek veri kümesi bazı önişlemeyi yapacağım \(W\).

raw_movie_ids = movielens['train']['movie_id']
raw_user_ids = movielens['train']['user_id']
genres = movielens['train']['movie_genres']

movie_ids, movie_labels = pd.factorize(movielens['train']['movie_id'])
user_ids, user_labels = pd.factorize(movielens['train']['user_id'])

num_movies = movie_ids.max() + 1
num_users = user_ids.max() + 1

movie_titles = dict(zip(movielens['train']['movie_id'],
                        movielens['train']['movie_title']))
movie_genres = dict(zip(movielens['train']['movie_id'],
                        genres))
movie_id_to_title = [movie_titles[movie_labels[id]].decode('utf-8')
                     for id in range(num_movies)]
movie_id_to_genre = [GENRES[movie_genres[movie_labels[id]][0]] for id in range(num_movies)]

watch_matrix = np.zeros((num_users, num_movies), bool)
watch_matrix[user_ids, movie_ids] = True
print(watch_matrix.shape)
(6040, 3706)

Biz üretken bir modeli tanımlamak \(W\)basit bir olasılık matris ayrıştırma yöntemi kullanılarak. Biz gizli varsayalım \(N \times D\) kullanıcı matris \(U\) ve gizli \(M \times D\) film matris \(V\)izle matrisi için Bernoulli logits üretmek çarpılır, \(W\). Aynı zamanda, kullanıcı ve filmler için bir önyargı vektörleri ekleriz \(u\) ve \(v\).

\[ \begin{align*} U &\sim \mathcal{N}(0, 1) \quad u \sim \mathcal{N}(0, 1)\\ V &\sim \mathcal{N}(0, 1) \quad v \sim \mathcal{N}(0, 1)\\ W_{ij} &\sim \textrm{Bernoulli}\left(\sigma\left(\left(UV^T\right)_{ij} + u_i + v_j\right)\right) \end{align*} \]

Bu oldukça büyük bir matris; 6040 kullanıcı ve 3706 film, içinde 22 milyondan fazla giriş bulunan bir matrise yol açar. Bu modeli parçalamaya nasıl yaklaşırız? Eh, o varsayarsak \(N > M\) (filmlerinden daha kullanıcı var yani), o zaman her cihaz kullanıcılarının bir alt kümesine karşılık gelen izlemek matrisinin bir parça olurdu, böylece kullanıcı ekseni boyunca izle matrisi shard mantıklıdır . Önceki örneklerden farklı olarak, bununla birlikte, aynı zamanda yukarı shard gerekir \(U\) her kullanıcı için bir gömme sahip olduğu, her bir cihaz ile bir rastgele sorumlu olacak şekilde, matris \(U\) ve bir rastgele \(W\). Öte yandan, \(V\) unsharded olacak ve cihazlar arasında senkronize edilebilir.

sharded_watch_matrix = shard(watch_matrix)

Bizim yazmadan önce run , en hızlı yerel rasgele değişken Kırma işleminde ek zorlukları konuşalım \(U\). HMC, vanilya çalıştırırken tfp.mcmc.HamiltonianMonteCarlo zincir durumunun her bir eleman için, çekirdek örnek olacak momentumları. Önceden, yalnızca paylaşılmamış rastgele değişkenler bu durumun parçasıydı ve momentum her cihazda aynıydı. Şimdi bir kanatlı bir olduğunda \(U\)biz her cihaz üzerinde farklı momentumları örneklemek gerekir \(U\)aynı momentumları örneklerken, \(V\). Bunu gerçekleştirmek için, biz kullanabilirsiniz tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo bir ile Sharded ivme dağılımı. Paralel hesaplamayı birinci sınıf yapmaya devam ederken, örneğin HMC çekirdeğine bir parçalılık göstergesi alarak bunu basitleştirebiliriz.

def make_run(*,
             axis_name,
             dim=20,
             num_chains=2,
             prior_variance=1.,
             step_size=1e-2,
             num_leapfrog_steps=100,
             num_burnin_steps=1000,
             num_results=500,
             ):
  @functools.partial(jax.pmap, in_axes=(None, 0), axis_name=axis_name)
  @jax.default_matmul_precision('tensorfloat32')
  def run(key, watch_matrix):
    num_users, num_movies = watch_matrix.shape

    Sharded = functools.partial(tfed.Sharded, shard_axis_name=axis_name)

    def prior_fn():
      user_embeddings = yield Root(Sharded(tfd.Sample(tfd.Normal(0., 1.), [num_users, dim]), name='user_embeddings'))
      user_bias = yield Root(Sharded(tfd.Sample(tfd.Normal(0., 1.), [num_users]), name='user_bias'))
      movie_embeddings = yield Root(tfd.Sample(tfd.Normal(0., 1.), [num_movies, dim], name='movie_embeddings'))
      movie_bias = yield Root(tfd.Sample(tfd.Normal(0., 1.), [num_movies], name='movie_bias'))
      return (user_embeddings, user_bias, movie_embeddings, movie_bias)
    prior = tfed.JointDistributionCoroutine(prior_fn)

    def model_fn():
      user_embeddings, user_bias, movie_embeddings, movie_bias = yield from prior_fn()
      logits = (jnp.einsum('...nd,...md->...nm', user_embeddings, movie_embeddings)
                + user_bias[..., :, None] + movie_bias[..., None, :])
      yield Sharded(tfd.Independent(tfd.Bernoulli(logits=logits), 2), name='watch')
    model = tfed.JointDistributionCoroutine(model_fn)

    init_key, sample_key = random.split(key)
    initial_state = prior.sample(seed=init_key, sample_shape=num_chains)

    def target_log_prob(*state):
      return model.log_prob((*state, watch_matrix))

    momentum_distribution = tfed.JointDistributionSequential([
      Sharded(tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_users, dim]), 1.), 2)),
      Sharded(tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_users]), 1.), 1)),
      tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_movies, dim]), 1.), 2),
      tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_movies]), 1.), 1),
    ])

    # We pass in momentum_distribution here to ensure that the momenta for 
    # user_embeddings and user_bias are also sharded
    kernel = tfem.PreconditionedHamiltonianMonteCarlo(target_log_prob, step_size,
                                                      num_leapfrog_steps,
                                                      momentum_distribution=momentum_distribution)

    num_adaptation_steps = int(0.8 * num_burnin_steps)
    kernel = tfm.DualAveragingStepSizeAdaptation(kernel, num_adaptation_steps)

    def trace_fn(state, pkr):
      return {
        'log_prob': target_log_prob(*state),
        'log_accept_ratio': pkr.inner_results.log_accept_ratio,
      }
    return tfm.sample_chain(
        num_results, initial_state,
        kernel=kernel,
        num_burnin_steps=num_burnin_steps,
        trace_fn=trace_fn,
        seed=sample_key)
  return run

Derlenmiş önbelleğe kez Tekrar kaçıyorum run .

%%time
run = make_run(axis_name='data')
output = run(random.PRNGKey(0), sharded_watch_matrix)
jax.tree_map(lambda x: x.block_until_ready(), output)
CPU times: user 56 s, sys: 1min 24s, total: 2min 20s
Wall time: 3min 35s

Şimdi derleme yükü olmadan tekrar çalıştıracağız.

%%time
states, trace = run(random.PRNGKey(0), sharded_watch_matrix)
jax.tree_map(lambda x: x.block_until_ready(), trace)
CPU times: user 28.8 s, sys: 1min 16s, total: 1min 44s
Wall time: 3min 1s

Görünüşe göre yaklaşık 3 dakikada yaklaşık 150.000 birdirme adımı tamamladık, yani saniyede yaklaşık 83 sıçrayış adımı! Örneklerimizin kabul oranını ve log yoğunluğunu çizelim.

fig, axs = plt.subplots(1, len(trace), figsize=(5 * len(trace), 5))
for ax, (key, val) in zip(axs, trace.items()):
  ax.plot(val[0]) # Indexing into a sharded array, each element is the same
  ax.set_title(key);

png

Artık Markov zincirimizden bazı örneklere sahip olduğumuza göre, bunları bazı tahminlerde bulunmak için kullanalım. İlk önce, bileşenlerin her birini çıkaralım. Unutmayın user_embeddings ve user_bias bizim bitiştirmek gerekir, böylece cihazın karşısında bölünmüş olan ShardedArray hepsini elde etmek. Öte yandan, movie_embeddings ve movie_bias her cihazda aynı, bu yüzden sadece ilk kırığın değeri alabilirsiniz. Biz normal kullanacağız numpy CPU'ya TPU arkasından değerleri kopyalamak için.

user_embeddings = np.concatenate(np.array(states.user_embeddings, np.float32), axis=2)
user_bias = np.concatenate(np.array(states.user_bias, np.float32), axis=2)
movie_embeddings = np.array(states.movie_embeddings[0], dtype=np.float32)
movie_bias = np.array(states.movie_bias[0], dtype=np.float32)
samples = (user_embeddings, user_bias, movie_embeddings, movie_bias)
print(f'User embeddings: {user_embeddings.shape}')
print(f'User bias: {user_bias.shape}')
print(f'Movie embeddings: {movie_embeddings.shape}')
print(f'Movie bias: {movie_bias.shape}')
User embeddings: (500, 2, 6040, 20)
User bias: (500, 2, 6040)
Movie embeddings: (500, 2, 3706, 20)
Movie bias: (500, 2, 3706)

Bu örneklerde yakalanan belirsizliği kullanan basit bir öneri sistemi oluşturmaya çalışalım. Önce filmleri izlenme olasılığına göre sıralayan bir fonksiyon yazalım.

@jax.jit
def recommend(sample, user_id):
  user_embeddings, user_bias, movie_embeddings, movie_bias = sample
  movie_logits = (
      jnp.einsum('d,md->m', user_embeddings[user_id], movie_embeddings)
      + user_bias[user_id] + movie_bias)
  return movie_logits.argsort()[::-1]

Artık tüm örnekler üzerinde dolaşan ve her biri için kullanıcının daha önce izlemediği en üst sıradaki filmi seçen bir fonksiyon yazabiliriz. Daha sonra örnekler arasında önerilen tüm filmlerin sayısını görebiliriz.

def get_recommendations(user_id): 
  movie_ids = []
  already_watched = set(jnp.arange(num_movies)[watch_matrix[user_id] == 1])
  for i in range(500):
    for j in range(2):
      sample = jax.tree_map(lambda x: x[i, j], samples)
      ranking = recommend(sample, user_id)
      for movie_id in ranking:
        if int(movie_id) not in already_watched:
          movie_ids.append(movie_id)
          break
  return movie_ids

def plot_recommendations(movie_ids, ax=None):
  titles = collections.Counter([movie_id_to_title[i] for i in movie_ids])
  ax = ax or plt.gca()
  names, counts = zip(*sorted(titles.items(), key=lambda x: -x[1]))
  ax.bar(names, counts)
  ax.set_xticklabels(names, rotation=90)

En çok film izleyen kullanıcıyı en az izleyen kullanıcıya karşı alalım.

user_watch_counts = watch_matrix.sum(axis=1)
user_most = user_watch_counts.argmax()
user_least = user_watch_counts.argmin()
print(user_watch_counts[user_most], user_watch_counts[user_least])
2314 20

Biz sistemimiz hakkında daha fazla kesinlik vardır umut user_most daha user_least biz filmlerin sıralar ilgili daha fazla bilgiye sahip olduğu göz önüne alındığında, user_most izlemek için daha olasıdır.

fig, ax = plt.subplots(1, 2, figsize=(20, 10))
most_recommendations = get_recommendations(user_most)
plot_recommendations(most_recommendations, ax=ax[0])
ax[0].set_title('Recommendation for user_most')
least_recommendations = get_recommendations(user_least)
plot_recommendations(least_recommendations, ax=ax[1])
ax[1].set_title('Recommendation for user_least');

png

Biz sunduğumuz önerilerin daha varyans olduğunu görmek user_least onların izle tercihleri bölümünden ek belirsizliği yansıtmaktadır.

Ayrıca önerilen filmlerin türlerine de bakabiliriz.

most_genres = collections.Counter([movie_id_to_genre[i] for i in most_recommendations])
least_genres = collections.Counter([movie_id_to_genre[i] for i in least_recommendations])
fig, ax = plt.subplots(1, 2, figsize=(20, 10))
ax[0].bar(most_genres.keys(), most_genres.values())
ax[0].set_title('Genres recommended for user_most')
ax[1].bar(least_genres.keys(), least_genres.values())
ax[1].set_title('Genres recommended for user_least');

png

user_most bir sürü film gördü ve buna karşın gizem ve suç gibi daha niş türler tavsiye edilmiştir user_least çok film izledim değil ve hangi çarpık komedi ve aksiyon daha ana akım sinema, önerildi.