Глубокое обучение с учетом неопределенности с помощью SNGP

Посмотреть на TensorFlow.org Запустить в Google Colab Посмотреть на GitHub Скачать блокнот

В приложениях ИИ, которые критичны для безопасности (например, принятие медицинских решений и автономное вождение) или где данные по своей природе зашумлены (например, понимание естественного языка), важно, чтобы глубокий классификатор надежно определял свою неопределенность. Глубокий классификатор должен быть в состоянии осознавать свои собственные ограничения и когда он должен передать управление экспертам-людям. В этом руководстве показано, как улучшить возможности глубокого классификатора в количественной оценке неопределенности с помощью метода, называемого спектрально-нормализованным нейронным гауссовским процессом ( SNGP ) .

Основная идея SNGP состоит в том, чтобы улучшить понимание расстояния глубоким классификатором путем применения простых модификаций к сети. Осведомленность модели о расстоянии — это мера того, как ее прогностическая вероятность отражает расстояние между тестовым примером и обучающими данными. Это желательное свойство, которое характерно для вероятностных моделей золотого стандарта (например, гауссовский процесс с ядрами RBF), но отсутствует в моделях с глубокими нейронными сетями. SNGP предоставляет простой способ внедрить это поведение гауссовского процесса в глубокий классификатор, сохраняя при этом его точность прогнозирования.

В этом учебном пособии реализуется модель SNGP на основе глубокой остаточной сети (ResNet) для набора данных о двух лунах и сравнивается ее поверхность неопределенности с двумя другими популярными подходами к неопределенности — отсев методом Монте-Карло и глубокий ансамбль ).

В этом руководстве показана модель SNGP на игрушечном 2D-наборе данных. Пример применения SNGP к реальной задаче понимания естественного языка с использованием базы BERT см. в руководстве по SNGP-BERT . Для высококачественных реализаций модели SNGP (и многих других методов определения неопределенности) в самых разных наборах эталонных данных (например, CIFAR-100 , ImageNet , обнаружение токсичности Jigsaw и т. д.) ознакомьтесь с эталонным тестом Uncertainty Baselines .

О СНГП

Спектрально-нормированный нейронный гауссовский процесс (SNGP) — это простой подход к улучшению качества неопределенности глубокого классификатора при сохранении аналогичного уровня точности и задержки. Учитывая глубокую остаточную сеть, SNGP вносит в модель два простых изменения:

  • Он применяет спектральную нормализацию к скрытым остаточным слоям.
  • Он заменяет плотный выходной слой на слой процесса Гаусса.

СНГП

По сравнению с другими подходами к определению неопределенности (например, методом исключения методом Монте-Карло или глубоким ансамблем) SNGP имеет несколько преимуществ:

  • Он работает для широкого спектра современных архитектур на основе остаточных данных (например, (Wide) ResNet, DenseNet, BERT и т. д.).
  • Это метод одной модели (т. е. не основанный на усреднении по ансамблю). Таким образом, SNGP имеет тот же уровень задержки, что и единая детерминированная сеть, и может быть легко масштабирована для больших наборов данных, таких как классификация ImageNet и Jigsaw Toxic Comment .
  • Он обладает высокой производительностью обнаружения вне домена благодаря свойству распознавания расстояния .

Недостатками этого метода являются:

  • Прогностическая неопределенность SNGP вычисляется с использованием приближения Лапласа . Поэтому теоретически апостериорная неопределенность SNGP отличается от неопределенности точного гауссовского процесса.

  • Обучение SNGP требует шага сброса ковариации в начале новой эпохи. Это может немного усложнить конвейер обучения. В этом руководстве показан простой способ реализовать это с помощью обратных вызовов Keras.

Настраивать

pip install --use-deprecated=legacy-resolver tf-models-official
# refresh pkg_resources so it takes the changes into account.
import pkg_resources
import importlib
importlib.reload(pkg_resources)
<module 'pkg_resources' from '/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/pkg_resources/__init__.py'>
import matplotlib.pyplot as plt
import matplotlib.colors as colors

import sklearn.datasets

import numpy as np
import tensorflow as tf

import official.nlp.modeling.layers as nlp_layers

Определение макросов визуализации

plt.rcParams['figure.dpi'] = 140

DEFAULT_X_RANGE = (-3.5, 3.5)
DEFAULT_Y_RANGE = (-2.5, 2.5)
DEFAULT_CMAP = colors.ListedColormap(["#377eb8", "#ff7f00"])
DEFAULT_NORM = colors.Normalize(vmin=0, vmax=1,)
DEFAULT_N_GRID = 100

Набор данных о двух лунах

Создайте наборы данных для обучения и оценки из набора данных двух лун .

def make_training_data(sample_size=500):
  """Create two moon training dataset."""
  train_examples, train_labels = sklearn.datasets.make_moons(
      n_samples=2 * sample_size, noise=0.1)

  # Adjust data position slightly.
  train_examples[train_labels == 0] += [-0.1, 0.2]
  train_examples[train_labels == 1] += [0.1, -0.2]

  return train_examples, train_labels

Оцените прогностическое поведение модели по всему двумерному входному пространству.

def make_testing_data(x_range=DEFAULT_X_RANGE, y_range=DEFAULT_Y_RANGE, n_grid=DEFAULT_N_GRID):
  """Create a mesh grid in 2D space."""
  # testing data (mesh grid over data space)
  x = np.linspace(x_range[0], x_range[1], n_grid)
  y = np.linspace(y_range[0], y_range[1], n_grid)
  xv, yv = np.meshgrid(x, y)
  return np.stack([xv.flatten(), yv.flatten()], axis=-1)

Чтобы оценить неопределенность модели, добавьте набор данных вне предметной области (OOD), который принадлежит к третьему классу. Модель никогда не видит эти примеры OOD во время обучения.

def make_ood_data(sample_size=500, means=(2.5, -1.75), vars=(0.01, 0.01)):
  return np.random.multivariate_normal(
      means, cov=np.diag(vars), size=sample_size)
# Load the train, test and OOD datasets.
train_examples, train_labels = make_training_data(
    sample_size=500)
test_examples = make_testing_data()
ood_examples = make_ood_data(sample_size=500)

# Visualize
pos_examples = train_examples[train_labels == 0]
neg_examples = train_examples[train_labels == 1]

plt.figure(figsize=(7, 5.5))

plt.scatter(pos_examples[:, 0], pos_examples[:, 1], c="#377eb8", alpha=0.5)
plt.scatter(neg_examples[:, 0], neg_examples[:, 1], c="#ff7f00", alpha=0.5)
plt.scatter(ood_examples[:, 0], ood_examples[:, 1], c="red", alpha=0.1)

plt.legend(["Postive", "Negative", "Out-of-Domain"])

plt.ylim(DEFAULT_Y_RANGE)
plt.xlim(DEFAULT_X_RANGE)

plt.show()

png

Здесь синий и оранжевый представляют положительные и отрицательные классы, а красный цвет представляет данные OOD. Ожидается, что модель, которая дает количественную оценку неопределенности, будет надежной, когда она близка к обучающим данным (т. е \(p(x_{test})\) близок к 0 или 1), и будет неопределенной, когда она находится далеко от областей обучающих данных (т. е \(p(x_{test})\) близок к 0,5). ).

Детерминированная модель

Определить модель

Начните с (базовой) детерминированной модели: многоуровневой остаточной сети (ResNet) с регуляризацией отсева.

В этом руководстве используется 6-слойная сеть ResNet со 128 скрытыми единицами.

resnet_config = dict(num_classes=2, num_layers=6, num_hidden=128)
resnet_model = DeepResNet(**resnet_config)
resnet_model.build((None, 2))
resnet_model.summary()
Model: "deep_res_net"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               multiple                  384       
                                                                 
 dense_1 (Dense)             multiple                  16512     
                                                                 
 dense_2 (Dense)             multiple                  16512     
                                                                 
 dense_3 (Dense)             multiple                  16512     
                                                                 
 dense_4 (Dense)             multiple                  16512     
                                                                 
 dense_5 (Dense)             multiple                  16512     
                                                                 
 dense_6 (Dense)             multiple                  16512     
                                                                 
 dense_7 (Dense)             multiple                  258       
                                                                 
=================================================================
Total params: 99,714
Trainable params: 99,330
Non-trainable params: 384
_________________________________________________________________

Модель поезда

Настройте параметры обучения для использования SparseCategoricalCrossentropy в качестве функции потерь и оптимизатора Адама.

loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = tf.keras.metrics.SparseCategoricalAccuracy(),
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)

train_config = dict(loss=loss, metrics=metrics, optimizer=optimizer)

Обучите модель на 100 эпох с размером пакета 128.

fit_config = dict(batch_size=128, epochs=100)
resnet_model.compile(**train_config)
resnet_model.fit(train_examples, train_labels, **fit_config)
Epoch 1/100
8/8 [==============================] - 1s 4ms/step - loss: 1.1251 - sparse_categorical_accuracy: 0.5050
Epoch 2/100
8/8 [==============================] - 0s 3ms/step - loss: 0.5538 - sparse_categorical_accuracy: 0.6920
Epoch 3/100
8/8 [==============================] - 0s 3ms/step - loss: 0.2881 - sparse_categorical_accuracy: 0.9160
Epoch 4/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1923 - sparse_categorical_accuracy: 0.9370
Epoch 5/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1550 - sparse_categorical_accuracy: 0.9420
Epoch 6/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1403 - sparse_categorical_accuracy: 0.9450
Epoch 7/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1269 - sparse_categorical_accuracy: 0.9430
Epoch 8/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1208 - sparse_categorical_accuracy: 0.9460
Epoch 9/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1158 - sparse_categorical_accuracy: 0.9510
Epoch 10/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1103 - sparse_categorical_accuracy: 0.9490
Epoch 11/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1051 - sparse_categorical_accuracy: 0.9510
Epoch 12/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1053 - sparse_categorical_accuracy: 0.9510
Epoch 13/100
8/8 [==============================] - 0s 3ms/step - loss: 0.1013 - sparse_categorical_accuracy: 0.9450
Epoch 14/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0967 - sparse_categorical_accuracy: 0.9500
Epoch 15/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0991 - sparse_categorical_accuracy: 0.9530
Epoch 16/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0984 - sparse_categorical_accuracy: 0.9500
Epoch 17/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0982 - sparse_categorical_accuracy: 0.9480
Epoch 18/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0918 - sparse_categorical_accuracy: 0.9510
Epoch 19/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0903 - sparse_categorical_accuracy: 0.9500
Epoch 20/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0883 - sparse_categorical_accuracy: 0.9510
Epoch 21/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0870 - sparse_categorical_accuracy: 0.9530
Epoch 22/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0884 - sparse_categorical_accuracy: 0.9560
Epoch 23/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0850 - sparse_categorical_accuracy: 0.9540
Epoch 24/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0808 - sparse_categorical_accuracy: 0.9580
Epoch 25/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0773 - sparse_categorical_accuracy: 0.9560
Epoch 26/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0801 - sparse_categorical_accuracy: 0.9590
Epoch 27/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0779 - sparse_categorical_accuracy: 0.9580
Epoch 28/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0807 - sparse_categorical_accuracy: 0.9580
Epoch 29/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0820 - sparse_categorical_accuracy: 0.9570
Epoch 30/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0730 - sparse_categorical_accuracy: 0.9600
Epoch 31/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0782 - sparse_categorical_accuracy: 0.9590
Epoch 32/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0704 - sparse_categorical_accuracy: 0.9600
Epoch 33/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0709 - sparse_categorical_accuracy: 0.9610
Epoch 34/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0758 - sparse_categorical_accuracy: 0.9580
Epoch 35/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0702 - sparse_categorical_accuracy: 0.9610
Epoch 36/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0688 - sparse_categorical_accuracy: 0.9600
Epoch 37/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0675 - sparse_categorical_accuracy: 0.9630
Epoch 38/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0636 - sparse_categorical_accuracy: 0.9690
Epoch 39/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0677 - sparse_categorical_accuracy: 0.9610
Epoch 40/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0702 - sparse_categorical_accuracy: 0.9650
Epoch 41/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0614 - sparse_categorical_accuracy: 0.9690
Epoch 42/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0663 - sparse_categorical_accuracy: 0.9680
Epoch 43/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0626 - sparse_categorical_accuracy: 0.9740
Epoch 44/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0590 - sparse_categorical_accuracy: 0.9760
Epoch 45/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0573 - sparse_categorical_accuracy: 0.9780
Epoch 46/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0568 - sparse_categorical_accuracy: 0.9770
Epoch 47/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0595 - sparse_categorical_accuracy: 0.9780
Epoch 48/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0482 - sparse_categorical_accuracy: 0.9840
Epoch 49/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0515 - sparse_categorical_accuracy: 0.9820
Epoch 50/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0525 - sparse_categorical_accuracy: 0.9830
Epoch 51/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0507 - sparse_categorical_accuracy: 0.9790
Epoch 52/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0433 - sparse_categorical_accuracy: 0.9850
Epoch 53/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0511 - sparse_categorical_accuracy: 0.9820
Epoch 54/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0501 - sparse_categorical_accuracy: 0.9820
Epoch 55/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0440 - sparse_categorical_accuracy: 0.9890
Epoch 56/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0438 - sparse_categorical_accuracy: 0.9850
Epoch 57/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0438 - sparse_categorical_accuracy: 0.9880
Epoch 58/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0416 - sparse_categorical_accuracy: 0.9860
Epoch 59/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0479 - sparse_categorical_accuracy: 0.9860
Epoch 60/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0434 - sparse_categorical_accuracy: 0.9860
Epoch 61/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0414 - sparse_categorical_accuracy: 0.9880
Epoch 62/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0402 - sparse_categorical_accuracy: 0.9870
Epoch 63/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0376 - sparse_categorical_accuracy: 0.9890
Epoch 64/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0337 - sparse_categorical_accuracy: 0.9900
Epoch 65/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0309 - sparse_categorical_accuracy: 0.9910
Epoch 66/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0336 - sparse_categorical_accuracy: 0.9910
Epoch 67/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0389 - sparse_categorical_accuracy: 0.9870
Epoch 68/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0333 - sparse_categorical_accuracy: 0.9920
Epoch 69/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0331 - sparse_categorical_accuracy: 0.9890
Epoch 70/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0346 - sparse_categorical_accuracy: 0.9900
Epoch 71/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0367 - sparse_categorical_accuracy: 0.9880
Epoch 72/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0283 - sparse_categorical_accuracy: 0.9920
Epoch 73/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0315 - sparse_categorical_accuracy: 0.9930
Epoch 74/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0271 - sparse_categorical_accuracy: 0.9900
Epoch 75/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0257 - sparse_categorical_accuracy: 0.9920
Epoch 76/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0289 - sparse_categorical_accuracy: 0.9900
Epoch 77/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0264 - sparse_categorical_accuracy: 0.9900
Epoch 78/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0272 - sparse_categorical_accuracy: 0.9910
Epoch 79/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0336 - sparse_categorical_accuracy: 0.9880
Epoch 80/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0249 - sparse_categorical_accuracy: 0.9900
Epoch 81/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0216 - sparse_categorical_accuracy: 0.9930
Epoch 82/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0279 - sparse_categorical_accuracy: 0.9890
Epoch 83/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0261 - sparse_categorical_accuracy: 0.9920
Epoch 84/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0235 - sparse_categorical_accuracy: 0.9920
Epoch 85/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0236 - sparse_categorical_accuracy: 0.9930
Epoch 86/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0219 - sparse_categorical_accuracy: 0.9920
Epoch 87/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0196 - sparse_categorical_accuracy: 0.9920
Epoch 88/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0215 - sparse_categorical_accuracy: 0.9900
Epoch 89/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0223 - sparse_categorical_accuracy: 0.9900
Epoch 90/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0200 - sparse_categorical_accuracy: 0.9950
Epoch 91/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0250 - sparse_categorical_accuracy: 0.9900
Epoch 92/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0160 - sparse_categorical_accuracy: 0.9940
Epoch 93/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0203 - sparse_categorical_accuracy: 0.9930
Epoch 94/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0203 - sparse_categorical_accuracy: 0.9930
Epoch 95/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0172 - sparse_categorical_accuracy: 0.9960
Epoch 96/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0209 - sparse_categorical_accuracy: 0.9940
Epoch 97/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0179 - sparse_categorical_accuracy: 0.9920
Epoch 98/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0195 - sparse_categorical_accuracy: 0.9940
Epoch 99/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0165 - sparse_categorical_accuracy: 0.9930
Epoch 100/100
8/8 [==============================] - 0s 3ms/step - loss: 0.0170 - sparse_categorical_accuracy: 0.9950
<keras.callbacks.History at 0x7ff7ac5c8fd0>

