Lihat di TensorFlow.org | Jalankan di Google Colab | Lihat sumber di GitHub | Unduh buku catatan |
Gagasan tentang kumpulan data yang dikunci oleh klien (misalnya pengguna) sangat penting untuk komputasi gabungan seperti yang dimodelkan dalam TFF. TFF menyediakan antarmuka tff.simulation.datasets.ClientData
untuk abstrak lebih konsep ini, dan yang TFF host (dataset stackoverflow , shakespeare , emnist , cifar100 , dan gldv2 ) semua mengimplementasikan interface ini.
Jika Anda bekerja pada pembelajaran Federasi dengan dataset Anda sendiri, TFF sangat mendorong Anda untuk baik melaksanakan ClientData
antarmuka atau menggunakan salah satu dari fungsi pembantu TFF untuk menghasilkan ClientData
yang mewakili data Anda pada disk, misalnya tff.simulation.datasets.ClientData.from_clients_and_fn
.
Karena kebanyakan dari TFF contoh end-to-end mulai dengan ClientData
objek, menerapkan ClientData
antarmuka dengan dataset kustom Anda akan membuat lebih mudah untuk spelunk melalui kode yang ada ditulis dengan TFF. Selanjutnya, tf.data.Datasets
yang ClientData
konstruksi dapat mengulangi lebih langsung untuk menghasilkan struktur numpy
array, sehingga ClientData
benda dapat digunakan dengan kerangka ML berbasis Python sebelum pindah ke TFF.
Ada beberapa pola yang dapat Anda gunakan untuk membuat hidup Anda lebih mudah jika Anda berniat untuk meningkatkan simulasi Anda ke banyak mesin atau menerapkannya. Di bawah ini kami akan berjalan melalui beberapa cara kita dapat menggunakan ClientData
dan TFF untuk membuat skala kecil iterasi-to skala besar eksperimen-produksi pengalaman penyebaran kami sebagai halus mungkin.
Pola mana yang harus saya gunakan untuk meneruskan ClientData ke TFF?
Kita akan membahas dua penggunaan dari TFF ClientData
secara mendalam; jika Anda termasuk dalam salah satu dari dua kategori di bawah ini, Anda jelas akan lebih memilih satu dari yang lain. Jika tidak, Anda mungkin memerlukan pemahaman yang lebih rinci tentang pro dan kontra dari masing-masing untuk membuat pilihan yang lebih bernuansa.
Saya ingin mengulangi secepat mungkin di mesin lokal; Saya tidak perlu dapat dengan mudah memanfaatkan runtime terdistribusi TFF.
- Anda ingin lulus
tf.data.Datasets
ke TFF langsung. - Hal ini memungkinkan Anda untuk program imperatif dengan
tf.data.Dataset
benda, dan proses mereka sewenang-wenang. - Ini memberikan lebih banyak fleksibilitas daripada opsi di bawah ini; mendorong logika ke klien mengharuskan logika ini dapat serial.
- Anda ingin lulus
Saya ingin menjalankan komputasi gabungan saya di runtime jarak jauh TFF, atau saya berencana untuk melakukannya segera.
- Dalam hal ini Anda ingin memetakan konstruksi set data dan prapemrosesan ke klien.
- Hasil dalam Anda ini melewati hanya daftar
client_ids
langsung ke perhitungan federasi Anda. - Mendorong konstruksi set data dan prapemrosesan ke klien menghindari kemacetan dalam serialisasi, dan secara signifikan meningkatkan kinerja dengan ratusan hingga ribuan klien.
Siapkan lingkungan sumber terbuka
# tensorflow_federated_nightly also bring in tf_nightly, which
# can causes a duplicate tensorboard install, leading to errors.
!pip uninstall --yes tensorboard tb-nightly
!pip install --quiet --upgrade tensorflow_federated_nightly
!pip install --quiet --upgrade nest_asyncio
import nest_asyncio
nest_asyncio.apply()
paket impor
import collections
import time
import tensorflow as tf
import tensorflow_federated as tff
Memanipulasi objek ClientData
Mari kita mulai dengan bongkar menjelajahi TFF EMNIST ClientData
:
client_data, _ = tff.simulation.datasets.emnist.load_data()
Downloading emnist_all.sqlite.lzma: 100%|██████████| 170507172/170507172 [00:19<00:00, 8831921.67it/s] 2021-10-01 11:17:58.718735: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Memeriksa dataset pertama dapat memberitahu kita apa jenis contoh adalah di ClientData
.
first_client_id = client_data.client_ids[0]
first_client_dataset = client_data.create_tf_dataset_for_client(
first_client_id)
print(first_client_dataset.element_spec)
# This information is also available as a `ClientData` property:
assert client_data.element_type_structure == first_client_dataset.element_spec
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])
Perhatikan bahwa hasil dataset collections.OrderedDict
objek yang memiliki pixels
dan label
kunci, di mana pixel adalah tensor dengan bentuk [28, 28]
. Misalkan kita ingin meratakan masukan kami keluar ke bentuk [784]
. Salah satu cara yang mungkin bisa kita lakukan ini akan menjadi untuk menerapkan fungsi pre-processing untuk kami ClientData
objek.
def preprocess_dataset(dataset):
"""Create batches of 5 examples, and limit to 3 batches."""
def map_fn(input):
return collections.OrderedDict(
x=tf.reshape(input['pixels'], shape=(-1, 784)),
y=tf.cast(tf.reshape(input['label'], shape=(-1, 1)), tf.int64),
)
return dataset.batch(5).map(
map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).take(5)
preprocessed_client_data = client_data.preprocess(preprocess_dataset)
# Notice that we have both reshaped and renamed the elements of the ordered dict.
first_client_dataset = preprocessed_client_data.create_tf_dataset_for_client(
first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])
Kami mungkin ingin selain melakukan beberapa pemrosesan awal yang lebih kompleks (dan mungkin stateful), misalnya pengocokan.
def preprocess_and_shuffle(dataset):
"""Applies `preprocess_dataset` above and shuffles the result."""
preprocessed = preprocess_dataset(dataset)
return preprocessed.shuffle(buffer_size=5)
preprocessed_and_shuffled = client_data.preprocess(preprocess_and_shuffle)
# The type signature will remain the same, but the batches will be shuffled.
first_client_dataset = preprocessed_and_shuffled.create_tf_dataset_for_client(
first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])
Berinteraksi dengan tff.Computation
Sekarang kita dapat melakukan beberapa manipulasi dasar dengan ClientData
objek, kami siap untuk data umpan ke tff.Computation
. Kami mendefinisikan tff.templates.IterativeProcess
yang mengimplementasikan Federasi Averaging , dan mengeksplorasi metode yang berbeda lewat itu data.
def model_fn():
model = tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(784,)),
tf.keras.layers.Dense(10, kernel_initializer='zeros'),
])
return tff.learning.from_keras_model(
model,
# Note: input spec is the _batched_ shape, and includes the
# label tensor which will be passed to the loss function. This model is
# therefore configured to accept data _after_ it has been preprocessed.
input_spec=collections.OrderedDict(
x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),
y=tf.TensorSpec(shape=[None, 1], dtype=tf.int64)),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
trainer = tff.learning.build_federated_averaging_process(
model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.01))
Sebelum kita mulai bekerja dengan ini IterativeProcess
, satu komentar pada semantik ClientData
adalah dalam rangka. Sebuah ClientData
objek mewakili keseluruhan dari populasi yang tersedia untuk pelatihan federasi, yang pada umumnya tidak tersedia untuk lingkungan eksekusi dari sistem produksi FL dan khusus untuk simulasi. ClientData
memang memberikan pengguna kemampuan untuk komputasi federasi memotong sepenuhnya dan hanya melatih model server-side seperti biasa melalui ClientData.create_tf_dataset_from_all_clients
.
Lingkungan simulasi TFF menempatkan peneliti dalam kendali penuh atas loop luar. Secara khusus ini menyiratkan pertimbangan ketersediaan klien, klien putus sekolah, dll, harus ditangani oleh pengguna atau skrip driver Python. Satu bisa misalnya model client putus sekolah dengan menyesuaikan distribusi sampling atas Anda ClientData's
client_ids
sehingga pengguna dengan data yang lebih (dan Sejalan lagi berjalan perhitungan lokal) akan dipilih dengan probabilitas yang lebih rendah.
Namun, dalam sistem federasi nyata, klien tidak dapat dipilih secara eksplisit oleh pelatih model; pemilihan klien didelegasikan ke sistem yang menjalankan komputasi gabungan.
Melewati tf.data.Datasets
langsung ke TFF
Salah satu pilihan yang kita miliki untuk interfacing antara ClientData
dan IterativeProcess
adalah bahwa membangun tf.data.Datasets
di Python, dan melewati dataset ini untuk TFF.
Perhatikan bahwa jika kita menggunakan preprocessed kami ClientData
dataset kami menghasilkan adalah dari jenis yang sesuai yang diharapkan oleh model kami yang didefinisikan di atas.
selected_client_ids = preprocessed_and_shuffled.client_ids[:10]
preprocessed_data_for_clients = [
preprocessed_and_shuffled.create_tf_dataset_for_client(
selected_client_ids[i]) for i in range(10)
]
state = trainer.initialize()
for _ in range(5):
t1 = time.time()
state, metrics = trainer.next(state, preprocessed_data_for_clients)
t2 = time.time()
print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:62: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.compat.v1.graph_util.extract_sub_graph` WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:62: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.compat.v1.graph_util.extract_sub_graph` loss 2.9005744457244873, round time 4.576513767242432 loss 3.113278388977051, round time 0.49641919136047363 loss 2.7581865787506104, round time 0.4904160499572754 loss 2.87259578704834, round time 0.48976993560791016 loss 3.1202380657196045, round time 0.6724586486816406
Jika kita mengambil rute ini, namun, kami tidak akan dapat sepele pindah ke simulasi MULTIMESIN. Dataset kita membangun di runtime TensorFlow lokal dapat menangkap negara dari lingkungan python sekitarnya, dan gagal dalam serialisasi atau deserialization ketika mereka mencoba untuk negara referensi yang tidak lagi tersedia untuk mereka. Hal ini dapat terwujud misalnya dalam kesalahan ajaib dari TensorFlow ini tensor_util.cc
:
Check failed: DT_VARIANT == input.dtype() (21 vs. 20)
Pemetaan konstruksi dan preprocessing atas klien
Untuk menghindari masalah ini, TFF merekomendasikan penggunanya untuk mempertimbangkan dataset Instansiasi dan preprocessing sebagai sesuatu yang terjadi secara lokal pada setiap klien, dan menggunakan pembantu TFF atau federated_map
secara eksplisit menjalankan ini kode preprocessing pada setiap klien.
Secara konseptual, alasan untuk memilih ini jelas: dalam runtime lokal TFF, klien hanya "secara tidak sengaja" memiliki akses ke lingkungan Python global karena fakta bahwa seluruh orkestrasi federasi terjadi pada satu mesin. Perlu dicatat pada titik ini bahwa pemikiran serupa memunculkan filosofi fungsional lintas platform, selalu serialisasi, dan fungsional TFF.
TFF membuat perubahan yang sederhana melalui ClientData's
atribut dataset_computation
, sebuah tff.Computation
yang membutuhkan client_id
dan mengembalikan terkait tf.data.Dataset
.
Perhatikan bahwa preprocess
hanya bekerja dengan dataset_computation
; yang dataset_computation
atribut dari preprocessed ClientData
menggabungkan seluruh pipa preprocessing kita hanya didefinisikan:
print('dataset computation without preprocessing:')
print(client_data.dataset_computation.type_signature)
print('\n')
print('dataset computation with preprocessing:')
print(preprocessed_and_shuffled.dataset_computation.type_signature)
dataset computation without preprocessing: (string -> <label=int32,pixels=float32[28,28]>*) dataset computation with preprocessing: (string -> <x=float32[?,784],y=int64[?,1]>*)
Kita bisa memanggil dataset_computation
dan menerima dataset bersemangat dalam runtime Python, tapi kekuatan nyata dari pendekatan ini dilaksanakan ketika kita menulis dengan proses berulang atau perhitungan lain untuk menghindari mewujudkan dataset ini dalam runtime bersemangat global yang sama sekali. TFF menyediakan fungsi pembantu tff.simulation.compose_dataset_computation_with_iterative_process
yang dapat digunakan untuk melakukan hal ini.
trainer_accepting_ids = tff.simulation.compose_dataset_computation_with_iterative_process(
preprocessed_and_shuffled.dataset_computation, trainer)
Kedua ini tff.templates.IterativeProcesses
dan satu di atas dijalankan dengan cara yang sama; namun mantan menerima dataset client preprocessed, dan yang terakhir menerima string yang mewakili id klien, penanganan baik konstruksi dataset dan preprocessing di tubuhnya - sebenarnya state
dapat dilalui antara keduanya.
for _ in range(5):
t1 = time.time()
state, metrics = trainer_accepting_ids.next(state, selected_client_ids)
t2 = time.time()
print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))
loss 2.8417396545410156, round time 1.6707067489624023 loss 2.7670371532440186, round time 0.5207102298736572 loss 2.665048122406006, round time 0.5302855968475342 loss 2.7213189601898193, round time 0.5313887596130371 loss 2.580148935317993, round time 0.5283482074737549
Menskalakan ke sejumlah besar klien
trainer_accepting_ids
dapat langsung digunakan dalam TFF runtime MULTIMESIN, dan menghindari mewujudkan tf.data.Datasets
dan controller (dan karena itu serialisasi mereka dan mengirim mereka keluar untuk para pekerja).
Ini secara signifikan mempercepat simulasi terdistribusi, terutama dengan sejumlah besar klien, dan memungkinkan agregasi menengah untuk menghindari overhead serialisasi/deserialisasi yang serupa.
Deepdive opsional: menyusun logika prapemrosesan secara manual di TFF
TFF dirancang untuk komposisi dari bawah ke atas; jenis komposisi yang baru saja dilakukan oleh helper TFF sepenuhnya berada dalam kendali kami sebagai pengguna. Kita bisa memiliki manual menyusun perhitungan preprocessing kita hanya didefinisikan dengan pelatih sendiri next
cukup sederhana:
selected_clients_type = tff.FederatedType(preprocessed_and_shuffled.dataset_computation.type_signature.parameter, tff.CLIENTS)
@tff.federated_computation(trainer.next.type_signature.parameter[0], selected_clients_type)
def new_next(server_state, selected_clients):
preprocessed_data = tff.federated_map(preprocessed_and_shuffled.dataset_computation, selected_clients)
return trainer.next(server_state, preprocessed_data)
manual_trainer_with_preprocessing = tff.templates.IterativeProcess(initialize_fn=trainer.initialize, next_fn=new_next)
Faktanya, inilah yang secara efektif dilakukan oleh helper yang kami gunakan di bawah tenda (ditambah melakukan pengecekan dan manipulasi tipe yang sesuai). Kita bahkan bisa menyatakan logika yang sama sedikit berbeda, dengan serialisasi preprocess_and_shuffle
menjadi tff.Computation
, dan dekomposisi federated_map
menjadi satu langkah yang membangun un-preprocessed dataset dan lain yang berjalan preprocess_and_shuffle
di setiap klien.
Kami dapat memverifikasi bahwa jalur yang lebih manual ini menghasilkan komputasi dengan tanda tangan tipe yang sama dengan helper TFF (nama parameter modulo):
print(trainer_accepting_ids.next.type_signature)
print(manual_trainer_with_preprocessing.next.type_signature)
(<server_state=<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,federated_dataset={string}@CLIENTS> -> <<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,<broadcast=<>,aggregation=<mean_value=<>,mean_weight=<>>,train=<sparse_categorical_accuracy=float32,loss=float32>,stat=<num_examples=int64>>@SERVER>) (<server_state=<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,selected_clients={string}@CLIENTS> -> <<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,<broadcast=<>,aggregation=<mean_value=<>,mean_weight=<>>,train=<sparse_categorical_accuracy=float32,loss=float32>,stat=<num_examples=int64>>@SERVER>)