Học liên kết để phân loại hình ảnh

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ ghi chép

Trong hướng dẫn này, chúng tôi sử dụng các ví dụ huấn luyện MNIST cổ điển để giới thiệu Learning Federated (FL) lớp API của TFF, tff.learning - một tập hợp các giao diện cấp cao hơn có thể được sử dụng để thực hiện các loại chung của nhiệm vụ học tập liên kết, chẳng hạn như đào tạo liên kết, chống lại các mô hình do người dùng cung cấp được triển khai trong TensorFlow.

Hướng dẫn này và API học liên kết, chủ yếu dành cho những người dùng muốn kết nối các mô hình TensorFlow của riêng họ vào TFF, coi mô hình sau chủ yếu là một hộp đen. Đối với một sâu sắc hơn sự hiểu biết của TFF và làm thế nào để thực hiện thuật toán học liên riêng bạn, hãy xem các hướng dẫn trên API FC Core - Tuỳ chỉnh Federated thuật toán Phần 1Phần 2 .

Để biết thêm về tff.learning , tiếp tục với Learning Federated cho Text thế hệ , hướng dẫn trong đó ngoài bao gồm mô hình tái phát, cũng chứng tỏ tải một đăng model Keras trước huấn luyện cho tinh tế với việc học liên kết hợp với đánh giá sử dụng Keras.

Trước khi chúng ta bắt đầu

Trước khi chúng tôi bắt đầu, vui lòng chạy phần sau để đảm bảo rằng môi trường của bạn được thiết lập chính xác. Nếu bạn không thấy một lời chào, xin vui lòng tham khảo các cài đặt hướng dẫn để được hướng dẫn.

# tensorflow_federated_nightly also bring in tf_nightly, which
# can causes a duplicate tensorboard install, leading to errors.
!pip uninstall --yes tensorboard tb-nightly

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio
!pip install --quiet --upgrade tb-nightly  # or tensorboard, but not both

import nest_asyncio
nest_asyncio.apply()
%load_ext tensorboard
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

tff.federated_computation(lambda: 'Hello, World!')()
b'Hello, World!'

Chuẩn bị dữ liệu đầu vào

Hãy bắt đầu với dữ liệu. Học liên kết yêu cầu một tập dữ liệu được liên kết, tức là một tập hợp dữ liệu từ nhiều người dùng. Dữ liệu Federated thường không iid , trong đó đặt ra một bộ duy nhất của những thách thức.

Để tạo điều kiện thử nghiệm, chúng tôi hạt kho TFF với một vài bộ dữ liệu, bao gồm một phiên bản liên của MNIST có chứa một phiên bản của NIST bộ dữ liệu gốc đã được tái xử lý bằng để các dữ liệu được keyed bởi nhà văn gốc các chữ số. Vì mỗi người viết có một phong cách riêng, tập dữ liệu này thể hiện kiểu hành vi không ổn định như mong đợi của các tập dữ liệu liên kết.

Đây là cách chúng tôi có thể tải nó.

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

Các bộ dữ liệu trả về bởi load_data() là trường hợp của tff.simulation.ClientData , một giao diện cho phép bạn liệt kê các thiết lập của người dùng, để xây dựng một tf.data.Dataset đại diện cho dữ liệu của một người dùng cụ thể, và để truy vấn cấu trúc của các phần tử riêng lẻ. Đây là cách bạn có thể sử dụng giao diện này để khám phá nội dung của tập dữ liệu. Hãy nhớ rằng mặc dù giao diện này cho phép bạn lặp lại các id máy khách, nhưng đây chỉ là một tính năng của dữ liệu mô phỏng. Như bạn sẽ thấy ngay sau đây, danh tính khách hàng không được sử dụng bởi khung học tập liên kết - mục đích duy nhất của chúng là cho phép bạn chọn các tập hợp con của dữ liệu để mô phỏng.

len(emnist_train.client_ids)
3383
emnist_train.element_type_structure
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

example_element = next(iter(example_dataset))

example_element['label'].numpy()
1
from matplotlib import pyplot as plt
plt.imshow(example_element['pixels'].numpy(), cmap='gray', aspect='equal')
plt.grid(False)
_ = plt.show()

png

Khám phá sự không đồng nhất trong dữ liệu được liên kết

Dữ liệu Federated thường không iid , người dùng thường có sự phân bố dữ liệu khác nhau tùy thuộc vào thói quen sử dụng. Một số khách hàng có thể có ít ví dụ đào tạo hơn trên thiết bị, do dữ liệu bị mờ cục bộ, trong khi một số khách hàng sẽ có nhiều ví dụ đào tạo hơn. Hãy cùng khám phá khái niệm về tính không đồng nhất dữ liệu điển hình của một hệ thống liên hợp với dữ liệu EMNIST mà chúng tôi có sẵn. Điều quan trọng cần lưu ý là phân tích sâu về dữ liệu của khách hàng chỉ có sẵn cho chúng tôi vì đây là môi trường mô phỏng nơi tất cả dữ liệu có sẵn cho chúng tôi tại địa phương. Trong môi trường liên kết sản xuất thực, bạn sẽ không thể kiểm tra dữ liệu của một khách hàng.

