Este sencillo ejemplo demuestra cómo conectar conjuntos de datos de TensorFlow (TFDS) en un modelo de Keras.
Ver en TensorFlow.org | Ejecutar en Google Colab | Ver fuente en GitHub | Descargar libreta |
import tensorflow as tf
import tensorflow_datasets as tfds
Paso 1: Cree su tubería de entrada
Comience por construir una tubería de entrada eficiente utilizando los consejos de:
Cargar un conjunto de datos
Cargue el conjunto de datos MNIST con los siguientes argumentos:
-
shuffle_files=True
: los datos de MNIST solo se almacenan en un solo archivo, pero para conjuntos de datos más grandes con varios archivos en el disco, es una buena práctica mezclarlos durante el entrenamiento. -
as_supervised=True
: Devuelve una tupla(img, label)
en lugar de un diccionario{'image': img, 'label': label}
.
(ds_train, ds_test), ds_info = tfds.load(
'mnist',
split=['train', 'test'],
shuffle_files=True,
as_supervised=True,
with_info=True,
)
2022-02-07 04:05:46.671689: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Cree una canalización de capacitación
Aplicar las siguientes transformaciones:
-
tf.data.Dataset.map
: TFDS proporciona imágenes de tipotf.uint8
, mientras que el modelo esperatf.float32
. Por lo tanto, necesita normalizar las imágenes. -
tf.data.Dataset.cache
A medida que ajusta el conjunto de datos en la memoria, colóquelo en caché antes de barajar para un mejor rendimiento.
Nota: Las transformaciones aleatorias deben aplicarse después del almacenamiento en caché. -
tf.data.Dataset.shuffle
: para una verdadera aleatoriedad, establezca el búfer de reproducción aleatoria en el tamaño completo del conjunto de datos.
Nota: Para grandes conjuntos de datos que no caben en la memoria, usebuffer_size=1000
si su sistema lo permite. -
tf.data.Dataset.batch
: Elementos de lote del conjunto de datos después de mezclar para obtener lotes únicos en cada época. -
tf.data.Dataset.prefetch
: es una buena práctica finalizar la canalización mediante la captación previa del rendimiento .
def normalize_img(image, label):
"""Normalizes images: `uint8` -> `float32`."""
return tf.cast(image, tf.float32) / 255., label
ds_train = ds_train.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)
Cree una canalización de evaluación
Su canalización de prueba es similar a la canalización de entrenamiento con pequeñas diferencias:
- No necesita llamar a
tf.data.Dataset.shuffle
. - El almacenamiento en caché se realiza después del procesamiento por lotes porque los lotes pueden ser los mismos entre épocas.
ds_test = ds_test.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)
Paso 2: crear y entrenar el modelo
Conecte la canalización de entrada de TFDS en un modelo Keras simple, compile el modelo y entrénelo.
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(
optimizer=tf.keras.optimizers.Adam(0.001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
model.fit(
ds_train,
epochs=6,
validation_data=ds_test,
)
Epoch 1/6 469/469 [==============================] - 5s 4ms/step - loss: 0.3503 - sparse_categorical_accuracy: 0.9053 - val_loss: 0.1979 - val_sparse_categorical_accuracy: 0.9415 Epoch 2/6 469/469 [==============================] - 1s 2ms/step - loss: 0.1668 - sparse_categorical_accuracy: 0.9524 - val_loss: 0.1392 - val_sparse_categorical_accuracy: 0.9595 Epoch 3/6 469/469 [==============================] - 1s 2ms/step - loss: 0.1216 - sparse_categorical_accuracy: 0.9657 - val_loss: 0.1120 - val_sparse_categorical_accuracy: 0.9653 Epoch 4/6 469/469 [==============================] - 1s 2ms/step - loss: 0.0939 - sparse_categorical_accuracy: 0.9726 - val_loss: 0.0960 - val_sparse_categorical_accuracy: 0.9704 Epoch 5/6 469/469 [==============================] - 1s 2ms/step - loss: 0.0757 - sparse_categorical_accuracy: 0.9781 - val_loss: 0.0928 - val_sparse_categorical_accuracy: 0.9717 Epoch 6/6 469/469 [==============================] - 1s 2ms/step - loss: 0.0625 - sparse_categorical_accuracy: 0.9818 - val_loss: 0.0851 - val_sparse_categorical_accuracy: 0.9728 <keras.callbacks.History at 0x7f77b42cd910>