View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |
In this tutorial, we use the EMNIST dataset to demonstrate how to enable lossy compression algorithms to reduce communication cost in the Federated Averaging algorithm using the tff.learning
API. For more details on the Federated Averaging algorithm, see the paper Communication-Efficient Learning of Deep Networks from Decentralized Data.
Before we start
Before we start, please run the following to make sure that your environment is correctly setup. If you don't see a greeting, please refer to the Installation guide for instructions.
pip install --quiet --upgrade tensorflow-federated
pip install --quiet --upgrade tensorflow-model-optimization
%load_ext tensorboard
import functools
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
Verify if TFF is working.
@tff.federated_computation
def hello_world():
return 'Hello, World!'
hello_world()
b'Hello, World!'
Preparing the input data
In this section we load and preprocess the EMNIST dataset included in TFF. Please check out Federated Learning for Image Classification tutorial for more details about EMNIST dataset.
# This value only applies to EMNIST dataset, consider choosing appropriate
# values if switching to other datasets.
MAX_CLIENT_DATASET_SIZE = 418
CLIENT_EPOCHS_PER_ROUND = 1
CLIENT_BATCH_SIZE = 20
TEST_BATCH_SIZE = 500
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(
only_digits=True)
def reshape_emnist_element(element):
return (tf.expand_dims(element['pixels'], axis=-1), element['label'])
def preprocess_train_dataset(dataset):
"""Preprocessing function for the EMNIST training dataset."""
return (dataset
# Shuffle according to the largest client dataset
.shuffle(buffer_size=MAX_CLIENT_DATASET_SIZE)
# Repeat to do multiple local epochs
.repeat(CLIENT_EPOCHS_PER_ROUND)
# Batch to a fixed client batch size
.batch(CLIENT_BATCH_SIZE, drop_remainder=False)
# Preprocessing step
.map(reshape_emnist_element))
emnist_train = emnist_train.preprocess(preprocess_train_dataset)
Defining a model
Here we define a keras model based on the orginial FedAvg CNN, and then wrap the keras model in an instance of tff.learning.models.VariableModel so that it can be consumed by TFF.
Note that we'll need a function which produces a model instead of simply a model directly. In addition, the function cannot just capture a pre-constructed model, it must create the model in the context that it is called. The reason is that TFF is designed to go to devices, and needs control over when resources are constructed so that they can be captured and packaged up.
def create_original_fedavg_cnn_model(only_digits=True):
"""The CNN model used in https://arxiv.org/abs/1602.05629."""
data_format = 'channels_last'
max_pool = functools.partial(
tf.keras.layers.MaxPooling2D,
pool_size=(2, 2),
padding='same',
data_format=data_format)
conv2d = functools.partial(
tf.keras.layers.Conv2D,
kernel_size=5,
padding='same',
data_format=data_format,
activation=tf.nn.relu)
model = tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
conv2d(filters=32),
max_pool(),
conv2d(filters=64),
max_pool(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dense(10 if only_digits else 62),
tf.keras.layers.Softmax(),
])
return model
# Gets the type information of the input data. TFF is a strongly typed
# functional programming framework, and needs type information about inputs to
# the model.
input_spec = emnist_train.create_tf_dataset_for_client(
emnist_train.client_ids[0]).element_spec
def tff_model_fn():
keras_model = create_original_fedavg_cnn_model()
return tff.learning.models.from_keras_model(
keras_model=keras_model,
input_spec=input_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
Training the model and outputting training metrics
Now we are ready to construct a Federated Averaging algorithm and train the defined model on EMNIST dataset.
First we need to build a Federated Averaging algorithm using the tff.learning.algorithms.build_weighted_fed_avg API.
federated_averaging = tff.learning.algorithms.build_weighted_fed_avg(
model_fn=tff_model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))
Now let's run the Federated Averaging algorithm. The execution of a Federated Learning algorithm from the perspective of TFF looks like this:
- Initialize the algorithm and get the inital server state. The server state contains necessary information to perform the algorithm. Recall, since TFF is functional, that this state includes both any optimizer state the algorithm uses (e.g. momentum terms) as well as the model parameters themselves--these will be passed as arguments and returned as results from TFF computations.
- Execute the algorithm round by round. In each round, a new server state will be returned as the result of each client training the model on its data. Typically in one round:
- Server broadcast the model to all the participating clients.
- Each client perform work based on the model and its own data.
- Server aggregates all the model to produce a sever state which contains a new model.
For more details, please see Custom Federated Algorithms, Part 2: Implementing Federated Averaging tutorial.
Training metrics are written to the Tensorboard directory for displaying after the training.
def train(federated_averaging_process, num_rounds, num_clients_per_round, summary_writer):
"""Trains the federated averaging process and output metrics."""
# Initialize the Federated Averaging algorithm to get the initial server state.
state = federated_averaging_process.initialize()
with summary_writer.as_default():
for round_num in range(num_rounds):
# Sample the clients parcitipated in this round.
sampled_clients = np.random.choice(
emnist_train.client_ids,
size=num_clients_per_round,
replace=False)
# Create a list of `tf.Dataset` instances from the data of sampled clients.
sampled_train_data = [
emnist_train.create_tf_dataset_for_client(client)
for client in sampled_clients
]
# Round one round of the algorithm based on the server state and client data
# and output the new state and metrics.
result = federated_averaging_process.next(state, sampled_train_data)
state = result.state
train_metrics = result.metrics['client_work']['train']
# Add metrics to Tensorboard.
for name, value in train_metrics.items():
tf.summary.scalar(name, value, step=round_num)
summary_writer.flush()
# Clean the log directory to avoid conflicts.
try:
tf.io.gfile.rmtree('/tmp/logs/scalars')
except tf.errors.OpError as e:
pass # Path doesn't exist
# Set up the log directory and writer for Tensorboard.
logdir = "/tmp/logs/scalars/original/"
summary_writer = tf.summary.create_file_writer(logdir)
train(federated_averaging_process=federated_averaging, num_rounds=10,
num_clients_per_round=10, summary_writer=summary_writer)
round 0, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.092454836), ('loss', 2.310193), ('num_examples', 941), ('num_batches', 51)]), broadcasted_bits=507.62Mibit, aggregated_bits=507.62Mibit round 1, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.10029791), ('loss', 2.3102622), ('num_examples', 1007), ('num_batches', 55)]), broadcasted_bits=1015.24Mibit, aggregated_bits=1015.25Mibit round 2, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.10710711), ('loss', 2.3048222), ('num_examples', 999), ('num_batches', 54)]), broadcasted_bits=1.49Gibit, aggregated_bits=1.49Gibit round 3, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.1061061), ('loss', 2.3066027), ('num_examples', 999), ('num_batches', 55)]), broadcasted_bits=1.98Gibit, aggregated_bits=1.98Gibit round 4, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.1287594), ('loss', 2.2999024), ('num_examples', 1064), ('num_batches', 58)]), broadcasted_bits=2.48Gibit, aggregated_bits=2.48Gibit round 5, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.13529412), ('loss', 2.2994456), ('num_examples', 1020), ('num_batches', 55)]), broadcasted_bits=2.97Gibit, aggregated_bits=2.97Gibit round 6, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.124045804), ('loss', 2.2947247), ('num_examples', 1048), ('num_batches', 57)]), broadcasted_bits=3.47Gibit, aggregated_bits=3.47Gibit round 7, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.14217557), ('loss', 2.290349), ('num_examples', 1048), ('num_batches', 57)]), broadcasted_bits=3.97Gibit, aggregated_bits=3.97Gibit round 8, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.14641434), ('loss', 2.290953), ('num_examples', 1004), ('num_batches', 56)]), broadcasted_bits=4.46Gibit, aggregated_bits=4.46Gibit round 9, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.1695238), ('loss', 2.2859888), ('num_examples', 1050), ('num_batches', 57)]), broadcasted_bits=4.96Gibit, aggregated_bits=4.96Gibit
Start TensorBoard with the root log directory specified above to display the training metrics. It can take a few seconds for the data to load. Except for Loss and Accuracy, we also output the amount of broadcasted and aggregated data. Broadcasted data refers to tensors the server pushes to each client while aggregated data refers to tensors each client returns to the server.
%tensorboard --logdir /tmp/logs/scalars/ --port=0
Build a custom aggregation function
Now let's implement function to use lossy compression algorithms on aggregated data. We will use TFF's API to create a tff.aggregators.AggregationFactory
for this. While researchers may often want to implement their own (which can be done via the tff.aggregators
API), we will use a built-in method for doing so, specifically tff.learning.compression_aggregator
.
It is important to note that this aggregator does not apply compression to the entire model at once. Instead, applies compression to only those variables in the model that are sufficiently large. Generally, small variables such as biases are more sensitive to inaccuracy, and being relatively small, the potential communication savings are also relatively small.
compression_aggregator = tff.learning.compression_aggregator()
isinstance(compression_aggregator, tff.aggregators.WeightedAggregationFactory)
True
Above, you can see that the compression aggregator is a weighted aggregation factory, which means that it involves weighted aggregation (in contrast to aggregators meant for differential privacy, which are often unweighted).
This aggregation factory can be direclty plugged into FedAvg via its model_aggregator
argument.
federated_averaging_with_compression = tff.learning.algorithms.build_weighted_fed_avg(
tff_model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
model_aggregator=compression_aggregator)
Training the model again
Now let's run the new Federated Averaging algorithm.
logdir_for_compression = "/tmp/logs/scalars/compression/"
summary_writer_for_compression = tf.summary.create_file_writer(
logdir_for_compression)
train(federated_averaging_process=federated_averaging_with_compression,
num_rounds=10,
num_clients_per_round=10,
summary_writer=summary_writer_for_compression)
round 0, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.087804876), ('loss', 2.3126457), ('num_examples', 1025), ('num_batches', 55)]), broadcasted_bits=507.62Mibit, aggregated_bits=146.47Mibit round 1, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.073267326), ('loss', 2.3111901), ('num_examples', 1010), ('num_batches', 56)]), broadcasted_bits=1015.24Mibit, aggregated_bits=292.93Mibit round 2, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.08925144), ('loss', 2.3071017), ('num_examples', 1042), ('num_batches', 57)]), broadcasted_bits=1.49Gibit, aggregated_bits=439.40Mibit round 3, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.07985144), ('loss', 2.3061485), ('num_examples', 1077), ('num_batches', 59)]), broadcasted_bits=1.98Gibit, aggregated_bits=585.86Mibit round 4, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.11947791), ('loss', 2.302166), ('num_examples', 996), ('num_batches', 55)]), broadcasted_bits=2.48Gibit, aggregated_bits=732.33Mibit round 5, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.12195122), ('loss', 2.2997446), ('num_examples', 984), ('num_batches', 54)]), broadcasted_bits=2.97Gibit, aggregated_bits=878.79Mibit round 6, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.10429448), ('loss', 2.2997215), ('num_examples', 978), ('num_batches', 55)]), broadcasted_bits=3.47Gibit, aggregated_bits=1.00Gibit round 7, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.16857143), ('loss', 2.2961135), ('num_examples', 1050), ('num_batches', 56)]), broadcasted_bits=3.97Gibit, aggregated_bits=1.14Gibit round 8, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.1399177), ('loss', 2.2942808), ('num_examples', 972), ('num_batches', 54)]), broadcasted_bits=4.46Gibit, aggregated_bits=1.29Gibit round 9, train_metrics=OrderedDict([('sparse_categorical_accuracy', 0.14202899), ('loss', 2.2972558), ('num_examples', 1035), ('num_batches', 57)]), broadcasted_bits=4.96Gibit, aggregated_bits=1.43Gibit
Start TensorBoard again to compare the training metrics between two runs.
As you can see in Tensorboard, there is a significant reduction between the orginial
and compression
curves in the aggregated_bits
plots while in the loss
and sparse_categorical_accuracy
plot the two curves are pretty similiar.
In conclusion, we implemented a compression algorithm that can achieve similar performance as the orignial Federated Averaging algorithm while the comminucation cost is significently reduced.
%tensorboard --logdir /tmp/logs/scalars/ --port=0
Exercises
To implement a custom compression algorithm and apply it to the training loop, you can:
- Implement a new compression algorithm as a subclass of tff.aggregators.MeanFactory.
- Perform training with the compression algorithm to see if it does better than the algorithm above.
Potentially valuable open research questions include: non-uniform quantization, lossless compression such as huffman coding, and mechanisms for adapting compression based on the information from previous training rounds.
Recommended reading materials: