تدريب شبكة عصبية على MNIST مع Keras

يوضح هذا المثال البسيط كيفية توصيل مجموعات بيانات TensorFlow (TFDS) في نموذج Keras.

عرض على TensorFlow.org تشغيل في Google Colab عرض المصدر على جيثب تحميل دفتر
import tensorflow as tf
import tensorflow_datasets as tfds

الخطوة 1: قم بإنشاء خط أنابيب الإدخال الخاص بك

ابدأ ببناء خط أنابيب فعال باستخدام نصائح من:

قم بتحميل مجموعة بيانات

قم بتحميل مجموعة بيانات MNIST بالوسيطات التالية:

  • shuffle_files=True : يتم تخزين بيانات MNIST في ملف واحد فقط ، ولكن بالنسبة لمجموعات البيانات الأكبر التي تحتوي على ملفات متعددة على القرص ، فمن الجيد تبديلها أثناء التدريب.
  • as_supervised=True : إرجاع مجموعة (img, label) بدلاً من القاموس {'image': img, 'label': label} .
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)
2022-02-07 04:05:46.671689: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

بناء خط أنابيب التدريب

قم بتطبيق التحولات التالية:

  • tf.data.Dataset.map : توفر TFDS صورًا من النوع tf.uint8 ، بينما يتوقع النموذج tf.float32 . لذلك ، تحتاج إلى تطبيع الصور.
  • tf.data.Dataset.cache عندما تلائم مجموعة البيانات في الذاكرة ، قم بتخزينها مؤقتًا قبل خلطها للحصول على أداء أفضل.
    ملاحظة: يجب تطبيق التحويلات العشوائية بعد التخزين المؤقت.
  • tf.data.Dataset.shuffle : للحصول على عشوائية حقيقية ، اضبط المخزن المؤقت العشوائي على الحجم الكامل لمجموعة البيانات.
    ملاحظة: بالنسبة لمجموعات البيانات الكبيرة التي لا تتسع في الذاكرة ، استخدم buffer_size=1000 إذا كان نظامك يسمح بذلك.
  • tf.data.Dataset.batch : مجموعة عناصر مجموعة البيانات بعد الخلط للحصول على دفعات فريدة في كل فترة.
  • tf.data.Dataset.prefetch : من الممارسات الجيدة إنهاء خط الأنابيب عن طريق الجلب المسبق للأداء .
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

بناء خط أنابيب التقييم

يشبه خط أنابيب الاختبار الخاص بك خط أنابيب التدريب مع وجود اختلافات صغيرة:

  • لا تحتاج إلى الاتصال بـ tf.data.Dataset.shuffle .
  • يتم التخزين المؤقت بعد التجميع لأن الدُفعات يمكن أن تكون هي نفسها بين العصور.
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

الخطوة 2: إنشاء النموذج وتدريبه

قم بتوصيل خط أنابيب إدخال TFDS في نموذج Keras البسيط ، وقم بتجميع النموذج ، وقم بتدريبه.

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)
Epoch 1/6
469/469 [==============================] - 5s 4ms/step - loss: 0.3503 - sparse_categorical_accuracy: 0.9053 - val_loss: 0.1979 - val_sparse_categorical_accuracy: 0.9415
Epoch 2/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1668 - sparse_categorical_accuracy: 0.9524 - val_loss: 0.1392 - val_sparse_categorical_accuracy: 0.9595
Epoch 3/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1216 - sparse_categorical_accuracy: 0.9657 - val_loss: 0.1120 - val_sparse_categorical_accuracy: 0.9653
Epoch 4/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0939 - sparse_categorical_accuracy: 0.9726 - val_loss: 0.0960 - val_sparse_categorical_accuracy: 0.9704
Epoch 5/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0757 - sparse_categorical_accuracy: 0.9781 - val_loss: 0.0928 - val_sparse_categorical_accuracy: 0.9717
Epoch 6/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0625 - sparse_categorical_accuracy: 0.9818 - val_loss: 0.0851 - val_sparse_categorical_accuracy: 0.9728
<keras.callbacks.History at 0x7f77b42cd910>