Visualizza su TensorFlow.org | Esegui in Google Colab | Visualizza l'origine su GitHub | Scarica quaderno |
Panoramica
GPU e TPU possono ridurre radicalmente il tempo necessario per eseguire un singolo passaggio di addestramento. Il raggiungimento delle massime prestazioni richiede una pipeline di input efficiente che fornisca i dati per il passaggio successivo prima che il passaggio corrente sia terminato. L'API tf.data
aiuta a creare pipeline di input flessibili ed efficienti. Questo documento mostra come utilizzare l'API tf.data
per creare pipeline di input TensorFlow ad alte prestazioni.
Prima di continuare, consulta la guida alle pipeline di input Build TensorFlow per informazioni su come utilizzare l'API tf.data
.
Risorse
- Crea pipeline di input TensorFlow
-
tf.data.Dataset
- Analizza le prestazioni di
tf.data
con TF Profiler
Impostare
import tensorflow as tf
import time
In questa guida, eseguirai l'iterazione su un set di dati e misurerai le prestazioni. Creare parametri di riferimento delle prestazioni riproducibili può essere difficile. Diversi fattori che influenzano la riproducibilità includono:
- Il carico attuale della CPU
- Il traffico di rete
- Meccanismi complessi, come la cache
Per ottenere un benchmark riproducibile, costruirai un esempio artificiale.
Il set di dati
Inizia con la definizione di una classe che eredita da tf.data.Dataset
chiamata ArtificialDataset
. Questo set di dati:
- Genera
num_samples
campioni (il valore predefinito è 3) - Dorme per qualche tempo prima del primo elemento per simulare l'apertura di un file
- Dorme per un po' prima di produrre ogni elemento per simulare la lettura dei dati da un file
class ArtificialDataset(tf.data.Dataset):
def _generator(num_samples):
# Opening the file
time.sleep(0.03)
for sample_idx in range(num_samples):
# Reading data (line, record) from the file
time.sleep(0.015)
yield (sample_idx,)
def __new__(cls, num_samples=3):
return tf.data.Dataset.from_generator(
cls._generator,
output_signature = tf.TensorSpec(shape = (1,), dtype = tf.int64),
args=(num_samples,)
)
Questo set di dati è simile a quello tf.data.Dataset.range
, aggiungendo un ritardo fisso all'inizio e tra ogni campione.
Il ciclo di formazione
Quindi, scrivi un ciclo di addestramento fittizio che misuri il tempo necessario per eseguire l'iterazione su un set di dati. Il tempo di allenamento è simulato.
def benchmark(dataset, num_epochs=2):
start_time = time.perf_counter()
for epoch_num in range(num_epochs):
for sample in dataset:
# Performing a training step
time.sleep(0.01)
print("Execution time:", time.perf_counter() - start_time)
Ottimizza le prestazioni
Per mostrare come è possibile ottimizzare le prestazioni, migliorerai le prestazioni di ArtificialDataset
.
L'approccio ingenuo
Inizia con una pipeline ingenua senza trucchi, iterando sul set di dati così com'è.
benchmark(ArtificialDataset())
Execution time: 0.26497629899995445
Sotto il cofano, ecco come è stato speso il tuo tempo di esecuzione:
La trama mostra che l'esecuzione di una fase di formazione comporta:
- Apertura di un file se non è stato ancora aperto
- Recupero di una voce di dati dal file
- Utilizzo dei dati per l'allenamento
Tuttavia, in un'implementazione sincrona ingenua come qui, mentre la tua pipeline sta recuperando i dati, il tuo modello è inattivo. Al contrario, mentre il modello è in fase di training, la pipeline di input è inattiva. Il tempo della fase di allenamento è quindi la somma dei tempi di apertura, lettura e formazione.
Le sezioni successive si basano su questa pipeline di input, illustrando le migliori pratiche per la progettazione di pipeline di input TensorFlow ad alte prestazioni.
Prelettura
Il precaricamento si sovrappone alla preelaborazione e all'esecuzione del modello di una fase di addestramento. Mentre il modello esegue la fase di addestramento s
, la pipeline di input legge i dati per la fase s+1
. In questo modo si riduce il tempo di passaggio al massimo (anziché alla somma) dell'allenamento e il tempo necessario per estrarre i dati.
L'API tf.data
fornisce la trasformazione tf.data.Dataset.prefetch
. Può essere utilizzato per disaccoppiare il momento in cui i dati vengono prodotti dal momento in cui i dati vengono consumati. In particolare, la trasformazione utilizza un thread in background e un buffer interno per precaricare gli elementi dal dataset di input prima del momento in cui vengono richiesti. Il numero di elementi da precaricare deve essere uguale (o possibilmente maggiore) al numero di batch consumati da un singolo passaggio di addestramento. È possibile ottimizzare manualmente questo valore o impostarlo su tf.data.AUTOTUNE
, che richiederà al runtime tf.data
di ottimizzare il valore in modo dinamico in fase di esecuzione.
Si noti che la trasformazione di precaricamento offre vantaggi ogni volta che si presenta l'opportunità di sovrapporre il lavoro di un "produttore" con il lavoro di un "consumatore".
benchmark(
ArtificialDataset()
.prefetch(tf.data.AUTOTUNE)
)
Execution time: 0.21731788600027357
Ora, come mostra il grafico del tempo di esecuzione dei dati, mentre la fase di addestramento è in esecuzione per l'esempio 0, la pipeline di input legge i dati per l'esempio 1 e così via.
Parallelizzare l'estrazione dei dati
In un ambiente reale, i dati di input possono essere archiviati in remoto (ad esempio, su Google Cloud Storage o HDFS). Una pipeline di set di dati che funziona bene durante la lettura dei dati in locale potrebbe subire un collo di bottiglia sull'I/O durante la lettura dei dati in remoto a causa delle seguenti differenze tra archiviazione locale e remota:
- Time-to-first-byte : la lettura del primo byte di un file dall'archiviazione remota può richiedere ordini di grandezza più lunghi rispetto all'archiviazione locale.
- Velocità effettiva di lettura : sebbene l'archiviazione remota offra in genere un'ampia larghezza di banda aggregata, la lettura di un singolo file potrebbe essere in grado di utilizzare solo una piccola parte di questa larghezza di banda.
Inoltre, una volta che i byte grezzi sono stati caricati in memoria, potrebbe essere necessario deserializzare e/o decrittare i dati (es. protobuf ), che richiede un calcolo aggiuntivo. Questo sovraccarico è presente indipendentemente dal fatto che i dati siano archiviati localmente o in remoto, ma può essere peggiore nel caso remoto se i dati non vengono precaricati in modo efficace.
Per mitigare l'impatto dei vari costi di estrazione dei dati, la trasformazione tf.data.Dataset.interleave
può essere utilizzata per parallelizzare la fase di caricamento dei dati, intercalando il contenuto di altri set di dati (come lettori di file di dati). Il numero di dataset da sovrapporre può essere specificato dall'argomento cycle_length
, mentre il livello di parallelismo può essere specificato dall'argomento num_parallel_calls
. Simile alla trasformazione di prefetch
, la trasformazione di interleave
supporta tf.data.AUTOTUNE
, che delegherà la decisione sul livello di parallelismo da usare al runtime tf.data
.
Interfoglio sequenziale
Gli argomenti predefiniti della trasformazione tf.data.Dataset.interleave
fanno in modo che intercalino singoli campioni da due set di dati in sequenza.
benchmark(
tf.data.Dataset.range(2)
.interleave(lambda _: ArtificialDataset())
)
Execution time: 0.4987426460002098
Questo grafico del tempo di esecuzione dei dati consente di mostrare il comportamento della trasformazione di interleave
, prelevando campioni alternativamente dai due set di dati disponibili. Tuttavia, qui non è coinvolto alcun miglioramento delle prestazioni.
Interfoglio parallelo
Ora, usa l'argomento num_parallel_calls
della trasformazione interleave
. Questo carica più set di dati in parallelo, riducendo il tempo di attesa per l'apertura dei file.
benchmark(
tf.data.Dataset.range(2)
.interleave(
lambda _: ArtificialDataset(),
num_parallel_calls=tf.data.AUTOTUNE
)
)
Execution time: 0.283668874000341
Questa volta, come mostra il grafico del tempo di esecuzione dei dati, la lettura dei due set di dati è parallela, riducendo il tempo globale di elaborazione dei dati.
Parallelizzare la trasformazione dei dati
Durante la preparazione dei dati, potrebbe essere necessario pre-elaborare gli elementi di input. A tal fine, l'API tf.data
offre la trasformazione tf.data.Dataset.map
, che applica una funzione definita dall'utente a ciascun elemento del set di dati di input. Poiché gli elementi di input sono indipendenti l'uno dall'altro, la pre-elaborazione può essere parallelizzata su più core della CPU. Per renderlo possibile, analogamente alle trasformazioni di prefetch
e interleave
, la trasformazione della map
fornisce l'argomento num_parallel_calls
per specificare il livello di parallelismo.
La scelta del valore migliore per l'argomento num_parallel_calls
dipende dall'hardware, dalle caratteristiche dei dati di addestramento (come dimensioni e forma), dal costo della funzione della mappa e da quale altra elaborazione sta avvenendo contemporaneamente sulla CPU. Una semplice euristica consiste nell'usare il numero di core CPU disponibili. Tuttavia, come per la trasformazione di prefetch
e interleave
, la trasformazione della map
supporta tf.data.AUTOTUNE
che delegherà la decisione su quale livello di parallelismo utilizzare al runtime tf.data
.
def mapped_function(s):
# Do some hard pre-processing
tf.py_function(lambda: time.sleep(0.03), [], ())
return s
Mappatura sequenziale
Inizia usando la trasformazione della map
senza parallelismo come esempio di base.
benchmark(
ArtificialDataset()
.map(mapped_function)
)
Execution time: 0.4505277170001136
Per quanto riguarda l' approccio ingenuo , qui, come mostra la trama, i tempi spesi per l'apertura, la lettura, la pre-elaborazione (mappatura) e le fasi di addestramento si sommano per un'unica iterazione.
Mappatura parallela
Ora, usa la stessa funzione di pre-elaborazione ma applicala in parallelo su più campioni.
benchmark(
ArtificialDataset()
.map(
mapped_function,
num_parallel_calls=tf.data.AUTOTUNE
)
)
Execution time: 0.2839677860001757
Come dimostra il grafico dei dati, i passaggi di pre-elaborazione si sovrappongono, riducendo il tempo complessivo per una singola iterazione.
Memorizzazione nella cache
La trasformazione tf.data.Dataset.cache
può memorizzare nella cache un set di dati, in memoria o nell'archiviazione locale. Ciò salverà alcune operazioni (come l'apertura di file e la lettura dei dati) dall'esecuzione durante ogni epoca.
benchmark(
ArtificialDataset()
.map( # Apply time consuming operations before cache
mapped_function
).cache(
),
5
)
Execution time: 0.3848854380003104
Qui, il grafico del tempo di esecuzione dei dati mostra che quando si memorizza nella cache un set di dati, le trasformazioni prima di quella della cache
(come l'apertura del file e la lettura dei dati) vengono eseguite solo durante la prima epoca. Le epoche successive riutilizzeranno i dati memorizzati nella cache
.
Se la funzione definita dall'utente passata alla trasformazione della map
è costosa, applicare la trasformazione della cache
dopo la trasformazione della map
, purché il set di dati risultante possa ancora rientrare nella memoria o nell'archiviazione locale. Se la funzione definita dall'utente aumenta lo spazio necessario per archiviare il set di dati oltre la capacità della cache, applicala dopo la trasformazione della cache
o considera la pre-elaborazione dei dati prima del processo di addestramento per ridurre l'utilizzo delle risorse.
Mappatura vettorizzatrice
Il richiamo di una funzione definita dall'utente passata alla trasformazione della map
comporta un sovraccarico relativo alla pianificazione e all'esecuzione della funzione definita dall'utente. Vettorizza la funzione definita dall'utente (ovvero, falla funzionare su un batch di input contemporaneamente) e applica la trasformazione batch
prima della trasformazione della map
.
Per illustrare questa buona pratica, il tuo set di dati artificiale non è adatto. Il ritardo di pianificazione è di circa 10 microsecondi (10e-6 secondi), molto inferiore alle decine di millisecondi utilizzati ArtificialDataset
, e quindi il suo impatto è difficile da vedere.
Per questo esempio, usa la funzione di base tf.data.Dataset.range
e semplifica il ciclo di addestramento nella sua forma più semplice.
fast_dataset = tf.data.Dataset.range(10000)
def fast_benchmark(dataset, num_epochs=2):
start_time = time.perf_counter()
for _ in tf.data.Dataset.range(num_epochs):
for _ in dataset:
pass
tf.print("Execution time:", time.perf_counter() - start_time)
def increment(x):
return x+1
Mappatura scalare
fast_benchmark(
fast_dataset
# Apply function one item at a time
.map(increment)
# Batch
.batch(256)
)
Execution time: 0.2712608739998359
Il grafico sopra illustra cosa sta succedendo (con meno campioni) usando il metodo di mappatura scalare. Mostra che la funzione mappata viene applicata a ciascun campione. Sebbene questa funzione sia molto veloce, ha un sovraccarico che influisce sulle prestazioni temporali.
Mappatura vettorizzata
fast_benchmark(
fast_dataset
.batch(256)
# Apply function on a batch of items
# The tf.Tensor.__add__ method already handle batches
.map(increment)
)
Execution time: 0.02737950600021577
Questa volta, la funzione mappata viene chiamata una volta e si applica a un batch di campioni. Come mostra il grafico del tempo di esecuzione dei dati, mentre la funzione potrebbe richiedere più tempo per l'esecuzione, l'overhead appare solo una volta, migliorando le prestazioni complessive del tempo.
Riduzione dell'ingombro di memoria
Numerose trasformazioni, tra cui interleave
, prefetch
e shuffle
, mantengono un buffer interno di elementi. Se la funzione definita dall'utente passata alla trasformazione della map
cambia la dimensione degli elementi, l'ordine della trasformazione della mappa e le trasformazioni che buffer gli elementi influiscono sull'utilizzo della memoria. In generale, scegli l'ordine che si traduce in un footprint di memoria inferiore, a meno che non sia auspicabile un ordine diverso per le prestazioni.
Memorizzazione nella cache di calcoli parziali
Si consiglia di memorizzare nella cache il set di dati dopo la trasformazione della map
a meno che questa trasformazione non renda i dati troppo grandi per essere inseriti nella memoria. È possibile ottenere un compromesso se la funzione mappata può essere suddivisa in due parti: una che richiede tempo e una parte che richiede memoria. In questo caso, puoi concatenare le tue trasformazioni come di seguito:
dataset.map(time_consuming_mapping).cache().map(memory_consuming_mapping)
In questo modo, la parte che richiede tempo viene eseguita solo durante la prima epoca ed eviti di utilizzare troppo spazio nella cache.
Riepilogo delle migliori pratiche
Di seguito è riportato un riepilogo delle migliori pratiche per la progettazione di pipeline di input TensorFlow performanti:
- Utilizzare la trasformazione di
prefetch
per sovrapporre il lavoro di un produttore e di un consumatore - Parallelizzare la trasformazione della lettura dei dati utilizzando la trasformazione
interleave
- Parallelizza la trasformazione della
map
impostando l'argomentonum_parallel_calls
- Utilizzare la
cache
della cache per memorizzare nella cache i dati durante la prima epoca - Vettorizza le funzioni definite dall'utente passate alla trasformazione della
map
- Riduci l'utilizzo della memoria quando si applicano le trasformazioni
interleave
,prefetch
eshuffle
Riproduzione delle figure
Per approfondire la comprensione dell'API tf.data.Dataset
, puoi giocare con le tue pipeline. Di seguito è riportato il codice utilizzato per tracciare le immagini di questa guida. Può essere un buon punto di partenza, mostrando alcune soluzioni alternative per difficoltà comuni come:
- Riproducibilità del tempo di esecuzione
- Esecuzione desiderosa di funzioni mappate
- trasformazione
interleave
richiamabile
import itertools
from collections import defaultdict
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
Il set di dati
Simile ArtificialDataset
, puoi creare un set di dati che restituisca il tempo trascorso in ogni passaggio.
class TimeMeasuredDataset(tf.data.Dataset):
# OUTPUT: (steps, timings, counters)
OUTPUT_TYPES = (tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32)
OUTPUT_SHAPES = ((2, 1), (2, 2), (2, 3))
_INSTANCES_COUNTER = itertools.count() # Number of datasets generated
_EPOCHS_COUNTER = defaultdict(itertools.count) # Number of epochs done for each dataset
def _generator(instance_idx, num_samples):
epoch_idx = next(TimeMeasuredDataset._EPOCHS_COUNTER[instance_idx])
# Opening the file
open_enter = time.perf_counter()
time.sleep(0.03)
open_elapsed = time.perf_counter() - open_enter
for sample_idx in range(num_samples):
# Reading data (line, record) from the file
read_enter = time.perf_counter()
time.sleep(0.015)
read_elapsed = time.perf_counter() - read_enter
yield (
[("Open",), ("Read",)],
[(open_enter, open_elapsed), (read_enter, read_elapsed)],
[(instance_idx, epoch_idx, -1), (instance_idx, epoch_idx, sample_idx)]
)
open_enter, open_elapsed = -1., -1. # Negative values will be filtered
def __new__(cls, num_samples=3):
return tf.data.Dataset.from_generator(
cls._generator,
output_types=cls.OUTPUT_TYPES,
output_shapes=cls.OUTPUT_SHAPES,
args=(next(cls._INSTANCES_COUNTER), num_samples)
)
Questo set di dati fornisce campioni di forma [[2, 1], [2, 2], [2, 3]]
e di tipo [tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32]
. Ogni campione è:
(
[("Open"), ("Read")],
[(t0, d), (t0, d)],
[(i, e, -1), (i, e, s)]
)
Dove:
-
Open
eRead
sono identificatori di passaggi -
t0
è il timestamp di inizio del passaggio corrispondente -
d
è il tempo trascorso nel passaggio corrispondente -
i
è l'indice di istanza -
e
è l'epoca index (numero di volte in cui il set di dati è stato iterato) -
s
è l'indice del campione
Il ciclo di iterazione
Rendi il ciclo di iterazione un po' più complicato per aggregare tutti i tempi. Questo funzionerà solo con set di dati che generano campioni come descritto sopra.
def timelined_benchmark(dataset, num_epochs=2):
# Initialize accumulators
steps_acc = tf.zeros([0, 1], dtype=tf.dtypes.string)
times_acc = tf.zeros([0, 2], dtype=tf.dtypes.float32)
values_acc = tf.zeros([0, 3], dtype=tf.dtypes.int32)
start_time = time.perf_counter()
for epoch_num in range(num_epochs):
epoch_enter = time.perf_counter()
for (steps, times, values) in dataset:
# Record dataset preparation informations
steps_acc = tf.concat((steps_acc, steps), axis=0)
times_acc = tf.concat((times_acc, times), axis=0)
values_acc = tf.concat((values_acc, values), axis=0)
# Simulate training time
train_enter = time.perf_counter()
time.sleep(0.01)
train_elapsed = time.perf_counter() - train_enter
# Record training informations
steps_acc = tf.concat((steps_acc, [["Train"]]), axis=0)
times_acc = tf.concat((times_acc, [(train_enter, train_elapsed)]), axis=0)
values_acc = tf.concat((values_acc, [values[-1]]), axis=0)
epoch_elapsed = time.perf_counter() - epoch_enter
# Record epoch informations
steps_acc = tf.concat((steps_acc, [["Epoch"]]), axis=0)
times_acc = tf.concat((times_acc, [(epoch_enter, epoch_elapsed)]), axis=0)
values_acc = tf.concat((values_acc, [[-1, epoch_num, -1]]), axis=0)
time.sleep(0.001)
tf.print("Execution time:", time.perf_counter() - start_time)
return {"steps": steps_acc, "times": times_acc, "values": values_acc}
Il metodo di tracciatura
Infine, definisci una funzione in grado di tracciare una timeline dati i valori restituiti dalla funzione timelined_benchmark
.
def draw_timeline(timeline, title, width=0.5, annotate=False, save=False):
# Remove invalid entries (negative times, or empty steps) from the timelines
invalid_mask = np.logical_and(timeline['times'] > 0, timeline['steps'] != b'')[:,0]
steps = timeline['steps'][invalid_mask].numpy()
times = timeline['times'][invalid_mask].numpy()
values = timeline['values'][invalid_mask].numpy()
# Get a set of different steps, ordered by the first time they are encountered
step_ids, indices = np.stack(np.unique(steps, return_index=True))
step_ids = step_ids[np.argsort(indices)]
# Shift the starting time to 0 and compute the maximal time value
min_time = times[:,0].min()
times[:,0] = (times[:,0] - min_time)
end = max(width, (times[:,0]+times[:,1]).max() + 0.01)
cmap = mpl.cm.get_cmap("plasma")
plt.close()
fig, axs = plt.subplots(len(step_ids), sharex=True, gridspec_kw={'hspace': 0})
fig.suptitle(title)
fig.set_size_inches(17.0, len(step_ids))
plt.xlim(-0.01, end)
for i, step in enumerate(step_ids):
step_name = step.decode()
ax = axs[i]
ax.set_ylabel(step_name)
ax.set_ylim(0, 1)
ax.set_yticks([])
ax.set_xlabel("time (s)")
ax.set_xticklabels([])
ax.grid(which="both", axis="x", color="k", linestyle=":")
# Get timings and annotation for the given step
entries_mask = np.squeeze(steps==step)
serie = np.unique(times[entries_mask], axis=0)
annotations = values[entries_mask]
ax.broken_barh(serie, (0, 1), color=cmap(i / len(step_ids)), linewidth=1, alpha=0.66)
if annotate:
for j, (start, width) in enumerate(serie):
annotation = "\n".join([f"{l}: {v}" for l,v in zip(("i", "e", "s"), annotations[j])])
ax.text(start + 0.001 + (0.001 * (j % 2)), 0.55 - (0.1 * (j % 2)), annotation,
horizontalalignment='left', verticalalignment='center')
if save:
plt.savefig(title.lower().translate(str.maketrans(" ", "_")) + ".svg")
Usa i wrapper per la funzione mappata
Per eseguire la funzione mappata in un contesto desideroso, devi racchiuderli all'interno di una chiamata tf.py_function
.
def map_decorator(func):
def wrapper(steps, times, values):
# Use a tf.py_function to prevent auto-graph from compiling the method
return tf.py_function(
func,
inp=(steps, times, values),
Tout=(steps.dtype, times.dtype, values.dtype)
)
return wrapper
Confronto tubazioni
_batch_map_num_items = 50
def dataset_generator_fun(*args):
return TimeMeasuredDataset(num_samples=_batch_map_num_items)
Ingenuo
@map_decorator
def naive_map(steps, times, values):
map_enter = time.perf_counter()
time.sleep(0.001) # Time consuming step
time.sleep(0.0001) # Memory consuming step
map_elapsed = time.perf_counter() - map_enter
return (
tf.concat((steps, [["Map"]]), axis=0),
tf.concat((times, [[map_enter, map_elapsed]]), axis=0),
tf.concat((values, [values[-1]]), axis=0)
)
naive_timeline = timelined_benchmark(
tf.data.Dataset.range(2)
.flat_map(dataset_generator_fun)
.map(naive_map)
.batch(_batch_map_num_items, drop_remainder=True)
.unbatch(),
5
)
WARNING:tensorflow:From /tmp/ipykernel_23983/64197174.py:36: calling DatasetV2.from_generator (from tensorflow.python.data.ops.dataset_ops) with output_types is deprecated and will be removed in a future version. Instructions for updating: Use output_signature instead WARNING:tensorflow:From /tmp/ipykernel_23983/64197174.py:36: calling DatasetV2.from_generator (from tensorflow.python.data.ops.dataset_ops) with output_shapes is deprecated and will be removed in a future version. Instructions for updating: Use output_signature instead Execution time: 13.13538893499981
Ottimizzato
@map_decorator
def time_consuming_map(steps, times, values):
map_enter = time.perf_counter()
time.sleep(0.001 * values.shape[0]) # Time consuming step
map_elapsed = time.perf_counter() - map_enter
return (
tf.concat((steps, tf.tile([[["1st map"]]], [steps.shape[0], 1, 1])), axis=1),
tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
)
@map_decorator
def memory_consuming_map(steps, times, values):
map_enter = time.perf_counter()
time.sleep(0.0001 * values.shape[0]) # Memory consuming step
map_elapsed = time.perf_counter() - map_enter
# Use tf.tile to handle batch dimension
return (
tf.concat((steps, tf.tile([[["2nd map"]]], [steps.shape[0], 1, 1])), axis=1),
tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
)
optimized_timeline = timelined_benchmark(
tf.data.Dataset.range(2)
.interleave( # Parallelize data reading
dataset_generator_fun,
num_parallel_calls=tf.data.AUTOTUNE
)
.batch( # Vectorize your mapped function
_batch_map_num_items,
drop_remainder=True)
.map( # Parallelize map transformation
time_consuming_map,
num_parallel_calls=tf.data.AUTOTUNE
)
.cache() # Cache data
.map( # Reduce memory usage
memory_consuming_map,
num_parallel_calls=tf.data.AUTOTUNE
)
.prefetch( # Overlap producer and consumer works
tf.data.AUTOTUNE
)
.unbatch(),
5
)
Execution time: 6.723691489999965
draw_timeline(naive_timeline, "Naive", 15)
draw_timeline(optimized_timeline, "Optimized", 15)