View source on GitHub |
Represents a variable-based model for use in TensorFlow Federated.
Used in the notebooks
Used in the tutorials |
---|
Each VariableModel
will work on a set of tf.Variables
, and each method
should be a computation that can be implemented as a tf.function
; this
implies the class should essentially be stateless from a Python perspective,
as each method will generally only be traced once (per set of arguments) to
create the corresponding TensorFlow graph functions. Thus, VariableModel
instances should behave as expected in both eager and graph (TF 1.0) usage.
In general, tf.Variables
may be either:
- Weights, the variables needed to make predictions with the model.
- Local variables, e.g. to accumulate aggregated metrics across calls to forward_pass.
The weights can be broken down into trainable variables (variables
that can and should be trained using gradient-based methods), and
non-trainable variables (which could include fixed pre-trained layers,
or static model data). These variables are provided via the
trainable_variables
, non_trainable_variables
, and local_variables
properties, and must be initialized by the user of the VariableModel
.
In federated learning, model weights will generally be provided by the
server, and updates to trainable model variables will be sent back to the
server. Local variables are not transmitted, and are instead initialized
locally on the device, and then used to produce aggregated_outputs
which
are sent to the server.
All tf.Variables
should be introduced in __init__
; this could move to a
build
method more inline with Keras (see
https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) in
the future.
Attributes | |
---|---|
input_spec
|
|
local_variables
|
An iterable of tf.Variable objects, see class comment for details.
|
non_trainable_variables
|
An iterable of tf.Variable objects, see class comment for details.
|
trainable_variables
|
An iterable of tf.Variable objects, see class comment for details.
|
Methods
forward_pass
@abc.abstractmethod
forward_pass( batch_input, training=True ) ->
tff.learning.models.BatchOutput
Runs the forward pass and returns results.
This method must be serializable in a tff.tensorflow.computation
or other
backend decorator. Any pure-Python or unserializable logic will not be
runnable in the federated system.
This method should not modify any variables that are part of the model
parameters, that is, variables that influence the predictions (exceptions
being updated, rather than learned, parameters such as BatchNorm means and
variances). Rather, this is done by the training loop. However, this method
may update aggregated metrics computed across calls to forward_pass
; the
final values of such metrics can be accessed via aggregated_outputs
.
Uses in TFF | |
---|---|
|
Args | |
---|---|
batch_input
|
A nested structure that matches the structure of
VariableModel.input_spec and each tensor in batch_input satisfies
tf.TensorSpec.is_compatible_with() for the corresponding
tf.TensorSpec in VariableModel.input_spec .
|
training
|
If True , run the training forward pass, otherwise, run in
evaluation mode. The semantics are generally the same as the training
argument to keras.Model.call ; this might e.g. influence how dropout or
batch normalization is handled.
|
Returns | |
---|---|
A BatchOutput object. The object must include the loss tensor if the
model will be trained via a gradient-based algorithm.
|
metric_finalizers
@abc.abstractmethod
metric_finalizers() ->
tff.learning.metrics.MetricFinalizersType
Creates an collections.OrderedDict
of metric names to finalizers.
This method and the report_local_unfinalized_metrics()
method should have
the same keys (i.e., metric names). A finalizer returned by this method is a
function (typically a tf.function
decorated callable or a
tff.tensorflow.computation
decorated TFF Computation) that takes in a
metric's unfinalized values (returned by
report_local_unfinalized_metrics()
), and returns the finalized metric
values.
This method and the report_local_unfinalized_metrics()
method will be used
together to build a cross-client metrics aggregator. See the documentation
of report_local_unfinalized_metrics()
for more information.
Returns | |
---|---|
An collections.OrderedDict of metric names to finalizers. The metric
names must be
the same as those from the report_local_unfinalized_metrics() method. A
finalizer is a tf.function (or tff.tensorflow.computation ) decorated
callable that takes in a metric's unfinalized values, and returns the
finalized values. This method and the report_local_unfinalized_metrics()
method will be used together to build a cross-client metrics aggregator in
federated training processes or evaluation computations.
|
predict_on_batch
@abc.abstractmethod
predict_on_batch( batch_input, training=True )
report_local_unfinalized_metrics
@abc.abstractmethod
report_local_unfinalized_metrics() -> collections.OrderedDict[str, Any]
Creates an collections.OrderedDict
of metric names to unfinalized values.
For a metric, its unfinalized values are given as a structure (typically a
list) of tensors representing values from aggregating over all previous
forward_pass
calls, unless the reset_metrics
is called. Each time the
reset_metrics
is called, the local metric variables will be reset, and
report_local_unfinalized_metrics
only reports metrics aggregated from the
forward_pass
calls since the last reset_metrics
call. For a Keras
metric, its unfinalized values are typically the tensor values of its state
variables. In general, the tensors can be an arbitrary function of all the
tf.Variable
s of this model.
The metric names returned by this method should be the same as those
expected by the metric_finalizers()
; one should be able to use the
unfinalized values as input to the finalizers to get the finalized values.
Taking tf.keras.metrics.CategoricalAccuracy
as an example, its unfinalized
values can be a list of two tensors (from its state variables): total
and
count
, and the finalizer function performs a tf.math.divide_no_nan
.
In federated learning, this method returns the local results from clients,
which will typically be further aggregated across clients and made available
on the server. This method and the metric_finalizers()
method will be used
together to build a cross-client metrics aggregator. For example, a simple
"sum_then_finalize" aggregator will first sum the unfinalized metric values
from clients, and then call the finalizer functions at the server.
Because both of this method and the metric_finalizers()
method are defined
in a per-metric manner, users have the flexiblity to call finalizer at the
clients or at the server for different metrics. Users also have the freedom
to defined a cross-client metrics aggregator that aggregates a single metric
in multiple ways.
Returns | |
---|---|
An collections.OrderedDict of metric names to unfinalized values. The
metric names
must be the same as those expected by the metric_finalizers() method.
One should be able to use the unfinalized metric values (returned by this
method) as the input to the finalizers (returned by metric_finalizers() )
to get the finalized metrics. This method and the metric_finalizers()
method will be used together to build a cross-client metrics aggregator
when defining the federated training processes or evaluation computations.
|
reset_metrics
@abc.abstractmethod
reset_metrics() -> None
Resets metrics variables to initial value.
This method is a tf.function
. It is used to reset the metrics variables
between different stages in client's local computation. Each time the
reset_metrics
is called, the local metric variables will be reset, and
report_local_unfinalized_metrics
only reports metrics aggregated from the
forward_pass
calls since the last reset_metrics
call. If the
reset_metrics
is never called, report_local_unfinalized_metrics
will
report metrics aggregated over all previous forward_pass
calls.