Lihat di TensorFlow.org | Jalankan di Google Colab | Lihat di GitHub | Unduh buku catatan | Lihat model TF Hub |
TensorFlow Hub adalah repositori model TensorFlow yang telah dilatih sebelumnya.
Tutorial ini menunjukkan cara:
- Gunakan model dari TensorFlow Hub dengan
tf.keras
. - Gunakan model klasifikasi gambar dari TensorFlow Hub.
- Lakukan transfer learning sederhana untuk menyempurnakan model untuk kelas gambar Anda sendiri.
Mempersiapkan
import numpy as np
import time
import PIL.Image as Image
import matplotlib.pylab as plt
import tensorflow as tf
import tensorflow_hub as hub
import datetime
%load_ext tensorboard
Pengklasifikasi ImageNet
Anda akan mulai dengan menggunakan model pengklasifikasi yang telah dilatih sebelumnya pada kumpulan data benchmark ImageNet —tidak diperlukan pelatihan awal!
Unduh pengklasifikasi
Pilih model terlatih MobileNetV2 dari TensorFlow Hub dan bungkus sebagai lapisan Keras dengan hub.KerasLayer
. Semua model pengklasifikasi gambar yang kompatibel dari TensorFlow Hub akan berfungsi di sini, termasuk contoh yang diberikan dalam tarik-turun di bawah.
mobilenet_v2 ="https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4"
inception_v3 = "https://tfhub.dev/google/imagenet/inception_v3/classification/5"
classifier_model = mobilenet_v2
IMAGE_SHAPE = (224, 224)
classifier = tf.keras.Sequential([
hub.KerasLayer(classifier_model, input_shape=IMAGE_SHAPE+(3,))
])
Jalankan pada satu gambar
Unduh satu gambar untuk mencoba model di:
grace_hopper = tf.keras.utils.get_file('image.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg')
grace_hopper = Image.open(grace_hopper).resize(IMAGE_SHAPE)
grace_hopper
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg 65536/61306 [================================] - 0s 0us/step 73728/61306 [====================================] - 0s 0us/step
grace_hopper = np.array(grace_hopper)/255.0
grace_hopper.shape
(224, 224, 3)
Tambahkan dimensi batch (dengan np.newaxis
) dan teruskan gambar ke model:
result = classifier.predict(grace_hopper[np.newaxis, ...])
result.shape
(1, 1001)
Hasilnya adalah vektor logit 1001 elemen, menilai probabilitas setiap kelas untuk gambar.
ID kelas atas dapat ditemukan dengan tf.math.argmax
:
predicted_class = tf.math.argmax(result[0], axis=-1)
predicted_class
<tf.Tensor: shape=(), dtype=int64, numpy=653>
Decode prediksi
Ambil ID predicted_class
(seperti 653
) dan ambil label kumpulan data ImageNet untuk memecahkan kode prediksi:
labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')
imagenet_labels = np.array(open(labels_path).read().splitlines())
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt 16384/10484 [==============================================] - 0s 0us/step 24576/10484 [======================================================================] - 0s 0us/step
plt.imshow(grace_hopper)
plt.axis('off')
predicted_class_name = imagenet_labels[predicted_class]
_ = plt.title("Prediction: " + predicted_class_name.title())
Pembelajaran transfer sederhana
Tetapi bagaimana jika Anda ingin membuat pengklasifikasi khusus menggunakan dataset Anda sendiri yang memiliki kelas yang tidak disertakan dalam dataset ImageNet asli (dimana model pra-pelatihan dilatih)?
Untuk melakukannya, Anda dapat:
- Pilih model terlatih dari TensorFlow Hub; dan
- Latih kembali lapisan atas (terakhir) untuk mengenali kelas dari kumpulan data khusus Anda.
Himpunan data
Dalam contoh ini, Anda akan menggunakan kumpulan data bunga TensorFlow:
data_root = tf.keras.utils.get_file(
'flower_photos',
'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
untar=True)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz 228818944/228813984 [==============================] - 7s 0us/step 228827136/228813984 [==============================] - 7s 0us/step
Pertama, muat data ini ke dalam model menggunakan data gambar dari disk dengan tf.keras.utils.image_dataset_from_directory
, yang akan menghasilkan tf.data.Dataset
:
batch_size = 32
img_height = 224
img_width = 224
train_ds = tf.keras.utils.image_dataset_from_directory(
str(data_root),
validation_split=0.2,
subset="training",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size
)
val_ds = tf.keras.utils.image_dataset_from_directory(
str(data_root),
validation_split=0.2,
subset="validation",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size
)
Found 3670 files belonging to 5 classes. Using 2936 files for training. Found 3670 files belonging to 5 classes. Using 734 files for validation.
Dataset bunga memiliki lima kelas:
class_names = np.array(train_ds.class_names)
print(class_names)
['daisy' 'dandelion' 'roses' 'sunflowers' 'tulips']
Kedua, karena konvensi TensorFlow Hub untuk model gambar mengharapkan input float dalam kisaran [0, 1]
, gunakan lapisan prapemrosesan tf.keras.layers.Rescaling
untuk mencapai ini.
normalization_layer = tf.keras.layers.Rescaling(1./255)
train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y)) # Where x—images, y—labels.
val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y)) # Where x—images, y—labels.
Ketiga, selesaikan jalur input dengan menggunakan buffered prefetching dengan Dataset.prefetch
, sehingga Anda dapat menghasilkan data dari disk tanpa masalah pemblokiran I/O.
Ini adalah beberapa metode tf.data
terpenting yang harus Anda gunakan saat memuat data. Pembaca yang tertarik dapat mempelajari lebih lanjut tentang mereka, serta cara menyimpan data ke disk dan teknik lainnya, dalam kinerja yang lebih baik dengan panduan API tf.data .
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
for image_batch, labels_batch in train_ds:
print(image_batch.shape)
print(labels_batch.shape)
break
(32, 224, 224, 3) (32,) 2022-01-26 05:06:19.465331: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Jalankan classifier pada sekumpulan gambar
Sekarang, jalankan classifier pada kumpulan gambar:
result_batch = classifier.predict(train_ds)
predicted_class_names = imagenet_labels[tf.math.argmax(result_batch, axis=-1)]
predicted_class_names
array(['daisy', 'coral fungus', 'rapeseed', ..., 'daisy', 'daisy', 'birdhouse'], dtype='<U30')
Periksa bagaimana prediksi ini sejalan dengan gambar:
plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
plt.subplot(6,5,n+1)
plt.imshow(image_batch[n])
plt.title(predicted_class_names[n])
plt.axis('off')
_ = plt.suptitle("ImageNet predictions")
Hasilnya jauh dari sempurna, tetapi masuk akal mengingat ini bukan kelas yang dilatih modelnya (kecuali untuk "daisy").
Unduh model tanpa kepala
TensorFlow Hub juga mendistribusikan model tanpa lapisan klasifikasi teratas. Ini dapat digunakan untuk melakukan transfer learning dengan mudah.
Pilih model terlatih MobileNetV2 dari TensorFlow Hub . Semua model vektor fitur gambar yang kompatibel dari TensorFlow Hub akan berfungsi di sini, termasuk contoh dari menu tarik-turun.
mobilenet_v2 = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4"
inception_v3 = "https://tfhub.dev/google/tf2-preview/inception_v3/feature_vector/4"
feature_extractor_model = mobilenet_v2
Buat ekstraktor fitur dengan membungkus model yang telah dilatih sebelumnya sebagai lapisan Keras dengan hub.KerasLayer
. Gunakan argumen trainable=False
untuk membekukan variabel, sehingga pelatihan hanya memodifikasi lapisan classifier baru:
feature_extractor_layer = hub.KerasLayer(
feature_extractor_model,
input_shape=(224, 224, 3),
trainable=False)
Ekstraktor fitur mengembalikan vektor sepanjang 1280 untuk setiap gambar (ukuran kumpulan gambar tetap pada 32 dalam contoh ini):
feature_batch = feature_extractor_layer(image_batch)
print(feature_batch.shape)
(32, 1280)
Lampirkan kepala klasifikasi
Untuk melengkapi model, bungkus lapisan ekstraktor fitur dalam model tf.keras.Sequential
dan tambahkan lapisan yang terhubung penuh untuk klasifikasi:
num_classes = len(class_names)
model = tf.keras.Sequential([
feature_extractor_layer,
tf.keras.layers.Dense(num_classes)
])
model.summary()
Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= keras_layer_1 (KerasLayer) (None, 1280) 2257984 dense (Dense) (None, 5) 6405 ================================================================= Total params: 2,264,389 Trainable params: 6,405 Non-trainable params: 2,257,984 _________________________________________________________________
predictions = model(image_batch)
predictions.shape
TensorShape([32, 5])
Latih modelnya
Gunakan Model.compile
untuk mengonfigurasi proses pelatihan dan menambahkan panggilan balik tf.keras.callbacks.TensorBoard
untuk membuat dan menyimpan log:
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['acc'])
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=log_dir,
histogram_freq=1) # Enable histogram computation for every epoch.
Sekarang gunakan metode Model.fit
untuk melatih model.
Untuk mempersingkat contoh ini, Anda hanya akan berlatih selama 10 epoch. Untuk memvisualisasikan kemajuan pelatihan di TensorBoard nanti, buat dan simpan log panggilan balik TensorBoard .
NUM_EPOCHS = 10
history = model.fit(train_ds,
validation_data=val_ds,
epochs=NUM_EPOCHS,
callbacks=tensorboard_callback)
Epoch 1/10 92/92 [==============================] - 7s 42ms/step - loss: 0.7904 - acc: 0.7210 - val_loss: 0.4592 - val_acc: 0.8515 Epoch 2/10 92/92 [==============================] - 3s 33ms/step - loss: 0.3850 - acc: 0.8713 - val_loss: 0.3694 - val_acc: 0.8787 Epoch 3/10 92/92 [==============================] - 3s 33ms/step - loss: 0.3027 - acc: 0.9057 - val_loss: 0.3367 - val_acc: 0.8856 Epoch 4/10 92/92 [==============================] - 3s 33ms/step - loss: 0.2524 - acc: 0.9237 - val_loss: 0.3210 - val_acc: 0.8869 Epoch 5/10 92/92 [==============================] - 3s 33ms/step - loss: 0.2164 - acc: 0.9373 - val_loss: 0.3124 - val_acc: 0.8896 Epoch 6/10 92/92 [==============================] - 3s 33ms/step - loss: 0.1888 - acc: 0.9469 - val_loss: 0.3070 - val_acc: 0.8937 Epoch 7/10 92/92 [==============================] - 3s 33ms/step - loss: 0.1668 - acc: 0.9550 - val_loss: 0.3032 - val_acc: 0.9005 Epoch 8/10 92/92 [==============================] - 3s 33ms/step - loss: 0.1487 - acc: 0.9619 - val_loss: 0.3004 - val_acc: 0.9005 Epoch 9/10 92/92 [==============================] - 3s 33ms/step - loss: 0.1335 - acc: 0.9687 - val_loss: 0.2981 - val_acc: 0.9019 Epoch 10/10 92/92 [==============================] - 3s 33ms/step - loss: 0.1206 - acc: 0.9748 - val_loss: 0.2964 - val_acc: 0.9046
Mulai TensorBoard untuk melihat bagaimana metrik berubah dengan setiap zaman dan untuk melacak nilai skalar lainnya:
%tensorboard --logdir logs/fit
Cek prediksinya
Dapatkan daftar nama kelas yang diurutkan dari prediksi model:
predicted_batch = model.predict(image_batch)
predicted_id = tf.math.argmax(predicted_batch, axis=-1)
predicted_label_batch = class_names[predicted_id]
print(predicted_label_batch)
['roses' 'dandelion' 'tulips' 'sunflowers' 'dandelion' 'roses' 'dandelion' 'roses' 'tulips' 'dandelion' 'tulips' 'tulips' 'sunflowers' 'tulips' 'dandelion' 'roses' 'daisy' 'tulips' 'dandelion' 'dandelion' 'dandelion' 'tulips' 'sunflowers' 'roses' 'sunflowers' 'dandelion' 'tulips' 'roses' 'roses' 'sunflowers' 'tulips' 'sunflowers']
Plot prediksi model:
plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
plt.subplot(6,5,n+1)
plt.imshow(image_batch[n])
plt.title(predicted_label_batch[n].title())
plt.axis('off')
_ = plt.suptitle("Model predictions")
Ekspor dan muat ulang model Anda
Sekarang setelah Anda melatih modelnya, ekspor sebagai SavedModel untuk digunakan kembali nanti.
t = time.time()
export_path = "/tmp/saved_models/{}".format(int(t))
model.save(export_path)
export_path
2022-01-26 05:07:03.429901: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. INFO:tensorflow:Assets written to: /tmp/saved_models/1643173621/assets INFO:tensorflow:Assets written to: /tmp/saved_models/1643173621/assets '/tmp/saved_models/1643173621'
Konfirmasikan bahwa Anda dapat memuat ulang SavedModel dan model dapat menampilkan hasil yang sama:
reloaded = tf.keras.models.load_model(export_path)
result_batch = model.predict(image_batch)
reloaded_result_batch = reloaded.predict(image_batch)
abs(reloaded_result_batch - result_batch).max()
0.0
reloaded_predicted_id = tf.math.argmax(reloaded_result_batch, axis=-1)
reloaded_predicted_label_batch = class_names[reloaded_predicted_id]
print(reloaded_predicted_label_batch)
['roses' 'dandelion' 'tulips' 'sunflowers' 'dandelion' 'roses' 'dandelion' 'roses' 'tulips' 'dandelion' 'tulips' 'tulips' 'sunflowers' 'tulips' 'dandelion' 'roses' 'daisy' 'tulips' 'dandelion' 'dandelion' 'dandelion' 'tulips' 'sunflowers' 'roses' 'sunflowers' 'dandelion' 'tulips' 'roses' 'roses' 'sunflowers' 'tulips' 'sunflowers']
plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
plt.subplot(6,5,n+1)
plt.imshow(image_batch[n])
plt.title(reloaded_predicted_label_batch[n].title())
plt.axis('off')
_ = plt.suptitle("Model predictions")
Langkah selanjutnya
Anda dapat menggunakan SavedModel untuk memuat inferensi atau mengonversinya menjadi model TensorFlow Lite (untuk pembelajaran mesin di perangkat) atau model TensorFlow.js (untuk pembelajaran mesin dalam JavaScript).
Temukan lebih banyak tutorial untuk mempelajari cara menggunakan model terlatih dari TensorFlow Hub pada tugas gambar, teks, audio, dan video.