Đào tạo mạng nơ-ron trên MNIST với Keras

Ví dụ đơn giản này trình bày cách cắm Bộ dữ liệu TensorFlow (TFDS) vào một mô hình Keras.

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ ghi chép
import tensorflow as tf
import tensorflow_datasets as tfds

Bước 1: Tạo đường dẫn đầu vào của bạn

Bắt đầu bằng cách xây dựng một đường dẫn đầu vào hiệu quả bằng cách sử dụng lời khuyên từ:

Tải tập dữ liệu

Tải tập dữ liệu MNIST với các đối số sau:

  • shuffle_files=True : Dữ liệu MNIST chỉ được lưu trữ trong một tệp duy nhất, nhưng đối với các bộ dữ liệu lớn hơn có nhiều tệp trên đĩa, bạn nên xáo trộn chúng khi huấn luyện.
  • as_supervised=True : Trả về một tuple (img, label) thay vì từ điển {'image': img, 'label': label} .
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)
2022-02-07 04:05:46.671689: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

Xây dựng một đường dẫn đào tạo

Áp dụng các phép biến đổi sau:

  • tf.data.Dataset.map : TFDS cung cấp hình ảnh kiểu tf.uint8 , trong khi mô hình mong đợi tf.float32 . Do đó, bạn cần bình thường hóa hình ảnh.
  • tf.data.Dataset.cache Khi bạn vừa với tập dữ liệu trong bộ nhớ, hãy lưu vào bộ nhớ cache trước khi xáo trộn để có hiệu suất tốt hơn.
    Lưu ý: Các phép biến đổi ngẫu nhiên nên được áp dụng sau khi lưu vào bộ nhớ đệm.
  • tf.data.Dataset.shuffle : Để thực sự ngẫu nhiên, hãy đặt bộ đệm xáo trộn thành kích thước tập dữ liệu đầy đủ.
    Lưu ý: Đối với các tập dữ liệu lớn không thể vừa trong bộ nhớ, hãy sử dụng buffer_size=1000 nếu hệ thống của bạn cho phép.
  • tf.data.Dataset.batch : Hàng loạt phần tử của tập dữ liệu sau khi xáo trộn để có được các lô duy nhất tại mỗi kỷ nguyên.
  • tf.data.Dataset.prefetch : Bạn nên kết thúc đường dẫn bằng cách tìm nạp trước để đạt được hiệu suất .
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

Xây dựng quy trình đánh giá

Quy trình thử nghiệm của bạn tương tự như quy trình đào tạo với những khác biệt nhỏ:

  • Bạn không cần gọi tf.data.Dataset.shuffle .
  • Lưu vào bộ nhớ đệm được thực hiện sau khi theo lô vì các lô có thể giống nhau giữa các kỷ nguyên.
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

Bước 2: Tạo và đào tạo mô hình

Cắm đường ống đầu vào TFDS vào một mô hình Keras đơn giản, biên dịch mô hình và đào tạo nó.

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)
Epoch 1/6
469/469 [==============================] - 5s 4ms/step - loss: 0.3503 - sparse_categorical_accuracy: 0.9053 - val_loss: 0.1979 - val_sparse_categorical_accuracy: 0.9415
Epoch 2/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1668 - sparse_categorical_accuracy: 0.9524 - val_loss: 0.1392 - val_sparse_categorical_accuracy: 0.9595
Epoch 3/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1216 - sparse_categorical_accuracy: 0.9657 - val_loss: 0.1120 - val_sparse_categorical_accuracy: 0.9653
Epoch 4/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0939 - sparse_categorical_accuracy: 0.9726 - val_loss: 0.0960 - val_sparse_categorical_accuracy: 0.9704
Epoch 5/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0757 - sparse_categorical_accuracy: 0.9781 - val_loss: 0.0928 - val_sparse_categorical_accuracy: 0.9717
Epoch 6/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0625 - sparse_categorical_accuracy: 0.9818 - val_loss: 0.0851 - val_sparse_categorical_accuracy: 0.9728
<keras.callbacks.History at 0x7f77b42cd910>