TFF per la ricerca sull'apprendimento federato: compressione di modelli e aggiornamenti

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza la fonte su GitHub Scarica taccuino

In questo tutorial, utilizziamo l'EMNIST set di dati per dimostrare come attivare algoritmi di compressione con perdita di ridurre i costi di comunicazione con l'algoritmo di calcolo della media federata utilizzando il tff.learning.build_federated_averaging_process API e tensor_encoding API. Per maggiori dettagli su l'algoritmo di calcolo della media federati, vedere il documento di comunicazione-efficace apprendimento di profonda reti da decentrata dei dati .

Prima di iniziare

Prima di iniziare, eseguire quanto segue per assicurarsi che l'ambiente sia configurato correttamente. Se non vedi un saluto, si prega di fare riferimento alla installazione guida per le istruzioni.

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade tensorflow-model-optimization
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
%load_ext tensorboard

import functools

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

from tensorflow_model_optimization.python.core.internal import tensor_encoding as te

Verifica se TFF funziona.

@tff.federated_computation
def hello_world():
  return 'Hello, World!'

hello_world()
b'Hello, World!'

Preparazione dei dati di input

In questa sezione carichiamo e preprocessiamo il dataset EMNIST incluso in TFF. Si prega di verificare Federati di apprendimento per un'immagine di classificazione tutorial per ulteriori dettagli su EMNIST set di dati.

# This value only applies to EMNIST dataset, consider choosing appropriate
# values if switching to other datasets.
MAX_CLIENT_DATASET_SIZE = 418

CLIENT_EPOCHS_PER_ROUND = 1
CLIENT_BATCH_SIZE = 20
TEST_BATCH_SIZE = 500

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(
    only_digits=True)

def reshape_emnist_element(element):
  return (tf.expand_dims(element['pixels'], axis=-1), element['label'])

def preprocess_train_dataset(dataset):
  """Preprocessing function for the EMNIST training dataset."""
  return (dataset
          # Shuffle according to the largest client dataset
          .shuffle(buffer_size=MAX_CLIENT_DATASET_SIZE)
          # Repeat to do multiple local epochs
          .repeat(CLIENT_EPOCHS_PER_ROUND)
          # Batch to a fixed client batch size
          .batch(CLIENT_BATCH_SIZE, drop_remainder=False)
          # Preprocessing step
          .map(reshape_emnist_element))

emnist_train = emnist_train.preprocess(preprocess_train_dataset)

Definire un modello

Qui si definisce un modello keras basato sul orginial FedAvg CNN, e poi avvolgere il modello keras in un'istanza di tff.learning.Model in modo che possa essere consumato da TFF.

Si noti che abbiamo bisogno di una funzione che produce un modello invece di un modello direttamente. Inoltre, la funzione non può semplicemente catturare un modello pre-costruito, è necessario creare il modello nel contesto che si chiama. Il motivo è che TFF è progettato per andare sui dispositivi e necessita di controllo su quando le risorse vengono costruite in modo che possano essere catturate e impacchettate.

def create_original_fedavg_cnn_model(only_digits=True):
  """The CNN model used in https://arxiv.org/abs/1602.05629."""
  data_format = 'channels_last'

  max_pool = functools.partial(
      tf.keras.layers.MaxPooling2D,
      pool_size=(2, 2),
      padding='same',
      data_format=data_format)
  conv2d = functools.partial(
      tf.keras.layers.Conv2D,
      kernel_size=5,
      padding='same',
      data_format=data_format,
      activation=tf.nn.relu)

  model = tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
      conv2d(filters=32),
      max_pool(),
      conv2d(filters=64),
      max_pool(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(512, activation=tf.nn.relu),
      tf.keras.layers.Dense(10 if only_digits else 62),
      tf.keras.layers.Softmax(),
  ])

  return model

# Gets the type information of the input data. TFF is a strongly typed
# functional programming framework, and needs type information about inputs to 
# the model.
input_spec = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0]).element_spec

def tff_model_fn():
  keras_model = create_original_fedavg_cnn_model()
  return tff.learning.from_keras_model(
      keras_model=keras_model,
      input_spec=input_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

Addestramento del modello e output delle metriche di addestramento

Ora siamo pronti per costruire un algoritmo di media federata e addestrare il modello definito sul set di dati EMNIST.

In primo luogo abbiamo bisogno di costruire un algoritmo di calcolo della media federati utilizzando il tff.learning.build_federated_averaging_process API.

federated_averaging = tff.learning.build_federated_averaging_process(
    model_fn=tff_model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

Ora eseguiamo l'algoritmo di media federata. L'esecuzione di un algoritmo di apprendimento federato dal punto di vista di TFF si presenta così:

  1. Inizializza l'algoritmo e ottieni lo stato iniziale del server. Lo stato del server contiene le informazioni necessarie per eseguire l'algoritmo. Ricordiamo, poiché TFF è funzionale, che questo stato include sia qualsiasi stato di ottimizzazione utilizzato dall'algoritmo (ad es. termini di momento) sia i parametri del modello stessi: questi verranno passati come argomenti e restituiti come risultati dei calcoli TFF.
  2. Eseguire l'algoritmo round per round. In ogni round, verrà restituito un nuovo stato del server come risultato di ogni client che addestra il modello sui suoi dati. Tipicamente in un turno:
    1. Il server trasmette il modello a tutti i client partecipanti.
    2. Ogni cliente esegue il lavoro in base al modello e ai propri dati.
    3. Il server aggrega tutto il modello per produrre uno stato del server che contiene un nuovo modello.

Per maggiori dettagli, si veda personalizzati federati Algoritmi, Parte 2: Implementazione Federati della media tutorial.

Le metriche di addestramento vengono scritte nella directory Tensorboard per essere visualizzate dopo l'addestramento.

Carica funzioni di utilità

def train(federated_averaging_process, num_rounds, num_clients_per_round, summary_writer):
  """Trains the federated averaging process and output metrics."""
  # Create a environment to get communication cost.
  environment = set_sizing_environment()

  # Initialize the Federated Averaging algorithm to get the initial server state.
  state = federated_averaging_process.initialize()

  with summary_writer.as_default():
    for round_num in range(num_rounds):
      # Sample the clients parcitipated in this round.
      sampled_clients = np.random.choice(
          emnist_train.client_ids,
          size=num_clients_per_round,
          replace=False)
      # Create a list of `tf.Dataset` instances from the data of sampled clients.
      sampled_train_data = [
          emnist_train.create_tf_dataset_for_client(client)
          for client in sampled_clients
      ]
      # Round one round of the algorithm based on the server state and client data
      # and output the new state and metrics.
      state, metrics = federated_averaging_process.next(state, sampled_train_data)

      # For more about size_info, please see https://www.tensorflow.org/federated/api_docs/python/tff/framework/SizeInfo
      size_info = environment.get_size_info()
      broadcasted_bits = size_info.broadcast_bits[-1]
      aggregated_bits = size_info.aggregate_bits[-1]

      print('round {:2d}, metrics={}, broadcasted_bits={}, aggregated_bits={}'.format(round_num, metrics, format_size(broadcasted_bits), format_size(aggregated_bits)))

      # Add metrics to Tensorboard.
      for name, value in metrics['train'].items():
          tf.summary.scalar(name, value, step=round_num)

      # Add broadcasted and aggregated data size to Tensorboard.
      tf.summary.scalar('cumulative_broadcasted_bits', broadcasted_bits, step=round_num)
      tf.summary.scalar('cumulative_aggregated_bits', aggregated_bits, step=round_num)
      summary_writer.flush()
# Clean the log directory to avoid conflicts.
try:
  tf.io.gfile.rmtree('/tmp/logs/scalars')
except tf.errors.OpError as e:
  pass  # Path doesn't exist

# Set up the log directory and writer for Tensorboard.
logdir = "/tmp/logs/scalars/original/"
summary_writer = tf.summary.create_file_writer(logdir)

train(federated_averaging_process=federated_averaging, num_rounds=10,
      num_clients_per_round=10, summary_writer=summary_writer)
round  0, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.07383774), ('loss', 2.3276227)])), ('stat', OrderedDict([('num_examples', 1097)]))]), broadcasted_bits=507.62Mibit, aggregated_bits=507.62Mibit
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.099585064), ('loss', 2.3152695)])), ('stat', OrderedDict([('num_examples', 964)]))]), broadcasted_bits=1015.24Mibit, aggregated_bits=1015.24Mibit
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09760766), ('loss', 2.3077576)])), ('stat', OrderedDict([('num_examples', 1045)]))]), broadcasted_bits=1.49Gibit, aggregated_bits=1.49Gibit
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.0963035), ('loss', 2.3066626)])), ('stat', OrderedDict([('num_examples', 1028)]))]), broadcasted_bits=1.98Gibit, aggregated_bits=1.98Gibit
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.10694184), ('loss', 2.3033001)])), ('stat', OrderedDict([('num_examples', 1066)]))]), broadcasted_bits=2.48Gibit, aggregated_bits=2.48Gibit
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.1185567), ('loss', 2.2999184)])), ('stat', OrderedDict([('num_examples', 970)]))]), broadcasted_bits=2.97Gibit, aggregated_bits=2.97Gibit
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.11751663), ('loss', 2.296883)])), ('stat', OrderedDict([('num_examples', 902)]))]), broadcasted_bits=3.47Gibit, aggregated_bits=3.47Gibit
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13063477), ('loss', 2.2990246)])), ('stat', OrderedDict([('num_examples', 1087)]))]), broadcasted_bits=3.97Gibit, aggregated_bits=3.97Gibit
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.12742382), ('loss', 2.2971866)])), ('stat', OrderedDict([('num_examples', 1083)]))]), broadcasted_bits=4.46Gibit, aggregated_bits=4.46Gibit
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13555992), ('loss', 2.2934425)])), ('stat', OrderedDict([('num_examples', 1018)]))]), broadcasted_bits=4.96Gibit, aggregated_bits=4.96Gibit

