Migliori prestazioni con l'API tf.data

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza l'origine su GitHub Scarica quaderno

Panoramica

GPU e TPU possono ridurre radicalmente il tempo necessario per eseguire un singolo passaggio di addestramento. Il raggiungimento delle massime prestazioni richiede una pipeline di input efficiente che fornisca i dati per il passaggio successivo prima che il passaggio corrente sia terminato. L'API tf.data aiuta a creare pipeline di input flessibili ed efficienti. Questo documento mostra come utilizzare l'API tf.data per creare pipeline di input TensorFlow ad alte prestazioni.

Prima di continuare, consulta la guida alle pipeline di input Build TensorFlow per informazioni su come utilizzare l'API tf.data .

Risorse

Impostare

import tensorflow as tf

import time

In questa guida, eseguirai l'iterazione su un set di dati e misurerai le prestazioni. Creare parametri di riferimento delle prestazioni riproducibili può essere difficile. Diversi fattori che influenzano la riproducibilità includono:

  • Il carico attuale della CPU
  • Il traffico di rete
  • Meccanismi complessi, come la cache

Per ottenere un benchmark riproducibile, costruirai un esempio artificiale.

Il set di dati

Inizia con la definizione di una classe che eredita da tf.data.Dataset chiamata ArtificialDataset . Questo set di dati:

  • Genera num_samples campioni (il valore predefinito è 3)
  • Dorme per qualche tempo prima del primo elemento per simulare l'apertura di un file
  • Dorme per un po' prima di produrre ogni elemento per simulare la lettura dei dati da un file
class ArtificialDataset(tf.data.Dataset):
    def _generator(num_samples):
        # Opening the file
        time.sleep(0.03)

        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            time.sleep(0.015)

            yield (sample_idx,)

    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_signature = tf.TensorSpec(shape = (1,), dtype = tf.int64),
            args=(num_samples,)
        )

Questo set di dati è simile a quello tf.data.Dataset.range , aggiungendo un ritardo fisso all'inizio e tra ogni campione.

Il ciclo di formazione

Quindi, scrivi un ciclo di addestramento fittizio che misuri il tempo necessario per eseguire l'iterazione su un set di dati. Il tempo di allenamento è simulato.

def benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        for sample in dataset:
            # Performing a training step
            time.sleep(0.01)
    print("Execution time:", time.perf_counter() - start_time)

Ottimizza le prestazioni

Per mostrare come è possibile ottimizzare le prestazioni, migliorerai le prestazioni di ArtificialDataset .

L'approccio ingenuo

Inizia con una pipeline ingenua senza trucchi, iterando sul set di dati così com'è.

benchmark(ArtificialDataset())
Execution time: 0.26497629899995445

Sotto il cofano, ecco come è stato speso il tuo tempo di esecuzione:

Grafico del tempo di esecuzione dei dati: un metodo ingenuo

La trama mostra che l'esecuzione di una fase di formazione comporta:

  • Apertura di un file se non è stato ancora aperto
  • Recupero di una voce di dati dal file
  • Utilizzo dei dati per l'allenamento

Tuttavia, in un'implementazione sincrona ingenua come qui, mentre la tua pipeline sta recuperando i dati, il tuo modello è inattivo. Al contrario, mentre il modello è in fase di training, la pipeline di input è inattiva. Il tempo della fase di allenamento è quindi la somma dei tempi di apertura, lettura e formazione.

Le sezioni successive si basano su questa pipeline di input, illustrando le migliori pratiche per la progettazione di pipeline di input TensorFlow ad alte prestazioni.

Prelettura

Il precaricamento si sovrappone alla preelaborazione e all'esecuzione del modello di una fase di addestramento. Mentre il modello esegue la fase di addestramento s , la pipeline di input legge i dati per la fase s+1 . In questo modo si riduce il tempo di passaggio al massimo (anziché alla somma) dell'allenamento e il tempo necessario per estrarre i dati.

L'API tf.data fornisce la trasformazione tf.data.Dataset.prefetch . Può essere utilizzato per disaccoppiare il momento in cui i dati vengono prodotti dal momento in cui i dati vengono consumati. In particolare, la trasformazione utilizza un thread in background e un buffer interno per precaricare gli elementi dal dataset di input prima del momento in cui vengono richiesti. Il numero di elementi da precaricare deve essere uguale (o possibilmente maggiore) al numero di batch consumati da un singolo passaggio di addestramento. È possibile ottimizzare manualmente questo valore o impostarlo su tf.data.AUTOTUNE , che richiederà al runtime tf.data di ottimizzare il valore in modo dinamico in fase di esecuzione.

