Ver en TensorFlow.org | Ejecutar en Google Colab | Ver fuente en GitHub | Descargar libreta |
Descripción general
Las GPU y TPU pueden reducir radicalmente el tiempo necesario para ejecutar un solo paso de entrenamiento. Lograr el máximo rendimiento requiere una canalización de entrada eficiente que entregue datos para el siguiente paso antes de que finalice el paso actual. La API tf.data
ayuda a construir canalizaciones de entrada flexibles y eficientes. Este documento demuestra cómo usar la API tf.data
para crear canalizaciones de entrada de TensorFlow de alto rendimiento.
Antes de continuar, consulta la guía Construir canalizaciones de entrada de TensorFlow para aprender a usar la API tf.data
.
Recursos
- Cree canalizaciones de entrada de TensorFlow
- API
tf.data.Dataset
- Analice el rendimiento de
tf.data
con TF Profiler
Configuración
import tensorflow as tf
import time
A lo largo de esta guía, iterará a través de un conjunto de datos y medirá el rendimiento. Hacer puntos de referencia de rendimiento reproducibles puede ser difícil. Los diferentes factores que afectan la reproducibilidad incluyen:
- La carga actual de la CPU
- el trafico de la red
- Mecanismos complejos, como caché
Para obtener un punto de referencia reproducible, construirá un ejemplo artificial.
el conjunto de datos
Comience con la definición de una clase heredada de tf.data.Dataset
denominada ArtificialDataset
. Este conjunto de datos:
- Genera
num_samples
muestras (el valor predeterminado es 3) - Duerme durante algún tiempo antes del primer elemento para simular la apertura de un archivo
- Duerme durante un tiempo antes de producir cada elemento para simular la lectura de datos de un archivo
class ArtificialDataset(tf.data.Dataset):
def _generator(num_samples):
# Opening the file
time.sleep(0.03)
for sample_idx in range(num_samples):
# Reading data (line, record) from the file
time.sleep(0.015)
yield (sample_idx,)
def __new__(cls, num_samples=3):
return tf.data.Dataset.from_generator(
cls._generator,
output_signature = tf.TensorSpec(shape = (1,), dtype = tf.int64),
args=(num_samples,)
)
Este conjunto de datos es similar al tf.data.Dataset.range
, agregando un retraso fijo al comienzo y entre cada muestra.
El circuito de entrenamiento
A continuación, escriba un bucle de entrenamiento ficticio que mida cuánto tiempo se tarda en iterar sobre un conjunto de datos. Se simula el tiempo de entrenamiento.
def benchmark(dataset, num_epochs=2):
start_time = time.perf_counter()
for epoch_num in range(num_epochs):
for sample in dataset:
# Performing a training step
time.sleep(0.01)
print("Execution time:", time.perf_counter() - start_time)
Optimizar el rendimiento
Para mostrar cómo se puede optimizar el rendimiento, mejorará el rendimiento de ArtificialDataset
.
El enfoque ingenuo
Comience con una canalización ingenua sin trucos, iterando sobre el conjunto de datos tal como está.
benchmark(ArtificialDataset())
Execution time: 0.26497629899995445
Bajo el capó, así es como se gastó su tiempo de ejecución:
El gráfico muestra que realizar un paso de entrenamiento implica:
- Abrir un archivo si aún no se ha abierto
- Obtener una entrada de datos del archivo
- Uso de los datos para el entrenamiento.
Sin embargo, en una implementación síncrona ingenua como aquí, mientras su canalización obtiene los datos, su modelo está inactivo. Por el contrario, mientras su modelo se está entrenando, la canalización de entrada está inactiva. El tiempo del paso de entrenamiento es, por lo tanto, la suma de los tiempos de apertura, lectura y entrenamiento.
Las siguientes secciones se basan en esta canalización de entrada e ilustran las mejores prácticas para diseñar canalizaciones de entrada de TensorFlow de alto rendimiento.
captación previa
La captación previa superpone el preprocesamiento y la ejecución del modelo de un paso de entrenamiento. Mientras el modelo ejecuta el paso de entrenamiento s
, la canalización de entrada lee los datos del paso s+1
. Al hacerlo, se reduce el tiempo de paso al máximo (a diferencia de la suma) del entrenamiento y el tiempo que se tarda en extraer los datos.
La API tf.data
proporciona la transformación tf.data.Dataset.prefetch
. Se puede utilizar para desacoplar el momento en que se producen los datos del momento en que se consumen. En particular, la transformación utiliza un subproceso en segundo plano y un búfer interno para obtener elementos del conjunto de datos de entrada antes de que se soliciten. El número de elementos para precargar debe ser igual (o posiblemente mayor que) el número de lotes consumidos por un solo paso de entrenamiento. Puede ajustar manualmente este valor o establecerlo en tf.data.AUTOTUNE
, lo que hará que el tiempo de ejecución de tf.data
ajuste el valor dinámicamente en el tiempo de ejecución.
Tenga en cuenta que la transformación de captación previa proporciona beneficios cada vez que existe la oportunidad de superponer el trabajo de un "productor" con el trabajo de un "consumidor".
benchmark(
ArtificialDataset()
.prefetch(tf.data.AUTOTUNE)
)
Execution time: 0.21731788600027357
Ahora, como muestra el gráfico de tiempo de ejecución de datos, mientras se ejecuta el paso de entrenamiento para la muestra 0, la canalización de entrada lee los datos para la muestra 1, y así sucesivamente.
Paralelización de la extracción de datos
En una configuración del mundo real, los datos de entrada se pueden almacenar de forma remota (por ejemplo, en Google Cloud Storage o HDFS). Una canalización de conjunto de datos que funciona bien al leer datos localmente puede convertirse en un cuello de botella en la E/S al leer datos de forma remota debido a las siguientes diferencias entre el almacenamiento local y remoto:
- Tiempo hasta el primer byte : la lectura del primer byte de un archivo desde el almacenamiento remoto puede tardar varios órdenes de magnitud más que desde el almacenamiento local.
- Rendimiento de lectura : si bien el almacenamiento remoto generalmente ofrece un gran ancho de banda agregado, es posible que la lectura de un solo archivo solo pueda utilizar una pequeña fracción de este ancho de banda.
Además, una vez que los bytes sin procesar se cargan en la memoria, también puede ser necesario deserializar y/o descifrar los datos (por ejemplo, protobuf ), lo que requiere un cálculo adicional. Esta sobrecarga está presente independientemente de si los datos se almacenan local o remotamente, pero puede ser peor en el caso remoto si los datos no se obtienen previamente de manera efectiva.
Para mitigar el impacto de los diversos gastos generales de extracción de datos, se puede usar la transformación tf.data.Dataset.interleave
para paralelizar el paso de carga de datos, intercalando el contenido de otros conjuntos de datos (como lectores de archivos de datos). El número de conjuntos de datos que se superpondrán se puede especificar con el argumento cycle_length
, mientras que el nivel de paralelismo se puede especificar con el argumento num_parallel_calls
. De forma similar a la transformación de tf.data.AUTOTUNE
, la transformación de interleave
admite prefetch
, que delegará la decisión sobre qué nivel de paralelismo usar en el tiempo de ejecución de tf.data
.
intercalado secuencial
Los argumentos predeterminados de la transformación tf.data.Dataset.interleave
hacen que intercale muestras individuales de dos conjuntos de datos de forma secuencial.
benchmark(
tf.data.Dataset.range(2)
.interleave(lambda _: ArtificialDataset())
)
Execution time: 0.4987426460002098
Este gráfico de tiempo de ejecución de datos permite exhibir el comportamiento de la transformación interleave
, obteniendo muestras alternativamente de los dos conjuntos de datos disponibles. Sin embargo, aquí no se trata de una mejora del rendimiento.
intercalado paralelo
Ahora, use el argumento num_parallel_calls
de la transformación interleave
. Esto carga varios conjuntos de datos en paralelo, lo que reduce el tiempo de espera para abrir los archivos.
benchmark(
tf.data.Dataset.range(2)
.interleave(
lambda _: ArtificialDataset(),
num_parallel_calls=tf.data.AUTOTUNE
)
)
Execution time: 0.283668874000341
Esta vez, como muestra el diagrama de tiempo de ejecución de datos, la lectura de los dos conjuntos de datos se realiza en paralelo, lo que reduce el tiempo de procesamiento de datos global.
Transformación de datos en paralelo
Al preparar los datos, es posible que sea necesario preprocesar los elementos de entrada. Con este fin, la API tf.data
ofrece la transformación tf.data.Dataset.map
, que aplica una función definida por el usuario a cada elemento del conjunto de datos de entrada. Debido a que los elementos de entrada son independientes entre sí, el preprocesamiento se puede paralelizar en varios núcleos de CPU. Para que esto sea posible, de manera similar a las transformaciones de num_parallel_calls
e interleave
, la transformación de map
proporciona el argumento prefetch
para especificar el nivel de paralelismo.
Elegir el mejor valor para el argumento num_parallel_calls
depende de su hardware, las características de sus datos de entrenamiento (como su tamaño y forma), el costo de su función de mapa y qué otro procesamiento está ocurriendo en la CPU al mismo tiempo. Una heurística simple es usar la cantidad de núcleos de CPU disponibles. Sin embargo, en cuanto a la transformación de búsqueda tf.data.AUTOTUNE
e interleave
, la transformación de map
admite prefetch
, que delegará la decisión sobre qué nivel de paralelismo usar en el tiempo de ejecución de tf.data
.
def mapped_function(s):
# Do some hard pre-processing
tf.py_function(lambda: time.sleep(0.03), [], ())
return s
Mapeo secuencial
Comience utilizando la transformación de map
sin paralelismo como ejemplo de referencia.
benchmark(
ArtificialDataset()
.map(mapped_function)
)
Execution time: 0.4505277170001136
En cuanto al enfoque ingenuo , aquí, como muestra la trama, los tiempos dedicados a la apertura, lectura, preprocesamiento (mapeo) y pasos de entrenamiento se suman para una sola iteración.
Mapeo paralelo
Ahora, use la misma función de preprocesamiento pero aplíquela en paralelo en varias muestras.
benchmark(
ArtificialDataset()
.map(
mapped_function,
num_parallel_calls=tf.data.AUTOTUNE
)
)
Execution time: 0.2839677860001757
Como demuestra el gráfico de datos, los pasos de preprocesamiento se superponen, lo que reduce el tiempo total para una única iteración.
almacenamiento en caché
La transformación tf.data.Dataset.cache
puede almacenar en caché un conjunto de datos, ya sea en la memoria o en el almacenamiento local. Esto evitará que algunas operaciones (como la apertura de archivos y la lectura de datos) se ejecuten durante cada época.
benchmark(
ArtificialDataset()
.map( # Apply time consuming operations before cache
mapped_function
).cache(
),
5
)
Execution time: 0.3848854380003104
Aquí, el gráfico de tiempo de ejecución de datos muestra que cuando almacena en caché un conjunto de datos, las transformaciones anteriores a la cache
(como la apertura del archivo y la lectura de datos) se ejecutan solo durante la primera época. Las próximas épocas reutilizarán los datos almacenados en caché por la transformación de cache
.
Si la función definida por el usuario que se pasa a la transformación del map
es costosa, aplique la transformación de cache
después de la transformación del map
, siempre que el conjunto de datos resultante aún quepa en la memoria o en el almacenamiento local. Si la función definida por el usuario aumenta el espacio requerido para almacenar el conjunto de datos más allá de la capacidad de la memoria caché, aplíquela después de la transformación de la cache
o considere procesar previamente sus datos antes de su trabajo de entrenamiento para reducir el uso de recursos.
Mapeo de vectorización
La invocación de una función definida por el usuario pasada a la transformación del map
tiene una sobrecarga relacionada con la programación y ejecución de la función definida por el usuario. Vectorice la función definida por el usuario (es decir, haga que opere sobre un lote de entradas a la vez) y aplique la transformación por batch
antes de la transformación del map
.
Para ilustrar esta buena práctica, su conjunto de datos artificial no es adecuado. El retraso de la programación es de alrededor de 10 microsegundos (10e-6 segundos), mucho menos que las decenas de milisegundos que se usan en el ArtificialDataset
y, por lo tanto, su impacto es difícil de ver.
Para este ejemplo, use la función base tf.data.Dataset.range
y simplifique el ciclo de entrenamiento a su forma más simple.
fast_dataset = tf.data.Dataset.range(10000)
def fast_benchmark(dataset, num_epochs=2):
start_time = time.perf_counter()
for _ in tf.data.Dataset.range(num_epochs):
for _ in dataset:
pass
tf.print("Execution time:", time.perf_counter() - start_time)
def increment(x):
return x+1
Mapeo escalar
fast_benchmark(
fast_dataset
# Apply function one item at a time
.map(increment)
# Batch
.batch(256)
)
Execution time: 0.2712608739998359
La gráfica anterior ilustra lo que está sucediendo (con menos muestras) usando el método de mapeo escalar. Muestra que la función mapeada se aplica para cada muestra. Si bien esta función es muy rápida, tiene algunos gastos generales que afectan el rendimiento del tiempo.
Mapeo vectorizado
fast_benchmark(
fast_dataset
.batch(256)
# Apply function on a batch of items
# The tf.Tensor.__add__ method already handle batches
.map(increment)
)
Execution time: 0.02737950600021577
Esta vez, la función asignada se llama una vez y se aplica a un lote de muestra. Como muestra el diagrama de tiempo de ejecución de datos, si bien la función podría tardar más en ejecutarse, la sobrecarga aparece solo una vez, lo que mejora el rendimiento general del tiempo.
Reducción de la huella de memoria
Varias transformaciones, incluidas interleave
, prefetch
y shuffle
, mantienen un búfer interno de elementos. Si la función definida por el usuario que se pasa a la transformación del map
cambia el tamaño de los elementos, el orden de la transformación del mapa y las transformaciones que almacenan elementos en el búfer afectan el uso de la memoria. En general, elija el orden que resulte en un uso de memoria más bajo, a menos que se desee un orden diferente para el rendimiento.
Almacenamiento en caché de cálculos parciales
Se recomienda almacenar en caché el conjunto de datos después de la transformación del map
, excepto si esta transformación hace que los datos sean demasiado grandes para caber en la memoria. Se puede lograr una compensación si su función mapeada se puede dividir en dos partes: una parte que consume tiempo y otra que consume memoria. En este caso, puede encadenar sus transformaciones como se muestra a continuación:
dataset.map(time_consuming_mapping).cache().map(memory_consuming_mapping)
De esta manera, la parte que consume mucho tiempo solo se ejecuta durante la primera época y evita usar demasiado espacio de caché.
Resumen de mejores prácticas
Aquí hay un resumen de las mejores prácticas para diseñar canalizaciones de entrada TensorFlow de alto rendimiento:
- Utilice la
prefetch
de captación previa para superponer el trabajo de un productor y un consumidor - Paralelice la transformación de lectura de datos mediante la transformación
interleave
- Paralelice la transformación del
map
configurando el argumentonum_parallel_calls
- Use la
cache
de caché para almacenar en caché los datos en la memoria durante la primera época - Vectorizar funciones definidas por el usuario pasadas a la transformación del
map
- Reduzca el uso de la memoria al aplicar las transformaciones
interleave
,prefetch
yshuffle
reproduciendo las figuras
Para profundizar en la comprensión de la API tf.data.Dataset
, puede jugar con sus propias canalizaciones. A continuación se muestra el código utilizado para trazar las imágenes de esta guía. Puede ser un buen punto de partida, mostrando algunas soluciones para dificultades comunes como:
- Reproducibilidad del tiempo de ejecución
- Funciones mapeadas ejecución ansiosa
- transformación
interleave
invocable
import itertools
from collections import defaultdict
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
el conjunto de datos
De forma similar al ArtificialDataset
, puede crear un conjunto de datos que devuelva el tiempo empleado en cada paso.
class TimeMeasuredDataset(tf.data.Dataset):
# OUTPUT: (steps, timings, counters)
OUTPUT_TYPES = (tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32)
OUTPUT_SHAPES = ((2, 1), (2, 2), (2, 3))
_INSTANCES_COUNTER = itertools.count() # Number of datasets generated
_EPOCHS_COUNTER = defaultdict(itertools.count) # Number of epochs done for each dataset
def _generator(instance_idx, num_samples):
epoch_idx = next(TimeMeasuredDataset._EPOCHS_COUNTER[instance_idx])
# Opening the file
open_enter = time.perf_counter()
time.sleep(0.03)
open_elapsed = time.perf_counter() - open_enter
for sample_idx in range(num_samples):
# Reading data (line, record) from the file
read_enter = time.perf_counter()
time.sleep(0.015)
read_elapsed = time.perf_counter() - read_enter
yield (
[("Open",), ("Read",)],
[(open_enter, open_elapsed), (read_enter, read_elapsed)],
[(instance_idx, epoch_idx, -1), (instance_idx, epoch_idx, sample_idx)]
)
open_enter, open_elapsed = -1., -1. # Negative values will be filtered
def __new__(cls, num_samples=3):
return tf.data.Dataset.from_generator(
cls._generator,
output_types=cls.OUTPUT_TYPES,
output_shapes=cls.OUTPUT_SHAPES,
args=(next(cls._INSTANCES_COUNTER), num_samples)
)
Este conjunto de datos proporciona muestras de forma [[2, 1], [2, 2], [2, 3]]
y de tipo [tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32]
. Cada muestra es:
(
[("Open"), ("Read")],
[(t0, d), (t0, d)],
[(i, e, -1), (i, e, s)]
)
Donde:
-
Open
yRead
son identificadores de pasos -
t0
es la marca de tiempo cuando comenzó el paso correspondiente -
d
es el tiempo empleado en el paso correspondiente -
i
es el índice de instancia -
e
es el índice de época (número de veces que se ha iterado el conjunto de datos) -
s
es el índice de la muestra
El bucle de iteración
Haga que el ciclo de iteración sea un poco más complicado para agregar todos los tiempos. Esto solo funcionará con conjuntos de datos que generen muestras como se detalla anteriormente.
def timelined_benchmark(dataset, num_epochs=2):
# Initialize accumulators
steps_acc = tf.zeros([0, 1], dtype=tf.dtypes.string)
times_acc = tf.zeros([0, 2], dtype=tf.dtypes.float32)
values_acc = tf.zeros([0, 3], dtype=tf.dtypes.int32)
start_time = time.perf_counter()
for epoch_num in range(num_epochs):
epoch_enter = time.perf_counter()
for (steps, times, values) in dataset:
# Record dataset preparation informations
steps_acc = tf.concat((steps_acc, steps), axis=0)
times_acc = tf.concat((times_acc, times), axis=0)
values_acc = tf.concat((values_acc, values), axis=0)
# Simulate training time
train_enter = time.perf_counter()
time.sleep(0.01)
train_elapsed = time.perf_counter() - train_enter
# Record training informations
steps_acc = tf.concat((steps_acc, [["Train"]]), axis=0)
times_acc = tf.concat((times_acc, [(train_enter, train_elapsed)]), axis=0)
values_acc = tf.concat((values_acc, [values[-1]]), axis=0)
epoch_elapsed = time.perf_counter() - epoch_enter
# Record epoch informations
steps_acc = tf.concat((steps_acc, [["Epoch"]]), axis=0)
times_acc = tf.concat((times_acc, [(epoch_enter, epoch_elapsed)]), axis=0)
values_acc = tf.concat((values_acc, [[-1, epoch_num, -1]]), axis=0)
time.sleep(0.001)
tf.print("Execution time:", time.perf_counter() - start_time)
return {"steps": steps_acc, "times": times_acc, "values": values_acc}
El método de trazado
Finalmente, defina una función capaz de trazar una línea de tiempo dados los valores devueltos por la función timelined_benchmark
.
def draw_timeline(timeline, title, width=0.5, annotate=False, save=False):
# Remove invalid entries (negative times, or empty steps) from the timelines
invalid_mask = np.logical_and(timeline['times'] > 0, timeline['steps'] != b'')[:,0]
steps = timeline['steps'][invalid_mask].numpy()
times = timeline['times'][invalid_mask].numpy()
values = timeline['values'][invalid_mask].numpy()
# Get a set of different steps, ordered by the first time they are encountered
step_ids, indices = np.stack(np.unique(steps, return_index=True))
step_ids = step_ids[np.argsort(indices)]
# Shift the starting time to 0 and compute the maximal time value
min_time = times[:,0].min()
times[:,0] = (times[:,0] - min_time)
end = max(width, (times[:,0]+times[:,1]).max() + 0.01)
cmap = mpl.cm.get_cmap("plasma")
plt.close()
fig, axs = plt.subplots(len(step_ids), sharex=True, gridspec_kw={'hspace': 0})
fig.suptitle(title)
fig.set_size_inches(17.0, len(step_ids))
plt.xlim(-0.01, end)
for i, step in enumerate(step_ids):
step_name = step.decode()
ax = axs[i]
ax.set_ylabel(step_name)
ax.set_ylim(0, 1)
ax.set_yticks([])
ax.set_xlabel("time (s)")
ax.set_xticklabels([])
ax.grid(which="both", axis="x", color="k", linestyle=":")
# Get timings and annotation for the given step
entries_mask = np.squeeze(steps==step)
serie = np.unique(times[entries_mask], axis=0)
annotations = values[entries_mask]
ax.broken_barh(serie, (0, 1), color=cmap(i / len(step_ids)), linewidth=1, alpha=0.66)
if annotate:
for j, (start, width) in enumerate(serie):
annotation = "\n".join([f"{l}: {v}" for l,v in zip(("i", "e", "s"), annotations[j])])
ax.text(start + 0.001 + (0.001 * (j % 2)), 0.55 - (0.1 * (j % 2)), annotation,
horizontalalignment='left', verticalalignment='center')
if save:
plt.savefig(title.lower().translate(str.maketrans(" ", "_")) + ".svg")
Usar envoltorios para la función asignada
Para ejecutar la función asignada en un contexto ansioso, debe envolverlos dentro de una llamada tf.py_function
.
def map_decorator(func):
def wrapper(steps, times, values):
# Use a tf.py_function to prevent auto-graph from compiling the method
return tf.py_function(
func,
inp=(steps, times, values),
Tout=(steps.dtype, times.dtype, values.dtype)
)
return wrapper
Comparación de tuberías
_batch_map_num_items = 50
def dataset_generator_fun(*args):
return TimeMeasuredDataset(num_samples=_batch_map_num_items)
Ingenuo
@map_decorator
def naive_map(steps, times, values):
map_enter = time.perf_counter()
time.sleep(0.001) # Time consuming step
time.sleep(0.0001) # Memory consuming step
map_elapsed = time.perf_counter() - map_enter
return (
tf.concat((steps, [["Map"]]), axis=0),
tf.concat((times, [[map_enter, map_elapsed]]), axis=0),
tf.concat((values, [values[-1]]), axis=0)
)
naive_timeline = timelined_benchmark(
tf.data.Dataset.range(2)
.flat_map(dataset_generator_fun)
.map(naive_map)
.batch(_batch_map_num_items, drop_remainder=True)
.unbatch(),
5
)
WARNING:tensorflow:From /tmp/ipykernel_23983/64197174.py:36: calling DatasetV2.from_generator (from tensorflow.python.data.ops.dataset_ops) with output_types is deprecated and will be removed in a future version. Instructions for updating: Use output_signature instead WARNING:tensorflow:From /tmp/ipykernel_23983/64197174.py:36: calling DatasetV2.from_generator (from tensorflow.python.data.ops.dataset_ops) with output_shapes is deprecated and will be removed in a future version. Instructions for updating: Use output_signature instead Execution time: 13.13538893499981
optimizado
@map_decorator
def time_consuming_map(steps, times, values):
map_enter = time.perf_counter()
time.sleep(0.001 * values.shape[0]) # Time consuming step
map_elapsed = time.perf_counter() - map_enter
return (
tf.concat((steps, tf.tile([[["1st map"]]], [steps.shape[0], 1, 1])), axis=1),
tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
)
@map_decorator
def memory_consuming_map(steps, times, values):
map_enter = time.perf_counter()
time.sleep(0.0001 * values.shape[0]) # Memory consuming step
map_elapsed = time.perf_counter() - map_enter
# Use tf.tile to handle batch dimension
return (
tf.concat((steps, tf.tile([[["2nd map"]]], [steps.shape[0], 1, 1])), axis=1),
tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
)
optimized_timeline = timelined_benchmark(
tf.data.Dataset.range(2)
.interleave( # Parallelize data reading
dataset_generator_fun,
num_parallel_calls=tf.data.AUTOTUNE
)
.batch( # Vectorize your mapped function
_batch_map_num_items,
drop_remainder=True)
.map( # Parallelize map transformation
time_consuming_map,
num_parallel_calls=tf.data.AUTOTUNE
)
.cache() # Cache data
.map( # Reduce memory usage
memory_consuming_map,
num_parallel_calls=tf.data.AUTOTUNE
)
.prefetch( # Overlap producer and consumer works
tf.data.AUTOTUNE
)
.unbatch(),
5
)
Execution time: 6.723691489999965
draw_timeline(naive_timeline, "Naive", 15)
draw_timeline(optimized_timeline, "Optimized", 15)