Визуализируйте неопределенность

Теперь визуализируйте предсказания детерминированной модели. Сначала постройте вероятность класса:

\[p(x) = softmax(logit(x))\]

resnet_logits = resnet_model(test_examples)
resnet_probs = tf.nn.softmax(resnet_logits, axis=-1)[:, 0]  # Take the probability for class 0.
_, ax = plt.subplots(figsize=(7, 5.5))

pcm = plot_uncertainty_surface(resnet_probs, ax=ax)

plt.colorbar(pcm, ax=ax)
plt.title("Class Probability, Deterministic Model")

plt.show()

png

На этом графике желтый и фиолетовый цвета — это прогностические вероятности для двух классов. Детерминированная модель хорошо справилась с классификацией двух известных классов (синих и оранжевых) с нелинейной границей решения. Тем не менее, он не распознает расстояние и уверенно классифицирует никогда не встречавшиеся красные примеры вне домена (OOD) как оранжевый класс.

Визуализируйте неопределенность модели путем вычисления предиктивной дисперсии :

\[var(x) = p(x) * (1 - p(x))\]

resnet_uncertainty = resnet_probs * (1 - resnet_probs)
_, ax = plt.subplots(figsize=(7, 5.5))

pcm = plot_uncertainty_surface(resnet_uncertainty, ax=ax)

plt.colorbar(pcm, ax=ax)
plt.title("Predictive Uncertainty, Deterministic Model")

plt.show()

png

На этом графике желтый цвет указывает на высокую неопределенность, а фиолетовый — на низкую неопределенность. Неопределенность детерминированного ResNet зависит только от расстояния тестовых примеров от границы решения. Это приводит к тому, что модель становится слишком самоуверенной, когда находится вне области обучения. В следующем разделе показано, как SNGP ведет себя по-разному в этом наборе данных.

Модель SNGP

Определить модель SNGP

Давайте теперь реализуем модель SNGP. Оба компонента SNGP, SpectralNormalization и RandomFeatureGaussianProcess , доступны на встроенных уровнях tensorflow_model.

СНГП

Рассмотрим эти два компонента более подробно. (Вы также можете перейти к разделу Модель SNGP, чтобы увидеть, как реализована полная модель.)

Оболочка спектральной нормализации

SpectralNormalization — это оболочка слоя Keras. Его можно применить к существующему плотному слою следующим образом:

dense = tf.keras.layers.Dense(units=10)
dense = nlp_layers.SpectralNormalization(dense, norm_multiplier=0.9)

Спектральная нормализация упорядочивает скрытый вес \(W\) , постепенно приближая его спектральную норму (т. е. наибольшее собственное значение \(W\)) к целевому значению norm_multiplier .

Слой Gaussian Process (GP)

RandomFeatureGaussianProcess реализует основанную на случайных признаках аппроксимацию модели гауссовского процесса, которую можно сквозно обучать с помощью глубокой нейронной сети. Под капотом уровень гауссовского процесса реализует двухуровневую сеть:

\[logits(x) = \Phi(x) \beta, \quad \Phi(x)=\sqrt{\frac{2}{M} } * cos(Wx + b)\]

Здесь \(x\) — входные данные, а \(W\) и \(b\) — замороженные веса, инициализированные случайным образом из гауссовского и равномерного распределений соответственно. (Поэтому \(\Phi(x)\) называются «случайными функциями».) \(\beta\) — это обучаемый вес ядра, аналогичный весу плотного слоя.

batch_size = 32
input_dim = 1024
num_classes = 10
gp_layer = nlp_layers.RandomFeatureGaussianProcess(units=num_classes,
                                               num_inducing=1024,
                                               normalize_input=False,
                                               scale_random_features=True,
                                               gp_cov_momentum=-1)

Основными параметрами слоев ГП являются:

  • units : размер выходных логитов.
  • num_inducing : Размер \(M\) скрытого веса \(W\). По умолчанию 1024.
  • normalize_input : применять ли нормализацию слоя ко входу \(x\).
  • scale_random_features : применять ли масштаб \(\sqrt{2/M}\) к скрытому выводу.
  • gp_cov_momentum управляет тем, как вычисляется ковариация модели. Если установлено положительное значение (например, 0,999), ковариационная матрица вычисляется с использованием обновления скользящего среднего на основе импульса (аналогично пакетной нормализации). Если установлено значение -1, ковариационная матрица обновляется без импульса.

Учитывая пакетный ввод с shape (batch_size, input_dim) , уровень GP возвращает тензор logits (shape (batch_size, num_classes) ) для прогнозирования, а также тензор covmat (shape (batch_size, batch_size) ), который является апостериорной ковариационной матрицей пакетные логиты.

embedding = tf.random.normal(shape=(batch_size, input_dim))

logits, covmat = gp_layer(embedding)

Теоретически можно расширить алгоритм для вычисления разных значений дисперсии для разных классов (как представлено в исходной статье SNGP ). Однако это трудно масштабировать для задач с большими объемами вывода (например, ImageNet или языковое моделирование).

Полная модель SNGP

Учитывая базовый класс DeepResNet , модель SNGP можно легко реализовать, изменив скрытый и выходной уровни остаточной сети. Для совместимости с model.fit() API также измените метод call() модели, чтобы он logits только во время обучения.

class DeepResNetSNGP(DeepResNet):
  def __init__(self, spec_norm_bound=0.9, **kwargs):
    self.spec_norm_bound = spec_norm_bound
    super().__init__(**kwargs)

  def make_dense_layer(self):
    """Applies spectral normalization to the hidden layer."""
    dense_layer = super().make_dense_layer()
    return nlp_layers.SpectralNormalization(
        dense_layer, norm_multiplier=self.spec_norm_bound)

  def make_output_layer(self, num_classes):
    """Uses Gaussian process as the output layer."""
    return nlp_layers.RandomFeatureGaussianProcess(
        num_classes, 
        gp_cov_momentum=-1,
        **self.classifier_kwargs)

  def call(self, inputs, training=False, return_covmat=False):
    # Gets logits and covariance matrix from GP layer.
    logits, covmat = super().call(inputs)

    # Returns only logits during training.
    if not training and return_covmat:
      return logits, covmat

    return logits

Используйте ту же архитектуру, что и детерминированная модель.

resnet_config
{'num_classes': 2, 'num_layers': 6, 'num_hidden': 128}
sngp_model = DeepResNetSNGP(**resnet_config)
sngp_model.build((None, 2))
sngp_model.summary()
Model: "deep_res_net_sngp"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_9 (Dense)             multiple                  384       
                                                                 
 spectral_normalization_1 (S  multiple                 16768     
 pectralNormalization)                                           
                                                                 
 spectral_normalization_2 (S  multiple                 16768     
 pectralNormalization)                                           
                                                                 
 spectral_normalization_3 (S  multiple                 16768     
 pectralNormalization)                                           
                                                                 
 spectral_normalization_4 (S  multiple                 16768     
 pectralNormalization)                                           
                                                                 
 spectral_normalization_5 (S  multiple                 16768     
 pectralNormalization)                                           
                                                                 
 spectral_normalization_6 (S  multiple                 16768     
 pectralNormalization)                                           
                                                                 
 random_feature_gaussian_pro  multiple                 1182722   
 cess (RandomFeatureGaussian                                     
 Process)                                                        
                                                                 
=================================================================
Total params: 1,283,714
Trainable params: 101,120
Non-trainable params: 1,182,594
_________________________________________________________________

Реализуйте обратный вызов Keras, чтобы сбросить ковариационную матрицу в начале новой эпохи.

class ResetCovarianceCallback(tf.keras.callbacks.Callback):

  def on_epoch_begin(self, epoch, logs=None):
    """Resets covariance matrix at the begining of the epoch."""
    if epoch > 0:
      self.model.classifier.reset_covariance_matrix()

Добавьте этот обратный вызов в класс модели DeepResNetSNGP .

class DeepResNetSNGPWithCovReset(DeepResNetSNGP):
  def fit(self, *args, **kwargs):
    """Adds ResetCovarianceCallback to model callbacks."""
    kwargs["callbacks"] = list(kwargs.get("callbacks", []))
    kwargs["callbacks"].append(ResetCovarianceCallback())

    return super().fit(*args, **kwargs)

Модель поезда

Используйте tf.keras.model.fit для обучения модели.

