В этом документе представлены советы по производительности, специфичные для наборов данных TensorFlow (TFDS). Обратите внимание, что TFDS предоставляет наборы данных как объекты tf.data.Dataset
, поэтому рекомендации из руководства tf.data
по-прежнему применимы.
Контрольные наборы данных
Используйте tfds.benchmark(ds)
для сравнения любого объекта tf.data.Dataset
.
Обязательно укажите batch_size=
чтобы нормализовать результаты (например, 100 итер/сек -> 3200 итер/сек). Это работает с любой итерацией (например tfds.benchmark(tfds.as_numpy(ds))
).
ds = tfds.load('mnist', split='train').batch(32).prefetch()
# Display some benchmark statistics
tfds.benchmark(ds, batch_size=32)
# Second iteration is much faster, due to auto-caching
tfds.benchmark(ds, batch_size=32)
Небольшие наборы данных (менее 1 ГБ)
Все наборы данных TFDS хранят данные на диске в формате TFRecord
. Для небольших наборов данных (например, MNIST, CIFAR-10/-100) чтение из .tfrecord
может привести к значительным накладным расходам.
Поскольку эти наборы данных помещаются в память, можно значительно повысить производительность за счет кэширования или предварительной загрузки набора данных. Обратите внимание, что TFDS автоматически кэширует небольшие наборы данных (подробности приведены в следующем разделе).
Кэширование набора данных
Вот пример конвейера данных, который явно кэширует набор данных после нормализации изображений.
def normalize_img(image, label):
"""Normalizes images: `uint8` -> `float32`."""
return tf.cast(image, tf.float32) / 255., label
ds, ds_info = tfds.load(
'mnist',
split='train',
as_supervised=True, # returns `(img, label)` instead of dict(image=, ...)
with_info=True,
)
# Applying normalization before `ds.cache()` to re-use it.
# Note: Random transformations (e.g. images augmentations) should be applied
# after both `ds.cache()` (to avoid caching randomness) and `ds.batch()` (for
# vectorization [1]).
ds = ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.cache()
# For true randomness, we set the shuffle buffer to the full dataset size.
ds = ds.shuffle(ds_info.splits['train'].num_examples)
# Batch after shuffling to get unique batches at each epoch.
ds = ds.batch(128)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
При переборе этого набора данных вторая итерация будет намного быстрее первой благодаря кэшированию.
Автокэширование
По умолчанию TFDS автоматически кэширует (с помощью ds.cache()
) наборы данных, которые удовлетворяют следующим ограничениям:
- Общий размер набора данных (все разделения) определен и составляет < 250 МБ.
-
shuffle_files
отключен или читается только один осколок
Можно отказаться от автоматического кэширования, передав try_autocaching=False
в tfds.ReadConfig
в tfds.load
. Посмотрите документацию по каталогу наборов данных, чтобы узнать, будет ли конкретный набор данных использовать автоматическое кэширование.
Загрузка полных данных в виде одного тензора
Если ваш набор данных помещается в память, вы также можете загрузить полный набор данных в виде одного массива Tensor или NumPy. Это можно сделать, установив batch_size=-1
для пакетной обработки всех примеров в одном tf.Tensor
. Затем используйте tfds.as_numpy
для преобразования из tf.Tensor
в np.array
.
(img_train, label_train), (img_test, label_test) = tfds.as_numpy(tfds.load(
'mnist',
split=['train', 'test'],
batch_size=-1,
as_supervised=True,
))
Большие наборы данных
Большие наборы данных сегментируются (разбиваются на несколько файлов) и обычно не помещаются в памяти, поэтому их не следует кэшировать.
Перестановка и тренировка
Во время обучения важно хорошо перетасовать данные: плохо перетасованные данные могут привести к снижению точности обучения.
Помимо использования ds.shuffle
для перетасовки записей, вам также следует установить shuffle_files=True
чтобы обеспечить хорошее поведение при перетасовке больших наборов данных, разбитых на несколько файлов. В противном случае эпохи будут считывать фрагменты в одном и том же порядке, и данные не будут по-настоящему рандомизированы.
ds = tfds.load('imagenet2012', split='train', shuffle_files=True)
Кроме того, когда shuffle_files=True
, TFDS отключает options.deterministic
, что может немного повысить производительность. Чтобы получить детерминированное перетасовывание, можно отказаться от этой функции с помощью tfds.ReadConfig
: либо установив read_config.shuffle_seed
, либо перезаписав read_config.options.deterministic
.
Автоматическое разделение ваших данных между работниками (TF)
При обучении нескольких рабочих процессов вы можете использовать аргумент input_context
tfds.ReadConfig
, поэтому каждый рабочий процесс будет читать подмножество данных.
input_context = tf.distribute.InputContext(
input_pipeline_id=1, # Worker id
num_input_pipelines=4, # Total number of workers
)
read_config = tfds.ReadConfig(
input_context=input_context,
)
ds = tfds.load('dataset', split='train', read_config=read_config)
Это дополняет API-интерфейс subsplit. Сначала применяется API subplit: train[:50%]
преобразуется в список файлов для чтения. Затем к этим файлам применяется операция ds.shard()
. Например, при использовании train[:50%]
с num_input_pipelines=2
каждый из двух воркеров будет читать 1/4 данных.
Если shuffle_files=True
файлы перемешиваются внутри одного работника, но не между работниками. Каждый рабочий процесс будет читать одно и то же подмножество файлов между эпохами.
Автоматическое разделение ваших данных между работниками (Jax)
С Jax вы можете использовать API tfds.split_for_jax_process
или tfds.even_splits
для распределения ваших данных между работниками. См. руководство по разделенному API .
split = tfds.split_for_jax_process('train', drop_remainder=True)
ds = tfds.load('my_dataset', split=split)
tfds.split_for_jax_process
— это простой псевдоним для:
# The current `process_index` loads only `1 / process_count` of the data.
splits = tfds.even_splits('train', n=jax.process_count(), drop_remainder=True)
split = splits[jax.process_index()]
Более быстрое декодирование изображений
По умолчанию TFDS автоматически декодирует изображения. Однако в некоторых случаях может быть более эффективно пропустить декодирование изображения с помощью tfds.decode.SkipDecoding
и вручную применить операцию tf.io.decode_image
:
- При фильтрации примеров (с помощью
tf.data.Dataset.filter
) для декодирования изображений после фильтрации примеров. - При обрезке изображений используйте функцию Fused
tf.image.decode_and_crop_jpeg
.
Код для обоих примеров доступен в руководстве по декодированию .
Пропустить неиспользуемые функции
Если вы используете только часть функций, некоторые функции можно полностью пропустить. Если в вашем наборе данных много неиспользуемых функций, отказ от их декодирования может значительно улучшить производительность. См. https://www.tensorflow.org/datasets/decode#only_decode_a_sub-set_of_the_features .
tf.data использует всю мою оперативную память!
Если вы ограничены в оперативной памяти или загружаете много наборов данных параллельно при использовании tf.data
, вот несколько вариантов, которые могут помочь:
Переопределить размер буфера
builder.as_dataset(
read_config=tfds.ReadConfig(
...
override_buffer_size=1024, # Save quite a bit of RAM.
),
...
)
Это переопределяет buffer_size
, переданный в TFRecordDataset
(или его эквивалент): https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset#args
Используйте tf.data.Dataset.with_options, чтобы остановить магическое поведение.
https://www.tensorflow.org/api_docs/python/tf/data/Dataset#with_options
options = tf.data.Options()
# Stop magic stuff that eats up RAM:
options.autotune.enabled = False
options.experimental_distribute.auto_shard_policy = (
tf.data.experimental.AutoShardPolicy.OFF)
options.experimental_optimization.inject_prefetch = False
data = data.with_options(options)