Ver en TensorFlow.org | Ejecutar en Google Colab | Ver fuente en GitHub | Descargar cuaderno |
La noción de un conjunto de datos codificado por clientes (por ejemplo, usuarios) es esencial para la computación federada como se modela en TFF. TFF proporciona la interfaz tff.simulation.datasets.ClientData
abstraer sobre este concepto, y los conjuntos de datos donde se celebran TFF ( stackoverflow , Shakespeare , emnist , cifar100 y gldv2 ) toda implementar esta interfaz.
Si está trabajando en el aprendizaje federada con su propio conjunto de datos, TFF recomienda encarecidamente que implementar ya sea el ClientData
interfaz o utilizar una de las funciones de ayudante de TFF para generar un ClientData
que representa los datos en el disco, por ejemplo tff.simulation.datasets.ClientData.from_clients_and_fn
.
Como la mayoría de los ejemplos de extremo a extremo de TFF empezar con ClientData
objetos, la aplicación de la ClientData
interfaz con el conjunto de datos a medida que hará más fácil a través de Spelunk código existente escrito con TFF. Además, los tf.data.Datasets
que ClientData
construcciones se pueden repiten a lo largo directamente para producir estructuras de numpy
arrays, por lo ClientData
objetos se pueden utilizar con cualquier marco ML basado en Python antes de pasar a TFF.
Hay varios patrones con los que puede hacer su vida más fácil si tiene la intención de ampliar sus simulaciones a muchas máquinas o implementarlas. A continuación vamos a caminar a través de algunas de las formas en que podemos utilizar ClientData
y TFF para hacer nuestra pequeña escala iteración a la experimentación a gran escala a la producción de experiencia de implementación lo más suave posible.
¿Qué patrón debo usar para pasar ClientData a TFF?
Vamos a discutir dos usos de la TFF ClientData
de profundidad; si encaja en cualquiera de las dos categorías siguientes, claramente preferirá una sobre la otra. De lo contrario, es posible que necesite una comprensión más detallada de los pros y los contras de cada uno para tomar una decisión más matizada.
Quiero iterar lo más rápido posible en una máquina local; No necesito poder aprovechar fácilmente el tiempo de ejecución distribuido de TFF.
- Se debe pasar
tf.data.Datasets
a TFF directamente. - Esto le permite programar imperativamente con
tf.data.Dataset
objetos, y procesarlos de manera arbitraria. - Proporciona más flexibilidad que la siguiente opción; Enviar la lógica a los clientes requiere que esta lógica sea serializable.
- Se debe pasar
Quiero ejecutar mi computación federada en el tiempo de ejecución remoto de TFF, o planeo hacerlo pronto.
- En este caso, desea mapear la construcción y el preprocesamiento de conjuntos de datos a los clientes.
- Esto resulta en que pasa simplemente una lista de
client_ids
directamente a su cómputo federado. - Llevar la construcción y el preprocesamiento de conjuntos de datos a los clientes evita los cuellos de botella en la serialización y aumenta significativamente el rendimiento con cientos o miles de clientes.
Configurar un entorno de código abierto
# tensorflow_federated_nightly also bring in tf_nightly, which
# can causes a duplicate tensorboard install, leading to errors.
!pip uninstall --yes tensorboard tb-nightly
!pip install --quiet --upgrade tensorflow_federated_nightly
!pip install --quiet --upgrade nest_asyncio
import nest_asyncio
nest_asyncio.apply()
Importar paquetes
import collections
import time
import tensorflow as tf
import tensorflow_federated as tff
Manipular un objeto ClientData
Vamos a empezar por la carga y la exploración de TFF EMNIST ClientData
:
client_data, _ = tff.simulation.datasets.emnist.load_data()
Downloading emnist_all.sqlite.lzma: 100%|██████████| 170507172/170507172 [00:19<00:00, 8831921.67it/s] 2021-10-01 11:17:58.718735: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Inspeccionar el primer conjunto de datos nos puede decir qué tipo de ejemplos se encuentran en el ClientData
.
first_client_id = client_data.client_ids[0]
first_client_dataset = client_data.create_tf_dataset_for_client(
first_client_id)
print(first_client_dataset.element_spec)
# This information is also available as a `ClientData` property:
assert client_data.element_type_structure == first_client_dataset.element_spec
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])
Nota que los rendimientos de conjunto de datos collections.OrderedDict
objetos que tienen pixels
y label
llaves, donde píxeles es un tensor con forma de [28, 28]
. Supongamos que deseamos para aplanar nuestras entradas a la forma [784]
. Una posible manera de hacer esto sería aplicar una función de pre-procesamiento en nuestro ClientData
objeto.
def preprocess_dataset(dataset):
"""Create batches of 5 examples, and limit to 3 batches."""
def map_fn(input):
return collections.OrderedDict(
x=tf.reshape(input['pixels'], shape=(-1, 784)),
y=tf.cast(tf.reshape(input['label'], shape=(-1, 1)), tf.int64),
)
return dataset.batch(5).map(
map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).take(5)
preprocessed_client_data = client_data.preprocess(preprocess_dataset)
# Notice that we have both reshaped and renamed the elements of the ordered dict.
first_client_dataset = preprocessed_client_data.create_tf_dataset_for_client(
first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])
Es posible que queramos, además, realizar algún preprocesamiento más complejo (y posiblemente con estado), por ejemplo, barajar.
def preprocess_and_shuffle(dataset):
"""Applies `preprocess_dataset` above and shuffles the result."""
preprocessed = preprocess_dataset(dataset)
return preprocessed.shuffle(buffer_size=5)
preprocessed_and_shuffled = client_data.preprocess(preprocess_and_shuffle)
# The type signature will remain the same, but the batches will be shuffled.
first_client_dataset = preprocessed_and_shuffled.create_tf_dataset_for_client(
first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])
Interfaz con un tff.Computation
Ahora que podemos realizar algunas operaciones básicas con ClientData
objetos, estamos listos para los datos de alimentación a un tff.Computation
. Definimos una tff.templates.IterativeProcess
que implementa Federados de promedio , y explorar diferentes métodos de pasándole datos.
def model_fn():
model = tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(784,)),
tf.keras.layers.Dense(10, kernel_initializer='zeros'),
])
return tff.learning.from_keras_model(
model,
# Note: input spec is the _batched_ shape, and includes the
# label tensor which will be passed to the loss function. This model is
# therefore configured to accept data _after_ it has been preprocessed.
input_spec=collections.OrderedDict(
x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),
y=tf.TensorSpec(shape=[None, 1], dtype=tf.int64)),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
trainer = tff.learning.build_federated_averaging_process(
model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.01))
Antes de comenzar a trabajar con esta IterativeProcess
, un comentario sobre la semántica de ClientData
está en orden. Un ClientData
objeto representa la totalidad de la población disponible para la formación federados, que en general es no disponible para el entorno de ejecución de un sistema de producción FL y es específica para la simulación. ClientData
de hecho le da al usuario la capacidad de derivación de la computación federados por completo y simplemente entrenar un modelo del lado del servidor como de costumbre a través de ClientData.create_tf_dataset_from_all_clients
.
El entorno de simulación de TFF pone al investigador en completo control del bucle exterior. En particular, esto implica consideraciones de disponibilidad del cliente, abandono del cliente, etc., que deben ser abordadas por el usuario o el script del controlador Python. Uno podría, por ejemplo modelo de deserción cliente mediante el ajuste de la distribución de muestreo sobre sus ClientData's
client_ids
tales que los usuarios con más datos (y correspondientemente más largo a ejecutar cálculos locales) sería seleccionado con menor probabilidad.
Sin embargo, en un sistema federado real, el entrenador modelo no puede seleccionar explícitamente a los clientes; la selección de clientes se delega al sistema que está ejecutando el cómputo federado.
Pasando tf.data.Datasets
directamente a TFF
Una de las opciones que tenemos para la interconexión entre un ClientData
y un IterativeProcess
es el de construir tf.data.Datasets
en Python, y pasando estos conjuntos de datos a TFF.
Tenga en cuenta que si usamos nuestros preprocesados ClientData
los conjuntos de datos que dió son del tipo apropiado esperado por nuestro modelo definido anteriormente.
selected_client_ids = preprocessed_and_shuffled.client_ids[:10]
preprocessed_data_for_clients = [
preprocessed_and_shuffled.create_tf_dataset_for_client(
selected_client_ids[i]) for i in range(10)
]
state = trainer.initialize()
for _ in range(5):
t1 = time.time()
state, metrics = trainer.next(state, preprocessed_data_for_clients)
t2 = time.time()
print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:62: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.compat.v1.graph_util.extract_sub_graph` WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:62: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.compat.v1.graph_util.extract_sub_graph` loss 2.9005744457244873, round time 4.576513767242432 loss 3.113278388977051, round time 0.49641919136047363 loss 2.7581865787506104, round time 0.4904160499572754 loss 2.87259578704834, round time 0.48976993560791016 loss 3.1202380657196045, round time 0.6724586486816406
Si tomamos esta ruta, sin embargo, no seremos capaces de trivialmente para mover a la simulación multimáquina. Los conjuntos de datos que construimos en el tiempo de ejecución TensorFlow local puede capturar el estado del ambiente que rodea a pitón, y fallar en la serialización o deserialización cuando intentan estado de referencia, que ya no está disponible para ellos es. Esto puede manifestarse por ejemplo en el error de inescrutable de TensorFlow tensor_util.cc
:
Check failed: DT_VARIANT == input.dtype() (21 vs. 20)
Mapeo de la construcción y preprocesamiento sobre los clientes.
Para evitar este problema, TFF recomienda a sus usuarios a tener en cuenta el conjunto de datos de instancias y el procesamiento previo como algo que ocurre localmente en cada cliente, ya utilizar los ayudantes de TFF o federated_map
para funcionar de forma explícita este código procesamiento previo a cada cliente.
Conceptualmente, la razón para preferir esto es clara: en el tiempo de ejecución local de TFF, los clientes solo "accidentalmente" tienen acceso al entorno global de Python debido al hecho de que toda la orquestación federada ocurre en una sola máquina. Vale la pena señalar en este punto que un pensamiento similar da lugar a la filosofía funcional multiplataforma, siempre serializable de TFF.
TFF hace que un cambio tan sencilla a través ClientData's
atributo dataset_computation
, un tff.Computation
que toma un client_id
y devuelve el asociado tf.data.Dataset
.
Tenga en cuenta que preprocess
simplemente trabaja con dataset_computation
; la dataset_computation
atributo del preprocesado ClientData
incorpora toda la tubería de procesamiento previo que acabamos de definir:
print('dataset computation without preprocessing:')
print(client_data.dataset_computation.type_signature)
print('\n')
print('dataset computation with preprocessing:')
print(preprocessed_and_shuffled.dataset_computation.type_signature)
dataset computation without preprocessing: (string -> <label=int32,pixels=float32[28,28]>*) dataset computation with preprocessing: (string -> <x=float32[?,784],y=int64[?,1]>*)
Podríamos invocar dataset_computation
y recibir un conjunto de datos con ganas en el tiempo de ejecución de Python, pero el poder real de este enfoque se ejerce cuando componemos con un proceso iterativo de cálculo u otra para evitar la materialización de estos conjuntos de datos en el tiempo de ejecución ansiosos mundial en absoluto. TFF proporciona una función de ayuda tff.simulation.compose_dataset_computation_with_iterative_process
que puede ser utilizado para hacer exactamente esto.
trainer_accepting_ids = tff.simulation.compose_dataset_computation_with_iterative_process(
preprocessed_and_shuffled.dataset_computation, trainer)
Tanto este tff.templates.IterativeProcesses
y la de arriba funcionar de la misma manera; pero el ex acepta preprocesados conjuntos de datos de cliente, y el último acepta cadenas que representan los ID de cliente, la manipulación tanto en la construcción y el conjunto de datos de preprocesamiento en su cuerpo - de hecho state
se puede transmitir entre los dos.
for _ in range(5):
t1 = time.time()
state, metrics = trainer_accepting_ids.next(state, selected_client_ids)
t2 = time.time()
print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))
loss 2.8417396545410156, round time 1.6707067489624023 loss 2.7670371532440186, round time 0.5207102298736572 loss 2.665048122406006, round time 0.5302855968475342 loss 2.7213189601898193, round time 0.5313887596130371 loss 2.580148935317993, round time 0.5283482074737549
Escalando a un gran número de clientes
trainer_accepting_ids
de inmediato se pueden utilizar en tiempo de ejecución multimáquina de TFF, y evita materializar tf.data.Datasets
y el controlador (y por lo tanto la serialización de ellos y enviarlos a los trabajadores).
Esto acelera significativamente las simulaciones distribuidas, especialmente con una gran cantidad de clientes, y permite la agregación intermedia para evitar una sobrecarga similar de serialización / deserialización.
Deepdive opcional: componer manualmente la lógica de preprocesamiento en TFF
TFF está diseñado para la composicionalidad desde cero; el tipo de composición que acaba de realizar el ayudante de TFF está totalmente bajo nuestro control como usuarios. Podríamos tener manualmente componer el cálculo de procesamiento previo que acabamos de definir con la del entrenador propia next
sencillamente:
selected_clients_type = tff.FederatedType(preprocessed_and_shuffled.dataset_computation.type_signature.parameter, tff.CLIENTS)
@tff.federated_computation(trainer.next.type_signature.parameter[0], selected_clients_type)
def new_next(server_state, selected_clients):
preprocessed_data = tff.federated_map(preprocessed_and_shuffled.dataset_computation, selected_clients)
return trainer.next(server_state, preprocessed_data)
manual_trainer_with_preprocessing = tff.templates.IterativeProcess(initialize_fn=trainer.initialize, next_fn=new_next)
De hecho, esto es efectivamente lo que el ayudante que usamos está haciendo bajo el capó (además de realizar una verificación y manipulación de tipo adecuada). Incluso podríamos haber expresado la misma lógica ligeramente diferente, serializando preprocess_and_shuffle
en un tff.Computation
, y descomponer el federated_map
en un paso que construye conjuntos de datos de la ONU-preprocesado y otra que corre preprocess_and_shuffle
a cada cliente.
Podemos verificar que esta ruta más manual da como resultado cálculos con la misma firma de tipo que el ayudante de TFF (nombres de parámetros de módulo):
print(trainer_accepting_ids.next.type_signature)
print(manual_trainer_with_preprocessing.next.type_signature)
(<server_state=<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,federated_dataset={string}@CLIENTS> -> <<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,<broadcast=<>,aggregation=<mean_value=<>,mean_weight=<>>,train=<sparse_categorical_accuracy=float32,loss=float32>,stat=<num_examples=int64>>@SERVER>) (<server_state=<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,selected_clients={string}@CLIENTS> -> <<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,<broadcast=<>,aggregation=<mean_value=<>,mean_weight=<>>,train=<sparse_categorical_accuracy=float32,loss=float32>,stat=<num_examples=int64>>@SERVER>)