TFF לחקר למידה מאוחדת: דגם דחיסה ועדכון

הצג באתר TensorFlow.org הפעל בגוגל קולאב צפה במקור ב-GitHub הורד מחברת

במדריך זה, אנו משתמשים EMNIST נתון כדי להדגים כיצד לאפשר אלגוריתמים לדחיסה כדי להפחית את עלות תקשורת באלגוריתם הממוצע Federated באמצעות tff.learning.build_federated_averaging_process API ואת tensor_encoding API. לפרטים נוספים על אלגוריתם ממוצעי Federated, רואה את נייר למידה יעילה-תקשורת של דיפ רשתות מתוך מבוזר נתונים .

לפני שאנחנו מתחילים

לפני שנתחיל, אנא הפעל את הפעולות הבאות כדי לוודא שהסביבה שלך מוגדרת כהלכה. אם אינך רואה ברכה, עיין התקנת המדריך לקבלת הוראות.

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade tensorflow-model-optimization
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
%load_ext tensorboard

import functools

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

from tensorflow_model_optimization.python.core.internal import tensor_encoding as te

ודא אם TFF עובד.

@tff.federated_computation
def hello_world():
  return 'Hello, World!'

hello_world()
b'Hello, World!'

הכנת נתוני הקלט

בסעיף זה אנו טוענים ומעבדים מראש את מערך הנתונים של EMNIST הכלול ב-TFF. אנא בדוק Federated למידה עבור סיווג תמונה הדרכה לקבלת פרטים נוספים על בסיס הנתונים EMNIST.

# This value only applies to EMNIST dataset, consider choosing appropriate
# values if switching to other datasets.
MAX_CLIENT_DATASET_SIZE = 418

CLIENT_EPOCHS_PER_ROUND = 1
CLIENT_BATCH_SIZE = 20
TEST_BATCH_SIZE = 500

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(
    only_digits=True)

def reshape_emnist_element(element):
  return (tf.expand_dims(element['pixels'], axis=-1), element['label'])

def preprocess_train_dataset(dataset):
  """Preprocessing function for the EMNIST training dataset."""
  return (dataset
          # Shuffle according to the largest client dataset
          .shuffle(buffer_size=MAX_CLIENT_DATASET_SIZE)
          # Repeat to do multiple local epochs
          .repeat(CLIENT_EPOCHS_PER_ROUND)
          # Batch to a fixed client batch size
          .batch(CLIENT_BATCH_SIZE, drop_remainder=False)
          # Preprocessing step
          .map(reshape_emnist_element))

emnist_train = emnist_train.preprocess(preprocess_train_dataset)

הגדרת דגם

כאן אנו מגדירים מודל keras מבוסס על CNN FedAvg orginial, ולאחר מכן לעטוף את המודל keras ב מופע של tff.learning.Model כך שהוא יכול להיות נצרך על ידי TFF.

שים לב נצטרך פונקציה אשר מייצרת מודל במקום פשוט מודל ישירות. בנוסף, הפונקציה לא רק ללכוד מודל מובנה מראש, הוא חייב ליצור מודל בהקשר שזה נקרא. הסיבה היא ש-TFF נועד ללכת למכשירים, וצריך שליטה על מתי נבנים משאבים כך שניתן יהיה ללכוד אותם ולארוז אותם.

def create_original_fedavg_cnn_model(only_digits=True):
  """The CNN model used in https://arxiv.org/abs/1602.05629."""
  data_format = 'channels_last'

  max_pool = functools.partial(
      tf.keras.layers.MaxPooling2D,
      pool_size=(2, 2),
      padding='same',
      data_format=data_format)
  conv2d = functools.partial(
      tf.keras.layers.Conv2D,
      kernel_size=5,
      padding='same',
      data_format=data_format,
      activation=tf.nn.relu)

  model = tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
      conv2d(filters=32),
      max_pool(),
      conv2d(filters=64),
      max_pool(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(512, activation=tf.nn.relu),
      tf.keras.layers.Dense(10 if only_digits else 62),
      tf.keras.layers.Softmax(),
  ])

  return model

# Gets the type information of the input data. TFF is a strongly typed
# functional programming framework, and needs type information about inputs to 
# the model.
input_spec = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0]).element_spec

def tff_model_fn():
  keras_model = create_original_fedavg_cnn_model()
  return tff.learning.from_keras_model(
      keras_model=keras_model,
      input_spec=input_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

הכשרת המודל ופלטת מדדי הדרכה

כעת אנו מוכנים לבנות אלגוריתם ממוצע פדרלי ולאמן את המודל המוגדר על מערך הנתונים של EMNIST.

ראשית עלינו לבנות אלגוריתם ממוצעים Federated באמצעות tff.learning.build_federated_averaging_process API.

federated_averaging = tff.learning.build_federated_averaging_process(
    model_fn=tff_model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

עכשיו בואו נריץ את אלגוריתם הממוצע הפדרציה. הביצוע של אלגוריתם למידה פדרציה מנקודת המבט של TFF נראה כך:

  1. אתחל את האלגוריתם וקבל את מצב השרת הראשוני. מצב השרת מכיל מידע הכרחי לביצוע האלגוריתם. נזכיר, מכיוון ש-TFF הוא פונקציונלי, שמצב זה כולל גם כל מצב אופטימיזציה שהאלגוריתם משתמש בו (למשל מונחי מומנטום) וגם את פרמטרי המודל עצמם - אלה יועברו כארגומנטים ויוחזרו כתוצאות מחישובי TFF.
  2. בצע את האלגוריתם סיבוב אחר סיבוב. בכל סבב, מצב שרת חדש יוחזר כתוצאה מכל לקוח אימון המודל על הנתונים שלו. בדרך כלל בסיבוב אחד:
    1. השרת שידר את המודל לכל הלקוחות המשתתפים.
    2. כל לקוח מבצע עבודה על בסיס המודל והנתונים שלו.
    3. השרת מצרף את כל המודל כדי לייצר מצב שרת המכיל מודל חדש.

לפרטים נוספים, ראו אלגוריתמי Federated מותאמים אישית, חלק 2: ממוצעי היישום Federated הדרכה.

מדדי אימון נכתבים לספריית Tensorboard להצגה לאחר האימון.

טען פונקציות שירות

def train(federated_averaging_process, num_rounds, num_clients_per_round, summary_writer):
  """Trains the federated averaging process and output metrics."""
  # Create a environment to get communication cost.
  environment = set_sizing_environment()

  # Initialize the Federated Averaging algorithm to get the initial server state.
  state = federated_averaging_process.initialize()

  with summary_writer.as_default():
    for round_num in range(num_rounds):
      # Sample the clients parcitipated in this round.
      sampled_clients = np.random.choice(
          emnist_train.client_ids,
          size=num_clients_per_round,
          replace=False)
      # Create a list of `tf.Dataset` instances from the data of sampled clients.
      sampled_train_data = [
          emnist_train.create_tf_dataset_for_client(client)
          for client in sampled_clients
      ]
      # Round one round of the algorithm based on the server state and client data
      # and output the new state and metrics.
      state, metrics = federated_averaging_process.next(state, sampled_train_data)

      # For more about size_info, please see https://www.tensorflow.org/federated/api_docs/python/tff/framework/SizeInfo
      size_info = environment.get_size_info()
      broadcasted_bits = size_info.broadcast_bits[-1]
      aggregated_bits = size_info.aggregate_bits[-1]

      print('round {:2d}, metrics={}, broadcasted_bits={}, aggregated_bits={}'.format(round_num, metrics, format_size(broadcasted_bits), format_size(aggregated_bits)))

      # Add metrics to Tensorboard.
      for name, value in metrics['train'].items():
          tf.summary.scalar(name, value, step=round_num)

      # Add broadcasted and aggregated data size to Tensorboard.
      tf.summary.scalar('cumulative_broadcasted_bits', broadcasted_bits, step=round_num)
      tf.summary.scalar('cumulative_aggregated_bits', aggregated_bits, step=round_num)
      summary_writer.flush()
# Clean the log directory to avoid conflicts.
try:
  tf.io.gfile.rmtree('/tmp/logs/scalars')
except tf.errors.OpError as e:
  pass  # Path doesn't exist

# Set up the log directory and writer for Tensorboard.
logdir = "/tmp/logs/scalars/original/"
summary_writer = tf.summary.create_file_writer(logdir)

train(federated_averaging_process=federated_averaging, num_rounds=10,
      num_clients_per_round=10, summary_writer=summary_writer)
round  0, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.07383774), ('loss', 2.3276227)])), ('stat', OrderedDict([('num_examples', 1097)]))]), broadcasted_bits=507.62Mibit, aggregated_bits=507.62Mibit
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.099585064), ('loss', 2.3152695)])), ('stat', OrderedDict([('num_examples', 964)]))]), broadcasted_bits=1015.24Mibit, aggregated_bits=1015.24Mibit
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09760766), ('loss', 2.3077576)])), ('stat', OrderedDict([('num_examples', 1045)]))]), broadcasted_bits=1.49Gibit, aggregated_bits=1.49Gibit
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.0963035), ('loss', 2.3066626)])), ('stat', OrderedDict([('num_examples', 1028)]))]), broadcasted_bits=1.98Gibit, aggregated_bits=1.98Gibit
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.10694184), ('loss', 2.3033001)])), ('stat', OrderedDict([('num_examples', 1066)]))]), broadcasted_bits=2.48Gibit, aggregated_bits=2.48Gibit
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.1185567), ('loss', 2.2999184)])), ('stat', OrderedDict([('num_examples', 970)]))]), broadcasted_bits=2.97Gibit, aggregated_bits=2.97Gibit
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.11751663), ('loss', 2.296883)])), ('stat', OrderedDict([('num_examples', 902)]))]), broadcasted_bits=3.47Gibit, aggregated_bits=3.47Gibit
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13063477), ('loss', 2.2990246)])), ('stat', OrderedDict([('num_examples', 1087)]))]), broadcasted_bits=3.97Gibit, aggregated_bits=3.97Gibit
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.12742382), ('loss', 2.2971866)])), ('stat', OrderedDict([('num_examples', 1083)]))]), broadcasted_bits=4.46Gibit, aggregated_bits=4.46Gibit
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13555992), ('loss', 2.2934425)])), ('stat', OrderedDict([('num_examples', 1018)]))]), broadcasted_bits=4.96Gibit, aggregated_bits=4.96Gibit

הפעל את TensorBoard עם ספריית יומן השורש שצוינה למעלה כדי להציג את מדדי ההדרכה. ייתכן שיחלפו מספר שניות עד שהנתונים ייטענו. פרט לאובדן ודיוק, אנו מפלטים גם את כמות הנתונים המשודרים והמצטברים. נתונים משודרים מתייחסים לטנזורים שהשרת דוחף לכל לקוח בעוד נתונים מצטברים מתייחסים לטנזורים שכל לקוח מחזיר לשרת.

%tensorboard --logdir /tmp/logs/scalars/ --port=0
Launching TensorBoard...
Reusing TensorBoard on port 34445 (pid 579503), started 1:53:14 ago. (Use '!kill 579503' to kill it.)
<IPython.core.display.Javascript at 0x7f9135ef1630>

בניית שידור מותאם אישית ופונקציית צבירה

עכשיו בואו ליישם פונקציה להשתמש באלגוריתמי דחיסת lossy על נתונים משודרים ונתונים מצטברים באמצעות tensor_encoding API.

ראשית, אנו מגדירים שתי פונקציות:

  • broadcast_encoder_fn אשר יוצר מופע של te.core.SimpleEncoder כדי tensors לקודד או משתנה שרת (נתוני שידור) תקשורת הלקוח.
  • mean_encoder_fn אשר יוצר מופע של te.core.GatherEncoder כדי tensors לקודד או משתנה מלקוח לשרת communicaiton (נתוני Aggregation).

חשוב לציין שאנו לא מיישמים שיטת דחיסה על כל הדגם בבת אחת. במקום זאת, אנו מחליטים כיצד (ואם) לדחוס כל משתנה של המודל באופן עצמאי. הסיבה היא שבדרך כלל, משתנים קטנים כמו הטיות רגישים יותר לאי דיוק, ובהיותם קטנים יחסית, החיסכון הפוטנציאלי בתקשורת הוא גם קטן יחסית. מכאן שאנו לא דוחסים משתנים קטנים כברירת מחדל. בדוגמה זו, אנו מיישמים קוונטיזציה אחידה על 8 סיביות (256 דליים) על כל משתנה עם יותר מ-10000 אלמנטים, ומחילים זהות רק על משתנים אחרים.

def broadcast_encoder_fn(value):
  """Function for building encoded broadcast."""
  spec = tf.TensorSpec(value.shape, value.dtype)
  if value.shape.num_elements() > 10000:
    return te.encoders.as_simple_encoder(
        te.encoders.uniform_quantization(bits=8), spec)
  else:
    return te.encoders.as_simple_encoder(te.encoders.identity(), spec)


def mean_encoder_fn(tensor_spec):
  """Function for building a GatherEncoder."""
  spec = tf.TensorSpec(tensor_spec.shape, tensor_spec.dtype)
  if tensor_spec.shape.num_elements() > 10000:
    return te.encoders.as_gather_encoder(
        te.encoders.uniform_quantization(bits=8), spec)
  else:
    return te.encoders.as_gather_encoder(te.encoders.identity(), spec)

TFF מספק APIs להמיר את הפונקציה מקודד לפורמט tff.learning.build_federated_averaging_process API יכול לצרוך. באמצעות tff.learning.framework.build_encoded_broadcast_from_model ו tff.aggregators.MeanFactory , אנחנו יכולים ליצור שני עצמים שיכולים להיות מועברים לתוך broadcast_process ו model_update_aggregation_factory agruments של tff.learning.build_federated_averaging_process ליצור אלגוריתמים ממוצעים Federated עם אלגוריתם דחיסה lossy.

encoded_broadcast_process = (
    tff.learning.framework.build_encoded_broadcast_process_from_model(
        tff_model_fn, broadcast_encoder_fn))

mean_factory = tff.aggregators.MeanFactory(
    tff.aggregators.EncodedSumFactory(mean_encoder_fn), # numerator
    tff.aggregators.EncodedSumFactory(mean_encoder_fn), # denominator
)

federated_averaging_with_compression = tff.learning.build_federated_averaging_process(
    tff_model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
    broadcast_process=encoded_broadcast_process,
    model_update_aggregation_factory=mean_factory)

אימון מחדש של הדגם

עכשיו בואו נריץ את אלגוריתם הממוצע הפדרציה החדש.

logdir_for_compression = "/tmp/logs/scalars/compression/"
summary_writer_for_compression = tf.summary.create_file_writer(
    logdir_for_compression)

train(federated_averaging_process=federated_averaging_with_compression, 
      num_rounds=10,
      num_clients_per_round=10,
      summary_writer=summary_writer_for_compression)
round  0, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.093), ('loss', 2.3194966)])), ('stat', OrderedDict([('num_examples', 1000)]))]), broadcasted_bits=146.46Mibit, aggregated_bits=146.46Mibit
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.10432034), ('loss', 2.3079953)])), ('stat', OrderedDict([('num_examples', 949)]))]), broadcasted_bits=292.92Mibit, aggregated_bits=292.93Mibit
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.07886754), ('loss', 2.3101337)])), ('stat', OrderedDict([('num_examples', 989)]))]), broadcasted_bits=439.38Mibit, aggregated_bits=439.39Mibit
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09774436), ('loss', 2.305069)])), ('stat', OrderedDict([('num_examples', 1064)]))]), broadcasted_bits=585.84Mibit, aggregated_bits=585.85Mibit
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09404097), ('loss', 2.302943)])), ('stat', OrderedDict([('num_examples', 1074)]))]), broadcasted_bits=732.30Mibit, aggregated_bits=732.32Mibit
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09), ('loss', 2.304385)])), ('stat', OrderedDict([('num_examples', 1000)]))]), broadcasted_bits=878.77Mibit, aggregated_bits=878.78Mibit
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.14368932), ('loss', 2.2973824)])), ('stat', OrderedDict([('num_examples', 1030)]))]), broadcasted_bits=1.00Gibit, aggregated_bits=1.00Gibit
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.12140871), ('loss', 2.2993405)])), ('stat', OrderedDict([('num_examples', 1079)]))]), broadcasted_bits=1.14Gibit, aggregated_bits=1.14Gibit
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13600783), ('loss', 2.2953267)])), ('stat', OrderedDict([('num_examples', 1022)]))]), broadcasted_bits=1.29Gibit, aggregated_bits=1.29Gibit
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13844621), ('loss', 2.295768)])), ('stat', OrderedDict([('num_examples', 1004)]))]), broadcasted_bits=1.43Gibit, aggregated_bits=1.43Gibit

הפעל שוב את TensorBoard כדי להשוות את מדדי האימון בין שתי ריצות.

כפי שניתן לראות ב Tensorboard, יש ירידה משמעותית בין orginial ו compression העקומה ב broadcasted_bits ו aggregated_bits החלקה בעוד loss ו sparse_categorical_accuracy עלילת שתי העקומות הן די similiar.

לסיכום, הטמענו אלגוריתם דחיסה שיכול להשיג ביצועים דומים לאלגוריתם הממוצע הפדרציה המקורי, בעוד שעלות הקומינוקציה מופחתת באופן משמעותי.

%tensorboard --logdir /tmp/logs/scalars/ --port=0
Launching TensorBoard...
Reusing TensorBoard on port 34445 (pid 579503), started 1:54:12 ago. (Use '!kill 579503' to kill it.)
<IPython.core.display.Javascript at 0x7f9140eb5ef0>

תרגילים

כדי ליישם אלגוריתם דחיסה מותאם אישית ולהחיל אותו על לולאת האימון, אתה יכול:

  1. ליישם אלגוריתם דחיסה חדש בתור תת של EncodingStageInterface או וריאנט שלה כללי יותר, AdaptiveEncodingStageInterface הבא בדוגמא זו .
  2. Construct החדש שלך Encoder ו מתמחה בו שידור מודל או מיצוע עדכון מודל .
  3. השימוש חפצים אלה כדי לבנות את כל החישובים אימונים .

שאלות מחקר פתוחות בעלות ערך פוטנציאלי כוללות: קוונטיזציה לא אחידה, דחיסה ללא הפסדים כגון קידוד האפמן ומנגנונים להתאמת דחיסה על סמך המידע מסבבי אימון קודמים.

חומרי קריאה מומלצים: