Ver en TensorFlow.org | Ejecutar en Google Colab | Ver fuente en GitHub | Descargar cuaderno |
Este tutorial muestra cómo TFF se pueden utilizar para entrenar a un gran modelo en el que cada dispositivo cliente sólo descarga y actualiza una pequeña parte del modelo, utilizando tff.federated_select
y agregación escasa. Si bien este tutorial es bastante auto-contenido, la tff.federated_select
tutorial y personalizada FL algoritmos tutorial proporcionan buenas introducciones a algunas de las técnicas utilizadas aquí.
Concretamente, en este tutorial consideramos la regresión logística para la clasificación de etiquetas múltiples, prediciendo qué "etiquetas" están asociadas con una cadena de texto en función de una representación de características de bolsa de palabras. Es importante destacar que los costos de comunicación y cálculo del lado del cliente son controlados por una constante fija ( MAX_TOKENS_SELECTED_PER_CLIENT
), y no escalan con el tamaño del vocabulario general, lo que podría ser muy importante en situaciones prácticas.
!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio
import nest_asyncio
nest_asyncio.apply()
import collections
import itertools
import numpy as np
from typing import Callable, List, Tuple
import tensorflow as tf
import tensorflow_federated as tff
tff.backends.native.set_local_python_execution_context()
Cada cliente federated_select
las filas de los pesos modelo por un máximo de esto muchas fichas únicas. Esto superior limita el tamaño del modelo local del cliente y la cantidad de servidor -> cliente ( federated_select
) y el cliente -> servidor (federated_aggregate
) lleva a cabo la comunicación.
Este tutorial aún debería ejecutarse correctamente incluso si lo configura en un valor tan pequeño como 1 (asegurándose de que no se seleccionen todos los tokens de cada cliente) o en un valor grande, aunque la convergencia del modelo puede verse afectada.
MAX_TOKENS_SELECTED_PER_CLIENT = 6
También definimos algunas constantes para varios tipos. Por esta colab, un contador es un identificador entero para una palabra en particular después de analizar el conjunto de datos.
# There are some constraints on types
# here that will require some explicit type conversions:
# - `tff.federated_select` requires int32
# - `tf.SparseTensor` requires int64 indices.
TOKEN_DTYPE = tf.int64
SELECT_KEY_DTYPE = tf.int32
# Type for counts of token occurences.
TOKEN_COUNT_DTYPE = tf.int32
# A sparse feature vector can be thought of as a map
# from TOKEN_DTYPE to FEATURE_DTYPE.
# Our features are {0, 1} indicators, so we could potentially
# use tf.int8 as an optimization.
FEATURE_DTYPE = tf.int32
Configuración del problema: conjunto de datos y modelo
Construimos un pequeño conjunto de datos de juguetes para facilitar la experimentación en este tutorial. Sin embargo, el formato del conjunto de datos es compatible con Federated Stackoverflow , y el pre-procesamiento y la arquitectura modelo se adoptó desde el Stackoverflow problema de predicción etiqueta de optimización Federados adaptativa .
Análisis y preprocesamiento de conjuntos de datos
NUM_OOV_BUCKETS = 1
BatchType = collections.namedtuple('BatchType', ['tokens', 'tags'])
def build_to_ids_fn(word_vocab: List[str],
tag_vocab: List[str]) -> Callable[[tf.Tensor], tf.Tensor]:
"""Constructs a function mapping examples to sequences of token indices."""
word_table_values = np.arange(len(word_vocab), dtype=np.int64)
word_table = tf.lookup.StaticVocabularyTable(
tf.lookup.KeyValueTensorInitializer(word_vocab, word_table_values),
num_oov_buckets=NUM_OOV_BUCKETS)
tag_table_values = np.arange(len(tag_vocab), dtype=np.int64)
tag_table = tf.lookup.StaticVocabularyTable(
tf.lookup.KeyValueTensorInitializer(tag_vocab, tag_table_values),
num_oov_buckets=NUM_OOV_BUCKETS)
def to_ids(example):
"""Converts a Stack Overflow example to a bag-of-words/tags format."""
sentence = tf.strings.join([example['tokens'], example['title']],
separator=' ')
# We represent that label (output tags) densely.
raw_tags = example['tags']
tags = tf.strings.split(raw_tags, sep='|')
tags = tag_table.lookup(tags)
tags, _ = tf.unique(tags)
tags = tf.one_hot(tags, len(tag_vocab) + NUM_OOV_BUCKETS)
tags = tf.reduce_max(tags, axis=0)
# We represent the features as a SparseTensor of {0, 1}s.
words = tf.strings.split(sentence)
tokens = word_table.lookup(words)
tokens, _ = tf.unique(tokens)
# Note: We could choose to use the word counts as the feature vector
# instead of just {0, 1} values (see tf.unique_with_counts).
tokens = tf.reshape(tokens, shape=(tf.size(tokens), 1))
tokens_st = tf.SparseTensor(
tokens,
tf.ones(tf.size(tokens), dtype=FEATURE_DTYPE),
dense_shape=(len(word_vocab) + NUM_OOV_BUCKETS,))
tokens_st = tf.sparse.reorder(tokens_st)
return BatchType(tokens_st, tags)
return to_ids
def build_preprocess_fn(word_vocab, tag_vocab):
@tf.function
def preprocess_fn(dataset):
to_ids = build_to_ids_fn(word_vocab, tag_vocab)
# We *don't* shuffle in order to make this colab deterministic for
# easier testing and reproducibility.
# But real-world training should use `.shuffle()`.
return dataset.map(to_ids, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return preprocess_fn
Un pequeño conjunto de datos de juguetes
Construimos un pequeño conjunto de datos de juguetes con un vocabulario global de 12 palabras y 3 clientes. Este pequeño ejemplo es útil para probar los casos de borde (por ejemplo, tenemos dos clientes con menos de MAX_TOKENS_SELECTED_PER_CLIENT = 6
tokens distintas, y una con más) y el desarrollo del código.
Sin embargo, los casos de uso del mundo real de este enfoque serían vocabularios globales de decenas de millones o más, con quizás miles de tokens distintos apareciendo en cada cliente. Debido a que el formato de los datos es la misma, la extensión de los problemas del banco de pruebas más realistas, por ejemplo, el tff.simulation.datasets.stackoverflow.load_data()
conjunto de datos, debe ser sencillo.
Primero, definimos nuestros vocabularios de palabras y etiquetas.
# Features
FRUIT_WORDS = ['apple', 'orange', 'pear', 'kiwi']
VEGETABLE_WORDS = ['carrot', 'broccoli', 'arugula', 'peas']
FISH_WORDS = ['trout', 'tuna', 'cod', 'salmon']
WORD_VOCAB = FRUIT_WORDS + VEGETABLE_WORDS + FISH_WORDS
# Labels
TAG_VOCAB = ['FRUIT', 'VEGETABLE', 'FISH']
Ahora, creamos 3 clientes con pequeños conjuntos de datos locales. Si está ejecutando este tutorial en colab, puede ser útil usar la función "reflejar celda en pestaña" para anclar esta celda y su salida con el fin de interpretar / verificar la salida de las funciones desarrolladas a continuación.
preprocess_fn = build_preprocess_fn(WORD_VOCAB, TAG_VOCAB)
def make_dataset(raw):
d = tf.data.Dataset.from_tensor_slices(
# Matches the StackOverflow formatting
collections.OrderedDict(
tokens=tf.constant([t[0] for t in raw]),
tags=tf.constant([t[1] for t in raw]),
title=['' for _ in raw]))
d = preprocess_fn(d)
return d
# 4 distinct tokens
CLIENT1_DATASET = make_dataset([
('apple orange apple orange', 'FRUIT'),
('carrot trout', 'VEGETABLE|FISH'),
('orange apple', 'FRUIT'),
('orange', 'ORANGE|CITRUS') # 2 OOV tag
])
# 6 distinct tokens
CLIENT2_DATASET = make_dataset([
('pear cod', 'FRUIT|FISH'),
('arugula peas', 'VEGETABLE'),
('kiwi pear', 'FRUIT'),
('sturgeon', 'FISH'), # OOV word
('sturgeon bass', 'FISH') # 2 OOV words
])
# A client with all possible words & tags (13 distinct tokens).
# With MAX_TOKENS_SELECTED_PER_CLIENT = 6, we won't download the model
# slices for all tokens that occur on this client.
CLIENT3_DATASET = make_dataset([
(' '.join(WORD_VOCAB + ['oovword']), '|'.join(TAG_VOCAB)),
# Mathe the OOV token and 'salmon' occur in the largest number
# of examples on this client:
('salmon oovword', 'FISH|OOVTAG')
])
print('Word vocab')
for i, word in enumerate(WORD_VOCAB):
print(f'{i:2d} {word}')
print('\nTag vocab')
for i, tag in enumerate(TAG_VOCAB):
print(f'{i:2d} {tag}')
Word vocab 0 apple 1 orange 2 pear 3 kiwi 4 carrot 5 broccoli 6 arugula 7 peas 8 trout 9 tuna 10 cod 11 salmon Tag vocab 0 FRUIT 1 VEGETABLE 2 FISH
Defina constantes para los números sin procesar de características de entrada (tokens / palabras) y etiquetas (etiquetas de publicación). Nuestros espacios de entrada / salida reales son NUM_OOV_BUCKETS = 1
más grande porque se añade un símbolo OOV / etiqueta.
NUM_WORDS = len(WORD_VOCAB)
NUM_TAGS = len(TAG_VOCAB)
WORD_VOCAB_SIZE = NUM_WORDS + NUM_OOV_BUCKETS
TAG_VOCAB_SIZE = NUM_TAGS + NUM_OOV_BUCKETS
Cree versiones por lotes de los conjuntos de datos y lotes individuales, que serán útiles para probar el código a medida que avanzamos.
batched_dataset1 = CLIENT1_DATASET.batch(2)
batched_dataset2 = CLIENT2_DATASET.batch(3)
batched_dataset3 = CLIENT3_DATASET.batch(2)
batch1 = next(iter(batched_dataset1))
batch2 = next(iter(batched_dataset2))
batch3 = next(iter(batched_dataset3))
Definir un modelo con entradas escasas
Usamos un modelo de regresión logística independiente simple para cada etiqueta.
def create_logistic_model(word_vocab_size: int, vocab_tags_size: int):
model = tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(word_vocab_size,), sparse=True),
tf.keras.layers.Dense(
vocab_tags_size,
activation='sigmoid',
kernel_initializer=tf.keras.initializers.zeros,
# For simplicity, don't use a bias vector; this means the model
# is a single tensor, and we only need sparse aggregation of
# the per-token slices of the model. Generalizing to also handle
# other model weights that are fully updated
# (non-dense broadcast and aggregate) would be a good exercise.
use_bias=False),
])
return model
Asegurémonos de que funcione, primero haciendo predicciones:
model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
p = model.predict(batch1.tokens)
print(p)
[[0.5 0.5 0.5 0.5] [0.5 0.5 0.5 0.5]]
Y una formación sencilla centralizada:
model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.001),
loss=tf.keras.losses.BinaryCrossentropy())
model.train_on_batch(batch1.tokens, batch1.tags)
Bloques de construcción para el cálculo federado
Vamos a implementar una versión simple del valor promedio Federados algoritmo con la diferencia fundamental de que cada dispositivo sólo se descarga un subconjunto relevante del modelo, y sólo contribuye cambios a ese subconjunto.
Utilizamos M
como forma abreviada de MAX_TOKENS_SELECTED_PER_CLIENT
. En un nivel alto, una ronda de capacitación implica estos pasos:
Cada cliente participante escanea su conjunto de datos local, analiza las cadenas de entrada y las asigna a los tokens correctos (índices int). Esto requiere el acceso al diccionario mundial (grande) (esto podría evitarse usando función de hash técnicas). Luego contamos escasamente cuántas veces ocurre cada token. Si
U
fichas únicas se producen en el dispositivo, elegimos losnum_actual_tokens = min(U, M)
mayor número de testigos frecuentes al tren.Los clientes utilizan
federated_select
para recuperar los coeficientes del modelo para losnum_actual_tokens
seleccionados fichas desde el servidor. Cada modelo rebanada es un tensor de la forma(TAG_VOCAB_SIZE, )
, por lo que el total de datos transmitidos al cliente es en la mayoría de tamañoTAG_VOCAB_SIZE * M
(véase la nota a continuación).Los clientes construir un mapeo
global_token -> local_token
donde el testigo local (int index) es el índice de la ficha global en la lista de símbolos seleccionados.Los clientes utilizan una "pequeña" versión del modelo global que sólo tiene coeficientes de a lo sumo
M
fichas, de la gama[0, num_actual_tokens)
. Elglobal -> local
asignación se utiliza para inicializar los parámetros de este modelo densas de las rodajas de modelo seleccionados.Los clientes a capacitar su modelo local usando SGD en los datos que se procesan con el
global -> local
mapeo.Los clientes recurren los parámetros de su modelo local en
IndexedSlices
actualizaciones utilizando ellocal -> global
asignación de índice de las filas. El servidor agrega estas actualizaciones mediante una agregación de suma escasa.El servidor toma el resultado (denso) de la agregación anterior, lo divide por el número de clientes que participan y aplica la actualización promedio resultante al modelo global.
En esta sección construimos los cimientos de estos pasos, que luego se combinan en una final federated_computation
que captura la lógica completa de una ronda de entrenamiento.
Contar fichas de cliente y decidir qué modelo de rebanadas de federated_select
Cada dispositivo debe decidir qué "segmentos" del modelo son relevantes para su conjunto de datos de entrenamiento local. Para nuestro problema, hacemos esto contando (¡escasamente!) Cuántos ejemplos contienen cada token en el conjunto de datos de entrenamiento del cliente.
@tf.function
def token_count_fn(token_counts, batch):
"""Adds counts from `batch` to the running `token_counts` sum."""
# Sum across the batch dimension.
flat_tokens = tf.sparse.reduce_sum(
batch.tokens, axis=0, output_is_sparse=True)
flat_tokens = tf.cast(flat_tokens, dtype=TOKEN_COUNT_DTYPE)
return tf.sparse.add(token_counts, flat_tokens)
# Simple tests
# Create the initial zero token counts using empty tensors.
initial_token_counts = tf.SparseTensor(
indices=tf.zeros(shape=(0, 1), dtype=TOKEN_DTYPE),
values=tf.zeros(shape=(0,), dtype=TOKEN_COUNT_DTYPE),
dense_shape=(WORD_VOCAB_SIZE,))
client_token_counts = batched_dataset1.reduce(initial_token_counts,
token_count_fn)
tokens = tf.reshape(client_token_counts.indices, (-1,)).numpy()
print('tokens:', tokens)
np.testing.assert_array_equal(tokens, [0, 1, 4, 8])
# The count is the number of *examples* in which the token/word
# occurs, not the total number of occurences, since we still featurize
# multiple occurences in the same example as a "1".
counts = client_token_counts.values.numpy()
print('counts:', counts)
np.testing.assert_array_equal(counts, [2, 3, 1, 1])
tokens: [0 1 4 8] counts: [2 3 1 1]
Vamos a seleccionar los parámetros de los modelos correspondientes a la MAX_TOKENS_SELECTED_PER_CLIENT
ocurre con más frecuencia fichas en el dispositivo. Si menos de esta cantidad de fichas ocurren en el dispositivo, que la almohadilla de la lista para permitir el uso de federated_select
.
Tenga en cuenta que otras estrategias posiblemente sean mejores, por ejemplo, seleccionar tokens al azar (quizás en función de su probabilidad de ocurrencia). Esto garantizaría que todas las secciones del modelo (para las que el cliente tiene datos) tengan alguna posibilidad de actualizarse.
@tf.function
def keys_for_client(client_dataset, max_tokens_per_client):
"""Computes a set of max_tokens_per_client keys."""
initial_token_counts = tf.SparseTensor(
indices=tf.zeros((0, 1), dtype=TOKEN_DTYPE),
values=tf.zeros((0,), dtype=TOKEN_COUNT_DTYPE),
dense_shape=(WORD_VOCAB_SIZE,))
client_token_counts = client_dataset.reduce(initial_token_counts,
token_count_fn)
# Find the most-frequently occuring tokens
tokens = tf.reshape(client_token_counts.indices, shape=(-1,))
counts = client_token_counts.values
perm = tf.argsort(counts, direction='DESCENDING')
tokens = tf.gather(tokens, perm)
counts = tf.gather(counts, perm)
num_raw_tokens = tf.shape(tokens)[0]
actual_num_tokens = tf.minimum(max_tokens_per_client, num_raw_tokens)
selected_tokens = tokens[:actual_num_tokens]
paddings = [[0, max_tokens_per_client - tf.shape(selected_tokens)[0]]]
padded_tokens = tf.pad(selected_tokens, paddings=paddings)
# Make sure the type is statically determined
padded_tokens = tf.reshape(padded_tokens, shape=(max_tokens_per_client,))
# We will pass these tokens as keys into `federated_select`, which
# requires SELECT_KEY_DTYPE=tf.int32 keys.
padded_tokens = tf.cast(padded_tokens, dtype=SELECT_KEY_DTYPE)
return padded_tokens, actual_num_tokens
# Simple test
# Case 1: actual_num_tokens > max_tokens_per_client
selected_tokens, actual_num_tokens = keys_for_client(batched_dataset1, 3)
assert tf.size(selected_tokens) == 3
assert actual_num_tokens == 3
# Case 2: actual_num_tokens < max_tokens_per_client
selected_tokens, actual_num_tokens = keys_for_client(batched_dataset1, 10)
assert tf.size(selected_tokens) == 10
assert actual_num_tokens == 4
Asignar tokens globales a tokens locales
La selección anterior nos da un conjunto denso de fichas en el rango [0, actual_num_tokens)
que utilizaremos para el modelo en el dispositivo. Sin embargo, el conjunto de datos que leemos tiene fichas de la gama mucho más amplia mundial vocabulario [0, WORD_VOCAB_SIZE)
.
Por lo tanto, necesitamos mapear los tokens globales a sus tokens locales correspondientes. Los ID de testigo locales están simplemente dados por los índices en los selected_tokens
tensor computado en el paso anterior.
@tf.function
def map_to_local_token_ids(client_data, client_keys):
global_to_local = tf.lookup.StaticHashTable(
# Note int32 -> int64 maps are not supported
tf.lookup.KeyValueTensorInitializer(
keys=tf.cast(client_keys, dtype=TOKEN_DTYPE),
# Note we need to use tf.shape, not the static
# shape client_keys.shape[0]
values=tf.range(0, limit=tf.shape(client_keys)[0],
dtype=TOKEN_DTYPE)),
# We use -1 for tokens that were not selected, which can occur for clients
# with more than MAX_TOKENS_SELECTED_PER_CLIENT distinct tokens.
# We will simply remove these invalid indices from the batch below.
default_value=-1)
def to_local_ids(sparse_tokens):
indices_t = tf.transpose(sparse_tokens.indices)
batch_indices = indices_t[0] # First column
tokens = indices_t[1] # Second column
tokens = tf.map_fn(
lambda global_token_id: global_to_local.lookup(global_token_id), tokens)
# Remove tokens that aren't actually available (looked up as -1):
available_tokens = tokens >= 0
tokens = tokens[available_tokens]
batch_indices = batch_indices[available_tokens]
updated_indices = tf.transpose(
tf.concat([[batch_indices], [tokens]], axis=0))
st = tf.sparse.SparseTensor(
updated_indices,
tf.ones(tf.size(tokens), dtype=FEATURE_DTYPE),
dense_shape=sparse_tokens.dense_shape)
st = tf.sparse.reorder(st)
return st
return client_data.map(lambda b: BatchType(to_local_ids(b.tokens), b.tags))
# Simple test
client_keys, actual_num_tokens = keys_for_client(
batched_dataset3, MAX_TOKENS_SELECTED_PER_CLIENT)
client_keys = client_keys[:actual_num_tokens]
d = map_to_local_token_ids(batched_dataset3, client_keys)
batch = next(iter(d))
all_tokens = tf.gather(batch.tokens.indices, indices=1, axis=1)
# Confirm we have local indices in the range [0, MAX):
assert tf.math.reduce_max(all_tokens) < MAX_TOKENS_SELECTED_PER_CLIENT
assert tf.math.reduce_max(all_tokens) >= 0
Entrene el (sub) modelo local en cada cliente
Nota federated_select
devolverá los sectores seleccionados como tf.data.Dataset
en el mismo orden que las teclas de selección. Entonces, primero definimos una función de utilidad para tomar tal conjunto de datos y convertirlo en un solo tensor denso que se puede usar como los pesos del modelo del modelo de cliente.
@tf.function
def slices_dataset_to_tensor(slices_dataset):
"""Convert a dataset of slices to a tensor."""
# Use batching to gather all of the slices into a single tensor.
d = slices_dataset.batch(MAX_TOKENS_SELECTED_PER_CLIENT,
drop_remainder=False)
iter_d = iter(d)
tensor = next(iter_d)
# Make sure we have consumed everything
opt = iter_d.get_next_as_optional()
tf.Assert(tf.logical_not(opt.has_value()), data=[''], name='CHECK_EMPTY')
return tensor
# Simple test
weights = np.random.random(
size=(MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE)).astype(np.float32)
model_slices_as_dataset = tf.data.Dataset.from_tensor_slices(weights)
weights2 = slices_dataset_to_tensor(model_slices_as_dataset)
np.testing.assert_array_equal(weights, weights2)
Ahora tenemos todos los componentes que necesitamos para definir un ciclo de entrenamiento local simple que se ejecutará en cada cliente.
@tf.function
def client_train_fn(model, client_optimizer,
model_slices_as_dataset, client_data,
client_keys, actual_num_tokens):
initial_model_weights = slices_dataset_to_tensor(model_slices_as_dataset)
assert len(model.trainable_variables) == 1
model.trainable_variables[0].assign(initial_model_weights)
# Only keep the "real" (unpadded) keys.
client_keys = client_keys[:actual_num_tokens]
client_data = map_to_local_token_ids(client_data, client_keys)
loss_fn = tf.keras.losses.BinaryCrossentropy()
for features, labels in client_data:
with tf.GradientTape() as tape:
predictions = model(features)
loss = loss_fn(labels, predictions)
grads = tape.gradient(loss, model.trainable_variables)
client_optimizer.apply_gradients(zip(grads, model.trainable_variables))
model_weights_delta = model.trainable_weights[0] - initial_model_weights
model_weights_delta = tf.slice(model_weights_delta, begin=[0, 0],
size=[actual_num_tokens, -1])
return client_keys, model_weights_delta
# Simple test
# Note if you execute this cell a second time, you need to also re-execute
# the preceeding cell to avoid "tf.function-decorated function tried to
# create variables on non-first call" errors.
on_device_model = create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT,
TAG_VOCAB_SIZE)
client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
client_keys, actual_num_tokens = keys_for_client(
batched_dataset2, MAX_TOKENS_SELECTED_PER_CLIENT)
model_slices_as_dataset = tf.data.Dataset.from_tensor_slices(
np.zeros((MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE),
dtype=np.float32))
keys, delta = client_train_fn(
on_device_model,
client_optimizer,
model_slices_as_dataset,
client_data=batched_dataset3,
client_keys=client_keys,
actual_num_tokens=actual_num_tokens)
print(delta)
Rebanadas indexadas agregadas
Utilizamos tff.federated_aggregate
para construir una suma escasa federado para IndexedSlices
. Esta sencilla aplicación tiene la limitación de que el dense_shape
se conoce de antemano de forma estática. Tenga en cuenta también que esta suma es única semi-escasa, en el sentido de que el cliente -> servidor de comunicación es escasa, pero el servidor mantiene una representación densa de la suma de accumulate
y merge
, y da salida a esta representación densa.
def federated_indexed_slices_sum(slice_indices, slice_values, dense_shape):
"""
Sumes IndexedSlices@CLIENTS to a dense @SERVER Tensor.
Intermediate aggregation is performed by converting to a dense representation,
which may not be suitable for all applications.
Args:
slice_indices: An IndexedSlices.indices tensor @CLIENTS.
slice_values: An IndexedSlices.values tensor @CLIENTS.
dense_shape: A statically known dense shape.
Returns:
A dense tensor placed @SERVER representing the sum of the client's
IndexedSclies.
"""
slices_dtype = slice_values.type_signature.member.dtype
zero = tff.tf_computation(
lambda: tf.zeros(dense_shape, dtype=slices_dtype))()
@tf.function
def accumulate_slices(dense, client_value):
indices, slices = client_value
# There is no built-in way to add `IndexedSlices`, but
# tf.convert_to_tensor is a quick way to convert to a dense representation
# so we can add them.
return dense + tf.convert_to_tensor(
tf.IndexedSlices(slices, indices, dense_shape))
return tff.federated_aggregate(
(slice_indices, slice_values),
zero=zero,
accumulate=tff.tf_computation(accumulate_slices),
merge=tff.tf_computation(lambda d1, d2: tf.add(d1, d2, name='merge')),
report=tff.tf_computation(lambda d: d))
Construir un mínimo federated_computation
como una prueba
dense_shape = (6, 2)
indices_type = tff.TensorType(tf.int64, (None,))
values_type = tff.TensorType(tf.float32, (None, 2))
client_slice_type = tff.type_at_clients(
(indices_type, values_type))
@tff.federated_computation(client_slice_type)
def test_sum_indexed_slices(indices_values_at_client):
indices, values = indices_values_at_client
return federated_indexed_slices_sum(indices, values, dense_shape)
print(test_sum_indexed_slices.type_signature)
({<int64[?],float32[?,2]>}@CLIENTS -> float32[6,2]@SERVER)
x = tf.IndexedSlices(
values=np.array([[2., 2.1], [0., 0.1], [1., 1.1], [5., 5.1]],
dtype=np.float32),
indices=[2, 0, 1, 5],
dense_shape=dense_shape)
y = tf.IndexedSlices(
values=np.array([[0., 0.3], [3.1, 3.2]], dtype=np.float32),
indices=[1, 3],
dense_shape=dense_shape)
# Sum one.
result = test_sum_indexed_slices([(x.indices, x.values)])
np.testing.assert_array_equal(tf.convert_to_tensor(x), result)
# Sum two.
expected = [[0., 0.1], [1., 1.4], [2., 2.1], [3.1, 3.2], [0., 0.], [5., 5.1]]
result = test_sum_indexed_slices([(x.indices, x.values), (y.indices, y.values)])
np.testing.assert_array_almost_equal(expected, result)
Poniendo todo junto en un federated_computation
Ahora tenemos utiliza TFF para unir entre sí los componentes en un tff.federated_computation
.
DENSE_MODEL_SHAPE = (WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
client_data_type = tff.SequenceType(batched_dataset1.element_spec)
model_type = tff.TensorType(tf.float32, shape=DENSE_MODEL_SHAPE)
Usamos una función básica de entrenamiento del servidor basada en Federated Averaging, aplicando la actualización con una tasa de aprendizaje del servidor de 1.0. Es importante que apliquemos una actualización (delta) al modelo, en lugar de simplemente promediar los modelos proporcionados por el cliente, ya que, de lo contrario, si un cliente no entrena una porción determinada del modelo en una ronda determinada, sus coeficientes se podrían poner a cero. fuera.
@tff.tf_computation
def server_update(current_model_weights, update_sum, num_clients):
average_update = update_sum / num_clients
return current_model_weights + average_update
Necesitamos un par más tff.tf_computation
componentes:
# Function to select slices from the model weights in federated_select:
select_fn = tff.tf_computation(
lambda model_weights, index: tf.gather(model_weights, index))
# We need to wrap `client_train_fn` as a `tff.tf_computation`, making
# sure we do any operations that might construct `tf.Variable`s outside
# of the `tf.function` we are wrapping.
@tff.tf_computation
def client_train_fn_tff(model_slices_as_dataset, client_data, client_keys,
actual_num_tokens):
# Note this is amaller than the global model, using
# MAX_TOKENS_SELECTED_PER_CLIENT which is much smaller than WORD_VOCAB_SIZE.
# W7e would like a model of size `actual_num_tokens`, but we
# can't build the model dynamically, so we will slice off the padded
# weights at the end.
client_model = create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT,
TAG_VOCAB_SIZE)
client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
return client_train_fn(client_model, client_optimizer,
model_slices_as_dataset, client_data, client_keys,
actual_num_tokens)
@tff.tf_computation
def keys_for_client_tff(client_data):
return keys_for_client(client_data, MAX_TOKENS_SELECTED_PER_CLIENT)
¡Ahora estamos listos para juntar todas las piezas!
@tff.federated_computation(
tff.type_at_server(model_type), tff.type_at_clients(client_data_type))
def sparse_model_update(server_model, client_data):
max_tokens = tff.federated_value(MAX_TOKENS_SELECTED_PER_CLIENT, tff.SERVER)
keys_at_clients, actual_num_tokens = tff.federated_map(
keys_for_client_tff, client_data)
model_slices = tff.federated_select(keys_at_clients, max_tokens, server_model,
select_fn)
update_keys, update_slices = tff.federated_map(
client_train_fn_tff,
(model_slices, client_data, keys_at_clients, actual_num_tokens))
dense_update_sum = federated_indexed_slices_sum(update_keys, update_slices,
DENSE_MODEL_SHAPE)
num_clients = tff.federated_sum(tff.federated_value(1.0, tff.CLIENTS))
updated_server_model = tff.federated_map(
server_update, (server_model, dense_update_sum, num_clients))
return updated_server_model
print(sparse_model_update.type_signature)
(<server_model=float32[13,4]@SERVER,client_data={<tokens=<indices=int64[?,2],values=int32[?],dense_shape=int64[2]>,tags=float32[?,4]>*}@CLIENTS> -> float32[13,4]@SERVER)
¡Entrenemos un modelo!
Ahora que tenemos nuestra función de entrenamiento, probémosla.
server_model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
server_model.compile( # Compile to make evaluation easy.
optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.0), # Unused
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=[
tf.keras.metrics.Precision(name='precision'),
tf.keras.metrics.AUC(name='auc'),
tf.keras.metrics.Recall(top_k=2, name='recall_at_2'),
])
def evaluate(model, dataset, name):
metrics = model.evaluate(dataset, verbose=0)
metrics_str = ', '.join([f'{k}={v:.2f}' for k, v in
(zip(server_model.metrics_names, metrics))])
print(f'{name}: {metrics_str}')
print('Before training')
evaluate(server_model, batched_dataset1, 'Client 1')
evaluate(server_model, batched_dataset2, 'Client 2')
evaluate(server_model, batched_dataset3, 'Client 3')
model_weights = server_model.trainable_weights[0]
client_datasets = [batched_dataset1, batched_dataset2, batched_dataset3]
for _ in range(10): # Run 10 rounds of FedAvg
# We train on 1, 2, or 3 clients per round, selecting
# randomly.
cohort_size = np.random.randint(1, 4)
clients = np.random.choice([0, 1, 2], cohort_size, replace=False)
print('Training on clients', clients)
model_weights = sparse_model_update(
model_weights, [client_datasets[i] for i in clients])
server_model.set_weights([model_weights])
print('After training')
evaluate(server_model, batched_dataset1, 'Client 1')
evaluate(server_model, batched_dataset2, 'Client 2')
evaluate(server_model, batched_dataset3, 'Client 3')
Before training Client 1: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.60 Client 2: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.50 Client 3: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.40 Training on clients [0 1] Training on clients [0 2 1] Training on clients [2 0] Training on clients [1 0 2] Training on clients [2] Training on clients [2 0] Training on clients [1 2 0] Training on clients [0] Training on clients [2] Training on clients [1 2] After training Client 1: loss=0.67, precision=0.80, auc=0.91, recall_at_2=0.80 Client 2: loss=0.68, precision=0.67, auc=0.96, recall_at_2=1.00 Client 3: loss=0.65, precision=1.00, auc=0.93, recall_at_2=0.80