آموزش شبکه عصبی در MNIST با Keras

این مثال ساده نحوه اتصال TensorFlow Datasets (TFDS) را به یک مدل Keras نشان می دهد.

مشاهده در TensorFlow.org در Google Colab اجرا شود مشاهده منبع در GitHub دانلود دفترچه یادداشت
import tensorflow as tf
import tensorflow_datasets as tfds

مرحله 1: خط لوله ورودی خود را ایجاد کنید

با ایجاد یک خط لوله ورودی کارآمد با استفاده از توصیه های زیر شروع کنید:

یک مجموعه داده را بارگیری کنید

مجموعه داده MNIST را با آرگومان های زیر بارگیری کنید:

  • shuffle_files=True : داده‌های MNIST فقط در یک فایل ذخیره می‌شوند، اما برای مجموعه داده‌های بزرگتر با چندین فایل روی دیسک، بهتر است هنگام آموزش، آنها را به هم بزنند.
  • as_supervised=True : یک تاپل (img, label) را به جای دیکشنری {'image': img, 'label': label} برمی گرداند.
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)
2022-02-07 04:05:46.671689: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

یک خط لوله آموزشی بسازید

تبدیل های زیر را اعمال کنید:

  • tf.data.Dataset.map : TFDS تصاویری از نوع tf.uint8 را ارائه می دهد، در حالی که مدل انتظار tf.float32 را دارد. بنابراین، شما باید تصاویر را عادی کنید.
  • tf.data.Dataset.cache همانطور که مجموعه داده را در حافظه جا می دهید، قبل از به هم زدن برای عملکرد بهتر، آن را در حافظه پنهان ذخیره کنید.
    توجه: تبدیل های تصادفی باید پس از ذخیره سازی اعمال شوند.
  • tf.data.Dataset.shuffle : برای تصادفی بودن واقعی، بافر shuffle را روی اندازه کامل مجموعه داده تنظیم کنید.
    توجه: برای مجموعه داده های بزرگی که نمی توانند در حافظه جا شوند، اگر سیستم شما اجازه می دهد از buffer_size=1000 استفاده کنید.
  • tf.data.Dataset.batch : عناصر دسته ای مجموعه داده پس از به هم زدن برای به دست آوردن دسته های منحصر به فرد در هر دوره.
  • tf.data.Dataset.prefetch : پایان دادن به خط لوله با واکشی اولیه برای عملکرد ، تمرین خوبی است.
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

یک خط لوله ارزیابی بسازید

خط لوله آزمایشی شما مشابه خط لوله آموزشی با تفاوت های کوچک است:

  • نیازی نیست با tf.data.Dataset.shuffle تماس بگیرید.
  • ذخیره سازی پس از بچینگ انجام می شود زیرا دسته ها می توانند بین دوره ها یکسان باشند.
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

مرحله 2: مدل را ایجاد و آموزش دهید

خط لوله ورودی TFDS را به یک مدل ساده Keras وصل کنید، مدل را کامپایل کنید و آن را آموزش دهید.

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)
Epoch 1/6
469/469 [==============================] - 5s 4ms/step - loss: 0.3503 - sparse_categorical_accuracy: 0.9053 - val_loss: 0.1979 - val_sparse_categorical_accuracy: 0.9415
Epoch 2/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1668 - sparse_categorical_accuracy: 0.9524 - val_loss: 0.1392 - val_sparse_categorical_accuracy: 0.9595
Epoch 3/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1216 - sparse_categorical_accuracy: 0.9657 - val_loss: 0.1120 - val_sparse_categorical_accuracy: 0.9653
Epoch 4/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0939 - sparse_categorical_accuracy: 0.9726 - val_loss: 0.0960 - val_sparse_categorical_accuracy: 0.9704
Epoch 5/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0757 - sparse_categorical_accuracy: 0.9781 - val_loss: 0.0928 - val_sparse_categorical_accuracy: 0.9717
Epoch 6/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0625 - sparse_categorical_accuracy: 0.9818 - val_loss: 0.0851 - val_sparse_categorical_accuracy: 0.9728
<keras.callbacks.History at 0x7f77b42cd910>