Split とスライス

すべての TFDS データセットは、カタログで詳しく見ることのできる様々なデータの Split('train''test' など)を公開します。Split 名には、all(すべての Split の結合に対応する予約語。以下参照)を除く任意の英字の文字列を使用できます。

TFDS では、「公式」のデータセットの Split に加え、Split のスライスやさまざまな組み合わせを選択することができます。

スライス API

スライスの指示は tfds.load または tfds.DatasetBuilder.as_datasetsplit= kwarg を介して指定されます。

ds = tfds.load('my_dataset', split='train[:75%]')
builder = tfds.builder('my_dataset')
ds = builder.as_dataset(split='test+train[:75%]')

Split には次のものがあります。

  • プレーンな Split 名'train''test' などの文字列): 選択された Split 内のすべての Example。
  • スライス: スライスのセマンティックはPython のスライス表記法と同じです。スライスには次のものがあります。
    • 絶対 ('train[123:450]'train[:4000]): (読み取り順序に関する警告については、以下の注意事項を参照してください)
    • パーセント率'train[:75%]''train[25%:75%]'): 全データを 100 個の均一なスライスに分けます。データを均等に割り切れない場合は、一部の 100 分の 1 のスライスに追加の Example が含まれる場合があります。小数点以下のパーセントはサポートされています。
    • シャードtrain[:4shard]train[4shard]): リクエストされたシャードのすべての Example を選択します。(Split のシャード数を取得するには、info.splits['train'].num_shards を確認します。)
  • Split の和集合'train+test''train[:25%]+test'): Split は共にインターリーブされます。
  • 完全なデータセット ('all'): 'all' すべての Split の和集合に対応する特別な Split 名です ( 'train+test+...' と同等)。
  • Split のリスト['train', 'test']): 次のようにして、複数の tf.data.Dataset が個別に返されます。
# Returns both train and test split separately
train_ds, test_ds = tfds.load('mnist', split=['train', 'test[:50%]'])

注意: シャードはインターリーブされるため、サブスプリット間での順序の一貫性は保証されません。つまり、test[0:100] の後に test[100:200] を読み取る場合と test[:200] を読み取る場合では、Example の順序が異なることがあります。TFDS が Example を読み取る順序についての詳細は、決定論ガイドを参照してください。

tfds.even_splits とマルチホストトレーニング

tfds.even_splits は、オーバーラップしない同じサイズのサブスプリットのリストを生成します。

# Divide the dataset into 3 even parts, each containing 1/3 of the data
split0, split1, split2 = tfds.even_splits('train', n=3)

ds = tfds.load('my_dataset', split=split2)

これは特に、ホストごとに元のデータのスライスを受け取る必要のある分散環境でのトレーニングに役立ちます。

Jax では、tfds.split_for_jax_process を使用してさらにこれを単純化できます。

split = tfds.split_for_jax_process('train', drop_remainder=True)
ds = tfds.load('my_dataset', split=split)

tfds.split_for_jax_process は以下の単純なエイリアスです。

# The current `process_index` loads only `1 / process_count` of the data.
splits = tfds.even_splits('train', n=jax.process_count(), drop_remainder=True)
split = splits[jax.process_index()]

tfds.even_splitstfds.split_for_jax_process は任意の Split 値を入力として受け入れます(例: 'train[75%:]+test')。

スライスとメタデータ

Split/サブスプリットに関する追加情報(num_examplesfile_instructions など)は、データセットの info を使って取得することができます。

builder = tfds.builder('my_dataset')
builder.info.splits['train'].num_examples  # 10_000
builder.info.splits['train[:75%]'].num_examples  # 7_500 (also works with slices)
builder.info.splits.keys()  # ['train', 'test']

クロス検証

文字列 API を使った 10 段階クロス検証の例:

vals_ds = tfds.load('mnist', split=[
    f'train[{k}%:{k+10}%]' for k in range(0, 100, 10)
])
trains_ds = tfds.load('mnist', split=[
    f'train[:{k}%]+train[{k+10}%:]' for k in range(0, 100, 10)
])

検証データセットは、[0%:10%], [10%:20%], ..., [90%:100%] というように、それぞれ 10% になります。また、トレーニングデータセットは、[10%:100%](対応する検証セットの [0%:10%])、`[0%:10%] というようにそれぞれ補完する 90% になります。

  • [20%:100%](検証セットの[10%:20%]`),...

tfds.core.ReadInstruction と四捨五入

Split は、str ではなく、tfds.core.ReadInstruction として渡すことが可能です。

たとえば、split = 'train[50%:75%] + test' は次と同等です。

split = (
    tfds.core.ReadInstruction(
        'train',
        from_=50,
        to=75,
        unit='%',
    )
    + tfds.core.ReadInstruction('test')
)
ds = tfds.load('my_dataset', split=split)

unit は、次であることができます。

  • abs: 全体的スライス
  • %: パーセント率スライス
  • shard: シャードスライス

tfds.ReadInstruction には四捨五入の引数もあります。データセットの Example 数が均等に割り切れない場合は、次のようになります。

  • rounding='closest'(デフォルト): 残りの Example は、パーセント率で配分され、一部のパーセントに追加の Example が含まれることがあります。
  • rounding='pct1_dropremainder': 残りの Example はドロップされますが、こうすることですべての 100 分の 1 スライスにまったく同じ数の Example が確実に含まれることになります(len(5%) == 5 * len(1%) など)。

再現可能性と決定性

生成中、特定のデータセットのバージョンにおいて、TFDS はExample が決定的にディスクでシャッフルされることを保証します。そのため、データセットを 2 回生成しても(2 台のコンピュータで)、Example の順序は変わりません。

同様に、サブスプリット API は必ず Example の同じ set を選択し、これにはプラットフォームやアーキテクチャなどは考慮されません。つまり、set('train[:20%]') == set('train[:10%]') + set('train[10%:20%]') となります。

ただし、Example が読み取られる順序は決定的ではない場合があります。これは他のパラメータshuffle_files=True であるかどうか)に依存しています。