ดูบน TensorFlow.org | ทำงานใน Google Colab | ดูแหล่งที่มาบน GitHub | ดาวน์โหลดโน๊ตบุ๊ค |
ก่อนที่เราจะเริ่มต้น
ก่อนที่เราจะเริ่ม โปรดเรียกใช้สิ่งต่อไปนี้เพื่อให้แน่ใจว่าสภาพแวดล้อมของคุณได้รับการตั้งค่าอย่างถูกต้อง หากคุณไม่เห็นคำทักทายโปรดดูที่ การติดตั้ง คู่มือสำหรับคำแนะนำ
!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio
import nest_asyncio
nest_asyncio.apply()
import tensorflow as tf
import tensorflow_federated as tff
ใน การจัดหมวดหมู่ภาพ และ รุ่นข้อความ บทเรียนที่เราได้เรียนรู้วิธีการตั้งค่ารูปแบบและข้อมูลที่เป็นท่อสำหรับสหพันธ์การเรียนรู้ (FL) และดำเนินการฝึกอบรมแบบ federated ผ่าน tff.learning
ชั้น API ของฉิบหาย
นี่เป็นเพียงส่วนเล็ก ๆ ของภูเขาน้ำแข็งเมื่อพูดถึงการวิจัยของ FL ในการกวดวิชานี้เราหารือถึงวิธีการดำเนินการตามขั้นตอนวิธีการเรียนรู้แบบ federated โดยไม่ต้องชะลอไป tff.learning
API เรามุ่งมั่นที่จะบรรลุสิ่งต่อไปนี้:
เป้าหมาย:
- ทำความเข้าใจโครงสร้างทั่วไปของอัลกอริธึมการเรียนรู้แบบสหพันธรัฐ
- สำรวจสหพันธ์หลักของฉิบหาย
- ใช้ Federated Core เพื่อปรับใช้ Federated Averaging โดยตรง
ในขณะที่การกวดวิชานี้มีข้อมูลในตัวเราขอแนะนำเป็นครั้งแรกที่ได้อ่าน การจัดหมวดหมู่ภาพ และ ข้อความรุ่น บทเรียน
การเตรียมข้อมูลเข้า
ก่อนอื่นเราโหลดและประมวลผลชุดข้อมูล EMNIST ที่รวมอยู่ใน TFF ล่วงหน้า สำหรับรายละเอียดเพิ่มเติมโปรดดูที่ การจัดหมวดหมู่ภาพ กวดวิชา
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
เพื่อที่จะกินอาหารชุดข้อมูลในรูปแบบของเราเราแผ่ข้อมูลและแปลงแต่ละตัวอย่างลงใน tuple ของรูปแบบ (flattened_image_vector, label)
NUM_CLIENTS = 10
BATCH_SIZE = 20
def preprocess(dataset):
def batch_format_fn(element):
"""Flatten a batch of EMNIST data and return a (features, label) tuple."""
return (tf.reshape(element['pixels'], [-1, 784]),
tf.reshape(element['label'], [-1, 1]))
return dataset.batch(BATCH_SIZE).map(batch_format_fn)
ตอนนี้เราเลือกลูกค้าจำนวนเล็กน้อย และใช้การประมวลผลล่วงหน้าด้านบนกับชุดข้อมูล
client_ids = sorted(emnist_train.client_ids)[:NUM_CLIENTS]
federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
for x in client_ids
]
กำลังเตรียมโมเดล
เราใช้รูปแบบเดียวกับใน การจัดหมวดหมู่ภาพ กวดวิชา รุ่นนี้ (ดำเนินการผ่านทาง tf.keras
) มีชั้นเดียวที่ซ่อนอยู่ตามชั้น softmax
def create_keras_model():
initializer = tf.keras.initializers.GlorotNormal(seed=0)
return tf.keras.models.Sequential([
tf.keras.layers.Input(shape=(784,)),
tf.keras.layers.Dense(10, kernel_initializer=initializer),
tf.keras.layers.Softmax(),
])
เพื่อที่จะใช้รูปแบบนี้ในฉิบหายเราห่อรุ่น Keras เป็น tff.learning.Model
นี้จะช่วยให้เราสามารถดำเนินการรูปแบบของ การส่งผ่านไปข้างหน้า ภายในฉิบหายและ เอาท์พุทสารสกัดจากรูปแบบ สำหรับรายละเอียดเพิ่มเติมยังเห็น การจัดหมวดหมู่ภาพ กวดวิชา
def model_fn():
keras_model = create_keras_model()
return tff.learning.from_keras_model(
keras_model,
input_spec=federated_train_data[0].element_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
ในขณะที่เราใช้ tf.keras
เพื่อสร้าง tff.learning.Model
, ฉิบหายสนับสนุนรูปแบบมากขึ้นทั่วไป โมเดลเหล่านี้มีคุณลักษณะที่เกี่ยวข้องต่อไปนี้ในการบันทึกน้ำหนักของแบบจำลอง:
-
trainable_variables
: การ iterable ของเทนเซอร์ที่สอดคล้องกับชั้นสุวินัย -
non_trainable_variables
: การ iterable ของเทนเซอร์ที่สอดคล้องกับชั้นไม่ใช่สุวินัย
สำหรับวัตถุประสงค์ของเราเราจะใช้ trainable_variables
(เนื่องจากโมเดลของเรามีเพียงแค่นั้นเท่านั้น!)
สร้างอัลกอริทึมการเรียนรู้แบบสหพันธรัฐของคุณเอง
ในขณะที่ tff.learning
API จะช่วยหนึ่งในการสร้างหลายสายพันธุ์ของสหพันธ์ Averaging มีอัลกอริทึมแบบ federated อื่น ๆ ที่ไม่เหมาะสมอย่างเรียบร้อยในกรอบนี้ ตัวอย่างเช่นคุณอาจต้องการเพิ่ม regularization คลิปหรืออัลกอริทึมที่มีความซับซ้อนมากขึ้นเช่น การฝึกอบรม GAN federated คุณก็อาจจะแทนมีความสนใจใน การวิเคราะห์แบบ federated
สำหรับอัลกอริธึมขั้นสูงเหล่านี้ เราจะต้องเขียนอัลกอริทึมของเราเองโดยใช้ TFF ในหลายกรณี อัลกอริธึมแบบรวมศูนย์มี 4 องค์ประกอบหลัก:
- ขั้นตอนการออกอากาศแบบเซิร์ฟเวอร์ถึงไคลเอ็นต์
- ขั้นตอนการอัปเดตไคลเอ็นต์ท้องถิ่น
- ขั้นตอนการอัปโหลดไคลเอ็นต์สู่เซิร์ฟเวอร์
- ขั้นตอนการอัปเดตเซิร์ฟเวอร์
ในฉิบหายเราโดยทั่วไปหมายถึงขั้นตอนวิธีแบบ federated เป็น tff.templates.IterativeProcess
(ซึ่งเราจะเรียกว่าเป็นเพียง IterativeProcess
ตลอด) นี้เป็นชั้นที่มี initialize
และ next
ฟังก์ชั่น ที่นี่ initialize
จะใช้ในการเริ่มต้นเซิร์ฟเวอร์และ next
จะดำเนินการอย่างใดอย่างหนึ่งรอบการสื่อสารของอัลกอริทึมแบบ federated มาเขียนโครงร่างของกระบวนการทำซ้ำของเราสำหรับ FedAvg กันดีกว่า
ครั้งแรกที่เรามีฟังก์ชั่นการเริ่มต้นที่เพียงสร้าง tff.learning.Model
และผลตอบแทนของน้ำหนักสุวินัย
def initialize_fn():
model = model_fn()
return model.trainable_variables
ฟังก์ชันนี้ดูดี แต่อย่างที่เราจะเห็นในภายหลัง เราจำเป็นต้องทำการปรับเปลี่ยนเล็กน้อยเพื่อให้เป็น "การคำนวณ TFF"
นอกจากนี้เรายังต้องการที่จะวาด next_fn
def next_fn(server_weights, federated_dataset):
# Broadcast the server weights to the clients.
server_weights_at_client = broadcast(server_weights)
# Each client computes their updated weights.
client_weights = client_update(federated_dataset, server_weights_at_client)
# The server averages these updates.
mean_client_weights = mean(client_weights)
# The server updates its model.
server_weights = server_update(mean_client_weights)
return server_weights
เราจะเน้นที่การใช้องค์ประกอบทั้งสี่นี้แยกกัน ก่อนอื่นเราเน้นที่ส่วนต่างๆ ที่สามารถนำมาใช้ใน TensorFlow ได้ นั่นคือขั้นตอนการอัปเดตไคลเอ็นต์และเซิร์ฟเวอร์
บล็อกเทนเซอร์โฟลว์
อัพเดทลูกค้า
เราจะใช้ของเรา tff.learning.Model
ที่จะดำเนินการฝึกอบรมลูกค้าเป็นหลักเดียวกับที่คุณจะอบรมรุ่น TensorFlow โดยเฉพาะอย่างยิ่งเราจะใช้ tf.GradientTape
เพื่อคำนวณการไล่ระดับสีในกระบวนการของข้อมูลแล้วใช้การไล่ระดับสีเหล่านี้โดยใช้ client_optimizer
เราเน้นเฉพาะตุ้มน้ำหนักที่สามารถฝึกได้
@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
"""Performs training (using the server model weights) on the client's dataset."""
# Initialize the client model with the current server weights.
client_weights = model.trainable_variables
# Assign the server weights to the client model.
tf.nest.map_structure(lambda x, y: x.assign(y),
client_weights, server_weights)
# Use the client_optimizer to update the local model.
for batch in dataset:
with tf.GradientTape() as tape:
# Compute a forward pass on the batch of data
outputs = model.forward_pass(batch)
# Compute the corresponding gradient
grads = tape.gradient(outputs.loss, client_weights)
grads_and_vars = zip(grads, client_weights)
# Apply the gradient using a client optimizer.
client_optimizer.apply_gradients(grads_and_vars)
return client_weights
อัพเดทเซิฟเวอร์
การอัปเดตเซิร์ฟเวอร์สำหรับ FedAvg ง่ายกว่าการอัปเดตไคลเอ็นต์ เราจะใช้ "วานิลลา" เฉลี่ยแบบรวมศูนย์ ซึ่งเราเพียงแค่แทนที่น้ำหนักโมเดลเซิร์ฟเวอร์ด้วยค่าเฉลี่ยของน้ำหนักโมเดลไคลเอ็นต์ ย้ำอีกครั้ง เราเน้นเฉพาะตุ้มน้ำหนักที่ฝึกได้
@tf.function
def server_update(model, mean_client_weights):
"""Updates the server model weights as the average of the client model weights."""
model_weights = model.trainable_variables
# Assign the mean client weights to the server model.
tf.nest.map_structure(lambda x, y: x.assign(y),
model_weights, mean_client_weights)
return model_weights
ข้อมูลอาจจะง่ายโดยเพียงแค่กลับมา mean_client_weights
การใช้งานอย่างไรก็ตามที่สูงขึ้นของสหพันธ์ใช้ Averaging mean_client_weights
ด้วยเทคนิคที่ซับซ้อนมากขึ้นเช่นโมเมนตัมหรือปรับตัว
ท้าทาย: Implement รุ่นของ server_update
ที่ปรับปรุงน้ำหนักเซิร์ฟเวอร์เพื่อเป็นจุดกึ่งกลางของ model_weights และ mean_client_weights (หมายเหตุ: ชนิดของ "จุดกึ่งกลาง" วิธีการนี้จะคล้ายคลึงกับการทำงานที่ผ่านมาเกี่ยวกับการ เพิ่มประสิทธิภาพ Lookahead !)
จนถึงตอนนี้ เราเพิ่งเขียนโค้ด TensorFlow ล้วนๆ เท่านั้น นี่คือการออกแบบ เนื่องจาก TFF อนุญาตให้คุณใช้โค้ด TensorFlow ส่วนใหญ่ที่คุณคุ้นเคยอยู่แล้ว แต่ตอนนี้เราจะต้องระบุตรรกะประสานที่เป็นตรรกะที่สั่งการสิ่งที่ออกอากาศเซิร์ฟเวอร์ไปยังลูกค้าและสิ่งที่เป็นภาพที่ส่งลูกค้าไปยังเซิร์ฟเวอร์
นี้จะต้องสหพันธ์หลักของฉิบหาย
บทนำสู่สหพันธรัฐคอร์
สหพันธ์แกน (เอฟซี) เป็นชุดของการเชื่อมต่อในระดับต่ำกว่าที่ทำหน้าที่เป็นรากฐานสำหรับการ tff.learning
API อย่างไรก็ตาม อินเทอร์เฟซเหล่านี้ไม่จำกัดเพียงการเรียนรู้ อันที่จริง พวกเขาสามารถใช้สำหรับการวิเคราะห์และการคำนวณอื่น ๆ มากกว่าข้อมูลแบบกระจาย
ในระดับสูง คอร์รวมเป็นสภาพแวดล้อมการพัฒนาที่ช่วยให้ตรรกะของโปรแกรมแสดงออกอย่างกะทัดรัดเพื่อรวมรหัส TensorFlow กับตัวดำเนินการสื่อสารแบบกระจาย (เช่น ผลรวมแบบกระจายและการออกอากาศ) เป้าหมายคือเพื่อให้นักวิจัยและผู้ปฏิบัติงานสามารถควบคุมการสื่อสารแบบกระจายในระบบของตนได้โดยไม่ต้องใช้รายละเอียดการใช้ระบบ (เช่นการระบุการแลกเปลี่ยนข้อความเครือข่ายแบบจุดต่อจุด)
ประเด็นสำคัญประการหนึ่งคือ TFF ได้รับการออกแบบมาเพื่อรักษาความเป็นส่วนตัว ดังนั้นจึงช่วยให้สามารถควบคุมตำแหน่งข้อมูลได้อย่างชัดเจน เพื่อป้องกันการสะสมข้อมูลที่ไม่ต้องการที่ตำแหน่งเซิร์ฟเวอร์ส่วนกลาง
ข้อมูลส่วนกลาง
แนวคิดหลักใน TFF คือ "ข้อมูลรวม" ซึ่งหมายถึงการรวบรวมรายการข้อมูลที่โฮสต์ข้ามกลุ่มอุปกรณ์ในระบบแบบกระจาย (เช่น ชุดข้อมูลไคลเอ็นต์ หรือน้ำหนักโมเดลเซิร์ฟเวอร์) เรารูปแบบการเก็บรวบรวมทั้งหมดของรายการข้อมูลในอุปกรณ์ทั้งหมดเป็นค่า federated เดียว
ตัวอย่างเช่น สมมติว่าเรามีอุปกรณ์ไคลเอนต์ที่แต่ละอันมีทุ่นแทนอุณหภูมิของเซ็นเซอร์ เราสามารถแสดงเป็นลอย federated โดย
federated_float_on_clients = tff.FederatedType(tf.float32, tff.CLIENTS)
ประเภทสหพันธ์ถูกกำหนดโดยชนิด T
ขององค์ประกอบของสมาชิก (เช่น. tf.float32
) และกลุ่ม G
ของอุปกรณ์ เราจะมุ่งเน้นไปที่กรณีที่ G
เป็นทั้ง tff.CLIENTS
หรือ tff.SERVER
ดังกล่าวเป็นชนิด federated จะแสดงเป็น {T}@G
ที่แสดงด้านล่าง
str(federated_float_on_clients)
'{float32}@CLIENTS'
ทำไมเราถึงสนใจเรื่องตำแหน่งมาก? เป้าหมายหลักของ TFF คือการเปิดใช้งานการเขียนโค้ดที่สามารถนำไปใช้งานบนระบบแบบกระจายจริงได้ ซึ่งหมายความว่าจำเป็นอย่างยิ่งที่จะต้องให้เหตุผลว่าอุปกรณ์ชุดย่อยใดที่เรียกใช้โค้ดใด และข้อมูลส่วนต่างๆ อยู่ที่ไหน
ฉิบหายมุ่งเน้นไปที่สิ่งที่สาม: ข้อมูลที่มีข้อมูลที่มีการวางและวิธีการที่ข้อมูลจะถูกเปลี่ยน สองคนแรกที่ถูกห่อหุ้มในรูปแบบ federated ในขณะที่ที่ผ่านมามีการห่อหุ้มในการคำนวณแบบ federated
การคำนวณแบบสหพันธรัฐ
ฉิบหายเป็นสภาพแวดล้อมการเขียนโปรแกรมอย่างยิ่งพิมพ์การทำงานที่มีหน่วยพื้นฐานคำนวณ federated เหล่านี้เป็นส่วนของตรรกะที่ยอมรับค่ารวมเป็นอินพุต และส่งกลับค่ารวมเป็นผลลัพธ์
ตัวอย่างเช่น สมมติว่าเราต้องการเฉลี่ยอุณหภูมิบนเซ็นเซอร์ลูกค้าของเรา เราสามารถกำหนดสิ่งต่อไปนี้ (โดยใช้ federated float ของเรา):
@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def get_average_temperature(client_temperatures):
return tff.federated_mean(client_temperatures)
คุณอาจจะถามว่าวิธีการที่แตกต่างกันนี้จาก tf.function
มัณฑนากรใน TensorFlow? คำตอบที่สำคัญคือว่ารหัสที่สร้างขึ้นโดย tff.federated_computation
จะไม่ TensorFlow มิได้หลามรหัส; มันเป็นคุณสมบัติของระบบกระจายในแพลตฟอร์มภาษากาวภายใน
แม้ว่าสิ่งนี้อาจฟังดูซับซ้อน แต่คุณสามารถคิดได้ว่าการคำนวณ TFF เป็นฟังก์ชันที่มีลายเซ็นประเภทที่กำหนดไว้อย่างดี ลายเซ็นประเภทนี้สามารถสอบถามโดยตรง
str(get_average_temperature.type_signature)
'({float32}@CLIENTS -> float32@SERVER)'
นี้ tff.federated_computation
ยอมรับข้อโต้แย้งของประเภท federated {float32}@CLIENTS
และค่าผลตอบแทนจากประเภท federated {float32}@SERVER
การคำนวณแบบรวมศูนย์อาจเปลี่ยนจากเซิร์ฟเวอร์หนึ่งไปยังอีกเครื่องหนึ่ง จากเครื่องลูกหนึ่งไปยังอีกเครื่องหนึ่ง หรือจากเครื่องหนึ่งไปยังอีกเครื่องหนึ่ง การคำนวณแบบรวมศูนย์สามารถประกอบได้เหมือนกับฟังก์ชันปกติ ตราบใดที่ลายเซ็นประเภทตรงกัน
เพื่อสนับสนุนการพัฒนา, ฉิบหายช่วยให้คุณสามารถเรียก tff.federated_computation
เป็นฟังก์ชั่นหลาม ตัวอย่างเช่น เราสามารถเรียก
get_average_temperature([68.5, 70.3, 69.8])
69.53334
การคำนวณแบบไม่กระตือรือร้นและ TensorFlow
มีข้อ จำกัด ที่สำคัญสองประการที่ควรทราบ ครั้งแรกเมื่อล่ามหลามพบ tff.federated_computation
มัณฑนากร, ฟังก์ชั่นที่มีการตรวจสอบทันทีและต่อเนื่องสำหรับการใช้งานในอนาคต เนื่องจากลักษณะการกระจายอำนาจของ Federated Learning การใช้งานในอนาคตนี้อาจเกิดขึ้นที่อื่น เช่น สภาพแวดล้อมการดำเนินการระยะไกล ดังนั้นการคำนวณฉิบหายเป็นพื้นฐานที่ไม่กระตือรือร้น ลักษณะการทำงานนี้จะค่อนข้างคล้ายกับที่ของ tf.function
มัณฑนากรใน TensorFlow
ประการที่สองการคำนวณแบบ federated สามารถประกอบด้วยผู้ประกอบการ federated (เช่น tff.federated_mean
) พวกเขาไม่สามารถมีการดำเนินงาน TensorFlow รหัส TensorFlow จะต้องถูกคุมขังในบล็อกตกแต่งด้วย tff.tf_computation
ส่วนใหญ่รหัส TensorFlow สามัญสามารถตกแต่งโดยตรงเช่นฟังก์ชั่นดังต่อไปนี้ที่ใช้เวลาจำนวนและเพิ่ม 0.5
กับมัน
@tff.tf_computation(tf.float32)
def add_half(x):
return tf.add(x, 0.5)
เหล่านี้ยังมีประเภทลายเซ็น แต่ไม่มีตำแหน่ง ตัวอย่างเช่น เราสามารถเรียก
str(add_half.type_signature)
'(float32 -> float32)'
ที่นี่เราเห็นความแตกต่างที่สำคัญระหว่าง tff.federated_computation
และ tff.tf_computation
อดีตมีตำแหน่งที่ชัดเจนในขณะที่หลังไม่มี
เราสามารถใช้ tff.tf_computation
บล็อกในการคำนวณแบบ federated โดยการระบุตำแหน่ง มาสร้างฟังก์ชันที่บวกครึ่ง แต่เฉพาะกับโฟลตรวมที่ไคลเอนต์ เราสามารถทำได้โดยใช้ tff.federated_map
ซึ่งใช้ที่กำหนด tff.tf_computation
ขณะที่การรักษาตำแหน่ง
@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def add_half_on_clients(x):
return tff.federated_map(add_half, x)
ฟังก์ชั่นนี้เป็นเกือบจะเหมือนกับ add_half
ยกเว้นว่าจะยอมรับเฉพาะค่ากับตำแหน่งที่ tff.CLIENTS
และค่าผลตอบแทนกับตำแหน่งเดียวกัน เราสามารถเห็นสิ่งนี้ในลายเซ็นประเภท:
str(add_half_on_clients.type_signature)
'({float32}@CLIENTS -> {float32}@CLIENTS)'
สรุป:
- TFF ทำงานบนค่าส่วนกลาง
- แต่ละค่า federated มีชนิดแบบรวมที่มีประเภท (เช่น.
tf.float32
) และตำแหน่ง (เช่น.tff.CLIENTS
) - ค่าสหพันธ์สามารถเปลี่ยนโดยใช้การคำนวณแบบ federated ซึ่งจะต้องได้รับการตกแต่งด้วย
tff.federated_computation
และลายเซ็นประเภท federated - รหัส TensorFlow จะต้องมีอยู่ในบล็อกที่มี
tff.tf_computation
ตกแต่ง - บล็อกเหล่านี้สามารถรวมเข้ากับการคำนวณแบบรวมศูนย์ได้
สร้างอัลกอริธึมการเรียนรู้แบบสหพันธรัฐของคุณเอง กลับมาอีกครั้ง
ตอนนี้เราได้เห็น Federated Core แล้ว เราจึงสามารถสร้างอัลกอริทึมการเรียนรู้แบบรวมศูนย์ของเราเองได้ โปรดจำไว้ว่าข้างต้นเราได้กำหนด initialize_fn
และ next_fn
สำหรับขั้นตอนวิธีการของเรา next_fn
จะทำให้การใช้งานของ client_update
และ server_update
เรากำหนดโดยใช้รหัส TensorFlow บริสุทธิ์
อย่างไรก็ตามในการที่จะทำให้อัลกอริทึมของเราคำนวณ federated เราจะจำเป็นต้องใช้ทั้ง next_fn
และ initialize_fn
แต่ละเป็น tff.federated_computation
TensorFlow Federated บล็อก
การสร้างการคำนวณการเริ่มต้น
ฟังก์ชั่นการเริ่มต้นจะค่อนข้างง่าย: เราจะสร้างรูปแบบการใช้ model_fn
แต่จำไว้ว่าเราจะต้องแยกออกจากเรารหัส TensorFlow ใช้ tff.tf_computation
@tff.tf_computation
def server_init():
model = model_fn()
return model.trainable_variables
จากนั้นเราจะสามารถผ่านนี้โดยตรงในการคำนวณโดยใช้ federated tff.federated_value
@tff.federated_computation
def initialize_fn():
return tff.federated_value(server_init(), tff.SERVER)
สร้าง next_fn
ตอนนี้เราใช้รหัสอัปเดตไคลเอ็นต์และเซิร์ฟเวอร์เพื่อเขียนอัลกอริทึมจริง ครั้งแรกที่เราจะเปิดของเรา client_update
เป็น tff.tf_computation
ที่ยอมรับชุดข้อมูลลูกค้าและน้ำหนักเซิร์ฟเวอร์และผลการปรับปรุงน้ำหนักของลูกค้าเมตริกซ์
เราจะต้องใช้ประเภทที่เกี่ยวข้องเพื่อตกแต่งฟังก์ชันของเราอย่างเหมาะสม โชคดีที่ประเภทของน้ำหนักเซิร์ฟเวอร์สามารถแยกได้โดยตรงจากแบบจำลองของเรา
whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)
ลองดูที่ลายเซ็นประเภทชุดข้อมูล จำไว้ว่าเราถ่ายภาพ 28 คูณ 28 (ด้วยป้ายกำกับจำนวนเต็ม) และทำให้แบน
str(tf_dataset_type)
'<float32[?,784],int32[?,1]>*'
นอกจากนี้เรายังสามารถแยกประเภทรุ่นน้ำหนักโดยใช้ของเรา server_init
ฟังก์ชั่นดังกล่าวข้างต้น
model_weights_type = server_init.type_signature.result
การตรวจสอบประเภทลายเซ็น เราจะสามารถเห็นสถาปัตยกรรมของแบบจำลองของเราได้!
str(model_weights_type)
'<float32[784,10],float32[10]>'
ตอนนี้เราสามารถสร้างของเรา tff.tf_computation
สำหรับการปรับปรุงไคลเอ็นต์
@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
model = model_fn()
client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
return client_update(model, tf_dataset, server_weights, client_optimizer)
tff.tf_computation
รุ่นของโปรแกรมปรับปรุงเซิร์ฟเวอร์สามารถกำหนดในลักษณะที่คล้ายกันโดยใช้ชนิดที่เราได้สกัดแล้ว
@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
model = model_fn()
return server_update(model, mean_client_weights)
สุดท้าย แต่ไม่น้อยเราต้องสร้าง tff.federated_computation
ที่นำมานี้ทั้งหมดเข้าด้วยกัน ฟังก์ชั่นนี้จะยอมรับค่าทั้งสองแบบ federated หนึ่งที่สอดคล้องกับน้ำหนักเซิร์ฟเวอร์ (กับตำแหน่ง tff.SERVER
) และอื่น ๆ ที่สอดคล้องกับชุดข้อมูลลูกค้า (กับตำแหน่ง tff.CLIENTS
)
โปรดทราบว่าทั้งสองประเภทนี้ถูกกำหนดไว้ข้างต้น! เราก็ต้องให้พวกเขามีตำแหน่งที่เหมาะสมโดยใช้ tff.FederatedType
federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)
จำองค์ประกอบ 4 ประการของอัลกอริธึม FL ได้หรือไม่
- ขั้นตอนการออกอากาศแบบเซิร์ฟเวอร์ถึงไคลเอ็นต์
- ขั้นตอนการอัปเดตไคลเอ็นต์ท้องถิ่น
- ขั้นตอนการอัปโหลดไคลเอ็นต์สู่เซิร์ฟเวอร์
- ขั้นตอนการอัปเดตเซิร์ฟเวอร์
ตอนนี้เราได้สร้างส่วนข้างต้นแล้ว แต่ละส่วนสามารถแสดงเป็นโค้ด TFF บรรทัดเดียวได้ ความเรียบง่ายนี้เป็นเหตุผลว่าทำไมเราจึงต้องระมัดระวังเป็นพิเศษในการระบุสิ่งต่างๆ เช่น ประเภทรวม!
@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
# Broadcast the server weights to the clients.
server_weights_at_client = tff.federated_broadcast(server_weights)
# Each client computes their updated weights.
client_weights = tff.federated_map(
client_update_fn, (federated_dataset, server_weights_at_client))
# The server averages these updates.
mean_client_weights = tff.federated_mean(client_weights)
# The server updates its model.
server_weights = tff.federated_map(server_update_fn, mean_client_weights)
return server_weights
ขณะนี้เรามี tff.federated_computation
สำหรับทั้งการเริ่มต้นขั้นตอนวิธีการและสำหรับการทำงานเป็นขั้นตอนหนึ่งของขั้นตอนวิธี จะเสร็จสิ้นขั้นตอนวิธีการของเราเราผ่านเหล่านี้เป็น tff.templates.IterativeProcess
federated_algorithm = tff.templates.IterativeProcess(
initialize_fn=initialize_fn,
next_fn=next_fn
)
ดู Let 's ที่ลายเซ็นประเภทของ initialize
และ next
การทำงานของกระบวนการซ้ำของเรา
str(federated_algorithm.initialize.type_signature)
'( -> <float32[784,10],float32[10]>@SERVER)'
สะท้อนให้เห็นถึงความจริงที่ว่า federated_algorithm.initialize
เป็นฟังก์ชั่นไม่หาเรื่องว่าผลตอบแทนรูปแบบชั้นเดียว (กับ 784 โดยน้ำหนัก 10 เมทริกซ์และ 10 หน่วยอคติ)
str(federated_algorithm.next.type_signature)
'(<server_weights=<float32[784,10],float32[10]>@SERVER,federated_dataset={<float32[?,784],int32[?,1]>*}@CLIENTS> -> <float32[784,10],float32[10]>@SERVER)'
ที่นี่เราจะเห็นว่า federated_algorithm.next
ยอมรับรูปแบบเซิร์ฟเวอร์และไคลเอ็นต์ข้อมูลและผลตอบแทนรูปแบบการปรับปรุงเซิร์ฟเวอร์
การประเมินอัลกอริทึม
ลองวิ่งสักสองสามรอบและดูว่าการสูญเสียจะเปลี่ยนไปอย่างไร ครั้งแรกที่เราจะกำหนดฟังก์ชั่นการประเมินผลการใช้วิธีการรวมศูนย์ที่กล่าวถึงในการกวดวิชาที่สอง
ขั้นแรก เราสร้างชุดข้อมูลการประเมินแบบรวมศูนย์ จากนั้นจึงใช้การประมวลผลล่วงหน้าแบบเดียวกับที่เราใช้สำหรับข้อมูลการฝึกอบรม
central_emnist_test = emnist_test.create_tf_dataset_from_all_clients()
central_emnist_test = preprocess(central_emnist_test)
ต่อไป เราเขียนฟังก์ชันที่ยอมรับสถานะเซิร์ฟเวอร์ และใช้ Keras เพื่อประเมินในชุดข้อมูลทดสอบ ถ้าคุณคุ้นเคยกับ tf.Keras
นี้จะมีลักษณะทุกคนคุ้นเคย แต่บันทึกการใช้งานของ set_weights
!
def evaluate(server_state):
keras_model = create_keras_model()
keras_model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
)
keras_model.set_weights(server_state)
keras_model.evaluate(central_emnist_test)
ตอนนี้ เรามาเริ่มต้นอัลกอริทึมของเราและประเมินในชุดทดสอบกัน
server_state = federated_algorithm.initialize()
evaluate(server_state)
2042/2042 [==============================] - 2s 767us/step - loss: 2.8479 - sparse_categorical_accuracy: 0.1027
มาฝึกกันสักสองสามรอบและดูว่ามีอะไรเปลี่ยนแปลงไหม
for round in range(15):
server_state = federated_algorithm.next(server_state, federated_train_data)
evaluate(server_state)
2042/2042 [==============================] - 2s 738us/step - loss: 2.5867 - sparse_categorical_accuracy: 0.0980
เราเห็นการลดลงเล็กน้อยในฟังก์ชันการสูญเสีย แม้ว่าการกระโดดจะเล็ก แต่เราได้ทำการฝึกซ้อมเพียง 15 รอบและในกลุ่มย่อยของลูกค้า เพื่อให้เห็นผลดีขึ้น เราอาจต้องทำหลายร้อยรอบหากไม่ใช่
การปรับเปลี่ยนอัลกอริทึมของเรา
ณ จุดนี้ ให้หยุดคิดเกี่ยวกับสิ่งที่เราทำสำเร็จ เราได้ปรับใช้ Federated Averaging โดยตรงโดยการรวมโค้ด TensorFlow แท้ (สำหรับการอัปเดตไคลเอ็นต์และเซิร์ฟเวอร์) กับการคำนวณแบบรวมศูนย์จาก Federated Core ของ TFF
เพื่อดำเนินการเรียนรู้ที่ซับซ้อนมากขึ้น เราสามารถปรับเปลี่ยนสิ่งที่เรามีข้างต้นได้ โดยเฉพาะอย่างยิ่ง โดยการแก้ไขโค้ด TF ล้วนๆ ด้านบน เราสามารถเปลี่ยนวิธีที่ไคลเอ็นต์ดำเนินการฝึกอบรม หรือวิธีที่เซิร์ฟเวอร์อัปเดตโมเดล
ท้าทาย: เพิ่ม คลิปลาด ไป client_update
ฟังก์ชั่น
หากเราต้องการทำการเปลี่ยนแปลงที่ใหญ่ขึ้น เราก็สามารถมีที่เก็บเซิร์ฟเวอร์และกระจายข้อมูลได้มากขึ้น ตัวอย่างเช่น เซิร์ฟเวอร์ยังสามารถเก็บอัตราการเรียนรู้ของไคลเอ็นต์ และทำให้เสื่อมลงเมื่อเวลาผ่านไป! หมายเหตุว่านี้จะต้องมีการเปลี่ยนแปลงลายเซ็นชนิดที่ใช้ใน tff.tf_computation
สายดังกล่าวข้างต้น
ท้าทายยาก: ระบบสหพันธ์ Averaging กับการเรียนรู้การสลายตัวของอัตราลูกค้า
ณ จุดนี้ คุณอาจเริ่มตระหนักถึงความยืดหยุ่นในสิ่งที่คุณสามารถนำไปใช้ในกรอบงานนี้ สำหรับความคิด (รวมถึงคำตอบให้กับความท้าทายที่ยากข้างต้น) คุณสามารถดูซอร์สโค้ดสำหรับ tff.learning.build_federated_averaging_process
หรือเช็คเอาท์ต่างๆ โครงการวิจัย โดยใช้ฉิบหาย