Questo documento fornisce suggerimenti sulle prestazioni specifici di TensorFlow Datasets (TFDS). Tieni presente che TFDS fornisce set di dati come oggetti tf.data.Dataset
, quindi i consigli della guida tf.data
sono ancora validi.
Set di dati di riferimento
Utilizza tfds.benchmark(ds)
per confrontare qualsiasi oggetto tf.data.Dataset
.
Assicurati di indicare batch_size=
per normalizzare i risultati (ad esempio 100 iter/sec -> 3200 ex/sec). Funziona con qualsiasi iterabile (ad esempio tfds.benchmark(tfds.as_numpy(ds))
).
ds = tfds.load('mnist', split='train').batch(32).prefetch()
# Display some benchmark statistics
tfds.benchmark(ds, batch_size=32)
# Second iteration is much faster, due to auto-caching
tfds.benchmark(ds, batch_size=32)
Piccoli set di dati (meno di 1 GB)
Tutti i set di dati TFDS memorizzano i dati su disco nel formato TFRecord
. Per set di dati di piccole dimensioni (ad esempio MNIST, CIFAR-10/-100), la lettura da .tfrecord
può aggiungere un sovraccarico significativo.
Poiché questi set di dati si adattano alla memoria, è possibile migliorare significativamente le prestazioni memorizzando nella cache o precaricando il set di dati. Tieni presente che TFDS memorizza automaticamente nella cache piccoli set di dati (la sezione seguente contiene i dettagli).
Memorizzazione nella cache del set di dati
Ecco un esempio di una pipeline di dati che memorizza esplicitamente nella cache il set di dati dopo aver normalizzato le immagini.
def normalize_img(image, label):
"""Normalizes images: `uint8` -> `float32`."""
return tf.cast(image, tf.float32) / 255., label
ds, ds_info = tfds.load(
'mnist',
split='train',
as_supervised=True, # returns `(img, label)` instead of dict(image=, ...)
with_info=True,
)
# Applying normalization before `ds.cache()` to re-use it.
# Note: Random transformations (e.g. images augmentations) should be applied
# after both `ds.cache()` (to avoid caching randomness) and `ds.batch()` (for
# vectorization [1]).
ds = ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.cache()
# For true randomness, we set the shuffle buffer to the full dataset size.
ds = ds.shuffle(ds_info.splits['train'].num_examples)
# Batch after shuffling to get unique batches at each epoch.
ds = ds.batch(128)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
Quando si esegue l'iterazione su questo set di dati, la seconda iterazione sarà molto più veloce della prima grazie alla memorizzazione nella cache.
Memorizzazione nella cache automatica
Per impostazione predefinita, TFDS memorizza automaticamente nella cache (con ds.cache()
) i set di dati che soddisfano i seguenti vincoli:
- La dimensione totale del set di dati (tutte le suddivisioni) è definita e < 250 MiB
-
shuffle_files
è disabilitato o viene letto solo un singolo frammento
È possibile disattivare la memorizzazione nella cache automatica passando try_autocaching=False
a tfds.ReadConfig
in tfds.load
. Dai un'occhiata alla documentazione del catalogo dei set di dati per vedere se un set di dati specifico utilizzerà la cache automatica.
Caricamento dei dati completi come un singolo tensore
Se il set di dati rientra nella memoria, puoi anche caricare l'intero set di dati come un singolo array Tensor o NumPy. È possibile farlo impostando batch_size=-1
per raggruppare tutti gli esempi in un unico tf.Tensor
. Quindi utilizzare tfds.as_numpy
per la conversione da tf.Tensor
a np.array
.
(img_train, label_train), (img_test, label_test) = tfds.as_numpy(tfds.load(
'mnist',
split=['train', 'test'],
batch_size=-1,
as_supervised=True,
))
Set di dati di grandi dimensioni
I set di dati di grandi dimensioni vengono suddivisi in partizioni (divisi in più file) e in genere non entrano nella memoria, quindi non devono essere memorizzati nella cache.
Shuffle e allenamento
Durante l'allenamento, è importante mescolare bene i dati: i dati mescolati in modo inadeguato possono comportare una minore precisione dell'allenamento.
Oltre a utilizzare ds.shuffle
per mescolare i record, dovresti anche impostare shuffle_files=True
per ottenere un buon comportamento di mescolamento per set di dati più grandi suddivisi in più file. Altrimenti, le epoche leggeranno i frammenti nello stesso ordine e quindi i dati non saranno veramente randomizzati.
ds = tfds.load('imagenet2012', split='train', shuffle_files=True)
Inoltre, quando shuffle_files=True
, TFDS disabilita options.deterministic
, il che può fornire un leggero aumento delle prestazioni. Per ottenere un mescolamento deterministico, è possibile disattivare questa funzionalità con tfds.ReadConfig
: impostando read_config.shuffle_seed
o sovrascrivendo read_config.options.deterministic
.
Suddividi automaticamente i tuoi dati tra i lavoratori (TF)
Durante l'addestramento su più lavoratori, puoi utilizzare l'argomento input_context
di tfds.ReadConfig
, in modo che ogni lavoratore leggerà un sottoinsieme di dati.
input_context = tf.distribute.InputContext(
input_pipeline_id=1, # Worker id
num_input_pipelines=4, # Total number of workers
)
read_config = tfds.ReadConfig(
input_context=input_context,
)
ds = tfds.load('dataset', split='train', read_config=read_config)
Questo è complementare all'API subsplit. Per prima cosa viene applicata l'API subplit: train[:50%]
viene convertito in un elenco di file da leggere. Quindi, a tali file viene applicata un'operazione ds.shard()
. Ad esempio, quando si utilizza train[:50%]
con num_input_pipelines=2
, ciascuno dei 2 lavoratori leggerà 1/4 dei dati.
Quando shuffle_files=True
, i file vengono mescolati all'interno di un lavoratore, ma non tra lavoratori. Ogni lavoratore leggerà lo stesso sottoinsieme di file tra le epoche.
Suddividi automaticamente i tuoi dati tra i lavoratori (Jax)
Con Jax, puoi utilizzare l'API tfds.split_for_jax_process
o tfds.even_splits
per distribuire i tuoi dati tra i lavoratori. Consulta la guida all'API divisa .
split = tfds.split_for_jax_process('train', drop_remainder=True)
ds = tfds.load('my_dataset', split=split)
tfds.split_for_jax_process
è un semplice alias per:
# The current `process_index` loads only `1 / process_count` of the data.
splits = tfds.even_splits('train', n=jax.process_count(), drop_remainder=True)
split = splits[jax.process_index()]
Decodifica delle immagini più veloce
Per impostazione predefinita, TFDS decodifica automaticamente le immagini. Tuttavia, ci sono casi in cui può essere più efficace saltare la decodifica dell'immagine con tfds.decode.SkipDecoding
e applicare manualmente l'operazione tf.io.decode_image
:
- Quando si filtrano gli esempi (con
tf.data.Dataset.filter
), per decodificare le immagini dopo che gli esempi sono stati filtrati. - Quando si ritagliano le immagini, per utilizzare il fuso
tf.image.decode_and_crop_jpeg
op.
Il codice per entrambi gli esempi è disponibile nella guida alla decodifica .
Salta le funzionalità non utilizzate
Se utilizzi solo un sottoinsieme delle funzionalità, è possibile ignorarne completamente alcune. Se il tuo set di dati ha molte funzionalità inutilizzate, non decodificarle può migliorare significativamente le prestazioni. Vedi https://www.tensorflow.org/datasets/decode#only_decode_a_sub-set_of_the_features
tf.data utilizza tutta la mia RAM!
Se hai una RAM limitata o se stai caricando molti set di dati in parallelo mentre usi tf.data
, ecco alcune opzioni che possono aiutarti:
Sostituisci la dimensione del buffer
builder.as_dataset(
read_config=tfds.ReadConfig(
...
override_buffer_size=1024, # Save quite a bit of RAM.
),
...
)
Ciò sovrascrive buffer_size
passato a TFRecordDataset
(o equivalente): https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset#args
Utilizza tf.data.Dataset.with_options per interrompere comportamenti magici
https://www.tensorflow.org/api_docs/python/tf/data/Dataset#with_options
options = tf.data.Options()
# Stop magic stuff that eats up RAM:
options.autotune.enabled = False
options.experimental_distribute.auto_shard_policy = (
tf.data.experimental.AutoShardPolicy.OFF)
options.experimental_optimization.inject_prefetch = False
data = data.with_options(options)