Kerugian Addons TensorFlow: TripletSemiHardLoss

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

Gambaran

Notebook ini akan menunjukkan cara menggunakan fungsi TripletSemiHardLoss di TensorFlow Addons.

Sumber daya:

TripletRugi

Seperti yang pertama kali diperkenalkan di makalah FaceNet, TripletLoss adalah fungsi kerugian yang melatih jaringan saraf untuk menyematkan fitur dari kelas yang sama secara dekat sambil memaksimalkan jarak antara penyematan kelas yang berbeda. Untuk melakukan ini jangkar dipilih bersama dengan satu sampel negatif dan satu sampel positif. gambar3

Fungsi kerugian digambarkan sebagai fungsi jarak Euclidean:

fungsi

Di mana A adalah input jangkar kami, P adalah input sampel positif, N adalah input sampel negatif, dan alfa adalah beberapa margin yang Anda gunakan untuk menentukan kapan triplet menjadi terlalu "mudah" dan Anda tidak lagi ingin menyesuaikan bobot darinya .

Pembelajaran Online Semi Keras

Seperti yang ditunjukkan dalam makalah, hasil terbaik adalah dari kembar tiga yang dikenal sebagai "Semi-Hard". Ini didefinisikan sebagai kembar tiga di mana negatif lebih jauh dari jangkar daripada positif, tetapi masih menghasilkan kerugian positif. Untuk menemukan triplet ini secara efisien, Anda menggunakan pembelajaran online dan hanya berlatih dari contoh Semi-Hard di setiap batch.

Mempersiapkan

pip install -q -U tensorflow-addons
import io
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_datasets as tfds

Siapkan Datanya

def _normalize_img(img, label):
    img = tf.cast(img, tf.float32) / 255.
    return (img, label)

train_dataset, test_dataset = tfds.load(name="mnist", split=['train', 'test'], as_supervised=True)

# Build your input pipelines
train_dataset = train_dataset.shuffle(1024).batch(32)
train_dataset = train_dataset.map(_normalize_img)

test_dataset = test_dataset.batch(32)
test_dataset = test_dataset.map(_normalize_img)
Downloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/3.0.1...
Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.

Bangun Modelnya

gambar2

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(filters=64, kernel_size=2, padding='same', activation='relu', input_shape=(28,28,1)),
    tf.keras.layers.MaxPooling2D(pool_size=2),
    tf.keras.layers.Dropout(0.3),
    tf.keras.layers.Conv2D(filters=32, kernel_size=2, padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(pool_size=2),
    tf.keras.layers.Dropout(0.3),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(256, activation=None), # No activation on final dense layer
    tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1)) # L2 normalize embeddings

])

Latih dan Evaluasi

# Compile the model
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tfa.losses.TripletSemiHardLoss())
# Train the network
history = model.fit(
    train_dataset,
    epochs=5)
Epoch 1/5
1875/1875 [==============================] - 21s 5ms/step - loss: 0.6983
Epoch 2/5
1875/1875 [==============================] - 8s 4ms/step - loss: 0.4723
Epoch 3/5
1875/1875 [==============================] - 8s 4ms/step - loss: 0.4298
Epoch 4/5
1875/1875 [==============================] - 8s 4ms/step - loss: 0.4139
Epoch 5/5
1875/1875 [==============================] - 8s 4ms/step - loss: 0.3938
# Evaluate the network
results = model.predict(test_dataset)
# Save test embeddings for visualization in projector
np.savetxt("vecs.tsv", results, delimiter='\t')

out_m = io.open('meta.tsv', 'w', encoding='utf-8')
for img, labels in tfds.as_numpy(test_dataset):
    [out_m.write(str(x) + "\n") for x in labels]
out_m.close()


try:
  from google.colab import files
  files.download('vecs.tsv')
  files.download('meta.tsv')
except:
  pass

Menanamkan Proyektor

Vektor dan metadata file dapat dimuat dan divisualisasikan di sini: https://projector.tensorflow.org/

Anda dapat melihat hasil data pengujian tersemat kami saat divisualisasikan dengan UMAP: penyematan