Si noti che la trasformazione di precaricamento offre vantaggi ogni volta che si presenta l'opportunità di sovrapporre il lavoro di un "produttore" con il lavoro di un "consumatore".

benchmark(
    ArtificialDataset()
    .prefetch(tf.data.AUTOTUNE)
)
Execution time: 0.21731788600027357

Grafico del tempo di esecuzione dei dati - metodo di prelettura

Ora, come mostra il grafico del tempo di esecuzione dei dati, mentre la fase di addestramento è in esecuzione per l'esempio 0, la pipeline di input legge i dati per l'esempio 1 e così via.

Parallelizzare l'estrazione dei dati

In un ambiente reale, i dati di input possono essere archiviati in remoto (ad esempio, su Google Cloud Storage o HDFS). Una pipeline di set di dati che funziona bene durante la lettura dei dati in locale potrebbe subire un collo di bottiglia sull'I/O durante la lettura dei dati in remoto a causa delle seguenti differenze tra archiviazione locale e remota:

  • Time-to-first-byte : la lettura del primo byte di un file dall'archiviazione remota può richiedere ordini di grandezza più lunghi rispetto all'archiviazione locale.
  • Velocità effettiva di lettura : sebbene l'archiviazione remota offra in genere un'ampia larghezza di banda aggregata, la lettura di un singolo file potrebbe essere in grado di utilizzare solo una piccola parte di questa larghezza di banda.

Inoltre, una volta che i byte grezzi sono stati caricati in memoria, potrebbe essere necessario deserializzare e/o decrittare i dati (es. protobuf ), che richiede un calcolo aggiuntivo. Questo sovraccarico è presente indipendentemente dal fatto che i dati siano archiviati localmente o in remoto, ma può essere peggiore nel caso remoto se i dati non vengono precaricati in modo efficace.

Per mitigare l'impatto dei vari costi di estrazione dei dati, la trasformazione tf.data.Dataset.interleave può essere utilizzata per parallelizzare la fase di caricamento dei dati, intercalando il contenuto di altri set di dati (come lettori di file di dati). Il numero di dataset da sovrapporre può essere specificato dall'argomento cycle_length , mentre il livello di parallelismo può essere specificato dall'argomento num_parallel_calls . Simile alla trasformazione di prefetch , la trasformazione di interleave supporta tf.data.AUTOTUNE , che delegherà la decisione sul livello di parallelismo da usare al runtime tf.data .

Interfoglio sequenziale

Gli argomenti predefiniti della trasformazione tf.data.Dataset.interleave fanno in modo che intercalino singoli campioni da due set di dati in sequenza.

benchmark(
    tf.data.Dataset.range(2)
    .interleave(lambda _: ArtificialDataset())
)
Execution time: 0.4987426460002098

Grafico del tempo di esecuzione dei dati - interleave sequenziale

Questo grafico del tempo di esecuzione dei dati consente di mostrare il comportamento della trasformazione di interleave , prelevando campioni alternativamente dai due set di dati disponibili. Tuttavia, qui non è coinvolto alcun miglioramento delle prestazioni.

Interfoglio parallelo

Ora, usa l'argomento num_parallel_calls della trasformazione interleave . Questo carica più set di dati in parallelo, riducendo il tempo di attesa per l'apertura dei file.

benchmark(
    tf.data.Dataset.range(2)
    .interleave(
        lambda _: ArtificialDataset(),
        num_parallel_calls=tf.data.AUTOTUNE
    )
)
Execution time: 0.283668874000341

Grafico del tempo di esecuzione dei dati - metodo di interleave parallelo

Questa volta, come mostra il grafico del tempo di esecuzione dei dati, la lettura dei due set di dati è parallela, riducendo il tempo globale di elaborazione dei dati.

Parallelizzare la trasformazione dei dati

