在诸如医疗决策制定和自动驾驶等安全性至关重要的 AI 应用中,或者在数据存在固有噪声的情况(例如自然语言理解)下,深度分类器必须能够可靠地量化其不确定性。深度分类器应能感知到自身的局限性,并能意识到何时应将控制权交给人类专家。本教程展示了如何使用称为谱归一化神经高斯过程 (SNGP{.external}) 的技术来提高深度分类器量化不确定性的能力。
SNGP 的核心思想是通过对网络应用简单的修改来提高深度分类器的距离感知。模型的距离感知是衡量其预测概率能否准确反映测试样本与训练数据间距离的指标。这是黄金标准下的概率模型(例如,具有 RBF 内核的高斯过程{.external})常需的属性,但在具有深度神经网络的模型中却缺少这一属性。SNGP 提供了一种将这种高斯过程行为注入到深度分类器中,同时能够保持其预测准确率的简单方式。
本教程在 scikit-learn 的双月{.external} 数据集上实现了一个基于深度残差网络 (ResNet) 的 SNGP 模型,并将其不确定性表面与其他两种热门不确定性方式的不确定性表面进行比较:蒙特卡罗随机失活{.external} 和深度集成{.external}。
本教程例举了基于小型 2D 数据集的 SNGP 模型。有关使用 BERT-base 将 SNGP 应用于现实世界自然语言理解任务的示例,请参阅 SNGP-BERT 教程。有关基于各种基准数据集(例如 CIFAR-100、ImageNet、Jigsaw 恶意检测等)的 SNGP 模型(和许多其他不确定性方法)的高质量实现方式,请参阅不确定性基线{.external} 基准。
关于 SNGP
SNGP 是一种可提高深度分类器的不确定性质量,同时能够保持相似准确率和延迟水平的简单方式。给定一个深度残差网络,SNGP 即会对模型进行两项简单更改:
- 将谱归一化应用于隐藏的残差层。
- 将 Dense 输出层替换为高斯过程层。
与其他不确定性方式(例如蒙特卡罗随机失活或深度集成)相比,SNGP 具有以下几项优点:
- 适用于各种最先进的基于残差的架构(例如 (Wide) ResNet、DenseNet 或 BERT)。
- 是一种单模型方法(不依赖于集合平均)。因此,SNGP 具有与单一确定性网络相似的延迟水平,并且可以轻松扩展至大型数据集,如 ImageNet{.external} 和 Jigsaw 恶意评论分类{.external}。
- 距离感知属性使之具有强大的域外检测性能。
这种方法的缺点为:
SNGP 的预测不确定性是使用拉普拉斯近似{.external}计算的。因此在理论上,SNGP 的后验不确定性与精确高斯过程的后验不确定性不同。
SNGP 训练需要在新周期开始时进行协方差重置步骤。这会对训练流水线额外增添些许复杂性。本教程展示了一种使用 Keras 回调实现此功能的简单方式。
安装
pip install -U -q --use-deprecated=legacy-resolver tf-models-official tensorflow
# refresh pkg_resources so it takes the changes into account.
import pkg_resources
import importlib
importlib.reload(pkg_resources)
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import sklearn.datasets
import numpy as np
import tensorflow as tf
import official.nlp.modeling.layers as nlp_layers
定义呈现宏
plt.rcParams['figure.dpi'] = 140
DEFAULT_X_RANGE = (-3.5, 3.5)
DEFAULT_Y_RANGE = (-2.5, 2.5)
DEFAULT_CMAP = colors.ListedColormap(["#377eb8", "#ff7f00"])
DEFAULT_NORM = colors.Normalize(vmin=0, vmax=1,)
DEFAULT_N_GRID = 100
双月数据集
从 scikit-learn 双月数据集{.external} 创建训练数据集和评估数据集。
def make_training_data(sample_size=500):
"""Create two moon training dataset."""
train_examples, train_labels = sklearn.datasets.make_moons(
n_samples=2 * sample_size, noise=0.1)
# Adjust data position slightly.
train_examples[train_labels == 0] += [-0.1, 0.2]
train_examples[train_labels == 1] += [0.1, -0.2]
return train_examples, train_labels
评估模型在整个二维输入空间上的预测行为。
def make_testing_data(x_range=DEFAULT_X_RANGE, y_range=DEFAULT_Y_RANGE, n_grid=DEFAULT_N_GRID):
"""Create a mesh grid in 2D space."""
# testing data (mesh grid over data space)
x = np.linspace(x_range[0], x_range[1], n_grid)
y = np.linspace(y_range[0], y_range[1], n_grid)
xv, yv = np.meshgrid(x, y)
return np.stack([xv.flatten(), yv.flatten()], axis=-1)
要评估模型不确定性,请添加属于第三类的域外 (OOD) 数据集。该模型在训练期间从不观测这些 OOD 样本。
def make_ood_data(sample_size=500, means=(2.5, -1.75), vars=(0.01, 0.01)):
return np.random.multivariate_normal(
means, cov=np.diag(vars), size=sample_size)
# Load the train, test and OOD datasets.
train_examples, train_labels = make_training_data(
sample_size=500)
test_examples = make_testing_data()
ood_examples = make_ood_data(sample_size=500)
# Visualize
pos_examples = train_examples[train_labels == 0]
neg_examples = train_examples[train_labels == 1]
plt.figure(figsize=(7, 5.5))
plt.scatter(pos_examples[:, 0], pos_examples[:, 1], c="#377eb8", alpha=0.5)
plt.scatter(neg_examples[:, 0], neg_examples[:, 1], c="#ff7f00", alpha=0.5)
plt.scatter(ood_examples[:, 0], ood_examples[:, 1], c="red", alpha=0.1)
plt.legend(["Positive", "Negative", "Out-of-Domain"])
plt.ylim(DEFAULT_Y_RANGE)
plt.xlim(DEFAULT_X_RANGE)
plt.show()
这里,蓝色和橙色代表正负类,红色代表 OOD 数据。能够准确量化不确定性的模型在接近训练数据时(即 接近 0 或 1)应达到较高的置信度,而在远离训练数据区域时(即 接近 0.5)则不确定性较高。
确定性模型
定义模型
从(基线)确定性模型开始:具有随机失活正则化的多层残差网络 (ResNet)。
class DeepResNet(tf.keras.Model):
"""Defines a multi-layer residual network."""
def __init__(self, num_classes, num_layers=3, num_hidden=128,
dropout_rate=0.1, **classifier_kwargs):
super().__init__()
# Defines class meta data.
self.num_hidden = num_hidden
self.num_layers = num_layers
self.dropout_rate = dropout_rate
self.classifier_kwargs = classifier_kwargs
# Defines the hidden layers.
self.input_layer = tf.keras.layers.Dense(self.num_hidden, trainable=False)
self.dense_layers = [self.make_dense_layer() for _ in range(num_layers)]
# Defines the output layer.
self.classifier = self.make_output_layer(num_classes)
def call(self, inputs):
# Projects the 2d input data to high dimension.
hidden = self.input_layer(inputs)
# Computes the ResNet hidden representations.
for i in range(self.num_layers):
resid = self.dense_layers[i](hidden)
resid = tf.keras.layers.Dropout(self.dropout_rate)(resid)
hidden += resid
return self.classifier(hidden)
def make_dense_layer(self):
"""Uses the Dense layer as the hidden layer."""
return tf.keras.layers.Dense(self.num_hidden, activation="relu")
def make_output_layer(self, num_classes):
"""Uses the Dense layer as the output layer."""
return tf.keras.layers.Dense(
num_classes, **self.classifier_kwargs)
本教程使用具有 128 个隐藏单元的六层 ResNet。
resnet_config = dict(num_classes=2, num_layers=6, num_hidden=128)
resnet_model = DeepResNet(**resnet_config)
resnet_model.build((None, 2))
resnet_model.summary()
训练模型
配置训练参数以使用 SparseCategoricalCrossentropy
作为损失函数和 Adam 优化器。
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = tf.keras.metrics.SparseCategoricalAccuracy(),
optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=1e-4)
train_config = dict(loss=loss, metrics=metrics, optimizer=optimizer)
以 128 为批次大小对模型训练 100 个周期。
fit_config = dict(batch_size=128, epochs=100)
resnet_model.compile(**train_config)
resnet_model.fit(train_examples, train_labels, **fit_config)
呈现不确定性
def plot_uncertainty_surface(test_uncertainty, ax, cmap=None):
"""Visualizes the 2D uncertainty surface.
For simplicity, assume these objects already exist in the memory:
test_examples: Array of test examples, shape (num_test, 2).
train_labels: Array of train labels, shape (num_train, ).
train_examples: Array of train examples, shape (num_train, 2).
Arguments:
test_uncertainty: Array of uncertainty scores, shape (num_test,).
ax: A matplotlib Axes object that specifies a matplotlib figure.
cmap: A matplotlib colormap object specifying the palette of the
predictive surface.
Returns:
pcm: A matplotlib PathCollection object that contains the palette
information of the uncertainty plot.
"""
# Normalize uncertainty for better visualization.
test_uncertainty = test_uncertainty / np.max(test_uncertainty)
# Set view limits.
ax.set_ylim(DEFAULT_Y_RANGE)
ax.set_xlim(DEFAULT_X_RANGE)
# Plot normalized uncertainty surface.
pcm = ax.imshow(
np.reshape(test_uncertainty, [DEFAULT_N_GRID, DEFAULT_N_GRID]),
cmap=cmap,
origin="lower",
extent=DEFAULT_X_RANGE + DEFAULT_Y_RANGE,
vmin=DEFAULT_NORM.vmin,
vmax=DEFAULT_NORM.vmax,
interpolation='bicubic',
aspect='auto')
# Plot training data.
ax.scatter(train_examples[:, 0], train_examples[:, 1],
c=train_labels, cmap=DEFAULT_CMAP, alpha=0.5)
ax.scatter(ood_examples[:, 0], ood_examples[:, 1], c="red", alpha=0.1)
return pcm
现在,呈现确定性模型的预测。首先绘制类概率:
resnet_logits = resnet_model(test_examples)
resnet_probs = tf.nn.softmax(resnet_logits, axis=-1)[:, 0] # Take the probability for class 0.
_, ax = plt.subplots(figsize=(7, 5.5))
pcm = plot_uncertainty_surface(resnet_probs, ax=ax)
plt.colorbar(pcm, ax=ax)
plt.title("Class Probability, Deterministic Model")
plt.show()
在此图中,黄色和紫色为这两个类的预测概率。确定性模型在利用非线性决策边界对两个已知类(蓝色和橙色)进行分类方面表现优秀。然而,它并不具备距离感知,并且会以较高的置信度将从未观测到的红色域外 (OOD) 样本分类为橙色类。
通过计算预测方差来呈现模型的不确定性:\(var(x) = p(x) * (1 - p(x))\)
resnet_uncertainty = resnet_probs * (1 - resnet_probs)
_, ax = plt.subplots(figsize=(7, 5.5))
pcm = plot_uncertainty_surface(resnet_uncertainty, ax=ax)
plt.colorbar(pcm, ax=ax)
plt.title("Predictive Uncertainty, Deterministic Model")
plt.show()
在此图中,黄色表示高不确定性,紫色表示低不确定性。确定性 ResNet 的不确定性仅取决于测试样本与决策边界之间的距离。这会导致模型在超出训练域时出现置信度过度的问题。下一部分将展示 SNGP 在此数据集上的行为方式有何不同。
SNGP 模型
定义 SNGP 模型
现在,我们来实现 SNGP 模型。SNGP 组件 SpectralNormalization
和 RandomFeatureGaussianProcess
均在 tensorflow_model 的内置层中可用。
让我们更加详细地检查这两个组件。(您也可以跳转到完整 SNGP 模型部分以了解 SNGP 的实现方法。)
SpectralNormalization
封装容器
SpectralNormalization
{.external} 是 Keras 层封装容器。它能够以如下方式应用于现有的 Dense 层:
dense = tf.keras.layers.Dense(units=10)
dense = nlp_layers.SpectralNormalization(dense, norm_multiplier=0.9)
谱归一化会通过将其谱范数(即 的最大特征值)朝目标值 norm_multiplier
逐渐引导来正则化隐藏权重 。
注:通常情况下,最好将 norm_multiplier
设置为小于 1 的值。但在实践中,也可以将其放宽为更大的值,以确保深度网络具有足够的表达力。
高斯过程 (GP) 层
RandomFeatureGaussianProcess
{.external} 可对能够通过深度神经网络进行端到端训练的高斯过程模型实现基于随机特征的近似{.external}。从底层来看,高斯过程层实现了一个两层网络:
$$logits(x) = \Phi(x) \beta, \quad \Phi(x)=\sqrt{\frac{2}{M} } * cos(Wx + b)$$
Here, is the input, and and are frozen weights initialized randomly from Gaussian and Uniform distributions, respectively. (Therefore, are called "random features".) is the learnable kernel weight similar to that of a Dense layer.
batch_size = 32
input_dim = 1024
num_classes = 10
gp_layer = nlp_layers.RandomFeatureGaussianProcess(units=num_classes,
num_inducing=1024,
normalize_input=False,
scale_random_features=True,
gp_cov_momentum=-1)
GP 层的主要参数包括:
units
:输出 logit 的维度。num_inducing
:隐藏权重 的维度 。默认值为 1024。normalize_input
:是否对输入 应用层归一化。scale_random_features
:是否将缩放 应用于隐藏输出。
注:对于对学习率较为敏感的深度神经网络(例如 ResNet-50 和 ResNet-110),一般建议设置 normalize_input=True
以提高训练稳定性,并设置 scale_random_features=False
以避免在通过 GP 层时学习率被意外修改。
gp_cov_momentum
可以控制如何计算模型协方差。如果设置为正值(例如0.999
),将使用基于动量的移动平均值更新(类似于批归一化)计算协方差矩阵。如果设置为-1
,则协方差矩阵将在无动量情况下更新。
注:基于动量的更新方法可能会对批次大小较为敏感。因此,通常建议设置 gp_cov_momentum=-1
以精确计算协方差。为了使其正常工作,协方差矩阵 estimator 需要在新周期开始时重置,以避免重复计算相同的数据。对于 RandomFeatureGaussianProcess
,这可以通过调用其 reset_covariance_matrix()
来实现。下一部分展示了使用 Keras 内置 API 的简单实现。
给定一个形状为 (batch_size, input_dim)
的批次输入,GP 层会返回 logits
张量(形状为 (batch_size, num_classes)
)用于预测;以及 covmat
张量(形状为 (batch_size, batch_size)
),它是批次 logit 的后验协方差矩阵。
embedding = tf.random.normal(shape=(batch_size, input_dim))
logits, covmat = gp_layer(embedding)
注:请注意,在 SNGP 模型的这种实现方式下,所有类的预测 logit 都会共享相同的协方差矩阵 ,后者描述了 与训练数据之间的距离。
理论上讲,可以扩展算法来为不同类计算不同的方差值(如原始 SNGP 论文{.external}中所介绍)。但是,这很难扩展到具有大输出空间的问题(例如使用 ImageNet 或语言建模的分类)。
完整 SNGP 模型
给定基类 DeepResNet
,即可通过修改残差网络的隐藏层和输出层来轻松实现 SNGP 模型。为了与 Keras model.fit()
API 兼容,还需修改模型的 call()
方法,使其仅在训练期间输出 logits
。
class DeepResNetSNGP(DeepResNet):
def __init__(self, spec_norm_bound=0.9, **kwargs):
self.spec_norm_bound = spec_norm_bound
super().__init__(**kwargs)
def make_dense_layer(self):
"""Applies spectral normalization to the hidden layer."""
dense_layer = super().make_dense_layer()
return nlp_layers.SpectralNormalization(
dense_layer, norm_multiplier=self.spec_norm_bound)
def make_output_layer(self, num_classes):
"""Uses Gaussian process as the output layer."""
return nlp_layers.RandomFeatureGaussianProcess(
num_classes,
gp_cov_momentum=-1,
**self.classifier_kwargs)
def call(self, inputs, training=False, return_covmat=False):
# Gets logits and a covariance matrix from the GP layer.
logits, covmat = super().call(inputs)
# Returns only logits during training.
if not training and return_covmat:
return logits, covmat
return logits
使用与确定性模型相同的架构。
resnet_config
sngp_model = DeepResNetSNGP(**resnet_config)
sngp_model.build((None, 2))
sngp_model.summary()
class ResetCovarianceCallback(tf.keras.callbacks.Callback):
def on_epoch_begin(self, epoch, logs=None):
"""Resets covariance matrix at the beginning of the epoch."""
if epoch > 0:
self.model.classifier.reset_covariance_matrix()
将此回调添加到 DeepResNetSNGP
模型类。
class DeepResNetSNGPWithCovReset(DeepResNetSNGP):
def fit(self, *args, **kwargs):
"""Adds ResetCovarianceCallback to model callbacks."""
kwargs["callbacks"] = list(kwargs.get("callbacks", []))
kwargs["callbacks"].append(ResetCovarianceCallback())
return super().fit(*args, **kwargs)
训练模型
使用 tf.keras.model.fit
训练模型。
sngp_model = DeepResNetSNGPWithCovReset(**resnet_config)
sngp_model.compile(**train_config)
sngp_model.fit(train_examples, train_labels, **fit_config)
呈现不确定性
首先,计算预测 logit 和方差。
sngp_logits, sngp_covmat = sngp_model(test_examples, return_covmat=True)
sngp_variance = tf.linalg.diag_part(sngp_covmat)[:, None]
现在,计算后验预测概率。计算概率模型预测概率的经典方法是使用蒙特卡罗采样法,即:
$$E(p(x)) = \frac{1}{M} \sum_{m=1}^M logit_m(x), $$
其中 为样本大小, 为来自 SNGP 后验 (sngp_logits
,sngp_covmat
) 的随机样本。但是,这种方式对于延迟敏感型应用(例如自动驾驶或实时竞价)而言,速度可能较慢。相反,您可以使用平均场法{.external}来逼近 :
$$E(p(x)) \approx softmax(\frac{logit(x)}{\sqrt{1+ \lambda * \sigma^2(x)} })$$
where is the SNGP variance, and is often chosen as or .
sngp_logits_adjusted = sngp_logits / tf.sqrt(1. + (np.pi / 8.) * sngp_variance)
sngp_probs = tf.nn.softmax(sngp_logits_adjusted, axis=-1)[:, 0]
注:除了将 限定为固定值之外,您还可以将其视为超参数,并对其进行调整以优化模型的校准性能。这在深度学习不确定性文献中被称为温度缩放{.external}。
这种平均场方法以内置函数 layers.gaussian_process.mean_field_logits
形式实现:
def compute_posterior_mean_probability(logits, covmat, lambda_param=np.pi / 8.):
# Computes uncertainty-adjusted logits using the built-in method.
logits_adjusted = nlp_layers.gaussian_process.mean_field_logits(
logits, covmat, mean_field_factor=lambda_param)
return tf.nn.softmax(logits_adjusted, axis=-1)[:, 0]
sngp_logits, sngp_covmat = sngp_model(test_examples, return_covmat=True)
sngp_probs = compute_posterior_mean_probability(sngp_logits, sngp_covmat)
SNGP 摘要
def plot_predictions(pred_probs, model_name=""):
"""Plot normalized class probabilities and predictive uncertainties."""
# Compute predictive uncertainty.
uncertainty = pred_probs * (1. - pred_probs)
# Initialize the plot axes.
fig, axs = plt.subplots(1, 2, figsize=(14, 5))
# Plots the class probability.
pcm_0 = plot_uncertainty_surface(pred_probs, ax=axs[0])
# Plots the predictive uncertainty.
pcm_1 = plot_uncertainty_surface(uncertainty, ax=axs[1])
# Adds color bars and titles.
fig.colorbar(pcm_0, ax=axs[0])
fig.colorbar(pcm_1, ax=axs[1])
axs[0].set_title(f"Class Probability, {model_name}")
axs[1].set_title(f"(Normalized) Predictive Uncertainty, {model_name}")
plt.show()
您现在可以将所有内容归总到一起。整个过程(训练、评估和不确定性计算)只需五行即可完成:
def train_and_test_sngp(train_examples, test_examples):
sngp_model = DeepResNetSNGPWithCovReset(**resnet_config)
sngp_model.compile(**train_config)
sngp_model.fit(train_examples, train_labels, verbose=0, **fit_config)
sngp_logits, sngp_covmat = sngp_model(test_examples, return_covmat=True)
sngp_probs = compute_posterior_mean_probability(sngp_logits, sngp_covmat)
return sngp_probs
sngp_probs = train_and_test_sngp(train_examples, test_examples)
呈现 SNGP 模型的类概率(左)和预测不确定性(右)。
plot_predictions(sngp_probs, model_name="SNGP")
请记住,在类概率图(左)中,黄色和紫色为类概率。当接近训练数据域时,SNGP 会以较高的置信度正确分类样本(即,分配接近 0 或 1 的概率)。当远离训练数据时,SNGP 的置信度会逐渐下降,其预测概率接近 0.5,而(归一化)模型不确定性上升到 1。
将此与确定性模型的不确定性表面进行比较:
plot_predictions(resnet_probs, model_name="Deterministic")
如上文所述,确定性模型不具备距离感知。它的不确定性会由测试样本与决策边界之间的距离定义。这会导致模型对域外样本(红色)产生置信度过高的预测。
与其他不确定性方式的比较
本部分将对 SNGP 的不确定性与蒙特卡罗随机失活{.external}和深度集成{.external}进行比较。
这两种方法均基于确定性模型多个前向传递的蒙特卡罗平均算法。首先,设置集合大小 。
num_ensemble = 10
蒙特卡罗随机失活
给定具有随机失活层的经训练的神经网络,蒙特卡罗随机失活会计算平均预测概率
$$E(p(x)) = \frac{1}{M}\sum_{m=1}^M softmax(logit_m(x))$$
by averaging over multiple Dropout-enabled forward passes .
def mc_dropout_sampling(test_examples):
# Enable dropout during inference.
return resnet_model(test_examples, training=True)
# Monte Carlo dropout inference.
dropout_logit_samples = [mc_dropout_sampling(test_examples) for _ in range(num_ensemble)]
dropout_prob_samples = [tf.nn.softmax(dropout_logits, axis=-1)[:, 0] for dropout_logits in dropout_logit_samples]
dropout_probs = tf.reduce_mean(dropout_prob_samples, axis=0)
dropout_probs = tf.reduce_mean(dropout_prob_samples, axis=0)
plot_predictions(dropout_probs, model_name="MC Dropout")
深度集成
深度集成是一种用于深度学习不确定性的最先进(但耗费算力)的方法。要训练深度集成,首先需要训练 个集合成员。
# Deep ensemble training
resnet_ensemble = []
for _ in range(num_ensemble):
resnet_model = DeepResNet(**resnet_config)
resnet_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
resnet_model.fit(train_examples, train_labels, verbose=0, **fit_config)
resnet_ensemble.append(resnet_model)
收集 logit 并计算平均预测概率 。
# Deep ensemble inference
ensemble_logit_samples = [model(test_examples) for model in resnet_ensemble]
ensemble_prob_samples = [tf.nn.softmax(logits, axis=-1)[:, 0] for logits in ensemble_logit_samples]
ensemble_probs = tf.reduce_mean(ensemble_prob_samples, axis=0)
plot_predictions(ensemble_probs, model_name="Deep ensemble")
蒙特卡罗随机失活和深度集成方法都会通过降低决策边界的确定性来提高模型的不确定性能力。然而,二者均继承了确定性深度网络在缺乏距离感知方面的局限性。
总结
在本教程中,您已:
- 在深度分类器上实现了 SNGP 模型以提高其距离感知能力。
- 使用 Keras
Model.fit
API 端到端地训练了 SNGP 模型。 - 呈现了 SNGP 的不确定性行为。
- 比较了 SNGP、蒙特卡罗随机失活和深度集成模型之间的不确定性行为。
资源和延伸阅读
- 请参阅 SNGP-BERT 教程以查看在 BERT 模型上应用 SNGP 以实现不确定性感知型自然语言理解的示例。
- 请转到不确定性基线 GitHub 仓库{.external}以查看在各种基准数据集(例如,CIFAR、ImageNet、Jigsaw 恶意检测等)上实现 SNGP 模型(和许多其他不确定性方法)的方式。
- 如需更深入地了解 SNGP 方法,请参阅题为 Simple and Principled Uncertainty Estimation with Deterministic Deep Learning via Distance Awareness{.external} 的论文。