View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |
The notion of a dataset keyed by clients (e.g. users) is essential to federated computation as modeled in TFF. TFF provides the interface tff.simulation.datasets.ClientData
to abstract over this concept, and the datasets which TFF hosts (stackoverflow, shakespeare, emnist, cifar100, and gldv2) all implement this interface.
If you are working on federated learning with your own dataset, TFF strongly encourages you to either implement the ClientData
interface or use one of TFF's helper functions to generate a ClientData
which represents your data on disk, e.g. tff.simulation.datasets.ClientData.from_clients_and_fn
.
As most of TFF's end-to-end examples start with ClientData
objects, implementing the ClientData
interface with your custom dataset will make it easier to spelunk through existing code written with TFF. Further, the tf.data.Datasets
which ClientData
constructs can be iterated over directly to yield structures of numpy
arrays, so ClientData
objects can be used with any Python-based ML framework before moving to TFF.
There are several patterns with which you can make your life easier if you intend to scale up your simulations to many machines or deploy them. Below we will walk through a few of the ways we can use ClientData
and TFF to make our small-scale iteration-to large-scale experimentation-to production deployment experience as smooth as possible.
Which pattern should I use to pass ClientData into TFF?
We will discuss two usages of TFF's ClientData
in depth; if you fit in either of the two categories below, you will clearly prefer one over the other. If not, you may need a more detailed understanding of the pros and cons of each to make a more nuanced choice.
I want to iterate as quickly as possible on a local machine; I don't need to be able to easily take advantage of TFF's distributed runtime.
- You want to pass
tf.data.Datasets
in to TFF directly. - This allows you to program imperatively with
tf.data.Dataset
objects, and process them arbitrarily. - It provides more flexibility than the option below; pushing logic to the clients requires that this logic be serializable.
- You want to pass
I want to run my federated computation in TFF's remote runtime, or I plan to do so soon.
- In this case you want to map dataset construction and preprocessing to clients.
- This results in you passing simply a list of
client_ids
directly to your federated computation. - Pushing dataset construction and preprocessing to the clients avoids bottlenecks in serialization, and significantly increases performance with hundreds-to-thousands of clients.
Set up open-source environment
# tensorflow_federated_nightly also bring in tf_nightly, which
# can causes a duplicate tensorboard install, leading to errors.
pip uninstall --yes tensorboard tb-nightly
pip install --quite --upgrade federated_language
pip install --quiet --upgrade tensorflow_federated
Import packages
import collections
import time
import federated_language
import tensorflow as tf
import tensorflow_federated as tff
Manipulating a ClientData object
Let's begin by loading and exploring TFF's EMNIST ClientData
:
client_data, _ = tff.simulation.datasets.emnist.load_data()
Inspecting the first dataset can tell us what type of examples are in the ClientData
.
first_client_id = client_data.client_ids[0]
first_client_dataset = client_data.create_tf_dataset_for_client(
first_client_id)
print(first_client_dataset.element_spec)
# This information is also available as a `ClientData` property:
assert client_data.element_type_structure == first_client_dataset.element_spec
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])
Note that the dataset yields collections.OrderedDict
objects that have pixels
and label
keys, where pixels is a tensor with shape [28, 28]
. Suppose we wish to flatten our inputs out to shape [784]
. One possible way we can do this would be to apply a pre-processing function to our ClientData
object.
def preprocess_dataset(dataset):
"""Create batches of 5 examples, and limit to 3 batches."""
def map_fn(input):
return collections.OrderedDict(
x=tf.reshape(input['pixels'], shape=(-1, 784)),
y=tf.cast(tf.reshape(input['label'], shape=(-1, 1)), tf.int64),
)
return dataset.batch(5).map(
map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).take(5)
preprocessed_client_data = client_data.preprocess(preprocess_dataset)
# Notice that we have both reshaped and renamed the elements of the ordered dict.
first_client_dataset = preprocessed_client_data.create_tf_dataset_for_client(
first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])
We may want in addition to perform some more complex (and possibly stateful) preprocessing, for example shuffling.
def preprocess_and_shuffle(dataset):
"""Applies `preprocess_dataset` above and shuffles the result."""
preprocessed = preprocess_dataset(dataset)
return preprocessed.shuffle(buffer_size=5)
preprocessed_and_shuffled = client_data.preprocess(preprocess_and_shuffle)
# The type signature will remain the same, but the batches will be shuffled.
first_client_dataset = preprocessed_and_shuffled.create_tf_dataset_for_client(
first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])
Interfacing with a tff.Computation
Now that we can perform some basic manipulations with ClientData
objects, we are ready to feed data to a tff.Computation
. We define a tff.templates.IterativeProcess
which implements Federated Averaging, and explore different methods of passing it data.
keras_model = tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(784,)),
tf.keras.layers.Dense(10, kernel_initializer='zeros'),
])
tff_model = tff.learning.models.functional_model_from_keras(
keras_model,
loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
# Note: input spec is the _batched_ shape, and includes the
# label tensor which will be passed to the loss function. This model is
# therefore configured to accept data _after_ it has been preprocessed.
input_spec=collections.OrderedDict(
x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),
y=tf.TensorSpec(shape=[None, 1], dtype=tf.int64),
),
metrics_constructor=collections.OrderedDict(
loss=lambda: tf.keras.metrics.SparseCategoricalCrossentropy(
from_logits=True
),
accuracy=tf.keras.metrics.SparseCategoricalAccuracy,
),
)
trainer = tff.learning.algorithms.build_weighted_fed_avg(
tff_model,
client_optimizer_fn=tff.learning.optimizers.build_sgdm(learning_rate=0.01),
)
Before we begin working with this IterativeProcess
, one comment on the semantics of ClientData
is in order. A ClientData
object represents the entirety of the population available for federated training, which in general is not available to the execution environment of a production FL system and is specific to simulation. ClientData
indeed gives the user the capacity to bypass federated computing entirely and simply train a server-side model as usual via ClientData.create_tf_dataset_from_all_clients
.
TFF's simulation environment puts the researcher in complete control of the outer loop. In particular this implies considerations of client availability, client dropout, etc, must be addressed by the user or Python driver script. One could for example model client dropout by adjusting the sampling distribution over your ClientData's
client_ids
such that users with more data (and correspondingly longer-running local computations) would be selected with lower probability.
In a real federated system, however, clients cannot be selected explicitly by the model trainer; the selection of clients is delegated to the system which is executing the federated computation.
Passing tf.data.Datasets
directly to TFF
One option we have for interfacing between a ClientData
and an IterativeProcess
is that of constructing tf.data.Datasets
in Python, and passing these datasets to TFF.
Notice that if we use our preprocessed ClientData
the datasets we yield are of the appropriate type expected by our model defined above.
selected_client_ids = preprocessed_and_shuffled.client_ids[:10]
preprocessed_data_for_clients = [
preprocessed_and_shuffled.create_tf_dataset_for_client(
selected_client_ids[i]
)
for i in range(10)
]
state = trainer.initialize()
for _ in range(5):
t1 = time.time()
result = trainer.next(state, preprocessed_data_for_clients)
state = result.state
train_metrics = result.metrics['client_work']['train']
t2 = time.time()
print(f'loss {train_metrics["loss"]:.2f}, round time {t2 - t1:.2f} seconds')
loss 2.89, round time 2.35 seconds loss 3.05, round time 2.26 seconds loss 2.80, round time 0.63 seconds loss 2.94, round time 3.18 seconds loss 3.17, round time 2.44 seconds
If we take this route, however, we will be unable to trivially move to multimachine simulation. The datasets we construct in the local TensorFlow runtime can capture state from the surrounding python environment, and fail in serialization or deserialization when they attempt to reference state which is no longer available to them. This can manifest for example in the inscrutable error from TensorFlow's tensor_util.cc
:
Check failed: DT_VARIANT == input.dtype() (21 vs. 20)
Mapping construction and preprocessing over the clients
To avoid this issue, TFF recommends its users to consider dataset instantiation and preprocessing as something that happens locally on each client, and to use TFF's helpers or federated_map
to explicitly run this preprocessing code at each client.
Conceptually, the reason for preferring this is clear: in TFF's local runtime, the clients only "accidentally" have access to the global Python environment due to the fact that the entire federated orchestration is happening on a single machine. It is worthwhile noting at this point that similar thinking gives rise to TFF's cross-platform, always-serializable, functional philosophy.
TFF makes such a change simple via ClientData's
attribute dataset_computation
, a tff.Computation
which takes a client_id
and returns the associated tf.data.Dataset
.
Note that preprocess
simply works with dataset_computation
; the dataset_computation
attribute of the preprocessed ClientData
incorporates the entire preprocessing pipeline we just defined:
print('dataset computation without preprocessing:')
print(client_data.dataset_computation.type_signature)
print('\n')
print('dataset computation with preprocessing:')
print(preprocessed_and_shuffled.dataset_computation.type_signature)
dataset computation without preprocessing: (str -> <label=int32,pixels=float32[28,28]>*) dataset computation with preprocessing: (str -> <x=float32[?,784],y=int64[?,1]>*)
We could invoke dataset_computation
and receive an eager dataset in the Python runtime, but the real power of this approach is exercised when we compose with an iterative process or another computation to avoid materializing these datasets in the global eager runtime at all. TFF provides a helper function tff.simulation.compose_dataset_computation_with_iterative_process
which can be used to do exactly this.
trainer_accepting_ids = tff.simulation.compose_dataset_computation_with_iterative_process(
preprocessed_and_shuffled.dataset_computation, trainer)
Both this tff.templates.IterativeProcesses
and the one above run the same way; but former accepts preprocessed client datasets, and the latter accepts strings representing client ids, handling both dataset construction and preprocessing in its body--in fact state
can be passed between the two.
for _ in range(5):
t1 = time.time()
result = trainer_accepting_ids.next(state, selected_client_ids)
state = result.state
train_metrics = result.metrics['client_work']['train']
t2 = time.time()
print(f'loss {train_metrics["loss"]:.2f}, round time {t2 - t1:.2f} seconds')
Scaling to large numbers of clients
trainer_accepting_ids
can immediately be used in TFF's multimachine runtime, and avoids materializing tf.data.Datasets
and the controller (and therefore serializing them and sending them out to the workers).
This significantly speeds up distributed simulations, especially with a large number of clients, and enables intermediate aggregation to avoid similar serialization/deserialization overhead.
Optional deepdive: manually composing preprocessing logic in TFF
TFF is designed for compositionality from the ground up; the kind of composition just performed by TFF's helper is fully within our control as users. We could have manually compose the preprocessing computation we just defined with the trainer's own next
quite simply:
selected_clients_type = federated_language.FederatedType(
preprocessed_and_shuffled.dataset_computation.type_signature.parameter,
tff.CLIENTS,
)
@tff.federated_computation(
trainer.next.type_signature.parameter[0], selected_clients_type
)
def new_next(server_state, selected_clients):
preprocessed_data = tff.federated_map(
preprocessed_and_shuffled.dataset_computation, selected_clients
)
return trainer.next(server_state, preprocessed_data)
manual_trainer_with_preprocessing = tff.templates.IterativeProcess(
initialize_fn=trainer.initialize, next_fn=new_next
)
In fact, this is effectively what the helper we used is doing under the hood (plus performing appropriate type checking and manipulation). We could even have expressed the same logic slightly differently, by serializing preprocess_and_shuffle
into a tff.Computation
, and decomposing the federated_map
into one step which constructs un-preprocessed datasets and another which runs preprocess_and_shuffle
at each client.
We can verify that this more-manual path results in computations with the same type signature as TFF's helper (modulo parameter names):
print(trainer_accepting_ids.next.type_signature)
print(manual_trainer_with_preprocessing.next.type_signature)
(<state=<global_model_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,distributor=<>,client_work=<>,aggregator=<value_sum_process=<>,weight_sum_process=<>>,finalizer=<learning_rate=float32>>@SERVER,client_data={str}@CLIENTS> -> <state=<global_model_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,distributor=<>,client_work=<>,aggregator=<value_sum_process=<>,weight_sum_process=<>>,finalizer=<learning_rate=float32>>@SERVER,metrics=<distributor=<>,client_work=<train=<loss=float32,accuracy=float32>>,aggregator=<mean_value=<>,mean_weight=<>>,finalizer=<update_non_finite=int32>>@SERVER>) (<server_state=<global_model_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,distributor=<>,client_work=<>,aggregator=<value_sum_process=<>,weight_sum_process=<>>,finalizer=<learning_rate=float32>>@SERVER,selected_clients={str}@CLIENTS> -> <state=<global_model_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,distributor=<>,client_work=<>,aggregator=<value_sum_process=<>,weight_sum_process=<>>,finalizer=<learning_rate=float32>>@SERVER,metrics=<distributor=<>,client_work=<train=<loss=float32,accuracy=float32>>,aggregator=<mean_value=<>,mean_weight=<>>,finalizer=<update_non_finite=int32>>@SERVER>)