Voir sur TensorFlow.org | Exécuter dans Google Colab | Voir la source sur GitHub | Télécharger le cahier |
Aperçu
Les GPU et les TPU peuvent réduire considérablement le temps nécessaire pour exécuter une seule étape de formation. L'obtention de performances optimales nécessite un pipeline d'entrée efficace qui fournit des données pour l'étape suivante avant la fin de l'étape en cours. L'API tf.data
permet de créer des pipelines d'entrée flexibles et efficaces. Ce document montre comment utiliser l'API tf.data
pour créer des pipelines d'entrée TensorFlow hautement performants.
Avant de continuer, consultez le guide Build TensorFlow input pipelines pour savoir comment utiliser l'API tf.data
.
Ressources
- Créer des pipelines d'entrée TensorFlow
- API
tf.data.Dataset
- Analysez les performances de
tf.data
avec le TF Profiler
Installer
import tensorflow as tf
import time
Tout au long de ce guide, vous allez parcourir un ensemble de données et mesurer les performances. Faire des benchmarks de performance reproductibles peut être difficile. Les différents facteurs affectant la reproductibilité comprennent :
- La charge CPU actuelle
- Le trafic réseau
- Mécanismes complexes, tels que le cache
Pour obtenir un benchmark reproductible, vous allez construire un exemple artificiel.
Le jeu de données
Commencez par définir une classe héritant de tf.data.Dataset
appelée ArtificialDataset
. Cet ensemble de données :
- Génère
num_samples
échantillons (la valeur par défaut est 3) - Dort pendant un certain temps avant le premier élément pour simuler l'ouverture d'un fichier
- Dort pendant un certain temps avant de produire chaque élément pour simuler la lecture de données à partir d'un fichier
class ArtificialDataset(tf.data.Dataset):
def _generator(num_samples):
# Opening the file
time.sleep(0.03)
for sample_idx in range(num_samples):
# Reading data (line, record) from the file
time.sleep(0.015)
yield (sample_idx,)
def __new__(cls, num_samples=3):
return tf.data.Dataset.from_generator(
cls._generator,
output_signature = tf.TensorSpec(shape = (1,), dtype = tf.int64),
args=(num_samples,)
)
Ce jeu de données est similaire à celui de tf.data.Dataset.range
, ajoutant un délai fixe au début et entre chaque échantillon.
La boucle d'entraînement
Ensuite, écrivez une boucle d'entraînement factice qui mesure le temps nécessaire pour itérer sur un ensemble de données. Le temps de formation est simulé.
def benchmark(dataset, num_epochs=2):
start_time = time.perf_counter()
for epoch_num in range(num_epochs):
for sample in dataset:
# Performing a training step
time.sleep(0.01)
print("Execution time:", time.perf_counter() - start_time)
Optimiser les performances
Pour montrer comment les performances peuvent être optimisées, vous allez améliorer les performances de ArtificialDataset
.
L'approche naïve
Commencez avec un pipeline naïf sans aucune astuce, itérant sur l'ensemble de données tel quel.
benchmark(ArtificialDataset())
Execution time: 0.26497629899995445
Sous le capot, voici comment votre temps d'exécution a été dépensé :
Le graphique montre que l'exécution d'une étape d'entraînement implique :
- Ouvrir un fichier s'il n'a pas encore été ouvert
- Récupération d'une entrée de données à partir du fichier
- Utilisation des données pour la formation
Cependant, dans une implémentation synchrone naïve comme ici, pendant que votre pipeline récupère les données, votre modèle est inactif. Inversement, pendant que votre modèle s'entraîne, le pipeline d'entrée est inactif. Le temps de pas d'apprentissage est donc la somme des temps d'ouverture, de lecture et d'apprentissage.
Les sections suivantes s'appuient sur ce pipeline d'entrée, illustrant les bonnes pratiques pour la conception de pipelines d'entrée TensorFlow performants.
Prélecture
La prélecture chevauche le prétraitement et l'exécution du modèle d'une étape d'apprentissage. Pendant que le modèle exécute l'étape d'apprentissage s
, le pipeline d'entrée lit les données pour l'étape s+1
. Cela réduit le temps de pas au maximum (par opposition à la somme) de la formation et le temps nécessaire pour extraire les données.
L'API tf.data
fournit la transformation tf.data.Dataset.prefetch
. Il peut être utilisé pour découpler le moment où les données sont produites du moment où les données sont consommées. En particulier, la transformation utilise un thread d'arrière-plan et un tampon interne pour préextraire les éléments de l'ensemble de données d'entrée avant le moment où ils sont demandés. Le nombre d'éléments à prérécupérer doit être égal (ou éventuellement supérieur) au nombre de lots consommés par une seule étape d'apprentissage. Vous pouvez soit ajuster manuellement cette valeur, soit la définir sur tf.data.AUTOTUNE
, ce qui demandera à l'environnement d'exécution tf.data
d'ajuster dynamiquement la valeur au moment de l'exécution.
Notez que la transformation de prélecture offre des avantages chaque fois qu'il y a une opportunité de chevaucher le travail d'un "producteur" avec le travail d'un "consommateur".
benchmark(
ArtificialDataset()
.prefetch(tf.data.AUTOTUNE)
)
Execution time: 0.21731788600027357
Maintenant, comme le montre le tracé du temps d'exécution des données, pendant que l'étape d'apprentissage s'exécute pour l'échantillon 0, le pipeline d'entrée lit les données pour l'échantillon 1, et ainsi de suite.
Parallélisation de l'extraction de données
Dans un environnement réel, les données d'entrée peuvent être stockées à distance (par exemple, sur Google Cloud Storage ou HDFS). Un pipeline d'ensemble de données qui fonctionne bien lors de la lecture de données localement peut présenter un goulot d'étranglement sur les E/S lors de la lecture de données à distance en raison des différences suivantes entre le stockage local et le stockage distant :
- Time-to-first-byte : la lecture du premier octet d'un fichier à partir d'un stockage distant peut prendre des ordres de grandeur plus longs qu'à partir d'un stockage local.
- Débit de lecture : alors que le stockage distant offre généralement une large bande passante agrégée, la lecture d'un seul fichier peut n'utiliser qu'une petite fraction de cette bande passante.
De plus, une fois les octets bruts chargés en mémoire, il peut également être nécessaire de désérialiser et/ou de déchiffrer les données (par exemple protobuf ), ce qui nécessite des calculs supplémentaires. Cette surcharge est présente indépendamment du fait que les données soient stockées localement ou à distance, mais peut être pire dans le cas distant si les données ne sont pas prélues efficacement.
Pour atténuer l'impact des divers frais généraux d'extraction de données, la transformation tf.data.Dataset.interleave
peut être utilisée pour paralléliser l'étape de chargement des données, en entrelaçant le contenu d'autres ensembles de données (tels que les lecteurs de fichiers de données). Le nombre d'ensembles de données à chevaucher peut être spécifié par l'argument cycle_length
, tandis que le niveau de parallélisme peut être spécifié par l'argument num_parallel_calls
. Semblable à la transformation de prefetch
, la transformation d' interleave
prend en charge tf.data.AUTOTUNE
, qui délègue la décision sur le niveau de parallélisme à utiliser au runtime tf.data
.
Entrelacement séquentiel
Les arguments par défaut de la transformation tf.data.Dataset.interleave
lui permettent d'entrelacer séquentiellement des échantillons uniques de deux ensembles de données.
benchmark(
tf.data.Dataset.range(2)
.interleave(lambda _: ArtificialDataset())
)
Execution time: 0.4987426460002098
Ce tracé du temps d'exécution des données permet d'exposer le comportement de la transformation interleave
, en récupérant alternativement des échantillons à partir des deux ensembles de données disponibles. Cependant, aucune amélioration des performances n'est impliquée ici.
Entrelacement parallèle
Maintenant, utilisez l'argument num_parallel_calls
de la transformation interleave
. Cela charge plusieurs ensembles de données en parallèle, ce qui réduit le temps d'attente pour l'ouverture des fichiers.
benchmark(
tf.data.Dataset.range(2)
.interleave(
lambda _: ArtificialDataset(),
num_parallel_calls=tf.data.AUTOTUNE
)
)
Execution time: 0.283668874000341
Cette fois, comme le montre le tracé du temps d'exécution des données, la lecture des deux jeux de données est parallélisée, ce qui réduit le temps de traitement global des données.
Paralléliser la transformation des données
Lors de la préparation des données, les éléments d'entrée peuvent devoir être prétraités. À cette fin, l'API tf.data
propose la transformation tf.data.Dataset.map
, qui applique une fonction définie par l'utilisateur à chaque élément de l'ensemble de données d'entrée. Étant donné que les éléments d'entrée sont indépendants les uns des autres, le prétraitement peut être parallélisé sur plusieurs cœurs de processeur. Pour rendre cela possible, de la même manière que les transformations prefetch
et interleave
, la transformation map
fournit l'argument num_parallel_calls
pour spécifier le niveau de parallélisme.
Le choix de la meilleure valeur pour l'argument num_parallel_calls
dépend de votre matériel, des caractéristiques de vos données d'entraînement (telles que leur taille et leur forme), du coût de votre fonction de carte et des autres traitements qui se produisent sur le CPU en même temps. Une heuristique simple consiste à utiliser le nombre de cœurs de processeur disponibles. Cependant, comme pour la transformation de prefetch
et d' interleave
, la transformation de map
prend en charge tf.data.AUTOTUNE
qui déléguera la décision sur le niveau de parallélisme à utiliser au runtime tf.data
.
def mapped_function(s):
# Do some hard pre-processing
tf.py_function(lambda: time.sleep(0.03), [], ())
return s
Cartographie séquentielle
Commencez par utiliser la transformation de map
sans parallélisme comme exemple de référence.
benchmark(
ArtificialDataset()
.map(mapped_function)
)
Execution time: 0.4505277170001136
Quant à l' approche naïve , ici, comme le montre l'intrigue, les temps passés pour les étapes d'ouverture, de lecture, de pré-traitement (cartographie) et d'apprentissage s'additionnent pour une seule itération.
Cartographie parallèle
Maintenant, utilisez la même fonction de prétraitement mais appliquez-la en parallèle sur plusieurs échantillons.
benchmark(
ArtificialDataset()
.map(
mapped_function,
num_parallel_calls=tf.data.AUTOTUNE
)
)
Execution time: 0.2839677860001757
Comme le montre le diagramme de données, les étapes de prétraitement se chevauchent, ce qui réduit le temps global pour une seule itération.
Mise en cache
La transformation tf.data.Dataset.cache
peut mettre en cache un jeu de données, soit en mémoire, soit sur le stockage local. Cela évitera que certaines opérations (telles que l'ouverture de fichiers et la lecture de données) ne soient exécutées à chaque époque.
benchmark(
ArtificialDataset()
.map( # Apply time consuming operations before cache
mapped_function
).cache(
),
5
)
Execution time: 0.3848854380003104
Ici, le tracé du temps d'exécution des données montre que lorsque vous mettez en cache un ensemble de données, les transformations avant celle du cache
(comme l'ouverture du fichier et la lecture des données) ne sont exécutées que pendant la première époque. Les époques suivantes réutiliseront les données mises en cache par la transformation du cache
.
Si la fonction définie par l'utilisateur transmise à la transformation de map
est coûteuse, appliquez la transformation de cache
après la transformation de map
tant que l'ensemble de données résultant peut toujours tenir dans la mémoire ou le stockage local. Si la fonction définie par l'utilisateur augmente l'espace requis pour stocker l'ensemble de données au-delà de la capacité du cache, appliquez-la après la transformation du cache
ou envisagez de prétraiter vos données avant votre tâche d'entraînement afin de réduire l'utilisation des ressources.
Cartographie de vectorisation
L'appel d'une fonction définie par l'utilisateur transmise à la transformation de map
entraîne une surcharge liée à la planification et à l'exécution de la fonction définie par l'utilisateur. Vectorisez la fonction définie par l'utilisateur (c'est-à-dire faites-la fonctionner sur un lot d'entrées à la fois) et appliquez la transformation batch
avant la transformation de la map
.
Pour illustrer cette bonne pratique, votre jeu de données artificiel n'est pas adapté. Le délai de planification est d'environ 10 microsecondes (10e-6 secondes), bien inférieur aux dizaines de millisecondes utilisées dans le ArtificialDataset
, et son impact est donc difficile à voir.
Pour cet exemple, utilisez la fonction de base tf.data.Dataset.range
et simplifiez la boucle de formation dans sa forme la plus simple.
fast_dataset = tf.data.Dataset.range(10000)
def fast_benchmark(dataset, num_epochs=2):
start_time = time.perf_counter()
for _ in tf.data.Dataset.range(num_epochs):
for _ in dataset:
pass
tf.print("Execution time:", time.perf_counter() - start_time)
def increment(x):
return x+1
Cartographie scalaire
fast_benchmark(
fast_dataset
# Apply function one item at a time
.map(increment)
# Batch
.batch(256)
)
Execution time: 0.2712608739998359
Le graphique ci-dessus illustre ce qui se passe (avec moins d'échantillons) en utilisant la méthode de cartographie scalaire. Il montre que la fonction mappée est appliquée pour chaque échantillon. Bien que cette fonction soit très rapide, elle a une surcharge qui a un impact sur les performances temporelles.
Cartographie vectorisée
fast_benchmark(
fast_dataset
.batch(256)
# Apply function on a batch of items
# The tf.Tensor.__add__ method already handle batches
.map(increment)
)
Execution time: 0.02737950600021577
Cette fois, la fonction mappée est appelée une fois et s'applique à un lot d'échantillon. Comme le montre le tracé du temps d'exécution des données, bien que la fonction puisse prendre plus de temps à s'exécuter, la surcharge n'apparaît qu'une seule fois, ce qui améliore les performances globales en matière de temps.
Réduction de l'empreinte mémoire
Un certain nombre de transformations, notamment interleave
, prefetch
et shuffle
, maintiennent un tampon interne d'éléments. Si la fonction définie par l'utilisateur transmise à la transformation de map
modifie la taille des éléments, l'ordre de la transformation de carte et des transformations qui mettent les éléments en mémoire tampon affecte l'utilisation de la mémoire. En général, choisissez l'ordre qui réduit l'empreinte mémoire, à moins qu'un ordre différent ne soit souhaitable pour les performances.
Mise en cache des calculs partiels
Il est recommandé de mettre en cache le jeu de données après la transformation de la map
, sauf si cette transformation rend les données trop volumineuses pour tenir en mémoire. Un compromis peut être atteint si votre fonction mappée peut être divisée en deux parties : une partie qui prend du temps et une partie qui consomme de la mémoire. Dans ce cas, vous pouvez enchaîner vos transformations comme ci-dessous :
dataset.map(time_consuming_mapping).cache().map(memory_consuming_mapping)
De cette façon, la partie chronophage n'est exécutée que pendant la première époque, et vous évitez d'utiliser trop d'espace de cache.
Résumé des meilleures pratiques
Voici un résumé des bonnes pratiques pour concevoir des pipelines d'entrée TensorFlow performants :
- Utiliser la transformation
prefetch
pour superposer le travail d'un producteur et d'un consommateur - Paralléliser la transformation de lecture de données à l'aide de la transformation d'
interleave
- Parallélisez la transformation de la
map
en définissant l'argumentnum_parallel_calls
- Utilisez la
cache
de cache pour mettre en cache les données en mémoire pendant la première époque - Vectoriser les fonctions définies par l'utilisateur transmises à la transformation de
map
- Réduisez l'utilisation de la mémoire lors de l'application des transformations
interleave
,prefetch
etshuffle
Reproduisant les chiffres
Pour approfondir la compréhension de l'API tf.data.Dataset
, vous pouvez jouer avec vos propres pipelines. Vous trouverez ci-dessous le code utilisé pour tracer les images de ce guide. Cela peut être un bon point de départ, montrant quelques solutions de contournement pour les difficultés courantes telles que :
- Reproductibilité du temps d'exécution
- Fonctions mappées exécutées avec impatience
- transformation
interleave
appelable
import itertools
from collections import defaultdict
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
Le jeu de données
Semblable au ArtificialDataset
, vous pouvez créer un ensemble de données renvoyant le temps passé à chaque étape.
class TimeMeasuredDataset(tf.data.Dataset):
# OUTPUT: (steps, timings, counters)
OUTPUT_TYPES = (tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32)
OUTPUT_SHAPES = ((2, 1), (2, 2), (2, 3))
_INSTANCES_COUNTER = itertools.count() # Number of datasets generated
_EPOCHS_COUNTER = defaultdict(itertools.count) # Number of epochs done for each dataset
def _generator(instance_idx, num_samples):
epoch_idx = next(TimeMeasuredDataset._EPOCHS_COUNTER[instance_idx])
# Opening the file
open_enter = time.perf_counter()
time.sleep(0.03)
open_elapsed = time.perf_counter() - open_enter
for sample_idx in range(num_samples):
# Reading data (line, record) from the file
read_enter = time.perf_counter()
time.sleep(0.015)
read_elapsed = time.perf_counter() - read_enter
yield (
[("Open",), ("Read",)],
[(open_enter, open_elapsed), (read_enter, read_elapsed)],
[(instance_idx, epoch_idx, -1), (instance_idx, epoch_idx, sample_idx)]
)
open_enter, open_elapsed = -1., -1. # Negative values will be filtered
def __new__(cls, num_samples=3):
return tf.data.Dataset.from_generator(
cls._generator,
output_types=cls.OUTPUT_TYPES,
output_shapes=cls.OUTPUT_SHAPES,
args=(next(cls._INSTANCES_COUNTER), num_samples)
)
Cet ensemble de données fournit des échantillons de forme [[2, 1], [2, 2], [2, 3]]
et de type [tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32]
. Chaque échantillon est :
(
[("Open"), ("Read")],
[(t0, d), (t0, d)],
[(i, e, -1), (i, e, s)]
)
Où:
-
Open
etRead
sont des identifiants d'étapes -
t0
est l'horodatage du démarrage de l'étape correspondante -
d
est le temps passé dans l'étape correspondante -
i
est l'indice d'instance -
e
est l'indice d'époque (nombre de fois où l'ensemble de données a été itéré) -
s
est l'indice d'échantillon
La boucle d'itération
Rendez la boucle d'itération un peu plus compliquée pour agréger tous les timings. Cela ne fonctionnera qu'avec des ensembles de données générant des échantillons comme indiqué ci-dessus.
def timelined_benchmark(dataset, num_epochs=2):
# Initialize accumulators
steps_acc = tf.zeros([0, 1], dtype=tf.dtypes.string)
times_acc = tf.zeros([0, 2], dtype=tf.dtypes.float32)
values_acc = tf.zeros([0, 3], dtype=tf.dtypes.int32)
start_time = time.perf_counter()
for epoch_num in range(num_epochs):
epoch_enter = time.perf_counter()
for (steps, times, values) in dataset:
# Record dataset preparation informations
steps_acc = tf.concat((steps_acc, steps), axis=0)
times_acc = tf.concat((times_acc, times), axis=0)
values_acc = tf.concat((values_acc, values), axis=0)
# Simulate training time
train_enter = time.perf_counter()
time.sleep(0.01)
train_elapsed = time.perf_counter() - train_enter
# Record training informations
steps_acc = tf.concat((steps_acc, [["Train"]]), axis=0)
times_acc = tf.concat((times_acc, [(train_enter, train_elapsed)]), axis=0)
values_acc = tf.concat((values_acc, [values[-1]]), axis=0)
epoch_elapsed = time.perf_counter() - epoch_enter
# Record epoch informations
steps_acc = tf.concat((steps_acc, [["Epoch"]]), axis=0)
times_acc = tf.concat((times_acc, [(epoch_enter, epoch_elapsed)]), axis=0)
values_acc = tf.concat((values_acc, [[-1, epoch_num, -1]]), axis=0)
time.sleep(0.001)
tf.print("Execution time:", time.perf_counter() - start_time)
return {"steps": steps_acc, "times": times_acc, "values": values_acc}
La méthode de traçage
Enfin, définissez une fonction capable de tracer une chronologie en fonction des valeurs renvoyées par la fonction timelined_benchmark
.
def draw_timeline(timeline, title, width=0.5, annotate=False, save=False):
# Remove invalid entries (negative times, or empty steps) from the timelines
invalid_mask = np.logical_and(timeline['times'] > 0, timeline['steps'] != b'')[:,0]
steps = timeline['steps'][invalid_mask].numpy()
times = timeline['times'][invalid_mask].numpy()
values = timeline['values'][invalid_mask].numpy()
# Get a set of different steps, ordered by the first time they are encountered
step_ids, indices = np.stack(np.unique(steps, return_index=True))
step_ids = step_ids[np.argsort(indices)]
# Shift the starting time to 0 and compute the maximal time value
min_time = times[:,0].min()
times[:,0] = (times[:,0] - min_time)
end = max(width, (times[:,0]+times[:,1]).max() + 0.01)
cmap = mpl.cm.get_cmap("plasma")
plt.close()
fig, axs = plt.subplots(len(step_ids), sharex=True, gridspec_kw={'hspace': 0})
fig.suptitle(title)
fig.set_size_inches(17.0, len(step_ids))
plt.xlim(-0.01, end)
for i, step in enumerate(step_ids):
step_name = step.decode()
ax = axs[i]
ax.set_ylabel(step_name)
ax.set_ylim(0, 1)
ax.set_yticks([])
ax.set_xlabel("time (s)")
ax.set_xticklabels([])
ax.grid(which="both", axis="x", color="k", linestyle=":")
# Get timings and annotation for the given step
entries_mask = np.squeeze(steps==step)
serie = np.unique(times[entries_mask], axis=0)
annotations = values[entries_mask]
ax.broken_barh(serie, (0, 1), color=cmap(i / len(step_ids)), linewidth=1, alpha=0.66)
if annotate:
for j, (start, width) in enumerate(serie):
annotation = "\n".join([f"{l}: {v}" for l,v in zip(("i", "e", "s"), annotations[j])])
ax.text(start + 0.001 + (0.001 * (j % 2)), 0.55 - (0.1 * (j % 2)), annotation,
horizontalalignment='left', verticalalignment='center')
if save:
plt.savefig(title.lower().translate(str.maketrans(" ", "_")) + ".svg")
Utiliser des wrappers pour la fonction mappée
Pour exécuter une fonction mappée dans un contexte impatient, vous devez les envelopper dans un appel tf.py_function
.
def map_decorator(func):
def wrapper(steps, times, values):
# Use a tf.py_function to prevent auto-graph from compiling the method
return tf.py_function(
func,
inp=(steps, times, values),
Tout=(steps.dtype, times.dtype, values.dtype)
)
return wrapper
Comparaison des pipelines
_batch_map_num_items = 50
def dataset_generator_fun(*args):
return TimeMeasuredDataset(num_samples=_batch_map_num_items)
Naïve
@map_decorator
def naive_map(steps, times, values):
map_enter = time.perf_counter()
time.sleep(0.001) # Time consuming step
time.sleep(0.0001) # Memory consuming step
map_elapsed = time.perf_counter() - map_enter
return (
tf.concat((steps, [["Map"]]), axis=0),
tf.concat((times, [[map_enter, map_elapsed]]), axis=0),
tf.concat((values, [values[-1]]), axis=0)
)
naive_timeline = timelined_benchmark(
tf.data.Dataset.range(2)
.flat_map(dataset_generator_fun)
.map(naive_map)
.batch(_batch_map_num_items, drop_remainder=True)
.unbatch(),
5
)
WARNING:tensorflow:From /tmp/ipykernel_23983/64197174.py:36: calling DatasetV2.from_generator (from tensorflow.python.data.ops.dataset_ops) with output_types is deprecated and will be removed in a future version. Instructions for updating: Use output_signature instead WARNING:tensorflow:From /tmp/ipykernel_23983/64197174.py:36: calling DatasetV2.from_generator (from tensorflow.python.data.ops.dataset_ops) with output_shapes is deprecated and will be removed in a future version. Instructions for updating: Use output_signature instead Execution time: 13.13538893499981
Optimisé
@map_decorator
def time_consuming_map(steps, times, values):
map_enter = time.perf_counter()
time.sleep(0.001 * values.shape[0]) # Time consuming step
map_elapsed = time.perf_counter() - map_enter
return (
tf.concat((steps, tf.tile([[["1st map"]]], [steps.shape[0], 1, 1])), axis=1),
tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
)
@map_decorator
def memory_consuming_map(steps, times, values):
map_enter = time.perf_counter()
time.sleep(0.0001 * values.shape[0]) # Memory consuming step
map_elapsed = time.perf_counter() - map_enter
# Use tf.tile to handle batch dimension
return (
tf.concat((steps, tf.tile([[["2nd map"]]], [steps.shape[0], 1, 1])), axis=1),
tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
)
optimized_timeline = timelined_benchmark(
tf.data.Dataset.range(2)
.interleave( # Parallelize data reading
dataset_generator_fun,
num_parallel_calls=tf.data.AUTOTUNE
)
.batch( # Vectorize your mapped function
_batch_map_num_items,
drop_remainder=True)
.map( # Parallelize map transformation
time_consuming_map,
num_parallel_calls=tf.data.AUTOTUNE
)
.cache() # Cache data
.map( # Reduce memory usage
memory_consuming_map,
num_parallel_calls=tf.data.AUTOTUNE
)
.prefetch( # Overlap producer and consumer works
tf.data.AUTOTUNE
)
.unbatch(),
5
)
Execution time: 6.723691489999965
draw_timeline(naive_timeline, "Naive", 15)
draw_timeline(optimized_timeline, "Optimized", 15)