অটো-ব্যাচড যৌথ বিতরণ: একটি মৃদু টিউটোরিয়াল

TensorFlow.org এ দেখুন Google Colab-এ চালান GitHub-এ উৎস দেখুন নোটবুক ডাউনলোড করুন

ভূমিকা

TensorFlow সম্ভাব্যতা (TFP) একটি নম্বর প্রস্তাব JointDistribution সহজেই কোনো ব্যবহারকারী একটি কাছাকাছি-গাণিতিক আকারে সম্ভাব্য গ্রাফিকাল মডেল প্রকাশ অনুমতি দিয়ে সহজ বিমূর্ত যে সম্ভাব্য অনুমান করা; বিমূর্ততা মডেল থেকে স্যাম্পলিং এবং মডেল থেকে নমুনার লগ সম্ভাব্যতা মূল্যায়ন করার পদ্ধতি তৈরি করে। এই টিউটোরিয়াল, আমরা REVIEW "autobatched" রূপগুলো, যা মূল পর বিকশিত হয়েছে JointDistribution বিমূর্ত। আসল, নন-অটোব্যাচড অ্যাবস্ট্রাকশনের সাথে আপেক্ষিক, স্বতঃব্যাচড সংস্করণগুলি ব্যবহার করা সহজ এবং আরও এর্গোনমিক, যা অনেক মডেলকে কম বয়লারপ্লেটের সাথে প্রকাশ করার অনুমতি দেয়। এই কোল্যাবে, আমরা একটি সাধারণ মডেল (সম্ভবত ক্লান্তিকর) বিশদে অন্বেষণ করি, স্বয়ংক্রিয় ব্যাচিং সমস্যার সমাধান করে তা পরিষ্কার করে এবং (আশা করি) পাঠককে TFP আকৃতির ধারণা সম্পর্কে আরও শিখিয়েছি।

Autobatching প্রবর্তনের পূর্বে সেখানে কয়েক বিভিন্ন রূপগুলো ছিল JointDistribution , সম্ভাব্য মডেল প্রকাশ করার জন্য বিভিন্ন অন্বিত শৈলী সংশ্লিষ্ট: JointDistributionSequential , JointDistributionNamed এবং JointDistributionCoroutine । Auobatching একটি mixin যেমন বিদ্যমান, তাই এখন আমরা আছে AutoBatched এই সব রুপভেদ। এই টিউটোরিয়াল, আমরা মধ্যে পার্থক্য অন্বেষণ JointDistributionSequential এবং JointDistributionSequentialAutoBatched ; যাইহোক, আমরা এখানে যা কিছু করি তা অন্যান্য ভেরিয়েন্টের ক্ষেত্রে প্রযোজ্য যা মূলত কোন পরিবর্তন ছাড়াই।

নির্ভরতা এবং পূর্বশর্ত

আমদানি এবং সেট আপ

পূর্বশর্ত: একটি Bayesian রিগ্রেশন সমস্যা

আমরা একটি খুব সাধারণ Bayesian রিগ্রেশন দৃশ্যকল্প বিবেচনা করব:

\[ \begin{align*} m & \sim \text{Normal}(0, 1) \\ b & \sim \text{Normal}(0, 1) \\ Y & \sim \text{Normal}(mX + b, 1) \end{align*} \]

এই মডেলে m এবং b মান লম্ব থেকে টানা হয়, এবং পর্যবেক্ষণ Y একটি সাধারণ বন্টনের যার গড় র্যান্ডম ভেরিয়েবল উপর নির্ভর করে থেকে টানা হয় m এবং b , এবং কিছু (nonrandom, পরিচিত) covariates X । (সরলতার জন্য, এই উদাহরণে, আমরা ধরে নিই যে সমস্ত র্যান্ডম ভেরিয়েবলের স্কেল পরিচিত।)

এই মডেলের অনুমান সঞ্চালন করার জন্য, আমরা উভয় covariates জানা প্রয়োজন চাই X ও পর্যবেক্ষণের Y , কিন্তু এই টিউটোরিয়ালের উদ্দেশ্যে, আমরা কেবল প্রয়োজন হবে X , তাই আমরা একটি সহজ ডামি সংজ্ঞায়িত X :

X = np.arange(7)
X
array([0, 1, 2, 3, 4, 5, 6])

ডেসিডারটা

সম্ভাব্য অনুমানে, আমরা প্রায়শই দুটি মৌলিক অপারেশন করতে চাই:

  • sample : মডেল থেকে অঙ্কন নমুনা।
  • log_prob : মডেল থেকে একটি নমুনা লগ সম্ভাব্যতা গণনা করা হচ্ছে।

