Creates aggregator with compression and adaptive zeroing and clipping.
tff.learning.compression_aggregator(
*,
zeroing: bool = True,
clipping: bool = True,
weighted: bool = True,
debug_measurements_fn: Optional[Callable[[tff.aggregators.AggregationFactory
], tff.aggregators.AggregationFactory
]] = None,
**kwargs
) -> tff.aggregators.AggregationFactory
Zeroes out extremely large values for robustness to data corruption on
clients and clips in the L2 norm to moderately high norm for robustness to
outliers. After weighting in mean, the weighted values are uniformly quantized
to reduce the size of the model update communicated from clients to the
server. For details, see Suresh et al. (2017)
http://proceedings.mlr.press/v70/suresh17a/suresh17a.pdf The default
configuration is chosen such that compression does not have adverse effect on
trained model quality in typical tasks.
Args |
zeroing
|
Whether to enable adaptive zeroing for data corruption mitigation.
|
clipping
|
Whether to enable adaptive clipping in the L2 norm for robustness.
Note this clipping is performed prior to the per-coordinate clipping
required for quantization.
|
weighted
|
Whether the mean is weighted (vs. unweighted).
|
debug_measurements_fn
|
A callable to add measurements suitable for debugging
learning algorithms, with possible values as None,
tff.learning.add_debug_measurements or
tff.learning.add_debug_measurements_with_mixed_dtype .
|
**kwargs
|
Keyword arguments.
|
Raises |
TypeError
|
if debug_measurement_fn yields an aggregation factory whose
weight type does not match weighted .
|