Metriche e grafici dell'analisi del modello Tensorflow

Panoramica

TFMA supporta le seguenti metriche e grafici:

  • Metriche keras standard ( tf.keras.metrics.* )
    • Tieni presente che non è necessario un modello Keras per utilizzare le metriche Keras. Le metriche vengono calcolate all'esterno del grafico in beam utilizzando direttamente le classi di metriche.
  • Metriche e grafici TFMA standard ( tfma.metrics.* )

  • Metriche Keras personalizzate (metriche derivate da tf.keras.metrics.Metric )

  • Metriche TFMA personalizzate (metriche derivate da tfma.metrics.Metric ) che utilizzano combinatori di travi personalizzati o metriche derivate da altre metriche).

TFMA fornisce inoltre supporto integrato per la conversione delle metriche di classificazione binaria da utilizzare con problemi multiclasse/multietichetta:

  • Binarizzazione basata su ID classe, top K, ecc.
  • Metriche aggregate basate su micro media, macro media, ecc.

TFMA fornisce inoltre supporto integrato per metriche basate su query/classificazione in cui gli esempi vengono raggruppati automaticamente in base a una chiave di query nella pipeline.

Insieme sono disponibili oltre 50 parametri e grafici standard per una varietà di problemi tra cui regressione, classificazione binaria, classificazione multiclasse/multietichetta, classificazione, ecc.

Configurazione

Esistono due modi per configurare i parametri in TFMA: (1) utilizzando tfma.MetricsSpec o (2) creando istanze delle tf.keras.metrics.* e/o tfma.metrics.* in python e utilizzando tfma.metrics.specs_from_metrics per convertirli in un elenco di tfma.MetricsSpec .

Le sezioni seguenti descrivono configurazioni di esempio per diversi tipi di problemi di machine learning.

Metriche di regressione

Di seguito è riportato un esempio di configurazione per un problema di regressione. Consultare i moduli tf.keras.metrics.* e tfma.metrics.* per eventuali parametri aggiuntivi supportati.

from google.protobuf import text_format

metrics_specs = text_format.Parse("""
  metrics_specs {
    metrics { class_name: "ExampleCount" }
    metrics { class_name: "MeanSquaredError" }
    metrics { class_name: "Accuracy" }
    metrics { class_name: "MeanLabel" }
    metrics { class_name: "MeanPrediction" }
    metrics { class_name: "Calibration" }
    metrics {
      class_name: "CalibrationPlot"
      config: '"min_value": 0, "max_value": 10'
    }
  }
""", tfma.EvalConfig()).metrics_specs

Questa stessa configurazione può essere creata utilizzando il seguente codice Python:

metrics = [
    tfma.metrics.ExampleCount(name='example_count'),
    tf.keras.metrics.MeanSquaredError(name='mse'),
    tf.keras.metrics.Accuracy(name='accuracy'),
    tfma.metrics.MeanLabel(name='mean_label'),
    tfma.metrics.MeanPrediction(name='mean_prediction'),
    tfma.metrics.Calibration(name='calibration'),
    tfma.metrics.CalibrationPlot(
        name='calibration', min_value=0, max_value=10)
]
metrics_specs = tfma.metrics.specs_from_metrics(metrics)

Tieni presente che questa configurazione è disponibile anche chiamando tfma.metrics.default_regression_specs .

Metriche di classificazione binaria

Di seguito è riportato un esempio di configurazione per un problema di classificazione binaria. Consultare i moduli tf.keras.metrics.* e tfma.metrics.* per eventuali parametri aggiuntivi supportati.

from google.protobuf import text_format

metrics_specs = text_format.Parse("""
  metrics_specs {
    metrics { class_name: "ExampleCount" }
    metrics { class_name: "BinaryCrossentropy" }
    metrics { class_name: "BinaryAccuracy" }
    metrics { class_name: "AUC" }
    metrics { class_name: "AUCPrecisionRecall" }
    metrics { class_name: "MeanLabel" }
    metrics { class_name: "MeanPrediction" }
    metrics { class_name: "Calibration" }
    metrics { class_name: "ConfusionMatrixPlot" }
    metrics { class_name: "CalibrationPlot" }
  }
""", tfma.EvalConfig()).metrics_specs

Questa stessa configurazione può essere creata utilizzando il seguente codice Python:

metrics = [
    tfma.metrics.ExampleCount(name='example_count'),
    tf.keras.metrics.BinaryCrossentropy(name='binary_crossentropy'),
    tf.keras.metrics.BinaryAccuracy(name='accuracy'),
    tf.keras.metrics.AUC(name='auc', num_thresholds=10000),
    tf.keras.metrics.AUC(
        name='auc_precision_recall', curve='PR', num_thresholds=10000),
    tf.keras.metrics.Precision(name='precision'),
    tf.keras.metrics.Recall(name='recall'),
    tfma.metrics.MeanLabel(name='mean_label'),
    tfma.metrics.MeanPrediction(name='mean_prediction'),
    tfma.metrics.Calibration(name='calibration'),
    tfma.metrics.ConfusionMatrixPlot(name='confusion_matrix_plot'),
    tfma.metrics.CalibrationPlot(name='calibration_plot')
]
metrics_specs = tfma.metrics.specs_from_metrics(metrics)

Tieni presente che questa configurazione è disponibile anche chiamando tfma.metrics.default_binary_classification_specs .

Metriche di classificazione multiclasse/multietichetta

Di seguito è riportato un esempio di configurazione per un problema di classificazione multiclasse. Consultare i moduli tf.keras.metrics.* e tfma.metrics.* per eventuali parametri aggiuntivi supportati.

from google.protobuf import text_format

metrics_specs = text_format.Parse("""
  metrics_specs {
    metrics { class_name: "ExampleCount" }
    metrics { class_name: "SparseCategoricalCrossentropy" }
    metrics { class_name: "SparseCategoricalAccuracy" }
    metrics { class_name: "Precision" config: '"top_k": 1' }
    metrics { class_name: "Precision" config: '"top_k": 3' }
    metrics { class_name: "Recall" config: '"top_k": 1' }
    metrics { class_name: "Recall" config: '"top_k": 3' }
    metrics { class_name: "MultiClassConfusionMatrixPlot" }
  }
""", tfma.EvalConfig()).metrics_specs

Questa stessa configurazione può essere creata utilizzando il seguente codice Python:

metrics = [
    tfma.metrics.ExampleCount(name='example_count'),
    tf.keras.metrics.SparseCategoricalCrossentropy(
        name='sparse_categorical_crossentropy'),
    tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
    tf.keras.metrics.Precision(name='precision', top_k=1),
    tf.keras.metrics.Precision(name='precision', top_k=3),
    tf.keras.metrics.Recall(name='recall', top_k=1),
    tf.keras.metrics.Recall(name='recall', top_k=3),
    tfma.metrics.MultiClassConfusionMatrixPlot(
        name='multi_class_confusion_matrix_plot'),
]
metrics_specs = tfma.metrics.specs_from_metrics(metrics)

Tieni presente che questa configurazione è disponibile anche chiamando tfma.metrics.default_multi_class_classification_specs .

Metriche binarizzate multiclasse/multietichetta

I parametri multiclasse/multietichetta possono essere binarizzati per produrre parametri per classe, per top_k e così via utilizzando tfma.BinarizationOptions . Per esempio:

from google.protobuf import text_format

metrics_specs = text_format.Parse("""
  metrics_specs {
    binarize: { class_ids: { values: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] } }
    // Metrics to binarize
    metrics { class_name: "AUC" }
    ...
  }
""", tfma.EvalConfig()).metrics_specs

Questa stessa configurazione può essere creata utilizzando il seguente codice Python:

metrics = [
    // Metrics to binarize
    tf.keras.metrics.AUC(name='auc', num_thresholds=10000),
    ...
]
metrics_specs = tfma.metrics.specs_from_metrics(
    metrics, binarize=tfma.BinarizationOptions(
        class_ids={'values': [0,1,2,3,4,5,6,7,8,9]}))

Metriche aggregate multiclasse/multietichetta

I parametri multiclasse/multietichetta possono essere aggregati per produrre un singolo valore aggregato per un parametro di classificazione binaria utilizzando tfma.AggregationOptions .

