Tái cấu trúc liên hợp để thừa số hóa ma trận

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ ghi chép

Hướng dẫn này tìm hiểu học tập liên phần địa phương, nơi một số thông số khách hàng không bao giờ được tổng hợp trên máy chủ. Điều này hữu ích cho các mô hình có các tham số dành riêng cho người dùng (ví dụ: mô hình thừa số hóa ma trận) và để đào tạo trong các cài đặt giới hạn giao tiếp. Chúng tôi xây dựng trên khái niệm được giới thiệu trong Học Federated cho hình ảnh Phân loại hướng dẫn; như trong hướng dẫn này, chúng tôi giới thiệu API cấp cao trong tff.learning cho đào tạo và đánh giá liên.

Chúng tôi bắt đầu bằng việc thúc đẩy học tập liên phần địa phương cho ma trận nhân tử . Chúng tôi mô tả Federated Tái thiết ( giấy , bài viết trên blog ), một thuật toán thực tế cho học tập liên phần địa phương theo tỷ lệ. Chúng tôi chuẩn bị tập dữ liệu MovieLens 1M, xây dựng mô hình cục bộ một phần, đào tạo và đánh giá nó.

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
import collections
import functools
import io
import os
import requests
import zipfile
from typing import List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(42)

Cơ sở: Dữ liệu hóa ma trận

Matrix nhân tử đã được một kỹ thuật phổ biến trong lịch sử cho việc học các kiến nghị và nhúng cơ quan đại diện cho các hạng mục dựa trên tương tác người dùng. Ví dụ kinh điển là giới thiệu phim, nơi có \(n\) người dùng và \(m\) phim ảnh, và người dùng đã đánh giá một số phim. Với một người dùng, chúng tôi sử dụng lịch sử xếp hạng của họ và xếp hạng của những người dùng tương tự để dự đoán xếp hạng của người dùng cho những bộ phim mà họ chưa xem. Nếu chúng tôi có một mô hình có thể dự đoán xếp hạng, thật dễ dàng để giới thiệu cho người dùng những bộ phim mới mà họ sẽ thích.

Đối với nhiệm vụ này, nó rất hữu ích để đại diện cho xếp hạng của người sử dụng như một \(n \times m\) ma trận \(R\):

Động lực tạo cơ sở dữ liệu ma trận (CC BY-SA 3.0; Người dùng Wikipedia Moshanin)

Ma trận này nói chung là thưa thớt, vì người dùng thường chỉ xem một phần nhỏ phim trong tập dữ liệu. Kết quả của ma trận nhân tử là hai ma trận: một \(n \times k\) ma trận \(U\) đại diện \(k\)embeddings sử dụng chiều cho mỗi người dùng, và một \(m \times k\) ma trận \(I\) đại diện \(k\)embeddings mục chiều cho mỗi mục. Mục tiêu đào tạo đơn giản nhất là để đảm bảo rằng sản phẩm chấm của người dùng và mục embeddings là tiên đoán xếp hạng quan sát \(O\):

\[argmin_{U,I} \sum_{(u, i) \in O} (R_{ui} - U_u I_i^T)^2\]

Điều này tương đương với việc giảm thiểu sai số bình phương trung bình giữa xếp hạng quan sát được và xếp hạng được dự đoán bằng cách lấy sản phẩm chấm của người dùng và mặt hàng nhúng tương ứng. Một cách khác để giải thích điều này là để đảm bảo này mà \(R \approx UI^T\) cho xếp hạng nổi tiếng, vì thế "ma trận nhân tử". Nếu điều này khó hiểu, đừng lo lắng – chúng ta sẽ không cần biết chi tiết về phân tích nhân tử của ma trận trong phần còn lại của hướng dẫn.

Khám phá dữ liệu MovieLens

Hãy bắt đầu bằng cách tải các MovieLens 1M dữ liệu, trong đó bao gồm 1.000.209 xếp hạng phim từ 6040 người dùng trên 3706 phim.

def download_movielens_data(dataset_path):
  """Downloads and copies MovieLens data to local /tmp directory."""
  if dataset_path.startswith('http'):
    r = requests.get(dataset_path)
    z = zipfile.ZipFile(io.BytesIO(r.content))
    z.extractall(path='/tmp')
  else:
    tf.io.gfile.makedirs('/tmp/ml-1m/')
    for filename in ['ratings.dat', 'movies.dat', 'users.dat']:
      tf.io.gfile.copy(
          os.path.join(dataset_path, filename),
          os.path.join('/tmp/ml-1m/', filename),
          overwrite=True)

