Turn a Keras metric construction method into a tuple of pure functions.
tff.learning.metrics.create_functional_metric_fns(
metrics_constructor: Union[MetricConstructor, MetricsConstructor, MetricConstructors]
) -> tuple[Callable[[], StateVar], Callable[[StateVar, Any, Any, Any], StateVar],
Callable[[StateVar], Any]]
This can be used to convert Keras metrics for use in
tff.learning.models.FunctionalModel
. The method traces the metric logic into
three tf.function
with explicit state
parameters that replace the
closure over internal tf.Variable
of the tf.keras.metrics.Metric
.
Example |
>>> metric = tf.keras.metrics.Accuracy()
>>> metric.update_state([1.0, 1.0], [0.0, 1.0])
>>> metric.result() # == 0.5
>>>
>>> metric_fns = tff.learning.metrics.create_functional_metric_fns(
>>> tf.keras.metrics.Accuracy)
>>> initialize, update, finalize = metric_fns
>>> state = initialize()
>>> batch_output = tff.learning.models.BatchOutput(predictions=[0.0, 1.0])
>>> state = update(state, [1.0, 1.0], batch_output)
>>> finalize(state) # == 0.5
|
Returns |
A 3-tuple of tf.function s namely (initialize, update, finalize) .
initialize is a no-arg function used to create the algrebraic "zero"
before reducing the metric over batches of examples. update is a function
that takes three arguments, the state, labels, and the
tff.learning.models.BatchOutput structure from the model's forward pass,
and is used to add an observation to the metric. finalize only takes a
state argument and returns the final metric value based on observations
previously added.
|
Raises |
TypeError
|
If metrics_constructor is not a callable or OrderedDict , or
if metrics_constructor is a callable returning values of the wrong type.
|