Đầu tiên, hãy lấy mẫu dữ liệu của một khách hàng để có cảm nhận về các ví dụ trên một thiết bị mô phỏng. Bởi vì tập dữ liệu chúng tôi đang sử dụng đã được khóa bởi người viết duy nhất, dữ liệu của một khách hàng đại diện cho chữ viết tay của một người cho một mẫu các chữ số từ 0 đến 9, mô phỏng "kiểu sử dụng" duy nhất của một người dùng.

## Example MNIST digits for one client
figure = plt.figure(figsize=(20, 4))
j = 0

for example in example_dataset.take(40):
  plt.subplot(4, 10, j+1)
  plt.imshow(example['pixels'].numpy(), cmap='gray', aspect='equal')
  plt.axis('off')
  j += 1

png

Bây giờ chúng ta hãy hình dung số lượng ví dụ trên mỗi máy khách cho mỗi nhãn chữ số MNIST. Trong môi trường liên kết, số lượng ví dụ trên mỗi máy khách có thể khác nhau khá nhiều, tùy thuộc vào hành vi của người dùng.

# Number of examples per layer for a sample of clients
f = plt.figure(figsize=(12, 7))
f.suptitle('Label Counts for a Sample of Clients')
for i in range(6):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    # Append counts individually per label to make plots
    # more colorful instead of one color per plot.
    label = example['label'].numpy()
    plot_data[label].append(label)
  plt.subplot(2, 3, i+1)
  plt.title('Client {}'.format(i))
  for j in range(10):
    plt.hist(
        plot_data[j],
        density=False,
        bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

png

Bây giờ chúng ta hãy hình dung hình ảnh trung bình trên mỗi khách hàng cho mỗi nhãn MNIST. Mã này sẽ tạo ra giá trị trung bình của mỗi pixel cho tất cả các ví dụ của người dùng cho một nhãn. Chúng ta sẽ thấy rằng hình ảnh có ý nghĩa của một khách hàng cho một chữ số sẽ trông khác với hình ảnh trung bình của một khách hàng khác cho cùng một chữ số, do kiểu chữ viết tay độc đáo của mỗi người. Chúng tôi có thể tìm hiểu về cách mỗi vòng đào tạo địa phương sẽ thúc đẩy mô hình theo một hướng khác nhau đối với mỗi khách hàng, vì chúng tôi đang học hỏi từ dữ liệu duy nhất của chính người dùng đó trong vòng đào tạo tại địa phương đó. Ở phần sau của hướng dẫn, chúng ta sẽ xem cách chúng ta có thể nhận từng bản cập nhật cho mô hình từ tất cả các khách hàng và tổng hợp chúng lại với nhau thành mô hình toàn cầu mới của chúng tôi, mô hình này đã học được từ dữ liệu duy nhất của mỗi khách hàng của chúng tôi.

# Each client has different mean images, meaning each client will be nudging
# the model in their own directions locally.

for i in range(5):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    plot_data[example['label'].numpy()].append(example['pixels'].numpy())
  f = plt.figure(i, figsize=(12, 5))
  f.suptitle("Client #{}'s Mean Image Per Label".format(i))
  for j in range(10):
    mean_img = np.mean(plot_data[j], 0)
    plt.subplot(2, 5, j+1)
    plt.imshow(mean_img.reshape((28, 28)))
    plt.axis('off')

png

png

png

png

png

Dữ liệu người dùng có thể bị nhiễu và được gắn nhãn không đáng tin cậy. Ví dụ: nhìn vào dữ liệu của Khách hàng số 2 ở trên, chúng ta có thể thấy rằng đối với nhãn 2, có thể có một số ví dụ được gắn nhãn sai tạo ra một hình ảnh trung bình ồn ào hơn.

Xử lý trước dữ liệu đầu vào

Kể từ khi dữ liệu đã là một tf.data.Dataset , tiền xử lý có thể được thực hiện bằng biến đổi Dataset. Ở đây, chúng ta san bằng 28x28 hình ảnh vào 784 mảng -element, shuffle các ví dụ cá nhân, sắp xếp chúng vào lô, và đổi tên các tính năng từ pixelslabel để xy để sử dụng với Keras. Chúng tôi cũng ném vào một repeat trên bộ dữ liệu để chạy nhiều thời đại.

NUM_CLIENTS = 10
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER = 10

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER, seed=1).batch(
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

Hãy xác minh điều này đã hoạt động.

preprocessed_example_dataset = preprocess(example_dataset)

sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                     next(iter(preprocessed_example_dataset)))

sample_batch
OrderedDict([('x', array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       ...,
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)), ('y', array([[2],
       [1],
       [5],
       [7],
       [1],
       [7],
       [7],
       [1],
       [4],
       [7],
       [4],
       [2],
       [2],
       [5],
       [4],
       [1],
       [1],
       [0],
       [0],
       [9]], dtype=int32))])