TFP এর মূল অবদান JointDistribution বিমূর্ত (সেইসাথে এর সম্ভাব্য প্রোগ্রামিং অনেক অন্যান্য পন্থা) ব্যবহারকারীদের একবার একটি মডেল লিখে উভয় অ্যাক্সেস করতে অনুমতি দেওয়া sample এবং log_prob কম্পিউটেশন।

বুঝেই আমরা আমাদের ডেটা সেট (7 পয়েন্ট আছে X.shape = (7,) ), আমরা এখন একটি চমৎকার জন্য desiderata রাষ্ট্র পারে JointDistribution :

  • sample() একটি তালিকা উত্পাদন করা উচিত Tensors আকৃতি থাকার [(), (), (7,) ] স্কালে ঢাল, স্কালে পক্ষপাত, এবং ভেক্টর পর্যবেক্ষণের সাথে সঙ্গতিপূর্ণ যথাক্রমে।
  • log_prob(sample()) একটি নির্দিষ্ট ঢাল, পক্ষপাত, এবং পর্যবেক্ষণ লগ সম্ভাব্যতা: স্কেলের উত্পাদন করা উচিত।
  • sample([5, 3]) একটি তালিকা উত্পাদন করা উচিত Tensors আকৃতি থাকার [(5, 3), (5, 3), (5, 3, 7)] , একটি প্রতিনিধিত্বমূলক (5, 3) - থেকে নমুনা ব্যাচ মডেলটি.
  • log_prob(sample([5, 3])) একটি উত্পাদন করা উচিত Tensor আকৃতি (5, 3) সঙ্গে।

আমরা এখন একটি উত্তরাধিকার তাকান করব JointDistribution , মডেল কিভাবে উপরে desiderata অর্জন করা দেখুন, এবং আশা একটু বেশি শিখতে সম্পর্কে TFP পথ ধরে আকার।

ভক্ষক সতর্কতা: পদ্ধতির মাফিক যোগ boilerplate, ছাড়া উপরে desiderata হয় autobatching

প্রথম প্রচেষ্টা; JointDistributionSequential

jds = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

এটি কমবেশি কোডে মডেলের সরাসরি অনুবাদ। ঢাল m এবং পক্ষপাত b সহজবোধ্য হয়। Y একটি ব্যবহার সংজ্ঞায়িত করা হয় lambda -function: সাধারণ প্যাটার্ন যে একটি হল lambda এর -function \(k\) একটি আর্গুমেন্ট JointDistributionSequential (JDS) পূর্ববর্তী ব্যবহার \(k\) মডেল ডিস্ট্রিবিউশন। "বিপরীত" ক্রম নোট করুন।

আমরা ডাকবো sample_distributions , যা আয় উভয় একটি নমুনা এবং অন্তর্নিহিত "উপ-ডিস্ট্রিবিউশন" যে নমুনা জেনারেট করতে ব্যবহার করা হয়েছে। (আমরা কল করে শুধু নমুনা উত্পাদিত থাকতে পারে sample : নমুনা আমরা উত্পাদন জরিমানা; পাশাপাশি পরবর্তী টিউটোরিয়ালে এটা ডিস্ট্রিবিউশন আছে সুবিধাজনক হতে হবে।)

dists, sample = jds.sample_distributions()
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

কিন্তু log_prob একটি অবাঞ্ছিত আকৃতি সঙ্গে ফলে উৎপন্ন:

jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=
array([-4.4777603, -4.6775575, -4.7430477, -4.647725 , -4.5746684,
       -4.4368567, -4.480562 ], dtype=float32)>

এবং একাধিক নমুনা কাজ করে না:

try:
  jds.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

আসুন বোঝার চেষ্টা করি কী ভুল হচ্ছে।

একটি সংক্ষিপ্ত পর্যালোচনা: ব্যাচ এবং ইভেন্ট আকার

TFP, একটি সাধারণ (ক না JointDistribution ) সম্ভাব্যতা বিতরণের একটি ইভেন্ট আকৃতি এবং একটি ব্যাচ আকৃতি, এবং পার্থক্য বুঝতে TFP কার্যকর ব্যবহার অত্যন্ত গুরুত্বপূর্ণ আছে:

  • ইভেন্ট আকৃতি বন্টন থেকে একটি একক ড্র এর আকৃতি বর্ণনা করে; ড্র মাত্রা জুড়ে নির্ভরশীল হতে পারে. স্কেলার ডিস্ট্রিবিউশনের জন্য, ইভেন্ট আকৃতি হল []। একটি 5-মাত্রিক মাল্টিভেরিয়েট নরমালের জন্য, ঘটনার আকৃতি হল [5]।
  • ব্যাচের আকৃতি স্বাধীন, অভিন্নভাবে বিতরণ করা ড্র নয়, ওরফে বিতরণের একটি "ব্যাচ" বর্ণনা করে। একটি একক পাইথন অবজেক্টে ডিস্ট্রিবিউশনের ব্যাচের প্রতিনিধিত্ব করা TFP স্কেলে দক্ষতা অর্জন করার অন্যতম প্রধান উপায়।