download_movielens_data('http://files.grouplens.org/datasets/movielens/ml-1m.zip')
def load_movielens_data(
    data_directory: str = "/tmp",
) -> Tuple[pd.DataFrame, pd.DataFrame]:
  """Loads pandas DataFrames for ratings, movies, users from data directory."""
  # Load pandas DataFrames from data directory. Assuming data is formatted as
  # specified in http://files.grouplens.org/datasets/movielens/ml-1m-README.txt.
  ratings_df = pd.read_csv(
      os.path.join(data_directory, "ml-1m", "ratings.dat"),
      sep="::",
      names=["UserID", "MovieID", "Rating", "Timestamp"], engine="python")
  movies_df = pd.read_csv(
      os.path.join(data_directory, "ml-1m", "movies.dat"),
      sep="::",
      names=["MovieID", "Title", "Genres"], engine="python")

  # Create dictionaries mapping from old IDs to new (remapped) IDs for both
  # MovieID and UserID. Use the movies and users present in ratings_df to
  # determine the mapping, since movies and users without ratings are unneeded.
  movie_mapping = {
      old_movie: new_movie for new_movie, old_movie in enumerate(
          ratings_df.MovieID.astype("category").cat.categories)
  }
  user_mapping = {
      old_user: new_user for new_user, old_user in enumerate(
          ratings_df.UserID.astype("category").cat.categories)
  }

  # Map each DataFrame consistently using the now-fixed mapping.
  ratings_df.MovieID = ratings_df.MovieID.map(movie_mapping)
  ratings_df.UserID = ratings_df.UserID.map(user_mapping)
  movies_df.MovieID = movies_df.MovieID.map(movie_mapping)

  # Remove nulls resulting from some movies being in movies_df but not
  # ratings_df.
  movies_df = movies_df[pd.notnull(movies_df.MovieID)]

  return ratings_df, movies_df

Hãy tải và khám phá một vài Pandas DataFrames chứa dữ liệu xếp hạng và phim.

ratings_df, movies_df = load_movielens_data()

Chúng ta có thể thấy rằng mỗi ví dụ xếp hạng có xếp hạng từ 1-5, UserID tương ứng, MovieID tương ứng và dấu thời gian.

ratings_df.head()

Mỗi bộ phim có một tiêu đề và có nhiều thể loại.

movies_df.head()

Luôn luôn là một ý kiến ​​hay để hiểu các số liệu thống kê cơ bản của tập dữ liệu:

print('Num users:', len(set(ratings_df.UserID)))
print('Num movies:', len(set(ratings_df.MovieID)))
Num users: 6040
Num movies: 3706
ratings = ratings_df.Rating.tolist()

plt.hist(ratings, bins=5)
plt.xticks([1, 2, 3, 4, 5])
plt.ylabel('Count')
plt.xlabel('Rating')
plt.show()

print('Average rating:', np.mean(ratings))
print('Median rating:', np.median(ratings))

png

Average rating: 3.581564453029317
Median rating: 4.0

Chúng tôi cũng có thể sắp xếp các thể loại phim phổ biến nhất.

movie_genres_list = movies_df.Genres.tolist()
# Count the number of times each genre describes a movie.
genre_count = collections.defaultdict(int)
for genres in movie_genres_list:
  curr_genres_list = genres.split('|')
  for genre in curr_genres_list:
    genre_count[genre] += 1
genre_name_list, genre_count_list = zip(*genre_count.items())

plt.figure(figsize=(11, 11))
plt.pie(genre_count_list, labels=genre_name_list)
plt.title('MovieLens Movie Genres')
plt.show()

png

Dữ liệu này được phân chia tự nhiên thành các xếp hạng từ những người dùng khác nhau, vì vậy chúng tôi mong đợi sự không đồng nhất về dữ liệu giữa các khách hàng. Dưới đây, chúng tôi hiển thị các thể loại phim được đánh giá phổ biến nhất cho những người dùng khác nhau. Chúng tôi có thể quan sát thấy sự khác biệt đáng kể giữa những người dùng.

def print_top_genres_for_user(ratings_df, movies_df, user_id):
  """Prints top movie genres for user with ID user_id."""
  user_ratings_df = ratings_df[ratings_df.UserID == user_id]
  movie_ids = user_ratings_df.MovieID

  genre_count = collections.Counter()
  for movie_id in movie_ids:
    genres_string = movies_df[movies_df.MovieID == movie_id].Genres.tolist()[0]
    for genre in genres_string.split('|'):
      genre_count[genre] += 1

  print(f'\nFor user {user_id}:')
  for (genre, freq) in genre_count.most_common(5):
    print(f'{genre} was rated {freq} times')

print_top_genres_for_user(ratings_df, movies_df, user_id=0)
print_top_genres_for_user(ratings_df, movies_df, user_id=10)
print_top_genres_for_user(ratings_df, movies_df, user_id=19)
For user 0:
Drama was rated 21 times
Children's was rated 20 times
Animation was rated 18 times
Musical was rated 14 times
Comedy was rated 14 times

For user 10:
Comedy was rated 84 times
Drama was rated 54 times
Romance was rated 22 times
Thriller was rated 18 times
Action was rated 9 times

For user 19:
Action was rated 17 times
Sci-Fi was rated 9 times
Thriller was rated 9 times
Drama was rated 6 times
Crime was rated 5 times

Xử lý trước dữ liệu MovieLens

Bây giờ chúng ta sẽ chuẩn bị các tập dữ liệu MovieLens như một danh sách tf.data.Dataset s đại diện cho dữ liệu của người dùng để sử dụng với TFF.

Chúng tôi thực hiện hai chức năng:

  • create_tf_datasets : mất DataFrame xếp hạng của chúng tôi và tạo ra một danh sách các người dùng chia tf.data.Dataset s.
  • split_tf_datasets : mất một danh sách các tập hợp dữ liệu và chia chúng thành tàu / val / kiểm tra bằng cách sử dụng, do đó val / bộ kiểm tra có chứa chỉ xếp hạng từ những người dùng vô hình trong đào tạo. Điển hình là trong tiêu chuẩn trung ma trận nhân tử chúng tôi thực sự chia để các bộ val / test chứa điểm lấy ra xếp hạng từ những người dùng nhìn thấy, kể từ khi người dùng không nhìn thấy không có embeddings người dùng. Trong trường hợp của chúng tôi, sau này chúng tôi sẽ thấy rằng phương pháp chúng tôi sử dụng để kích hoạt phân tích nhân tử ma trận trong FL cũng cho phép nhanh chóng tạo lại các nhúng của người dùng cho những người dùng không nhìn thấy.
def create_tf_datasets(ratings_df: pd.DataFrame,
                       batch_size: int = 1,
                       max_examples_per_user: Optional[int] = None,
                       max_clients: Optional[int] = None) -> List[tf.data.Dataset]:
  """Creates TF Datasets containing the movies and ratings for all users."""
  num_users = len(set(ratings_df.UserID))
  # Optionally limit to `max_clients` to speed up data loading.
  if max_clients is not None:
    num_users = min(num_users, max_clients)

  def rating_batch_map_fn(rating_batch):
    """Maps a rating batch to an OrderedDict with tensor values."""
    # Each example looks like: {x: movie_id, y: rating}.
    # We won't need the UserID since each client will only look at their own
    # data.
    return collections.OrderedDict([
        ("x", tf.cast(rating_batch[:, 1:2], tf.int64)),
        ("y", tf.cast(rating_batch[:, 2:3], tf.float32))
    ])

  tf_datasets = []
  for user_id in range(num_users):
    # Get subset of ratings_df belonging to a particular user.
    user_ratings_df = ratings_df[ratings_df.UserID == user_id]

    tf_dataset = tf.data.Dataset.from_tensor_slices(user_ratings_df)

    # Define preprocessing operations.
    tf_dataset = tf_dataset.take(max_examples_per_user).shuffle(
        buffer_size=max_examples_per_user, seed=42).batch(batch_size).map(
        rating_batch_map_fn,
        num_parallel_calls=tf.data.experimental.AUTOTUNE)
    tf_datasets.append(tf_dataset)

  return tf_datasets


def split_tf_datasets(
    tf_datasets: List[tf.data.Dataset],
    train_fraction: float = 0.8,
    val_fraction: float = 0.1,
) -> Tuple[List[tf.data.Dataset], List[tf.data.Dataset], List[tf.data.Dataset]]:
  """Splits a list of user TF datasets into train/val/test by user.
  """
  np.random.seed(42)
  np.random.shuffle(tf_datasets)

  train_idx = int(len(tf_datasets) * train_fraction)
  val_idx = int(len(tf_datasets) * (train_fraction + val_fraction))

  # Note that the val and test data contains completely different users, not
  # just unseen ratings from train users.
  return (tf_datasets[:train_idx], tf_datasets[train_idx:val_idx],
          tf_datasets[val_idx:])
# We limit the number of clients to speed up dataset creation. Feel free to pass
# max_clients=None to load all clients' data.
tf_datasets = create_tf_datasets(
    ratings_df=ratings_df,
    batch_size=5,
    max_examples_per_user=300,
    max_clients=2000)

# Split the ratings into training/val/test by client.
tf_train_datasets, tf_val_datasets, tf_test_datasets = split_tf_datasets(
    tf_datasets,
    train_fraction=0.8,
    val_fraction=0.1)

Khi kiểm tra nhanh, chúng tôi có thể in một loạt dữ liệu đào tạo. Chúng ta có thể thấy rằng mỗi ví dụ riêng lẻ chứa MovieID dưới phím "x" và xếp hạng dưới phím "y". Lưu ý rằng chúng tôi sẽ không cần UserID vì mỗi người dùng chỉ thấy dữ liệu của riêng họ.

print(next(iter(tf_train_datasets[0])))
OrderedDict([('x', <tf.Tensor: shape=(5, 1), dtype=int64, numpy=
array([[1907],
       [2891],
       [1574],
       [2785],
       [2775]])>), ('y', <tf.Tensor: shape=(5, 1), dtype=float32, numpy=
array([[3.],
       [3.],
       [3.],
       [4.],
       [3.]], dtype=float32)>)])

Chúng tôi có thể vẽ biểu đồ hiển thị số lượng xếp hạng trên mỗi người dùng.

def count_examples(curr_count, batch):
  return curr_count + tf.size(batch['x'])

num_examples_list = []
# Compute number of examples for every other user.
for i in range(0, len(tf_train_datasets), 2):
  num_examples = tf_train_datasets[i].reduce(tf.constant(0), count_examples).numpy()
  num_examples_list.append(num_examples)

plt.hist(num_examples_list, bins=10)
plt.ylabel('Count')
plt.xlabel('Number of Examples')
plt.show()

png

Bây giờ chúng ta đã tải và khám phá dữ liệu, chúng ta sẽ thảo luận về cách đưa sự phân tích nhân tử của ma trận vào học liên kết. Trong quá trình này, chúng tôi sẽ thúc đẩy học tập liên kết cục bộ một phần.

Đưa dữ liệu hóa ma trận vào FL

Mặc dù thừa số hóa ma trận đã được sử dụng theo cách truyền thống trong cài đặt tập trung, nhưng nó đặc biệt có liên quan trong học tập liên kết: xếp hạng người dùng có thể tồn tại trên các thiết bị khách riêng biệt và chúng tôi có thể muốn tìm hiểu cách nhúng và đề xuất cho người dùng và các mục mà không cần tập trung dữ liệu. Vì mỗi người dùng có một người dùng nhúng tương ứng, điều tự nhiên là mỗi khách hàng lưu trữ bản nhúng của người dùng của họ – điều này sẽ tốt hơn nhiều so với một máy chủ trung tâm lưu trữ tất cả các bản nhúng của người dùng.

Một đề xuất để đưa dữ liệu hóa ma trận sang FL như sau:

  1. Các cửa hàng máy chủ và gửi ma trận mục \(I\) cho các khách hàng lấy mẫu mỗi vòng
  2. Khách hàng cập nhật bảng kê hàng và người sử dụng cá nhân của họ nhúng \(U_u\) sử dụng SGD vào mục tiêu trên
  3. Cập nhật cho \(I\) được tổng hợp trên máy chủ, cập nhật các bản sao máy chủ của \(I\) cho vòng tiếp theo

Cách tiếp cận này là một phần địa phương -đó là, một số thông số khách hàng không bao giờ được tổng hợp bởi máy chủ. Mặc dù cách tiếp cận này hấp dẫn, nhưng nó yêu cầu khách hàng duy trì trạng thái qua các vòng, cụ thể là người dùng nhúng của họ. Các thuật toán liên kết trạng thái ít thích hợp hơn cho cài đặt FL trên nhiều thiết bị: trong các cài đặt này, quy mô dân số thường lớn hơn nhiều so với số lượng khách hàng tham gia vào mỗi vòng và một khách hàng thường tham gia nhiều nhất một lần trong quá trình đào tạo. Bên cạnh đó dựa vào nhà nước mà có thể không được khởi tạo, các thuật toán trạng thái có thể dẫn đến suy giảm hiệu suất trong cài đặt trên nhiều thiết bị do Nhà nước nhận được khi khách hàng được thường xuyên lấy mẫu. Quan trọng là, trong cài đặt thừa số hóa ma trận, một thuật toán trạng thái dẫn đến tất cả các ứng dụng khách không nhìn thấy được thiếu nhúng người dùng được đào tạo và trong quá trình đào tạo quy mô lớn, phần lớn người dùng có thể không nhìn thấy. Để biết thêm về các động lực cho các thuật toán không quốc tịch ở trên nhiều thiết bị FL, xem Wang et al. Năm 2021 3.1.1Reddi et al. 2020 Giây 5.1 .

Federated Tái thiết ( Singhal et al. 2021 ) là một lựa chọn quốc tịch cho cách tiếp cận nói trên. Ý tưởng chính là thay vì lưu trữ các lần nhúng của người dùng qua các vòng, khách hàng sẽ cấu trúc lại các lần nhúng của người dùng khi cần thiết. Khi FedRecon được áp dụng để phân tích nhân tử ma trận, quá trình huấn luyện sẽ diễn ra như sau:

  1. Các cửa hàng máy chủ và gửi ma trận mục \(I\) cho các khách hàng lấy mẫu mỗi vòng
  2. Mỗi khách hàng đóng băng \(I\) và huấn luyện sử dụng nhúng của họ \(U_u\) sử dụng một hoặc nhiều bước của SGD (tái)
  3. Mỗi khách hàng đóng băng \(U_u\) và xe lửa \(I\) sử dụng một hoặc nhiều bước của SGD
  4. Cập nhật cho \(I\) được tổng hợp trên người dùng, cập nhật các bản sao máy chủ của \(I\) cho vòng tiếp theo

Cách tiếp cận này không yêu cầu khách hàng duy trì trạng thái qua các vòng. Các tác giả cũng cho thấy trong bài báo rằng phương pháp này dẫn đến việc tái tạo nhanh quá trình nhúng của người dùng cho các khách hàng không nhìn thấy (Phần 4.2, Hình 3 và Bảng 1), cho phép phần lớn khách hàng không tham gia đào tạo có một mô hình được đào tạo , cho phép đề xuất cho những khách hàng này. Xem Federated Tái Blog Google AI bài cho kết quả quan trọng hơn.

Xác định mô hình

Tiếp theo, chúng tôi sẽ xác định mô hình phân tích nhân tử ma trận cục bộ sẽ được đào tạo trên các thiết bị khách. Mô hình này sẽ bao gồm đầy đủ ma trận mục \(I\) và người dùng nhúng đơn \(U_u\) cho khách hàng \(u\). Lưu ý rằng khách hàng sẽ không cần phải lưu trữ đầy đủ sử dụng ma trận \(U\).

