TFF pour la recherche sur l'apprentissage fédéré : compression de modèles et de mises à jour

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHub Télécharger le cahier

Dans ce tutoriel, nous utilisons le EMNIST ensemble de données pour démontrer comment activer les algorithmes de compression avec perte pour réduire les coûts de communication dans l'algorithme de calcul de la moyenne fédérée utilisant l' tff.learning.build_federated_averaging_process API et l' tensor_encoding API. Pour plus de détails sur l'algorithme de calcul de la moyenne fédérée, consultez le livre d' apprentissage de la communication-efficace des réseaux profonds de données Décentralisée .

Avant de commencer

Avant de commencer, veuillez exécuter ce qui suit pour vous assurer que votre environnement est correctement configuré. Si vous ne voyez pas un message d' accueil, s'il vous plaît se référer à l' installation guide pour les instructions.

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade tensorflow-model-optimization
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
%load_ext tensorboard

import functools

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

from tensorflow_model_optimization.python.core.internal import tensor_encoding as te

Vérifiez si TFF fonctionne.

@tff.federated_computation
def hello_world():
  return 'Hello, World!'

hello_world()
b'Hello, World!'

Préparation des données d'entrée

Dans cette section, nous chargeons et prétraitons l'ensemble de données EMNIST inclus dans TFF. S'il vous plaît vérifier l' apprentissage fédéré de classification d'images tutoriel pour plus de détails sur EMNIST ensemble de données.

# This value only applies to EMNIST dataset, consider choosing appropriate
# values if switching to other datasets.
MAX_CLIENT_DATASET_SIZE = 418

CLIENT_EPOCHS_PER_ROUND = 1
CLIENT_BATCH_SIZE = 20
TEST_BATCH_SIZE = 500

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(
    only_digits=True)

def reshape_emnist_element(element):
  return (tf.expand_dims(element['pixels'], axis=-1), element['label'])

def preprocess_train_dataset(dataset):
  """Preprocessing function for the EMNIST training dataset."""
  return (dataset
          # Shuffle according to the largest client dataset
          .shuffle(buffer_size=MAX_CLIENT_DATASET_SIZE)
          # Repeat to do multiple local epochs
          .repeat(CLIENT_EPOCHS_PER_ROUND)
          # Batch to a fixed client batch size
          .batch(CLIENT_BATCH_SIZE, drop_remainder=False)
          # Preprocessing step
          .map(reshape_emnist_element))

emnist_train = emnist_train.preprocess(preprocess_train_dataset)

Définir un modèle

Ici , nous définissons un modèle de keras basé sur les orginales FedAvg CNN, puis enveloppez le modèle keras dans une instance de tff.learning.Model afin qu'il puisse être consommé par TFF.

Notez que nous aurons besoin d' une fonction qui produit un modèle au lieu de simplement un modèle directement. En outre, la fonction ne peut pas simplement capturer un modèle construit avant, il doit créer le modèle dans le contexte qu'il est appelé. La raison en est que TFF est conçu pour aller aux appareils et doit contrôler le moment où les ressources sont construites afin qu'elles puissent être capturées et empaquetées.

def create_original_fedavg_cnn_model(only_digits=True):
  """The CNN model used in https://arxiv.org/abs/1602.05629."""
  data_format = 'channels_last'

  max_pool = functools.partial(
      tf.keras.layers.MaxPooling2D,
      pool_size=(2, 2),
      padding='same',
      data_format=data_format)
  conv2d = functools.partial(
      tf.keras.layers.Conv2D,
      kernel_size=5,
      padding='same',
      data_format=data_format,
      activation=tf.nn.relu)

  model = tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
      conv2d(filters=32),
      max_pool(),
      conv2d(filters=64),
      max_pool(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(512, activation=tf.nn.relu),
      tf.keras.layers.Dense(10 if only_digits else 62),
      tf.keras.layers.Softmax(),
  ])

  return model

# Gets the type information of the input data. TFF is a strongly typed
# functional programming framework, and needs type information about inputs to 
# the model.
input_spec = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0]).element_spec

def tff_model_fn():
  keras_model = create_original_fedavg_cnn_model()
  return tff.learning.from_keras_model(
      keras_model=keras_model,
      input_spec=input_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

Entraînement du modèle et sortie des métriques d'entraînement

Nous sommes maintenant prêts à construire un algorithme de moyenne fédérée et à entraîner le modèle défini sur l'ensemble de données EMNIST.

D' abord , nous devons construire un algorithme de calcul de la moyenne fédérée utilisant l' tff.learning.build_federated_averaging_process API.

federated_averaging = tff.learning.build_federated_averaging_process(
    model_fn=tff_model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

Exécutons maintenant l'algorithme de moyenne fédérée. L'exécution d'un algorithme d'apprentissage fédéré du point de vue de TFF ressemble à ceci :

  1. Initialisez l'algorithme et obtenez l'état initial du serveur. L'état du serveur contient les informations nécessaires pour exécuter l'algorithme. Rappelez-vous, puisque TFF est fonctionnel, que cet état inclut à la fois tout état d'optimiseur utilisé par l'algorithme (par exemple les termes de quantité de mouvement) ainsi que les paramètres du modèle eux-mêmes - ceux-ci seront passés en tant qu'arguments et renvoyés en tant que résultats des calculs TFF.
  2. Exécutez l'algorithme tour par tour. À chaque tour, un nouvel état de serveur sera renvoyé à la suite de l'entraînement de chaque client au modèle sur ses données. Typiquement en un tour :
    1. Le serveur a diffusé le modèle à tous les clients participants.
    2. Chaque client effectue un travail basé sur le modèle et ses propres données.
    3. Le serveur agrège tout le modèle pour produire un état de serveur qui contient un nouveau modèle.

Pour plus de détails, voir s'il vous plaît personnalisée fédérée Algorithmes, Partie 2: Mise en œuvre fédérée calcul de la moyenne tutoriel.

Les métriques d'entraînement sont écrites dans le répertoire Tensorboard pour être affichées après l'entraînement.

Charger les fonctions utilitaires

def train(federated_averaging_process, num_rounds, num_clients_per_round, summary_writer):
  """Trains the federated averaging process and output metrics."""
  # Create a environment to get communication cost.
  environment = set_sizing_environment()

  # Initialize the Federated Averaging algorithm to get the initial server state.
  state = federated_averaging_process.initialize()

  with summary_writer.as_default():
    for round_num in range(num_rounds):
      # Sample the clients parcitipated in this round.
      sampled_clients = np.random.choice(
          emnist_train.client_ids,
          size=num_clients_per_round,
          replace=False)
      # Create a list of `tf.Dataset` instances from the data of sampled clients.
      sampled_train_data = [
          emnist_train.create_tf_dataset_for_client(client)
          for client in sampled_clients
      ]
      # Round one round of the algorithm based on the server state and client data
      # and output the new state and metrics.
      state, metrics = federated_averaging_process.next(state, sampled_train_data)

      # For more about size_info, please see https://www.tensorflow.org/federated/api_docs/python/tff/framework/SizeInfo
      size_info = environment.get_size_info()
      broadcasted_bits = size_info.broadcast_bits[-1]
      aggregated_bits = size_info.aggregate_bits[-1]

      print('round {:2d}, metrics={}, broadcasted_bits={}, aggregated_bits={}'.format(round_num, metrics, format_size(broadcasted_bits), format_size(aggregated_bits)))

      # Add metrics to Tensorboard.
      for name, value in metrics['train'].items():
          tf.summary.scalar(name, value, step=round_num)

      # Add broadcasted and aggregated data size to Tensorboard.
      tf.summary.scalar('cumulative_broadcasted_bits', broadcasted_bits, step=round_num)
      tf.summary.scalar('cumulative_aggregated_bits', aggregated_bits, step=round_num)
      summary_writer.flush()
# Clean the log directory to avoid conflicts.
try:
  tf.io.gfile.rmtree('/tmp/logs/scalars')
except tf.errors.OpError as e:
  pass  # Path doesn't exist

# Set up the log directory and writer for Tensorboard.
logdir = "/tmp/logs/scalars/original/"
summary_writer = tf.summary.create_file_writer(logdir)

train(federated_averaging_process=federated_averaging, num_rounds=10,
      num_clients_per_round=10, summary_writer=summary_writer)
round  0, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.07383774), ('loss', 2.3276227)])), ('stat', OrderedDict([('num_examples', 1097)]))]), broadcasted_bits=507.62Mibit, aggregated_bits=507.62Mibit
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.099585064), ('loss', 2.3152695)])), ('stat', OrderedDict([('num_examples', 964)]))]), broadcasted_bits=1015.24Mibit, aggregated_bits=1015.24Mibit
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09760766), ('loss', 2.3077576)])), ('stat', OrderedDict([('num_examples', 1045)]))]), broadcasted_bits=1.49Gibit, aggregated_bits=1.49Gibit
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.0963035), ('loss', 2.3066626)])), ('stat', OrderedDict([('num_examples', 1028)]))]), broadcasted_bits=1.98Gibit, aggregated_bits=1.98Gibit
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.10694184), ('loss', 2.3033001)])), ('stat', OrderedDict([('num_examples', 1066)]))]), broadcasted_bits=2.48Gibit, aggregated_bits=2.48Gibit
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.1185567), ('loss', 2.2999184)])), ('stat', OrderedDict([('num_examples', 970)]))]), broadcasted_bits=2.97Gibit, aggregated_bits=2.97Gibit
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.11751663), ('loss', 2.296883)])), ('stat', OrderedDict([('num_examples', 902)]))]), broadcasted_bits=3.47Gibit, aggregated_bits=3.47Gibit
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13063477), ('loss', 2.2990246)])), ('stat', OrderedDict([('num_examples', 1087)]))]), broadcasted_bits=3.97Gibit, aggregated_bits=3.97Gibit
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.12742382), ('loss', 2.2971866)])), ('stat', OrderedDict([('num_examples', 1083)]))]), broadcasted_bits=4.46Gibit, aggregated_bits=4.46Gibit
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13555992), ('loss', 2.2934425)])), ('stat', OrderedDict([('num_examples', 1018)]))]), broadcasted_bits=4.96Gibit, aggregated_bits=4.96Gibit

Démarrez TensorBoard avec le répertoire de journal racine spécifié ci-dessus pour afficher les métriques d'entraînement. Le chargement des données peut prendre quelques secondes. À l'exception de la perte et de la précision, nous produisons également la quantité de données diffusées et agrégées. Les données diffusées font référence aux tenseurs que le serveur envoie à chaque client, tandis que les données agrégées font référence aux tenseurs que chaque client renvoie au serveur.

%tensorboard --logdir /tmp/logs/scalars/ --port=0
Launching TensorBoard...
Reusing TensorBoard on port 34445 (pid 579503), started 1:53:14 ago. (Use '!kill 579503' to kill it.)
<IPython.core.display.Javascript at 0x7f9135ef1630>

Créer une fonction de diffusion et d'agrégation personnalisée

Maintenant , nous allons mettre en œuvre la fonction d'utiliser des algorithmes de compression avec perte sur les données diffusées et des données agrégées à l' aide de l' tensor_encoding API.

Tout d'abord, nous définissons deux fonctions :

  • broadcast_encoder_fn qui crée une instance de te.core.SimpleEncoder à tenseurs encodent ou variables dans le serveur à la communication client (données de diffusion).
  • mean_encoder_fn qui crée une instance de te.core.GatherEncoder de tenseurs de codage ou de variables dans le client vers le serveur communication vers (données d' agrégation).

Il est important de noter que nous n'appliquons pas une méthode de compression à l'ensemble du modèle à la fois. Au lieu de cela, nous décidons comment (et si) compresser chaque variable du modèle indépendamment. La raison en est qu'en général, les petites variables telles que les biais sont plus sensibles à l'inexactitude, et étant relativement petites, les économies potentielles de communication sont également relativement faibles. Par conséquent, nous ne compressons pas les petites variables par défaut. Dans cet exemple, nous appliquons une quantification uniforme sur 8 bits (256 buckets) à chaque variable de plus de 10 000 éléments et n'appliquons l'identité qu'aux autres variables.

def broadcast_encoder_fn(value):
  """Function for building encoded broadcast."""
  spec = tf.TensorSpec(value.shape, value.dtype)
  if value.shape.num_elements() > 10000:
    return te.encoders.as_simple_encoder(
        te.encoders.uniform_quantization(bits=8), spec)
  else:
    return te.encoders.as_simple_encoder(te.encoders.identity(), spec)


def mean_encoder_fn(tensor_spec):
  """Function for building a GatherEncoder."""
  spec = tf.TensorSpec(tensor_spec.shape, tensor_spec.dtype)
  if tensor_spec.shape.num_elements() > 10000:
    return te.encoders.as_gather_encoder(
        te.encoders.uniform_quantization(bits=8), spec)
  else:
    return te.encoders.as_gather_encoder(te.encoders.identity(), spec)

TFF fournit des API pour convertir la fonction du codeur dans un format tff.learning.build_federated_averaging_process API peut consommer. En utilisant le tff.learning.framework.build_encoded_broadcast_from_model et tff.aggregators.MeanFactory , nous pouvons créer deux objets qui peuvent être transmis dans broadcast_process et model_update_aggregation_factory agruments de tff.learning.build_federated_averaging_process pour créer des algorithmes avec calcul de la moyenne Federated un algorithme de compression avec perte.

encoded_broadcast_process = (
    tff.learning.framework.build_encoded_broadcast_process_from_model(
        tff_model_fn, broadcast_encoder_fn))

mean_factory = tff.aggregators.MeanFactory(
    tff.aggregators.EncodedSumFactory(mean_encoder_fn), # numerator
    tff.aggregators.EncodedSumFactory(mean_encoder_fn), # denominator
)

federated_averaging_with_compression = tff.learning.build_federated_averaging_process(
    tff_model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
    broadcast_process=encoded_broadcast_process,
    model_update_aggregation_factory=mean_factory)

Entraîner à nouveau le modèle

Exécutons maintenant le nouvel algorithme de moyenne fédérée.

logdir_for_compression = "/tmp/logs/scalars/compression/"
summary_writer_for_compression = tf.summary.create_file_writer(
    logdir_for_compression)

train(federated_averaging_process=federated_averaging_with_compression, 
      num_rounds=10,
      num_clients_per_round=10,
      summary_writer=summary_writer_for_compression)
round  0, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.093), ('loss', 2.3194966)])), ('stat', OrderedDict([('num_examples', 1000)]))]), broadcasted_bits=146.46Mibit, aggregated_bits=146.46Mibit
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.10432034), ('loss', 2.3079953)])), ('stat', OrderedDict([('num_examples', 949)]))]), broadcasted_bits=292.92Mibit, aggregated_bits=292.93Mibit
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.07886754), ('loss', 2.3101337)])), ('stat', OrderedDict([('num_examples', 989)]))]), broadcasted_bits=439.38Mibit, aggregated_bits=439.39Mibit
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09774436), ('loss', 2.305069)])), ('stat', OrderedDict([('num_examples', 1064)]))]), broadcasted_bits=585.84Mibit, aggregated_bits=585.85Mibit
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09404097), ('loss', 2.302943)])), ('stat', OrderedDict([('num_examples', 1074)]))]), broadcasted_bits=732.30Mibit, aggregated_bits=732.32Mibit
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09), ('loss', 2.304385)])), ('stat', OrderedDict([('num_examples', 1000)]))]), broadcasted_bits=878.77Mibit, aggregated_bits=878.78Mibit
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.14368932), ('loss', 2.2973824)])), ('stat', OrderedDict([('num_examples', 1030)]))]), broadcasted_bits=1.00Gibit, aggregated_bits=1.00Gibit
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.12140871), ('loss', 2.2993405)])), ('stat', OrderedDict([('num_examples', 1079)]))]), broadcasted_bits=1.14Gibit, aggregated_bits=1.14Gibit
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13600783), ('loss', 2.2953267)])), ('stat', OrderedDict([('num_examples', 1022)]))]), broadcasted_bits=1.29Gibit, aggregated_bits=1.29Gibit
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13844621), ('loss', 2.295768)])), ('stat', OrderedDict([('num_examples', 1004)]))]), broadcasted_bits=1.43Gibit, aggregated_bits=1.43Gibit

Redémarrez TensorBoard pour comparer les métriques d'entraînement entre deux exécutions.

Comme vous pouvez le voir dans Tensorboard, il y a une réduction significative entre les orginial et compression des courbes dans les broadcasted_bits et aggregated_bits parcelles alors que dans la loss et sparse_categorical_accuracy intrigue les deux courbes sont assez similaire.

En conclusion, nous avons implémenté un algorithme de compression qui peut atteindre des performances similaires à celles de l'algorithme original de moyenne fédérée tandis que le coût de communication est considérablement réduit.

%tensorboard --logdir /tmp/logs/scalars/ --port=0
Launching TensorBoard...
Reusing TensorBoard on port 34445 (pid 579503), started 1:54:12 ago. (Use '!kill 579503' to kill it.)
<IPython.core.display.Javascript at 0x7f9140eb5ef0>

Des exercices

Pour implémenter un algorithme de compression personnalisé et l'appliquer à la boucle d'entraînement, vous pouvez :

  1. Mettre en œuvre un nouvel algorithme de compression comme une sous - classe de EncodingStageInterface ou sa variante plus générale, AdaptiveEncodingStageInterface suivant cet exemple .
  2. Construire votre nouveau Encoder et spécialisé pour la diffusion de modèle ou la mise à jour modèle moyenne .
  3. Utilisez ces objets pour construire l'ensemble de calcul de la formation .

Les questions de recherche ouvertes potentiellement intéressantes comprennent : la quantification non uniforme, la compression sans perte telle que le codage huffman et les mécanismes d'adaptation de la compression en fonction des informations des cycles de formation précédents.

Lectures recommandées :