View source on GitHub |
A model that parameterizes forward pass by model weights.
tff.learning.models.FunctionalModel(
*,
initial_weights: ModelWeights,
predict_on_batch_fn: Callable[[ModelWeights, Any, bool], Any],
loss_fn: Callable[[Any, Any, Any], Any],
metrics_fns: tuple[InitializeMetricsStateFn, UpdateMetricsStateFn, FinalizeMetricsFn] = (empty_metrics_state, noop_update_metrics, noop_finalize_metrics),
input_spec: Any
)
Args | |
---|---|
initial_weights
|
A 2-tuple (trainable, non_trainable) where the two
elements are sequences of weights. Weights must be values convertable to
tf.Tensor (e.g. numpy.ndarray , Python sequences, etc), but not
tf.Tensor values.
|
predict_on_batch_fn
|
A tf.function decorated callable that takes three
arguments, model_weights the same structure as initial_weights , x
the first element of batch_input (or input_spec ), and training a
boolean determinig whether the call is during a training pass (e.g. for
Dropout, BatchNormalization, etc). It must return either a tensor of
predictions or a structure whose first element (as determined by
tf.nest.flatten() ) is a tensor of predictions.
|
loss_fn
|
A callable that takes three arguments, output tensor(s) as
output of predict_on_batch that is interpretable by the loss function,
label the second element of batch_input , and optional
sample_weight that weights the output.
|
metrics_fns
|
A 3-tuple of callables that initialize the metrics state,
update the metrics state, and finalize the metrics values respectively.
This can be the result of tff.learning.metrics.create_functional_metric_fns or custom user written
callables.
|
input_spec
|
A 2-tuple of (x, y) where each element is a nested structure
of tf.TensorSpec . x corresponds to batched model inputs that define
the shape and dtype of x to predict_on_batch_fn , while y
corresponds to batched labels for those inputs that define the shape and
dtype of label to loss_fn .
|
Attributes | |
---|---|
initial_weights
|
|
input_spec
|
Methods
finalize_metrics
@tf.function
finalize_metrics( state: types.MetricsState ) -> collections.OrderedDict[str, Any]
initialize_metrics_state
@tf.function
initialize_metrics_state() -> types.MetricsState
loss
loss(
output: Any, label: Any, sample_weight: Optional[Any] = None
) -> float
Returns the loss value based on the model output and the label.
predict_on_batch
@tf.function
predict_on_batch( model_weights: ModelWeights, x: Any, training: bool = True )
Returns tensor(s) interpretable by the loss function.
update_metrics_state
@tf.function
update_metrics_state( state: GenericMetricsState, labels: Any, batch_output:
tff.learning.models.BatchOutput
, sample_weight: Optional[Any] = None ) -> GenericMetricsState