Durante la preparazione dei dati, potrebbe essere necessario pre-elaborare gli elementi di input. A tal fine, l'API tf.data offre la trasformazione tf.data.Dataset.map , che applica una funzione definita dall'utente a ciascun elemento del set di dati di input. Poiché gli elementi di input sono indipendenti l'uno dall'altro, la pre-elaborazione può essere parallelizzata su più core della CPU. Per renderlo possibile, analogamente alle trasformazioni di prefetch e interleave , la trasformazione della map fornisce l'argomento num_parallel_calls per specificare il livello di parallelismo.

La scelta del valore migliore per l'argomento num_parallel_calls dipende dall'hardware, dalle caratteristiche dei dati di addestramento (come dimensioni e forma), dal costo della funzione della mappa e da quale altra elaborazione sta avvenendo contemporaneamente sulla CPU. Una semplice euristica consiste nell'usare il numero di core CPU disponibili. Tuttavia, come per la trasformazione di prefetch e interleave , la trasformazione della map supporta tf.data.AUTOTUNE che delegherà la decisione su quale livello di parallelismo utilizzare al runtime tf.data .

def mapped_function(s):
    # Do some hard pre-processing
    tf.py_function(lambda: time.sleep(0.03), [], ())
    return s

Mappatura sequenziale

Inizia usando la trasformazione della map senza parallelismo come esempio di base.

benchmark(
    ArtificialDataset()
    .map(mapped_function)
)
Execution time: 0.4505277170001136

Grafico del tempo di esecuzione dei dati - metodo di mappatura sequenziale

Per quanto riguarda l' approccio ingenuo , qui, come mostra la trama, i tempi spesi per l'apertura, la lettura, la pre-elaborazione (mappatura) e le fasi di addestramento si sommano per un'unica iterazione.

Mappatura parallela

Ora, usa la stessa funzione di pre-elaborazione ma applicala in parallelo su più campioni.

benchmark(
    ArtificialDataset()
    .map(
        mapped_function,
        num_parallel_calls=tf.data.AUTOTUNE
    )
)
Execution time: 0.2839677860001757

Tempo di esecuzione dei dati - mappatura parallela

Come dimostra il grafico dei dati, i passaggi di pre-elaborazione si sovrappongono, riducendo il tempo complessivo per una singola iterazione.

Memorizzazione nella cache

La trasformazione tf.data.Dataset.cache può memorizzare nella cache un set di dati, in memoria o nell'archiviazione locale. Ciò salverà alcune operazioni (come l'apertura di file e la lettura dei dati) dall'esecuzione durante ogni epoca.

benchmark(
    ArtificialDataset()
    .map(  # Apply time consuming operations before cache
        mapped_function
    ).cache(
    ),
    5
)
Execution time: 0.3848854380003104

Tempo di esecuzione dei dati: metodo del set di dati memorizzato nella cache

