All TFDS datasets expose various data splits (e.g. 'train'
, 'test'
) which
can be explored in the
catalog. Any
alphabetical string can be used as split name, apart from all
(which is a
reserved term which corresponds to the union of all splits, see below).
In addition of the "official" dataset splits, TFDS allow to select slice(s) of split(s) and various combinations.
Slicing API
Slicing instructions are specified in tfds.load
or
tfds.DatasetBuilder.as_dataset
through the split=
kwarg.
ds = tfds.load('my_dataset', split='train[:75%]')
builder = tfds.builder('my_dataset')
ds = builder.as_dataset(split='test+train[:75%]')
Split can be:
- Plain split names (a string such as
'train'
,'test'
, ...): All examples within the split selected. - Slices: Slices have the same semantic as
python slice notation.
Slices can be:
- Absolute (
'train[123:450]'
,train[:4000]
): (see note below for caveat about read order) - Percent (
'train[:75%]'
,'train[25%:75%]'
): Divide the full data into even slices. If the data is not evenly divisible, some percent might contain additional examples. Fractional percent are supported. - Shard (
train[:4shard]
,train[4shard]
): Select all examples in the requested shard. (seeinfo.splits['train'].num_shards
to get the number of shards of the split)
- Absolute (
- Union of splits (
'train+test'
,'train[:25%]+test'
): Splits will be interleaved together. - Full dataset (
'all'
):'all'
is a special split name corresponding to the union of all splits (equivalent to'train+test+...'
). - List of splits (
['train', 'test']
): Multipletf.data.Dataset
are returned separately:
# Returns both train and test split separately
train_ds, test_ds = tfds.load('mnist', split=['train', 'test[:50%]'])
tfds.even_splits
& multi-host training
tfds.even_splits
generates a list of non-overlapping sub-splits of the same
size.
# Divide the dataset into 3 even parts, each containing 1/3 of the data
split0, split1, split2 = tfds.even_splits('train', n=3)
ds = tfds.load('my_dataset', split=split2)
This can be particularly useful when training in a distributed setting, where each host should receive a slice of the original data.
With Jax
, this can be simplified even further using
tfds.split_for_jax_process
:
split = tfds.split_for_jax_process('train', drop_remainder=True)
ds = tfds.load('my_dataset', split=split)
tfds.split_for_jax_process
is a simple alias for:
# The current `process_index` loads only `1 / process_count` of the data.
splits = tfds.even_splits('train', n=jax.process_count(), drop_remainder=True)
split = splits[jax.process_index()]
tfds.even_splits
, tfds.split_for_jax_process
accepts on any split value as
input (e.g. 'train[75%:]+test'
)
Slicing and metadata
It is possible to get additional info on the splits/subsplits (num_examples
,
file_instructions
,...) using the
dataset info:
builder = tfds.builder('my_dataset')
builder.info.splits['train'].num_examples # 10_000
builder.info.splits['train[:75%]'].num_examples # 7_500 (also works with slices)
builder.info.splits.keys() # ['train', 'test']
Cross validation
Examples of 10-fold cross-validation using the string API:
vals_ds = tfds.load('mnist', split=[
f'train[{k}%:{k+10}%]' for k in range(0, 100, 10)
])
trains_ds = tfds.load('mnist', split=[
f'train[:{k}%]+train[{k+10}%:]' for k in range(0, 100, 10)
])
The validation datasets are each going to be 10%: [0%:10%]
, [10%:20%]
, ...,
[90%:100%]
. And the training datasets are each going to be the complementary
90%: [10%:100%]
(for a corresponding validation set of [0%:10%]
), `[0%:10%]
- [20%:100%]
(for a validation set of
[10%:20%]`),...
tfds.core.ReadInstruction
and rounding
Rather than str
, it is possible to pass splits as tfds.core.ReadInstruction
:
For example, split = 'train[50%:75%] + test'
is equivalent to:
split = (
tfds.core.ReadInstruction(
'train',
from_=50,
to=75,
unit='%',
)
+ tfds.core.ReadInstruction('test')
)
ds = tfds.load('my_dataset', split=split)
unit
can be:
abs
: Absolute slicing%
: Percent slicingshard
: Shard slicing
tfds.ReadInstruction
also has a rounding argument. If the number of example in
the dataset is not divide evenly:
rounding='closest'
(default): The remaining examples are distributed among the percent, so some percent might contain additional examples.rounding='pct1_dropremainder'
: The remaining examples are dropped, but this guarantee all percent contain the exact same number of example (eg:len(5%) == 5 * len(1%)
).
Reproducibility & determinism
During generation, for a given dataset version, TFDS guarantee that examples are deterministically shuffled on disk. So generating the dataset twice (in 2 different computers) won't change the example order.
Similarly, the subsplit API will always select the same set
of examples,
regardless of platform, architecture, etc. This mean set('train[:20%]') ==
set('train[:10%]') + set('train[10%:20%]')
.
However, the order in which examples are read might not be deterministic.
This depends on other parameters (e.g. whether shuffle_files=True
).