Chúng tôi có gần như tất cả các khối xây dựng tại chỗ để xây dựng các tập dữ liệu được liên kết.

Một trong những cách để cung cấp dữ liệu liên để TFF trong một mô phỏng là cách đơn giản là một danh sách Python, với mỗi phần tử của danh sách tổ chức các dữ liệu của người dùng cá nhân, cho dù là một danh sách hoặc như một tf.data.Dataset . Vì chúng ta đã có một giao diện cung cấp giao diện thứ hai, hãy sử dụng nó.

Đây là một hàm trợ giúp đơn giản sẽ tạo danh sách các bộ dữ liệu từ một nhóm người dùng nhất định làm đầu vào cho một vòng đào tạo hoặc đánh giá.

def make_federated_data(client_data, client_ids):
  return [
      preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids
  ]

Bây giờ, chúng ta chọn khách hàng như thế nào?

Trong một kịch bản đào tạo liên hợp điển hình, chúng tôi đang đối phó với một lượng lớn thiết bị người dùng có khả năng rất lớn, chỉ một phần nhỏ trong số đó có thể khả dụng để đào tạo tại một thời điểm nhất định. Đây là trường hợp, ví dụ, khi các thiết bị khách hàng là điện thoại di động tham gia đào tạo chỉ khi được cắm vào nguồn điện, tắt mạng đo lường, và nếu không thì không hoạt động.

Tất nhiên, chúng tôi đang ở trong một môi trường mô phỏng và tất cả dữ liệu đều có sẵn tại địa phương. Thông thường, khi chạy mô phỏng, chúng tôi chỉ cần lấy mẫu ngẫu nhiên một tập hợp con khách hàng tham gia vào mỗi vòng đào tạo, nói chung là khác nhau trong mỗi vòng.

Điều đó nói rằng, như bạn có thể tìm hiểu bằng cách nghiên cứu bài báo trên trung bình Federated thuật toán, đạt hội tụ trong một hệ thống với các tập con lấy mẫu ngẫu nhiên của khách hàng trong mỗi vòng có thể mất một thời gian, và nó sẽ là không thực tế để phải chạy hàng trăm viên đạn trong hướng dẫn tương tác này.

Thay vào đó, những gì chúng tôi sẽ làm là lấy mẫu nhóm khách hàng một lần và sử dụng lại nhóm khách hàng tương tự qua các vòng để tăng tốc độ hội tụ (cố ý phù hợp quá mức với dữ liệu của một số người dùng này). Chúng tôi để nó như một bài tập cho người đọc để sửa đổi hướng dẫn này để mô phỏng lấy mẫu ngẫu nhiên - việc này khá dễ thực hiện (một khi bạn làm như vậy, hãy nhớ rằng việc đưa mô hình hội tụ có thể mất một lúc).

sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]

federated_train_data = make_federated_data(emnist_train, sample_clients)

print('Number of client datasets: {l}'.format(l=len(federated_train_data)))
print('First dataset: {d}'.format(d=federated_train_data[0]))
Number of client datasets: 10
First dataset: <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>

Tạo mô hình với Keras

Nếu bạn đang sử dụng Keras, bạn có thể đã có mã xây dựng mô hình Keras. Đây là một ví dụ về một mô hình đơn giản sẽ đáp ứng đủ cho nhu cầu của chúng tôi.

def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])

Để sử dụng bất kỳ mô hình với TFF, nó cần phải được bọc trong một thể hiện của các tff.learning.Model giao diện, mà thấy nhiều phương pháp để dập tắt vượt qua phía trước của mô hình, tính chất siêu dữ liệu, vv, tương tự như Keras, mà còn giới thiệu thêm các yếu tố, chẳng hạn như các cách kiểm soát quá trình tính toán các chỉ số được liên kết. Bây giờ chúng ta đừng lo lắng về điều này; nếu bạn có một mô hình Keras như mà chúng ta vừa định nghĩa ở trên, bạn có thể có TFF quấn nó cho bạn bằng cách gọi tff.learning.from_keras_model , đi qua các mô hình và một loạt dữ liệu mẫu như các đối số, như hình dưới đây.

def model_fn():
  # We _must_ create a new model here, and _not_ capture it from an external
  # scope. TFF will call this within different graph contexts.
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

Đào tạo mô hình trên dữ liệu liên kết

Bây giờ chúng ta có một mô hình bao bọc như tff.learning.Model để sử dụng với TFF, chúng ta có thể để cho TFF xây dựng một thuật toán trung bình Federated bằng cách gọi các chức năng helper tff.learning.build_federated_averaging_process , như sau.

Hãy ghi nhớ rằng lập luận cần phải là một nhà xây dựng (như model_fn ở trên), không phải là một ví dụ đã-xây dựng, do đó việc xây dựng các mô hình của bạn có thể xảy ra trong một bối cảnh điều khiển bởi TFF (Nếu bạn đang tò mò về lý do này, chúng tôi khuyến khích bạn đọc theo dõi hướng dẫn về thuật toán tùy chỉnh ).