Avvia TensorBoard con la directory del registro principale specificata sopra per visualizzare le metriche di addestramento. Il caricamento dei dati può richiedere alcuni secondi. Fatta eccezione per Perdita e Precisione, forniamo anche la quantità di dati trasmessi e aggregati. I dati trasmessi si riferiscono ai tensori che il server invia a ciascun client mentre i dati aggregati si riferiscono ai tensori che ogni client restituisce al server.

%tensorboard --logdir /tmp/logs/scalars/ --port=0
Launching TensorBoard...
Reusing TensorBoard on port 34445 (pid 579503), started 1:53:14 ago. (Use '!kill 579503' to kill it.)
<IPython.core.display.Javascript at 0x7f9135ef1630>

Crea una trasmissione personalizzata e una funzione di aggregazione

Ora cerchiamo di implementare la funzione di utilizzare algoritmi di compressione con perdita di dati trasmessi e dati aggregati utilizzando la tensor_encoding API.

Innanzitutto definiamo due funzioni:

  • broadcast_encoder_fn che crea un'istanza di te.core.SimpleEncoder per tensori codificare o variabili nel server per la comunicazione client (dati Broadcast).
  • mean_encoder_fn che crea un'istanza di te.core.GatherEncoder per tensori codificare o variabili nel client al server communicaiton (dati di aggregazione).

È importante notare che non applichiamo un metodo di compressione all'intero modello in una volta. Invece, decidiamo come (e se) comprimere ciascuna variabile del modello in modo indipendente. Il motivo è che generalmente piccole variabili come i bias sono più sensibili all'imprecisione e, essendo relativamente piccole, anche i potenziali risparmi di comunicazione sono relativamente piccoli. Quindi non comprimiamo piccole variabili per impostazione predefinita. In questo esempio, applichiamo la quantizzazione uniforme a 8 bit (256 bucket) a ogni variabile con più di 10000 elementi e applichiamo l'identità solo alle altre variabili.

def broadcast_encoder_fn(value):
  """Function for building encoded broadcast."""
  spec = tf.TensorSpec(value.shape, value.dtype)
  if value.shape.num_elements() > 10000:
    return te.encoders.as_simple_encoder(
        te.encoders.uniform_quantization(bits=8), spec)
  else:
    return te.encoders.as_simple_encoder(te.encoders.identity(), spec)


def mean_encoder_fn(tensor_spec):
  """Function for building a GatherEncoder."""
  spec = tf.TensorSpec(tensor_spec.shape, tensor_spec.dtype)
  if tensor_spec.shape.num_elements() > 10000:
    return te.encoders.as_gather_encoder(
        te.encoders.uniform_quantization(bits=8), spec)
  else:
    return te.encoders.as_gather_encoder(te.encoders.identity(), spec)

TFF fornisce API per convertire la funzione dell'encoder in un formato che tff.learning.build_federated_averaging_process API può consumare. Utilizzando il tff.learning.framework.build_encoded_broadcast_from_model e tff.aggregators.MeanFactory , possiamo creare due oggetti che possono essere trasmessi in broadcast_process e model_update_aggregation_factory agruments di tff.learning.build_federated_averaging_process per creare federata averaging algoritmi con un algoritmo di compressione lossy.

encoded_broadcast_process = (
    tff.learning.framework.build_encoded_broadcast_process_from_model(
        tff_model_fn, broadcast_encoder_fn))

mean_factory = tff.aggregators.MeanFactory(
    tff.aggregators.EncodedSumFactory(mean_encoder_fn), # numerator
    tff.aggregators.EncodedSumFactory(mean_encoder_fn), # denominator
)

federated_averaging_with_compression = tff.learning.build_federated_averaging_process(
    tff_model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
    broadcast_process=encoded_broadcast_process,
    model_update_aggregation_factory=mean_factory)

Addestrare di nuovo il modello

Ora eseguiamo il nuovo algoritmo di media federata.

logdir_for_compression = "/tmp/logs/scalars/compression/"
summary_writer_for_compression = tf.summary.create_file_writer(
    logdir_for_compression)

train(federated_averaging_process=federated_averaging_with_compression, 
      num_rounds=10,
      num_clients_per_round=10,
      summary_writer=summary_writer_for_compression)
round  0, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.093), ('loss', 2.3194966)])), ('stat', OrderedDict([('num_examples', 1000)]))]), broadcasted_bits=146.46Mibit, aggregated_bits=146.46Mibit
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.10432034), ('loss', 2.3079953)])), ('stat', OrderedDict([('num_examples', 949)]))]), broadcasted_bits=292.92Mibit, aggregated_bits=292.93Mibit
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.07886754), ('loss', 2.3101337)])), ('stat', OrderedDict([('num_examples', 989)]))]), broadcasted_bits=439.38Mibit, aggregated_bits=439.39Mibit
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09774436), ('loss', 2.305069)])), ('stat', OrderedDict([('num_examples', 1064)]))]), broadcasted_bits=585.84Mibit, aggregated_bits=585.85Mibit
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09404097), ('loss', 2.302943)])), ('stat', OrderedDict([('num_examples', 1074)]))]), broadcasted_bits=732.30Mibit, aggregated_bits=732.32Mibit
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09), ('loss', 2.304385)])), ('stat', OrderedDict([('num_examples', 1000)]))]), broadcasted_bits=878.77Mibit, aggregated_bits=878.78Mibit
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.14368932), ('loss', 2.2973824)])), ('stat', OrderedDict([('num_examples', 1030)]))]), broadcasted_bits=1.00Gibit, aggregated_bits=1.00Gibit
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.12140871), ('loss', 2.2993405)])), ('stat', OrderedDict([('num_examples', 1079)]))]), broadcasted_bits=1.14Gibit, aggregated_bits=1.14Gibit
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13600783), ('loss', 2.2953267)])), ('stat', OrderedDict([('num_examples', 1022)]))]), broadcasted_bits=1.29Gibit, aggregated_bits=1.29Gibit
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13844621), ('loss', 2.295768)])), ('stat', OrderedDict([('num_examples', 1004)]))]), broadcasted_bits=1.43Gibit, aggregated_bits=1.43Gibit

Avvia nuovamente TensorBoard per confrontare le metriche di allenamento tra due esecuzioni.

Come si può vedere nel Tensorboard, v'è una riduzione significativa tra i orginial e compression curve nei broadcasted_bits e aggregated_bits trame, mentre in loss e sparse_categorical_accuracy trama le due curve sono abbastanza simile.

In conclusione, abbiamo implementato un algoritmo di compressione che può ottenere prestazioni simili all'algoritmo originale di media federata mentre il costo di comminuzione è significativamente ridotto.

%tensorboard --logdir /tmp/logs/scalars/ --port=0
Launching TensorBoard...
Reusing TensorBoard on port 34445 (pid 579503), started 1:54:12 ago. (Use '!kill 579503' to kill it.)
<IPython.core.display.Javascript at 0x7f9140eb5ef0>

Esercizi

Per implementare un algoritmo di compressione personalizzato e applicarlo al ciclo di addestramento, puoi:

  1. Implementare un nuovo algoritmo di compressione come una sottoclasse di EncodingStageInterface o sua variante più generale, AdaptiveEncodingStageInterface seguendo questo esempio .
  2. Costruisci il tuo nuovo Encoder e si specializzano per il modello di trasmissione o modello di aggiornamento della media .
  3. Utilizzare gli oggetti per costruire l'intero calcolo di formazione .

Le domande di ricerca aperte potenzialmente preziose includono: quantizzazione non uniforme, compressione senza perdita di dati come la codifica Huffman e meccanismi per adattare la compressione in base alle informazioni dei precedenti cicli di addestramento.

Materiali di lettura consigliati: