TFDS e determinismo

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza su GitHub Scarica taccuino

Questo documento spiega:

  • Le garanzie TFDS sul determinismo
  • In quale ordine TFDS legge gli esempi
  • Vari avvertimenti e trucchi

Impostare

Set di dati

È necessario un po' di contesto per capire come TFDS legge i dati.

Durante la generazione, TFDS scrivono i dati originali in standardizzati .tfrecord file. Per i grandi insiemi di dati, più .tfrecord file vengono creati, ciascuno dei quali contiene più esempi. Chiamiamo ogni .tfrecord il file una scheggia.

Questa guida utilizza imagenet che ha 1024 shard:

import re
import tensorflow_datasets as tfds

imagenet = tfds.builder('imagenet2012')

num_shards = imagenet.info.splits['train'].num_shards
num_examples = imagenet.info.splits['train'].num_examples
print(f'imagenet has {num_shards} shards ({num_examples} examples)')
imagenet has 1024 shards (1281167 examples)

Trovare gli ID degli esempi di set di dati

Puoi saltare alla sezione seguente se vuoi conoscere solo il determinismo.

Ogni esempio set di dati è identificato univocamente da un id (es 'imagenet2012-train.tfrecord-01023-of-01024__32' ). È possibile recuperare questa id passando read_config.add_tfds_id = True che aggiungerà un 'tfds_id' chiave nel dict dal tf.data.Dataset .

In questo tutorial, definiamo un piccolo programma di utilità che stamperà gli ID di esempio del set di dati (convertito in intero per essere più leggibile dall'uomo):

Determinismo durante la lettura

Questa sezione spiega garanzia deterministim di tfds.load .

Con shuffle_files=False (predefinito)

Con TFDS predefinite producono esempi deterministico ( shuffle_files=False )

# Same as: imagenet.as_dataset(split='train').take(20)
print_ex_ids(imagenet, split='train', take=20)
print_ex_ids(imagenet, split='train', take=20)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]

Per le prestazioni, TFDS leggere più frammenti allo stesso tempo utilizzando tf.data.Dataset.interleave . Vediamo in questo esempio che TFDS passare al frammento 2 dopo aver letto 16 esempi ( ..., 14, 15, 1251, 1252, ... ). Maggiori informazioni sull'interfoglio qui sotto.

Allo stesso modo, anche l'API subsplit è deterministica:

print_ex_ids(imagenet, split='train[67%:84%]', take=20)
print_ex_ids(imagenet, split='train[67%:84%]', take=20)
[858382, 858383, 858384, 858385, 858386, 858387, 858388, 858389, 858390, 858391, 858392, 858393, 858394, 858395, 858396, 858397, 859533, 859534, 859535, 859536]
[858382, 858383, 858384, 858385, 858386, 858387, 858388, 858389, 858390, 858391, 858392, 858393, 858394, 858395, 858396, 858397, 859533, 859534, 859535, 859536]

Se ti alleni per più di un epoca, la messa a punto di cui sopra non è raccomandata come tutte le epoche leggeranno i frammenti nello stesso ordine (così casualità è limitata ai ds = ds.shuffle(buffer) la dimensione del buffer).

Con shuffle_files=True

Con shuffle_files=True , frammenti vengono mescolate per ogni epoca, quindi la lettura non è più deterministica.

print_ex_ids(imagenet, split='train', shuffle_files=True, take=20)
print_ex_ids(imagenet, split='train', shuffle_files=True, take=20)
[568017, 329050, 329051, 329052, 329053, 329054, 329056, 329055, 568019, 568020, 568021, 568022, 568023, 568018, 568025, 568024, 568026, 568028, 568030, 568031]
[43790, 43791, 43792, 43793, 43796, 43794, 43797, 43798, 43795, 43799, 43800, 43801, 43802, 43803, 43804, 43805, 43806, 43807, 43809, 43810]

Vedere la ricetta di seguito per ottenere il rimescolamento deterministico dei file.

Avvertenza sul determinismo: interleave args

Cambiare read_config.interleave_cycle_length , read_config.interleave_block_length cambierà l'ordine di esempi.

TFDS si basa su tf.data.Dataset.interleave per caricare solo pochi frammenti in una sola volta, migliorando le prestazioni e riducendo l'utilizzo di memoria.

L'ordine di esempio è garantito solo per essere lo stesso per un valore fisso di argomenti interleave. Vedi doc interleave di capire cosa cycle_length e block_length corrispondono troppo.

  • cycle_length=16 , block_length=16 (predefinito, come sopra):
print_ex_ids(imagenet, split='train', take=20)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]
  • cycle_length=3 , block_length=2 :
read_config = tfds.ReadConfig(
    interleave_cycle_length=3,
    interleave_block_length=2,
)
print_ex_ids(imagenet, split='train', read_config=read_config, take=20)
[0, 1, 1251, 1252, 2502, 2503, 2, 3, 1253, 1254, 2504, 2505, 4, 5, 1255, 1256, 2506, 2507, 6, 7]

Nel secondo esempio, si vede che il set di dati letto 2 ( block_length=2 ) Esempi di un frammento, poi passare al successivo frammento. Ogni 2 * 3 ( cycle_length=3 ) esempi, risale al primo frammento ( shard0-ex0, shard0-ex1, shard1-ex0, shard1-ex1, shard2-ex0, shard2-ex1, shard0-ex2, shard0-ex3, shard1-ex2, shard1-ex3, shard2-ex2,... ).

Sottosuddivisione e ordine di esempio

Ogni esempio ha un id 0, 1, ..., num_examples-1 . L' API subsplit selezionare una fetta di esempi (es train[:x] selezionare 0, 1, ..., x-1 ).

Tuttavia, all'interno della suddivisione, gli esempi non vengono letti in ordine crescente di ID (a causa di frammenti e interleave).

Più in particolare, ds.take(x) e split='train[:x]' non sono equivalenti!

Questo può essere visto facilmente nell'esempio di interleave sopra in cui gli esempi provengono da frammenti diversi.

print_ex_ids(imagenet, split='train', take=25)  # tfds.load(..., split='train').take(25)
print_ex_ids(imagenet, split='train[:25]', take=-1)  # tfds.load(..., split='train[:25]')
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254, 1255, 1256, 1257, 1258, 1259]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]

Dopo le 16 (block_length) esempi, .take(25) passa al successivo frammento mentre train[:25] Continue reading esempi dal primo frammento.

Ricette

Ottieni lo shuffling deterministico dei file

Ci sono 2 modi per avere un rimescolamento deterministico:

  1. Impostazione del shuffle_seed . Nota: ciò richiede la modifica del seme ad ogni epoca, altrimenti i frammenti verranno letti nello stesso ordine tra le epoche.
read_config = tfds.ReadConfig(
    shuffle_seed=32,
)

# Deterministic order, different from the default shuffle_files=False above
print_ex_ids(imagenet, split='train', shuffle_files=True, read_config=read_config, take=22)
print_ex_ids(imagenet, split='train', shuffle_files=True, read_config=read_config, take=22)
[176411, 176412, 176413, 176414, 176415, 176416, 176417, 176418, 176419, 176420, 176421, 176422, 176423, 176424, 176425, 176426, 710647, 710648, 710649, 710650, 710651, 710652]
[176411, 176412, 176413, 176414, 176415, 176416, 176417, 176418, 176419, 176420, 176421, 176422, 176423, 176424, 176425, 176426, 710647, 710648, 710649, 710650, 710651, 710652]
  1. Utilizzando experimental_interleave_sort_fn : Questo dà il pieno controllo su quali frammenti vengono letti e in quale ordine, piuttosto che basarsi su ds.shuffle ordine.
def _reverse_order(file_instructions):
  return list(reversed(file_instructions))

read_config = tfds.ReadConfig(
    experimental_interleave_sort_fn=_reverse_order,
)

# Last shard (01023-of-01024) is read first
print_ex_ids(imagenet, split='train', read_config=read_config, take=5)
[1279916, 1279917, 1279918, 1279919, 1279920]

Ottieni una pipeline prerilasciabile deterministica

Questo è più complicato. Non esiste una soluzione facile e soddisfacente.

  1. Senza ds.shuffle e con rimescolamento deterministica, in teoria dovrebbe essere possibile contare gli esempi che sono stati letti e dedurre che gli esempi sono stati letti dentro in ogni frammento (come funzione di cycle_length , block_length e ordine frammento). Poi il skip , take per ogni frammento potrebbe essere iniettato attraverso experimental_interleave_sort_fn .

  2. Con ds.shuffle è probabile impossibile senza riprodurre la pipeline di formazione completo. Sarebbe necessario salvare il ds.shuffle stato cuscinetto dedurre che gli esempi sono stati letti. Esempi possono essere non continua (es shard5_ex2 , shard5_ex4 letto ma non shard5_ex3 ).

  3. Con ds.shuffle , in un modo sarebbe quello di salvare tutti shards_ids / example_ids Read (dedotte da tfds_id ), quindi dedurre le istruzioni di file da questo.

Il caso più semplice per 1. è di avere .skip(x).take(y) partita train[x:x+y] partita. Richiede:

  • Set cycle_length=1 (in modo da schegge vengono letti in sequenza)
  • Set shuffle_files=False
  • Non utilizzare ds.shuffle

Dovrebbe essere utilizzato solo su un enorme set di dati in cui l'addestramento è solo 1 epoca. Gli esempi verrebbero letti nell'ordine casuale predefinito.

read_config = tfds.ReadConfig(
    interleave_cycle_length=1,  # Read shards sequentially
)

print_ex_ids(imagenet, split='train', read_config=read_config, skip=40, take=22)
# If the job get pre-empted, using the subsplit API will skip at most `len(shard0)`
print_ex_ids(imagenet, split='train[40:]', read_config=read_config, take=22)
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]

Trova quali frammenti/esempi vengono letti per una determinata suddivisione parziale

Con il tfds.core.DatasetInfo , si ha accesso diretto alle istruzioni di lettura.

imagenet.info.splits['train[44%:45%]'].file_instructions
[FileInstruction(filename='imagenet2012-train.tfrecord-00450-of-01024', skip=700, take=-1, num_examples=551),
 FileInstruction(filename='imagenet2012-train.tfrecord-00451-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00452-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00453-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00454-of-01024', skip=0, take=-1, num_examples=1252),
 FileInstruction(filename='imagenet2012-train.tfrecord-00455-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00456-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00457-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00458-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00459-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00460-of-01024', skip=0, take=1001, num_examples=1001)]