Federe Öğrenim Araştırması için TFF: Modelleme ve Güncelleme Sıkıştırma

TensorFlow.org'da görüntüleyin Google Colab'da çalıştırın Kaynağı GitHub'da görüntüleyin Not defterini indir

Bu eğitimde, kullandığımız EMNIST kullanarak Federe Averaging algoritmasında iletişim maliyetini azaltmak için kayıplı sıkıştırma algoritmaları etkinleştirme göstermek için veri kümesini tff.learning.build_federated_averaging_process API ve tensor_encoding API. Federe Averaging algoritması hakkında ayrıntılı bilgi için, kağıt bkz Yerinden Verilerden Derin Ağların Haberleşme-Verimli Öğrenme .

Başlamadan önce

Başlamadan önce, ortamınızın doğru şekilde kurulduğundan emin olmak için lütfen aşağıdakileri çalıştırın. Eğer bir selamlama görmüyorsanız, bakınız Kurulum talimatları için rehber.

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade tensorflow-model-optimization
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
%load_ext tensorboard

import functools

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

from tensorflow_model_optimization.python.core.internal import tensor_encoding as te

TFF'nin çalışıp çalışmadığını doğrulayın.

@tff.federated_computation
def hello_world():
  return 'Hello, World!'

hello_world()
b'Hello, World!'

Giriş verilerinin hazırlanması

Bu bölümde TFF'de bulunan EMNIST veri setini yükleyip ön işleme alıyoruz. Kontrol edin Görüntü Sınıflandırma için Federe Öğrenme EMNIST veri kümesi hakkında daha fazla ayrıntı için öğretici.

# This value only applies to EMNIST dataset, consider choosing appropriate
# values if switching to other datasets.
MAX_CLIENT_DATASET_SIZE = 418

CLIENT_EPOCHS_PER_ROUND = 1
CLIENT_BATCH_SIZE = 20
TEST_BATCH_SIZE = 500

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(
    only_digits=True)

def reshape_emnist_element(element):
  return (tf.expand_dims(element['pixels'], axis=-1), element['label'])

def preprocess_train_dataset(dataset):
  """Preprocessing function for the EMNIST training dataset."""
  return (dataset
          # Shuffle according to the largest client dataset
          .shuffle(buffer_size=MAX_CLIENT_DATASET_SIZE)
          # Repeat to do multiple local epochs
          .repeat(CLIENT_EPOCHS_PER_ROUND)
          # Batch to a fixed client batch size
          .batch(CLIENT_BATCH_SIZE, drop_remainder=False)
          # Preprocessing step
          .map(reshape_emnist_element))

emnist_train = emnist_train.preprocess(preprocess_train_dataset)

Model tanımlama

Burada orginial FedAvg CNN dayalı bir keras modeli tanımlamak ve sonra bir örneği de keras modeli sarmak tff.learning.Model o TFF tarafından tüketilebilir şekilde yerleştirin.

Doğrudan basit bir model yerine bir model üreten bir işlev gerektiğini unutmayın. Buna ek olarak, sadece önceden inşa modeli yakalamak olamaz fonksiyonu, buna denir bu bağlamda modeli oluşturmak gerekir. Bunun nedeni, TFF'nin cihazlara gitmek üzere tasarlanmış olması ve yakalanıp paketlenebilmeleri için kaynakların ne zaman oluşturulduğu üzerinde kontrole ihtiyaç duymasıdır.

def create_original_fedavg_cnn_model(only_digits=True):
  """The CNN model used in https://arxiv.org/abs/1602.05629."""
  data_format = 'channels_last'

  max_pool = functools.partial(
      tf.keras.layers.MaxPooling2D,
      pool_size=(2, 2),
      padding='same',
      data_format=data_format)
  conv2d = functools.partial(
      tf.keras.layers.Conv2D,
      kernel_size=5,
      padding='same',
      data_format=data_format,
      activation=tf.nn.relu)

  model = tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
      conv2d(filters=32),
      max_pool(),
      conv2d(filters=64),
      max_pool(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(512, activation=tf.nn.relu),
      tf.keras.layers.Dense(10 if only_digits else 62),
      tf.keras.layers.Softmax(),
  ])

  return model