Chúng tôi sẽ xác định những điều sau:

  • UserEmbedding : một lớp Keras đơn giản đại diện cho một đơn num_latent_factors nhúng sử dụng chiều.
  • get_matrix_factorization_model : một hàm trả về một tff.learning.reconstruction.Model chứa logic mô hình, trong đó có những lớp được tổng hợp trên toàn cầu trên máy chủ và những lớp vẫn còn địa phương. Chúng tôi cần thông tin bổ sung này để khởi tạo quy trình đào tạo Tái thiết Liên bang. Dưới đây chúng tôi sản xuất các tff.learning.reconstruction.Model từ một mô hình Keras sử dụng tff.learning.reconstruction.from_keras_model . Tương tự như tff.learning.Model , chúng tôi cũng có thể thực hiện một phong tục tff.learning.reconstruction.Model bằng cách thực hiện các giao diện lớp.
class UserEmbedding(tf.keras.layers.Layer):
  """Keras layer representing an embedding for a single user, used below."""

  def __init__(self, num_latent_factors, **kwargs):
    super().__init__(**kwargs)
    self.num_latent_factors = num_latent_factors

  def build(self, input_shape):
    self.embedding = self.add_weight(
        shape=(1, self.num_latent_factors),
        initializer='uniform',
        dtype=tf.float32,
        name='UserEmbeddingKernel')
    super().build(input_shape)

  def call(self, inputs):
    return self.embedding

  def compute_output_shape(self):
    return (1, self.num_latent_factors)


def get_matrix_factorization_model(
    num_items: int,
    num_latent_factors: int) -> tff.learning.reconstruction.Model:
  """Defines a Keras matrix factorization model."""
  # Layers with variables will be partitioned into global and local layers.
  # We'll pass this to `tff.learning.reconstruction.from_keras_model`.
  global_layers = []
  local_layers = []

  # Extract the item embedding.
  item_input = tf.keras.layers.Input(shape=[1], name='Item')
  item_embedding_layer = tf.keras.layers.Embedding(
      num_items,
      num_latent_factors,
      name='ItemEmbedding')
  global_layers.append(item_embedding_layer)
  flat_item_vec = tf.keras.layers.Flatten(name='FlattenItems')(
      item_embedding_layer(item_input))

  # Extract the user embedding.
  user_embedding_layer = UserEmbedding(
      num_latent_factors,
      name='UserEmbedding')
  local_layers.append(user_embedding_layer)

  # The item_input never gets used by the user embedding layer,
  # but this allows the model to directly use the user embedding.
  flat_user_vec = user_embedding_layer(item_input)

  # Compute the dot product between the user embedding, and the item one.
  pred = tf.keras.layers.Dot(
      1, normalize=False, name='Dot')([flat_user_vec, flat_item_vec])

  input_spec = collections.OrderedDict(
      x=tf.TensorSpec(shape=[None, 1], dtype=tf.int64),
      y=tf.TensorSpec(shape=[None, 1], dtype=tf.float32))

  model = tf.keras.Model(inputs=item_input, outputs=pred)

  return tff.learning.reconstruction.from_keras_model(
      keras_model=model,
      global_layers=global_layers,
      local_layers=local_layers,
      input_spec=input_spec)

Analagous đến giao diện cho Federated trung bình, giao diện cho Federated Tái hy vọng một model_fn không có đối số mà trả về một tff.learning.reconstruction.Model .

# This will be used to produce our training process.
# User and item embeddings will be 50-dimensional.
model_fn = functools.partial(
    get_matrix_factorization_model,
    num_items=3706,
    num_latent_factors=50)

Chúng tôi sẽ tiếp theo xác định loss_fnmetrics_fn , nơi loss_fn là một hàm không có đối số trả về một mất mát Keras sử dụng để đào tạo các mô hình, và metrics_fn là một hàm không có đối số trả về một danh sách các số liệu Keras để đánh giá. Chúng cần thiết để xây dựng các tính toán đào tạo và đánh giá.

Chúng tôi sẽ sử dụng Mean Squared Error làm tổn thất, như đã đề cập ở trên. Để đánh giá, chúng tôi sẽ sử dụng độ chính xác xếp hạng (khi sản phẩm chấm dự đoán của mô hình được làm tròn đến số nguyên gần nhất, tần suất nó khớp với xếp hạng nhãn là bao nhiêu?).

class RatingAccuracy(tf.keras.metrics.Mean):
  """Keras metric computing accuracy of reconstructed ratings."""

  def __init__(self,
               name: str = 'rating_accuracy',
               **kwargs):
    super().__init__(name=name, **kwargs)

  def update_state(self,
                   y_true: tf.Tensor,
                   y_pred: tf.Tensor,
                   sample_weight: Optional[tf.Tensor] = None):
    absolute_diffs = tf.abs(y_true - y_pred)
    # A [batch_size, 1] tf.bool tensor indicating correctness within the
    # threshold for each example in a batch. A 0.5 threshold corresponds
    # to correctness when predictions are rounded to the nearest whole
    # number.
    example_accuracies = tf.less_equal(absolute_diffs, 0.5)
    super().update_state(example_accuracies, sample_weight=sample_weight)


loss_fn = lambda: tf.keras.losses.MeanSquaredError()
metrics_fn = lambda: [RatingAccuracy()]

Đào tạo và Đánh giá

Bây giờ chúng tôi có mọi thứ chúng tôi cần để xác định quá trình đào tạo. Một khác biệt quan trọng từ giao diện cho Federated trung bình là bây giờ chúng ta vượt qua trong một reconstruction_optimizer_fn , mà sẽ được sử dụng khi xây dựng lại các thông số địa phương (trong trường hợp của chúng tôi, embeddings người dùng). Nó thường hợp lý để sử dụng SGD ở đây, với một tương tự hoặc hơi hạ thấp tỷ lệ học hơn so với khách hàng tối ưu hóa tốc độ học tập. Chúng tôi cung cấp cấu hình hoạt động bên dưới. Điều này chưa được điều chỉnh cẩn thận, vì vậy hãy thoải mái thử nghiệm với các giá trị khác nhau.

Kiểm tra các tài liệu hướng dẫn để biết thêm chi tiết và các tùy chọn.

# We'll use this by doing:
# state = training_process.initialize()
# state, metrics = training_process.next(state, federated_train_data)
training_process = tff.learning.reconstruction.build_training_process(
    model_fn=model_fn,
    loss_fn=loss_fn,
    metrics_fn=metrics_fn,
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0),
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.5),
    reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1))

Chúng tôi cũng có thể xác định một phép tính để đánh giá mô hình toàn cầu được đào tạo của chúng tôi.

# We'll use this by doing:
# eval_metrics = evaluation_computation(state.model, tf_val_datasets)
# where `state` is the state from the training process above.
evaluation_computation = tff.learning.reconstruction.build_federated_evaluation(
    model_fn,
    loss_fn=loss_fn,
    metrics_fn=metrics_fn,
    reconstruction_optimizer_fn=functools.partial(
            tf.keras.optimizers.SGD, 0.1))

Chúng ta có thể khởi tạo trạng thái quá trình đào tạo và kiểm tra nó. Quan trọng nhất, chúng ta có thể thấy rằng trạng thái máy chủ này chỉ lưu trữ các biến mặt hàng (hiện được khởi tạo ngẫu nhiên) chứ không phải bất kỳ nhúng nào của người dùng.

state = training_process.initialize()
print(state.model)
print('Item variables shape:', state.model.trainable[0].shape)
ModelWeights(trainable=[array([[-0.02840446,  0.01196523, -0.01864688, ...,  0.03020107,
         0.00121176,  0.00146852],
       [ 0.01330637,  0.04741272, -0.01487445, ..., -0.03352419,
         0.0104811 ,  0.03506917],
       [-0.04132779,  0.04883525, -0.04799002, ...,  0.00246904,
         0.00586842,  0.01506213],
       ...,
       [ 0.0216659 ,  0.00734354,  0.00471039, ...,  0.01596491,
        -0.00220431, -0.01559857],
       [-0.00319657, -0.01740328,  0.02808609, ..., -0.00501985,
        -0.03850871, -0.03844522],
       [ 0.03791947, -0.00035037,  0.04217024, ...,  0.00365371,
         0.00283421,  0.00897921]], dtype=float32)], non_trainable=[])
Item variables shape: (3706, 50)

Chúng tôi cũng có thể thử đánh giá mô hình được khởi tạo ngẫu nhiên của mình trên các máy khách xác thực. Đánh giá tái thiết liên kết ở đây bao gồm những điều sau:

  1. Các máy chủ gửi các ma trận mục \(I\) cho các khách hàng đánh giá mẫu
  2. Mỗi khách hàng đóng băng \(I\) và huấn luyện sử dụng nhúng của họ \(U_u\) sử dụng một hoặc nhiều bước của SGD (tái)
  3. Mỗi khách hàng tính toán của mất mát và số liệu sử dụng máy chủ \(I\) và xây dựng lại \(U_u\) trên một phần vô hình của dữ liệu địa phương
  4. Tổn thất và chỉ số được tính trung bình giữa những người dùng để tính toán tổn thất tổng thể và chỉ số

Lưu ý rằng bước 1 và bước 2 giống như đối với đào tạo. Kết nối này là rất quan trọng, vì đào tạo theo cùng một cách chúng tôi đánh giá dẫn đến một hình thức meta-học tập, hay học cách để học hỏi. Trong trường hợp này, mô hình đang học cách tìm hiểu các biến toàn cục (ma trận mục) dẫn đến việc tái tạo hiệu quả các biến cục bộ (người dùng nhúng). Để biết thêm về vấn đề này, xem Sec. 4.2 của bài báo.

Điều quan trọng là các bước 2 và 3 phải được thực hiện bằng cách sử dụng các phần riêng biệt của dữ liệu cục bộ của khách hàng, để đảm bảo đánh giá công bằng. Theo mặc định, cả quá trình đào tạo và tính toán đánh giá đều sử dụng mọi ví dụ khác để tái tạo và sử dụng nửa còn lại sau tái tạo. Hành vi này có thể được tùy chỉnh bằng cách sử dụng dataset_split_fn luận (chúng tôi sẽ khám phá này hơn nữa sau).

# We shouldn't expect good evaluation results here, since we haven't trained
# yet!
eval_metrics = evaluation_computation(state.model, tf_val_datasets)
print('Initial Eval:', eval_metrics['eval'])
Initial Eval: OrderedDict([('loss', 14.340279), ('rating_accuracy', 0.0)])

Tiếp theo chúng ta có thể thử chạy một vòng đào tạo. Để làm cho mọi thứ thực tế hơn, chúng tôi sẽ lấy mẫu ngẫu nhiên 50 khách hàng mỗi vòng mà không cần thay thế. Chúng ta vẫn nên mong đợi các chỉ số đào tạo kém, vì chúng tôi chỉ thực hiện một vòng đào tạo.

federated_train_data = np.random.choice(tf_train_datasets, size=50, replace=False).tolist()
state, metrics = training_process.next(state, federated_train_data)
print(f'Train metrics:', metrics['train'])
Train metrics: OrderedDict([('rating_accuracy', 0.0), ('loss', 14.317455)])

Bây giờ chúng ta hãy thiết lập một vòng huấn luyện để huấn luyện qua nhiều vòng.

NUM_ROUNDS = 20

train_losses = []
train_accs = []

state = training_process.initialize()

# This may take a couple minutes to run.
for i in range(NUM_ROUNDS):
  federated_train_data = np.random.choice(tf_train_datasets, size=50, replace=False).tolist()
  state, metrics = training_process.next(state, federated_train_data)
  print(f'Train round {i}:', metrics['train'])
  train_losses.append(metrics['train']['loss'])
  train_accs.append(metrics['train']['rating_accuracy'])


eval_metrics = evaluation_computation(state.model, tf_val_datasets)
print('Final Eval:', eval_metrics['eval'])
Train round 0: OrderedDict([('rating_accuracy', 0.0), ('loss', 14.7013445)])
Train round 1: OrderedDict([('rating_accuracy', 0.0), ('loss', 14.459233)])
Train round 2: OrderedDict([('rating_accuracy', 0.0), ('loss', 14.52466)])
Train round 3: OrderedDict([('rating_accuracy', 0.0), ('loss', 14.087793)])
Train round 4: OrderedDict([('rating_accuracy', 0.011243612), ('loss', 11.110232)])
Train round 5: OrderedDict([('rating_accuracy', 0.06366048), ('loss', 8.267054)])
Train round 6: OrderedDict([('rating_accuracy', 0.12331288), ('loss', 5.2693872)])
Train round 7: OrderedDict([('rating_accuracy', 0.14264487), ('loss', 5.1511016)])
Train round 8: OrderedDict([('rating_accuracy', 0.21046545), ('loss', 3.8246362)])
Train round 9: OrderedDict([('rating_accuracy', 0.21320973), ('loss', 3.303812)])
Train round 10: OrderedDict([('rating_accuracy', 0.21651311), ('loss', 3.4864292)])
Train round 11: OrderedDict([('rating_accuracy', 0.23476052), ('loss', 3.0105433)])
Train round 12: OrderedDict([('rating_accuracy', 0.21981856), ('loss', 3.1807854)])
Train round 13: OrderedDict([('rating_accuracy', 0.27683082), ('loss', 2.3382564)])
Train round 14: OrderedDict([('rating_accuracy', 0.26080742), ('loss', 2.7009728)])
Train round 15: OrderedDict([('rating_accuracy', 0.2733109), ('loss', 2.2993557)])
Train round 16: OrderedDict([('rating_accuracy', 0.29282996), ('loss', 2.5278995)])
Train round 17: OrderedDict([('rating_accuracy', 0.30204678), ('loss', 2.060092)])
Train round 18: OrderedDict([('rating_accuracy', 0.2940266), ('loss', 2.0976772)])
Train round 19: OrderedDict([('rating_accuracy', 0.3086304), ('loss', 2.0626144)])
Final Eval: OrderedDict([('loss', 1.9961331), ('rating_accuracy', 0.30322924)])

Chúng tôi có thể lập biểu đồ về sự mất mát và độ chính xác trong quá trình huấn luyện qua các vòng. Các siêu tham số trong sổ tay này chưa được điều chỉnh cẩn thận, vì vậy hãy thử các khách hàng khác nhau mỗi vòng, tỷ lệ học tập, số vòng và tổng số khách hàng để cải thiện các kết quả này.

plt.plot(range(NUM_ROUNDS), train_losses)
plt.ylabel('Train Loss')
plt.xlabel('Round')
plt.title('Train Loss')
plt.show()

plt.plot(range(NUM_ROUNDS), train_accs)
plt.ylabel('Train Accuracy')
plt.xlabel('Round')
plt.title('Train Accuracy')
plt.show()

png

png

Cuối cùng, chúng tôi có thể tính toán các chỉ số trên một tập hợp thử nghiệm chưa nhìn thấy khi chúng tôi điều chỉnh xong.

eval_metrics = evaluation_computation(state.model, tf_test_datasets)
print('Final Test:', eval_metrics['eval'])
Final Test: OrderedDict([('loss', 1.9566978), ('rating_accuracy', 0.30792442)])

Khám phá thêm

Rất tốt khi hoàn thành cuốn sổ này. Chúng tôi đề xuất các bài tập sau để khám phá thêm việc học liên kết cục bộ từng phần, được sắp xếp theo thứ tự tăng dần theo độ khó:

  • Các triển khai điển hình của Tính trung bình liên kết thực hiện nhiều lần chuyển cục bộ (kỷ nguyên) đối với dữ liệu (ngoài việc thực hiện một lần chuyển dữ liệu qua nhiều đợt). Đối với Tái thiết Liên bang, chúng tôi có thể muốn kiểm soát số lượng các bước riêng biệt cho quá trình tái thiết và đào tạo sau tái thiết. Đi qua các dataset_split_fn tranh luận với các nhà xây dựng tính toán đào tạo và đánh giá cho phép kiểm soát số lượng các bước và các thời kỳ trên cả hai tái thiết và sau khi xây dựng lại bộ dữ liệu. Như một bài tập, hãy thử thực hiện 3 kỷ nguyên địa phương của đào tạo tái thiết, giới hạn ở 50 bước và 1 kỷ nguyên địa phương của đào tạo sau tái thiết, giới hạn ở 50 bước. Gợi ý: Bạn có thể tìm thấy tff.learning.reconstruction.build_dataset_split_fn hữu ích. Khi bạn đã hoàn thành việc này, hãy thử điều chỉnh các siêu tham số này và các siêu tham số khác có liên quan như tốc độ học tập và kích thước lô để có được kết quả tốt hơn.

  • Hành vi mặc định của đào tạo và đánh giá Liên kết Tái thiết là chia đôi dữ liệu cục bộ của khách hàng cho mỗi quá trình tái tạo và sau xây dựng lại. Trong trường hợp khách hàng có rất ít dữ liệu cục bộ, có thể hợp lý nếu chỉ sử dụng lại dữ liệu để tái tạo và sau xây dựng lại cho quá trình đào tạo (không dùng để đánh giá, điều này sẽ dẫn đến đánh giá không công bằng). Cố gắng làm cho sự thay đổi này cho quá trình đào tạo, đảm bảo dataset_split_fn để đánh giá vẫn giữ tái thiết và sau khi xây dựng lại dữ liệu rời nhau. Gợi ý: tff.learning.reconstruction.simple_dataset_split_fn có thể có ích.

  • Ở trên, chúng tôi đã tạo ra một tff.learning.Model từ một mô hình Keras sử dụng tff.learning.reconstruction.from_keras_model . Chúng tôi cũng có thể thực hiện một mô hình tùy chỉnh sử dụng TensorFlow tinh khiết 2.0 thực hiện giao diện mô hình . Hãy thử thay đổi get_matrix_factorization_model để xây dựng và trả về một lớp học kéo dài tff.learning.reconstruction.Model , thực hiện các phương pháp của nó. Gợi ý: mã nguồn của tff.learning.reconstruction.from_keras_model cung cấp một ví dụ về việc mở rộng tff.learning.reconstruction.Model lớp. Tham khảo còn cho triển khai mô hình tùy chỉnh trong hình ảnh EMNIST phân loại hướng dẫn cho một cuộc tập trận tương tự trong việc mở rộng một tff.learning.Model .

  • Trong hướng dẫn này, chúng tôi đã thúc đẩy việc học liên kết cục bộ một phần trong bối cảnh phân tích nhân tử ma trận, trong đó việc gửi nhúng người dùng đến máy chủ sẽ làm rò rỉ một cách đáng kể các tùy chọn của người dùng. Chúng ta cũng có thể áp dụng Xây dựng lại liên kết trong các cài đặt khác như một cách để đào tạo các mô hình cá nhân hơn (vì một phần của mô hình là hoàn toàn cục bộ đối với mỗi người dùng) trong khi giảm giao tiếp (vì các tham số cục bộ không được gửi đến máy chủ). Nói chung, bằng cách sử dụng giao diện được trình bày ở đây, chúng ta có thể sử dụng bất kỳ mô hình liên hợp nào thường được đào tạo toàn cầu và thay vào đó phân vùng các biến của nó thành các biến toàn cục và biến cục bộ. Ví dụ khám phá trong giấy Federated Tái thiết là cá nhân dự đoán từ tiếp theo: ở đây, mỗi người dùng có bộ địa phương riêng của họ embeddings lời cho out-of-từ vựng từ, cho phép các mô hình để lóng chụp của người sử dụng và đạt được cá nhân hóa mà không cần giao tiếp bổ sung. Như một bài tập, hãy thử triển khai (dưới dạng mô hình Keras hoặc mô hình TensorFlow 2.0 tùy chỉnh) một mô hình khác để sử dụng với Cấu trúc liên kết. Gợi ý: triển khai mô hình phân loại EMNIST với nhúng người dùng cá nhân, trong đó nhúng người dùng cá nhân được nối với các tính năng hình ảnh CNN trước lớp Dày cuối cùng của mô hình. Bạn có thể tái sử dụng nhiều mã từ hướng dẫn này (ví dụ như UserEmbedding lớp) và hình ảnh phân loại hướng dẫn .


Nếu bạn vẫn đang tìm kiếm thêm về việc học liên phần địa phương, kiểm tra giấy Federated Tái thiếtmã nguồn mở mã thử nghiệm .