tffのClientDataの操作。

TensorFlow.orgで表示GoogleColabで実行GitHubでソースを表示 ノートブックをダウンロード

TFFでモデル化されたフェデレーション計算には、クライアント(ユーザーなど)によってキー設定されたデータセットの概念が不可欠です。 TFFは、インタフェース提供tff.simulation.datasets.ClientDataこの概念上抽象的に、及びTFFホスト(データセットstackoverflowのシェイクスピアemnistcifar100 、及びgldv2は)全てこのインタフェースを実装します。

あなたがあなた自身のデータセットで連合学習に取り組んでいる場合は、TFFは強く実装するか、あなたを奨励ClientData生成するTFFのヘルパー関数のインタフェースまたは使用1をClientData例えば、ディスク上のデータを表しtff.simulation.datasets.ClientData.from_clients_and_fn

TFFのエンドツーエンドの例のほとんどは、で始まるのようClientData実装し、オブジェクトClientDataカスタムデータセットとのインタフェースは、それが簡単にTFFで書かれた既存のコードをspelunkするようになります。さらに、 tf.data.Datasets ClientData構築物の構造を生成する直接反復処理することができるnumpyように、アレイをClientDataオブジェクトは、TFFに移動する前に、PythonベースのMLフレームワークで使用することができます。

シミュレーションを多くのマシンにスケールアップしたり、それらをデプロイしたりする場合は、いくつかのパターンを使用して作業を楽にすることができます。我々は我々が使うことができる方法のいくつかを歩いて下記ClientData私たちの小さな規模の大規模な反復-への生産実験-への展開の経験をするとTFFをできるだけスムーズ。

ClientDataをTFFに渡すためにどのパターンを使用する必要がありますか?

私たちは、TFFの二使い方について説明しますClientData深さを。以下の2つのカテゴリのいずれかに当てはまる場合は、明らかに一方を他方よりも優先します。そうでない場合は、より微妙な選択を行うために、それぞれの長所と短所をより詳細に理解する必要があります。

  • ローカルマシンでできるだけ早く反復したい。 TFFの分散ランタイムを簡単に利用できる必要はありません。

    • あなたは渡したいtf.data.Datasets直接TFFにして。
    • これは、あなたが命令的にプログラムすることができますtf.data.Datasetオブジェクト、および任意にそれらを処理します。
    • 以下のオプションよりも柔軟性があります。ロジックをクライアントにプッシュするには、このロジックがシリアル化可能である必要があります。
  • フェデレーション計算をTFFのリモートランタイムで実行したい、またはすぐに実行する予定です。

    • この場合、データセットの構築と前処理をクライアントにマッピングする必要があります。
    • あなたのこの結果は、単にのリスト渡しclient_ids直接あなたの連合の計算に。
    • データセットの構築と前処理をクライアントにプッシュすることで、シリアル化のボトルネックを回避し、数百から数千のクライアントでパフォーマンスを大幅に向上させます。

オープンソース環境をセットアップする

パッケージをインポートする

ClientDataオブジェクトの操作

のは、ロードとTFFのEMNIST模索することから始めましょうClientData

client_data, _ = tff.simulation.datasets.emnist.load_data()
Downloading emnist_all.sqlite.lzma: 100%|██████████| 170507172/170507172 [00:19<00:00, 8831921.67it/s]
2021-10-01 11:17:58.718735: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

最初のデータセットを検査するにある例のどのような種類の私たちに伝えることができClientData

first_client_id = client_data.client_ids[0]
first_client_dataset = client_data.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
# This information is also available as a `ClientData` property:
assert client_data.element_type_structure == first_client_dataset.element_spec
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])

データセットが得られることに留意されたいcollections.OrderedDict有するオブジェクトpixelslabel画素は形状のテンソルであるキーを、 [28, 28]我々は形に私達の入力を平らたいと仮定し[784]我々はこれを行うことができます1つの可能な方法は、私たちに前処理機能を適用することであろうClientDataオブジェクト。

def preprocess_dataset(dataset):
  """Create batches of 5 examples, and limit to 3 batches."""

  def map_fn(input):
    return collections.OrderedDict(
        x=tf.reshape(input['pixels'], shape=(-1, 784)),
        y=tf.cast(tf.reshape(input['label'], shape=(-1, 1)), tf.int64),
    )

  return dataset.batch(5).map(
      map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).take(5)


preprocessed_client_data = client_data.preprocess(preprocess_dataset)

# Notice that we have both reshaped and renamed the elements of the ordered dict.
first_client_dataset = preprocessed_client_data.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])

さらに、シャッフルなど、より複雑な(場合によってはステートフルな)前処理を実行したい場合があります。

def preprocess_and_shuffle(dataset):
  """Applies `preprocess_dataset` above and shuffles the result."""
  preprocessed = preprocess_dataset(dataset)
  return preprocessed.shuffle(buffer_size=5)

preprocessed_and_shuffled = client_data.preprocess(preprocess_and_shuffle)

# The type signature will remain the same, but the batches will be shuffled.
first_client_dataset = preprocessed_and_shuffled.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])

インタフェースtff.Computation

今、我々がいくつかの基本的な操作を行うことができClientData我々はにフィードデータへの準備ができている、オブジェクトtff.Computation 。私たちは、定義tff.templates.IterativeProcess実装フェデレーション平均化を、そしてそれにデータを渡すの異なる方法を探ります。

def model_fn():
  model = tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
  ])
  return tff.learning.from_keras_model(
      model,
      # Note: input spec is the _batched_ shape, and includes the 
      # label tensor which will be passed to the loss function. This model is
      # therefore configured to accept data _after_ it has been preprocessed.
      input_spec=collections.OrderedDict(
          x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),
          y=tf.TensorSpec(shape=[None, 1], dtype=tf.int64)),
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

trainer = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.01))

私たちはこれで作業を開始する前にIterativeProcess 、の意味論上の1件のコメントClientDataオーダーです。 ClientData目的は、一般にある連合訓練のために利用できる人口の全体を表す生産FLシステムの実行環境では使用できませんし、シミュレーションに固有のものですが。 ClientData実際に完全にユーザにバイパスフェデレーテッド・コンピューティング能力を与え、単に介して通常どおりサーバ側モデルを訓練ClientData.create_tf_dataset_from_all_clients

TFFのシミュレーション環境により、研究者は外側のループを完全に制御できます。特に、これは、クライアントの可用性、クライアントのドロップアウトなどの考慮事項を、ユーザーまたはPythonドライバースクリプトで対処する必要があることを意味します。一つは、あなたの上に標本分布を調整することにより、例えば、モデルクライアントドロップアウトのためにできたClientData's client_idsより多くのデータを持つユーザー(それに対応し、ローカル計算を長いが、実行されている)ような低確率で選択されることになります。

ただし、実際の連合システムでは、モデルトレーナーがクライアントを明示的に選択することはできません。クライアントの選択は、フェデレーション計算を実行しているシステムに委任されます。

渡すtf.data.Datasets TFFに直接

我々は間のインタフェースを持っている一つの選択肢ClientDataIterativeProcess構築のそれであるtf.data.Datasets Pythonで、そしてTFFにこれらのデータセットを渡します。

私達は私達の前処理を使用する場合ことに注意してくださいClientData我々が得たデータセットは、上記で定義された我々のモデルにより予想の適切なタイプです。

selected_client_ids = preprocessed_and_shuffled.client_ids[:10]

preprocessed_data_for_clients = [
    preprocessed_and_shuffled.create_tf_dataset_for_client(
        selected_client_ids[i]) for i in range(10)
]

state = trainer.initialize()
for _ in range(5):
  t1 = time.time()
  state, metrics = trainer.next(state, preprocessed_data_for_clients)
  t2 = time.time()
  print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:62: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:62: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
loss 2.9005744457244873, round time 4.576513767242432
loss 3.113278388977051, round time 0.49641919136047363
loss 2.7581865787506104, round time 0.4904160499572754
loss 2.87259578704834, round time 0.48976993560791016
loss 3.1202380657196045, round time 0.6724586486816406

我々はこのルートを取る場合は、しかし、我々は自明多機シミュレーションに移動することができません。私たちは地元のTensorFlowランタイムで構築するデータセットは、周囲のPython環境から状態をキャプチャし、そして、彼らはもはやそれらに利用可能である基準状態にしようとすると、直列化または直列化復元に失敗することができます。これはTensorFlowのからの不可解なエラーの例のマニフェストことができtensor_util.cc

Check failed: DT_VARIANT == input.dtype() (21 vs. 20)

クライアント上でのマッピング構築と前処理

この問題を回避するには、TFFは、そのユーザーが各クライアント上でローカルに発生したものとして、データセットのインスタンス化と前処理を検討することをお勧めします、とTFFのヘルパーまたは使用するfederated_map明示的に各クライアントで、この前処理のコードを実行します。

概念的には、これを好む理由は明らかです。TFFのローカルランタイムでは、フェデレーションオーケストレーション全体が単一のマシンで行われているため、クライアントは「誤って」グローバルPython環境にのみアクセスできます。この時点で、同様の考え方がTFFのクロスプラットフォームで、常にシリアル化可能な機能哲学を生み出すことは注目に値します。

TFFは、経由して、このような変更が簡単になりますClientData's属性dataset_computationtff.Computation取りclient_idおよび関連返しtf.data.Dataset

注意preprocess単純で動作しますdataset_computationdataset_computation前処理の属性ClientData我々だけで定義された全体の前処理パイプラインが組み込まれています。

print('dataset computation without preprocessing:')
print(client_data.dataset_computation.type_signature)
print('\n')
print('dataset computation with preprocessing:')
print(preprocessed_and_shuffled.dataset_computation.type_signature)
dataset computation without preprocessing:
(string -> <label=int32,pixels=float32[28,28]>*)


dataset computation with preprocessing:
(string -> <x=float32[?,784],y=int64[?,1]>*)

私たちは、呼び出すことができdataset_computationし、Pythonランタイムでの熱心なデータセットを受け取り、私たちは、すべてのグローバル熱心ランタイムでこれらのデータセットを実体化を避けるために、反復プロセスまたは別の計算で構成したときに、このアプローチの本当の力を行使しています。 TFFは、ヘルパー関数を提供しtff.simulation.compose_dataset_computation_with_iterative_processまさにこれを行うために使用することができます。

trainer_accepting_ids = tff.simulation.compose_dataset_computation_with_iterative_process(
    preprocessed_and_shuffled.dataset_computation, trainer)

この両方tff.templates.IterativeProcessesと同じように実行する上記1。前者前処理され、クライアントデータセットを受け入れ、後者はクライアントIDを表す文字列、両方のデータセットの構築を処理し、その本体に前処理受け付け-実際にstate 2の間を通過することができます。

for _ in range(5):
  t1 = time.time()
  state, metrics = trainer_accepting_ids.next(state, selected_client_ids)
  t2 = time.time()
  print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))
loss 2.8417396545410156, round time 1.6707067489624023
loss 2.7670371532440186, round time 0.5207102298736572
loss 2.665048122406006, round time 0.5302855968475342
loss 2.7213189601898193, round time 0.5313887596130371
loss 2.580148935317993, round time 0.5283482074737549

多数のクライアントへのスケーリング

trainer_accepting_idsすぐにTFFのマルチマシンの実行時に使用することができ、そしてことを回避するには、マテリアライズtf.data.Datasetsとコントローラを(したがって、それらをシリアル化し、労働者にそれらを送信します)。

これにより、特に多数のクライアントで分散シミュレーションが大幅に高速化され、中間集約が可能になり、同様のシリアル化/逆シリアル化のオーバーヘッドを回避できます。

オプションの詳細:TFFで前処理ロジックを手動で作成する

TFFは、ゼロから構成性を実現するように設計されています。 TFFのヘルパーによって実行されたばかりの種類の構成は、ユーザーとして完全に制御できます。私たちはただで定義された前処理計算を手動で構成する必要があり可能性があり、トレーナー自身のnext非常に単純:

selected_clients_type = tff.FederatedType(preprocessed_and_shuffled.dataset_computation.type_signature.parameter, tff.CLIENTS)

@tff.federated_computation(trainer.next.type_signature.parameter[0], selected_clients_type)
def new_next(server_state, selected_clients):
  preprocessed_data = tff.federated_map(preprocessed_and_shuffled.dataset_computation, selected_clients)
  return trainer.next(server_state, preprocessed_data)

manual_trainer_with_preprocessing = tff.templates.IterativeProcess(initialize_fn=trainer.initialize, next_fn=new_next)

実際、これは事実上、私たちが使用したヘルパーが内部で行っていることです(さらに、適切な型のチェックと操作を実行しています)。私たちも、シリアル化により、若干異なる同じロジックを表明している可能性がpreprocess_and_shuffletff.Computation 、および分解federated_map非前処理されたデータセットと実行される別の構築1つのステップにpreprocess_and_shuffle各クライアントでは。

このより手動のパスにより、TFFのヘルパー(モジュロパラメーター名)と同じ型アノテーションを使用して計算が行われることを確認できます。

print(trainer_accepting_ids.next.type_signature)
print(manual_trainer_with_preprocessing.next.type_signature)
(<server_state=<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,federated_dataset={string}@CLIENTS> -> <<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,<broadcast=<>,aggregation=<mean_value=<>,mean_weight=<>>,train=<sparse_categorical_accuracy=float32,loss=float32>,stat=<num_examples=int64>>@SERVER>)
(<server_state=<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,selected_clients={string}@CLIENTS> -> <<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,<broadcast=<>,aggregation=<mean_value=<>,mean_weight=<>>,train=<sparse_categorical_accuracy=float32,loss=float32>,stat=<num_examples=int64>>@SERVER>)