Apprendimento approfondito delle lingue consapevole dell'incertezza con BERT-SNGP

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza su GitHub Scarica quaderno Vedi modello TF Hub

Nel tutorial SNGP , hai imparato come costruire il modello SNGP su una rete residua profonda per migliorare la sua capacità di quantificare la sua incertezza. In questo tutorial, applicherai SNGP a un'attività di comprensione del linguaggio naturale (NLU) costruendolo su un codificatore BERT profondo per migliorare la capacità del modello NLU profondo di rilevare query fuori ambito.

Nello specifico dovrai:

  • Costruisci BERT-SNGP, un modello BERT potenziato con SNGP.
  • Carica il set di dati di rilevamento dell'intento fuori ambito (OOS) CLINC .
  • Allena il modello BERT-SNGP.
  • Valuta le prestazioni del modello BERT-SNGP nella calibrazione dell'incertezza e nel rilevamento fuori dominio.

Oltre a CLINC OOS, il modello SNGP è stato applicato a set di dati su larga scala come il rilevamento della tossicità di Jigsaw e ai set di dati di immagini come CIFAR-100 e ImageNet . Per i risultati del benchmark di SNGP e altri metodi di incertezza, nonché per un'implementazione di alta qualità con script di formazione/valutazione end-to-end, puoi consultare il benchmark Uncertainty Baselines .

Impostare

pip uninstall -y tensorflow tf-text
pip install -U tensorflow-text-nightly
pip install -U tf-nightly
pip install -U tf-models-nightly
import matplotlib.pyplot as plt

import sklearn.metrics
import sklearn.calibration

import tensorflow_hub as hub
import tensorflow_datasets as tfds

import numpy as np
import tensorflow as tf

import official.nlp.modeling.layers as layers
import official.nlp.optimization as optimization
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_addons/utils/ensure_tf_install.py:43: UserWarning: You are currently using a nightly version of TensorFlow (2.9.0-dev20220203). 
TensorFlow Addons offers no support for the nightly versions of TensorFlow. Some things might work, some other might not. 
If you encounter a bug, do not file an issue on GitHub.
  UserWarning,

Questo tutorial richiede che la GPU funzioni in modo efficiente. Controlla se la GPU è disponibile.

tf.__version__
'2.9.0-dev20220203'
gpus = tf.config.list_physical_devices('GPU')
gpus
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
assert gpus, """
  No GPU(s) found! This tutorial will take many hours to run without a GPU.

  You may hit this error if the installed tensorflow package is not
  compatible with the CUDA and CUDNN versions."""

Innanzitutto implementa un classificatore BERT standard seguendo il testo di classificazione con il tutorial BERT . Useremo l'encoder BERT-base e il ClassificationHead integrato come classificatore.

Modello BERT standard

Costruisci il modello SNGP

Per implementare un modello BERT-SNGP, devi solo sostituire ClassificationHead con GaussianProcessClassificationHead integrato. La normalizzazione spettrale è già preconfezionata in questa testata di classificazione. Come nel tutorial SNGP , aggiungi un callback di ripristino della covarianza al modello, in modo che il modello reimposti automaticamente lo stimatore di covarianza all'inizio di una nuova epoca per evitare di contare due volte gli stessi dati.

class ResetCovarianceCallback(tf.keras.callbacks.Callback):

  def on_epoch_begin(self, epoch, logs=None):
    """Resets covariance matrix at the begining of the epoch."""
    if epoch > 0:
      self.model.classifier.reset_covariance_matrix()
class SNGPBertClassifier(BertClassifier):

  def make_classification_head(self, num_classes, inner_dim, dropout_rate):
    return layers.GaussianProcessClassificationHead(
        num_classes=num_classes, 
        inner_dim=inner_dim,
        dropout_rate=dropout_rate,
        gp_cov_momentum=-1,
        temperature=30.,
        **self.classifier_kwargs)

  def fit(self, *args, **kwargs):
    """Adds ResetCovarianceCallback to model callbacks."""
    kwargs['callbacks'] = list(kwargs.get('callbacks', []))
    kwargs['callbacks'].append(ResetCovarianceCallback())

    return super().fit(*args, **kwargs)

Carica il set di dati CLINC OOS

Ora carica il set di dati di rilevamento dell'intento CLINC OOS . Questo set di dati contiene 15000 query vocali dell'utente raccolte su 150 classi di intenti, contiene anche 1000 frasi fuori dominio (OOD) che non sono coperte da nessuna delle classi note.

(clinc_train, clinc_test, clinc_test_oos), ds_info = tfds.load(
    'clinc_oos', split=['train', 'test', 'test_oos'], with_info=True, batch_size=-1)

Crea il treno e prova i dati.

train_examples = clinc_train['text']
train_labels = clinc_train['intent']

# Makes the in-domain (IND) evaluation data.
ind_eval_data = (clinc_test['text'], clinc_test['intent'])

Crea un set di dati di valutazione OOD. A tale scopo, combina i dati di test interni al dominio clinc_test e i dati esterni al dominio clinc_test_oos . Assegneremo anche l'etichetta 0 agli esempi all'interno del dominio e l'etichetta 1 agli esempi al di fuori del dominio.

test_data_size = ds_info.splits['test'].num_examples
oos_data_size = ds_info.splits['test_oos'].num_examples

# Combines the in-domain and out-of-domain test examples.
oos_texts = tf.concat([clinc_test['text'], clinc_test_oos['text']], axis=0)
oos_labels = tf.constant([0] * test_data_size + [1] * oos_data_size)

# Converts into a TF dataset.
ood_eval_dataset = tf.data.Dataset.from_tensor_slices(
    {"text": oos_texts, "label": oos_labels})

Allenati e valuta

Per prima cosa, imposta le configurazioni di formazione di base.

TRAIN_EPOCHS = 3
TRAIN_BATCH_SIZE = 32
EVAL_BATCH_SIZE = 256

optimizer = bert_optimizer(learning_rate=1e-4)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = tf.metrics.SparseCategoricalAccuracy()
fit_configs = dict(batch_size=TRAIN_BATCH_SIZE,
                   epochs=TRAIN_EPOCHS,
                   validation_batch_size=EVAL_BATCH_SIZE, 
                   validation_data=ind_eval_data)
sngp_model = SNGPBertClassifier()
sngp_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
sngp_model.fit(train_examples, train_labels, **fit_configs)
Epoch 1/3
469/469 [==============================] - 219s 427ms/step - loss: 1.0725 - sparse_categorical_accuracy: 0.7870 - val_loss: 0.4358 - val_sparse_categorical_accuracy: 0.9380
Epoch 2/3
469/469 [==============================] - 198s 422ms/step - loss: 0.0885 - sparse_categorical_accuracy: 0.9797 - val_loss: 0.2424 - val_sparse_categorical_accuracy: 0.9518
Epoch 3/3
469/469 [==============================] - 199s 424ms/step - loss: 0.0259 - sparse_categorical_accuracy: 0.9951 - val_loss: 0.1927 - val_sparse_categorical_accuracy: 0.9642
<keras.callbacks.History at 0x7fe24c0a7090>

Valuta le prestazioni OOD

Valuta quanto bene il modello è in grado di rilevare le query fuori dominio sconosciute. Per una valutazione rigorosa, usa il set di dati di valutazione OOD ood_eval_dataset creato in precedenza.

Calcola le probabilità OOD come \(1 - p(x)\), dove \(p(x)=softmax(logit(x))\) è la probabilità predittiva.

sngp_probs, ood_labels = oos_predict(sngp_model, ood_eval_dataset)
ood_probs = 1 - sngp_probs

Ora valuta quanto bene il punteggio di incertezza del modello ood_probs predice l'etichetta fuori dominio. Per prima cosa calcola l'area sotto la curva di richiamo di precisione (AURPC) per la probabilità OOD rispetto all'accuratezza del rilevamento OOD.

precision, recall, _ = sklearn.metrics.precision_recall_curve(ood_labels, ood_probs)
auprc = sklearn.metrics.auc(recall, precision)
print(f'SNGP AUPRC: {auprc:.4f}')
SNGP AUPRC: 0.9039

Ciò corrisponde alle prestazioni SNGP riportate al benchmark CLINC OOS in base alle linee di base dell'incertezza .

Successivamente, esaminare la qualità del modello nella calibrazione dell'incertezza , ovvero se la probabilità predittiva del modello corrisponde alla sua accuratezza predittiva. Un modello ben calibrato è considerato affidabile, poiché, ad esempio, la sua probabilità predittiva \(p(x)=0.8\) significa che il modello è corretto l'80% delle volte.

prob_true, prob_pred = sklearn.calibration.calibration_curve(
    ood_labels, ood_probs, n_bins=10, strategy='quantile')
plt.plot(prob_pred, prob_true)

plt.plot([0., 1.], [0., 1.], c='k', linestyle="--")
plt.xlabel('Predictive Probability')
plt.ylabel('Predictive Accuracy')
plt.title('Calibration Plots, SNGP')

plt.show()

png

Risorse e ulteriori letture