View source on GitHub |
Aggregation factory for weighted mean.
Inherits From: WeightedAggregationFactory
tff.aggregators.MeanFactory(
value_sum_factory: Optional[tff.aggregators.UnweightedAggregationFactory
] = None,
weight_sum_factory: Optional[tff.aggregators.UnweightedAggregationFactory
] = None,
no_nan_division: bool = False
)
Used in the notebooks
Used in the tutorials |
---|
The created tff.templates.AggregationProcess
computes the weighted mean of
values placed at CLIENTS
, and outputs the mean placed at SERVER
.
The input arguments of the next
attribute of the process returned by
create
are <state, value, weight>
, where weight
is a scalar broadcasted
to the structure of value
, and the weighted mean refers to the expression
sum(value * weight) / sum(weight)
.
The implementation is parameterized by two inner aggregation factories responsible for the summations above, with the following high-level steps.
- Multiplication of
value
andweight
atCLIENTS
. - Delegation to inner
value_sum_factory
andweight_sum_factory
to realize the sum of weighted values and weights. - Division of summed weighted values and summed weights at
SERVER
.
Note that the the division at SERVER
can protect against division by 0, as
specified by no_nan_division
constructor argument.
The state
is the composed state
of the aggregation processes created by
the two inner aggregation factories. The same holds for measurements
.
Args | |
---|---|
value_sum_factory
|
An optional
tff.aggregators.UnweightedAggregationFactory responsible for summation
of weighted values. If not specified, tff.aggregators.SumFactory is
used.
|
weight_sum_factory
|
An optional
tff.aggregators.UnweightedAggregationFactory responsible for summation
of weights. If not specified, tff.aggregators.SumFactory is used.
|
no_nan_division
|
A bool. If True, the computed mean is 0 if sum of weights is equal to 0. |
Raises | |
---|---|
TypeError
|
If provided value_sum_factory or weight_sum_factory is not
an instance of tff.aggregators.UnweightedAggregationFactory .
|
Methods
create
create(
value_type: factory.ValueType, weight_type: factory.ValueType
) -> tff.templates.AggregationProcess
Creates a tff.aggregators.AggregationProcess
with weights.
The provided value_type
and weight_type
are non-federated tff.Type
s.
That is, neither is a tff.FederatedType
.
The returned tff.aggregators.AggregationProcess
will be created
for aggregation of pairs of values matching value_type
and weight_type
placed at tff.CLIENTS
. That is, its next
method will expect type
<S@SERVER, {value_type}@CLIENTS, {weight_type}@CLIENTS>
, where S
is the
unplaced return type of its initialize
method.
Args | |
---|---|
value_type
|
A non-federated tff.Type .
|
weight_type
|
A non-federated tff.Type .
|
Returns | |
---|---|
A tff.templates.AggregationProcess .
|