View on TensorFlow.org | Run in Google Colab | View on GitHub | Download notebook |
This document explains:
- The TFDS guarantees on determinism
- In which order does TFDS read examples
- Various caveats and gotchas
Setup
Datasets
Some context is needed to understand how TFDS reads the data.
During generation, TFDS write the original data into standardized .tfrecord
files. For big datasets, multiple .tfrecord
files are created, each containing multiple examples. We call each .tfrecord
file a shard.
This guide uses imagenet which has 1024 shards:
import re
import tensorflow_datasets as tfds
imagenet = tfds.builder('imagenet2012')
num_shards = imagenet.info.splits['train'].num_shards
num_examples = imagenet.info.splits['train'].num_examples
print(f'imagenet has {num_shards} shards ({num_examples} examples)')
imagenet has 1024 shards (1281167 examples)
Finding the dataset examples ids
You can skip to the following section if you only want to know about determinism.
Each dataset example is uniquely identified by an id
(e.g. 'imagenet2012-train.tfrecord-01023-of-01024__32'
). You can recover this
id
by passing read_config.add_tfds_id = True
which will add a 'tfds_id'
key in the dict from the tf.data.Dataset
.
In this tutorial, we define a small util which will print the example ids of the dataset (converted in integer to be more human-readable):
def load_dataset(builder, **as_dataset_kwargs):
"""Load the dataset with the tfds_id."""
read_config = as_dataset_kwargs.pop('read_config', tfds.ReadConfig())
read_config.add_tfds_id = True # Set `True` to return the 'tfds_id' key
return builder.as_dataset(read_config=read_config, **as_dataset_kwargs)
def print_ex_ids(
builder,
*,
take: int,
skip: int = None,
**as_dataset_kwargs,
) -> None:
"""Print the example ids from the given dataset split."""
ds = load_dataset(builder, **as_dataset_kwargs)
if skip:
ds = ds.skip(skip)
ds = ds.take(take)
exs = [ex['tfds_id'].numpy().decode('utf-8') for ex in ds]
exs = [id_to_int(tfds_id, builder=builder) for tfds_id in exs]
print(exs)
def id_to_int(tfds_id: str, builder) -> str:
"""Format the tfds_id in a more human-readable."""
match = re.match(r'\w+-(\w+).\w+-(\d+)-of-\d+__(\d+)', tfds_id)
split_name, shard_id, ex_id = match.groups()
split_info = builder.info.splits[split_name]
return sum(split_info.shard_lengths[:int(shard_id)]) + int(ex_id)
Determinism when reading
This section explains deterministim guarantee of tfds.load
.
With shuffle_files=False
(default)
By default TFDS yield examples deterministically (shuffle_files=False
)
# Same as: imagenet.as_dataset(split='train').take(20)
print_ex_ids(imagenet, split='train', take=20)
print_ex_ids(imagenet, split='train', take=20)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]
For performance, TFDS read multiple shards at the same time using tf.data.Dataset.interleave. We see in this example that TFDS switch to shard 2 after reading 16 examples (..., 14, 15, 1251, 1252, ...
). More on interleave bellow.
Similarly, the subsplit API is also deterministic:
print_ex_ids(imagenet, split='train[67%:84%]', take=20)
print_ex_ids(imagenet, split='train[67%:84%]', take=20)
[858382, 858383, 858384, 858385, 858386, 858387, 858388, 858389, 858390, 858391, 858392, 858393, 858394, 858395, 858396, 858397, 859533, 859534, 859535, 859536] [858382, 858383, 858384, 858385, 858386, 858387, 858388, 858389, 858390, 858391, 858392, 858393, 858394, 858395, 858396, 858397, 859533, 859534, 859535, 859536]
If you're training for more than one epoch, the above setup is not recommended as all epochs will read the shards in the same order (so randomness is limited to the ds = ds.shuffle(buffer)
buffer size).
With shuffle_files=True
With shuffle_files=True
, shards are shuffled for each epoch, so reading is not deterministic anymore.
print_ex_ids(imagenet, split='train', shuffle_files=True, take=20)
print_ex_ids(imagenet, split='train', shuffle_files=True, take=20)
[568017, 329050, 329051, 329052, 329053, 329054, 329056, 329055, 568019, 568020, 568021, 568022, 568023, 568018, 568025, 568024, 568026, 568028, 568030, 568031] [43790, 43791, 43792, 43793, 43796, 43794, 43797, 43798, 43795, 43799, 43800, 43801, 43802, 43803, 43804, 43805, 43806, 43807, 43809, 43810]
See recipe below to get deterministic file shuffling.
Determinism caveat: interleave args
Changing read_config.interleave_cycle_length
, read_config.interleave_block_length
will change the examples order.
TFDS relies on tf.data.Dataset.interleave to only load a few shards at once, improving the performance and reducing memory usage.
The example order is only guaranteed to be the same for a fixed value of interleave args. See interleave doc to understand what cycle_length
and block_length
correspond too.
cycle_length=16
,block_length=16
(default, same as above):
print_ex_ids(imagenet, split='train', take=20)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]
cycle_length=3
,block_length=2
:
read_config = tfds.ReadConfig(
interleave_cycle_length=3,
interleave_block_length=2,
)
print_ex_ids(imagenet, split='train', read_config=read_config, take=20)
[0, 1, 1251, 1252, 2502, 2503, 2, 3, 1253, 1254, 2504, 2505, 4, 5, 1255, 1256, 2506, 2507, 6, 7]
In the second example, we see that the dataset read 2 (block_length=2
) examples in a shard, then switch to the next shard. Every 2 * 3 (cycle_length=3
) examples, it goes back to the first shard (shard0-ex0, shard0-ex1, shard1-ex0, shard1-ex1, shard2-ex0, shard2-ex1, shard0-ex2, shard0-ex3, shard1-ex2, shard1-ex3, shard2-ex2,...
).
Subsplit and example order
Each example has an id 0, 1, ..., num_examples-1
. The subsplit API select a slice of examples (e.g. train[:x]
select 0, 1, ..., x-1
).
However, within the subsplit, examples are not read in increasing id order (due to shards and interleave).
More specifically, ds.take(x)
and split='train[:x]'
are not equivalent !
This can be seen easily in the above interleave example where examples come from different shards.
print_ex_ids(imagenet, split='train', take=25) # tfds.load(..., split='train').take(25)
print_ex_ids(imagenet, split='train[:25]', take=-1) # tfds.load(..., split='train[:25]')
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254, 1255, 1256, 1257, 1258, 1259] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
After the 16 (block_length) examples, .take(25)
switches to the next shard while train[:25]
continue reading examples in from the first shard.
Recipes
Get deterministic file shuffling
There are 2 ways to have deterministic shuffling:
- Setting the
shuffle_seed
. Note: This requires changing the seed at each epoch, otherwise shards will be read in the same order between epoch.
read_config = tfds.ReadConfig(
shuffle_seed=32,
)
# Deterministic order, different from the default shuffle_files=False above
print_ex_ids(imagenet, split='train', shuffle_files=True, read_config=read_config, take=22)
print_ex_ids(imagenet, split='train', shuffle_files=True, read_config=read_config, take=22)
[176411, 176412, 176413, 176414, 176415, 176416, 176417, 176418, 176419, 176420, 176421, 176422, 176423, 176424, 176425, 176426, 710647, 710648, 710649, 710650, 710651, 710652] [176411, 176412, 176413, 176414, 176415, 176416, 176417, 176418, 176419, 176420, 176421, 176422, 176423, 176424, 176425, 176426, 710647, 710648, 710649, 710650, 710651, 710652]
- Using
experimental_interleave_sort_fn
: This gives full control over which shards are read and in which order, rather than relying onds.shuffle
order.
def _reverse_order(file_instructions):
return list(reversed(file_instructions))
read_config = tfds.ReadConfig(
experimental_interleave_sort_fn=_reverse_order,
)
# Last shard (01023-of-01024) is read first
print_ex_ids(imagenet, split='train', read_config=read_config, take=5)
[1279916, 1279917, 1279918, 1279919, 1279920]
Get deterministic preemptable pipeline
This one is more complicated. There is no easy, satisfactory solution.
Without
ds.shuffle
and with deterministic shuffling, in theory it should be possible to count the examples which have been read and deduce which examples have been read within in each shard (as a function ofcycle_length
,block_length
and shard order). Then theskip
,take
for each shard could be injected throughexperimental_interleave_sort_fn
.With
ds.shuffle
it's likely impossible without replaying the full training pipeline. It would require saving theds.shuffle
buffer state to deduce which examples have been read. Examples could be non-continuous (e.g.shard5_ex2
,shard5_ex4
read but notshard5_ex3
).With
ds.shuffle
, one way would be to save all shards_ids/example_ids read (deduced fromtfds_id
), then deducing the file instructions from that.
The simplest case for 1.
is to have .skip(x).take(y)
match train[x:x+y]
match. It requires:
- Set
cycle_length=1
(so shards are read sequentially) - Set
shuffle_files=False
- Do not use
ds.shuffle
It should only be used on huge dataset where the training is only 1 epoch. Examples would be read in the default shuffle order.
read_config = tfds.ReadConfig(
interleave_cycle_length=1, # Read shards sequentially
)
print_ex_ids(imagenet, split='train', read_config=read_config, skip=40, take=22)
# If the job get pre-empted, using the subsplit API will skip at most `len(shard0)`
print_ex_ids(imagenet, split='train[40:]', read_config=read_config, take=22)
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61] [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]
Find which shards/examples are read for a given subsplit
With the tfds.core.DatasetInfo
, you have direct access to the read instructions.
imagenet.info.splits['train[44%:45%]'].file_instructions
[FileInstruction(filename='imagenet2012-train.tfrecord-00450-of-01024', skip=700, take=-1, num_examples=551), FileInstruction(filename='imagenet2012-train.tfrecord-00451-of-01024', skip=0, take=-1, num_examples=1251), FileInstruction(filename='imagenet2012-train.tfrecord-00452-of-01024', skip=0, take=-1, num_examples=1251), FileInstruction(filename='imagenet2012-train.tfrecord-00453-of-01024', skip=0, take=-1, num_examples=1251), FileInstruction(filename='imagenet2012-train.tfrecord-00454-of-01024', skip=0, take=-1, num_examples=1252), FileInstruction(filename='imagenet2012-train.tfrecord-00455-of-01024', skip=0, take=-1, num_examples=1251), FileInstruction(filename='imagenet2012-train.tfrecord-00456-of-01024', skip=0, take=-1, num_examples=1251), FileInstruction(filename='imagenet2012-train.tfrecord-00457-of-01024', skip=0, take=-1, num_examples=1251), FileInstruction(filename='imagenet2012-train.tfrecord-00458-of-01024', skip=0, take=-1, num_examples=1251), FileInstruction(filename='imagenet2012-train.tfrecord-00459-of-01024', skip=0, take=-1, num_examples=1251), FileInstruction(filename='imagenet2012-train.tfrecord-00460-of-01024', skip=0, take=1001, num_examples=1001)]