Tieni presente che le impostazioni di aggregazione sono indipendenti dalle impostazioni di binarizzazione, quindi puoi utilizzare sia tfma.AggregationOptions che tfma.BinarizationOptions contemporaneamente.

Micromedia

La micro media può essere eseguita utilizzando l'opzione micro_average in tfma.AggregationOptions . Per esempio:

from google.protobuf import text_format

metrics_specs = text_format.Parse("""
  metrics_specs {
    aggregate: { micro_average: true }
    // Metrics to aggregate
    metrics { class_name: "AUC" }
    ...
  }
""", tfma.EvalConfig()).metrics_specs

Questa stessa configurazione può essere creata utilizzando il seguente codice Python:

metrics = [
    // Metrics to aggregate
    tf.keras.metrics.AUC(name='auc', num_thresholds=10000),
    ...
]
metrics_specs = tfma.metrics.specs_from_metrics(
    metrics, aggregate=tfma.AggregationOptions(micro_average=True))

La micro media supporta anche l'impostazione di top_k in cui nel calcolo vengono utilizzati solo i valori k principali. Per esempio:

from google.protobuf import text_format

metrics_specs = text_format.Parse("""
  metrics_specs {
    aggregate: {
      micro_average: true
      top_k_list: { values: [1, 3] }
    }
    // Metrics to aggregate
    metrics { class_name: "AUC" }
    ...
  }
""", tfma.EvalConfig()).metrics_specs

Questa stessa configurazione può essere creata utilizzando il seguente codice Python:

metrics = [
    // Metrics to aggregate
    tf.keras.metrics.AUC(name='auc', num_thresholds=10000),
    ...
]
metrics_specs = tfma.metrics.specs_from_metrics(
    metrics,
    aggregate=tfma.AggregationOptions(micro_average=True,
                                      top_k_list={'values': [1, 3]}))

Macro/Media macroponderata

La media delle macro può essere eseguita utilizzando le opzioni macro_average weighted_macro_average all'interno di tfma.AggregationOptions . A meno che non vengano utilizzate le impostazioni top_k , la macro richiede l'impostazione di class_weights per sapere per quali classi calcolare la media. Se non viene fornito un class_weight , si presuppone 0,0. Per esempio:

from google.protobuf import text_format

metrics_specs = text_format.Parse("""
  metrics_specs {
    aggregate: {
      macro_average: true
      class_weights: { key: 0 value: 1.0 }
      class_weights: { key: 1 value: 1.0 }
      class_weights: { key: 2 value: 1.0 }
      class_weights: { key: 3 value: 1.0 }
      class_weights: { key: 4 value: 1.0 }
      class_weights: { key: 5 value: 1.0 }
      class_weights: { key: 6 value: 1.0 }
      class_weights: { key: 7 value: 1.0 }
      class_weights: { key: 8 value: 1.0 }
      class_weights: { key: 9 value: 1.0 }
    }
    // Metrics to aggregate
    metrics { class_name: "AUC" }
    ...
  }
""", tfma.EvalConfig()).metrics_specs

Questa stessa configurazione può essere creata utilizzando il seguente codice Python:

metrics = [
    // Metrics to aggregate
    tf.keras.metrics.AUC(name='auc', num_thresholds=10000),
    ...
]
metrics_specs = tfma.metrics.specs_from_metrics(
    metrics,
    aggregate=tfma.AggregationOptions(
        macro_average=True, class_weights={i: 1.0 for i in range(10)}))

Come la micro media, anche la macro media supporta l'impostazione di top_k in cui nel calcolo vengono utilizzati solo i valori k principali. Per esempio:

from google.protobuf import text_format

metrics_specs = text_format.Parse("""
  metrics_specs {
    aggregate: {
      macro_average: true
      top_k_list: { values: [1, 3] }
    }
    // Metrics to aggregate
    metrics { class_name: "AUC" }
    ...
  }
""", tfma.EvalConfig()).metrics_specs

Questa stessa configurazione può essere creata utilizzando il seguente codice Python:

metrics = [
    // Metrics to aggregate
    tf.keras.metrics.AUC(name='auc', num_thresholds=10000),
    ...
]
metrics_specs = tfma.metrics.specs_from_metrics(
    metrics,
    aggregate=tfma.AggregationOptions(macro_average=True,
                                      top_k_list={'values': [1, 3]}))

Metriche basate su query/classifica

Le metriche basate su query/classifica sono abilitate specificando l'opzione query_key nelle specifiche delle metriche. Per esempio:

from google.protobuf import text_format

metrics_specs = text_format.Parse("""
  metrics_specs {
    query_key: "doc_id"
    metrics {
      class_name: "NDCG"
      config: '"gain_key": "gain", "top_k_list": [1, 2]'
    }
    metrics { class_name: "MinLabelPosition" }
  }
""", tfma.EvalConfig()).metrics_specs

Questa stessa configurazione può essere creata utilizzando il seguente codice Python:

metrics = [
    tfma.metrics.NDCG(name='ndcg', gain_key='gain', top_k_list=[1, 2]),
    tfma.metrics.MinLabelPosition(name='min_label_position')
]
metrics_specs = tfma.metrics.specs_from_metrics(metrics, query_key='doc_id')

Metriche di valutazione multi-modello

TFMA supporta la valutazione di più modelli contemporaneamente. Quando viene eseguita la valutazione multi-modello, le metriche verranno calcolate per ciascun modello. Per esempio:

from google.protobuf import text_format

metrics_specs = text_format.Parse("""
  metrics_specs {
    # no model_names means all models
    ...
  }
""", tfma.EvalConfig()).metrics_specs

Se è necessario calcolare le metriche per un sottoinsieme di modelli, impostare model_names in metric_specs . Per esempio:

from google.protobuf import text_format

metrics_specs = text_format.Parse("""
  metrics_specs {
    model_names: ["my-model1"]
    ...
  }
""", tfma.EvalConfig()).metrics_specs

L'API specs_from_metrics supporta anche il passaggio di nomi di modelli:

metrics = [
    ...
]
metrics_specs = tfma.metrics.specs_from_metrics(
    metrics, model_names=['my-model1'])

Metriche di confronto dei modelli

TFMA supporta la valutazione delle metriche di confronto per un modello candidato rispetto a un modello di base. Un modo semplice per impostare la coppia di modelli candidato e baseline è passare un eval_shared_model con i nomi di modello corretti (tfma.BASELINE_KEY e tfma.CANDIDATE_KEY):


eval_config = text_format.Parse("""
  model_specs {
    # ... model_spec without names ...
  }
  metrics_spec {
    # ... metrics ...
  }
""", tfma.EvalConfig())

eval_shared_models = [
  tfma.default_eval_shared_model(
      model_name=tfma.CANDIDATE_KEY,
      eval_saved_model_path='/path/to/saved/candidate/model',
      eval_config=eval_config),
  tfma.default_eval_shared_model(
      model_name=tfma.BASELINE_KEY,
      eval_saved_model_path='/path/to/saved/baseline/model',
      eval_config=eval_config),
]

eval_result = tfma.run_model_analysis(
    eval_shared_models,
    eval_config=eval_config,
    # This assumes your data is a TFRecords file containing records in the
    # tf.train.Example format.
    data_location="/path/to/file/containing/tfrecords",
    output_path="/path/for/output")

Le metriche di confronto vengono calcolate automaticamente per tutte le metriche differenziabili (attualmente solo le metriche dei valori scalari come precisione e AUC).

Metriche del modello multi-output

TFMA supporta la valutazione delle metriche su modelli che hanno output diversi. I modelli multi-output memorizzano le previsioni di output sotto forma di un dict codificato in base al nome dell'output. Quando vengono utilizzati modelli a più output, i nomi degli output associati a un set di parametri devono essere specificati nella sezione output_names di MetricsSpec. Per esempio:

from google.protobuf import text_format

metrics_specs = text_format.Parse("""
  metrics_specs {
    output_names: ["my-output"]
    ...
  }
""", tfma.EvalConfig()).metrics_specs

L'API specs_from_metrics supporta anche il passaggio di nomi di output:

metrics = [
    ...
]
metrics_specs = tfma.metrics.specs_from_metrics(
    metrics, output_names=['my-output'])

