View source on GitHub |
A convenience class for a task's data and preprocessing logic.
tff.simulation.baselines.BaselineTaskDatasets(
train_data: tff.simulation.datasets.ClientData
,
test_data: tff.simulation.datasets.ClientData
,
validation_data: Optional[tff.simulation.datasets.ClientData
] = None,
train_preprocess_fn: Optional[tff.Computation
] = None,
eval_preprocess_fn: Optional[tff.Computation
] = None
)
Args | |
---|---|
train_data
|
A tff.simulation.datasets.ClientData for training.
|
test_data
|
A tff.simulation.datasets.ClientData or a tf.data.Dataset
for computing test metrics.
|
validation_data
|
An optional tff.simulation.datasets.ClientData or a
tf.data.Dataset for computing validation metrics.
|
train_preprocess_fn
|
An optional callable accepting and returning a
tf.data.Dataset , used to perform dataset preprocessing for training.
If set to None , we use the identity map for all train preprocessing.
|
eval_preprocess_fn
|
An optional callable accepting and returning a
tf.data.Dataset , used to perform evaluation (eg. validation, testing)
preprocessing. If None , evaluation preprocessing will be done via the
identity map.
|
Attributes | |
---|---|
train_data
|
A tff.simulation.datasets.ClientData for training.
|
test_data
|
The test data for the baseline task. Can be a
tff.simulation.datasets.ClientData or a tf.data.Dataset .
|
validation_data
|
The validation data for the baseline task. Can be one of
tff.simulation.datasets.ClientData , tf.data.Dataset , or None if the
task does not have a validation dataset.
|
train_preprocess_fn
|
A callable mapping accepting and return
tf.data.Dataset instances, used for preprocessing train datasets. Set to
None if no train preprocessing occurs for the task.
|
eval_preprocess_fn
|
A callable mapping accepting and return
tf.data.Dataset instances, used for preprocessing evaluation datasets.
Set to None if no eval preprocessing occurs for the task.
|
element_type_structure
|
A nested structure of tf.TensorSpec objects
defining the type of the elements contained in datasets associated to this
task.
|
Methods
get_centralized_test_data
get_centralized_test_data() -> tf.data.Dataset
Returns a tf.data.Dataset
of test data for the task.
If the baseline task has centralized data, then this method will return the centralized data after applying preprocessing. If the test data is federated, then this method will first amalgamate the client datasets into a single dataset, then apply preprocessing.
sample_train_clients
sample_train_clients(
num_clients: int, replace: bool = False, random_seed: Optional[int] = None
) -> list[tf.data.Dataset]
Samples training clients uniformly at random.
Args | |
---|---|
num_clients
|
A positive integer representing number of clients to be sampled. |
replace
|
Whether to sample with replacement. If set to False , then
num_clients cannot exceed the number of training clients in the
associated train data.
|
random_seed
|
An optional integer used to set a random seed for sampling.
If no random seed is passed or the random seed is set to None , this
will attempt to set the random seed according to the current system time
(see numpy.random.RandomState for details).
|
Returns | |
---|---|
A list of tf.data.Dataset instances representing the client datasets.
|
summary
summary(
print_fn: Callable[[str], Any] = print
)
Prints a summary of the train, test, and validation data.
The summary will be printed as a table containing information on the type of train, test, and validation data (ie. federated or centralized) and the number of clients each data structure has (if it is federated). For example, if the train data has 10 clients, and both the test and validation data are centralized, then this will print the following table:
Split |Dataset Type |Number of Clients |
=============================================
Train |Federated |10 |
Test |Centralized |N/A |
Validation |Centralized |N/A |
_____________________________________________
In addition, this will print two lines after the table indicating whether train and eval preprocessing functions were passed in. In the example above, if we passed in a train preprocessing function but no eval preprocessing function, it would also print the lines:
Train Preprocess Function: True
Eval Preprocess Function: False
To capture the summary, you can use a custom print function. For example,
setting print_fn = summary_list.append
will cause each of the lines above
to be appended to summary_list
.
Args | |
---|---|
print_fn
|
An optional callable accepting string inputs. Used to print each
row of the summary. Defaults to print if not specified.
|