This tutorial will describe how to set up high-performance simulation using a TFF runtime deployed on Kubernetes.
For demonstrative purposes, we'll use the TFF simulation for image classification from the tutorial, Federated Learning for Image Classification, but we'll run it against a multi-machine setup consisting of two TFF workers running in Kubernetes. We'll use the same EMNIST dataset for training, but split into two partitions, one for each TFF worker.
This tutorial refers to the following Google Cloud services,
- GKE to create the Kubernetes cluster, but all the steps after the cluster is created can be used with any Kubernetes installation.
- Filestore to serve the training data, but works with any storage medium that can be mounted as a Kubernetes persistent volume.
View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |
Launch the TFF Workers on Kubernetes
Package TFF Worker Binary
worker_service.py contains the source code for our custom TFF worker. It runs a simulation server with custom logic for loading a dataset partition and sampling from it for each round of federated learning. (To learn more, see Loading Remote Data in TFF.)
We're going to deploy our TFF worker as a containerized application on Kubernetes. Lets start by building a Docker image. Using this Dockerfile, we can package the code by running,
$ WORKER_IMAGE=tff-worker-service:latest
$ docker build --tag $WORKER_IMAGE --file "./Dockerfile" .
(Assuming worker_service.py and Dockerfile are located in your working directory.)
Then publish the image to a container repository where it can be accessed by the Kubernetes cluster we're about to create, e.g.,
$ docker push $WORKER_IMAGE
Create a Kubernetes Cluster
The following step only needs to be done once. The cluster can be re-used for future workloads.
Follow the GKE instructions to create a cluster with Filestore CSI driver enabled, e.g.,
gcloud container clusters create tff-cluster --addons=GcpFilestoreCsiDriver
The commands to interact with GCP can be run locally or in the Google Cloud Shell. We recommend the Google Cloud Shell since it doesn't require additional setup.
The rest of this tutorial assumes that the cluster is named tff-cluster
, but the actual name isn't important.
Deploy the TFF Worker Application
worker_deployment.yaml declares the configuration for standing up two TFF workers, each in their own Kubernetes pod with two replicas each. We can apply this configuration to our running cluster,
kubectl apply -f worker_deployment.yaml
Once the changes have been requested, you can check the pods are ready,
kubectl get pod
NAME READY STATUS RESTARTS AGE
tff-workers-deployment-1-6bb8d458d5-hjl9d 1/1 Running 0 5m
tff-workers-deployment-1-6bb8d458d5-jgt4b 1/1 Running 0 5m
tff-workers-deployment-2-6cb76c6f5d-hqt88 1/1 Running 0 5m
tff-workers-deployment-2-6cb76c6f5d-xk92h 1/1 Running 0 5m
Each worker instance runs behind a load balancer with an endpoint. Look up the external IP address of the load balancers,
kubectl get service
NAME TYPE CLUSTER-IP EXTERNAL-IP PORT(S) AGE
tff-workers-service-1 LoadBalancer XX.XX.X.XXX XX.XXX.XX.XXX 80:31830/TCP 6m
tff-workers-service-2 LoadBalancer XX.XX.X.XXX XX.XXX.XX.XXX 80:31319/TCP 6m
You'll need it later to connect the training loop to the running workers.
Prepare Training Data
The EMNIST partitions we'll consume for training can be downloaded from TFF's public dataset repository,
gsutil cp -r gs://tff-datasets-public/emnist-partitions/2-partition
You can then upload them to each pod by copying them to a replica, e.g.,
kubectl cp emnist_part_1.sqlite tff-workers-deployment-1-6bb8d458d5-hjl9d:/root/worker/data/emnist_partition.sqlite
kubectl cp emnist_part_2.sqlite tff-workers-deployment-2-6cb76c6f5d-hqt88:/root/worker/data/emnist_partition.sqlite
Run Simulation
Now we're ready to run simulations against our cluster.
Setup TFF Environment
!pip install --quiet --upgrade tensorflow-federated
!pip install --quiet --upgrade nest-asyncio
import nest_asyncio
nest_asyncio.apply()
Define the Training Procedure
The following defines the dataset iteration methodology, the model architecture, and the round-over-round process for federated learning. (For more detail.)
import collections
from typing import Any, Optional, List
import tensorflow as tf
import tensorflow_federated as tff
class FederatedData(tff.program.FederatedDataSource,
tff.program.FederatedDataSourceIterator):
"""Interface for interacting with the federated training data."""
def __init__(self, type_spec: tff.FederatedType):
self._type_spec = type_spec
self._capabilities = [tff.program.Capability.RANDOM_UNIFORM]
@property
def federated_type(self) -> tff.FederatedType:
return self._type_spec
@property
def capabilities(self) -> List[tff.program.Capability]:
return self._capabilities
def iterator(self) -> tff.program.FederatedDataSourceIterator:
return self
def select(self, num_clients: Optional[int] = None) -> Any:
data_uris = [f'uri://{i}' for i in range(num_clients)]
return tff.framework.CreateDataDescriptor(
arg_uris=data_uris, arg_type=self._type_spec)
input_spec = collections.OrderedDict([
('x', tf.TensorSpec(shape=(1, 784), dtype=tf.float32, name=None)),
('y', tf.TensorSpec(shape=(1, 1), dtype=tf.int32, name=None))
])
element_type = tff.types.StructWithPythonType(
input_spec, container_type=collections.OrderedDict)
dataset_type = tff.types.SequenceType(element_type)
train_data_source = FederatedData(type_spec=dataset_type)
train_data_iterator = train_data_source.iterator()
def model_fn():
model = tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(784,)),
tf.keras.layers.Dense(units=10, kernel_initializer='zeros'),
tf.keras.layers.Softmax(),
])
return tff.learning.from_keras_model(
model,
input_spec=input_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
trainer = tff.learning.algorithms.build_weighted_fed_avg(
model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))
def train_loop(num_rounds=10, num_clients=10):
state = trainer.initialize()
for round in range(1, num_rounds + 1):
train_data = train_data_iterator.select(num_clients)
result = trainer.next(state, train_data)
state = result.state
train_metrics = result.metrics['client_work']['train']
print('round {:2d}, metrics={}'.format(round, train_metrics))
Connect to TFF Workers
By default, TFF executes all computations locally. In this step we tell TFF to connect to the Kubernetes services we set up above. Be sure to copy the external IP addresses of your services here.
import grpc
ip_address_1 = '0.0.0.0'
ip_address_2 = '0.0.0.0'
port = 80
channels = [
grpc.insecure_channel(f'{ip_address_1}:{port}'),
grpc.insecure_channel(f'{ip_address_2}:{port}')
]
tff.backends.native.set_remote_python_execution_context(channels)
Execute Training
train_loop()
round 1, metrics=OrderedDict([('sparse_categorical_accuracy', 0.10557769), ('loss', 12.475689), ('num_examples', 5020), ('num_batches', 5020)]) round 2, metrics=OrderedDict([('sparse_categorical_accuracy', 0.11940298), ('loss', 10.497084), ('num_examples', 5360), ('num_batches', 5360)]) round 3, metrics=OrderedDict([('sparse_categorical_accuracy', 0.16223507), ('loss', 7.569645), ('num_examples', 5190), ('num_batches', 5190)]) round 4, metrics=OrderedDict([('sparse_categorical_accuracy', 0.2648384), ('loss', 6.0947175), ('num_examples', 5105), ('num_batches', 5105)]) round 5, metrics=OrderedDict([('sparse_categorical_accuracy', 0.29003084), ('loss', 6.2815433), ('num_examples', 4865), ('num_batches', 4865)]) round 6, metrics=OrderedDict([('sparse_categorical_accuracy', 0.40237388), ('loss', 4.630901), ('num_examples', 5055), ('num_batches', 5055)]) round 7, metrics=OrderedDict([('sparse_categorical_accuracy', 0.4288425), ('loss', 4.2358975), ('num_examples', 5270), ('num_batches', 5270)]) round 8, metrics=OrderedDict([('sparse_categorical_accuracy', 0.46349892), ('loss', 4.3829923), ('num_examples', 4630), ('num_batches', 4630)]) round 9, metrics=OrderedDict([('sparse_categorical_accuracy', 0.492094), ('loss', 3.8121278), ('num_examples', 4680), ('num_batches', 4680)]) round 10, metrics=OrderedDict([('sparse_categorical_accuracy', 0.5872674), ('loss', 3.058461), ('num_examples', 5105), ('num_batches', 5105)])