Apprentissage des langues en profondeur sensible à l'incertitude avec BERT-SNGP

Voir sur TensorFlow.org Exécuter dans Google Colab Afficher sur GitHub Télécharger le cahier Voir le modèle TF Hub

Dans le didacticiel SNGP , vous avez appris à créer un modèle SNGP au-dessus d'un réseau résiduel profond pour améliorer sa capacité à quantifier son incertitude. Dans ce didacticiel, vous appliquerez SNGP à une tâche de compréhension du langage naturel (NLU) en le construisant au-dessus d'un encodeur BERT profond pour améliorer la capacité du modèle NLU profond à détecter les requêtes hors de portée.

Concrètement, vous allez :

  • Construisez BERT-SNGP, un modèle BERT augmenté de SNGP.
  • Chargez l'ensemble de données de détection d'intention CLINC Out-of-scope (OOS) .
  • Entraînez le modèle BERT-SNGP.
  • Évaluez les performances du modèle BERT-SNGP en matière d'étalonnage d'incertitude et de détection hors domaine.

Au-delà de CLINC OOS, le modèle SNGP a été appliqué à des ensembles de données à grande échelle tels que la détection de toxicité Jigsaw et à des ensembles de données d'images tels que CIFAR-100 et ImageNet . Pour les résultats de référence du SNGP et d'autres méthodes d'incertitude, ainsi qu'une mise en œuvre de haute qualité avec des scripts de formation/d'évaluation de bout en bout, vous pouvez consulter le benchmark Uncertainty Baselines .

Installer

pip uninstall -y tensorflow tf-text
pip install -U tensorflow-text-nightly
pip install -U tf-nightly
pip install -U tf-models-nightly
import matplotlib.pyplot as plt

import sklearn.metrics
import sklearn.calibration

import tensorflow_hub as hub
import tensorflow_datasets as tfds

import numpy as np
import tensorflow as tf

import official.nlp.modeling.layers as layers
import official.nlp.optimization as optimization
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_addons/utils/ensure_tf_install.py:43: UserWarning: You are currently using a nightly version of TensorFlow (2.9.0-dev20220203). 
TensorFlow Addons offers no support for the nightly versions of TensorFlow. Some things might work, some other might not. 
If you encounter a bug, do not file an issue on GitHub.
  UserWarning,

Ce tutoriel a besoin du GPU pour fonctionner efficacement. Vérifiez si le GPU est disponible.

tf.__version__
'2.9.0-dev20220203'
gpus = tf.config.list_physical_devices('GPU')
gpus
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
assert gpus, """
  No GPU(s) found! This tutorial will take many hours to run without a GPU.

  You may hit this error if the installed tensorflow package is not
  compatible with the CUDA and CUDNN versions."""

Implémentez d'abord un classificateur BERT standard en suivant le tutoriel de classification de texte avec BERT . Nous utiliserons l'encodeur de base BERT et le ClassificationHead intégré comme classificateur.

Modèle BERT standard

Créer un modèle SNGP

Pour implémenter un modèle BERT-SNGP, il vous suffit de remplacer le ClassificationHead par le GaussianProcessClassificationHead intégré. La normalisation spectrale est déjà pré-emballée dans cette tête de classification. Comme dans le didacticiel SNGP , ajoutez un rappel de réinitialisation de covariance au modèle, afin que le modèle réinitialise automatiquement l'estimateur de covariance au début d'une nouvelle époque pour éviter de compter deux fois les mêmes données.

class ResetCovarianceCallback(tf.keras.callbacks.Callback):

  def on_epoch_begin(self, epoch, logs=None):
    """Resets covariance matrix at the begining of the epoch."""
    if epoch > 0:
      self.model.classifier.reset_covariance_matrix()
class SNGPBertClassifier(BertClassifier):

  def make_classification_head(self, num_classes, inner_dim, dropout_rate):
    return layers.GaussianProcessClassificationHead(
        num_classes=num_classes, 
        inner_dim=inner_dim,
        dropout_rate=dropout_rate,
        gp_cov_momentum=-1,
        temperature=30.,
        **self.classifier_kwargs)

  def fit(self, *args, **kwargs):
    """Adds ResetCovarianceCallback to model callbacks."""
    kwargs['callbacks'] = list(kwargs.get('callbacks', []))
    kwargs['callbacks'].append(ResetCovarianceCallback())

    return super().fit(*args, **kwargs)

Charger le jeu de données CLINC OOS

Chargez maintenant l'ensemble de données de détection d'intention CLINC OOS . Cet ensemble de données contient 15 000 requêtes vocales d'utilisateurs collectées sur 150 classes d'intention, il contient également 1 000 phrases hors domaine (OOD) qui ne sont couvertes par aucune des classes connues.

(clinc_train, clinc_test, clinc_test_oos), ds_info = tfds.load(
    'clinc_oos', split=['train', 'test', 'test_oos'], with_info=True, batch_size=-1)

Faites le train et testez les données.

train_examples = clinc_train['text']
train_labels = clinc_train['intent']

# Makes the in-domain (IND) evaluation data.
ind_eval_data = (clinc_test['text'], clinc_test['intent'])

Créez un ensemble de données d'évaluation OOD. Pour cela, combinez les données de test dans le domaine clinc_test et les données hors domaine clinc_test_oos . Nous attribuerons également l'étiquette 0 aux exemples dans le domaine et l'étiquette 1 aux exemples hors domaine.

test_data_size = ds_info.splits['test'].num_examples
oos_data_size = ds_info.splits['test_oos'].num_examples

# Combines the in-domain and out-of-domain test examples.
oos_texts = tf.concat([clinc_test['text'], clinc_test_oos['text']], axis=0)
oos_labels = tf.constant([0] * test_data_size + [1] * oos_data_size)

# Converts into a TF dataset.
ood_eval_dataset = tf.data.Dataset.from_tensor_slices(
    {"text": oos_texts, "label": oos_labels})

Former et évaluer

Configurez d'abord les configurations de formation de base.

TRAIN_EPOCHS = 3
TRAIN_BATCH_SIZE = 32
EVAL_BATCH_SIZE = 256

optimizer = bert_optimizer(learning_rate=1e-4)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = tf.metrics.SparseCategoricalAccuracy()
fit_configs = dict(batch_size=TRAIN_BATCH_SIZE,
                   epochs=TRAIN_EPOCHS,
                   validation_batch_size=EVAL_BATCH_SIZE, 
                   validation_data=ind_eval_data)
sngp_model = SNGPBertClassifier()
sngp_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
sngp_model.fit(train_examples, train_labels, **fit_configs)
Epoch 1/3
469/469 [==============================] - 219s 427ms/step - loss: 1.0725 - sparse_categorical_accuracy: 0.7870 - val_loss: 0.4358 - val_sparse_categorical_accuracy: 0.9380
Epoch 2/3
469/469 [==============================] - 198s 422ms/step - loss: 0.0885 - sparse_categorical_accuracy: 0.9797 - val_loss: 0.2424 - val_sparse_categorical_accuracy: 0.9518
Epoch 3/3
469/469 [==============================] - 199s 424ms/step - loss: 0.0259 - sparse_categorical_accuracy: 0.9951 - val_loss: 0.1927 - val_sparse_categorical_accuracy: 0.9642
<keras.callbacks.History at 0x7fe24c0a7090>

Évaluer les performances OOD

Évaluez dans quelle mesure le modèle peut détecter les requêtes hors domaine inconnues. Pour une évaluation rigoureuse, utilisez le jeu de données d'évaluation OOD ood_eval_dataset créé précédemment.

Calcule les probabilités OOD sous la \(1 - p(x)\), où \(p(x)=softmax(logit(x))\) est la probabilité prédictive.

sngp_probs, ood_labels = oos_predict(sngp_model, ood_eval_dataset)
ood_probs = 1 - sngp_probs

Évaluez maintenant dans quelle mesure le score d'incertitude du modèle ood_probs prédit l'étiquette hors domaine. Calculez d'abord l'aire sous la courbe de rappel de précision (AUPRC) pour la probabilité OOD par rapport à la précision de détection OOD.

precision, recall, _ = sklearn.metrics.precision_recall_curve(ood_labels, ood_probs)
auprc = sklearn.metrics.auc(recall, precision)
print(f'SNGP AUPRC: {auprc:.4f}')
SNGP AUPRC: 0.9039

Cela correspond aux performances SNGP rapportées au benchmark CLINC OOS sous les lignes de base d'incertitude .

Ensuite, examinez la qualité du modèle dans l'étalonnage de l'incertitude , c'est-à-dire si la probabilité prédictive du modèle correspond à sa précision prédictive. Un modèle bien calibré est considéré comme digne de confiance puisque, par exemple, sa probabilité prédictive \(p(x)=0.8\) signifie que le modèle est correct 80 % du temps.

prob_true, prob_pred = sklearn.calibration.calibration_curve(
    ood_labels, ood_probs, n_bins=10, strategy='quantile')
plt.plot(prob_pred, prob_true)

plt.plot([0., 1.], [0., 1.], c='k', linestyle="--")
plt.xlabel('Predictive Probability')
plt.ylabel('Predictive Accuracy')
plt.title('Calibration Plots, SNGP')

plt.show()

png

Ressources et lectures complémentaires