আমাদের উদ্দেশ্যে, মনে রাখতে হবে একটি গুরুত্বপূর্ণ সত্য যে আমরা যদি বলি log_prob একটি বিতরণ থেকে একটি একক নমুনার উপর, ফলে সবসময় একটি আকৃতি যে ম্যাচ (অর্থাত, ডানদিকের মাত্রা যেমন আছে) ব্যাচ আকৃতি হবে।

আকার একটি আরো বিস্তারিত আলোচনার জন্য দেখুন, "বুঝতে TensorFlow ডিস্ট্রিবিউশন আকার" টিউটোরিয়ালটি

নেই কেন log_prob(sample()) স্কেলের উত্পাদন?

আসুন ব্যাচ এবং ইভেন্ট আকৃতি আমাদের জ্ঞান ব্যবহার করেন সেটি কী দিয়ে ঘটছে অন্বেষণ করতে log_prob(sample()) এখানে আবার আমাদের নমুনা:

sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

এবং এখানে আমাদের বিতরণ আছে:

dists
[<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>]

অংশগুলির (মিলিত) উপাদানগুলিতে সাব-ডিস্ট্রিবিউশনগুলির লগ সম্ভাব্যতাগুলিকে যোগ করে লগ সম্ভাব্যতা গণনা করা হয়:

log_prob_parts = [dist.log_prob(s) for (dist, s) in zip(dists, sample)]
log_prob_parts
[<tf.Tensor: shape=(), dtype=float32, numpy=-2.3113134>,
 <tf.Tensor: shape=(), dtype=float32, numpy=-1.1357536>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([-1.0306933, -1.2304904, -1.2959809, -1.200658 , -1.1276014,
        -0.9897899, -1.0334952], dtype=float32)>]
np.sum(log_prob_parts) - jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)>

সুতরাং, ব্যাখ্যার এক স্তর যে লগ সম্ভাব্যতা হিসাব 7 টেন্সর ফেরার কারণ তৃতীয় subcomponent হয় log_prob_parts 7 টেন্সর হয়। কিন্তু কেন?

ভাল, আমরা দেখতে যে শেষ উপাদান dists , যার উপর আমাদের ডিস্ট্রিবিউশন অনুরূপ Y mathematial তৈয়ার, একটি হয়েছে batch_shape এর [7] । অন্য কথায়, ওভার আমাদের ডিস্ট্রিবিউশন Y (এই ক্ষেত্রে, একই স্কেল বিভিন্ন উপায়ে সঙ্গে এবং,) 7 স্বাধীন লম্ব একটি ব্যাচ হয়।

এখন আমরা বুঝতে পারি কী হয়েছে: JDS মধ্যে, ওভার বন্টন Y হয়েছে batch_shape=[7] , JDS থেকে একটি নমুনা জন্য scalars প্রতিনিধিত্ব করে m এবং b আর 7 স্বাধীন লম্ব এর "ব্যাচ"। এবং log_prob 7 পৃথক লগ-সম্ভাব্যতা, প্রতিটি যা আঁকার লগ সম্ভাব্যতা প্রতিনিধিত্ব করে নির্ণয় m এবং b ও একটি একক পর্যবেক্ষণ Y[i] কিছু X[i]

স্থির log_prob(sample()) সঙ্গে Independent

পুনরাহ্বান যে dists[2] হয়েছে event_shape=[] এবং batch_shape=[7] :

dists[2]
<tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>

TFP এর ব্যবহারের Independent metadistribution, যা ঘটনা আকারে ব্যাচ মাত্রা পরিবর্তন করে, আমরা সঙ্গে একটি বন্টন এই রূপান্তর করতে পারেন event_shape=[7] এবং batch_shape=[] (আমরা এটিতে নামান্তর করব y_dist_i কারণ এটি একটি ডিস্ট্রিবিউশনের Y সঙ্গে _i স্থায়ী আমাদের জন্য Independent মোড়ানো):

y_dist_i = tfd.Independent(dists[2], reinterpreted_batch_ndims=1)
y_dist_i
<tfp.distributions.Independent 'IndependentJointDistributionSequential_sample_distributions_Normal' batch_shape=[] event_shape=[7] dtype=float32>

এখন, log_prob 7 ভেক্টরের স্কেলের হল:

y_dist_i.log_prob(sample[2])
<tf.Tensor: shape=(), dtype=float32, numpy=-7.9087086>

আন্ডার কভার, Independent ব্যাচ উপর ধরছে:

y_dist_i.log_prob(sample[2]) - tf.reduce_sum(dists[2].log_prob(sample[2]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

এবং প্রকৃতপক্ষে, আমরা একটি নতুন গঠন করা এই ব্যবহার করতে পারেন jds_i ( i আবার ঘোরা Independent ) যেখানে log_prob স্কেলের ফেরৎ:

jds_i = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m*X + b, scale=1.),
        reinterpreted_batch_ndims=1)
])

jds_i.log_prob(sample)
<tf.Tensor: shape=(), dtype=float32, numpy=-11.355776>

একটি দম্পতি নোট:

  • jds_i.log_prob(s) হিসাবে একই নয় tf.reduce_sum(jds.log_prob(s)) প্রাক্তনটি যৌথ বিতরণের "সঠিক" লগ সম্ভাবনা তৈরি করে। 7 টেন্সর উপর আধুনিক অঙ্কের, প্রতিটি উপাদান যার লগ সম্ভাব্যতা এর সমষ্টি m , b , এবং লগ সম্ভাব্যতা একটি একক উপাদান Y , তাই এটি overcounts m এবং b । ( log_prob(m) + log_prob(b) + log_prob(Y) একটি ব্যতিক্রম নিক্ষেপ কারণ TFP TF এবং NumPy এর সম্প্রচার নিয়ম অনুসরণ করে বদলে ফলে ফেরৎ;। একটি ভেক্টর করার জন্য একটি স্কেলার যোগ করার সময় একটি ভেক্টর আকারের ফলাফলের উত্পাদন করে)
  • এই বিশেষ ক্ষেত্রে আমরা সমস্যার সমাধান করতে পারে এবং ব্যবহার একই ফলাফল অর্জন MultivariateNormalDiag পরিবর্তে Independent(Normal(...)) MultivariateNormalDiag একটি ভেক্টর-মূল্যবান বন্টন (অর্থাত, এটি আগে থেকেই ভেক্টর ঘটনা-আকৃতি আছে)। Indeeed MultivariateNormalDiag হতে পারে (কিন্তু নয়) একটি রচনা হিসাবে প্রয়োগ করা Independent এবং Normal । এটা যে একটি ভেক্টর দেওয়া স্মরণ করা উপযুক্ত V থেকে নমুনা n1 = Normal(loc=V) , এবং n2 = MultivariateNormalDiag(loc=V) আলাদা করে চেনা হয়; এই ডিস্ট্রিবিউশন beween পার্থক্য যে n1.log_prob(n1.sample()) একটি ভেক্টর এবং n2.log_prob(n2.sample()) স্কেলের হয়।

একাধিক নমুনা?

একাধিক নমুনা আঁকা এখনও কাজ করে না:

try:
  jds_i.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

এর কারণ সম্পর্কে চিন্তা করা যাক. আমরা যখন কল jds_i.sample([5, 3]) , তাই আমরা প্রথমেই নমুনার জন্য আঁকা করব m এবং b , আকৃতি সঙ্গে প্রতিটি (5, 3) । এর পরে, আমরা গঠন করা চেষ্টা করতে যাচ্ছেন Normal মাধ্যমে বন্টন:

tfd.Normal(loc=m*X + b, scale=1.)

কিন্তু যদি m আকৃতি আছে (5, 3) এবং X আকৃতি আছে 7 , আমরা সংখ্যাবৃদ্ধি তাদের একসঙ্গে করতে পারে না এবং প্রকৃতপক্ষে এই ত্রুটি করছি আঘাত করা হয়:

m = tfd.Normal(0., 1.).sample([5, 3])
try:
  m * X
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

এটি সমস্যাটির সমাধান করার জন্য, এর বৈশিষ্ট্য উপর বন্টন কী মনে করেন Y হয়েছে আছে। আমরা নামক থাকেন তাহলে jds_i.sample([5, 3]) , তারপর আমরা জানি m এবং b উভয় আকৃতি হবে (5, 3) । কি আকৃতি একটি কল করা উচিত sample উপর Y বন্টন উত্পাদন? সুস্পষ্ট জবাব (5, 3, 7) : প্রতিটি ব্যাচ বিন্দু জন্য, আমরা হিসাবে একই আকার সঙ্গে একটি নমুনা চান X । আমরা TensorFlow এর সম্প্রচার ক্ষমতা ব্যবহার করে, অতিরিক্ত মাত্রা যোগ করে এটি অর্জন করতে পারি:

m[..., tf.newaxis].shape
TensorShape([5, 3, 1])
(m[..., tf.newaxis] * X).shape
TensorShape([5, 3, 7])

উভয় একটি অক্ষ যোগ করার পদ্ধতি m এবং b , আমরা একটি নতুন JDS সংজ্ঞায়িত করতে পারে সমর্থন একাধিক নমুনা:

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

shaped_sample = jds_ia.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-1.1133379 ,  0.16390413, -0.24177533],
        [-1.1312429 , -0.6224666 , -1.8182136 ],
        [-0.31343174, -0.32932565,  0.5164407 ],
        [-0.0119963 , -0.9079621 ,  2.3655841 ],
        [-0.26293617,  0.8229698 ,  0.31098196]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-0.02876974,  1.0872147 ,  1.0138507 ],
        [ 0.27367726, -1.331534  , -0.09084719],
        [ 1.3349475 , -0.68765205,  1.680652  ],
        [ 0.75436825,  1.3050154 , -0.9415123 ],
        [-1.2502679 , -0.25730947,  0.74611956]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 7), dtype=float32, numpy=
 array([[[-1.8258233e+00, -3.0641669e-01, -2.7595463e+00, -1.6952467e+00,
          -4.8197951e+00, -5.2986512e+00, -6.6931367e+00],
         [ 3.6438566e-01,  1.0067395e+00,  1.4542470e+00,  8.1155670e-01,
           1.8868095e+00,  2.3877139e+00,  1.0195159e+00],
         [-8.3624744e-01,  1.2518480e+00,  1.0943471e+00,  1.3052304e+00,
          -4.5756745e-01, -1.0668410e-01, -7.0669651e-02]],
 
        [[-3.1788960e-01,  9.2615485e-03, -3.0963073e+00, -2.2846246e+00,
          -3.2269263e+00, -6.0213070e+00, -7.4806519e+00],
         [-3.9149747e+00, -3.5155020e+00, -1.5669601e+00, -5.0759468e+00,
          -4.5065498e+00, -5.6719379e+00, -4.8012795e+00],
         [ 1.3053948e-01, -8.0493152e-01, -4.7845001e+00, -4.9721808e+00,
          -7.1365709e+00, -9.6198196e+00, -9.7951422e+00]],
 
        [[ 2.0621397e+00,  3.4639853e-01,  7.0252883e-01, -1.4311566e+00,
           3.3790007e+00,  1.1619035e+00, -8.9105040e-01],
         [-7.8956139e-01, -8.5023916e-01, -9.7148323e-01, -2.6229355e+00,
          -2.7150445e+00, -2.4633870e+00, -2.1841538e+00],
         [ 7.7627432e-01,  2.2401071e+00,  3.7601702e+00,  2.4245868e+00,
           4.0690269e+00,  4.0605016e+00,  5.1753912e+00]],
 
        [[ 1.4275590e+00,  3.3346462e+00,  1.5374103e+00, -2.2849756e-01,
           9.1219616e-01, -3.1220305e-01, -3.2643962e-01],
         [-3.1910419e-02, -3.8848895e-01,  9.9946201e-02, -2.3619974e+00,
          -1.8507402e+00, -3.6830821e+00, -5.4907336e+00],
         [-7.1941972e-02,  2.1602919e+00,  4.9575748e+00,  4.2317696e+00,
           9.3528280e+00,  1.0526063e+01,  1.5262107e+01]],
 
        [[-2.3257759e+00, -2.5343289e+00, -3.5342445e+00, -4.0423255e+00,
          -3.2361765e+00, -3.3434000e+00, -2.6849220e+00],
         [ 1.5006512e-02, -1.9866472e-01,  7.6781356e-01,  1.6228745e+00,
           1.4191239e+00,  2.6655579e+00,  4.4663467e+00],
         [ 2.6599693e+00,  1.2663836e+00,  1.7162113e+00,  1.4839669e+00,
           2.0559487e+00,  2.5976877e+00,  2.5977583e+00]]], dtype=float32)>]
jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.483114 , -10.139662 , -11.514159 ],
       [-11.656767 , -17.201958 , -12.132455 ],
       [-17.838818 ,  -9.474525 , -11.24898  ],
       [-13.95219  , -12.490049 , -17.123957 ],
       [-14.487818 , -11.3755455, -10.576363 ]], dtype=float32)>

একটি অতিরিক্ত চেক হিসাবে, আমরা যাচাই করব যে একটি একক ব্যাচ পয়েন্টের লগ সম্ভাব্যতা আমাদের আগে যা ছিল তার সাথে মেলে:

(jds_ia.log_prob(shaped_sample)[3, 1] -
 jds_i.log_prob([shaped_sample[0][3, 1],
                 shaped_sample[1][3, 1],
                 shaped_sample[2][3, 1, :]]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

জয়ের জন্য অটোব্যাচিং

চমৎকার! আমরা এখন JointDistribution একটি সংস্করণ আছে হ্যান্ডলগুলি সব আমাদের desiderata: log_prob আয় ব্যবহার করার একটা ঝোঁক স্কালে ধন্যবাদ tfd.Independent , এবং একাধিক নমুনা এখন কাজ যে আমরা অতিরিক্ত অক্ষ যোগ করে সম্প্রচার স্থির করেছি।

যদি আমি আপনাকে বলি যে একটি সহজ, ভাল উপায় আছে? নেই, এবং এটি বলা হচ্ছে JointDistributionSequentialAutoBatched (JDSAB):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
jds_ab.log_prob(jds.sample())
<tf.Tensor: shape=(), dtype=float32, numpy=-12.954952>
shaped_sample = jds_ab.sample([5, 3])
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.191533 , -10.43885  , -16.371655 ],
       [-13.292994 , -11.97949  , -16.788685 ],
       [-15.987699 , -13.435732 , -10.6029   ],
       [-10.184758 , -11.969714 , -14.275676 ],
       [-12.740775 , -11.5654125, -12.990162 ]], dtype=float32)>
jds_ab.log_prob(shaped_sample) - jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)>

কিভাবে কাজ করে? আপনাকে প্রচেষ্টা পারে যদিও কোড পড়া একটি গভীর বোঝার জন্য, আমরা একটি সংক্ষিপ্ত যা সর্বাধিক ব্যবহারের ক্ষেত্রে জন্য যথেষ্ট দেব:

  • রিকল আমাদের প্রথম সমস্যা যে জন্য আমাদের ডিস্ট্রিবিউশন ছিল Y ছিল batch_shape=[7] এবং event_shape=[] , এবং আমরা ব্যবহার Independent একটি ইভেন্ট মাত্রার ব্যাচ মাত্রা রূপান্তর করবে। JDSAB উপাদান বিতরণের ব্যাচ আকার উপেক্ষা করে; এর পরিবর্তে এটি মডেল, যা গণ্য করা হয় একটি সামগ্রিক সম্পত্তি হিসেবে ব্যাচ আকৃতি একইরূপে [] (যদি না সেটিং দ্বারা অন্যথায় নিদিষ্ট batch_ndims > 0 )। প্রভাব tfd.Independent ব্যবহার ঘটনা মাত্রা মধ্যে উপাদান ডিস্ট্রিবিউশন সব ব্যাচ মাত্রা রূপান্তর করতে, যেমন আমরা ম্যানুয়ালি উপরে করেনি দেওয়ার সমতুল্য।
  • আমাদের দ্বিতীয় সমস্যার আকার ম্যাসেজ করার প্রয়োজন ছিল m এবং b , যাতে তারা সঙ্গে উপযুক্তভাবে সম্প্রচার পারে X যখন একাধিক নমুনা তৈরি করা। JDSAB সঙ্গে, আপনি একটি মডেল একটি একক নমুনা জেনারেট করতে লিখুন, এবং আমরা "লিফট" সমগ্র মডেল TensorFlow এর ব্যবহার করে অনেকগুলি নমুনা উত্পাদন করতে পারে vectorized_map । (এই বৈশিষ্ট্যটি Jax এর analagous হয় vmap ।)

আরো বিস্তারিত ব্যাচ আকৃতি ইস্যু এক্সপ্লোরিং, আমরা আমাদের মূল "খারাপ" যৌথ বিতরণের ব্যাচ আকার তুলনা করতে পারবেন jds , আমাদের ব্যাচ-সংশোধন ডিস্ট্রিবিউশন jds_i এবং jds_ia , এবং আমাদের autobatched jds_ab :

jds.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([7])]
jds_i.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ia.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ab.batch_shape
TensorShape([])

আমরা দেখি যে মূল jds বিভিন্ন ব্যাচ আকার সঙ্গে subdistributions হয়েছে। jds_i এবং jds_ia একই (খালি) ব্যাচ আকৃতি সঙ্গে subdistributions তৈরির মাধ্যমে এই সমাধান করুন। jds_ab শুধুমাত্র একটি একক (খালি) ব্যাচ আকৃতি হয়েছে।

এটা তোলে এর মূল্য লক্ষ করেন, JointDistributionSequentialAutoBatched বিনামূল্যে জন্য কিছু অতিরিক্ত সাধারণত্ব উপলব্ধ করা হয়। আমরা covariates করা ধরুন X (এবং, পরোক্ষভাবে, পর্যবেক্ষণ Y ) দ্বি-মাত্রিক:

X = np.arange(14).reshape((2, 7))
X
array([[ 0,  1,  2,  3,  4,  5,  6],
       [ 7,  8,  9, 10, 11, 12, 13]])

আমাদের JointDistributionSequentialAutoBatched কোন পরিবর্তন (আমরা মডেল পুনরায় সংজ্ঞায়িত কারণ আকৃতি প্রয়োজন সঙ্গে কাজ করে X দ্বারা সংরক্ষিত করা হয়েছে jds_ab.log_prob ):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

shaped_sample = jds_ab.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.1813647 , -0.85994506,  0.27593774],
        [-0.73323774,  1.1153806 ,  0.8841938 ],
        [ 0.5127983 , -0.29271227,  0.63733214],
        [ 0.2362284 , -0.919168  ,  1.6648189 ],
        [ 0.26317367,  0.73077047,  2.5395133 ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.09636458,  2.0138032 , -0.5054413 ],
        [ 0.63941646, -1.0785882 , -0.6442188 ],
        [ 1.2310615 , -0.3293852 ,  0.77637213],
        [ 1.2115169 , -0.98906034, -0.07816773],
        [-1.1318136 ,  0.510014  ,  1.036522  ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 2, 7), dtype=float32, numpy=
 array([[[[-1.9685398e+00, -1.6832136e+00, -6.9127172e-01,
            8.5992378e-01, -5.3123581e-01,  3.1584005e+00,
            2.9044402e+00],
          [-2.5645006e-01,  3.1554163e-01,  3.1186538e+00,
            1.4272424e+00,  1.2843871e+00,  1.2266440e+00,
            1.2798605e+00]],
 
         [[ 1.5973477e+00, -5.3631151e-01,  6.8143606e-03,
           -1.4910895e+00, -2.1568544e+00, -2.0513713e+00,
           -3.1663666e+00],
          [-4.9448099e+00, -2.8385928e+00, -6.9027486e+00,
           -5.6543546e+00, -7.2378774e+00, -8.1577444e+00,
           -9.3582869e+00]],
 
         [[-2.1233239e+00,  5.8853775e-02,  1.2024102e+00,
            1.6622503e+00, -1.9197327e-01,  1.8647723e+00,
            6.4322817e-01],
          [ 3.7549341e-01,  1.5853541e+00,  2.4594500e+00,
            2.1952972e+00,  1.7517658e+00,  2.9666045e+00,
            2.5468128e+00]]],
 
 
        [[[ 8.9906776e-01,  6.7375046e-01,  7.3354661e-01,
           -9.9894643e-01, -3.4606690e+00, -3.4810467e+00,
           -4.4315586e+00],
          [-3.0670738e+00, -6.3628020e+00, -6.2538433e+00,
           -6.8091092e+00, -7.7134805e+00, -8.6319380e+00,
           -8.6904278e+00]],
 
         [[-2.2462025e+00, -3.3060855e-01,  1.8974400e-01,
            3.1422038e+00,  4.1483402e+00,  3.5642972e+00,
            4.8709240e+00],
          [ 4.7880130e+00,  5.8790064e+00,  9.6695948e+00,
            7.8112822e+00,  1.2022618e+01,  1.2411858e+01,
            1.4323385e+01]],
 
         [[-1.0189297e+00, -7.8115642e-01,  1.6466728e+00,
            8.2378983e-01,  3.0765080e+00,  3.0170646e+00,
            5.1899948e+00],
          [ 6.5285158e+00,  7.8038850e+00,  6.4155884e+00,
            9.0899811e+00,  1.0040427e+01,  9.1404457e+00,
            1.0411951e+01]]],
 
 
        [[[ 4.5557004e-01,  1.4905317e+00,  1.4904103e+00,
            2.9777462e+00,  2.8620450e+00,  3.4745665e+00,
            3.8295493e+00],
          [ 3.9977460e+00,  5.7173767e+00,  7.8421035e+00,
            6.3180594e+00,  6.0838981e+00,  8.2257290e+00,
            9.6548376e+00]],
 
         [[-7.0750320e-01, -3.5972297e-01,  4.3136525e-01,
           -2.3301599e+00, -5.0374687e-01, -2.8338656e+00,
           -3.4453444e+00],
          [-3.1258626e+00, -3.4687450e+00, -1.2045374e+00,
           -4.0196013e+00, -5.8831010e+00, -4.2965469e+00,
           -4.1388311e+00]],
 
         [[ 2.1969774e+00,  2.4614549e+00,  2.2314475e+00,
            1.8392437e+00,  2.8367062e+00,  4.8600502e+00,
            4.2273531e+00],
          [ 6.1879644e+00,  5.1792760e+00,  6.1141996e+00,
            5.6517797e+00,  8.9979610e+00,  7.5938139e+00,
            9.7918644e+00]]],
 
 
        [[[ 1.5249090e+00,  1.1388919e+00,  8.6903995e-01,
            3.0762129e+00,  1.5128503e+00,  3.5204377e+00,
            2.4760864e+00],
          [ 3.4166217e+00,  3.5930209e+00,  3.1694956e+00,
            4.5797420e+00,  4.5271711e+00,  2.8774328e+00,
            4.7288942e+00]],
 
         [[-2.3095846e+00, -2.0595703e+00, -3.0093951e+00,
           -3.8594103e+00, -4.9681158e+00, -6.4256043e+00,
           -5.5345035e+00],
          [-6.4306297e+00, -7.0924540e+00, -8.4075985e+00,
           -1.0417805e+01, -1.1727266e+01, -1.1196255e+01,
           -1.1333830e+01]],
 
         [[-7.0419472e-01,  1.4568675e+00,  3.7946482e+00,
            4.8489718e+00,  6.6498446e+00,  9.0224218e+00,
            1.1153137e+01],
          [ 1.0060651e+01,  1.1998097e+01,  1.5326431e+01,
            1.7957514e+01,  1.8323889e+01,  2.0160881e+01,
            2.1269085e+01]]],
 
 
        [[[-2.2360647e-01, -1.3632748e+00, -7.2704530e-01,
            2.3558271e-01, -1.0381399e+00,  1.9387857e+00,
           -3.3694571e-01],
          [ 1.6015106e-01,  1.5284677e+00, -4.8567140e-01,
           -1.7770648e-01,  2.1919653e+00,  1.3015286e+00,
            1.3877077e+00]],
 
         [[ 1.3688663e+00,  2.6602898e+00,  6.6657305e-01,
            4.6554832e+00,  5.7781887e+00,  4.9115267e+00,
            4.8446012e+00],
          [ 5.1983776e+00,  6.2297459e+00,  6.3848300e+00,
            8.4291229e+00,  7.1309576e+00,  1.0395646e+01,
            8.5736713e+00]],
 
         [[ 1.2675294e+00,  5.2844582e+00,  5.1331611e+00,
            8.9993315e+00,  1.0794343e+01,  1.4039831e+01,
            1.5731170e+01],
          [ 1.9084715e+01,  2.2191265e+01,  2.3481146e+01,
            2.5803375e+01,  2.8632090e+01,  3.0234968e+01,
            3.1886738e+01]]]], dtype=float32)>]
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-28.90071 , -23.052422, -19.851362],
       [-19.775568, -25.894997, -20.302256],
       [-21.10754 , -23.667885, -20.973007],
       [-19.249458, -20.87892 , -20.573763],
       [-22.351208, -25.457762, -24.648403]], dtype=float32)>

অন্যদিকে, আমাদের সাবধানে crafted JointDistributionSequential আর কাজ করে:

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

try:
  jds_ia.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3,1] vs. [2,7] [Op:Mul]

এই সমস্যার সমাধান করতে, আমরা একটি দ্বিতীয় যোগ আছে চাই tf.newaxis উভয় m এবং b আকৃতি, এবং বৃদ্ধি মেলে reinterpreted_batch_ndims করার কলে 2 Independent । এই ক্ষেত্রে, স্বয়ংক্রিয়-ব্যাচিং মেশিনারিগুলিকে আকৃতির সমস্যাগুলি পরিচালনা করতে দেওয়া ছোট, সহজ এবং আরও ergonomic।

আবার, আমরা লক্ষ করুন যে, যখন এই নোটবুক অন্বেষণ JointDistributionSequentialAutoBatched , অন্যান্য রূপগুলো JointDistribution সমতুল্য আছে AutoBatched । (ব্যবহারকারীদের জন্য JointDistributionCoroutine , JointDistributionCoroutineAutoBatched বাড়তি সুবিধা আছে যা আপনাকে তা নির্দিষ্ট করার আর প্রয়োজন Root নোড; যদি তোমরা আর কখনও ব্যবহৃত থাকেন JointDistributionCoroutine । আপনি নিরাপদে এই বিবৃতি উপেক্ষা করতে পারেন)

সমাপ্তি চিন্তা

এই নোটবুক, আমরা চালু JointDistributionSequentialAutoBatched এবং বিস্তারিতভাবে একটি সহজ উদাহরণ দিয়ে কাজ করেন। আশা করি আপনি TFP আকার এবং অটোব্যাচিং সম্পর্কে কিছু শিখেছেন!