성능 팁

이 문서에서는 TensorFlow Datasets(TFDS) 관련 성능 팁을 제공합니다. TFDS는 데이터 세트를 tf.data.Dataset 객체로 제공하므로 tf.data 가이드 의 조언이 계속 적용됩니다.

벤치마크 데이터 세트

tfds.benchmark(ds) 사용하여 tf.data.Dataset 객체를 벤치마킹하세요.

결과를 정규화하려면 batch_size= 지정해야 합니다(예: 100 iter/sec -> 3200 ex/sec). 이는 모든 반복 가능한 항목(예: tfds.benchmark(tfds.as_numpy(ds)) 에서 작동합니다.

ds = tfds.load('mnist', split='train').batch(32).prefetch()
# Display some benchmark statistics
tfds.benchmark(ds, batch_size=32)
# Second iteration is much faster, due to auto-caching
tfds.benchmark(ds, batch_size=32)

소규모 데이터 세트(1GB 미만)

모든 TFDS 데이터 세트는 TFRecord 형식으로 디스크에 데이터를 저장합니다. 소규모 데이터 세트(예: MNIST, CIFAR-10/-100)의 경우 .tfrecord 에서 읽으면 상당한 오버헤드가 추가될 수 있습니다.

이러한 데이터 세트는 메모리에 적합하므로 데이터 세트를 캐싱하거나 미리 로드하면 성능을 크게 향상시킬 수 있습니다. TFDS는 작은 데이터세트를 자동으로 캐시합니다(자세한 내용은 다음 섹션에 있음).

데이터 세트 캐싱

다음은 이미지를 정규화한 후 데이터세트를 명시적으로 캐시하는 데이터 파이프라인의 예입니다.

def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label


ds, ds_info = tfds.load(
    'mnist',
    split='train',
    as_supervised=True,  # returns `(img, label)` instead of dict(image=, ...)
    with_info=True,
)
# Applying normalization before `ds.cache()` to re-use it.
# Note: Random transformations (e.g. images augmentations) should be applied
# after both `ds.cache()` (to avoid caching randomness) and `ds.batch()` (for
# vectorization [1]).
ds = ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.cache()
# For true randomness, we set the shuffle buffer to the full dataset size.
ds = ds.shuffle(ds_info.splits['train'].num_examples)
# Batch after shuffling to get unique batches at each epoch.
ds = ds.batch(128)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

이 데이터 세트를 반복할 때 캐싱 덕분에 두 번째 반복은 첫 번째 반복보다 훨씬 빠릅니다.

자동 캐싱

기본적으로 TFDS는 다음 제약 조건을 충족하는 데이터 세트( ds.cache() 사용)를 자동 캐시합니다.

  • 총 데이터 세트 크기(모든 분할)가 정의되고 < 250MiB
  • shuffle_files 비활성화되었거나 단일 샤드만 읽습니다.

tfds.loadtfds.ReadConfigtry_autocaching=False 전달하여 자동 캐싱을 옵트아웃할 수 있습니다. 특정 데이터세트가 자동 캐시를 사용하는지 확인하려면 데이터세트 카탈로그 문서를 살펴보세요.

전체 데이터를 단일 Tensor로 로드

데이터세트가 메모리에 적합하다면 전체 데이터세트를 단일 Tensor 또는 NumPy 배열로 로드할 수도 있습니다. 단일 tf.Tensor 에서 모든 예제를 일괄 처리하도록 batch_size=-1 설정하면 그렇게 할 수 있습니다. 그런 다음 tf.Tensor 에서 np.array 로 변환하려면 tfds.as_numpy 사용하세요.

(img_train, label_train), (img_test, label_test) = tfds.as_numpy(tfds.load(
    'mnist',
    split=['train', 'test'],
    batch_size=-1,
    as_supervised=True,
))

대규모 데이터 세트

대규모 데이터 세트는 샤딩(여러 파일로 분할)되어 있으며 일반적으로 메모리에 맞지 않으므로 캐시하면 안 됩니다.

셔플 및 트레이닝

훈련 중에는 데이터를 잘 섞는 것이 중요합니다. 데이터를 제대로 섞지 않으면 훈련 정확도가 낮아질 수 있습니다.

ds.shuffle 사용하여 레코드를 섞는 것 외에도 shuffle_files=True 설정하여 여러 파일로 샤딩되는 대규모 데이터 세트에 대해 좋은 섞기 동작을 얻으려면도 설정해야 합니다. 그렇지 않으면 시대가 동일한 순서로 샤드를 읽으므로 데이터가 실제로 무작위화되지 않습니다.

ds = tfds.load('imagenet2012', split='train', shuffle_files=True)

또한 shuffle_files=True 인 경우 TFDS는 options.deterministic 비활성화하여 약간의 성능 향상을 제공할 수 있습니다. 결정론적 셔플링을 얻으려면 tfds.ReadConfig 사용하여 이 기능을 옵트아웃할 수 있습니다. read_config.shuffle_seed 설정하거나 read_config.options.deterministic 을 덮어쓰면 됩니다.

작업자 간에 데이터 자동 샤딩(TF)

여러 작업자를 훈련할 때 tfds.ReadConfiginput_context 인수를 사용할 수 있으므로 각 작업자는 데이터의 하위 집합을 읽습니다.

input_context = tf.distribute.InputContext(
    input_pipeline_id=1,  # Worker id
    num_input_pipelines=4,  # Total number of workers
)
read_config = tfds.ReadConfig(
    input_context=input_context,
)
ds = tfds.load('dataset', split='train', read_config=read_config)

이는 하위 분할 API를 보완합니다. 먼저 subplit API가 적용됩니다. train[:50%] 읽을 파일 목록으로 변환됩니다. 그런 다음 ds.shard() 작업이 해당 파일에 적용됩니다. 예를 들어, num_input_pipelines=2 와 함께 train[:50%] 사용하면 2명의 워커 각각이 데이터의 1/4을 읽습니다.

shuffle_files=True 인 경우 파일은 한 작업자 내에서 섞이지 않고 여러 작업자 간에 섞이지 않습니다. 각 작업자는 시대 간에 동일한 파일 하위 집합을 읽습니다.

작업자 간에 데이터 자동 샤딩(Jax)

Jax를 사용하면 tfds.split_for_jax_process 또는 tfds.even_splits API를 사용하여 작업자 간에 데이터를 배포할 수 있습니다. 분할 API 가이드를 참조하세요.

split = tfds.split_for_jax_process('train', drop_remainder=True)
ds = tfds.load('my_dataset', split=split)

tfds.split_for_jax_process 는 다음에 대한 간단한 별칭입니다.

# The current `process_index` loads only `1 / process_count` of the data.
splits = tfds.even_splits('train', n=jax.process_count(), drop_remainder=True)
split = splits[jax.process_index()]

더 빠른 이미지 디코딩

기본적으로 TFDS는 이미지를 자동으로 디코딩합니다. 그러나 tfds.decode.SkipDecoding 사용하여 이미지 디코딩을 건너뛰고 tf.io.decode_image 작업을 수동으로 적용하는 것이 더 성능이 좋은 경우가 있습니다.

두 예제의 코드는 디코드 가이드 에서 확인할 수 있습니다.

사용하지 않는 기능 건너뛰기

기능의 하위 집합만 사용하는 경우 일부 기능을 완전히 건너뛸 수 있습니다. 데이터 세트에 사용되지 않는 기능이 많이 있는 경우 해당 기능을 디코딩하지 않으면 성능이 크게 향상될 수 있습니다. https://www.tensorflow.org/datasets/decode#only_decode_a_sub-set_of_the_features 를 참조하세요.

tf.data는 내 RAM을 모두 사용합니다!

RAM이 제한되어 있거나 tf.data 사용하는 동안 많은 데이터 세트를 병렬로 로드하는 경우 도움이 될 수 있는 몇 가지 옵션은 다음과 같습니다.

버퍼 크기 재정의

builder.as_dataset(
  read_config=tfds.ReadConfig(
    ...
    override_buffer_size=1024,  # Save quite a bit of RAM.
  ),
  ...
)

이는 TFRecordDataset (또는 이에 상응하는 것)에 전달된 buffer_size 재정의합니다: https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset#args

tf.data.Dataset.with_options를 사용하여 매직 동작을 중지하세요.

https://www.tensorflow.org/api_docs/python/tf/data/Dataset#with_options

options = tf.data.Options()

# Stop magic stuff that eats up RAM:
options.autotune.enabled = False
options.experimental_distribute.auto_shard_policy = (
  tf.data.experimental.AutoShardPolicy.OFF)
options.experimental_optimization.inject_prefetch = False

data = data.with_options(options)