Lihat di TensorFlow.org | Jalankan di Google Colab | Lihat sumber di GitHub | Unduh buku catatan |
Ini adalah sebuah alternatif untuk Membangun Sendiri Algoritma Pembelajaran Federasi Anda tutorial dan simple_fedavg contoh untuk membangun sebuah proses berulang kustom untuk rata-rata Federasi algoritma. Tutorial ini akan menggunakan pengoptimalan TFF bukan pengoptimalan Keras. Abstraksi pengoptimal TFF dirancang menjadi state-in-state-out agar lebih mudah untuk digabungkan dalam proses iteratif TFF. The tff.learning
API juga menerima pengoptimalan TFF sebagai argumen masukan.
Sebelum kita mulai
Sebelum kita mulai, jalankan yang berikut ini untuk memastikan bahwa lingkungan Anda telah diatur dengan benar. Jika Anda tidak melihat salam, silakan merujuk ke Instalasi panduan untuk petunjuk.
!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio
import nest_asyncio
nest_asyncio.apply()
import functools
import attr
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
Menyiapkan data dan model
Pengolahan EMNIST data dan model yang sangat mirip dengan simple_fedavg contoh.
only_digits=True
# Load dataset.
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(only_digits)
# Define preprocessing functions.
def preprocess_fn(dataset, batch_size=16):
def batch_format_fn(element):
return (tf.expand_dims(element['pixels'], -1), element['label'])
return dataset.batch(batch_size).map(batch_format_fn)
# Preprocess and sample clients for prototyping.
train_client_ids = sorted(emnist_train.client_ids)
train_data = emnist_train.preprocess(preprocess_fn)
central_test_data = preprocess_fn(
emnist_train.create_tf_dataset_for_client(train_client_ids[0]))
# Define model.
def create_keras_model():
"""The CNN model used in https://arxiv.org/abs/1602.05629."""
data_format = 'channels_last'
input_shape = [28, 28, 1]
max_pool = functools.partial(
tf.keras.layers.MaxPooling2D,
pool_size=(2, 2),
padding='same',
data_format=data_format)
conv2d = functools.partial(
tf.keras.layers.Conv2D,
kernel_size=5,
padding='same',
data_format=data_format,
activation=tf.nn.relu)
model = tf.keras.models.Sequential([
conv2d(filters=32, input_shape=input_shape),
max_pool(),
conv2d(filters=64),
max_pool(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dense(10 if only_digits else 62),
])
return model
# Wrap as `tff.learning.Model`.
def model_fn():
keras_model = create_keras_model()
return tff.learning.from_keras_model(
keras_model,
input_spec=central_test_data.element_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))
Proses berulang khusus
Dalam banyak kasus, algoritma federasi memiliki 4 komponen utama:
- Langkah siaran server-ke-klien.
- Langkah pembaruan klien lokal.
- Langkah unggah klien-ke-server.
- Langkah pembaruan server.
Dalam TFF, kita umumnya mewakili algoritma federasi sebagai tff.templates.IterativeProcess
(yang kita sebut sebagai hanya sebuah IterativeProcess
seluruh). Ini adalah kelas yang berisi initialize
dan next
fungsi. Di sini, initialize
digunakan untuk menginisialisasi server, dan next
akan melakukan satu putaran komunikasi dari algoritma Federasi.
Kami akan memperkenalkan komponen yang berbeda untuk membangun algoritma rata-rata federasi (FedAvg), yang akan menggunakan pengoptimal di langkah pembaruan klien, dan pengoptimal lain di langkah pembaruan server. Logika inti pembaruan klien dan server dapat dinyatakan sebagai blok TF murni.
Blok TF: pembaruan klien dan server
Pada setiap klien, lokal client_optimizer
diinisialisasi dan digunakan untuk memperbarui bobot model client. Di server, server_optimizer
akan menggunakan negara dari putaran sebelumnya, dan memperbarui negara untuk putaran berikutnya.
@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
"""Performs local training on the client's dataset."""
# Initialize the client model with the current server weights.
client_weights = model.trainable_variables
# Assign the server weights to the client model.
tf.nest.map_structure(lambda x, y: x.assign(y),
client_weights, server_weights)
# Initialize the client optimizer.
trainable_tensor_specs = tf.nest.map_structure(
lambda v: tf.TensorSpec(v.shape, v.dtype), client_weights)
optimizer_state = client_optimizer.initialize(trainable_tensor_specs)
# Use the client_optimizer to update the local model.
for batch in iter(dataset):
with tf.GradientTape() as tape:
# Compute a forward pass on the batch of data.
outputs = model.forward_pass(batch)
# Compute the corresponding gradient.
grads = tape.gradient(outputs.loss, client_weights)
# Apply the gradient using a client optimizer.
optimizer_state, updated_weights = client_optimizer.next(
optimizer_state, client_weights, grads)
tf.nest.map_structure(lambda a, b: a.assign(b),
client_weights, updated_weights)
# Return model deltas.
return tf.nest.map_structure(tf.subtract, client_weights, server_weights)
@attr.s(eq=False, frozen=True, slots=True)
class ServerState(object):
trainable_weights = attr.ib()
optimizer_state = attr.ib()
@tf.function
def server_update(server_state, mean_model_delta, server_optimizer):
"""Updates the server model weights."""
# Use aggregated negative model delta as pseudo gradient.
negative_weights_delta = tf.nest.map_structure(
lambda w: -1.0 * w, mean_model_delta)
new_optimizer_state, updated_weights = server_optimizer.next(
server_state.optimizer_state, server_state.trainable_weights,
negative_weights_delta)
return tff.structure.update_struct(
server_state,
trainable_weights=updated_weights,
optimizer_state=new_optimizer_state)
TFF blok: tff.tf_computation
dan tff.federated_computation
Kami sekarang menggunakan TFF untuk orkestrasi dan membangun proses berulang untuk FedAvg. Kita harus membungkus blok TF didefinisikan di atas dengan tff.tf_computation
, dan metode penggunaan TFF tff.federated_broadcast
, tff.federated_map
, tff.federated_mean
dalam tff.federated_computation
fungsi. Sangat mudah untuk menggunakan tff.learning.optimizers.Optimizer
API dengan initialize
dan next
berfungsi saat mendefinisikan proses berulang kustom.
# 1. Server and client optimizer to be used.
server_optimizer = tff.learning.optimizers.build_sgdm(
learning_rate=0.05, momentum=0.9)
client_optimizer = tff.learning.optimizers.build_sgdm(
learning_rate=0.01)
# 2. Functions return initial state on server.
@tff.tf_computation
def server_init():
model = model_fn()
trainable_tensor_specs = tf.nest.map_structure(
lambda v: tf.TensorSpec(v.shape, v.dtype), model.trainable_variables)
optimizer_state = server_optimizer.initialize(trainable_tensor_specs)
return ServerState(
trainable_weights=model.trainable_variables,
optimizer_state=optimizer_state)
@tff.federated_computation
def server_init_tff():
return tff.federated_value(server_init(), tff.SERVER)
# 3. One round of computation and communication.
server_state_type = server_init.type_signature.result
print('server_state_type:\n',
server_state_type.formatted_representation())
trainable_weights_type = server_state_type.trainable_weights
print('trainable_weights_type:\n',
trainable_weights_type.formatted_representation())
# 3-1. Wrap server and client TF blocks with `tff.tf_computation`.
@tff.tf_computation(server_state_type, trainable_weights_type)
def server_update_fn(server_state, model_delta):
return server_update(server_state, model_delta, server_optimizer)
whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)
print('tf_dataset_type:\n',
tf_dataset_type.formatted_representation())
@tff.tf_computation(tf_dataset_type, trainable_weights_type)
def client_update_fn(dataset, server_weights):
model = model_fn()
return client_update(model, dataset, server_weights, client_optimizer)
# 3-2. Orchestration with `tff.federated_computation`.
federated_server_type = tff.FederatedType(server_state_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)
@tff.federated_computation(federated_server_type, federated_dataset_type)
def run_one_round(server_state, federated_dataset):
# Server-to-client broadcast.
server_weights_at_client = tff.federated_broadcast(
server_state.trainable_weights)
# Local client update.
model_deltas = tff.federated_map(
client_update_fn, (federated_dataset, server_weights_at_client))
# Client-to-server upload and aggregation.
mean_model_delta = tff.federated_mean(model_deltas)
# Server update.
server_state = tff.federated_map(
server_update_fn, (server_state, mean_model_delta))
return server_state
# 4. Build the iterative process for FedAvg.
fedavg_process = tff.templates.IterativeProcess(
initialize_fn=server_init_tff, next_fn=run_one_round)
print('type signature of `initialize`:\n',
fedavg_process.initialize.type_signature.formatted_representation())
print('type signature of `next`:\n',
fedavg_process.next.type_signature.formatted_representation())
server_state_type: < trainable_weights=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] >, optimizer_state=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] > > trainable_weights_type: < float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] > tf_dataset_type: < float32[?,28,28,1], int32[?] >* type signature of `initialize`: ( -> < trainable_weights=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] >, optimizer_state=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] > >@SERVER) type signature of `next`: (< server_state=< trainable_weights=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] >, optimizer_state=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] > >@SERVER, federated_dataset={< float32[?,28,28,1], int32[?] >*}@CLIENTS > -> < trainable_weights=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] >, optimizer_state=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] > >@SERVER)
Mengevaluasi algoritma
Kami mengevaluasi kinerja pada dataset evaluasi terpusat.
def evaluate(server_state):
keras_model = create_keras_model()
tf.nest.map_structure(
lambda var, t: var.assign(t),
keras_model.trainable_weights, server_state.trainable_weights)
metric = tf.keras.metrics.SparseCategoricalAccuracy()
for batch in iter(central_test_data):
preds = keras_model(batch[0], training=False)
metric.update_state(y_true=batch[1], y_pred=preds)
return metric.result().numpy()
server_state = fedavg_process.initialize()
acc = evaluate(server_state)
print('Initial test accuracy', acc)
# Evaluate after a few rounds
CLIENTS_PER_ROUND=2
sampled_clients = train_client_ids[:CLIENTS_PER_ROUND]
sampled_train_data = [
train_data.create_tf_dataset_for_client(client)
for client in sampled_clients]
for round in range(20):
server_state = fedavg_process.next(server_state, sampled_train_data)
acc = evaluate(server_state)
print('Test accuracy', acc)
Initial test accuracy 0.09677419 Test accuracy 0.13978495