Keras ile MNIST üzerinde bir sinir ağı eğitimi

Bu basit örnek, TensorFlow Veri Kümelerinin (TFDS) bir Keras modeline nasıl ekleneceğini gösterir.

TensorFlow.org'da görüntüleyin Google Colab'da çalıştırın Kaynağı GitHub'da görüntüleyin Not defterini indir
import tensorflow as tf
import tensorflow_datasets as tfds

1. Adım: Giriş işlem hattınızı oluşturun

Aşağıdaki tavsiyeleri kullanarak verimli bir girdi hattı oluşturarak başlayın:

Bir veri kümesi yükleyin

MNIST veri kümesini aşağıdaki bağımsız değişkenlerle yükleyin:

  • shuffle_files=True : MNIST verileri yalnızca tek bir dosyada saklanır, ancak diskte birden çok dosya bulunan daha büyük veri kümeleri için, bunları eğitim sırasında karıştırmak iyi bir uygulamadır.
  • as_supervised=True : Sözlük {'image': img, 'label': label} yerine bir demet (img, label) döndürür.
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)
tutucu2 l10n-yer
2022-02-07 04:05:46.671689: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

Bir eğitim hattı oluşturun

Aşağıdaki dönüşümleri uygulayın:

  • tf.data.Dataset.map : TFDS , tf.uint8 türünde görüntüler sağlarken model tf.float32 bekler . Bu nedenle, görüntüleri normalleştirmeniz gerekir.
  • tf.data.Dataset.cache Veri kümesini belleğe sığdırırken, daha iyi bir performans için karıştırmadan önce önbelleğe alın.
    Not: Önbelleğe alma işleminden sonra rastgele dönüşümler uygulanmalıdır.
  • tf.data.Dataset.shuffle : Gerçek rastgelelik için karıştırma arabelleğini tam veri kümesi boyutuna ayarlayın.
    Not: Belleğe sığamayan büyük veri kümeleri için, sisteminiz izin veriyorsa, buffer_size=1000 kullanın.
  • tf.data.Dataset.batch : Her çağda benzersiz gruplar elde etmek için karıştırmadan sonra veri kümesinin toplu öğeleri.
  • tf.data.Dataset.prefetch : Performans için önceden getirme yoluyla ardışık düzeni sonlandırmak iyi bir uygulamadır.
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

Bir değerlendirme ardışık düzeni oluşturun

Test hattınız, küçük farklarla eğitim hattına benzer:

  • tf.data.Dataset.shuffle çağırmanız gerekmez.
  • Toplu işlemden sonra önbelleğe alma yapılır, çünkü yığınlar dönemler arasında aynı olabilir.
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

2. Adım: Modeli oluşturun ve eğitin

TFDS giriş hattını basit bir Keras modeline takın, modeli derleyin ve eğitin.

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)
tutucu6 l10n-yer
Epoch 1/6
469/469 [==============================] - 5s 4ms/step - loss: 0.3503 - sparse_categorical_accuracy: 0.9053 - val_loss: 0.1979 - val_sparse_categorical_accuracy: 0.9415
Epoch 2/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1668 - sparse_categorical_accuracy: 0.9524 - val_loss: 0.1392 - val_sparse_categorical_accuracy: 0.9595
Epoch 3/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1216 - sparse_categorical_accuracy: 0.9657 - val_loss: 0.1120 - val_sparse_categorical_accuracy: 0.9653
Epoch 4/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0939 - sparse_categorical_accuracy: 0.9726 - val_loss: 0.0960 - val_sparse_categorical_accuracy: 0.9704
Epoch 5/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0757 - sparse_categorical_accuracy: 0.9781 - val_loss: 0.0928 - val_sparse_categorical_accuracy: 0.9717
Epoch 6/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0625 - sparse_categorical_accuracy: 0.9818 - val_loss: 0.0851 - val_sparse_categorical_accuracy: 0.9728
<keras.callbacks.History at 0x7f77b42cd910>