sngp_model = DeepResNetSNGPWithCovReset(**resnet_config)
sngp_model.compile(**train_config)
sngp_model.fit(train_examples, train_labels, **fit_config)
Epoch 1/100
8/8 [==============================] - 2s 5ms/step - loss: 0.6223 - sparse_categorical_accuracy: 0.9570
Epoch 2/100
8/8 [==============================] - 0s 4ms/step - loss: 0.5310 - sparse_categorical_accuracy: 0.9980
Epoch 3/100
8/8 [==============================] - 0s 4ms/step - loss: 0.4766 - sparse_categorical_accuracy: 0.9990
Epoch 4/100
8/8 [==============================] - 0s 5ms/step - loss: 0.4346 - sparse_categorical_accuracy: 0.9980
Epoch 5/100
8/8 [==============================] - 0s 5ms/step - loss: 0.4015 - sparse_categorical_accuracy: 0.9980
Epoch 6/100
8/8 [==============================] - 0s 5ms/step - loss: 0.3757 - sparse_categorical_accuracy: 0.9990
Epoch 7/100
8/8 [==============================] - 0s 4ms/step - loss: 0.3525 - sparse_categorical_accuracy: 0.9990
Epoch 8/100
8/8 [==============================] - 0s 4ms/step - loss: 0.3305 - sparse_categorical_accuracy: 0.9990
Epoch 9/100
8/8 [==============================] - 0s 5ms/step - loss: 0.3144 - sparse_categorical_accuracy: 0.9980
Epoch 10/100
8/8 [==============================] - 0s 5ms/step - loss: 0.2975 - sparse_categorical_accuracy: 0.9990
Epoch 11/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2832 - sparse_categorical_accuracy: 0.9990
Epoch 12/100
8/8 [==============================] - 0s 5ms/step - loss: 0.2707 - sparse_categorical_accuracy: 0.9990
Epoch 13/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2568 - sparse_categorical_accuracy: 0.9990
Epoch 14/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2470 - sparse_categorical_accuracy: 0.9970
Epoch 15/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2361 - sparse_categorical_accuracy: 0.9990
Epoch 16/100
8/8 [==============================] - 0s 5ms/step - loss: 0.2271 - sparse_categorical_accuracy: 0.9990
Epoch 17/100
8/8 [==============================] - 0s 5ms/step - loss: 0.2182 - sparse_categorical_accuracy: 0.9990
Epoch 18/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2097 - sparse_categorical_accuracy: 0.9990
Epoch 19/100
8/8 [==============================] - 0s 4ms/step - loss: 0.2018 - sparse_categorical_accuracy: 0.9990
Epoch 20/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1940 - sparse_categorical_accuracy: 0.9980
Epoch 21/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1892 - sparse_categorical_accuracy: 0.9990
Epoch 22/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1821 - sparse_categorical_accuracy: 0.9980
Epoch 23/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1768 - sparse_categorical_accuracy: 0.9990
Epoch 24/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1702 - sparse_categorical_accuracy: 0.9980
Epoch 25/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1664 - sparse_categorical_accuracy: 0.9990
Epoch 26/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1604 - sparse_categorical_accuracy: 0.9990
Epoch 27/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1565 - sparse_categorical_accuracy: 0.9990
Epoch 28/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1517 - sparse_categorical_accuracy: 0.9990
Epoch 29/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1469 - sparse_categorical_accuracy: 0.9990
Epoch 30/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1431 - sparse_categorical_accuracy: 0.9980
Epoch 31/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1385 - sparse_categorical_accuracy: 0.9980
Epoch 32/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1351 - sparse_categorical_accuracy: 0.9990
Epoch 33/100
8/8 [==============================] - 0s 5ms/step - loss: 0.1312 - sparse_categorical_accuracy: 0.9980
Epoch 34/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1289 - sparse_categorical_accuracy: 0.9990
Epoch 35/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1254 - sparse_categorical_accuracy: 0.9980
Epoch 36/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1223 - sparse_categorical_accuracy: 0.9980
Epoch 37/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1180 - sparse_categorical_accuracy: 0.9990
Epoch 38/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1167 - sparse_categorical_accuracy: 0.9990
Epoch 39/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1132 - sparse_categorical_accuracy: 0.9980
Epoch 40/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1110 - sparse_categorical_accuracy: 0.9990
Epoch 41/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1075 - sparse_categorical_accuracy: 0.9990
Epoch 42/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1067 - sparse_categorical_accuracy: 0.9990
Epoch 43/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1034 - sparse_categorical_accuracy: 0.9990
Epoch 44/100
8/8 [==============================] - 0s 4ms/step - loss: 0.1006 - sparse_categorical_accuracy: 0.9990
Epoch 45/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0991 - sparse_categorical_accuracy: 0.9990
Epoch 46/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0963 - sparse_categorical_accuracy: 0.9990
Epoch 47/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0943 - sparse_categorical_accuracy: 0.9980
Epoch 48/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0925 - sparse_categorical_accuracy: 0.9990
Epoch 49/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0905 - sparse_categorical_accuracy: 0.9990
Epoch 50/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0889 - sparse_categorical_accuracy: 0.9990
Epoch 51/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0863 - sparse_categorical_accuracy: 0.9980
Epoch 52/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0847 - sparse_categorical_accuracy: 0.9990
Epoch 53/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0831 - sparse_categorical_accuracy: 0.9980
Epoch 54/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0818 - sparse_categorical_accuracy: 0.9990
Epoch 55/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0799 - sparse_categorical_accuracy: 0.9990
Epoch 56/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0780 - sparse_categorical_accuracy: 0.9990
Epoch 57/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0768 - sparse_categorical_accuracy: 0.9990
Epoch 58/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0751 - sparse_categorical_accuracy: 0.9990
Epoch 59/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0748 - sparse_categorical_accuracy: 0.9990
Epoch 60/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0723 - sparse_categorical_accuracy: 0.9990
Epoch 61/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0712 - sparse_categorical_accuracy: 0.9990
Epoch 62/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0701 - sparse_categorical_accuracy: 0.9990
Epoch 63/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0701 - sparse_categorical_accuracy: 0.9990
Epoch 64/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0683 - sparse_categorical_accuracy: 0.9990
Epoch 65/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0665 - sparse_categorical_accuracy: 0.9990
Epoch 66/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0661 - sparse_categorical_accuracy: 0.9990
Epoch 67/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0636 - sparse_categorical_accuracy: 0.9990
Epoch 68/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0631 - sparse_categorical_accuracy: 0.9990
Epoch 69/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0620 - sparse_categorical_accuracy: 0.9990
Epoch 70/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0606 - sparse_categorical_accuracy: 0.9990
Epoch 71/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0601 - sparse_categorical_accuracy: 0.9980
Epoch 72/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0590 - sparse_categorical_accuracy: 0.9990
Epoch 73/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0586 - sparse_categorical_accuracy: 0.9990
Epoch 74/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0574 - sparse_categorical_accuracy: 0.9990
Epoch 75/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0565 - sparse_categorical_accuracy: 1.0000
Epoch 76/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0559 - sparse_categorical_accuracy: 0.9990
Epoch 77/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0549 - sparse_categorical_accuracy: 0.9990
Epoch 78/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0534 - sparse_categorical_accuracy: 1.0000
Epoch 79/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0532 - sparse_categorical_accuracy: 0.9990
Epoch 80/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0519 - sparse_categorical_accuracy: 1.0000
Epoch 81/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0511 - sparse_categorical_accuracy: 1.0000
Epoch 82/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0508 - sparse_categorical_accuracy: 0.9990
Epoch 83/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0499 - sparse_categorical_accuracy: 1.0000
Epoch 84/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0490 - sparse_categorical_accuracy: 1.0000
Epoch 85/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0490 - sparse_categorical_accuracy: 0.9990
Epoch 86/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0470 - sparse_categorical_accuracy: 1.0000
Epoch 87/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0468 - sparse_categorical_accuracy: 1.0000
Epoch 88/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0468 - sparse_categorical_accuracy: 1.0000
Epoch 89/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0453 - sparse_categorical_accuracy: 1.0000
Epoch 90/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0448 - sparse_categorical_accuracy: 1.0000
Epoch 91/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0441 - sparse_categorical_accuracy: 1.0000
Epoch 92/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0434 - sparse_categorical_accuracy: 1.0000
Epoch 93/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0431 - sparse_categorical_accuracy: 1.0000
Epoch 94/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0424 - sparse_categorical_accuracy: 1.0000
Epoch 95/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0420 - sparse_categorical_accuracy: 1.0000
Epoch 96/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0415 - sparse_categorical_accuracy: 1.0000
Epoch 97/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0409 - sparse_categorical_accuracy: 1.0000
Epoch 98/100
8/8 [==============================] - 0s 4ms/step - loss: 0.0401 - sparse_categorical_accuracy: 1.0000
Epoch 99/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0396 - sparse_categorical_accuracy: 1.0000
Epoch 100/100
8/8 [==============================] - 0s 5ms/step - loss: 0.0392 - sparse_categorical_accuracy: 1.0000
<keras.callbacks.History at 0x7ff7ac0f83d0>

