View source on GitHub |
An UnweightedAggregationFactory
for reservoir sampling values.
Inherits From: UnweightedAggregationFactory
tff.aggregators.UnweightedReservoirSamplingFactory(
sample_size: int, return_sampling_metadata: bool = False
)
The created tff.templates.AggregationProcess
samples values placed at
CLIENTS
, and outputs the sample placed at SERVER
.
The process has empty state
. The measurements
of this factory counts the
number of non-finite (NaN
or Inf
values) leaves in the client values
before sampling. Specifically, the returned measurements
has the same
structure as the client value, and every leaf node is a tf.int64
scalar
tensor counting the number of clients having non-finite value in that leaf.
For example, suppose we are aggregating from three clients:
client_value_1 = collections.OrderedDict(a=[1.0, 2.0], b=[1.0, np.nan])
client_value_2 = collections.OrderedDict(a=[np.nan, np.inf], b=[1.0, 2.0])
client_value_3 = collections.OrderedDict(a=[1.0, 2.0], b=[np.inf, np.nan])
Then measurements
will be:
collections.OrderedDict(a=tf.constant(1, dtype=int64),
b=tf.constant(2, dtype=int64)
For more about reservoir sampling see https://en.wikipedia.org/wiki/Reservoir_sampling
Args | |
---|---|
sample_size
|
An integer specifying the number of clients sampled (by reservoir sampling algorithm). Values from the sampled clients are collected at the server (see the class documentation for details). |
return_sampling_metadata
|
If True, the result property in the returned
tff.templates.MeasuredProcessOutput object contains a dictionary of
sampled values and other sampling metadata (such as random values
generated during reservoir sampling). Otherwise, it only contains the
sampled values.
|
Raises | |
---|---|
TypeError
|
If any argument type mismatches. |
ValueError
|
If sample_size is not positive.
|
Methods
create
create(
value_type: tff.types.Type
) -> tff.templates.AggregationProcess
Creates a tff.aggregators.AggregationProcess
without weights.
The provided value_type
is a non-federated tff.Type
, that is, not a
tff.FederatedType
.
The returned tff.aggregators.AggregationProcess
will be created for
aggregation of values matching value_type
placed at tff.CLIENTS
.
That is, its next
method will expect type
<S@SERVER, {value_type}@CLIENTS>
, where S
is the unplaced return type of
its initialize
method.
Args | |
---|---|
value_type
|
A non-federated tff.Type .
|
Returns | |
---|---|
A tff.templates.AggregationProcess .
|