TensorFlow 1 version | View source on GitHub |
Encapsulates metric logic and state.
Inherits From: Layer
tf.keras.metrics.Metric(
name=None, dtype=None, **kwargs
)
Usage:
m = SomeMetric(...)
for input in ...:
m.update_state(input)
print('Final result: ', m.result().numpy())
Usage with tf.keras API:
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(64, activation='relu'))
model.add(tf.keras.layers.Dense(64, activation='relu'))
model.add(tf.keras.layers.Dense(10, activation='softmax'))
model.compile(optimizer=tf.compat.v1.train.RMSPropOptimizer(0.01),
loss=tf.keras.losses.categorical_crossentropy,
metrics=[tf.keras.metrics.CategoricalAccuracy()])
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
dataset = tf.data.Dataset.from_tensor_slices((data, labels))
dataset = dataset.batch(32)
dataset = dataset.repeat()
model.fit(dataset, epochs=10, steps_per_epoch=30)
To be implemented by subclasses:
__init__()
: All state variables should be created in this method by callingself.add_weight()
like:self.var = self.add_weight(...)
update_state()
: Has all updates to the state variables like: self.var.assign_add(...).result()
: Computes and returns a value for the metric from the state variables.
Example subclass implementation:
class BinaryTruePositives(tf.keras.metrics.Metric):
def __init__(self, name='binary_true_positives', **kwargs):
super(BinaryTruePositives, self).__init__(name=name, **kwargs)
self.true_positives = self.add_weight(name='tp', initializer='zeros')
def update_state(self, y_true, y_pred, sample_weight=None):
y_true = tf.cast(y_true, tf.bool)
y_pred = tf.cast(y_pred, tf.bool)
values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True))
values = tf.cast(values, self.dtype)
if sample_weight is not None:
sample_weight = tf.cast(sample_weight, self.dtype)
sample_weight = tf.broadcast_weights(sample_weight, values)
values = tf.multiply(values, sample_weight)
self.true_positives.assign_add(tf.reduce_sum(values))
def result(self):
return self.true_positives
Methods
add_weight
add_weight(
name, shape=(), aggregation=tf.compat.v1.VariableAggregation.SUM,
synchronization=tf.VariableSynchronization.ON_READ, initializer=None, dtype=None
)
Adds state variable. Only for use by subclasses.
reset_states
reset_states()
Resets all of the metric state variables.
This function is called between epochs/steps, when a metric is evaluated during training.
result
@abc.abstractmethod
result()
Computes and returns the metric value tensor.
Result computation is an idempotent operation that simply calculates the metric value using the state variables.
update_state
@abc.abstractmethod
update_state( *args, **kwargs )
Accumulates statistics for the metric.
Please use tf.config.experimental_run_functions_eagerly(True)
to execute
this function eagerly for debugging or profiling.
Args | |
---|---|
*args
|
|
**kwargs
|
A mini-batch of inputs to the Metric. |