TensorFlow.orgで表示 | GoogleColabで実行 | GitHubでソースを表示 | ノートブックをダウンロード |
TFFでモデル化されたフェデレーション計算には、クライアント(ユーザーなど)によってキー設定されたデータセットの概念が不可欠です。 TFFは、インタフェース提供tff.simulation.datasets.ClientData
この概念上抽象的に、及びTFFホスト(データセットstackoverflowの、シェイクスピア、 emnist 、 cifar100 、及びgldv2は)全てこのインタフェースを実装します。
あなたがあなた自身のデータセットで連合学習に取り組んでいる場合は、TFFは強く実装するか、あなたを奨励ClientData
生成するTFFのヘルパー関数のインタフェースまたは使用1をClientData
例えば、ディスク上のデータを表しtff.simulation.datasets.ClientData.from_clients_and_fn
。
TFFのエンドツーエンドの例のほとんどは、で始まるのようClientData
実装し、オブジェクトClientData
カスタムデータセットとのインタフェースは、それが簡単にTFFで書かれた既存のコードをspelunkするようになります。さらに、 tf.data.Datasets
ClientData
構築物の構造を生成する直接反復処理することができるnumpy
ように、アレイをClientData
オブジェクトは、TFFに移動する前に、PythonベースのMLフレームワークで使用することができます。
シミュレーションを多くのマシンにスケールアップしたり、それらをデプロイしたりする場合は、いくつかのパターンを使用して作業を楽にすることができます。我々は我々が使うことができる方法のいくつかを歩いて下記ClientData
私たちの小さな規模の大規模な反復-への生産実験-への展開の経験をするとTFFをできるだけスムーズ。
ClientDataをTFFに渡すためにどのパターンを使用する必要がありますか?
私たちは、TFFの二使い方について説明しますClientData
深さを。以下の2つのカテゴリのいずれかに当てはまる場合は、明らかに一方を他方よりも優先します。そうでない場合は、より微妙な選択を行うために、それぞれの長所と短所をより詳細に理解する必要があります。
ローカルマシンでできるだけ早く反復したい。 TFFの分散ランタイムを簡単に利用できる必要はありません。
- あなたは渡したい
tf.data.Datasets
直接TFFにして。 - これは、あなたが命令的にプログラムすることができます
tf.data.Dataset
オブジェクト、および任意にそれらを処理します。 - 以下のオプションよりも柔軟性があります。ロジックをクライアントにプッシュするには、このロジックがシリアル化可能である必要があります。
- あなたは渡したい
フェデレーション計算をTFFのリモートランタイムで実行したい、またはすぐに実行する予定です。
- この場合、データセットの構築と前処理をクライアントにマッピングする必要があります。
- あなたのこの結果は、単にのリスト渡し
client_ids
直接あなたの連合の計算に。 - データセットの構築と前処理をクライアントにプッシュすることで、シリアル化のボトルネックを回避し、数百から数千のクライアントでパフォーマンスを大幅に向上させます。
オープンソース環境をセットアップする
# tensorflow_federated_nightly also bring in tf_nightly, which
# can causes a duplicate tensorboard install, leading to errors.
!pip uninstall --yes tensorboard tb-nightly
!pip install --quiet --upgrade tensorflow_federated_nightly
!pip install --quiet --upgrade nest_asyncio
import nest_asyncio
nest_asyncio.apply()
パッケージをインポートする
import collections
import time
import tensorflow as tf
import tensorflow_federated as tff
ClientDataオブジェクトの操作
のは、ロードとTFFのEMNIST模索することから始めましょうClientData
:
client_data, _ = tff.simulation.datasets.emnist.load_data()
Downloading emnist_all.sqlite.lzma: 100%|██████████| 170507172/170507172 [00:19<00:00, 8831921.67it/s] 2021-10-01 11:17:58.718735: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
最初のデータセットを検査するにある例のどのような種類の私たちに伝えることができClientData
。
first_client_id = client_data.client_ids[0]
first_client_dataset = client_data.create_tf_dataset_for_client(
first_client_id)
print(first_client_dataset.element_spec)
# This information is also available as a `ClientData` property:
assert client_data.element_type_structure == first_client_dataset.element_spec
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])
データセットが得られることに留意されたいcollections.OrderedDict
有するオブジェクトpixels
とlabel
画素は形状のテンソルであるキーを、 [28, 28]
我々は形に私達の入力を平らたいと仮定し[784]
我々はこれを行うことができます1つの可能な方法は、私たちに前処理機能を適用することであろうClientData
オブジェクト。
def preprocess_dataset(dataset):
"""Create batches of 5 examples, and limit to 3 batches."""
def map_fn(input):
return collections.OrderedDict(
x=tf.reshape(input['pixels'], shape=(-1, 784)),
y=tf.cast(tf.reshape(input['label'], shape=(-1, 1)), tf.int64),
)
return dataset.batch(5).map(
map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).take(5)
preprocessed_client_data = client_data.preprocess(preprocess_dataset)
# Notice that we have both reshaped and renamed the elements of the ordered dict.
first_client_dataset = preprocessed_client_data.create_tf_dataset_for_client(
first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])
さらに、シャッフルなど、より複雑な(場合によってはステートフルな)前処理を実行したい場合があります。
def preprocess_and_shuffle(dataset):
"""Applies `preprocess_dataset` above and shuffles the result."""
preprocessed = preprocess_dataset(dataset)
return preprocessed.shuffle(buffer_size=5)
preprocessed_and_shuffled = client_data.preprocess(preprocess_and_shuffle)
# The type signature will remain the same, but the batches will be shuffled.
first_client_dataset = preprocessed_and_shuffled.create_tf_dataset_for_client(
first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])
インタフェースtff.Computation
今、我々がいくつかの基本的な操作を行うことができClientData
我々はにフィードデータへの準備ができている、オブジェクトtff.Computation
。私たちは、定義tff.templates.IterativeProcess
実装フェデレーション平均化を、そしてそれにデータを渡すの異なる方法を探ります。
def model_fn():
model = tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(784,)),
tf.keras.layers.Dense(10, kernel_initializer='zeros'),
])
return tff.learning.from_keras_model(
model,
# Note: input spec is the _batched_ shape, and includes the
# label tensor which will be passed to the loss function. This model is
# therefore configured to accept data _after_ it has been preprocessed.
input_spec=collections.OrderedDict(
x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),
y=tf.TensorSpec(shape=[None, 1], dtype=tf.int64)),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
trainer = tff.learning.build_federated_averaging_process(
model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.01))
私たちはこれで作業を開始する前にIterativeProcess
、の意味論上の1件のコメントClientData
オーダーです。 ClientData
目的は、一般にある連合訓練のために利用できる人口の全体を表す生産FLシステムの実行環境では使用できませんし、シミュレーションに固有のものですが。 ClientData
実際に完全にユーザにバイパスフェデレーテッド・コンピューティング能力を与え、単に介して通常どおりサーバ側モデルを訓練ClientData.create_tf_dataset_from_all_clients
。
TFFのシミュレーション環境により、研究者は外側のループを完全に制御できます。特に、これは、クライアントの可用性、クライアントのドロップアウトなどの考慮事項を、ユーザーまたはPythonドライバースクリプトで対処する必要があることを意味します。一つは、あなたの上に標本分布を調整することにより、例えば、モデルクライアントドロップアウトのためにできたClientData's
client_ids
より多くのデータを持つユーザー(それに対応し、ローカル計算を長いが、実行されている)ような低確率で選択されることになります。
ただし、実際の連合システムでは、モデルトレーナーがクライアントを明示的に選択することはできません。クライアントの選択は、フェデレーション計算を実行しているシステムに委任されます。
渡すtf.data.Datasets
TFFに直接
我々は間のインタフェースを持っている一つの選択肢ClientData
とIterativeProcess
構築のそれであるtf.data.Datasets
Pythonで、そしてTFFにこれらのデータセットを渡します。
私達は私達の前処理を使用する場合ことに注意してくださいClientData
我々が得たデータセットは、上記で定義された我々のモデルにより予想の適切なタイプです。
selected_client_ids = preprocessed_and_shuffled.client_ids[:10]
preprocessed_data_for_clients = [
preprocessed_and_shuffled.create_tf_dataset_for_client(
selected_client_ids[i]) for i in range(10)
]
state = trainer.initialize()
for _ in range(5):
t1 = time.time()
state, metrics = trainer.next(state, preprocessed_data_for_clients)
t2 = time.time()
print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:62: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.compat.v1.graph_util.extract_sub_graph` WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:62: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.compat.v1.graph_util.extract_sub_graph` loss 2.9005744457244873, round time 4.576513767242432 loss 3.113278388977051, round time 0.49641919136047363 loss 2.7581865787506104, round time 0.4904160499572754 loss 2.87259578704834, round time 0.48976993560791016 loss 3.1202380657196045, round time 0.6724586486816406
我々はこのルートを取る場合は、しかし、我々は自明多機シミュレーションに移動することができません。私たちは地元のTensorFlowランタイムで構築するデータセットは、周囲のPython環境から状態をキャプチャし、そして、彼らはもはやそれらに利用可能である基準状態にしようとすると、直列化または直列化復元に失敗することができます。これはTensorFlowのからの不可解なエラーの例のマニフェストことができtensor_util.cc
:
Check failed: DT_VARIANT == input.dtype() (21 vs. 20)
クライアント上でのマッピング構築と前処理
この問題を回避するには、TFFは、そのユーザーが各クライアント上でローカルに発生したものとして、データセットのインスタンス化と前処理を検討することをお勧めします、とTFFのヘルパーまたは使用するfederated_map
明示的に各クライアントで、この前処理のコードを実行します。
概念的には、これを好む理由は明らかです。TFFのローカルランタイムでは、フェデレーションオーケストレーション全体が単一のマシンで行われているため、クライアントは「誤って」グローバルPython環境にのみアクセスできます。この時点で、同様の考え方がTFFのクロスプラットフォームで、常にシリアル化可能な機能哲学を生み出すことは注目に値します。
TFFは、経由して、このような変更が簡単になりますClientData's
属性dataset_computation
、 tff.Computation
取りclient_id
および関連返しtf.data.Dataset
。
注意preprocess
単純で動作しますdataset_computation
。 dataset_computation
前処理の属性ClientData
我々だけで定義された全体の前処理パイプラインが組み込まれています。
print('dataset computation without preprocessing:')
print(client_data.dataset_computation.type_signature)
print('\n')
print('dataset computation with preprocessing:')
print(preprocessed_and_shuffled.dataset_computation.type_signature)
dataset computation without preprocessing: (string -> <label=int32,pixels=float32[28,28]>*) dataset computation with preprocessing: (string -> <x=float32[?,784],y=int64[?,1]>*)
私たちは、呼び出すことができdataset_computation
し、Pythonランタイムでの熱心なデータセットを受け取り、私たちは、すべてのグローバル熱心ランタイムでこれらのデータセットを実体化を避けるために、反復プロセスまたは別の計算で構成したときに、このアプローチの本当の力を行使しています。 TFFは、ヘルパー関数を提供しtff.simulation.compose_dataset_computation_with_iterative_process
まさにこれを行うために使用することができます。
trainer_accepting_ids = tff.simulation.compose_dataset_computation_with_iterative_process(
preprocessed_and_shuffled.dataset_computation, trainer)
この両方tff.templates.IterativeProcesses
と同じように実行する上記1。前者前処理され、クライアントデータセットを受け入れ、後者はクライアントIDを表す文字列、両方のデータセットの構築を処理し、その本体に前処理受け付け-実際にstate
2の間を通過することができます。
for _ in range(5):
t1 = time.time()
state, metrics = trainer_accepting_ids.next(state, selected_client_ids)
t2 = time.time()
print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))
loss 2.8417396545410156, round time 1.6707067489624023 loss 2.7670371532440186, round time 0.5207102298736572 loss 2.665048122406006, round time 0.5302855968475342 loss 2.7213189601898193, round time 0.5313887596130371 loss 2.580148935317993, round time 0.5283482074737549
多数のクライアントへのスケーリング
trainer_accepting_ids
すぐにTFFのマルチマシンの実行時に使用することができ、そしてことを回避するには、マテリアライズtf.data.Datasets
とコントローラを(したがって、それらをシリアル化し、労働者にそれらを送信します)。
これにより、特に多数のクライアントで分散シミュレーションが大幅に高速化され、中間集約が可能になり、同様のシリアル化/逆シリアル化のオーバーヘッドを回避できます。
オプションの詳細:TFFで前処理ロジックを手動で作成する
TFFは、ゼロから構成性を実現するように設計されています。 TFFのヘルパーによって実行されたばかりの種類の構成は、ユーザーとして完全に制御できます。私たちはただで定義された前処理計算を手動で構成する必要があり可能性があり、トレーナー自身のnext
非常に単純:
selected_clients_type = tff.FederatedType(preprocessed_and_shuffled.dataset_computation.type_signature.parameter, tff.CLIENTS)
@tff.federated_computation(trainer.next.type_signature.parameter[0], selected_clients_type)
def new_next(server_state, selected_clients):
preprocessed_data = tff.federated_map(preprocessed_and_shuffled.dataset_computation, selected_clients)
return trainer.next(server_state, preprocessed_data)
manual_trainer_with_preprocessing = tff.templates.IterativeProcess(initialize_fn=trainer.initialize, next_fn=new_next)
実際、これは事実上、私たちが使用したヘルパーが内部で行っていることです(さらに、適切な型のチェックと操作を実行しています)。私たちも、シリアル化により、若干異なる同じロジックを表明している可能性がpreprocess_and_shuffle
にtff.Computation
、および分解federated_map
非前処理されたデータセットと実行される別の構築1つのステップにpreprocess_and_shuffle
各クライアントでは。
このより手動のパスにより、TFFのヘルパー(モジュロパラメーター名)と同じ型アノテーションを使用して計算が行われることを確認できます。
print(trainer_accepting_ids.next.type_signature)
print(manual_trainer_with_preprocessing.next.type_signature)
(<server_state=<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,federated_dataset={string}@CLIENTS> -> <<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,<broadcast=<>,aggregation=<mean_value=<>,mean_weight=<>>,train=<sparse_categorical_accuracy=float32,loss=float32>,stat=<num_examples=int64>>@SERVER>) (<server_state=<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,selected_clients={string}@CLIENTS> -> <<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,<broadcast=<>,aggregation=<mean_value=<>,mean_weight=<>>,train=<sparse_categorical_accuracy=float32,loss=float32>,stat=<num_examples=int64>>@SERVER>)