# Gets the type information of the input data. TFF is a strongly typed
# functional programming framework, and needs type information about inputs to 
# the model.
input_spec = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0]).element_spec

def tff_model_fn():
  keras_model = create_original_fedavg_cnn_model()
  return tff.learning.from_keras_model(
      keras_model=keras_model,
      input_spec=input_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

Modeli eğitme ve eğitim metriklerinin çıktısını alma

Artık bir Birleşik Ortalama Algoritması oluşturmaya ve tanımlanan modeli EMNIST veri kümesi üzerinde eğitmeye hazırız.

Önce kullanarak Federe Averaging algoritması oluşturmak için gereken tff.learning.build_federated_averaging_process API.

federated_averaging = tff.learning.build_federated_averaging_process(
    model_fn=tff_model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

Şimdi Federated Averaging algoritmasını çalıştıralım. Bir Federated Learning algoritmasının TFF perspektifinden yürütülmesi şuna benzer:

  1. Algoritmayı başlatın ve ilk sunucu durumunu alın. Sunucu durumu, algoritmayı gerçekleştirmek için gerekli bilgileri içerir. TFF işlevsel olduğundan, bu durumun hem algoritmanın kullandığı herhangi bir optimize edici durumunu (örneğin momentum terimleri) hem de model parametrelerinin kendisini içerdiğini hatırlayın - bunlar argüman olarak geçirilecek ve TFF hesaplamalarının sonuçları olarak döndürülecektir.
  2. Algoritmayı teker teker yürütün. Her turda, her müşterinin modeli kendi verileri üzerinde eğitmesinin sonucu olarak yeni bir sunucu durumu döndürülecektir. Tipik olarak bir turda:
    1. Sunucu, modeli katılan tüm istemcilere yayınlar.
    2. Her müşteri, modele ve kendi verilerine dayalı olarak iş gerçekleştirir.
    3. Sunucu, yeni bir model içeren bir sunucu durumu oluşturmak için tüm modeli toplar.

Daha fazla ayrıntı için bakınız Uygulama Federe Averaging: Özel Federe Algoritmalar, Bölüm 2 öğretici.

Eğitim metrikleri, eğitimden sonra görüntülenmek üzere Tensorboard dizinine yazılır.

Yardımcı fonksiyonları yükle

def train(federated_averaging_process, num_rounds, num_clients_per_round, summary_writer):
  """Trains the federated averaging process and output metrics."""
  # Create a environment to get communication cost.
  environment = set_sizing_environment()

  # Initialize the Federated Averaging algorithm to get the initial server state.
  state = federated_averaging_process.initialize()

  with summary_writer.as_default():
    for round_num in range(num_rounds):
      # Sample the clients parcitipated in this round.
      sampled_clients = np.random.choice(
          emnist_train.client_ids,
          size=num_clients_per_round,
          replace=False)
      # Create a list of `tf.Dataset` instances from the data of sampled clients.
      sampled_train_data = [
          emnist_train.create_tf_dataset_for_client(client)
          for client in sampled_clients
      ]
      # Round one round of the algorithm based on the server state and client data
      # and output the new state and metrics.
      state, metrics = federated_averaging_process.next(state, sampled_train_data)

      # For more about size_info, please see https://www.tensorflow.org/federated/api_docs/python/tff/framework/SizeInfo
      size_info = environment.get_size_info()
      broadcasted_bits = size_info.broadcast_bits[-1]
      aggregated_bits = size_info.aggregate_bits[-1]

      print('round {:2d}, metrics={}, broadcasted_bits={}, aggregated_bits={}'.format(round_num, metrics, format_size(broadcasted_bits), format_size(aggregated_bits)))

      # Add metrics to Tensorboard.
      for name, value in metrics['train'].items():
          tf.summary.scalar(name, value, step=round_num)

      # Add broadcasted and aggregated data size to Tensorboard.
      tf.summary.scalar('cumulative_broadcasted_bits', broadcasted_bits, step=round_num)
      tf.summary.scalar('cumulative_aggregated_bits', aggregated_bits, step=round_num)
      summary_writer.flush()
# Clean the log directory to avoid conflicts.
try:
  tf.io.gfile.rmtree('/tmp/logs/scalars')
except tf.errors.OpError as e:
  pass  # Path doesn't exist

# Set up the log directory and writer for Tensorboard.
logdir = "/tmp/logs/scalars/original/"
summary_writer = tf.summary.create_file_writer(logdir)

train(federated_averaging_process=federated_averaging, num_rounds=10,
      num_clients_per_round=10, summary_writer=summary_writer)
round  0, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.07383774), ('loss', 2.3276227)])), ('stat', OrderedDict([('num_examples', 1097)]))]), broadcasted_bits=507.62Mibit, aggregated_bits=507.62Mibit
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.099585064), ('loss', 2.3152695)])), ('stat', OrderedDict([('num_examples', 964)]))]), broadcasted_bits=1015.24Mibit, aggregated_bits=1015.24Mibit
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09760766), ('loss', 2.3077576)])), ('stat', OrderedDict([('num_examples', 1045)]))]), broadcasted_bits=1.49Gibit, aggregated_bits=1.49Gibit
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.0963035), ('loss', 2.3066626)])), ('stat', OrderedDict([('num_examples', 1028)]))]), broadcasted_bits=1.98Gibit, aggregated_bits=1.98Gibit
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.10694184), ('loss', 2.3033001)])), ('stat', OrderedDict([('num_examples', 1066)]))]), broadcasted_bits=2.48Gibit, aggregated_bits=2.48Gibit
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.1185567), ('loss', 2.2999184)])), ('stat', OrderedDict([('num_examples', 970)]))]), broadcasted_bits=2.97Gibit, aggregated_bits=2.97Gibit
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.11751663), ('loss', 2.296883)])), ('stat', OrderedDict([('num_examples', 902)]))]), broadcasted_bits=3.47Gibit, aggregated_bits=3.47Gibit
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13063477), ('loss', 2.2990246)])), ('stat', OrderedDict([('num_examples', 1087)]))]), broadcasted_bits=3.97Gibit, aggregated_bits=3.97Gibit
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.12742382), ('loss', 2.2971866)])), ('stat', OrderedDict([('num_examples', 1083)]))]), broadcasted_bits=4.46Gibit, aggregated_bits=4.46Gibit
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13555992), ('loss', 2.2934425)])), ('stat', OrderedDict([('num_examples', 1018)]))]), broadcasted_bits=4.96Gibit, aggregated_bits=4.96Gibit

Eğitim ölçümlerini görüntülemek için TensorBoard'u yukarıda belirtilen kök günlük dizini ile başlatın. Verilerin yüklenmesi birkaç saniye sürebilir. Kayıp ve Doğruluk dışında, yayınlanan ve toplu veri miktarını da çıkarıyoruz. Yayınlanan veriler, sunucunun her bir istemciye ittiği tensörleri ifade ederken, toplu veriler, her müşterinin sunucuya döndürdüğü tensörleri ifade eder.

%tensorboard --logdir /tmp/logs/scalars/ --port=0
Launching TensorBoard...
Reusing TensorBoard on port 34445 (pid 579503), started 1:53:14 ago. (Use '!kill 579503' to kill it.)
<IPython.core.display.Javascript at 0x7f9135ef1630>

Özel bir yayın ve toplama işlevi oluşturun

Şimdi kullanılarak yayınlanan verilere ve toplanan veriler üzerinde kayıplı sıkıştırma algoritmaları kullanmak işlevi uygulamak izin tensor_encoding API.

İlk olarak, iki fonksiyon tanımlıyoruz:

  • broadcast_encoder_fn bir örneğini oluşturur te.core.SimpleEncoder müşteri iletişimi (Yayın veriler) sunucu şifreleyen tensörlerle veya değişkenlerin.
  • mean_encoder_fn bir örneğini oluşturur te.core.GatherEncoder sunucu Haberleşme (Toplama veriler) müşteri kodlamak tensörlerle veya değişkenlerin.

Tüm modele aynı anda bir sıkıştırma yöntemi uygulamadığımızı belirtmek önemlidir. Bunun yerine, modelin her bir değişkenini bağımsız olarak nasıl sıkıştıracağımıza (ve sıkıştırıp sıkıştırmayacağımıza) karar veririz. Bunun nedeni, genellikle, önyargılar gibi küçük değişkenlerin yanlışlığa daha duyarlı olması ve nispeten küçük olmaları nedeniyle potansiyel iletişim tasarruflarının da nispeten küçük olmasıdır. Bu nedenle, varsayılan olarak küçük değişkenleri sıkıştırmıyoruz. Bu örnekte, 10000'den fazla öğeye sahip her değişkene 8 bite (256 kova) tek biçimli niceleme uyguluyoruz ve yalnızca diğer değişkenlere özdeşlik uyguluyoruz.

def broadcast_encoder_fn(value):
  """Function for building encoded broadcast."""
  spec = tf.TensorSpec(value.shape, value.dtype)
  if value.shape.num_elements() > 10000:
    return te.encoders.as_simple_encoder(
        te.encoders.uniform_quantization(bits=8), spec)
  else:
    return te.encoders.as_simple_encoder(te.encoders.identity(), spec)


def mean_encoder_fn(tensor_spec):
  """Function for building a GatherEncoder."""
  spec = tf.TensorSpec(tensor_spec.shape, tensor_spec.dtype)
  if tensor_spec.shape.num_elements() > 10000:
    return te.encoders.as_gather_encoder(
        te.encoders.uniform_quantization(bits=8), spec)
  else:
    return te.encoders.as_gather_encoder(te.encoders.identity(), spec)

TFF bir biçime kodlayıcı fonksiyonunu dönüştürmek için API'ler sağlar tff.learning.build_federated_averaging_process API tüketebilir. Kullanarak tff.learning.framework.build_encoded_broadcast_from_model ve tff.aggregators.MeanFactory , biz içine geçirilebilir iki nesneleri oluşturabilir broadcast_process ve model_update_aggregation_factory ait agruments tff.learning.build_federated_averaging_process kayıplı sıkıştırma algoritması ile bir Federe Averaging algoritmalar oluşturmak için.

encoded_broadcast_process = (
    tff.learning.framework.build_encoded_broadcast_process_from_model(
        tff_model_fn, broadcast_encoder_fn))

mean_factory = tff.aggregators.MeanFactory(
    tff.aggregators.EncodedSumFactory(mean_encoder_fn), # numerator
    tff.aggregators.EncodedSumFactory(mean_encoder_fn), # denominator
)

federated_averaging_with_compression = tff.learning.build_federated_averaging_process(
    tff_model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
    broadcast_process=encoded_broadcast_process,
    model_update_aggregation_factory=mean_factory)

Modeli tekrar eğitmek

Şimdi yeni Birleşik Ortalama Algoritmasını çalıştıralım.

logdir_for_compression = "/tmp/logs/scalars/compression/"
summary_writer_for_compression = tf.summary.create_file_writer(
    logdir_for_compression)

train(federated_averaging_process=federated_averaging_with_compression, 
      num_rounds=10,
      num_clients_per_round=10,
      summary_writer=summary_writer_for_compression)
round  0, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.093), ('loss', 2.3194966)])), ('stat', OrderedDict([('num_examples', 1000)]))]), broadcasted_bits=146.46Mibit, aggregated_bits=146.46Mibit
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.10432034), ('loss', 2.3079953)])), ('stat', OrderedDict([('num_examples', 949)]))]), broadcasted_bits=292.92Mibit, aggregated_bits=292.93Mibit
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.07886754), ('loss', 2.3101337)])), ('stat', OrderedDict([('num_examples', 989)]))]), broadcasted_bits=439.38Mibit, aggregated_bits=439.39Mibit
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09774436), ('loss', 2.305069)])), ('stat', OrderedDict([('num_examples', 1064)]))]), broadcasted_bits=585.84Mibit, aggregated_bits=585.85Mibit
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09404097), ('loss', 2.302943)])), ('stat', OrderedDict([('num_examples', 1074)]))]), broadcasted_bits=732.30Mibit, aggregated_bits=732.32Mibit
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09), ('loss', 2.304385)])), ('stat', OrderedDict([('num_examples', 1000)]))]), broadcasted_bits=878.77Mibit, aggregated_bits=878.78Mibit
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.14368932), ('loss', 2.2973824)])), ('stat', OrderedDict([('num_examples', 1030)]))]), broadcasted_bits=1.00Gibit, aggregated_bits=1.00Gibit
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.12140871), ('loss', 2.2993405)])), ('stat', OrderedDict([('num_examples', 1079)]))]), broadcasted_bits=1.14Gibit, aggregated_bits=1.14Gibit
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13600783), ('loss', 2.2953267)])), ('stat', OrderedDict([('num_examples', 1022)]))]), broadcasted_bits=1.29Gibit, aggregated_bits=1.29Gibit
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13844621), ('loss', 2.295768)])), ('stat', OrderedDict([('num_examples', 1004)]))]), broadcasted_bits=1.43Gibit, aggregated_bits=1.43Gibit

İki çalıştırma arasındaki eğitim ölçümlerini karşılaştırmak için TensorBoard'u yeniden başlatın.

Eğer Tensorboard görebileceğiniz gibi, aralarında anlamlı bir azalma olduğu orginial ve compression eğrilerin broadcasted_bits ve aggregated_bits ise araziler loss ve sparse_categorical_accuracy iki eğrileri oldukça benzemektedir arsa.

Sonuç olarak, iletişim maliyeti önemli ölçüde azalırken, orijinal Birleşik Ortalama Algoritması ile benzer performans elde edebilen bir sıkıştırma algoritması uyguladık.

%tensorboard --logdir /tmp/logs/scalars/ --port=0
Launching TensorBoard...
Reusing TensorBoard on port 34445 (pid 579503), started 1:54:12 ago. (Use '!kill 579503' to kill it.)
<IPython.core.display.Javascript at 0x7f9140eb5ef0>

Egzersizler

Özel bir sıkıştırma algoritması uygulamak ve bunu eğitim döngüsüne uygulamak için şunları yapabilirsiniz:

  1. Bir alt sınıfı olarak yeni bir sıkıştırma algoritması uygulanması EncodingStageInterface veya daha genel varyantı AdaptiveEncodingStageInterface aşağıdaki örnekte .
  2. Yeni Construct Encoder ve bunu uzmanlaşmak modeli yayın veya yenilenen modeli ortalamadan .
  3. Tüm inşa etmek bu nesneleri kullanın eğitim hesaplama .

Potansiyel olarak değerli açık araştırma soruları şunları içerir: tekdüze olmayan niceleme, huffman kodlaması gibi kayıpsız sıkıştırma ve önceki eğitim turlarından elde edilen bilgilere dayalı olarak sıkıştırmayı uyarlamak için mekanizmalar.

Önerilen okuma materyalleri: