View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |
Overview
This notebook will demonstrate how to use the TripletSemiHardLoss function in TensorFlow Addons.
Resources:
- FaceNet: A Unified Embedding for Face Recognition and Clustering
- Oliver Moindrot's blog does an excellent job of describing the algorithm in detail
TripletLoss
As first introduced in the FaceNet paper, TripletLoss is a loss function that trains a neural network to closely embed features of the same class while maximizing the distance between embeddings of different classes. To do this an anchor is chosen along with one negative and one positive sample.
The loss function is described as a Euclidean distance function:
Where A is our anchor input, P is the positive sample input, N is the negative sample input, and alpha is some margin you use to specify when a triplet has become too "easy" and you no longer want to adjust the weights from it.
SemiHard Online Learning
As shown in the paper, the best results are from triplets known as "Semi-Hard". These are defined as triplets where the negative is farther from the anchor than the positive, but still produces a positive loss. To efficiently find these triplets you utilize online learning and only train from the Semi-Hard examples in each batch.
Setup
pip install -U tensorflow-addons
import io
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_datasets as tfds
Prepare the Data
def _normalize_img(img, label):
img = tf.cast(img, tf.float32) / 255.
return (img, label)
train_dataset, test_dataset = tfds.load(name="mnist", split=['train', 'test'], as_supervised=True)
# Build your input pipelines
train_dataset = train_dataset.shuffle(1024).batch(32)
train_dataset = train_dataset.map(_normalize_img)
test_dataset = test_dataset.batch(32)
test_dataset = test_dataset.map(_normalize_img)
Build the Model
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(filters=64, kernel_size=2, padding='same', activation='relu', input_shape=(28,28,1)),
tf.keras.layers.MaxPooling2D(pool_size=2),
tf.keras.layers.Dropout(0.3),
tf.keras.layers.Conv2D(filters=32, kernel_size=2, padding='same', activation='relu'),
tf.keras.layers.MaxPooling2D(pool_size=2),
tf.keras.layers.Dropout(0.3),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(256, activation=None), # No activation on final dense layer
tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1)) # L2 normalize embeddings
])
Train and Evaluate
# Compile the model
model.compile(
optimizer=tf.keras.optimizers.Adam(0.001),
loss=tfa.losses.TripletSemiHardLoss())
# Train the network
history = model.fit(
train_dataset,
epochs=5)
Epoch 1/5 1875/1875 [==============================] - 13s 4ms/step - loss: 0.5952 Epoch 2/5 1875/1875 [==============================] - 6s 3ms/step - loss: 0.4562 Epoch 3/5 1875/1875 [==============================] - 7s 4ms/step - loss: 0.4207 Epoch 4/5 1875/1875 [==============================] - 6s 3ms/step - loss: 0.4040 Epoch 5/5 1875/1875 [==============================] - 6s 3ms/step - loss: 0.3964
# Evaluate the network
results = model.predict(test_dataset)
313/313 [==============================] - 1s 1ms/step
# Save test embeddings for visualization in projector
np.savetxt("vecs.tsv", results, delimiter='\t')
out_m = io.open('meta.tsv', 'w', encoding='utf-8')
for img, labels in tfds.as_numpy(test_dataset):
[out_m.write(str(x) + "\n") for x in labels]
out_m.close()
try:
from google.colab import files
files.download('vecs.tsv')
files.download('meta.tsv')
except:
pass
Embedding Projector
The vector and metadata files can be loaded and visualized here: https://projector.tensorflow.org/
You can see the results of our embedded test data when visualized with UMAP: