View source on GitHub |
Builds a learning process for Mime Lite with optimizer scheduling.
tff.learning.algorithms.build_mime_lite_with_optimizer_schedule(
model_fn: Union[Callable[[], tff.learning.models.VariableModel
], tff.learning.models.FunctionalModel
],
learning_rate_fn: Callable[[int], float],
base_optimizer: tff.learning.optimizers.Optimizer
,
server_optimizer: tff.learning.optimizers.Optimizer
= sgdm.build_sgdm(1.0),
client_weighting: Optional[tff.learning.ClientWeighting
] = tff.learning.ClientWeighting.NUM_EXAMPLES
,
model_distributor: Optional[tff.learning.templates.DistributionProcess
] = None,
model_aggregator: Optional[tff.aggregators.WeightedAggregationFactory
] = None,
full_gradient_aggregator: Optional[tff.aggregators.WeightedAggregationFactory
] = None,
metrics_aggregator: Optional[tff.learning.metrics.MetricsAggregatorType
] = None,
loop_implementation: tff.learning.LoopImplementation
= tff.learning.LoopImplementation.DATASET_REDUCE
) -> tff.learning.templates.LearningProcess
This function creates a tff.learning.templates.LearningProcess
that performs
Mime Lite algorithm on client models. The iterative process has the following
methods inherited from tff.learning.templates.LearningProcess
:
initialize
: Atff.Computation
with the functional type signature( -> S@SERVER)
, whereS
is atff.learning.templates.LearningAlgorithmState
representing the initial state of the server.next
: Atff.Computation
with the functional type signature(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)
whereS
is atff.learning.templates.LearningAlgorithmState
whose type matches the output ofinitialize
and{B*}@CLIENTS
represents the client datasets. The outputL
contains the updated server state, as well as aggregated metrics at the server, including client training metrics and any other metrics from distribution and aggregation processes.get_model_weights
: Atff.Computation
with type signature(S -> M)
, whereS
is atff.learning.templates.LearningAlgorithmState
whose type matches the output ofinitialize
andnext
, andM
represents the type of the model weights used during training.set_model_weights
: Atff.Computation
with type signature(<S, M> -> S)
, whereS
is atff.learning.templates.LearningAlgorithmState
whose type matches the output ofinitialize
andM
represents the type of the model weights used during training.
Each time the next
method is called, the server model is communicated to
each client using the provided model_distributor
. For each client, local
training is performed using optimizer
, where its state is communicated by
the server, and kept intact during local training. The state is updated only
at the server based on the full gradient evaluated by the clients based on the
current server model state. The client full gradients are aggregated by
weighted full_gradient_aggregator
. Each client computes the difference
between the client model after training and its initial model. These model
deltas are then aggregated by weighted model_aggregator
. Both of the
aggregations are weighted, according to client_weighting
. The aggregate
model delta is added to the existing server model state.
The Mime Lite algorithm is based on the paper "Breaking the centralized barrier for cross-device federated learning." Sai Praneeth Karimireddy, Martin Jaggi, Satyen Kale, Mehryar Mohri, Sashank Reddi, Sebastian U. Stich, and Ananda Theertha Suresh. Advances in Neural Information Processing Systems 34 (2021). https://proceedings.neurips.cc/paper/2021/file/f0e6be4ce76ccfa73c5a540d992d0756-Paper.pdf
Note that Keras optimizers are not supported. This is due to the Mime Lite
algorithm applying the optimizer without changing it state at clients
(optimizer's tf.Variable
s in the case of Keras), which is not possible with
Keras optimizers without reaching into private implementation details and
incurring additional computation and memory cost at clients.
Args | |
---|---|
model_fn
|
A no-arg function that returns a
tff.learning.models.VariableModel , or an instance of a
tff.learning.models.FunctionalModel . When passing a callable, the
callable must not capture TensorFlow tensors or variables and use them.
The model must be constructed entirely from scratch on each invocation,
returning the same pre-constructed model each call will result in an
error.
|
learning_rate_fn
|
A callable accepting an integer round number and returning
a float to be used as a learning rate for the optimizer.
learning_rate_fn must be serializable by Tensorflow (e.g. via
tf.function ).
|
base_optimizer
|
A tff.learning.optimizers.Optimizer which will be used for
both creating and updating a global optimizer state, as well as
optimization at clients given the global state, which is fixed during the
optimization.
|
server_optimizer
|
A tff.learning.optimizers.Optimizer which will be used
for applying the aggregate model update to the global model weights.
|
client_weighting
|
A member of tff.learning.ClientWeighting that specifies
a built-in weighting method. By default, weighting by number of examples
is used.
|
model_distributor
|
An optional DistributionProcess that distributes the
model weights on the server to the clients. If set to None , the
distributor is constructed via distributors.build_broadcast_process .
|
model_aggregator
|
An optional tff.aggregators.WeightedAggregationFactory
used to aggregate client updates on the server. If None , this is set to
tff.aggregators.MeanFactory .
|
full_gradient_aggregator
|
An optional
tff.aggregators.WeightedAggregationFactory used to aggregate the full
gradients on client datasets. If None , this is set to
tff.aggregators.MeanFactory .
|
metrics_aggregator
|
A function that takes in the metric finalizers (i.e.,
tff.learning.models.VariableModel.metric_finalizers() ) and a
tff.types.StructWithPythonType of the unfinalized metrics (i.e., the TFF
type of
tff.learning.models.VariableModel.report_local_unfinalized_metrics() ),
and returns a tff.Computation for aggregating the unfinalized metrics.
If None , this is set to tff.learning.metrics.sum_then_finalize .
|
loop_implementation
|
Changes the implementation of the training loop
generated. See tff.learning.LoopImplementation for more details.
|
Returns | |
---|---|
A tff.learning.templates.LearningProcess .
|