Визуализируйте неопределенность

Сначала вычислите прогнозные логиты и отклонения.

sngp_logits, sngp_covmat = sngp_model(test_examples, return_covmat=True)
sngp_variance = tf.linalg.diag_part(sngp_covmat)[:, None]

Теперь вычислите апостериорную прогностическую вероятность. Классический метод вычисления предсказательной вероятности вероятностной модели заключается в использовании выборки Монте-Карло, т. е.

\[E(p(x)) = \frac{1}{M} \sum_{m=1}^M logit_m(x), \]

где \(M\) — размер выборки, а \(logit_m(x)\) — случайные выборки из апостериорного \(MultivariateNormal\)SNGP ( sngp_logits , sngp_covmat ). Однако этот подход может быть медленным для чувствительных к задержкам приложений, таких как автономное вождение или торги в реальном времени. Вместо этого можно аппроксимировать \(E(p(x))\) с помощью метода среднего поля :

\[E(p(x)) \approx softmax(\frac{logit(x)}{\sqrt{1+ \lambda * \sigma^2(x)} })\]

где \(\sigma^2(x)\) — это дисперсия SNGP, а \(\lambda\) часто выбирается как \(\pi/8\) или \(3/\pi^2\).

sngp_logits_adjusted = sngp_logits / tf.sqrt(1. + (np.pi / 8.) * sngp_variance)
sngp_probs = tf.nn.softmax(sngp_logits_adjusted, axis=-1)[:, 0]

Этот метод среднего поля реализован как встроенная layers.gaussian_process.mean_field_logits :

def compute_posterior_mean_probability(logits, covmat, lambda_param=np.pi / 8.):
  # Computes uncertainty-adjusted logits using the built-in method.
  logits_adjusted = nlp_layers.gaussian_process.mean_field_logits(
      logits, covmat, mean_field_factor=lambda_param)

  return tf.nn.softmax(logits_adjusted, axis=-1)[:, 0]
sngp_logits, sngp_covmat = sngp_model(test_examples, return_covmat=True)
sngp_probs = compute_posterior_mean_probability(sngp_logits, sngp_covmat)

Сводка SNGP

Соберите все вместе. Всю процедуру (обучение, оценку и вычисление неопределенности) можно выполнить всего в пяти строках:

def train_and_test_sngp(train_examples, test_examples):
  sngp_model = DeepResNetSNGPWithCovReset(**resnet_config)

  sngp_model.compile(**train_config)
  sngp_model.fit(train_examples, train_labels, verbose=0, **fit_config)

  sngp_logits, sngp_covmat = sngp_model(test_examples, return_covmat=True)
  sngp_probs = compute_posterior_mean_probability(sngp_logits, sngp_covmat)

  return sngp_probs
sngp_probs = train_and_test_sngp(train_examples, test_examples)

Визуализируйте вероятность класса (слева) и прогностическую неопределенность (справа) модели SNGP.

plot_predictions(sngp_probs, model_name="SNGP")

png

Помните, что на графике вероятности класса (слева) желтый и фиолетовый — это вероятности класса. При приближении к обучающей области данных SNGP правильно классифицирует примеры с высокой достоверностью (т. е. присваивая вероятность, близкую к 0 или 1). Вдали от обучающих данных SNGP постепенно становится менее уверенным, и его прогностическая вероятность становится близкой к 0,5, в то время как (нормализованная) неопределенность модели возрастает до 1.

Сравните это с поверхностью неопределенности детерминированной модели:

plot_predictions(resnet_probs, model_name="Deterministic")

png

Как упоминалось ранее, детерминированная модель не учитывает расстояния . Его неопределенность определяется расстоянием тестового примера от границы решения. Это приводит к тому, что модель выдает слишком самоуверенные прогнозы для примеров вне предметной области (красный).

Сравнение с другими подходами к оценке неопределенности

В этом разделе сравнивается неопределенность SNGP с отсевом методом Монте-Карло и глубоким ансамблем .

Оба этих метода основаны на усреднении методом Монте-Карло нескольких прямых проходов детерминированных моделей. Сначала установите размер ансамбля \(M\).

num_ensemble = 10

Отсев Монте-Карло

Учитывая обученную нейронную сеть со слоями отсева, отсев по методу Монте-Карло вычисляет среднюю прогностическую вероятность

\[E(p(x)) = \frac{1}{M}\sum_{m=1}^M softmax(logit_m(x))\]

путем усреднения нескольких проходов вперед с включенным Dropout \(\{logit_m(x)\}_{m=1}^M\).

def mc_dropout_sampling(test_examples):
  # Enable dropout during inference.
  return resnet_model(test_examples, training=True)
# Monte Carlo dropout inference.
dropout_logit_samples = [mc_dropout_sampling(test_examples) for _ in range(num_ensemble)]
dropout_prob_samples = [tf.nn.softmax(dropout_logits, axis=-1)[:, 0] for dropout_logits in dropout_logit_samples]
dropout_probs = tf.reduce_mean(dropout_prob_samples, axis=0)
dropout_probs = tf.reduce_mean(dropout_prob_samples, axis=0)
plot_predictions(dropout_probs, model_name="MC Dropout")

png

Глубокий ансамбль

Глубокий ансамбль — это современный (но дорогой) метод для глубокого обучения неопределенности. Чтобы обучить ансамбль Deep, сначала обучите членов ансамбля \(M\) .

# Deep ensemble training
resnet_ensemble = []
for _ in range(num_ensemble):
  resnet_model = DeepResNet(**resnet_config)
  resnet_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
  resnet_model.fit(train_examples, train_labels, verbose=0, **fit_config)  

  resnet_ensemble.append(resnet_model)

Соберите логиты и вычислите среднюю предсказательную вероятность \(E(p(x)) = \frac{1}{M}\sum_{m=1}^M softmax(logit_m(x))\).

# Deep ensemble inference
ensemble_logit_samples = [model(test_examples) for model in resnet_ensemble]
ensemble_prob_samples = [tf.nn.softmax(logits, axis=-1)[:, 0] for logits in ensemble_logit_samples]
ensemble_probs = tf.reduce_mean(ensemble_prob_samples, axis=0)
plot_predictions(ensemble_probs, model_name="Deep ensemble")

png

И MC Dropout, и глубокий ансамбль улучшают способность модели к неопределенности, делая границу решения менее определенной. Тем не менее, они оба наследуют ограничение детерминированной глубокой сети, заключающееся в отсутствии понимания расстояния.

Резюме

В этом уроке у вас есть:

  • Реализована модель SNGP для глубокого классификатора, чтобы улучшить его понимание расстояния.
  • Обучил модель SNGP от начала до конца, используя model.fit() API.
  • Визуализировано поведение неопределенности SNGP.
  • Сравнили поведение неопределенности между моделями SNGP, отсева методом Монте-Карло и моделями глубокого ансамбля.

Ресурсы и дополнительная литература