Métricas de pós-exportação

Como o nome sugere, esta é uma métrica que é adicionada após a exportação, antes da avaliação.

O TFMA é empacotado com várias métricas de avaliação predefinidas, como example_count, auc, confusion_matrix_at_thresholds, precision_recall_at_k, mse, mae, para citar alguns. (Lista completa aqui .)

Se você não encontrar uma métrica existente relevante para seu caso de uso ou quiser personalizar uma métrica, defina sua própria métrica personalizada. Leia sobre os detalhes!

Adicionando métricas personalizadas no TFMA

Definindo métricas personalizadas no TFMA 1.x

Estender a classe base abstrata

Para adicionar uma métrica personalizada, crie uma nova classe estendendo a classe abstrata _PostExportMetric e defina seu construtor e implemente métodos abstratos/não implementados.

Definir construtor

No construtor, tome como parâmetros todas as informações relevantes, como label_key, forecast_key, example_weight_key, metric_tag etc. necessárias para a métrica personalizada.

Implementar métodos abstratos/não implementados
  • verificação_compatibilidade

    Implemente este método para verificar a compatibilidade da métrica com o modelo que está sendo avaliado, ou seja, verificar se todos os recursos necessários, rótulo esperado e chave de previsão estão presentes no modelo no tipo de dados apropriado. São necessários três argumentos:

    • features_dict
    • previsões_dict
    • labels_dict

    Esses dicionários contêm referências a tensores para o modelo.

  • get_metric_ops

    Implemente este método para fornecer operações de métrica (operações de valor e atualização) para calcular a métrica. Semelhante ao método check_compatibility, ele também recebe três argumentos:

    • features_dict
    • previsões_dict
    • labels_dict

    Defina sua lógica de cálculo de métrica usando essas referências a tensores para o modelo.

  • populate_stats_and_pop e populate_plots_and_pop

    Implemente essa métrica para converter os resultados brutos da métrica para o formato proto MetricValue e PlotData . Isso leva três argumentos:

    • slice_key: nome da métrica de fatia à qual pertence.
    • Combined_metrics: Dicionário contendo resultados brutos.
    • output_metrics: Dicionário de saída contendo métrica no formato proto desejado.
@_export('my_metric')
class _MyMetric(_PostExportMetric):
   def __init__(self,
                target_prediction_keys: Optional[List[Text]] = None,
                labels_key: Optional[Text] = None,
                metric_tag: Optional[Text] = None):
      self._target_prediction_keys = target_prediction_keys
      self._label_keys = label_keys
      self._metric_tag = metric_tag
      self._metric_key = 'my_metric_key'

   def check_compatibility(self, features_dict:types.TensorTypeMaybeDict,
                           predictions_dict: types.TensorTypeMaybeDict,
                           labels_dict: types.TensorTypeMaybeDict) -> None:
       # Add compatibility check needed for the metric here.

   def get_metric_ops(self, features_dict: types.TensorTypeMaybeDict,
                      predictions_dict: types.TensorTypeMaybeDict,
                      labels_dict: types.TensorTypeMaybeDict
                     ) -> Dict[bytes, Tuple[types.TensorType,
                     types.TensorType]]:
        # Metric computation logic here.
        # Define value and update ops.
        value_op = compute_metric_value(...)
        update_op = create_update_op(... )
        return {self._metric_key: (value_op, update_op)}

   def populate_stats_and_pop(
       self, slice_key: slicer.SliceKeyType, combined_metrics: Dict[Text, Any],
       output_metrics: Dict[Text, metrics_pb2.MetricValue]) -> None:
       # Parses the metric and converts it into required metric format.
       metric_result = combined_metrics[self._metric_key]
       output_metrics[self._metric_key].double_value.value = metric_result

Uso

# Custom metric callback
custom_metric_callback = my_metric(
    labels_key='label',
    target_prediction_keys=['prediction'])

fairness_indicators_callback =
   post_export_metrics.fairness_indicators(
        thresholds=[0.1, 0.3, 0.5, 0.7, 0.9], labels_key=label)

add_metrics_callbacks = [custom_metric_callback,
   fairness_indicators_callback]

eval_shared_model = tfma.default_eval_shared_model(
    eval_saved_model_path=eval_saved_model_path,
    add_metrics_callbacks=add_metrics_callbacks)

eval_config = tfma.EvalConfig(...)

# Run evaluation
tfma.run_model_analysis(
    eval_config=eval_config, eval_shared_model=eval_shared_model)