Keras के साथ MNIST पर एक तंत्रिका नेटवर्क का प्रशिक्षण

यह सरल उदाहरण दर्शाता है कि TensorFlow डेटासेट (TFDS) को Keras मॉडल में कैसे प्लग किया जाए।

TensorFlow.org पर देखें Google Colab में चलाएं GitHub पर स्रोत देखें नोटबुक डाउनलोड करें
import tensorflow as tf
import tensorflow_datasets as tfds

चरण 1: अपनी इनपुट पाइपलाइन बनाएं

निम्न की सलाह का उपयोग करके एक कुशल इनपुट पाइपलाइन बनाकर प्रारंभ करें:

डेटासेट लोड करें

MNIST डेटासेट को निम्न तर्कों के साथ लोड करें:

  • shuffle_files=True : MNIST डेटा केवल एक फ़ाइल में संग्रहीत होता है, लेकिन डिस्क पर एकाधिक फ़ाइलों वाले बड़े डेटासेट के लिए, प्रशिक्षण के दौरान उन्हें फेरबदल करना अच्छा अभ्यास है।
  • as_supervised=True : डिक्शनरी {'image': img, 'label': label} के बजाय एक टपल (img, label) देता है।
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)
2022-02-07 04:05:46.671689: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

एक प्रशिक्षण पाइपलाइन बनाएँ

निम्नलिखित परिवर्तन लागू करें:

  • tf.data.Dataset.map : TFDS tf.uint8 प्रकार की छवियां प्रदान करता है, जबकि मॉडल tf.float32 की अपेक्षा करता है। इसलिए, आपको छवियों को सामान्य करने की आवश्यकता है।
  • tf.data.Dataset.cache जैसे ही आप डेटासेट को मेमोरी में फिट करते हैं, बेहतर प्रदर्शन के लिए फेरबदल करने से पहले इसे कैश करें।
    नोट: कैशिंग के बाद यादृच्छिक परिवर्तन लागू किया जाना चाहिए।
  • tf.data.Dataset.shuffle : सही यादृच्छिकता के लिए, फेरबदल बफर को पूर्ण डेटासेट आकार पर सेट करें।
    नोट: बड़े डेटासेट के लिए जो मेमोरी में फ़िट नहीं हो सकते हैं, यदि आपका सिस्टम इसकी अनुमति देता है तो buffer_size=1000 का उपयोग करें।
  • tf.data.Dataset.batch : प्रत्येक युग में अद्वितीय बैच प्राप्त करने के लिए फेरबदल के बाद डेटासेट के बैच तत्व।
  • tf.data.Dataset.prefetch : प्रदर्शन के लिए प्रीफ़ेच करके पाइपलाइन को समाप्त करना अच्छा अभ्यास है।
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

एक मूल्यांकन पाइपलाइन बनाएँ

आपकी परीक्षण पाइपलाइन छोटे अंतरों के साथ प्रशिक्षण पाइपलाइन के समान है:

  • आपको tf.data.Dataset.shuffle पर कॉल करने की आवश्यकता नहीं है।
  • बैचिंग के बाद कैशिंग की जाती है क्योंकि बैचों के बीच बैच समान हो सकते हैं।
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

चरण 2: मॉडल बनाएं और प्रशिक्षित करें

TFDS इनपुट पाइपलाइन को एक साधारण Keras मॉडल में प्लग करें, मॉडल को संकलित करें और उसे प्रशिक्षित करें।

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)
Epoch 1/6
469/469 [==============================] - 5s 4ms/step - loss: 0.3503 - sparse_categorical_accuracy: 0.9053 - val_loss: 0.1979 - val_sparse_categorical_accuracy: 0.9415
Epoch 2/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1668 - sparse_categorical_accuracy: 0.9524 - val_loss: 0.1392 - val_sparse_categorical_accuracy: 0.9595
Epoch 3/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1216 - sparse_categorical_accuracy: 0.9657 - val_loss: 0.1120 - val_sparse_categorical_accuracy: 0.9653
Epoch 4/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0939 - sparse_categorical_accuracy: 0.9726 - val_loss: 0.0960 - val_sparse_categorical_accuracy: 0.9704
Epoch 5/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0757 - sparse_categorical_accuracy: 0.9781 - val_loss: 0.0928 - val_sparse_categorical_accuracy: 0.9717
Epoch 6/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0625 - sparse_categorical_accuracy: 0.9818 - val_loss: 0.0851 - val_sparse_categorical_accuracy: 0.9728
<keras.callbacks.History at 0x7f77b42cd910>