Dicas de desempenho

Este documento fornece dicas de desempenho específicas do TensorFlow Datasets (TFDS). Observe que o TFDS fornece conjuntos de dados como objetos tf.data.Dataset , portanto, o conselho do guia tf.data ainda se aplica.

Conjuntos de dados de referência

Use tfds.benchmark(ds) para avaliar qualquer objeto tf.data.Dataset .

Certifique-se de indicar batch_size= para normalizar os resultados (por exemplo, 100 iter/seg -> 3200 ex/seg). Isso funciona com qualquer iterável (por exemplo tfds.benchmark(tfds.as_numpy(ds)) ).

ds = tfds.load('mnist', split='train').batch(32).prefetch()
# Display some benchmark statistics
tfds.benchmark(ds, batch_size=32)
# Second iteration is much faster, due to auto-caching
tfds.benchmark(ds, batch_size=32)

Conjuntos de dados pequenos (menos de 1 GB)

Todos os conjuntos de dados TFDS armazenam os dados em disco no formato TFRecord . Para conjuntos de dados pequenos (por exemplo, MNIST, CIFAR-10/-100), a leitura de .tfrecord pode adicionar sobrecarga significativa.

À medida que esses conjuntos de dados cabem na memória, é possível melhorar significativamente o desempenho armazenando em cache ou pré-carregando o conjunto de dados. Observe que o TFDS armazena automaticamente em cache pequenos conjuntos de dados (a seção a seguir contém os detalhes).

Armazenando o conjunto de dados em cache

Aqui está um exemplo de pipeline de dados que armazena explicitamente o conjunto de dados em cache após normalizar as imagens.

def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label


ds, ds_info = tfds.load(
    'mnist',
    split='train',
    as_supervised=True,  # returns `(img, label)` instead of dict(image=, ...)
    with_info=True,
)
# Applying normalization before `ds.cache()` to re-use it.
# Note: Random transformations (e.g. images augmentations) should be applied
# after both `ds.cache()` (to avoid caching randomness) and `ds.batch()` (for
# vectorization [1]).
ds = ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.cache()
# For true randomness, we set the shuffle buffer to the full dataset size.
ds = ds.shuffle(ds_info.splits['train'].num_examples)
# Batch after shuffling to get unique batches at each epoch.
ds = ds.batch(128)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

Ao iterar neste conjunto de dados, a segunda iteração será muito mais rápida que a primeira graças ao cache.

Cache automático

Por padrão, o TFDS armazena em cache automático (com ds.cache() ) conjuntos de dados que satisfazem as seguintes restrições:

  • O tamanho total do conjunto de dados (todas as divisões) é definido e <250 MiB
  • shuffle_files está desativado ou apenas um único fragmento é lido

É possível cancelar o cache automático passando try_autocaching=False para tfds.ReadConfig em tfds.load . Dê uma olhada na documentação do catálogo do conjunto de dados para ver se um conjunto de dados específico usará cache automático.

Carregando os dados completos como um único Tensor

Se o seu conjunto de dados couber na memória, você também pode carregar o conjunto de dados completo como um único array Tensor ou NumPy. É possível fazer isso definindo batch_size=-1 para agrupar todos os exemplos em um único tf.Tensor . Em seguida, use tfds.as_numpy para a conversão de tf.Tensor para np.array .

(img_train, label_train), (img_test, label_test) = tfds.as_numpy(tfds.load(
    'mnist',
    split=['train', 'test'],
    batch_size=-1,
    as_supervised=True,
))

Grandes conjuntos de dados

Grandes conjuntos de dados são fragmentados (divididos em vários arquivos) e normalmente não cabem na memória, portanto, não devem ser armazenados em cache.

Embaralhamento e treinamento

Durante o treinamento, é importante embaralhar bem os dados – dados mal embaralhados podem resultar em menor precisão do treinamento.

Além de usar ds.shuffle para embaralhar registros, você também deve definir shuffle_files=True para obter um bom comportamento de embaralhamento para conjuntos de dados maiores que são fragmentados em vários arquivos. Caso contrário, as épocas lerão os fragmentos na mesma ordem e, portanto, os dados não serão verdadeiramente randomizados.

ds = tfds.load('imagenet2012', split='train', shuffle_files=True)

Além disso, quando shuffle_files=True , o TFDS desativa options.deterministic , o que pode proporcionar um ligeiro aumento de desempenho. Para obter embaralhamento determinístico, é possível desativar esse recurso com tfds.ReadConfig : configurando read_config.shuffle_seed ou substituindo read_config.options.deterministic .

Fragmente automaticamente seus dados entre trabalhadores (TF)

Ao treinar vários trabalhadores, você pode usar o argumento input_context de tfds.ReadConfig , para que cada trabalhador leia um subconjunto dos dados.

input_context = tf.distribute.InputContext(
    input_pipeline_id=1,  # Worker id
    num_input_pipelines=4,  # Total number of workers
)
read_config = tfds.ReadConfig(
    input_context=input_context,
)
ds = tfds.load('dataset', split='train', read_config=read_config)

Isso é complementar à API subsplit. Primeiro, a API subplit é aplicada: train[:50%] é convertido em uma lista de arquivos para leitura. Em seguida, uma operação ds.shard() é aplicada a esses arquivos. Por exemplo, ao usar train[:50%] com num_input_pipelines=2 , cada um dos 2 trabalhadores lerá 1/4 dos dados.

Quando shuffle_files=True , os arquivos são embaralhados dentro de um trabalhador, mas não entre trabalhadores. Cada trabalhador lerá o mesmo subconjunto de arquivos entre épocas.

Fragmente automaticamente seus dados entre trabalhadores (Jax)

Com Jax, você pode usar a API tfds.split_for_jax_process ou tfds.even_splits para distribuir seus dados entre trabalhadores. Consulte o guia da API dividida .

split = tfds.split_for_jax_process('train', drop_remainder=True)
ds = tfds.load('my_dataset', split=split)

tfds.split_for_jax_process é um alias simples para:

# The current `process_index` loads only `1 / process_count` of the data.
splits = tfds.even_splits('train', n=jax.process_count(), drop_remainder=True)
split = splits[jax.process_index()]

Decodificação de imagem mais rápida

Por padrão, o TFDS decodifica imagens automaticamente. No entanto, há casos em que pode ser mais eficiente pular a decodificação da imagem com tfds.decode.SkipDecoding e aplicar manualmente a operação tf.io.decode_image :

O código para ambos os exemplos está disponível no guia de decodificação .

Ignorar recursos não utilizados

Se você estiver usando apenas um subconjunto de recursos, é possível ignorar completamente alguns recursos. Se o seu conjunto de dados tiver muitos recursos não utilizados, não decodificá-los pode melhorar significativamente o desempenho. Consulte https://www.tensorflow.org/datasets/decode#only_decode_a_sub-set_of_the_features

tf.data usa toda a minha RAM!

Se você tiver memória RAM limitada ou estiver carregando muitos conjuntos de dados em paralelo ao usar tf.data , aqui estão algumas opções que podem ajudar:

Substituir tamanho do buffer

builder.as_dataset(
  read_config=tfds.ReadConfig(
    ...
    override_buffer_size=1024,  # Save quite a bit of RAM.
  ),
  ...
)

Isso substitui o buffer_size passado para TFRecordDataset (ou equivalente): https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset#args

Use tf.data.Dataset.with_options para interromper comportamentos mágicos

https://www.tensorflow.org/api_docs/python/tf/data/Dataset#with_options

options = tf.data.Options()

# Stop magic stuff that eats up RAM:
options.autotune.enabled = False
options.experimental_distribute.auto_shard_policy = (
  tf.data.experimental.AutoShardPolicy.OFF)
options.experimental_optimization.inject_prefetch = False

data = data.with_options(options)