Personalizzazione delle impostazioni metriche

TFMA consente la personalizzazione delle impostazioni utilizzate con metriche diverse. Ad esempio, potresti voler modificare il nome, impostare soglie, ecc. Ciò viene fatto aggiungendo una sezione config alla configurazione della metrica. La configurazione viene specificata utilizzando la versione stringa JSON dei parametri che verrebbero passati al metodo metrics __init__ (per facilità d'uso le parentesi iniziali e finali "{" e "}" possono essere omesse). Per esempio:

from google.protobuf import text_format

metrics_specs = text_format.Parse("""
  metrics_specs {
    metrics {
      class_name: "ConfusionMatrixAtThresholds"
      config: '"thresholds": [0.3, 0.5, 0.8]'
    }
  }
""", tfma.MetricsSpec()).metrics_specs

Naturalmente questa personalizzazione è supportata anche direttamente:

metrics = [
   tfma.metrics.ConfusionMatrixAtThresholds(thresholds=[0.3, 0.5, 0.8]),
]
metrics_specs = tfma.metrics.specs_from_metrics(metrics)

Uscite

L'output di una valutazione metrica è una serie di chiavi/valori metrici e/o chiavi/valori di grafico in base alla configurazione utilizzata.

Chiavi metriche

Le MetricKey vengono definite utilizzando un tipo di chiave strutturata. Questa chiave identifica in modo univoco ciascuno dei seguenti aspetti di una metrica:

  • Nome della metrica ( auc , mean_label , ecc.)
  • Nome del modello (utilizzato solo in caso di valutazione multi-modello)
  • Nome dell'output (utilizzato solo se vengono valutati modelli a più uscite)
  • Chiave secondaria (ad esempio ID classe se il modello multiclasse è binarizzato)

Valore metrico

I MetricValues ​​vengono definiti utilizzando un protocollo che incapsula i diversi tipi di valore supportati dai diversi parametri (ad esempio double , ConfusionMatrixAtThresholds , ecc.).

Di seguito sono riportati i tipi di valori parametrici supportati:

  • double_value - Un wrapper per un tipo double.
  • bytes_value : un valore in byte.
  • bounded_value - Rappresenta un valore reale che potrebbe essere una stima puntuale, facoltativamente con limiti approssimativi di qualche tipo. Ha le proprietà value , lower_bound e upper_bound .
  • value_at_cutoffs - Valore ai cutoff (ad esempio precisione@K, ​​richiamo@K). Ha values di proprietà, ognuno dei quali ha proprietà cutoff e value .
  • confusion_matrix_at_thresholds - Matrice di confusione alle soglie. Dispone di matrices di proprietà, ognuna delle quali ha proprietà per valori di threshold , precision , recall e matrice di confusione come false_negatives .
  • array_value - Per parametri che restituiscono una matrice di valori.

Chiavi di trama

Le PlotKey sono simili alle chiavi metriche, tranne per il fatto che per ragioni storiche tutti i valori dei grafici sono archiviati in un unico protocollo, quindi la chiave della trama non ha un nome.

Tracciare i valori

Tutti i grafici supportati sono archiviati in un unico protocollo chiamato PlotData .

RisultatoValutazione

Il risultato di un'esecuzione di valutazione è un tfma.EvalResult . Questo record contiene slicing_metrics che codificano la chiave della metrica come un dict a più livelli in cui i livelli corrispondono rispettivamente al nome di output, all'ID di classe, al nome della metrica e al valore della metrica. Questo è destinato ad essere utilizzato per la visualizzazione dell'interfaccia utente in un notebook Jupiter. Se è necessario l'accesso ai dati sottostanti, è necessario utilizzare invece il file dei risultati metrics (vedere metrics_for_slice.proto ).

Personalizzazione

Oltre alle metriche personalizzate aggiunte come parte di un keras salvato (o di un EvalSavedModel legacy). Esistono due modi per personalizzare le metriche nel salvataggio successivo a TFMA: (1) definendo una classe di metriche Keras personalizzata e (2) definendo una classe di metriche TFMA personalizzata supportata da un combinatore di travi.

In entrambi i casi, le metriche vengono configurate specificando il nome della classe di metriche e del modulo associato. Per esempio:

from google.protobuf import text_format

metrics_specs = text_format.Parse("""
  metrics_specs {
    metrics { class_name: "MyMetric" module: "my.module"}
  }
""", tfma.EvalConfig()).metrics_specs

Metriche Keras personalizzate

Per creare una metrica keras personalizzata, gli utenti devono estendere tf.keras.metrics.Metric con la loro implementazione e quindi assicurarsi che il modulo della metrica sia disponibile al momento della valutazione.

Tieni presente che per le metriche aggiunte dopo il salvataggio del modello, TFMA supporta solo le metriche che accettano etichetta (ovvero y_true), previsione (y_pred) e peso di esempio (sample_weight) come parametri per il metodo update_state .

Esempio metrico di Keras

Di seguito è riportato un esempio di metrica Keras personalizzata:

class MyMetric(tf.keras.metrics.Mean):

  def __init__(self, name='my_metric', dtype=None):
    super(MyMetric, self).__init__(name=name, dtype=dtype)

  def update_state(self, y_true, y_pred, sample_weight=None):
    return super(MyMetric, self).update_state(
        y_pred, sample_weight=sample_weight)

Metriche TFMA personalizzate

Per creare una metrica TFMA personalizzata, gli utenti devono estendere tfma.metrics.Metric con la loro implementazione e quindi assicurarsi che il modulo della metrica sia disponibile al momento della valutazione.

Metrico

Un'implementazione tfma.metrics.Metric è costituita da un insieme di kwarg che definiscono la configurazione delle metriche insieme a una funzione per creare i calcoli (possibilmente multipli) necessari per calcolare il valore delle metriche. È possibile utilizzare due tipi di calcolo principali: tfma.metrics.MetricComputation e tfma.metrics.DerivedMetricComputation , descritti nelle sezioni seguenti. Alla funzione che crea questi calcoli verranno passati i seguenti parametri come input:

  • eval_config: tfam.EvalConfig
    • La configurazione eval passata al valutatore (utile per cercare le impostazioni delle specifiche del modello come la chiave di previsione da utilizzare, ecc.).
  • model_names: List[Text]
    • Elenco di nomi di modelli per cui calcolare le metriche (nessuno se modello singolo)
  • output_names: List[Text] .
    • Elenco di nomi di output per cui calcolare i parametri (nessuno se modello singolo)
  • sub_keys: List[tfma.SubKey] .
    • Elenco di sottochiavi (ID classe, K superiore, ecc.) per calcolare le metriche per (o Nessuna)
  • aggregation_type: tfma.AggregationType
    • Tipo di aggregazione se si calcola una metrica di aggregazione.
  • class_weights: Dict[int, float] .
    • Pesi delle classi da utilizzare se si calcola una metrica di aggregazione.
  • query_key: Text
    • Chiave di query utilizzata se si calcola una metrica basata su query/classificazione.

Se una metrica non è associata a una o più di queste impostazioni, potrebbe escludere tali parametri dalla definizione della firma.

Se una metrica viene calcolata allo stesso modo per ciascun modello, output e chiave secondaria, è possibile utilizzare l'utilità tfma.metrics.merge_per_key_computations per eseguire gli stessi calcoli separatamente per ciascuno di questi input.

Calcolo metrico

Un MetricComputation è costituito da una combinazione di preprocessors e un combiner . I preprocessors sono un elenco di preprocessor , che è un beam.DoFn che accetta gli estratti come input e restituisce lo stato iniziale che verrà utilizzato dal combinatore (vedi architettura per maggiori informazioni su cosa sono gli estratti). Tutti i preprocessori verranno eseguiti in sequenza nell'ordine dell'elenco. Se i preprocessors sono vuoti, al combinatore verrà passato StandardMetricInputs (gli input metrici standard contengono etichette, previsioni e example_weights). Il combiner è un beam.CombineFn che accetta una tupla di (slice key, output del preprocessore) come input e restituisce una tupla di (slice_key, metric results dict) come risultato.

Si noti che l'affettamento avviene tra i preprocessors e combiner .

Si noti che se un calcolo metrico vuole utilizzare entrambi gli input metrici standard, ma aumentarli con alcune funzionalità dagli estratti features , è possibile utilizzare lo speciale FeaturePreprocessor che unirà le funzionalità richieste da più combinatori in un unico valore StandardMetricsInputs condiviso che viene passato a tutti i combinatori (i combinatori sono responsabili di leggere le funzionalità a cui sono interessati e di ignorare il resto).

Esempio

Quello che segue è un esempio molto semplice di definizione della metrica TFMA per il calcolo di EsempioCount:

class ExampleCount(tfma.metrics.Metric):

  def __init__(self, name: Text = 'example_count'):
    super(ExampleCount, self).__init__(_example_count, name=name)


def _example_count(
    name: Text = 'example_count') -> tfma.metrics.MetricComputations:
  key = tfma.metrics.MetricKey(name=name)
  return [
      tfma.metrics.MetricComputation(
          keys=[key],
          preprocessors=[_ExampleCountPreprocessor()],
          combiner=_ExampleCountCombiner(key))
  ]


class ExampleCountTest(tfma.test.testutil.TensorflowModelAnalysisTest):

  def testExampleCount(self):
    metric = ExampleCount()
    computations = metric.computations(example_weighted=False)
    computation = computations[0]

    with beam.Pipeline() as pipeline:
      result = (
          pipeline
          | 'Create' >> beam.Create([...])  # Add inputs
          | 'PreProcess' >> beam.ParDo(computation.preprocessors[0])
          | 'Process' >> beam.Map(tfma.metrics.to_standard_metric_inputs)
          | 'AddSlice' >> beam.Map(lambda x: ((), x))
          | 'ComputeMetric' >> beam.CombinePerKey(computation.combiner)
      )

      def check_result(got):
        try:
          self.assertLen(got, 1)
          got_slice_key, got_metrics = got[0]
          self.assertEqual(got_slice_key, ())
          key = computation.keys[0]
          self.assertIn(key, got_metrics)
          self.assertAlmostEqual(got_metrics[key], expected_value, places=5)
        except AssertionError as err:
          raise util.BeamAssertException(err)

      util.assert_that(result, check_result, label='result')

class _ExampleCountPreprocessor(beam.DoFn):

  def process(self, extracts: tfma.Extracts) -> Iterable[int]:
    yield 1


class _ExampleCountPreprocessorTest(unittest.TestCase):

  def testExampleCountPreprocessor(self):
    ...  # Init the test case here
    with beam.Pipeline() as pipeline:
      updated_pcoll = (
          pipeline
          | 'Create' >> beam.Create([...])  # Add inputs
          | 'Preprocess'
          >> beam.ParDo(
              _ExampleCountPreprocessor()
          )
      )

      beam_testing_util.assert_that(
          updated_pcoll,
          lambda result: ...,  # Assert the test case
      )


class _ExampleCountCombiner(beam.CombineFn):

  def __init__(self, metric_key: tfma.metrics.MetricKey):
    self._metric_key = metric_key

  def create_accumulator(self) -> int:
    return 0

  def add_input(self, accumulator: int, state: int) -> int:
    return accumulator + state

  def merge_accumulators(self, accumulators: Iterable[int]) -> int:
    accumulators = iter(accumulators)
    result = next(accumulator)
    for accumulator in accumulators:
      result += accumulator
    return result

  def extract_output(self,
                     accumulator: int) -> Dict[tfma.metrics.MetricKey, int]:
    return {self._metric_key: accumulator}

Calcolo metrico derivato

Un DerivedMetricComputation è costituito da una funzione di risultato utilizzata per calcolare i valori metrici in base all'output di altri calcoli metrici. La funzione di risultato accetta un dict di valori calcolati come input e restituisce un dict di risultati metrici aggiuntivi.

Si noti che è accettabile (consigliato) includere i calcoli da cui dipende un calcolo derivato nell'elenco dei calcoli creati da una metrica. In questo modo si evita di dover precreare e passare calcoli condivisi tra più parametri. Il valutatore deduplica automaticamente i calcoli che hanno la stessa definizione in modo che venga effettivamente eseguito un solo calcolo.

Esempio

Le metriche TJUR forniscono un buon esempio di metriche derivate.