Cet exemple simple montre comment connecter des ensembles de données TensorFlow (TFDS) à un modèle Keras.
Voir sur TensorFlow.org | Exécuter dans Google Colab | Voir la source sur GitHub | Télécharger le cahier |
import tensorflow as tf
import tensorflow_datasets as tfds
Étape 1 : Créer votre pipeline d'entrée
Commencez par créer un pipeline d'entrée efficace en utilisant les conseils de :
- Le guide des conseils sur les performances
- Les meilleures performances avec le guide de l'API
tf.data
Charger un jeu de données
Chargez l'ensemble de données MNIST avec les arguments suivants :
-
shuffle_files=True
: les données MNIST ne sont stockées que dans un seul fichier, mais pour les ensembles de données plus volumineux avec plusieurs fichiers sur le disque, il est recommandé de les mélanger lors de la formation. -
as_supervised=True
: Retourne un tuple(img, label)
au lieu d'un dictionnaire{'image': img, 'label': label}
.
(ds_train, ds_test), ds_info = tfds.load(
'mnist',
split=['train', 'test'],
shuffle_files=True,
as_supervised=True,
with_info=True,
)
2022-02-07 04:05:46.671689: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Créer un pipeline de formation
Appliquez les transformations suivantes :
-
tf.data.Dataset.map
: TFDS fournit des images de typetf.uint8
, tandis que le modèle attendtf.float32
. Par conséquent, vous devez normaliser les images. -
tf.data.Dataset.cache
Au fur et à mesure que vous adaptez l'ensemble de données en mémoire, mettez-le en cache avant de le mélanger pour de meilleures performances.
Remarque : Les transformations aléatoires doivent être appliquées après la mise en cache. -
tf.data.Dataset.shuffle
: pour un véritable caractère aléatoire, définissez le tampon de mélange sur la taille complète de l'ensemble de données.
Remarque : Pour les ensembles de données volumineux qui ne tiennent pas en mémoire, utilisezbuffer_size=1000
si votre système le permet. -
tf.data.Dataset.batch
: Regroupez les éléments de l'ensemble de données après mélange pour obtenir des lots uniques à chaque époque. -
tf.data.Dataset.prefetch
: il est recommandé de terminer le pipeline en effectuant une prélecture pour des raisons de performances .
def normalize_img(image, label):
"""Normalizes images: `uint8` -> `float32`."""
return tf.cast(image, tf.float32) / 255., label
ds_train = ds_train.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)
Créer un pipeline d'évaluation
Votre pipeline de test est similaire au pipeline de formation avec de petites différences :
- Vous n'avez pas besoin d'appeler
tf.data.Dataset.shuffle
. - La mise en cache est effectuée après le traitement par lot car les lots peuvent être les mêmes d'une époque à l'autre.
ds_test = ds_test.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)
Étape 2 : Créer et entraîner le modèle
Branchez le pipeline d'entrée TFDS dans un modèle Keras simple, compilez le modèle et entraînez-le.
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(
optimizer=tf.keras.optimizers.Adam(0.001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
model.fit(
ds_train,
epochs=6,
validation_data=ds_test,
)
Epoch 1/6 469/469 [==============================] - 5s 4ms/step - loss: 0.3503 - sparse_categorical_accuracy: 0.9053 - val_loss: 0.1979 - val_sparse_categorical_accuracy: 0.9415 Epoch 2/6 469/469 [==============================] - 1s 2ms/step - loss: 0.1668 - sparse_categorical_accuracy: 0.9524 - val_loss: 0.1392 - val_sparse_categorical_accuracy: 0.9595 Epoch 3/6 469/469 [==============================] - 1s 2ms/step - loss: 0.1216 - sparse_categorical_accuracy: 0.9657 - val_loss: 0.1120 - val_sparse_categorical_accuracy: 0.9653 Epoch 4/6 469/469 [==============================] - 1s 2ms/step - loss: 0.0939 - sparse_categorical_accuracy: 0.9726 - val_loss: 0.0960 - val_sparse_categorical_accuracy: 0.9704 Epoch 5/6 469/469 [==============================] - 1s 2ms/step - loss: 0.0757 - sparse_categorical_accuracy: 0.9781 - val_loss: 0.0928 - val_sparse_categorical_accuracy: 0.9717 Epoch 6/6 469/469 [==============================] - 1s 2ms/step - loss: 0.0625 - sparse_categorical_accuracy: 0.9818 - val_loss: 0.0851 - val_sparse_categorical_accuracy: 0.9728 <keras.callbacks.History at 0x7f77b42cd910>