View on TensorFlow.org | Run in Google Colab | View on GitHub | Download notebook | See TF Hub models |
This colab demonstrates how to:
- Load BERT models from TensorFlow Hub that have been trained on different tasks including MNLI, SQuAD, and PubMed
- Use a matching preprocessing model to tokenize raw text and convert it to ids
- Generate the pooled and sequence output from the token input ids using the loaded model
- Look at the semantic similarity of the pooled outputs of different sentences
Note: This colab should be run with a GPU runtime
Set up and imports
pip install --quiet "tensorflow-text==2.11.*"
import seaborn as sns
from sklearn.metrics import pairwise
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text # Imports TF ops for preprocessing.
2023-12-08 13:22:58.830986: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory 2023-12-08 13:22:59.587638: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2023-12-08 13:22:59.587746: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2023-12-08 13:22:59.587757: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
Configure the model
BERT_MODEL = "https://tfhub.dev/google/experts/bert/wiki_books/2" # @param {type: "string"} ["https://tfhub.dev/google/experts/bert/wiki_books/2", "https://tfhub.dev/google/experts/bert/wiki_books/mnli/2", "https://tfhub.dev/google/experts/bert/wiki_books/qnli/2", "https://tfhub.dev/google/experts/bert/wiki_books/qqp/2", "https://tfhub.dev/google/experts/bert/wiki_books/squad2/2", "https://tfhub.dev/google/experts/bert/wiki_books/sst2/2", "https://tfhub.dev/google/experts/bert/pubmed/2", "https://tfhub.dev/google/experts/bert/pubmed/squad2/2"]
# Preprocessing must match the model, but all the above use the same.
PREPROCESS_MODEL = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
Sentences
Let's take some sentences from Wikipedia to run through the model
sentences = [
"Here We Go Then, You And I is a 1999 album by Norwegian pop artist Morten Abel. It was Abel's second CD as a solo artist.",
"The album went straight to number one on the Norwegian album chart, and sold to double platinum.",
"Among the singles released from the album were the songs \"Be My Lover\" and \"Hard To Stay Awake\".",
"Riccardo Zegna is an Italian jazz musician.",
"Rajko Maksimović is a composer, writer, and music pedagogue.",
"One of the most significant Serbian composers of our time, Maksimović has been and remains active in creating works for different ensembles.",
"Ceylon spinach is a common name for several plants and may refer to: Basella alba Talinum fruticosum",
"A solar eclipse occurs when the Moon passes between Earth and the Sun, thereby totally or partly obscuring the image of the Sun for a viewer on Earth.",
"A partial solar eclipse occurs in the polar regions of the Earth when the center of the Moon's shadow misses the Earth.",
]
Run the model
We'll load the BERT model from TF-Hub, tokenize our sentences using the matching preprocessing model from TF-Hub, then feed in the tokenized sentences to the model. To keep this colab fast and simple, we recommend running on GPU.
Go to Runtime → Change runtime type to make sure that GPU is selected
preprocess = hub.load(PREPROCESS_MODEL)
bert = hub.load(BERT_MODEL)
inputs = preprocess(sentences)
outputs = bert(inputs)
2023-12-08 13:23:00.818651: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
print("Sentences:")
print(sentences)
print("\nBERT inputs:")
print(inputs)
print("\nPooled embeddings:")
print(outputs["pooled_output"])
print("\nPer token embeddings:")
print(outputs["sequence_output"])
Sentences: ["Here We Go Then, You And I is a 1999 album by Norwegian pop artist Morten Abel. It was Abel's second CD as a solo artist.", 'The album went straight to number one on the Norwegian album chart, and sold to double platinum.', 'Among the singles released from the album were the songs "Be My Lover" and "Hard To Stay Awake".', 'Riccardo Zegna is an Italian jazz musician.', 'Rajko Maksimović is a composer, writer, and music pedagogue.', 'One of the most significant Serbian composers of our time, Maksimović has been and remains active in creating works for different ensembles.', 'Ceylon spinach is a common name for several plants and may refer to: Basella alba Talinum fruticosum', 'A solar eclipse occurs when the Moon passes between Earth and the Sun, thereby totally or partly obscuring the image of the Sun for a viewer on Earth.', "A partial solar eclipse occurs in the polar regions of the Earth when the center of the Moon's shadow misses the Earth."] BERT inputs: {'input_word_ids': <tf.Tensor: shape=(9, 128), dtype=int32, numpy= array([[ 101, 2182, 2057, ..., 0, 0, 0], [ 101, 1996, 2201, ..., 0, 0, 0], [ 101, 2426, 1996, ..., 0, 0, 0], ..., [ 101, 16447, 6714, ..., 0, 0, 0], [ 101, 1037, 5943, ..., 0, 0, 0], [ 101, 1037, 7704, ..., 0, 0, 0]], dtype=int32)>, 'input_mask': <tf.Tensor: shape=(9, 128), dtype=int32, numpy= array([[1, 1, 1, ..., 0, 0, 0], [1, 1, 1, ..., 0, 0, 0], [1, 1, 1, ..., 0, 0, 0], ..., [1, 1, 1, ..., 0, 0, 0], [1, 1, 1, ..., 0, 0, 0], [1, 1, 1, ..., 0, 0, 0]], dtype=int32)>, 'input_type_ids': <tf.Tensor: shape=(9, 128), dtype=int32, numpy= array([[0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], ..., [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0]], dtype=int32)>} Pooled embeddings: tf.Tensor( [[ 0.7975988 -0.48580432 0.49781707 ... -0.34488383 0.39727724 -0.2063961 ] [ 0.5712041 -0.412053 0.7048913 ... -0.35185084 0.19032447 -0.40419042] [-0.6993834 0.1586692 0.06569834 ... -0.06232304 -0.8155023 -0.0792362 ] ... [-0.35727224 0.77089787 0.15756573 ... 0.44185558 -0.86448216 0.04505127] [ 0.9107703 0.4150136 0.56063557 ... -0.49263772 0.39640468 -0.05036155] [ 0.90502924 -0.15505382 0.7267205 ... -0.3473441 0.50526446 -0.19542791]], shape=(9, 768), dtype=float32) Per token embeddings: tf.Tensor( [[[ 1.09197736e+00 -5.30553758e-01 5.46399832e-01 ... -3.59625190e-01 4.20411706e-01 -2.09404156e-01] [ 1.01438344e+00 7.80789793e-01 8.53757977e-01 ... 5.52821755e-01 -1.12458134e+00 5.60274601e-01] [ 7.88627028e-01 7.77751431e-02 9.51507509e-01 ... -1.90752178e-01 5.92061162e-01 6.19106948e-01] ... [-3.22031379e-01 -4.25213337e-01 -1.28237948e-01 ... -3.90948206e-01 -7.90973783e-01 4.22366500e-01] [-3.10375020e-02 2.39857227e-01 -2.19943717e-01 ... -1.14400357e-01 -1.26804805e+00 -1.61363304e-01] [-4.20635521e-01 5.49729943e-01 -3.24445844e-01 ... -1.84786245e-01 -1.13429761e+00 -5.89753427e-02]] [[ 6.49308205e-01 -4.38081563e-01 8.76956284e-01 ... -3.67554575e-01 1.92673832e-01 -4.28647608e-01] [-1.12487352e+00 2.99315214e-01 1.17996371e+00 ... 4.87293720e-01 5.34003437e-01 2.28362054e-01] [-2.70572275e-01 3.23551893e-02 1.04256880e+00 ... 5.89936972e-01 1.53679097e+00 5.84257543e-01] ... [-1.47625208e+00 1.82393655e-01 5.58771193e-02 ... -1.67332029e+00 -6.73989058e-01 -7.24497080e-01] [-1.51381421e+00 5.81847310e-01 1.61419123e-01 ... -1.26408207e+00 -4.02722836e-01 -9.71971810e-01] [-4.71528202e-01 2.28173852e-01 5.27762115e-01 ... -7.54835010e-01 -9.09029961e-01 -1.69545799e-01]] [[-8.66092384e-01 1.60021260e-01 6.57931119e-02 ... -6.24039248e-02 -1.14323997e+00 -7.94026628e-02] [ 7.71179259e-01 7.08046913e-01 1.13500684e-01 ... 7.88312197e-01 -3.14382076e-01 -9.74871039e-01] [-4.40023631e-01 -3.00597101e-01 3.54795575e-01 ... 7.97389075e-02 -4.73936439e-01 -1.10018480e+00] ... [-1.02052820e+00 2.69382656e-01 -4.73106146e-01 ... -6.63192511e-01 -1.45799482e+00 -3.46653670e-01] [-9.70035076e-01 -4.50126976e-02 -5.97797990e-01 ... -3.05263996e-01 -1.27442646e+00 -2.80517220e-01] [-7.31441319e-01 1.76993489e-01 -4.62580770e-01 ... -1.60622299e-01 -1.63460863e+00 -3.20605874e-01]] ... [[-3.73755515e-01 1.02253711e+00 1.58889472e-01 ... 4.74534214e-01 -1.31081784e+00 4.50817905e-02] [-4.15890545e-01 5.00191152e-01 -4.58437860e-01 ... 4.14820373e-01 -6.20655239e-01 -7.15548456e-01] [-1.25043976e+00 5.09365320e-01 -5.71035028e-01 ... 3.54912102e-01 2.43686497e-01 -2.05771971e+00] ... [ 1.33937567e-01 1.18591630e+00 -2.21703053e-01 ... -8.19470763e-01 -1.67373240e+00 -3.96922529e-01] [-3.36620927e-01 1.65561998e+00 -3.78125042e-01 ... -9.67455268e-01 -1.48010266e+00 -8.33307445e-01] [-2.26492420e-01 1.61784339e+00 -6.70442462e-01 ... -4.90783870e-01 -1.45356917e+00 -7.17070937e-01]] [[ 1.53202403e+00 4.41652954e-01 6.33759618e-01 ... -5.39537430e-01 4.19376075e-01 -5.04041985e-02] [ 8.93777490e-01 8.93953502e-01 3.06282490e-02 ... 5.90415969e-02 -2.06496209e-01 -8.48113477e-01] [-1.85601413e-02 1.04790521e+00 -1.33296037e+00 ... -1.38697088e-01 -3.78795743e-01 -4.90686715e-01] ... [ 1.42756391e+00 1.06969677e-01 -4.06351686e-02 ... -3.17769870e-02 -4.14600194e-01 7.00368643e-01] [ 1.12866282e+00 1.45479128e-01 -6.13726497e-01 ... 4.74917591e-01 -3.98524106e-01 4.31243986e-01] [ 1.43932939e+00 1.80306792e-01 -4.28543597e-01 ... -2.50226915e-01 -1.00005698e+00 3.59853417e-01]] [[ 1.49934173e+00 -1.56314656e-01 9.21741903e-01 ... -3.62420291e-01 5.56350231e-01 -1.97974458e-01] [ 1.11105156e+00 3.66512716e-01 3.55057359e-01 ... -5.42976081e-01 1.44713879e-01 -3.16760272e-01] [ 2.40488604e-01 3.81156802e-01 -5.91828823e-01 ... 3.74109328e-01 -5.98297358e-01 -1.01662636e+00] ... [ 1.01586139e+00 5.02605081e-01 1.07370526e-01 ... -9.56424594e-01 -4.10394579e-01 -2.67600328e-01] [ 1.18489230e+00 6.54797316e-01 1.01533532e-03 ... -8.61542225e-01 -8.80386233e-02 -3.06369245e-01] [ 1.26691031e+00 4.77678001e-01 6.62612915e-03 ... -1.15857971e+00 -7.06757903e-02 -1.86785877e-01]]], shape=(9, 128, 768), dtype=float32)
Semantic similarity
Now let's take a look at the pooled_output
embeddings of our sentences and compare how similar they are across sentences.
Helper functions
def plot_similarity(features, labels):
"""Plot a similarity matrix of the embeddings."""
cos_sim = pairwise.cosine_similarity(features)
sns.set(font_scale=1.2)
cbar_kws=dict(use_gridspec=False, location="left")
g = sns.heatmap(
cos_sim, xticklabels=labels, yticklabels=labels,
vmin=0, vmax=1, cmap="Blues", cbar_kws=cbar_kws)
g.tick_params(labelright=True, labelleft=False)
g.set_yticklabels(labels, rotation=0)
g.set_title("Semantic Textual Similarity")
plot_similarity(outputs["pooled_output"], sentences)
Learn more
- Find more BERT models on TensorFlow Hub
- This notebook demonstrates simple inference with BERT, you can find a more advanced tutorial about fine-tuning BERT at tensorflow.org/official_models/fine_tuning_bert
- We used just one GPU chip to run the model, you can learn more about how to load models using tf.distribute at tensorflow.org/tutorials/distribute/save_and_load