Deep Language Learning com reconhecimento de incertezas com BERT-SNGP

Veja no TensorFlow.org Executar no Google Colab Ver no GitHub Baixar caderno Veja o modelo do TF Hub

No tutorial do SNGP , você aprendeu como construir o modelo SNGP sobre uma rede residual profunda para melhorar sua capacidade de quantificar sua incerteza. Neste tutorial, você aplicará o SNGP a uma tarefa de compreensão de linguagem natural (NLU) construindo-a em cima de um codificador BERT profundo para melhorar a capacidade do modelo NLU profundo de detectar consultas fora do escopo.

Especificamente, você irá:

  • Construa o BERT-SNGP, um modelo de BERT aumentado pelo SNGP.
  • Carregue o conjunto de dados de detecção de intenção CLINC fora do escopo (OOS) .
  • Treine o modelo BERT-SNGP.
  • Avalie o desempenho do modelo BERT-SNGP na calibração de incerteza e detecção fora de domínio.

Além do CLINC OOS, o modelo SNGP foi aplicado a conjuntos de dados de grande escala, como a detecção de toxicidade Jigsaw , e a conjuntos de dados de imagem, como CIFAR-100 e ImageNet . Para resultados de benchmark do SNGP e outros métodos de incerteza, bem como implementação de alta qualidade com scripts de treinamento/avaliação de ponta a ponta, você pode conferir o benchmark Uncertainty Baselines .

Configurar

pip uninstall -y tensorflow tf-text
pip install -U tensorflow-text-nightly
pip install -U tf-nightly
pip install -U tf-models-nightly
import matplotlib.pyplot as plt

import sklearn.metrics
import sklearn.calibration

import tensorflow_hub as hub
import tensorflow_datasets as tfds

import numpy as np
import tensorflow as tf

import official.nlp.modeling.layers as layers
import official.nlp.optimization as optimization
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_addons/utils/ensure_tf_install.py:43: UserWarning: You are currently using a nightly version of TensorFlow (2.9.0-dev20220203). 
TensorFlow Addons offers no support for the nightly versions of TensorFlow. Some things might work, some other might not. 
If you encounter a bug, do not file an issue on GitHub.
  UserWarning,

Este tutorial precisa que a GPU seja executada com eficiência. Verifique se a GPU está disponível.

tf.__version__
'2.9.0-dev20220203'
gpus = tf.config.list_physical_devices('GPU')
gpus
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
assert gpus, """
  No GPU(s) found! This tutorial will take many hours to run without a GPU.

  You may hit this error if the installed tensorflow package is not
  compatible with the CUDA and CUDNN versions."""

Primeiro implemente um classificador BERT padrão seguindo o tutorial classificar texto com BERT . Usaremos o codificador baseado em BERT e o ClassificationHead integrado como classificador.

Modelo BERT padrão

Construir modelo SNGP

Para implementar um modelo BERT-SNGP, você só precisa substituir o ClassificationHead pelo GaussianProcessClassificationHead integrado. A normalização espectral já está pré-empacotada neste cabeçalho de classificação. Como no tutorial do SNGP , adicione um retorno de chamada de redefinição de covariância ao modelo, para que o modelo redefina automaticamente o estimador de covariância no início de uma nova época para evitar contar os mesmos dados duas vezes.

class ResetCovarianceCallback(tf.keras.callbacks.Callback):

  def on_epoch_begin(self, epoch, logs=None):
    """Resets covariance matrix at the begining of the epoch."""
    if epoch > 0:
      self.model.classifier.reset_covariance_matrix()
class SNGPBertClassifier(BertClassifier):

  def make_classification_head(self, num_classes, inner_dim, dropout_rate):
    return layers.GaussianProcessClassificationHead(
        num_classes=num_classes, 
        inner_dim=inner_dim,
        dropout_rate=dropout_rate,
        gp_cov_momentum=-1,
        temperature=30.,
        **self.classifier_kwargs)

  def fit(self, *args, **kwargs):
    """Adds ResetCovarianceCallback to model callbacks."""
    kwargs['callbacks'] = list(kwargs.get('callbacks', []))
    kwargs['callbacks'].append(ResetCovarianceCallback())

    return super().fit(*args, **kwargs)

Carregar conjunto de dados CLINC OOS

Agora carregue o conjunto de dados de detecção de intenção CLINC OOS . Esse conjunto de dados contém 15.000 consultas faladas do usuário coletadas em 150 classes de intenção, ele também contém 1.000 sentenças fora do domínio (OOD) que não são cobertas por nenhuma das classes conhecidas.

(clinc_train, clinc_test, clinc_test_oos), ds_info = tfds.load(
    'clinc_oos', split=['train', 'test', 'test_oos'], with_info=True, batch_size=-1)

Faça o trem e teste os dados.

train_examples = clinc_train['text']
train_labels = clinc_train['intent']

# Makes the in-domain (IND) evaluation data.
ind_eval_data = (clinc_test['text'], clinc_test['intent'])

Crie um conjunto de dados de avaliação OOD. Para isso, combine os dados de teste no domínio clinc_test e os dados fora do domínio clinc_test_oos . Também atribuiremos o rótulo 0 aos exemplos no domínio e o rótulo 1 aos exemplos fora do domínio.

test_data_size = ds_info.splits['test'].num_examples
oos_data_size = ds_info.splits['test_oos'].num_examples

# Combines the in-domain and out-of-domain test examples.
oos_texts = tf.concat([clinc_test['text'], clinc_test_oos['text']], axis=0)
oos_labels = tf.constant([0] * test_data_size + [1] * oos_data_size)

# Converts into a TF dataset.
ood_eval_dataset = tf.data.Dataset.from_tensor_slices(
    {"text": oos_texts, "label": oos_labels})

Treinar e avaliar

Primeiro, configure as configurações básicas de treinamento.

TRAIN_EPOCHS = 3
TRAIN_BATCH_SIZE = 32
EVAL_BATCH_SIZE = 256

optimizer = bert_optimizer(learning_rate=1e-4)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = tf.metrics.SparseCategoricalAccuracy()
fit_configs = dict(batch_size=TRAIN_BATCH_SIZE,
                   epochs=TRAIN_EPOCHS,
                   validation_batch_size=EVAL_BATCH_SIZE, 
                   validation_data=ind_eval_data)
sngp_model = SNGPBertClassifier()
sngp_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
sngp_model.fit(train_examples, train_labels, **fit_configs)
Epoch 1/3
469/469 [==============================] - 219s 427ms/step - loss: 1.0725 - sparse_categorical_accuracy: 0.7870 - val_loss: 0.4358 - val_sparse_categorical_accuracy: 0.9380
Epoch 2/3
469/469 [==============================] - 198s 422ms/step - loss: 0.0885 - sparse_categorical_accuracy: 0.9797 - val_loss: 0.2424 - val_sparse_categorical_accuracy: 0.9518
Epoch 3/3
469/469 [==============================] - 199s 424ms/step - loss: 0.0259 - sparse_categorical_accuracy: 0.9951 - val_loss: 0.1927 - val_sparse_categorical_accuracy: 0.9642
<keras.callbacks.History at 0x7fe24c0a7090>

Avalie o desempenho OOD

Avalie quão bem o modelo pode detectar as consultas fora do domínio desconhecidas. Para uma avaliação rigorosa, use o conjunto de dados de avaliação OOD ood_eval_dataset criado anteriormente.

Calcula as probabilidades OOD como \(1 - p(x)\), onde \(p(x)=softmax(logit(x))\) é a probabilidade preditiva.

sngp_probs, ood_labels = oos_predict(sngp_model, ood_eval_dataset)
ood_probs = 1 - sngp_probs

Agora avalie quão bem a pontuação de incerteza do modelo ood_probs prevê o rótulo fora do domínio. Primeiro, calcule a área sob a curva de recuperação de precisão (AUPRC) para probabilidade de OOD versus precisão de detecção de OOD.

precision, recall, _ = sklearn.metrics.precision_recall_curve(ood_labels, ood_probs)
auprc = sklearn.metrics.auc(recall, precision)
print(f'SNGP AUPRC: {auprc:.4f}')
SNGP AUPRC: 0.9039

Isso corresponde ao desempenho do SNGP relatado no benchmark CLINC OOS em Uncertainty Baselines .

Em seguida, examine a qualidade do modelo na calibração da incerteza , ou seja, se a probabilidade preditiva do modelo corresponde à sua precisão preditiva. Um modelo bem calibrado é considerado confiável, pois, por exemplo, sua probabilidade preditiva \(p(x)=0.8\) significa que o modelo está correto 80% das vezes.

prob_true, prob_pred = sklearn.calibration.calibration_curve(
    ood_labels, ood_probs, n_bins=10, strategy='quantile')
plt.plot(prob_pred, prob_true)

plt.plot([0., 1.], [0., 1.], c='k', linestyle="--")
plt.xlabel('Predictive Probability')
plt.ylabel('Predictive Accuracy')
plt.title('Calibration Plots, SNGP')

plt.show()

png

Recursos e leitura adicional