Một lưu ý quan trọng trên các thuật toán trung bình Federated dưới đây, có 2 tối ưu: một ưu _client và tối ưu hóa _SERVER. Tôi ưu hoa _client chỉ được sử dụng để tính toán cập nhật mô hình cục bộ trên mỗi khách hàng. Tôi ưu hoa _SERVER áp dụng bản cập nhật tính trung bình để mô hình toàn cầu của máy chủ. Đặc biệt, điều này có nghĩa là lựa chọn trình tối ưu hóa và tốc độ học được sử dụng có thể cần phải khác với lựa chọn bạn đã sử dụng để đào tạo mô hình trên tập dữ liệu iid tiêu chuẩn. Chúng tôi khuyên bạn nên bắt đầu với SGD thông thường, có thể với tỷ lệ học tập nhỏ hơn bình thường. Tỷ lệ học tập mà chúng tôi sử dụng chưa được điều chỉnh cẩn thận, hãy thoải mái thử nghiệm.

iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

Chuyện gì vừa xảy ra vậy? TFF đã xây dựng một cặp tính toán liên và đóng gói chúng thành một tff.templates.IterativeProcess trong đó những tính toán có sẵn như là một cặp tính initializenext .

Tóm lại, tính toán liên là những chương trình trong ngôn ngữ nội bộ của TFF có thể thể hiện các thuật toán liên khác nhau (bạn có thể tìm thêm về điều này trong tùy chỉnh các thuật toán hướng dẫn). Trong trường hợp này, hai tính tạo ra và đóng gói vào iterative_process thực hiện Federated trung bình .

Mục tiêu của TFF là xác định các phép tính theo cách mà chúng có thể được thực thi trong các cài đặt học liên kết thực, nhưng hiện tại chỉ thời gian chạy mô phỏng thực thi cục bộ mới được thực hiện. Để thực thi một phép tính trong trình mô phỏng, bạn chỉ cần gọi nó giống như một hàm Python. Môi trường thông dịch mặc định này không được thiết kế cho hiệu suất cao, nhưng nó sẽ đủ cho hướng dẫn này; chúng tôi hy vọng sẽ cung cấp thời gian chạy mô phỏng hiệu suất cao hơn để tạo điều kiện cho nghiên cứu quy mô lớn hơn trong các bản phát hành trong tương lai.

Hãy bắt đầu với initialize tính toán. Như trường hợp của tất cả các phép tính liên hợp, bạn có thể coi nó như một hàm. Tính toán không có đối số và trả về một kết quả - biểu diễn trạng thái của quá trình Tính trung bình liên kết trên máy chủ. Mặc dù chúng tôi không muốn đi sâu vào chi tiết của TFF, nhưng có thể mang tính hướng dẫn để xem trạng thái này trông như thế nào. Bạn có thể hình dung nó như sau.

str(iterative_process.initialize.type_signature)
'( -> <model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER)'

Trong khi các loại chữ ký trên lúc đầu có vẻ khó hiểu chút, bạn có thể nhận ra rằng tình trạng máy chủ bao gồm một model (các thông số mô hình ban đầu cho MNIST đó sẽ được phân phối cho tất cả các thiết bị), và optimizer_state (bổ sung thông tin được duy trì bởi các máy chủ, chẳng hạn như số vòng để sử dụng cho lịch biểu siêu tham số, v.v.).

Hãy gọi initialize tính toán để xây dựng tình trạng máy chủ.

state = iterative_process.initialize()

Thứ hai của cặp tính liên kết, next , đại diện cho một vòng duy nhất của Federated trung bình, trong đó bao gồm đẩy tình trạng máy chủ (bao gồm các thông số mô hình) cho khách hàng, trên thiết bị đào tạo về dữ liệu địa phương của họ, thu thập và cập nhật mô hình trung bình và tạo ra một mô hình cập nhật mới tại máy chủ.

Về mặt lý thuyết, bạn có thể nghĩ next là có một loại chữ ký chức năng rằng ngoại hình như sau.

SERVER_STATE, FEDERATED_DATA -> SERVER_STATE, TRAINING_METRICS

Đặc biệt, ta nên suy nghĩ về next() không như là một chức năng mà chạy trên một máy chủ, mà đúng hơn là một đại diện chức năng tường thuật của toàn bộ tính toán phân tán - một số nguyên liệu đầu vào được cung cấp bởi máy chủ ( SERVER_STATE ), nhưng mỗi tham gia thiết bị đóng góp tập dữ liệu cục bộ của riêng nó.

Hãy chạy một vòng đào tạo và hình dung kết quả. Chúng tôi có thể sử dụng dữ liệu được liên kết mà chúng tôi đã tạo ở trên cho một mẫu người dùng.

state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.12345679), ('loss', 3.1193738)])), ('stat', OrderedDict([('num_examples', 4860)]))])

Hãy chạy thêm vài vòng nữa. Như đã lưu ý trước đó, thông thường tại thời điểm này, bạn sẽ chọn một tập hợp con dữ liệu mô phỏng của mình từ một mẫu người dùng mới được chọn ngẫu nhiên cho mỗi vòng để mô phỏng một triển khai thực tế trong đó người dùng liên tục đến và đi, nhưng trong sổ ghi chép tương tác này, cho vì mục đích minh chứng là chúng tôi sẽ chỉ sử dụng lại những người dùng giống nhau, để hệ thống hội tụ nhanh chóng.

NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13518518), ('loss', 2.9834728)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.14382716), ('loss', 2.861665)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.17407407), ('loss', 2.7957022)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.19917695), ('loss', 2.6146567)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.21975309), ('loss', 2.529761)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.2409465), ('loss', 2.4053504)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.2611111), ('loss', 2.315389)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.30823046), ('loss', 2.1240263)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round 10, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.33312756), ('loss', 2.1164262)])), ('stat', OrderedDict([('num_examples', 4860)]))])

Tổn thất trong huấn luyện đang giảm dần sau mỗi đợt huấn luyện liên đoàn, cho thấy mô hình đang hội tụ. Có một số cảnh báo quan trọng với những số liệu đào tạo, tuy nhiên, xem phần đánh giá sau này trong hướng dẫn này.

Hiển thị số liệu mô hình trong TensorBoard

Tiếp theo, hãy hình dung các số liệu từ các phép tính liên hợp này bằng Tensorboard.

Hãy bắt đầu bằng cách tạo thư mục và trình viết tóm tắt tương ứng để ghi các số liệu vào.

logdir = "/tmp/logs/scalars/training/"
summary_writer = tf.summary.create_file_writer(logdir)
state = iterative_process.initialize()

Vẽ biểu đồ các chỉ số vô hướng có liên quan với cùng một người viết tóm tắt.

with summary_writer.as_default():
  for round_num in range(1, NUM_ROUNDS):
    state, metrics = iterative_process.next(state, federated_train_data)
    for name, value in metrics['train'].items():
      tf.summary.scalar(name, value, step=round_num)

Khởi động TensorBoard với thư mục nhật ký gốc được chỉ định ở trên. Có thể mất vài giây để tải dữ liệu.

!ls {logdir}
%tensorboard --logdir {logdir} --port=0
events.out.tfevents.1629557449.ebe6e776479e64ea-4903924a278.borgtask.google.com.458912.1.v2
Launching TensorBoard...
Reusing TensorBoard on port 50681 (pid 292785), started 0:30:30 ago. (Use '!kill 292785' to kill it.)
<IPython.core.display.Javascript at 0x7fd6617e02d0>
# Uncomment and run this this cell to clean your directory of old output for
# future graphs from this directory. We don't run it by default so that if 
# you do a "Runtime > Run all" you don't lose your results.

# !rm -R /tmp/logs/scalars/*

Để xem các chỉ số đánh giá theo cách tương tự, bạn có thể tạo một thư mục eval riêng, như "nhật ký / vô hướng / eval", để ghi vào TensorBoard.

Tùy chỉnh việc triển khai mô hình

Keras là đề nghị cấp cao mô hình API cho TensorFlow , và chúng tôi khuyến khích sử dụng mô hình Keras (thông qua tff.learning.from_keras_model ) trong TFF bất cứ khi nào có thể.

Tuy nhiên, tff.learning cung cấp một giao diện mô hình cấp thấp hơn, tff.learning.Model , mà cho thấy nhiều chức năng tối thiểu cần thiết cho việc sử dụng một mô hình cho việc học tập liên. Trực tiếp thực hiện các giao diện này (có thể vẫn còn sử dụng các khối xây dựng như tf.keras.layers ) cho phép tuỳ biến tối đa mà không sửa đổi bên trong của thuật toán học liên.

Vì vậy, chúng ta hãy làm lại từ đầu.

Xác định các biến mô hình, chuyển tiếp và số liệu

Bước đầu tiên là xác định các biến TensorFlow mà chúng ta sẽ làm việc với. Để làm cho đoạn mã sau dễ đọc hơn, hãy xác định cấu trúc dữ liệu để đại diện cho toàn bộ tập hợp. Điều này sẽ bao gồm các biến như weightsbias rằng chúng tôi sẽ đào tạo, cũng như các biến mà sẽ tổ chức thống kê khác nhau tích lũy và quầy chúng tôi sẽ cập nhật trong thời gian đào tạo, chẳng hạn như loss_sum , accuracy_sum , và num_examples .

MnistVariables = collections.namedtuple(
    'MnistVariables', 'weights bias num_examples loss_sum accuracy_sum')

Đây là một phương pháp tạo các biến. Vì lợi ích của sự đơn giản, chúng tôi đại diện tất cả các thống kê như tf.float32 , vì điều đó sẽ loại bỏ sự cần thiết của loại chuyển đổi ở giai đoạn sau. Gói initializers biến như lambdas là một yêu cầu áp đặt bởi các biến tài nguyên .

def create_mnist_variables():
  return MnistVariables(
      weights=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(784, 10)),
          name='weights',
          trainable=True),
      bias=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(10)),
          name='bias',
          trainable=True),
      num_examples=tf.Variable(0.0, name='num_examples', trainable=False),
      loss_sum=tf.Variable(0.0, name='loss_sum', trainable=False),
      accuracy_sum=tf.Variable(0.0, name='accuracy_sum', trainable=False))

