パフォーマンスに関するヒント

このドキュメントでは、TensorFlow Datasets (TFDS) 固有のパフォーマンスに関するヒントを提供します。TFDS はデータセットを tf.data.Dataset として提供するため、tf.data ガイドのアドバイスが適用されます。

データセットのベンチマークを作成する

tfds.benchmark(ds) を使用すると、あらゆる tf.data.Dataset オブジェクトのベンチマークを作成できます。

結果を正規化するために、batch_size= を必ず指定してください(100 iter/sec -> 3200 ex/sec など)。これは、任意のイテラブル(tfds.benchmark(tfds.as_numpy(ds)) など)で機能します。

ds = tfds.load('mnist', split='train').batch(32).prefetch()
# Display some benchmark statistics
tfds.benchmark(ds, batch_size=32)
# Second iteration is much faster, due to auto-caching
tfds.benchmark(ds, batch_size=32)

小規模なデータセット (1 GB 未満)

すべての TFDS データセットは、TFRecord 形式でディスク上にデータを保存します。小規模なデータセット (MNIST、CIFAR-10/-100 など) の場合、.tfrecord から読み取ると、著しいオーバーヘッドが追加されます。

これらのデータセットはメモリに収まるため、データセットをキャッシュするか事前に読み込むことで、パフォーマンスを大幅に改善することが可能です。TFDS は小規模なデータセットを自動的にキャッシュすることに注意してください(詳細は次のセクションをご覧ください)。

データセットのキャッシュ

次は、画像を正規化した後にデータセットを明示的にキャッシュするデータパイプラインの例です。

def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label


ds, ds_info = tfds.load(
    'mnist',
    split='train',
    as_supervised=True,  # returns `(img, label)` instead of dict(image=, ...)
    with_info=True,
)
# Applying normalization before `ds.cache()` to re-use it.
# Note: Random transformations (e.g. images augmentations) should be applied
# after both `ds.cache()` (to avoid caching randomness) and `ds.batch()` (for
# vectorization [1]).
ds = ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.cache()
# For true randomness, we set the shuffle buffer to the full dataset size.
ds = ds.shuffle(ds_info.splits['train'].num_examples)
# Batch after shuffling to get unique batches at each epoch.
ds = ds.batch(128)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

このデータセットをイテレートする際、キャッシュによって、2 回目のイテレーションは最初のイテレーションよりはるかに高速に行われます。

自動キャッシュ

デフォルトでは、TFDS は (ds.cache() で) 次の制限を満たすデータセットを自動的にキャッシュします。

  • 合計データセットサイズ(全分割)が定義されており、250 MiB 未満である
  • shuffle_files が無効化されているか、単一のシャードのみが読み取られる

tfds.loadtfds.ReadConfigtry_autocaching=False を渡して、自動キャッシュをオプトアウトすることができまます。特定のデータセットで自動キャッシュが使用されるかどうかを確認するには、データセットカタログドキュメントをご覧ください。

単一テンソルとしての全データの読み込み

データセットがメモリに収まる場合は、全データセットを単一のテンソルまたは Numpy 配列として読み込むこともできます。これは、batch_size=-1 を設定して、単一の tf.Tensor 内のすべてのサンプルをバッチ処理し、tfds.as_numpy を使用して tf.Tensor から np.array に変換することで行えます。

(img_train, label_train), (img_test, label_test) = tfds.as_numpy(tfds.load(
    'mnist',
    split=['train', 'test'],
    batch_size=-1,
    as_supervised=True,
))

大規模なデータセット

大規模なデータセットはシャーディングされており(複数のファイルに分割)、 通常、メモリに収まらないため、キャッシュすることはできません。

シャッフルとトレーニング

トレーニング中はデータを十分にシャッフルすることが重要です。適切にシャッフルされていないデータでは、トレーニング精度が低下してしまいます。

ds.shuffle を使用してレコードをシャッフルするほかに、shuffle_files=True を設定して、複数のファイルシャーディングされている大規模なデータセット向けに、十分なシャッフル動作を得る必要があります。シャッフルが十分でない場合、エポックは、同じ順でシャードを読み取ってしまい、データが実質的にランダム化されません。

ds = tfds.load('imagenet2012', split='train', shuffle_files=True)

さらに、shuffle_files=True である場合、TFDS は options.deterministic を無効化するため、パフォーマンスがわずかに向上することがあります。確定的なシャッフルを行うには、tfds.ReadConfig を使って read_config.shuffle_seed を設定するか、read_config.options.deterministic を上書きすることで、この機能をオプトアウトすることができます。

ワーカー間でデータを自動シャーディングする(TF)

複数のワーカーでトレーニングを実施する場合、tfds.ReadConfiginput_context 引数を使用して、各ワーカーがデータのサブセットを読み取るようにすることができます。

input_context = tf.distribute.InputContext(
    input_pipeline_id=1,  # Worker id
    num_input_pipelines=4,  # Total number of workers
)
read_config = tfds.ReadConfig(
    input_context=input_context,
)
ds = tfds.load('dataset', split='train', read_config=read_config)

これは subsplit API を補完するものです。まず、subplit API が適用され(train[:50%] が読み取るファイルのリストに変換されます)、それらのファイルに対して ds.shard() op が適用されます。たとえば、num_input_pipelines=2train[:50%] を使用する場合、2 つの各ワーカーはデータの 4 分の 1 を読み取るようになります。

shuffle_files=True である場合、ファイルは 1 つのワーカー内でシャッフルされますが、ワーカー全体ではシャッフルされません。各ワーカーはエポックごとにファイルの同じサブセットを読み取ります。

注意: tf.distribute.Strategy を使用すると、distribute_datasets_from_functioninput_context を自動的に作成することができます。

ワーカー間でデータを自動シャーディングする(Jax)

Jax の場合、tfds.split_for_jax_process または tfds.even_splits API を使用すると、ワーカー間でデータを分散させることができます。API の分割ガイドをご覧ください。

split = tfds.split_for_jax_process('train', drop_remainder=True)
ds = tfds.load('my_dataset', split=split)

tfds.split_for_jax_process は以下の単純なエイリアスです。

# The current `process_index` loads only `1 / process_count` of the data.
splits = tfds.even_splits('train', n=jax.process_count(), drop_remainder=True)
split = splits[jax.process_index()]

画像のデコードの高速化

デフォルトでは、TFDS は自動的に画像をデコードするようになっていますが、tfds.decode.SkipDecoding を使って画像のデコードをスキップし、手動で tf.io.decode_image 演算を適用する方がパフォーマンスが高くなる場合もあります。

上記の両方の例のコードは、decode ガイドをご覧ください。

未使用の特徴量をスキップする

特徴量のサブセットのみを使用している場合は、一部の特徴量を完全にスキップすることが可能です。データセットに未使用の特徴量が多数ある場合は、それらをデコードしないことでパフォーマンスが大幅に改善されます。https://www.tensorflow.org/datasets/decode#only_decode_a_sub-set_of_the_features をご覧ください。