Qui, il grafico del tempo di esecuzione dei dati mostra che quando si memorizza nella cache un set di dati, le trasformazioni prima di quella della cache (come l'apertura del file e la lettura dei dati) vengono eseguite solo durante la prima epoca. Le epoche successive riutilizzeranno i dati memorizzati nella cache .

Se la funzione definita dall'utente passata alla trasformazione della map è costosa, applicare la trasformazione della cache dopo la trasformazione della map , purché il set di dati risultante possa ancora rientrare nella memoria o nell'archiviazione locale. Se la funzione definita dall'utente aumenta lo spazio necessario per archiviare il set di dati oltre la capacità della cache, applicala dopo la trasformazione della cache o considera la pre-elaborazione dei dati prima del processo di addestramento per ridurre l'utilizzo delle risorse.

Mappatura vettorizzatrice

Il richiamo di una funzione definita dall'utente passata alla trasformazione della map comporta un sovraccarico relativo alla pianificazione e all'esecuzione della funzione definita dall'utente. Vettorizza la funzione definita dall'utente (ovvero, falla funzionare su un batch di input contemporaneamente) e applica la trasformazione batch prima della trasformazione della map .

Per illustrare questa buona pratica, il tuo set di dati artificiale non è adatto. Il ritardo di pianificazione è di circa 10 microsecondi (10e-6 secondi), molto inferiore alle decine di millisecondi utilizzati ArtificialDataset , e quindi il suo impatto è difficile da vedere.

Per questo esempio, usa la funzione di base tf.data.Dataset.range e semplifica il ciclo di addestramento nella sua forma più semplice.

fast_dataset = tf.data.Dataset.range(10000)

def fast_benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for _ in tf.data.Dataset.range(num_epochs):
        for _ in dataset:
            pass
    tf.print("Execution time:", time.perf_counter() - start_time)

def increment(x):
    return x+1

Mappatura scalare

fast_benchmark(
    fast_dataset
    # Apply function one item at a time
    .map(increment)
    # Batch
    .batch(256)
)
Execution time: 0.2712608739998359

Tempo di esecuzione dei dati - metodo della mappa scalare

Il grafico sopra illustra cosa sta succedendo (con meno campioni) usando il metodo di mappatura scalare. Mostra che la funzione mappata viene applicata a ciascun campione. Sebbene questa funzione sia molto veloce, ha un sovraccarico che influisce sulle prestazioni temporali.

Mappatura vettorizzata

fast_benchmark(
    fast_dataset
    .batch(256)
    # Apply function on a batch of items
    # The tf.Tensor.__add__ method already handle batches
    .map(increment)
)
Execution time: 0.02737950600021577

Tempo di esecuzione dei dati - metodo della mappa vettoriale

Questa volta, la funzione mappata viene chiamata una volta e si applica a un batch di campioni. Come mostra il grafico del tempo di esecuzione dei dati, mentre la funzione potrebbe richiedere più tempo per l'esecuzione, l'overhead appare solo una volta, migliorando le prestazioni complessive del tempo.

Riduzione dell'ingombro di memoria

Numerose trasformazioni, tra cui interleave , prefetch e shuffle , mantengono un buffer interno di elementi. Se la funzione definita dall'utente passata alla trasformazione della map cambia la dimensione degli elementi, l'ordine della trasformazione della mappa e le trasformazioni che buffer gli elementi influiscono sull'utilizzo della memoria. In generale, scegli l'ordine che si traduce in un footprint di memoria inferiore, a meno che non sia auspicabile un ordine diverso per le prestazioni.

Memorizzazione nella cache di calcoli parziali

Si consiglia di memorizzare nella cache il set di dati dopo la trasformazione della map a meno che questa trasformazione non renda i dati troppo grandi per essere inseriti nella memoria. È possibile ottenere un compromesso se la funzione mappata può essere suddivisa in due parti: una che richiede tempo e una parte che richiede memoria. In questo caso, puoi concatenare le tue trasformazioni come di seguito:

dataset.map(time_consuming_mapping).cache().map(memory_consuming_mapping)

In questo modo, la parte che richiede tempo viene eseguita solo durante la prima epoca ed eviti di utilizzare troppo spazio nella cache.

Riepilogo delle migliori pratiche

Di seguito è riportato un riepilogo delle migliori pratiche per la progettazione di pipeline di input TensorFlow performanti:

Riproduzione delle figure

Per approfondire la comprensione dell'API tf.data.Dataset , puoi giocare con le tue pipeline. Di seguito è riportato il codice utilizzato per tracciare le immagini di questa guida. Può essere un buon punto di partenza, mostrando alcune soluzioni alternative per difficoltà comuni come:

  • Riproducibilità del tempo di esecuzione
  • Esecuzione desiderosa di funzioni mappate
  • trasformazione interleave richiamabile
import itertools
from collections import defaultdict

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

Il set di dati

Simile ArtificialDataset , puoi creare un set di dati che restituisca il tempo trascorso in ogni passaggio.

class TimeMeasuredDataset(tf.data.Dataset):
    # OUTPUT: (steps, timings, counters)
    OUTPUT_TYPES = (tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32)
    OUTPUT_SHAPES = ((2, 1), (2, 2), (2, 3))

    _INSTANCES_COUNTER = itertools.count()  # Number of datasets generated
    _EPOCHS_COUNTER = defaultdict(itertools.count)  # Number of epochs done for each dataset

    def _generator(instance_idx, num_samples):
        epoch_idx = next(TimeMeasuredDataset._EPOCHS_COUNTER[instance_idx])

        # Opening the file
        open_enter = time.perf_counter()
        time.sleep(0.03)
        open_elapsed = time.perf_counter() - open_enter

        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            read_enter = time.perf_counter()
            time.sleep(0.015)
            read_elapsed = time.perf_counter() - read_enter

            yield (
                [("Open",), ("Read",)],
                [(open_enter, open_elapsed), (read_enter, read_elapsed)],
                [(instance_idx, epoch_idx, -1), (instance_idx, epoch_idx, sample_idx)]
            )
            open_enter, open_elapsed = -1., -1.  # Negative values will be filtered


    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_types=cls.OUTPUT_TYPES,
            output_shapes=cls.OUTPUT_SHAPES,
            args=(next(cls._INSTANCES_COUNTER), num_samples)
        )

Questo set di dati fornisce campioni di forma [[2, 1], [2, 2], [2, 3]] e di tipo [tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32] . Ogni campione è:

(
  [("Open"), ("Read")],
  [(t0, d), (t0, d)],
  [(i, e, -1), (i, e, s)]
)

Dove:

  • Open e Read sono identificatori di passaggi
  • t0 è il timestamp di inizio del passaggio corrispondente
  • d è il tempo trascorso nel passaggio corrispondente
  • i è l'indice di istanza
  • e è l'epoca index (numero di volte in cui il set di dati è stato iterato)
  • s è l'indice del campione

Il ciclo di iterazione

Rendi il ciclo di iterazione un po' più complicato per aggregare tutti i tempi. Questo funzionerà solo con set di dati che generano campioni come descritto sopra.

def timelined_benchmark(dataset, num_epochs=2):
    # Initialize accumulators
    steps_acc = tf.zeros([0, 1], dtype=tf.dtypes.string)
    times_acc = tf.zeros([0, 2], dtype=tf.dtypes.float32)
    values_acc = tf.zeros([0, 3], dtype=tf.dtypes.int32)

    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        epoch_enter = time.perf_counter()
        for (steps, times, values) in dataset:
            # Record dataset preparation informations
            steps_acc = tf.concat((steps_acc, steps), axis=0)
            times_acc = tf.concat((times_acc, times), axis=0)
            values_acc = tf.concat((values_acc, values), axis=0)

            # Simulate training time
            train_enter = time.perf_counter()
            time.sleep(0.01)
            train_elapsed = time.perf_counter() - train_enter

            # Record training informations
            steps_acc = tf.concat((steps_acc, [["Train"]]), axis=0)
            times_acc = tf.concat((times_acc, [(train_enter, train_elapsed)]), axis=0)
            values_acc = tf.concat((values_acc, [values[-1]]), axis=0)

        epoch_elapsed = time.perf_counter() - epoch_enter
        # Record epoch informations
        steps_acc = tf.concat((steps_acc, [["Epoch"]]), axis=0)
        times_acc = tf.concat((times_acc, [(epoch_enter, epoch_elapsed)]), axis=0)
        values_acc = tf.concat((values_acc, [[-1, epoch_num, -1]]), axis=0)
        time.sleep(0.001)

    tf.print("Execution time:", time.perf_counter() - start_time)
    return {"steps": steps_acc, "times": times_acc, "values": values_acc}

Il metodo di tracciatura

Infine, definisci una funzione in grado di tracciare una timeline dati i valori restituiti dalla funzione timelined_benchmark .

def draw_timeline(timeline, title, width=0.5, annotate=False, save=False):
    # Remove invalid entries (negative times, or empty steps) from the timelines
    invalid_mask = np.logical_and(timeline['times'] > 0, timeline['steps'] != b'')[:,0]
    steps = timeline['steps'][invalid_mask].numpy()
    times = timeline['times'][invalid_mask].numpy()
    values = timeline['values'][invalid_mask].numpy()

    # Get a set of different steps, ordered by the first time they are encountered
    step_ids, indices = np.stack(np.unique(steps, return_index=True))
    step_ids = step_ids[np.argsort(indices)]

    # Shift the starting time to 0 and compute the maximal time value
    min_time = times[:,0].min()
    times[:,0] = (times[:,0] - min_time)
    end = max(width, (times[:,0]+times[:,1]).max() + 0.01)

    cmap = mpl.cm.get_cmap("plasma")
    plt.close()
    fig, axs = plt.subplots(len(step_ids), sharex=True, gridspec_kw={'hspace': 0})
    fig.suptitle(title)
    fig.set_size_inches(17.0, len(step_ids))
    plt.xlim(-0.01, end)

    for i, step in enumerate(step_ids):
        step_name = step.decode()
        ax = axs[i]
        ax.set_ylabel(step_name)
        ax.set_ylim(0, 1)
        ax.set_yticks([])
        ax.set_xlabel("time (s)")
        ax.set_xticklabels([])
        ax.grid(which="both", axis="x", color="k", linestyle=":")

        # Get timings and annotation for the given step
        entries_mask = np.squeeze(steps==step)
        serie = np.unique(times[entries_mask], axis=0)
        annotations = values[entries_mask]

        ax.broken_barh(serie, (0, 1), color=cmap(i / len(step_ids)), linewidth=1, alpha=0.66)
        if annotate:
            for j, (start, width) in enumerate(serie):
                annotation = "\n".join([f"{l}: {v}" for l,v in zip(("i", "e", "s"), annotations[j])])
                ax.text(start + 0.001 + (0.001 * (j % 2)), 0.55 - (0.1 * (j % 2)), annotation,
                        horizontalalignment='left', verticalalignment='center')
    if save:
        plt.savefig(title.lower().translate(str.maketrans(" ", "_")) + ".svg")

Usa i wrapper per la funzione mappata

Per eseguire la funzione mappata in un contesto desideroso, devi racchiuderli all'interno di una chiamata tf.py_function .

def map_decorator(func):
    def wrapper(steps, times, values):
        # Use a tf.py_function to prevent auto-graph from compiling the method
        return tf.py_function(
            func,
            inp=(steps, times, values),
            Tout=(steps.dtype, times.dtype, values.dtype)
        )
    return wrapper

Confronto tubazioni

_batch_map_num_items = 50

def dataset_generator_fun(*args):
    return TimeMeasuredDataset(num_samples=_batch_map_num_items)

Ingenuo

@map_decorator
def naive_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.001)  # Time consuming step
    time.sleep(0.0001)  # Memory consuming step
    map_elapsed = time.perf_counter() - map_enter

    return (
        tf.concat((steps, [["Map"]]), axis=0),
        tf.concat((times, [[map_enter, map_elapsed]]), axis=0),
        tf.concat((values, [values[-1]]), axis=0)
    )

naive_timeline = timelined_benchmark(
    tf.data.Dataset.range(2)
    .flat_map(dataset_generator_fun)
    .map(naive_map)
    .batch(_batch_map_num_items, drop_remainder=True)
    .unbatch(),
    5
)
WARNING:tensorflow:From /tmp/ipykernel_23983/64197174.py:36: calling DatasetV2.from_generator (from tensorflow.python.data.ops.dataset_ops) with output_types is deprecated and will be removed in a future version.
Instructions for updating:
Use output_signature instead
WARNING:tensorflow:From /tmp/ipykernel_23983/64197174.py:36: calling DatasetV2.from_generator (from tensorflow.python.data.ops.dataset_ops) with output_shapes is deprecated and will be removed in a future version.
Instructions for updating:
Use output_signature instead
Execution time: 13.13538893499981

Ottimizzato

@map_decorator
def time_consuming_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.001 * values.shape[0])  # Time consuming step
    map_elapsed = time.perf_counter() - map_enter

    return (
        tf.concat((steps, tf.tile([[["1st map"]]], [steps.shape[0], 1, 1])), axis=1),
        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
    )


@map_decorator
def memory_consuming_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.0001 * values.shape[0])  # Memory consuming step
    map_elapsed = time.perf_counter() - map_enter

    # Use tf.tile to handle batch dimension
    return (
        tf.concat((steps, tf.tile([[["2nd map"]]], [steps.shape[0], 1, 1])), axis=1),
        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
    )


optimized_timeline = timelined_benchmark(
    tf.data.Dataset.range(2)
    .interleave(  # Parallelize data reading
        dataset_generator_fun,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    .batch(  # Vectorize your mapped function
        _batch_map_num_items,
        drop_remainder=True)
    .map(  # Parallelize map transformation
        time_consuming_map,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    .cache()  # Cache data
    .map(  # Reduce memory usage
        memory_consuming_map,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    .prefetch(  # Overlap producer and consumer works
        tf.data.AUTOTUNE
    )
    .unbatch(),
    5
)
Execution time: 6.723691489999965
draw_timeline(naive_timeline, "Naive", 15)

png

draw_timeline(optimized_timeline, "Optimized", 15)

png