Conseils sur l'amélioration des performances

Ce document fournit des conseils de performances spécifiques aux ensembles de données TensorFlow (TFDS). Notez que TFDS fournit des ensembles de données sous forme d'objets tf.data.Dataset , donc les conseils du guide tf.data s'appliquent toujours.

Ensembles de données de référence

Utilisez tfds.benchmark(ds) pour comparer n'importe quel objet tf.data.Dataset .

Assurez-vous d'indiquer batch_size= pour normaliser les résultats (par exemple 100 iter/sec -> 3200 ex/sec). Cela fonctionne avec n'importe quel itérable (par exemple tfds.benchmark(tfds.as_numpy(ds)) ).

ds = tfds.load('mnist', split='train').batch(32).prefetch()
# Display some benchmark statistics
tfds.benchmark(ds, batch_size=32)
# Second iteration is much faster, due to auto-caching
tfds.benchmark(ds, batch_size=32)

Petits ensembles de données (moins de 1 Go)

Tous les ensembles de données TFDS stockent les données sur disque au format TFRecord . Pour les petits ensembles de données (par exemple MNIST, CIFAR-10/-100), la lecture à partir de .tfrecord peut ajouter une surcharge importante.

Comme ces ensembles de données tiennent en mémoire, il est possible d’améliorer considérablement les performances en mettant en cache ou en préchargeant l’ensemble de données. Notez que TFDS met automatiquement en cache les petits ensembles de données (la section suivante contient les détails).

Mise en cache de l'ensemble de données

Voici un exemple de pipeline de données qui met explicitement en cache l'ensemble de données après avoir normalisé les images.

def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label


ds, ds_info = tfds.load(
    'mnist',
    split='train',
    as_supervised=True,  # returns `(img, label)` instead of dict(image=, ...)
    with_info=True,
)
# Applying normalization before `ds.cache()` to re-use it.
# Note: Random transformations (e.g. images augmentations) should be applied
# after both `ds.cache()` (to avoid caching randomness) and `ds.batch()` (for
# vectorization [1]).
ds = ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.cache()
# For true randomness, we set the shuffle buffer to the full dataset size.
ds = ds.shuffle(ds_info.splits['train'].num_examples)
# Batch after shuffling to get unique batches at each epoch.
ds = ds.batch(128)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

Lors de l'itération sur cet ensemble de données, la deuxième itération sera beaucoup plus rapide que la première grâce à la mise en cache.

Mise en cache automatique

Par défaut, TFDS met automatiquement en cache (avec ds.cache() ) les ensembles de données qui satisfont aux contraintes suivantes :

  • La taille totale de l'ensemble de données (toutes les divisions) est définie et < 250 Mio
  • shuffle_files est désactivé ou un seul fragment est lu

Il est possible de désactiver la mise en cache automatique en passant try_autocaching=False à tfds.ReadConfig dans tfds.load . Jetez un œil à la documentation du catalogue d'ensembles de données pour voir si un ensemble de données spécifique utilisera le cache automatique.

Chargement des données complètes en tant que Tensor unique

Si votre ensemble de données tient dans la mémoire, vous pouvez également charger l'ensemble de données complet sous la forme d'un seul tableau Tensor ou NumPy. Il est possible de le faire en définissant batch_size=-1 pour regrouper tous les exemples dans un seul tf.Tensor . Utilisez ensuite tfds.as_numpy pour la conversion de tf.Tensor en np.array .

(img_train, label_train), (img_test, label_test) = tfds.as_numpy(tfds.load(
    'mnist',
    split=['train', 'test'],
    batch_size=-1,
    as_supervised=True,
))

Grands ensembles de données

Les ensembles de données volumineux sont fragmentés (divisés en plusieurs fichiers) et ne tiennent généralement pas en mémoire. Ils ne doivent donc pas être mis en cache.

Mélange et entraînement

Pendant l'entraînement, il est important de bien mélanger les données. Des données mal mélangées peuvent entraîner une moindre précision de l'entraînement.

En plus d'utiliser ds.shuffle pour mélanger les enregistrements, vous devez également définir shuffle_files=True pour obtenir un bon comportement de lecture aléatoire pour les ensembles de données plus volumineux fragmentés en plusieurs fichiers. Sinon, les époques liront les fragments dans le même ordre et les données ne seront donc pas véritablement aléatoires.

ds = tfds.load('imagenet2012', split='train', shuffle_files=True)

De plus, lorsque shuffle_files=True , TFDS désactive options.deterministic , ce qui peut améliorer légèrement les performances. Pour obtenir un brassage déterministe, il est possible de désactiver cette fonctionnalité avec tfds.ReadConfig : soit en définissant read_config.shuffle_seed , soit en écrasant read_config.options.deterministic .

Partagez automatiquement vos données entre les collaborateurs (TF)

Lors de la formation sur plusieurs travailleurs, vous pouvez utiliser l'argument input_context de tfds.ReadConfig , afin que chaque travailleur lise un sous-ensemble des données.

input_context = tf.distribute.InputContext(
    input_pipeline_id=1,  # Worker id
    num_input_pipelines=4,  # Total number of workers
)
read_config = tfds.ReadConfig(
    input_context=input_context,
)
ds = tfds.load('dataset', split='train', read_config=read_config)

Ceci est complémentaire à l’API subsplit. Tout d'abord, l'API subplit est appliquée : train[:50%] est converti en une liste de fichiers à lire. Ensuite, une opération ds.shard() est appliquée sur ces fichiers. Par exemple, lors de l'utilisation train[:50%] avec num_input_pipelines=2 , chacun des 2 travailleurs lira 1/4 des données.

Lorsque shuffle_files=True , les fichiers sont mélangés au sein d'un seul travailleur, mais pas entre les travailleurs. Chaque travailleur lira le même sous-ensemble de fichiers entre les époques.

Partagez automatiquement vos données entre les collaborateurs (Jax)

Avec Jax, vous pouvez utiliser l'API tfds.split_for_jax_process ou tfds.even_splits pour distribuer vos données entre les nœuds de calcul. Consultez le guide de l'API divisée .

split = tfds.split_for_jax_process('train', drop_remainder=True)
ds = tfds.load('my_dataset', split=split)

tfds.split_for_jax_process est un simple alias pour :

# The current `process_index` loads only `1 / process_count` of the data.
splits = tfds.even_splits('train', n=jax.process_count(), drop_remainder=True)
split = splits[jax.process_index()]

Décodage d'image plus rapide

Par défaut, TFDS décode automatiquement les images. Cependant, il existe des cas où il peut être plus performant d'ignorer le décodage de l'image avec tfds.decode.SkipDecoding et d'appliquer manuellement l'opération tf.io.decode_image :

Le code des deux exemples est disponible dans le guide de décodage .

Ignorer les fonctionnalités inutilisées

Si vous n'utilisez qu'un sous-ensemble de fonctionnalités, il est possible d'ignorer complètement certaines fonctionnalités. Si votre ensemble de données comporte de nombreuses fonctionnalités inutilisées, ne pas les décoder peut améliorer considérablement les performances. Voir https://www.tensorflow.org/datasets/decode#only_decode_a_sub-set_of_the_features

tf.data utilise toute ma RAM !

Si vous êtes limité en RAM ou si vous chargez de nombreux ensembles de données en parallèle tout en utilisant tf.data , voici quelques options qui peuvent vous aider :

Remplacer la taille du tampon

builder.as_dataset(
  read_config=tfds.ReadConfig(
    ...
    override_buffer_size=1024,  # Save quite a bit of RAM.
  ),
  ...
)

Cela remplace le buffer_size transmis à TFRecordDataset (ou équivalent) : https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset#args

Utilisez tf.data.Dataset.with_options pour arrêter les comportements magiques

https://www.tensorflow.org/api_docs/python/tf/data/Dataset#with_options

options = tf.data.Options()

# Stop magic stuff that eats up RAM:
options.autotune.enabled = False
options.experimental_distribute.auto_shard_policy = (
  tf.data.experimental.AutoShardPolicy.OFF)
options.experimental_optimization.inject_prefetch = False

data = data.with_options(options)