TensorFlow.org'da görüntüleyin | Google Colab'da çalıştırın | Kaynağı GitHub'da görüntüleyin | Not defterini indir |
genel bakış
Bu not defteri, TensorFlow Eklentilerinde TripletSemiHardLoss işlevinin nasıl kullanılacağını gösterecektir.
Kaynaklar:
- FaceNet: Yüz Tanıma ve Kümeleme için Birleşik Bir Gömme
- Oliver Moindrot'un blogu, algoritmayı ayrıntılı olarak açıklamak için mükemmel bir iş çıkarıyor
Üçlü Kaybı
İlk olarak FaceNet belgesinde tanıtıldığı gibi, TripletLoss, bir sinir ağını aynı sınıfın özelliklerini yakından yerleştirmek ve farklı sınıfların yerleştirmeleri arasındaki mesafeyi en üst düzeye çıkarmak için eğiten bir kayıp işlevidir. Bunu yapmak için bir negatif ve bir pozitif örnekle birlikte bir çapa seçilir.
Kayıp fonksiyonu bir Öklid uzaklık fonksiyonu olarak tanımlanır:
A bizim çapa girdimiz olduğunda, P pozitif örnek girdidir, N negatif örnek girdidir ve alfa, bir üçlü çok "kolay" hale geldiğinde ve artık bundan ağırlıkları ayarlamak istemediğinizde belirtmek için kullandığınız bir miktar marjdır. .
SemiHard Çevrimiçi Öğrenme
Makalede gösterildiği gibi, en iyi sonuçlar "Yarı Sert" olarak bilinen üçüzlerden alınmıştır. Bunlar, negatifin çapadan pozitiften daha uzak olduğu, ancak yine de pozitif bir kayıp ürettiği üçlüler olarak tanımlanır. Bu üçüzleri verimli bir şekilde bulmak için çevrimiçi öğrenmeyi kullanırsınız ve her grupta yalnızca Yarı Zor örneklerden eğitim alırsınız.
Kurmak
pip install -q -U tensorflow-addons
import io
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_datasets as tfds
Verileri Hazırlayın
def _normalize_img(img, label):
img = tf.cast(img, tf.float32) / 255.
return (img, label)
train_dataset, test_dataset = tfds.load(name="mnist", split=['train', 'test'], as_supervised=True)
# Build your input pipelines
train_dataset = train_dataset.shuffle(1024).batch(32)
train_dataset = train_dataset.map(_normalize_img)
test_dataset = test_dataset.batch(32)
test_dataset = test_dataset.map(_normalize_img)
Downloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/3.0.1... Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
Modeli Oluştur
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(filters=64, kernel_size=2, padding='same', activation='relu', input_shape=(28,28,1)),
tf.keras.layers.MaxPooling2D(pool_size=2),
tf.keras.layers.Dropout(0.3),
tf.keras.layers.Conv2D(filters=32, kernel_size=2, padding='same', activation='relu'),
tf.keras.layers.MaxPooling2D(pool_size=2),
tf.keras.layers.Dropout(0.3),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(256, activation=None), # No activation on final dense layer
tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1)) # L2 normalize embeddings
])
Eğitin ve Değerlendirin
# Compile the model
model.compile(
optimizer=tf.keras.optimizers.Adam(0.001),
loss=tfa.losses.TripletSemiHardLoss())
# Train the network
history = model.fit(
train_dataset,
epochs=5)
Epoch 1/5 1875/1875 [==============================] - 21s 5ms/step - loss: 0.6983 Epoch 2/5 1875/1875 [==============================] - 8s 4ms/step - loss: 0.4723 Epoch 3/5 1875/1875 [==============================] - 8s 4ms/step - loss: 0.4298 Epoch 4/5 1875/1875 [==============================] - 8s 4ms/step - loss: 0.4139 Epoch 5/5 1875/1875 [==============================] - 8s 4ms/step - loss: 0.3938
# Evaluate the network
results = model.predict(test_dataset)
# Save test embeddings for visualization in projector
np.savetxt("vecs.tsv", results, delimiter='\t')
out_m = io.open('meta.tsv', 'w', encoding='utf-8')
for img, labels in tfds.as_numpy(test_dataset):
[out_m.write(str(x) + "\n") for x in labels]
out_m.close()
try:
from google.colab import files
files.download('vecs.tsv')
files.download('meta.tsv')
except:
pass
Gömme Projektör
Vektör ve meta veri dosyaları yüklenir ve burada görüntülenebilmektedir: https://projector.tensorflow.org/
UMAP ile görselleştirildiğinde gömülü test verilerimizin sonuçlarını görebilirsiniz: