Xem trên TensorFlow.org | Chạy trong Google Colab | Xem nguồn trên GitHub | Tải xuống sổ ghi chép |
Trong hướng dẫn này, chúng tôi sử dụng EMNIST bộ dữ liệu để chứng minh làm thế nào để cho phép các thuật toán nén lossy để giảm chi phí thông tin liên lạc trong thuật toán trung bình Federated sử dụng tff.learning.build_federated_averaging_process
API và tensor_encoding API. Để biết thêm chi tiết về thuật toán trung bình Federated, xem giấy Learning Truyền thông-hiệu quả của Deep Networks từ phân cấp dữ liệu .
Trước khi chúng ta bắt đầu
Trước khi chúng tôi bắt đầu, vui lòng chạy phần sau để đảm bảo rằng môi trường của bạn được thiết lập chính xác. Nếu bạn không thấy một lời chào, xin vui lòng tham khảo các cài đặt hướng dẫn để được hướng dẫn.
!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade tensorflow-model-optimization
!pip install --quiet --upgrade nest-asyncio
import nest_asyncio
nest_asyncio.apply()
%load_ext tensorboard
import functools
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
from tensorflow_model_optimization.python.core.internal import tensor_encoding as te
Xác minh xem TFF có hoạt động hay không.
@tff.federated_computation
def hello_world():
return 'Hello, World!'
hello_world()
b'Hello, World!'
Chuẩn bị dữ liệu đầu vào
Trong phần này, chúng tôi tải và xử lý trước tập dữ liệu EMNIST có trong TFF. Vui lòng kiểm tra ra Federated Learning cho Phân loại hình hướng dẫn để biết thêm chi tiết về EMNIST tập dữ liệu.
# This value only applies to EMNIST dataset, consider choosing appropriate
# values if switching to other datasets.
MAX_CLIENT_DATASET_SIZE = 418
CLIENT_EPOCHS_PER_ROUND = 1
CLIENT_BATCH_SIZE = 20
TEST_BATCH_SIZE = 500
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(
only_digits=True)
def reshape_emnist_element(element):
return (tf.expand_dims(element['pixels'], axis=-1), element['label'])
def preprocess_train_dataset(dataset):
"""Preprocessing function for the EMNIST training dataset."""
return (dataset
# Shuffle according to the largest client dataset
.shuffle(buffer_size=MAX_CLIENT_DATASET_SIZE)
# Repeat to do multiple local epochs
.repeat(CLIENT_EPOCHS_PER_ROUND)
# Batch to a fixed client batch size
.batch(CLIENT_BATCH_SIZE, drop_remainder=False)
# Preprocessing step
.map(reshape_emnist_element))
emnist_train = emnist_train.preprocess(preprocess_train_dataset)
Xác định một mô hình
Ở đây chúng ta xác định một mô hình keras dựa trên orginial FedAvg CNN, và sau đó quấn mô hình keras trong một thể hiện của tff.learning.Model để nó có thể được tiêu thụ bởi TFF.
Lưu ý rằng chúng tôi sẽ cần một chức năng mà tạo ra một mô hình thay vì chỉ đơn giản là một mô hình trực tiếp. Bên cạnh đó, chức năng có thể không chỉ chụp một mô hình pre-xây dựng, nó phải tạo ra các mô hình trong bối cảnh mà nó được gọi. Lý do là TFF được thiết kế để đi đến các thiết bị và cần kiểm soát thời điểm tài nguyên được xây dựng để chúng có thể được thu thập và đóng gói.
def create_original_fedavg_cnn_model(only_digits=True):
"""The CNN model used in https://arxiv.org/abs/1602.05629."""
data_format = 'channels_last'
max_pool = functools.partial(
tf.keras.layers.MaxPooling2D,
pool_size=(2, 2),
padding='same',
data_format=data_format)
conv2d = functools.partial(
tf.keras.layers.Conv2D,
kernel_size=5,
padding='same',
data_format=data_format,
activation=tf.nn.relu)
model = tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
conv2d(filters=32),
max_pool(),
conv2d(filters=64),
max_pool(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dense(10 if only_digits else 62),
tf.keras.layers.Softmax(),
])
return model
# Gets the type information of the input data. TFF is a strongly typed
# functional programming framework, and needs type information about inputs to
# the model.
input_spec = emnist_train.create_tf_dataset_for_client(
emnist_train.client_ids[0]).element_spec
def tff_model_fn():
keras_model = create_original_fedavg_cnn_model()
return tff.learning.from_keras_model(
keras_model=keras_model,
input_spec=input_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
Đào tạo mô hình và xuất các chỉ số đào tạo
Bây giờ chúng ta đã sẵn sàng để xây dựng thuật toán Trung bình liên kết và đào tạo mô hình đã xác định trên tập dữ liệu EMNIST.
Đầu tiên chúng ta cần phải xây dựng một thuật toán trung bình Federated sử dụng tff.learning.build_federated_averaging_process API.
federated_averaging = tff.learning.build_federated_averaging_process(
model_fn=tff_model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))
Bây giờ hãy chạy thuật toán Trung bình Liên kết. Việc thực thi thuật toán Học liên kết từ quan điểm của TFF trông như thế này:
- Khởi tạo thuật toán và lấy trạng thái máy chủ inital. Trạng thái máy chủ chứa thông tin cần thiết để thực hiện thuật toán. Nhớ lại, vì TFF là chức năng, trạng thái này bao gồm cả bất kỳ trạng thái nào của trình tối ưu hóa mà thuật toán sử dụng (ví dụ: các thuật ngữ xung lượng) cũng như bản thân các tham số của mô hình - những tham số này sẽ được chuyển dưới dạng đối số và trả về dưới dạng kết quả từ tính toán TFF.
- Thực hiện thuật toán từng vòng. Trong mỗi vòng, một trạng thái máy chủ mới sẽ được trả về do mỗi máy khách đào tạo mô hình trên dữ liệu của nó. Thông thường trong một vòng:
- Máy chủ quảng bá mô hình cho tất cả các máy khách tham gia.
- Mỗi khách hàng thực hiện công việc dựa trên mô hình và dữ liệu của chính nó.
- Máy chủ tổng hợp tất cả các mô hình để tạo ra một trạng thái máy chủ chứa một mô hình mới.
Để biết thêm chi tiết, vui lòng xem Tuỳ chỉnh Federated thuật toán, Phần 2: Thực hiện Federated trung bình hướng dẫn.
Các chỉ số đào tạo được ghi vào thư mục Tensorboard để hiển thị sau khóa đào tạo.
Tải các chức năng tiện ích
def format_size(size):
"""A helper function for creating a human-readable size."""
size = float(size)
for unit in ['bit','Kibit','Mibit','Gibit']:
if size < 1024.0:
return "{size:3.2f}{unit}".format(size=size, unit=unit)
size /= 1024.0
return "{size:.2f}{unit}".format(size=size, unit='TiB')
def set_sizing_environment():
"""Creates an environment that contains sizing information."""
# Creates a sizing executor factory to output communication cost
# after the training finishes. Note that sizing executor only provides an
# estimate (not exact) of communication cost, and doesn't capture cases like
# compression of over-the-wire representations. However, it's perfect for
# demonstrating the effect of compression in this tutorial.
sizing_factory = tff.framework.sizing_executor_factory()
# TFF has a modular runtime you can configure yourself for various
# environments and purposes, and this example just shows how to configure one
# part of it to report the size of things.
context = tff.framework.ExecutionContext(executor_fn=sizing_factory)
tff.framework.set_default_context(context)
return sizing_factory
def train(federated_averaging_process, num_rounds, num_clients_per_round, summary_writer):
"""Trains the federated averaging process and output metrics."""
# Create a environment to get communication cost.
environment = set_sizing_environment()
# Initialize the Federated Averaging algorithm to get the initial server state.
state = federated_averaging_process.initialize()
with summary_writer.as_default():
for round_num in range(num_rounds):
# Sample the clients parcitipated in this round.
sampled_clients = np.random.choice(
emnist_train.client_ids,
size=num_clients_per_round,
replace=False)
# Create a list of `tf.Dataset` instances from the data of sampled clients.
sampled_train_data = [
emnist_train.create_tf_dataset_for_client(client)
for client in sampled_clients
]
# Round one round of the algorithm based on the server state and client data
# and output the new state and metrics.
state, metrics = federated_averaging_process.next(state, sampled_train_data)
# For more about size_info, please see https://www.tensorflow.org/federated/api_docs/python/tff/framework/SizeInfo
size_info = environment.get_size_info()
broadcasted_bits = size_info.broadcast_bits[-1]
aggregated_bits = size_info.aggregate_bits[-1]
print('round {:2d}, metrics={}, broadcasted_bits={}, aggregated_bits={}'.format(round_num, metrics, format_size(broadcasted_bits), format_size(aggregated_bits)))
# Add metrics to Tensorboard.
for name, value in metrics['train'].items():
tf.summary.scalar(name, value, step=round_num)
# Add broadcasted and aggregated data size to Tensorboard.
tf.summary.scalar('cumulative_broadcasted_bits', broadcasted_bits, step=round_num)
tf.summary.scalar('cumulative_aggregated_bits', aggregated_bits, step=round_num)
summary_writer.flush()
# Clean the log directory to avoid conflicts.
try:
tf.io.gfile.rmtree('/tmp/logs/scalars')
except tf.errors.OpError as e:
pass # Path doesn't exist
# Set up the log directory and writer for Tensorboard.
logdir = "/tmp/logs/scalars/original/"
summary_writer = tf.summary.create_file_writer(logdir)
train(federated_averaging_process=federated_averaging, num_rounds=10,
num_clients_per_round=10, summary_writer=summary_writer)
round 0, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.07383774), ('loss', 2.3276227)])), ('stat', OrderedDict([('num_examples', 1097)]))]), broadcasted_bits=507.62Mibit, aggregated_bits=507.62Mibit round 1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.099585064), ('loss', 2.3152695)])), ('stat', OrderedDict([('num_examples', 964)]))]), broadcasted_bits=1015.24Mibit, aggregated_bits=1015.24Mibit round 2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09760766), ('loss', 2.3077576)])), ('stat', OrderedDict([('num_examples', 1045)]))]), broadcasted_bits=1.49Gibit, aggregated_bits=1.49Gibit round 3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.0963035), ('loss', 2.3066626)])), ('stat', OrderedDict([('num_examples', 1028)]))]), broadcasted_bits=1.98Gibit, aggregated_bits=1.98Gibit round 4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.10694184), ('loss', 2.3033001)])), ('stat', OrderedDict([('num_examples', 1066)]))]), broadcasted_bits=2.48Gibit, aggregated_bits=2.48Gibit round 5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.1185567), ('loss', 2.2999184)])), ('stat', OrderedDict([('num_examples', 970)]))]), broadcasted_bits=2.97Gibit, aggregated_bits=2.97Gibit round 6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.11751663), ('loss', 2.296883)])), ('stat', OrderedDict([('num_examples', 902)]))]), broadcasted_bits=3.47Gibit, aggregated_bits=3.47Gibit round 7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13063477), ('loss', 2.2990246)])), ('stat', OrderedDict([('num_examples', 1087)]))]), broadcasted_bits=3.97Gibit, aggregated_bits=3.97Gibit round 8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.12742382), ('loss', 2.2971866)])), ('stat', OrderedDict([('num_examples', 1083)]))]), broadcasted_bits=4.46Gibit, aggregated_bits=4.46Gibit round 9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13555992), ('loss', 2.2934425)])), ('stat', OrderedDict([('num_examples', 1018)]))]), broadcasted_bits=4.96Gibit, aggregated_bits=4.96Gibit
Khởi động TensorBoard với thư mục nhật ký gốc được chỉ định ở trên để hiển thị các chỉ số đào tạo. Có thể mất vài giây để tải dữ liệu. Ngoại trừ Mất mát và Độ chính xác, chúng tôi cũng xuất lượng dữ liệu tổng hợp và truyền phát. Dữ liệu được truyền phát đề cập đến tensors mà máy chủ đẩy đến từng máy khách trong khi dữ liệu tổng hợp đề cập đến tensors mỗi máy khách quay trở lại máy chủ.
%tensorboard --logdir /tmp/logs/scalars/ --port=0
Launching TensorBoard... Reusing TensorBoard on port 34445 (pid 579503), started 1:53:14 ago. (Use '!kill 579503' to kill it.) <IPython.core.display.Javascript at 0x7f9135ef1630>
Xây dựng một chức năng tổng hợp và truyền phát tùy chỉnh
Bây giờ chúng ta hãy thực hiện chức năng để sử dụng các thuật toán nén tổn hao trên dữ liệu phát sóng và dữ liệu tổng hợp bằng cách sử dụng tensor_encoding API.
Đầu tiên, chúng tôi xác định hai chức năng:
-
broadcast_encoder_fn
mà tạo ra một thể hiện của te.core.SimpleEncoder để tensors mã hóa hoặc biến trong máy chủ để giao tiếp khách hàng (số liệu Broadcast). -
mean_encoder_fn
mà tạo ra một thể hiện của te.core.GatherEncoder để tensors mã hóa hoặc biến trong client tới server communicaiton (dữ liệu tập hợp).
Điều quan trọng cần lưu ý là chúng tôi không áp dụng một phương pháp nén cho toàn bộ mô hình cùng một lúc. Thay vào đó, chúng tôi quyết định cách (và liệu) nén từng biến của mô hình một cách độc lập hay không. Lý do là nói chung, các biến nhỏ như độ lệch nhạy cảm hơn với sự không chính xác và tương đối nhỏ, khả năng tiết kiệm truyền thông cũng tương đối nhỏ. Do đó, chúng tôi không nén các biến nhỏ theo mặc định. Trong ví dụ này, chúng tôi áp dụng lượng tử hóa thống nhất thành 8 bit (256 nhóm) cho mọi biến có hơn 10000 phần tử và chỉ áp dụng nhận dạng cho các biến khác.
def broadcast_encoder_fn(value):
"""Function for building encoded broadcast."""
spec = tf.TensorSpec(value.shape, value.dtype)
if value.shape.num_elements() > 10000:
return te.encoders.as_simple_encoder(
te.encoders.uniform_quantization(bits=8), spec)
else:
return te.encoders.as_simple_encoder(te.encoders.identity(), spec)
def mean_encoder_fn(tensor_spec):
"""Function for building a GatherEncoder."""
spec = tf.TensorSpec(tensor_spec.shape, tensor_spec.dtype)
if tensor_spec.shape.num_elements() > 10000:
return te.encoders.as_gather_encoder(
te.encoders.uniform_quantization(bits=8), spec)
else:
return te.encoders.as_gather_encoder(te.encoders.identity(), spec)
TFF cung cấp các API để chuyển đổi chức năng mã hóa sang một định dạng mà tff.learning.build_federated_averaging_process
API có thể tiêu thụ. Bằng việc sử dụng tff.learning.framework.build_encoded_broadcast_from_model
và tff.aggregators.MeanFactory
, chúng ta có thể tạo ra hai đối tượng có thể được thông qua vào broadcast_process
và model_update_aggregation_factory
agruments của tff.learning.build_federated_averaging_process
để tạo ra một thuật toán Federated Trung bình với một thuật toán nén lossy.
encoded_broadcast_process = (
tff.learning.framework.build_encoded_broadcast_process_from_model(
tff_model_fn, broadcast_encoder_fn))
mean_factory = tff.aggregators.MeanFactory(
tff.aggregators.EncodedSumFactory(mean_encoder_fn), # numerator
tff.aggregators.EncodedSumFactory(mean_encoder_fn), # denominator
)
federated_averaging_with_compression = tff.learning.build_federated_averaging_process(
tff_model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
broadcast_process=encoded_broadcast_process,
model_update_aggregation_factory=mean_factory)
Đào tạo lại mô hình
Bây giờ, hãy chạy thuật toán Trung bình Liên kết mới.
logdir_for_compression = "/tmp/logs/scalars/compression/"
summary_writer_for_compression = tf.summary.create_file_writer(
logdir_for_compression)
train(federated_averaging_process=federated_averaging_with_compression,
num_rounds=10,
num_clients_per_round=10,
summary_writer=summary_writer_for_compression)
round 0, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.093), ('loss', 2.3194966)])), ('stat', OrderedDict([('num_examples', 1000)]))]), broadcasted_bits=146.46Mibit, aggregated_bits=146.46Mibit round 1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.10432034), ('loss', 2.3079953)])), ('stat', OrderedDict([('num_examples', 949)]))]), broadcasted_bits=292.92Mibit, aggregated_bits=292.93Mibit round 2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.07886754), ('loss', 2.3101337)])), ('stat', OrderedDict([('num_examples', 989)]))]), broadcasted_bits=439.38Mibit, aggregated_bits=439.39Mibit round 3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09774436), ('loss', 2.305069)])), ('stat', OrderedDict([('num_examples', 1064)]))]), broadcasted_bits=585.84Mibit, aggregated_bits=585.85Mibit round 4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09404097), ('loss', 2.302943)])), ('stat', OrderedDict([('num_examples', 1074)]))]), broadcasted_bits=732.30Mibit, aggregated_bits=732.32Mibit round 5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09), ('loss', 2.304385)])), ('stat', OrderedDict([('num_examples', 1000)]))]), broadcasted_bits=878.77Mibit, aggregated_bits=878.78Mibit round 6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.14368932), ('loss', 2.2973824)])), ('stat', OrderedDict([('num_examples', 1030)]))]), broadcasted_bits=1.00Gibit, aggregated_bits=1.00Gibit round 7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.12140871), ('loss', 2.2993405)])), ('stat', OrderedDict([('num_examples', 1079)]))]), broadcasted_bits=1.14Gibit, aggregated_bits=1.14Gibit round 8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13600783), ('loss', 2.2953267)])), ('stat', OrderedDict([('num_examples', 1022)]))]), broadcasted_bits=1.29Gibit, aggregated_bits=1.29Gibit round 9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13844621), ('loss', 2.295768)])), ('stat', OrderedDict([('num_examples', 1004)]))]), broadcasted_bits=1.43Gibit, aggregated_bits=1.43Gibit
Khởi động lại TensorBoard để so sánh các chỉ số đào tạo giữa hai lần chạy.
Như bạn có thể thấy trong Tensorboard, có một sự giảm đáng kể giữa orginial
và compression
đường cong trong broadcasted_bits
và aggregated_bits
lô trong khi loss
và sparse_categorical_accuracy
âm mưu hai đường cong là khá tương tự.
Kết luận, chúng tôi đã triển khai một thuật toán nén có thể đạt được hiệu suất tương tự như thuật toán Trung bình liên kết gốc trong khi chi phí chung giảm đáng kể.
%tensorboard --logdir /tmp/logs/scalars/ --port=0
Launching TensorBoard... Reusing TensorBoard on port 34445 (pid 579503), started 1:54:12 ago. (Use '!kill 579503' to kill it.) <IPython.core.display.Javascript at 0x7f9140eb5ef0>
Bài tập
Để triển khai một thuật toán nén tùy chỉnh và áp dụng nó vào vòng lặp đào tạo, bạn có thể:
- Thực hiện một thuật toán nén mới như một lớp con của
EncodingStageInterface
hoặc biến thể tổng quát hơn của nó,AdaptiveEncodingStageInterface
sau ví dụ này . - Xây dựng mới của bạn
Encoder
và chuyên nó cho phát sóng mô hình hay mô hình cập nhật trung bình . - Sử dụng các đối tượng để xây dựng toàn bộ tính toán đào tạo .
Các câu hỏi nghiên cứu mở có giá trị tiềm năng bao gồm: lượng tử hóa không đồng nhất, nén không mất dữ liệu như mã hóa huffman và cơ chế nén thích ứng dựa trên thông tin từ các vòng huấn luyện trước đó.
Tài liệu đọc đề xuất: