הצג באתר TensorFlow.org | הפעל בגוגל קולאב | צפה במקור ב-GitHub | הורד מחברת |
זוהי אלטרנטיבה Build Federated משלך אלגוריתם למידה הדרכה ואת simple_fedavg למשל לבנות תהליך איטרטיבי אישית עבור מיצוע Federated האלגוריתם. המדריך הזה משתמש אופטימיזציה TFF במקום אופטימיזציה Keras. ההפשטה של אופטימיזציית TFF מיועדת להיות במצב-in-state-out כדי שיהיה קל יותר לשילוב בתהליך איטרטיבי של TFF. tff.learning
APIs גם לקבל אופטימיזציה TFF כארגומנט קלט.
לפני שאנחנו מתחילים
לפני שנתחיל, אנא הפעל את הפעולות הבאות כדי לוודא שהסביבה שלך מוגדרת כהלכה. אם אינך רואה ברכה, עיין התקנה מדריך לקבלת הוראות.
!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio
import nest_asyncio
nest_asyncio.apply()
import functools
import attr
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
הכנת נתונים ומודל
עיבוד נתוני EMNIST המודל מאוד דומים simple_fedavg למשל.
only_digits=True
# Load dataset.
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(only_digits)
# Define preprocessing functions.
def preprocess_fn(dataset, batch_size=16):
def batch_format_fn(element):
return (tf.expand_dims(element['pixels'], -1), element['label'])
return dataset.batch(batch_size).map(batch_format_fn)
# Preprocess and sample clients for prototyping.
train_client_ids = sorted(emnist_train.client_ids)
train_data = emnist_train.preprocess(preprocess_fn)
central_test_data = preprocess_fn(
emnist_train.create_tf_dataset_for_client(train_client_ids[0]))
# Define model.
def create_keras_model():
"""The CNN model used in https://arxiv.org/abs/1602.05629."""
data_format = 'channels_last'
input_shape = [28, 28, 1]
max_pool = functools.partial(
tf.keras.layers.MaxPooling2D,
pool_size=(2, 2),
padding='same',
data_format=data_format)
conv2d = functools.partial(
tf.keras.layers.Conv2D,
kernel_size=5,
padding='same',
data_format=data_format,
activation=tf.nn.relu)
model = tf.keras.models.Sequential([
conv2d(filters=32, input_shape=input_shape),
max_pool(),
conv2d(filters=64),
max_pool(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dense(10 if only_digits else 62),
])
return model
# Wrap as `tff.learning.Model`.
def model_fn():
keras_model = create_keras_model()
return tff.learning.from_keras_model(
keras_model,
input_spec=central_test_data.element_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))
תהליך איטרטיבי מותאם אישית
במקרים רבים, לאלגוריתמים מאוחדים יש 4 מרכיבים עיקריים:
- שלב שידור שרת ללקוח.
- שלב עדכון לקוח מקומי.
- שלב העלאת לקוח לשרת.
- שלב עדכון שרת.
בשנת TFF, אנחנו מייצגים אלגוריתמים Federated בדרך כלל בתור tff.templates.IterativeProcess
(אשר אנו מתייחסים כאל סתם IterativeProcess
לאורך). זוהי מחלקה המכילה initialize
ו next
פונקציות. הנה, initialize
משמשת לאתחל את השרת, ואת next
תבצע סיבוב תקשורת אחד של אלגוריתם Federated.
נציג רכיבים שונים לבניית אלגוריתם הממוצע המאוחד (FedAvg), שישתמש באופטימיזציה בשלב עדכון הלקוח, ובאופטימיזציה נוספת בשלב עדכון השרת. הלוגיקה הליבה של עדכוני לקוח ושרת יכולה לבוא לידי ביטוי כבלוקים TF טהורים.
בלוקים TF: עדכון לקוח ושרת
על כל לקוח, מקומי client_optimizer
מאותחל זאת כדי לעדכן את המשקולות מודל הלקוח. בשרת, server_optimizer
ישתמש המדינה מהסיבוב הקודם, ולעדכן את המדינה לקראת הסיבוב הבא.
@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
"""Performs local training on the client's dataset."""
# Initialize the client model with the current server weights.
client_weights = model.trainable_variables
# Assign the server weights to the client model.
tf.nest.map_structure(lambda x, y: x.assign(y),
client_weights, server_weights)
# Initialize the client optimizer.
trainable_tensor_specs = tf.nest.map_structure(
lambda v: tf.TensorSpec(v.shape, v.dtype), client_weights)
optimizer_state = client_optimizer.initialize(trainable_tensor_specs)
# Use the client_optimizer to update the local model.
for batch in iter(dataset):
with tf.GradientTape() as tape:
# Compute a forward pass on the batch of data.
outputs = model.forward_pass(batch)
# Compute the corresponding gradient.
grads = tape.gradient(outputs.loss, client_weights)
# Apply the gradient using a client optimizer.
optimizer_state, updated_weights = client_optimizer.next(
optimizer_state, client_weights, grads)
tf.nest.map_structure(lambda a, b: a.assign(b),
client_weights, updated_weights)
# Return model deltas.
return tf.nest.map_structure(tf.subtract, client_weights, server_weights)
@attr.s(eq=False, frozen=True, slots=True)
class ServerState(object):
trainable_weights = attr.ib()
optimizer_state = attr.ib()
@tf.function
def server_update(server_state, mean_model_delta, server_optimizer):
"""Updates the server model weights."""
# Use aggregated negative model delta as pseudo gradient.
negative_weights_delta = tf.nest.map_structure(
lambda w: -1.0 * w, mean_model_delta)
new_optimizer_state, updated_weights = server_optimizer.next(
server_state.optimizer_state, server_state.trainable_weights,
negative_weights_delta)
return tff.structure.update_struct(
server_state,
trainable_weights=updated_weights,
optimizer_state=new_optimizer_state)
בלוקים TFF: tff.tf_computation
ו tff.federated_computation
כעת אנו משתמשים ב-TFF לתזמור ובונים את התהליך האיטרטיבי עבור FedAvg. אנחנו צריכים לעטוף את אבני TF כמוגדר לעיל עם tff.tf_computation
, ושיטות TFF השימוש tff.federated_broadcast
, tff.federated_map
, tff.federated_mean
בתוך tff.federated_computation
פונקציה. זה קל להשתמש tff.learning.optimizers.Optimizer
APIs עם initialize
ו next
פונקציות בעת הגדרת תהליך מנהג איטרטיבי.
# 1. Server and client optimizer to be used.
server_optimizer = tff.learning.optimizers.build_sgdm(
learning_rate=0.05, momentum=0.9)
client_optimizer = tff.learning.optimizers.build_sgdm(
learning_rate=0.01)
# 2. Functions return initial state on server.
@tff.tf_computation
def server_init():
model = model_fn()
trainable_tensor_specs = tf.nest.map_structure(
lambda v: tf.TensorSpec(v.shape, v.dtype), model.trainable_variables)
optimizer_state = server_optimizer.initialize(trainable_tensor_specs)
return ServerState(
trainable_weights=model.trainable_variables,
optimizer_state=optimizer_state)
@tff.federated_computation
def server_init_tff():
return tff.federated_value(server_init(), tff.SERVER)
# 3. One round of computation and communication.
server_state_type = server_init.type_signature.result
print('server_state_type:\n',
server_state_type.formatted_representation())
trainable_weights_type = server_state_type.trainable_weights
print('trainable_weights_type:\n',
trainable_weights_type.formatted_representation())
# 3-1. Wrap server and client TF blocks with `tff.tf_computation`.
@tff.tf_computation(server_state_type, trainable_weights_type)
def server_update_fn(server_state, model_delta):
return server_update(server_state, model_delta, server_optimizer)
whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)
print('tf_dataset_type:\n',
tf_dataset_type.formatted_representation())
@tff.tf_computation(tf_dataset_type, trainable_weights_type)
def client_update_fn(dataset, server_weights):
model = model_fn()
return client_update(model, dataset, server_weights, client_optimizer)
# 3-2. Orchestration with `tff.federated_computation`.
federated_server_type = tff.FederatedType(server_state_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)
@tff.federated_computation(federated_server_type, federated_dataset_type)
def run_one_round(server_state, federated_dataset):
# Server-to-client broadcast.
server_weights_at_client = tff.federated_broadcast(
server_state.trainable_weights)
# Local client update.
model_deltas = tff.federated_map(
client_update_fn, (federated_dataset, server_weights_at_client))
# Client-to-server upload and aggregation.
mean_model_delta = tff.federated_mean(model_deltas)
# Server update.
server_state = tff.federated_map(
server_update_fn, (server_state, mean_model_delta))
return server_state
# 4. Build the iterative process for FedAvg.
fedavg_process = tff.templates.IterativeProcess(
initialize_fn=server_init_tff, next_fn=run_one_round)
print('type signature of `initialize`:\n',
fedavg_process.initialize.type_signature.formatted_representation())
print('type signature of `next`:\n',
fedavg_process.next.type_signature.formatted_representation())
server_state_type: < trainable_weights=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] >, optimizer_state=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] > > trainable_weights_type: < float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] > tf_dataset_type: < float32[?,28,28,1], int32[?] >* type signature of `initialize`: ( -> < trainable_weights=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] >, optimizer_state=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] > >@SERVER) type signature of `next`: (< server_state=< trainable_weights=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] >, optimizer_state=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] > >@SERVER, federated_dataset={< float32[?,28,28,1], int32[?] >*}@CLIENTS > -> < trainable_weights=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] >, optimizer_state=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] > >@SERVER)
הערכת האלגוריתם
אנו מעריכים את הביצועים במערך נתונים מרכזי של הערכה.
def evaluate(server_state):
keras_model = create_keras_model()
tf.nest.map_structure(
lambda var, t: var.assign(t),
keras_model.trainable_weights, server_state.trainable_weights)
metric = tf.keras.metrics.SparseCategoricalAccuracy()
for batch in iter(central_test_data):
preds = keras_model(batch[0], training=False)
metric.update_state(y_true=batch[1], y_pred=preds)
return metric.result().numpy()
server_state = fedavg_process.initialize()
acc = evaluate(server_state)
print('Initial test accuracy', acc)
# Evaluate after a few rounds
CLIENTS_PER_ROUND=2
sampled_clients = train_client_ids[:CLIENTS_PER_ROUND]
sampled_train_data = [
train_data.create_tf_dataset_for_client(client)
for client in sampled_clients]
for round in range(20):
server_state = fedavg_process.next(server_state, sampled_train_data)
acc = evaluate(server_state)
print('Test accuracy', acc)
Initial test accuracy 0.09677419 Test accuracy 0.13978495