Ver en TensorFlow.org | Ejecutar en Google Colab | Ver en GitHub | Descargar libreta | Ver modelo TF Hub |
En el tutorial de SNGP , aprendió a construir un modelo SNGP sobre una red residual profunda para mejorar su capacidad de cuantificar su incertidumbre. En este tutorial, aplicará SNGP a una tarea de comprensión del lenguaje natural (NLU) construyéndolo sobre un codificador BERT profundo para mejorar la capacidad del modelo NLU profundo para detectar consultas fuera del alcance.
Específicamente, usted:
- Cree BERT-SNGP, un modelo BERT aumentado con SNGP.
- Cargue el conjunto de datos de detección de intenciones fuera de alcance (OOS) de CLINC .
- Entrena el modelo BERT-SNGP.
- Evalúe el rendimiento del modelo BERT-SNGP en calibración de incertidumbre y detección fuera de dominio.
Más allá de CLINC OOS, el modelo SNGP se ha aplicado a conjuntos de datos a gran escala, como la detección de toxicidad de Jigsaw , y a conjuntos de datos de imágenes, como CIFAR-100 e ImageNet . Para obtener resultados de referencia de SNGP y otros métodos de incertidumbre, así como una implementación de alta calidad con scripts de capacitación/evaluación de extremo a extremo, puede consultar la referencia de referencia de incertidumbre .
Configuración
pip uninstall -y tensorflow tf-text
pip install -U tensorflow-text-nightly
pip install -U tf-nightly
pip install -U tf-models-nightly
import matplotlib.pyplot as plt
import sklearn.metrics
import sklearn.calibration
import tensorflow_hub as hub
import tensorflow_datasets as tfds
import numpy as np
import tensorflow as tf
import official.nlp.modeling.layers as layers
import official.nlp.optimization as optimization
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_addons/utils/ensure_tf_install.py:43: UserWarning: You are currently using a nightly version of TensorFlow (2.9.0-dev20220203). TensorFlow Addons offers no support for the nightly versions of TensorFlow. Some things might work, some other might not. If you encounter a bug, do not file an issue on GitHub. UserWarning,
Este tutorial necesita la GPU para ejecutarse de manera eficiente. Compruebe si la GPU está disponible.
tf.__version__
'2.9.0-dev20220203'
gpus = tf.config.list_physical_devices('GPU')
gpus
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
assert gpus, """
No GPU(s) found! This tutorial will take many hours to run without a GPU.
You may hit this error if the installed tensorflow package is not
compatible with the CUDA and CUDNN versions."""
Primero implemente un clasificador BERT estándar siguiendo el tutorial de clasificación de texto con BERT . Usaremos el codificador basado en BERT y el ClassificationHead
integrado como clasificador.
Modelo BERT estándar
PREPROCESS_HANDLE = 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'
MODEL_HANDLE = 'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3'
class BertClassifier(tf.keras.Model):
def __init__(self,
num_classes=150, inner_dim=768, dropout_rate=0.1,
**classifier_kwargs):
super().__init__()
self.classifier_kwargs = classifier_kwargs
# Initiate the BERT encoder components.
self.bert_preprocessor = hub.KerasLayer(PREPROCESS_HANDLE, name='preprocessing')
self.bert_hidden_layer = hub.KerasLayer(MODEL_HANDLE, trainable=True, name='bert_encoder')
# Defines the encoder and classification layers.
self.bert_encoder = self.make_bert_encoder()
self.classifier = self.make_classification_head(num_classes, inner_dim, dropout_rate)
def make_bert_encoder(self):
text_inputs = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
encoder_inputs = self.bert_preprocessor(text_inputs)
encoder_outputs = self.bert_hidden_layer(encoder_inputs)
return tf.keras.Model(text_inputs, encoder_outputs)
def make_classification_head(self, num_classes, inner_dim, dropout_rate):
return layers.ClassificationHead(
num_classes=num_classes,
inner_dim=inner_dim,
dropout_rate=dropout_rate,
**self.classifier_kwargs)
def call(self, inputs, **kwargs):
encoder_outputs = self.bert_encoder(inputs)
classifier_inputs = encoder_outputs['sequence_output']
return self.classifier(classifier_inputs, **kwargs)
Construir modelo SNGP
Para implementar un modelo BERT-SNGP, solo necesita reemplazar el ClassificationHead
con el GaussianProcessClassificationHead
incorporado. La normalización espectral ya está preempaquetada en este cabezal de clasificación. Al igual que en el tutorial de SNGP , agregue una devolución de llamada de restablecimiento de covarianza al modelo, de modo que el modelo restablezca automáticamente el estimador de covarianza al comienzo de una nueva época para evitar contar los mismos datos dos veces.
class ResetCovarianceCallback(tf.keras.callbacks.Callback):
def on_epoch_begin(self, epoch, logs=None):
"""Resets covariance matrix at the begining of the epoch."""
if epoch > 0:
self.model.classifier.reset_covariance_matrix()
class SNGPBertClassifier(BertClassifier):
def make_classification_head(self, num_classes, inner_dim, dropout_rate):
return layers.GaussianProcessClassificationHead(
num_classes=num_classes,
inner_dim=inner_dim,
dropout_rate=dropout_rate,
gp_cov_momentum=-1,
temperature=30.,
**self.classifier_kwargs)
def fit(self, *args, **kwargs):
"""Adds ResetCovarianceCallback to model callbacks."""
kwargs['callbacks'] = list(kwargs.get('callbacks', []))
kwargs['callbacks'].append(ResetCovarianceCallback())
return super().fit(*args, **kwargs)
Cargar conjunto de datos CLINC OOS
Ahora cargue el conjunto de datos de detección de intenciones de CLINC OOS . Este conjunto de datos contiene 15000 consultas habladas de usuarios recopiladas en más de 150 clases de intención, también contiene 1000 oraciones fuera del dominio (OOD) que no están cubiertas por ninguna de las clases conocidas.
(clinc_train, clinc_test, clinc_test_oos), ds_info = tfds.load(
'clinc_oos', split=['train', 'test', 'test_oos'], with_info=True, batch_size=-1)
Haz el tren y prueba los datos.
train_examples = clinc_train['text']
train_labels = clinc_train['intent']
# Makes the in-domain (IND) evaluation data.
ind_eval_data = (clinc_test['text'], clinc_test['intent'])
Cree un conjunto de datos de evaluación OOD. Para esto, combine los datos de prueba en el dominio clinc_test
y los datos fuera del dominio clinc_test_oos
. También asignaremos la etiqueta 0 a los ejemplos dentro del dominio y la etiqueta 1 a los ejemplos fuera del dominio.
test_data_size = ds_info.splits['test'].num_examples
oos_data_size = ds_info.splits['test_oos'].num_examples
# Combines the in-domain and out-of-domain test examples.
oos_texts = tf.concat([clinc_test['text'], clinc_test_oos['text']], axis=0)
oos_labels = tf.constant([0] * test_data_size + [1] * oos_data_size)
# Converts into a TF dataset.
ood_eval_dataset = tf.data.Dataset.from_tensor_slices(
{"text": oos_texts, "label": oos_labels})
Formar y evaluar
Primero configure las configuraciones básicas de entrenamiento.
TRAIN_EPOCHS = 3
TRAIN_BATCH_SIZE = 32
EVAL_BATCH_SIZE = 256
def bert_optimizer(learning_rate,
batch_size=TRAIN_BATCH_SIZE, epochs=TRAIN_EPOCHS,
warmup_rate=0.1):
"""Creates an AdamWeightDecay optimizer with learning rate schedule."""
train_data_size = ds_info.splits['train'].num_examples
steps_per_epoch = int(train_data_size / batch_size)
num_train_steps = steps_per_epoch * epochs
num_warmup_steps = int(warmup_rate * num_train_steps)
# Creates learning schedule.
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=learning_rate,
decay_steps=num_train_steps,
end_learning_rate=0.0)
return optimization.AdamWeightDecay(
learning_rate=lr_schedule,
weight_decay_rate=0.01,
epsilon=1e-6,
exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'])
optimizer = bert_optimizer(learning_rate=1e-4)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = tf.metrics.SparseCategoricalAccuracy()
fit_configs = dict(batch_size=TRAIN_BATCH_SIZE,
epochs=TRAIN_EPOCHS,
validation_batch_size=EVAL_BATCH_SIZE,
validation_data=ind_eval_data)
sngp_model = SNGPBertClassifier()
sngp_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
sngp_model.fit(train_examples, train_labels, **fit_configs)
Epoch 1/3 469/469 [==============================] - 219s 427ms/step - loss: 1.0725 - sparse_categorical_accuracy: 0.7870 - val_loss: 0.4358 - val_sparse_categorical_accuracy: 0.9380 Epoch 2/3 469/469 [==============================] - 198s 422ms/step - loss: 0.0885 - sparse_categorical_accuracy: 0.9797 - val_loss: 0.2424 - val_sparse_categorical_accuracy: 0.9518 Epoch 3/3 469/469 [==============================] - 199s 424ms/step - loss: 0.0259 - sparse_categorical_accuracy: 0.9951 - val_loss: 0.1927 - val_sparse_categorical_accuracy: 0.9642 <keras.callbacks.History at 0x7fe24c0a7090>
Evaluar el rendimiento de OOD
Evalúe qué tan bien el modelo puede detectar las consultas desconocidas fuera del dominio. Para una evaluación rigurosa, utilice el conjunto de datos de evaluación OOD ood_eval_dataset
creado anteriormente.
def oos_predict(model, ood_eval_dataset, **model_kwargs):
oos_labels = []
oos_probs = []
ood_eval_dataset = ood_eval_dataset.batch(EVAL_BATCH_SIZE)
for oos_batch in ood_eval_dataset:
oos_text_batch = oos_batch["text"]
oos_label_batch = oos_batch["label"]
pred_logits = model(oos_text_batch, **model_kwargs)
pred_probs_all = tf.nn.softmax(pred_logits, axis=-1)
pred_probs = tf.reduce_max(pred_probs_all, axis=-1)
oos_labels.append(oos_label_batch)
oos_probs.append(pred_probs)
oos_probs = tf.concat(oos_probs, axis=0)
oos_labels = tf.concat(oos_labels, axis=0)
return oos_probs, oos_labels
Calcula las probabilidades OOD como \(1 - p(x)\), donde \(p(x)=softmax(logit(x))\) es la probabilidad predictiva.
sngp_probs, ood_labels = oos_predict(sngp_model, ood_eval_dataset)
ood_probs = 1 - sngp_probs
Ahora evalúe qué tan bien la puntuación de incertidumbre del modelo ood_probs
predice la etiqueta fuera del dominio. Primero calcule el área bajo la curva de recuperación de precisión (AUPRC) para la probabilidad de OOD frente a la precisión de detección de OOD.
precision, recall, _ = sklearn.metrics.precision_recall_curve(ood_labels, ood_probs)
auprc = sklearn.metrics.auc(recall, precision)
print(f'SNGP AUPRC: {auprc:.4f}')
SNGP AUPRC: 0.9039
Esto coincide con el rendimiento de SNGP informado en el punto de referencia CLINC OOS en las líneas de base de incertidumbre .
A continuación, examine la calidad del modelo en la calibración de la incertidumbre , es decir, si la probabilidad predictiva del modelo se corresponde con su precisión predictiva. Un modelo bien calibrado se considera digno de confianza, ya que, por ejemplo, su probabilidad predictiva \(p(x)=0.8\) significa que el modelo es correcto el 80% de las veces.
prob_true, prob_pred = sklearn.calibration.calibration_curve(
ood_labels, ood_probs, n_bins=10, strategy='quantile')
plt.plot(prob_pred, prob_true)
plt.plot([0., 1.], [0., 1.], c='k', linestyle="--")
plt.xlabel('Predictive Probability')
plt.ylabel('Predictive Accuracy')
plt.title('Calibration Plots, SNGP')
plt.show()
Recursos y lecturas adicionales
- Consulte el tutorial de SNGP para ver un recorrido detallado sobre la implementación de SNGP desde cero.
- Consulte las líneas de base de incertidumbre para la implementación del modelo SNGP (y muchos otros métodos de incertidumbre) en una amplia variedad de conjuntos de datos de referencia (p. ej., CIFAR , ImageNet , detección de toxicidad de rompecabezas, etc.).
- Para una comprensión más profunda del método SNGP, consulte el documento Estimación de incertidumbre simple y basada en principios con aprendizaje profundo determinista a través de la conciencia de la distancia .