Với các biến cho tham số mô hình và thống kê tích lũy đã có sẵn, giờ đây chúng ta có thể xác định phương pháp chuyển tiếp tính toán tổn thất, đưa ra dự đoán và cập nhật thống kê tích lũy cho một lô dữ liệu đầu vào, như sau.

def predict_on_batch(variables, x):
  return tf.nn.softmax(tf.matmul(x, variables.weights) + variables.bias)

def mnist_forward_pass(variables, batch):
  y = predict_on_batch(variables, batch['x'])
  predictions = tf.cast(tf.argmax(y, 1), tf.int32)

  flat_labels = tf.reshape(batch['y'], [-1])
  loss = -tf.reduce_mean(
      tf.reduce_sum(tf.one_hot(flat_labels, 10) * tf.math.log(y), axis=[1]))
  accuracy = tf.reduce_mean(
      tf.cast(tf.equal(predictions, flat_labels), tf.float32))

  num_examples = tf.cast(tf.size(batch['y']), tf.float32)

  variables.num_examples.assign_add(num_examples)
  variables.loss_sum.assign_add(loss * num_examples)
  variables.accuracy_sum.assign_add(accuracy * num_examples)

  return loss, predictions

Tiếp theo, chúng tôi xác định một hàm trả về một tập hợp các chỉ số cục bộ, một lần nữa bằng cách sử dụng TensorFlow. Đây là các giá trị (ngoài các bản cập nhật mô hình, được xử lý tự động) đủ điều kiện để được tổng hợp vào máy chủ trong quá trình học tập hoặc đánh giá được liên kết.

Ở đây, chúng tôi chỉ đơn giản là trả lại trung bình lossaccuracy , cũng như num_examples , mà chúng tôi sẽ cần phải cân một cách chính xác những đóng góp từ những người dùng khác nhau khi tính toán uẩn liên.

def get_local_mnist_metrics(variables):
  return collections.OrderedDict(
      num_examples=variables.num_examples,
      loss=variables.loss_sum / variables.num_examples,
      accuracy=variables.accuracy_sum / variables.num_examples)

Cuối cùng, chúng ta cần phải xác định làm thế nào để tổng hợp các số liệu địa phương phát ra từ mỗi thiết bị thông qua get_local_mnist_metrics . Đây là phần duy nhất của mã mà không được viết bằng TensorFlow - đó là một tính toán liên bày tỏ trong TFF. Nếu bạn muốn tìm hiểu sâu hơn, lướt qua toàn bộ các tùy chỉnh các thuật toán hướng dẫn, nhưng trong hầu hết các ứng dụng, bạn sẽ không thực sự cần phải; các biến thể của mẫu hiển thị bên dưới là đủ. Đây là những gì nó trông giống như:

@tff.federated_computation
def aggregate_mnist_metrics_across_clients(metrics):
  return collections.OrderedDict(
      num_examples=tff.federated_sum(metrics.num_examples),
      loss=tff.federated_mean(metrics.loss, metrics.num_examples),
      accuracy=tff.federated_mean(metrics.accuracy, metrics.num_examples))

Các đầu vào metrics tương ứng với tham số cho OrderedDict trả về bởi get_local_mnist_metrics trên, nhưng giới phê bình các giá trị không còn tf.Tensors - họ là "đóng hộp" như tff.Value s, để làm cho nó rõ ràng bạn không còn có thể thao tác chúng bằng cách sử TensorFlow, nhưng chỉ sử dụng khai thác liên TFF như tff.federated_meantff.federated_sum . Từ điển tổng hợp toàn cầu được trả về xác định tập hợp số liệu sẽ có sẵn trên máy chủ.

Xây dựng một thể hiện của tff.learning.Model

Với tất cả những điều trên, chúng tôi đã sẵn sàng xây dựng một biểu diễn mô hình để sử dụng với TFF tương tự như một biểu diễn được tạo cho bạn khi bạn cho phép TFF nhập mô hình Keras.

from typing import Callable, List, OrderedDict

class MnistModel(tff.learning.Model):

  def __init__(self):
    self._variables = create_mnist_variables()

  @property
  def trainable_variables(self):
    return [self._variables.weights, self._variables.bias]

  @property
  def non_trainable_variables(self):
    return []

  @property
  def local_variables(self):
    return [
        self._variables.num_examples, self._variables.loss_sum,
        self._variables.accuracy_sum
    ]

  @property
  def input_spec(self):
    return collections.OrderedDict(
        x=tf.TensorSpec([None, 784], tf.float32),
        y=tf.TensorSpec([None, 1], tf.int32))

  @tf.function
  def predict_on_batch(self, x, training=True):
    del training
    return predict_on_batch(self._variables, x)

  @tf.function
  def forward_pass(self, batch, training=True):
    del training
    loss, predictions = mnist_forward_pass(self._variables, batch)
    num_exmaples = tf.shape(batch['x'])[0]
    return tff.learning.BatchOutput(
        loss=loss, predictions=predictions, num_examples=num_exmaples)

  @tf.function
  def report_local_outputs(self):
    return get_local_mnist_metrics(self._variables)

  @property
  def federated_output_computation(self):
    return aggregate_mnist_metrics_across_clients

  @tf.function
  def report_local_unfinalized_metrics(
      self) -> OrderedDict[str, List[tf.Tensor]]:
    """Creates an `OrderedDict` of metric names to unfinalized values."""
    return collections.OrderedDict(
        num_examples=[self._variables.num_examples],
        loss=[self._variables.loss_sum, self._variables.num_examples],
        accuracy=[self._variables.accuracy_sum, self._variables.num_examples])

  def metric_finalizers(
      self) -> OrderedDict[str, Callable[[List[tf.Tensor]], tf.Tensor]]:
    """Creates an `OrderedDict` of metric names to finalizers."""
    return collections.OrderedDict(
        num_examples=tf.function(func=lambda x: x[0]),
        loss=tf.function(func=lambda x: x[0] / x[1]),
        accuracy=tf.function(func=lambda x: x[0] / x[1]))

Như bạn thấy, các phương pháp trừu tượng và tài sản được xác định bởi tff.learning.Model tương ứng với các đoạn mã trong phần trước đó giới thiệu các biến và xác định thiệt hại và thống kê.

Dưới đây là một số điểm đáng chú ý:

  • Tất cả các trạng thái đó mô hình của bạn sẽ sử dụng phải được chụp như biến TensorFlow, như TFF không sử dụng Python trong thời gian chạy (nhớ mã của bạn nên được viết như vậy mà nó có thể được triển khai đến các thiết bị di động, xem các tùy chỉnh các thuật toán hướng dẫn cho một sâu hơn bình luận về lý do).
  • Mô hình của bạn nên mô tả những gì hình thức của dữ liệu mà nó chấp nhận ( input_spec ), như nói chung, TFF là một môi trường mạnh mẽ, đánh máy và muốn xác định loại chữ ký cho tất cả các thành phần. Khai báo định dạng của đầu vào mô hình của bạn là một phần thiết yếu của nó.
  • Mặc dù về mặt kỹ thuật không cần thiết, chúng tôi khuyên bạn nên gói tất cả các logic TensorFlow (về phía trước vượt qua, tính toán số liệu, vv) như tf.function s, vì điều này sẽ giúp đảm bảo TensorFlow thể được tuần tự, và loại bỏ sự cần thiết phụ thuộc kiểm soát rõ ràng.

Trên đây là đủ để đánh giá và các thuật toán như Federated SGD. Tuy nhiên, đối với Tính trung bình liên kết, chúng ta cần chỉ định cách mô hình sẽ đào tạo cục bộ trên mỗi lô. Chúng tôi sẽ chỉ định một trình tối ưu hóa cục bộ khi xây dựng thuật toán Trung bình Liên kết.

Mô phỏng đào tạo liên đoàn với mô hình mới

Với tất cả những điều ở trên, phần còn lại của quá trình trông giống như những gì chúng ta đã thấy - chỉ cần thay thế hàm tạo mô hình bằng hàm tạo của lớp mô hình mới của chúng tôi và sử dụng hai phép tính liên kết trong quy trình lặp lại mà bạn đã tạo để chuyển qua các vòng huấn luyện.

iterative_process = tff.learning.build_federated_averaging_process(
    MnistModel,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02))
state = iterative_process.initialize()
state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 3.0708053), ('accuracy', 0.12777779)])), ('stat', OrderedDict([('num_examples', 4860)]))])
for round_num in range(2, 11):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 3.011699), ('accuracy', 0.13024691)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.7408307), ('accuracy', 0.15576132)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.6761012), ('accuracy', 0.17921811)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.675567), ('accuracy', 0.1855967)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.5664043), ('accuracy', 0.20329218)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.4179392), ('accuracy', 0.24382716)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.3237286), ('accuracy', 0.26687244)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.1861682), ('accuracy', 0.28209877)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round 10, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.046388), ('accuracy', 0.32037038)])), ('stat', OrderedDict([('num_examples', 4860)]))])

Để xem các chỉ số này trong TensorBoard, hãy tham khảo các bước được liệt kê ở trên trong "Hiển thị số liệu của mô hình trong TensorBoard".

Đánh giá

Tất cả các thử nghiệm của chúng tôi cho đến nay chỉ trình bày các chỉ số đào tạo được liên kết - các chỉ số trung bình trên tất cả các lô dữ liệu được đào tạo trên tất cả các khách hàng trong vòng. Điều này dẫn đến những lo ngại bình thường về việc trang bị quá nhiều, đặc biệt là vì chúng tôi đã sử dụng cùng một nhóm khách hàng trên mỗi vòng để đơn giản hóa, nhưng có thêm khái niệm về trang bị quá mức trong các chỉ số đào tạo cụ thể cho thuật toán Trung bình liên kết. Điều này dễ thấy nhất nếu chúng ta tưởng tượng rằng mỗi khách hàng có một lô dữ liệu duy nhất và chúng tôi đào tạo trên lô đó cho nhiều lần lặp lại (kỷ nguyên). Trong trường hợp này, mô hình cục bộ sẽ nhanh chóng phù hợp chính xác với một lô đó và do đó chỉ số độ chính xác cục bộ mà chúng tôi trung bình sẽ đạt tới 1,0. Do đó, những thước đo đào tạo này có thể được coi là một dấu hiệu cho thấy việc đào tạo đang tiến bộ, nhưng không nhiều hơn.

Thực hiện đánh giá trên dữ liệu liên kết, bạn có thể xây dựng thêm tính liên thiết kế chỉ cho mục đích này, bằng cách sử dụng tff.learning.build_federated_evaluation chức năng, và đi qua trong constructor mô hình của bạn như một cuộc tranh cãi. Lưu ý rằng không giống như với Federated trung bình, nơi mà chúng tôi đã sử dụng MnistTrainableModel , nó cũng đủ để vượt qua MnistModel . Đánh giá không thực hiện giảm độ dốc và không cần phải xây dựng trình tối ưu hóa.

Đối với thí nghiệm và nghiên cứu, khi kiểm tra dữ liệu tập trung có sẵn, Federated Learning cho Text thế hệ cho thấy một lựa chọn đánh giá: lấy trọng lượng đào tạo từ học liên kết, áp dụng chúng vào một mô hình Keras tiêu chuẩn, và sau đó chỉ cần gọi tf.keras.models.Model.evaluate() trên một tập dữ liệu tập trung.

evaluation = tff.learning.build_federated_evaluation(MnistModel)

Bạn có thể kiểm tra chữ ký kiểu trừu tượng của hàm đánh giá như sau.

str(evaluation.type_signature)
'(<server_model_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,784],y=int32[?,1]>*}@CLIENTS> -> <eval=<num_examples=float32,loss=float32,accuracy=float32>,stat=<num_examples=int64>>@SERVER)'

Không cần phải được quan tâm về các chi tiết vào thời điểm này, chỉ cần lưu ý rằng nó có dạng tổng quát sau đây, tương tự như tff.templates.IterativeProcess.next nhưng với hai sự khác biệt quan trọng. Đầu tiên, chúng tôi không trả lại trạng thái máy chủ, vì đánh giá không sửa đổi mô hình hoặc bất kỳ khía cạnh nào khác của trạng thái - bạn có thể coi nó là trạng thái không trạng thái. Thứ hai, đánh giá chỉ cần mô hình và không yêu cầu bất kỳ phần nào khác của trạng thái máy chủ có thể được liên kết với đào tạo, chẳng hạn như các biến trình tối ưu hóa.

SERVER_MODEL, FEDERATED_DATA -> TRAINING_METRICS

Hãy gọi đánh giá về trạng thái mới nhất mà chúng tôi đạt được trong quá trình huấn luyện. Để trích xuất mới nhất của mô hình đào tạo từ trạng thái máy chủ, bạn chỉ cần truy cập vào .model thành viên, như sau.

train_metrics = evaluation(state.model, federated_train_data)

Đây là những gì chúng tôi nhận được. Lưu ý rằng những con số trông tốt hơn một chút so với những gì được báo cáo bởi vòng đào tạo cuối cùng ở trên. Theo quy ước, các chỉ số đào tạo được báo cáo bởi quá trình đào tạo lặp đi lặp lại thường phản ánh hiệu suất của mô hình khi bắt đầu vòng đào tạo, do đó, các chỉ số đánh giá sẽ luôn đi trước một bước.

str(train_metrics)
"OrderedDict([('eval', OrderedDict([('num_examples', 4860.0), ('loss', 1.7510437), ('accuracy', 0.2788066)])), ('stat', OrderedDict([('num_examples', 4860)]))])"

Bây giờ, hãy biên dịch một mẫu thử nghiệm của dữ liệu được liên kết và chạy lại đánh giá trên dữ liệu thử nghiệm. Dữ liệu sẽ đến từ cùng một mẫu người dùng thực, nhưng từ một tập dữ liệu riêng biệt.

federated_test_data = make_federated_data(emnist_test, sample_clients)

len(federated_test_data), federated_test_data[0]
(10,
 <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>)
test_metrics = evaluation(state.model, federated_test_data)
str(test_metrics)
"OrderedDict([('eval', OrderedDict([('num_examples', 580.0), ('loss', 1.8361608), ('accuracy', 0.2413793)])), ('stat', OrderedDict([('num_examples', 580)]))])"

Điều này kết thúc hướng dẫn. Chúng tôi khuyến khích bạn chơi với các tham số (ví dụ: kích thước lô, số lượng người dùng, kỷ nguyên, tỷ lệ học tập, v.v.), để sửa đổi mã ở trên để mô phỏng đào tạo trên các mẫu ngẫu nhiên của người dùng trong mỗi vòng và khám phá các hướng dẫn khác chúng tôi đã phát triển.