TensorFlow Eklenti Kayıpları: TripletSemiHardLoss

TensorFlow.org'da görüntüleyin Google Colab'da çalıştırın Kaynağı GitHub'da görüntüleyin Not defterini indir

genel bakış

Bu not defteri, TensorFlow Eklentilerinde TripletSemiHardLoss işlevinin nasıl kullanılacağını gösterecektir.

Kaynaklar:

Üçlü Kaybı

İlk olarak FaceNet belgesinde tanıtıldığı gibi, TripletLoss, bir sinir ağını aynı sınıfın özelliklerini yakından yerleştirmek ve farklı sınıfların yerleştirmeleri arasındaki mesafeyi en üst düzeye çıkarmak için eğiten bir kayıp işlevidir. Bunu yapmak için bir negatif ve bir pozitif örnekle birlikte bir çapa seçilir. Şekil 3

Kayıp fonksiyonu bir Öklid uzaklık fonksiyonu olarak tanımlanır:

işlev

A bizim çapa girdimiz olduğunda, P pozitif örnek girdidir, N negatif örnek girdidir ve alfa, bir üçlü çok "kolay" hale geldiğinde ve artık bundan ağırlıkları ayarlamak istemediğinizde belirtmek için kullandığınız bir miktar marjdır. .

SemiHard Çevrimiçi Öğrenme

Makalede gösterildiği gibi, en iyi sonuçlar "Yarı Sert" olarak bilinen üçüzlerden alınmıştır. Bunlar, negatifin çapadan pozitiften daha uzak olduğu, ancak yine de pozitif bir kayıp ürettiği üçlüler olarak tanımlanır. Bu üçüzleri verimli bir şekilde bulmak için çevrimiçi öğrenmeyi kullanırsınız ve her grupta yalnızca Yarı Zor örneklerden eğitim alırsınız.

Kurmak

pip install -q -U tensorflow-addons
import io
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_datasets as tfds

Verileri Hazırlayın

def _normalize_img(img, label):
    img = tf.cast(img, tf.float32) / 255.
    return (img, label)

train_dataset, test_dataset = tfds.load(name="mnist", split=['train', 'test'], as_supervised=True)

# Build your input pipelines
train_dataset = train_dataset.shuffle(1024).batch(32)
train_dataset = train_dataset.map(_normalize_img)

test_dataset = test_dataset.batch(32)
test_dataset = test_dataset.map(_normalize_img)
Downloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/3.0.1...
Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.

Modeli Oluştur

incir. 2

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(filters=64, kernel_size=2, padding='same', activation='relu', input_shape=(28,28,1)),
    tf.keras.layers.MaxPooling2D(pool_size=2),
    tf.keras.layers.Dropout(0.3),
    tf.keras.layers.Conv2D(filters=32, kernel_size=2, padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(pool_size=2),
    tf.keras.layers.Dropout(0.3),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(256, activation=None), # No activation on final dense layer
    tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1)) # L2 normalize embeddings

])

Eğitin ve Değerlendirin

# Compile the model
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tfa.losses.TripletSemiHardLoss())
# Train the network
history = model.fit(
    train_dataset,
    epochs=5)
Epoch 1/5
1875/1875 [==============================] - 21s 5ms/step - loss: 0.6983
Epoch 2/5
1875/1875 [==============================] - 8s 4ms/step - loss: 0.4723
Epoch 3/5
1875/1875 [==============================] - 8s 4ms/step - loss: 0.4298
Epoch 4/5
1875/1875 [==============================] - 8s 4ms/step - loss: 0.4139
Epoch 5/5
1875/1875 [==============================] - 8s 4ms/step - loss: 0.3938
# Evaluate the network
results = model.predict(test_dataset)
# Save test embeddings for visualization in projector
np.savetxt("vecs.tsv", results, delimiter='\t')

out_m = io.open('meta.tsv', 'w', encoding='utf-8')
for img, labels in tfds.as_numpy(test_dataset):
    [out_m.write(str(x) + "\n") for x in labels]
out_m.close()


try:
  from google.colab import files
  files.download('vecs.tsv')
  files.download('meta.tsv')
except:
  pass

Gömme Projektör

Vektör ve meta veri dosyaları yüklenir ve burada görüntülenebilmektedir: https://projector.tensorflow.org/

UMAP ile görselleştirildiğinde gömülü test verilerimizin sonuçlarını görebilirsiniz: gömme