Посмотреть на TensorFlow.org | Запускаем в Google Colab | Посмотреть исходный код на GitHub | Скачать блокнот |
Этот ноутбук переопределяет и расширяет «анализ точки Изменить» байесовский пример из документации pymc3 .
Предпосылки
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (15,8)
%config InlineBackend.figure_format = 'retina'
import numpy as np
import pandas as pd
Набор данных
Набор данных из здесь . Обратите внимание, есть другой вариант этого примера плавает , но он « не хватает» данных - в этом случае вы должны были бы вменить пропущенные значения. (В противном случае ваша модель никогда не оставит свои начальные параметры, потому что функция правдоподобия будет неопределенной.)
disaster_data = np.array([ 4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6,
3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5,
2, 2, 3, 4, 2, 1, 3, 2, 2, 1, 1, 1, 1, 3, 0, 0,
1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1,
0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2,
3, 3, 1, 1, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4,
0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1])
years = np.arange(1851, 1962)
plt.plot(years, disaster_data, 'o', markersize=8);
plt.ylabel('Disaster count')
plt.xlabel('Year')
plt.title('Mining disaster data set')
plt.show()
Вероятностная модель
Модель предполагает «точку переключения» (например, год, в течение которого менялись правила безопасности) и распределенную по Пуассону частоту бедствий с постоянными (но потенциально разными) скоростями до и после этой точки переключения.
Фактическое количество стихийных бедствий фиксировано (наблюдается); любой образец этой модели должен будет указать как точку переключения, так и «раннюю» и «позднюю» частоту бедствий.
Оригинальная модель из примера pymc3 документации :
\[ \begin{align*} (D_t|s,e,l)&\sim \text{Poisson}(r_t), \\ & \,\quad\text{with}\; r_t = \begin{cases}e & \text{if}\; t < s\\l &\text{if}\; t \ge s\end{cases} \\ s&\sim\text{Discrete Uniform}(t_l,\,t_h) \\ e&\sim\text{Exponential}(r_e)\\ l&\sim\text{Exponential}(r_l) \end{align*} \]
Тем не менее, средняя скорость аварийного \(r_t\) имеет разрыв в точке переключения \(s\), что делает его не дифференцируемо. Таким образом, не дает градиент сигнала алгоритму гамильтонов Монт - Карло (HMC) - а потому , что \(s\) перед непрерывен, запасным вариантом HMC для случайного блуждания является достаточно хорошим , чтобы найти области с высокой вероятностью массы в этом примере.
В качестве второй модели, модифицировать исходную модель с использованием сигмовидного «переключателя» между е и л , чтобы сделать переход дифференцируем, и использовать непрерывное равномерное распределение для точки переключения \(s\). (Можно утверждать, что эта модель более соответствует действительности, поскольку «переключение» средней скорости, вероятно, будет растянутым на несколько лет.) Новая модель, таким образом:
\[ \begin{align*} (D_t|s,e,l)&\sim\text{Poisson}(r_t), \\ & \,\quad \text{with}\; r_t = e + \frac{1}{1+\exp(s-t)}(l-e) \\ s&\sim\text{Uniform}(t_l,\,t_h) \\ e&\sim\text{Exponential}(r_e)\\ l&\sim\text{Exponential}(r_l) \end{align*} \]
При отсутствии дополнительной информации , которую мы предполагаем \(r_e = r_l = 1\) в качестве параметров для настоятелей. Мы запустим обе модели и сравним их результаты вывода.
def disaster_count_model(disaster_rate_fn):
disaster_count = tfd.JointDistributionNamed(dict(
e=tfd.Exponential(rate=1.),
l=tfd.Exponential(rate=1.),
s=tfd.Uniform(0., high=len(years)),
d_t=lambda s, l, e: tfd.Independent(
tfd.Poisson(rate=disaster_rate_fn(np.arange(len(years)), s, l, e)),
reinterpreted_batch_ndims=1)
))
return disaster_count
def disaster_rate_switch(ys, s, l, e):
return tf.where(ys < s, e, l)
def disaster_rate_sigmoid(ys, s, l, e):
return e + tf.sigmoid(ys - s) * (l - e)
model_switch = disaster_count_model(disaster_rate_switch)
model_sigmoid = disaster_count_model(disaster_rate_sigmoid)
Приведенный выше код определяет модель через распределения JointDistributionSequential. В disaster_rate
функций вызываются с массивом [0, ..., len(years)-1]
для получения вектора len(years)
случайные величины - года перед порогом switchpoint
являются early_disaster_rate
, те , после late_disaster_rate
( по модулю сигмовидный переход).
Вот проверка работоспособности функции проверки целевого журнала:
def target_log_prob_fn(model, s, e, l):
return model.log_prob(s=s, e=e, l=l, d_t=disaster_data)
models = [model_switch, model_sigmoid]
print([target_log_prob_fn(m, 40., 3., .9).numpy() for m in models]) # Somewhat likely result
print([target_log_prob_fn(m, 60., 1., 5.).numpy() for m in models]) # Rather unlikely result
print([target_log_prob_fn(m, -10., 1., 1.).numpy() for m in models]) # Impossible result
[-176.94559, -176.28717] [-371.3125, -366.8816] [-inf, -inf]
HMC делает байесовский вывод
Мы определяем количество результатов и требуемых шагов приработки; код в основном по образцу документации tfp.mcmc.HamiltonianMonteCarlo . Он использует адаптивный размер шага (в противном случае результат очень чувствителен к выбранному значению размера шага). Мы используем значение, равное единице, в качестве начального состояния цепочки.
Однако это не вся история. Если вы вернетесь к приведенному выше определению модели, вы заметите, что некоторые распределения вероятностей не определены четко на всей прямой действительных чисел. Поэтому мы ограничиваем пространство , которое HMC должен изучить обертывание ядра HMC с TransformedTransitionKernel , которое указует вперед bijectors для преобразования вещественных чисел на области , что распределение вероятностей определяются на (см комментариев в коде ниже).
num_results = 10000
num_burnin_steps = 3000
@tf.function(autograph=False, jit_compile=True)
def make_chain(target_log_prob_fn):
kernel = tfp.mcmc.TransformedTransitionKernel(
inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=target_log_prob_fn,
step_size=0.05,
num_leapfrog_steps=3),
bijector=[
# The switchpoint is constrained between zero and len(years).
# Hence we supply a bijector that maps the real numbers (in a
# differentiable way) to the interval (0;len(yers))
tfb.Sigmoid(low=0., high=tf.cast(len(years), dtype=tf.float32)),
# Early and late disaster rate: The exponential distribution is
# defined on the positive real numbers
tfb.Softplus(),
tfb.Softplus(),
])
kernel = tfp.mcmc.SimpleStepSizeAdaptation(
inner_kernel=kernel,
num_adaptation_steps=int(0.8*num_burnin_steps))
states = tfp.mcmc.sample_chain(
num_results=num_results,
num_burnin_steps=num_burnin_steps,
current_state=[
# The three latent variables
tf.ones([], name='init_switchpoint'),
tf.ones([], name='init_early_disaster_rate'),
tf.ones([], name='init_late_disaster_rate'),
],
trace_fn=None,
kernel=kernel)
return states
switch_samples = [s.numpy() for s in make_chain(
lambda *args: target_log_prob_fn(model_switch, *args))]
sigmoid_samples = [s.numpy() for s in make_chain(
lambda *args: target_log_prob_fn(model_sigmoid, *args))]
switchpoint, early_disaster_rate, late_disaster_rate = zip(
switch_samples, sigmoid_samples)
Запустите обе модели параллельно:
Визуализируйте результат
Мы визуализируем результат в виде гистограмм выборок апостериорного распределения для ранней и поздней частоты бедствий, а также точки переключения. Гистограммы перекрываются сплошной линией, представляющей медианное значение выборки, а также границы 95% вероятного интервала в виде пунктирных линий.
def _desc(v):
return '(median: {}; 95%ile CI: $[{}, {}]$)'.format(
*np.round(np.percentile(v, [50, 2.5, 97.5]), 2))
for t, v in [
('Early disaster rate ($e$) posterior samples', early_disaster_rate),
('Late disaster rate ($l$) posterior samples', late_disaster_rate),
('Switch point ($s$) posterior samples', years[0] + switchpoint),
]:
fig, ax = plt.subplots(nrows=1, ncols=2, sharex=True)
for (m, i) in (('Switch', 0), ('Sigmoid', 1)):
a = ax[i]
a.hist(v[i], bins=50)
a.axvline(x=np.percentile(v[i], 50), color='k')
a.axvline(x=np.percentile(v[i], 2.5), color='k', ls='dashed', alpha=.5)
a.axvline(x=np.percentile(v[i], 97.5), color='k', ls='dashed', alpha=.5)
a.set_title(m + ' model ' + _desc(v[i]))
fig.suptitle(t)
plt.show()