View source on GitHub |
A stateful process for learning tasks that produces metrics.
Inherits From: IterativeProcess
tff.learning.templates.LearningProcess(
initialize_fn: tff.Computation
,
next_fn: tff.Computation
,
get_model_weights: tff.Computation
,
set_model_weights: tff.Computation
,
*,
get_hparams_fn: Optional[tff.Computation
] = None,
set_hparams_fn: Optional[tff.Computation
] = None
)
This class inherits the constraints documented by
tff.templates.IterativeProcess
, including an initialize
and next
attribute. The LearningProcess
also contains additional attributes,
including get_model_weights
and get_hparams
. The former can be used to
get out structures suitable for evaluation purposes, while the latter can
be used to extract hyperparameters from the process. There are also
corresponding set_model_weights
and set_hparams
attributes that can set
these structures in a given state.
For example, given a LearningProcess process
and client data data
, we
could call the following to initialize, optionally load other model weights,
update the state three times, and extract the model weights of the state:
state = process.initialize()
# Optional: state = process.set_model_weights(state, other_weights)
for _ in range(3):
state, metrics = process.next(state, data)
model_weights = process.get_model_weights(state)
Args | |
---|---|
initialize_fn
|
A no-arg tff.Computation that creates the initial state
of the learning process.
|
next_fn
|
A tff.Computation that defines an iterated function. Given that
initialize_fn returns a type S@SERVER , the next_fn must return a
LearningProcessOutput where the state attribute is assignable from
values with type S@SERVER , and accepts two arguments with types
assignable from values with type S@SERVER and {D*}@CLIENTS .
|
get_model_weights
|
A tff.Computation that accepts an input S whose
type is assignable from the result of init_fn . This computation is
used to create a representation of the state that can be used for
downstream tasks without requiring access to the entire server state.
For example, get_model_weights could be used to extract model weights
suitable for computing evaluation metrics on held-out data.
|
set_model_weights
|
A tff.Computation that accepts two inputs S and M
where the type of S is assignable from values with the type returned
by init_fn and M is a representation of the model weights stored in
S . This updates the model weights representation within the state with
the incoming value and returns a new value of type S .
|
get_hparams_fn
|
An optional tff.Computation accepting the state S and
returning the hyperparameters H . If not provided, this defaults to a
computation that returns an empty ordered dictionary, regardless of the
contents of the state.
|
set_hparams_fn
|
An optional tff.Computation accepting the state S and
hyperparameters H (matching the output of get_hparams_fn ) and
returning an updated state S . If not provided, this defaults to a
pass-through computation that returns the input state regardless of the
hparams passed in.
|
Raises | |
---|---|
TypeError
|
If initialize_fn and next_fn are not instances of
tff.Computation .
|
TemplateInitFnParamNotEmptyError
|
If initialize_fn has any input
arguments.
|
TemplateStateNotAssignableError
|
If the state returned by either
initialize_fn or next_fn is not assignable to the first input
argument of next_fn .
|
TemplateNextFnNumArgsError
|
If next_fn does not have at exactly two
input arguments.
|
LearningProcessPlacementError
|
If the placements of initialize_fn and
next_fn do not match the expected type placements.
|
LearningProcessOutputError
|
If next_fn does not return a
LearningProcessOutput .
|
GetModelWeightsTypeSignatureError
|
If the input type of get_model_weights does not match the process state type. |
SetModelWeightsTypeSignatureError
|
If the type of the first input or the type of the output of set_model_weights does not match the process state type. |
Attributes | |
---|---|
get_hparams
|
A tff.Computation returning the hyperparameters of a server state.
This computation accepts an unplaced state of the process (originally
produced by the |
get_model_weights
|
A tff.Computation returning the model weights of a server state.
This computation accepts an unplaced state of the process (originally
produced by the |
initialize
|
A tff.Computation that initializes the process.
This computation must have no input arguments, and its output must be the
initial state of the learning process, placed at |
next
|
A tff.Computation that runs one iteration of the process.
The first argument of this computation should always be the current state
(originally produced by the |
set_hparams
|
A tff.Computation that sets the hyperparamters of a server state.
This computation accepts two arguments: an unplaced state of the process
(originally produced by the |
set_model_weights
|
A tff.Computation that sets the model weights of a server state.
This computation accepts two arguments: an unplaced state of the process
(originally produced by the |
state_type
|
The tff.Type of the state of the process.
|