View source on GitHub |
Wraps AggregationFactory
to report additional measurements.
tff.aggregators.add_measurements(
inner_agg_factory: tff.aggregators.AggregationFactory
,
*,
client_measurement_fn: Optional[Callable[..., dict[str, Any]]] = None,
server_measurement_fn: Optional[Callable[..., dict[str, Any]]] = None
) -> tff.aggregators.AggregationFactory
The function client_measurement_fn
should be a Python callable that will be
called as client_measurement_fn(value)
or client_measurement_fn(value,
weight)
depending on whether inner_agg_factory
is weighted or unweighted.
It must be traceable by TFF and expect tff.Value
objects placed at CLIENTS
as inputs, and return a collections.OrderedDict
mapping string names to
tensor values placed at SERVER
, which will be added to the measurement dict
produced by the inner_agg_factory
.
Similarly, server_measurement_fn
should be a Python callable that will be
called as server_measurement_fn(result)
where result
is the result (on
server) of the inner aggregation.
One or both of client_measurement_fn
and server_measurement_fn
must be
specified.
Returns | |
---|---|
An AggregationFactory that reports additional measurements.
|