Formation d'un réseau de neurones sur MNIST avec Keras

Cet exemple simple montre comment connecter des ensembles de données TensorFlow (TFDS) à un modèle Keras.

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHub Télécharger le cahier
import tensorflow as tf
import tensorflow_datasets as tfds

Étape 1 : Créer votre pipeline d'entrée

Commencez par créer un pipeline d'entrée efficace en utilisant les conseils de :

Charger un jeu de données

Chargez l'ensemble de données MNIST avec les arguments suivants :

  • shuffle_files=True : les données MNIST ne sont stockées que dans un seul fichier, mais pour les ensembles de données plus volumineux avec plusieurs fichiers sur le disque, il est recommandé de les mélanger lors de la formation.
  • as_supervised=True : Retourne un tuple (img, label) au lieu d'un dictionnaire {'image': img, 'label': label} .
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)
2022-02-07 04:05:46.671689: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

Créer un pipeline de formation

Appliquez les transformations suivantes :

  • tf.data.Dataset.map : TFDS fournit des images de type tf.uint8 , tandis que le modèle attend tf.float32 . Par conséquent, vous devez normaliser les images.
  • tf.data.Dataset.cache Au fur et à mesure que vous adaptez l'ensemble de données en mémoire, mettez-le en cache avant de le mélanger pour de meilleures performances.
    Remarque : Les transformations aléatoires doivent être appliquées après la mise en cache.
  • tf.data.Dataset.shuffle : pour un véritable caractère aléatoire, définissez le tampon de mélange sur la taille complète de l'ensemble de données.
    Remarque : Pour les ensembles de données volumineux qui ne tiennent pas en mémoire, utilisez buffer_size=1000 si votre système le permet.
  • tf.data.Dataset.batch : Regroupez les éléments de l'ensemble de données après mélange pour obtenir des lots uniques à chaque époque.
  • tf.data.Dataset.prefetch : il est recommandé de terminer le pipeline en effectuant une prélecture pour des raisons de performances .
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

Créer un pipeline d'évaluation

Votre pipeline de test est similaire au pipeline de formation avec de petites différences :

  • Vous n'avez pas besoin d'appeler tf.data.Dataset.shuffle .
  • La mise en cache est effectuée après le traitement par lot car les lots peuvent être les mêmes d'une époque à l'autre.
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

Étape 2 : Créer et entraîner le modèle

Branchez le pipeline d'entrée TFDS dans un modèle Keras simple, compilez le modèle et entraînez-le.

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)
Epoch 1/6
469/469 [==============================] - 5s 4ms/step - loss: 0.3503 - sparse_categorical_accuracy: 0.9053 - val_loss: 0.1979 - val_sparse_categorical_accuracy: 0.9415
Epoch 2/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1668 - sparse_categorical_accuracy: 0.9524 - val_loss: 0.1392 - val_sparse_categorical_accuracy: 0.9595
Epoch 3/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1216 - sparse_categorical_accuracy: 0.9657 - val_loss: 0.1120 - val_sparse_categorical_accuracy: 0.9653
Epoch 4/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0939 - sparse_categorical_accuracy: 0.9726 - val_loss: 0.0960 - val_sparse_categorical_accuracy: 0.9704
Epoch 5/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0757 - sparse_categorical_accuracy: 0.9781 - val_loss: 0.0928 - val_sparse_categorical_accuracy: 0.9717
Epoch 6/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0625 - sparse_categorical_accuracy: 0.9818 - val_loss: 0.0851 - val_sparse_categorical_accuracy: 0.9728
<keras.callbacks.History at 0x7f77b42cd910>