Métricas posteriores a la exportación

Como sugiere el nombre, esta es una métrica que se agrega después de la exportación, antes de la evaluación.

TFMA incluye varias métricas de evaluación predefinidas, como example_count, auc, confusion_matrix_at_thresholds, precision_recall_at_k, mse, mae, por nombrar algunas. (Lista completa aquí ).

Si no encuentra una métrica existente relevante para su caso de uso o desea personalizar una métrica, puede definir su propia métrica personalizada. ¡Sigue leyendo para conocer los detalles!

Agregar métricas personalizadas en TFMA

Definición de métricas personalizadas en TFMA 1.x

Ampliar la clase base abstracta

Para agregar una métrica personalizada, cree una nueva clase que extienda la clase abstracta _PostExportMetric y defina su constructor e implemente métodos abstractos/no implementados.

Definir constructor

En el constructor, tome como parámetros toda la información relevante como label_key, predict_key, example_weight_key, metric_tag, etc. requerida para la métrica personalizada.

Implementar métodos abstractos/no implementados
  • comprobar_compatibilidad

    Implemente este método para verificar la compatibilidad de la métrica con el modelo que se está evaluando, es decir, verificar si todas las características requeridas, la etiqueta esperada y la clave de predicción están presentes en el modelo en el tipo de datos apropiado. Se necesitan tres argumentos:

    • características_dict
    • predicciones_dict
    • etiquetas_dict

    Estos diccionarios contienen referencias a tensores para el modelo.

  • get_metric_ops

    Implemente este método para proporcionar operaciones de métricas (operaciones de valor y actualización) para calcular la métrica. Similar al método check_compatibility, también requiere tres argumentos:

    • características_dict
    • predicciones_dict
    • etiquetas_dict

    Defina su lógica de cálculo métrico utilizando estas referencias a tensores para el modelo.

  • populate_stats_and_pop y populate_plots_and_pop

    Implemente esta métrica para convertir los resultados de métricas sin procesar al formato de protocolo MetricValue y PlotData . Esto requiere tres argumentos:

    • slice_key: nombre de la métrica de sector al que pertenece.
    • combinado_metrics: Diccionario que contiene resultados sin procesar.
    • output_metrics: diccionario de salida que contiene métricas en el formato de protocolo deseado.
@_export('my_metric')
class _MyMetric(_PostExportMetric):
   def __init__(self,
                target_prediction_keys: Optional[List[Text]] = None,
                labels_key: Optional[Text] = None,
                metric_tag: Optional[Text] = None):
      self._target_prediction_keys = target_prediction_keys
      self._label_keys = label_keys
      self._metric_tag = metric_tag
      self._metric_key = 'my_metric_key'

   def check_compatibility(self, features_dict:types.TensorTypeMaybeDict,
                           predictions_dict: types.TensorTypeMaybeDict,
                           labels_dict: types.TensorTypeMaybeDict) -> None:
       # Add compatibility check needed for the metric here.

   def get_metric_ops(self, features_dict: types.TensorTypeMaybeDict,
                      predictions_dict: types.TensorTypeMaybeDict,
                      labels_dict: types.TensorTypeMaybeDict
                     ) -> Dict[bytes, Tuple[types.TensorType,
                     types.TensorType]]:
        # Metric computation logic here.
        # Define value and update ops.
        value_op = compute_metric_value(...)
        update_op = create_update_op(... )
        return {self._metric_key: (value_op, update_op)}

   def populate_stats_and_pop(
       self, slice_key: slicer.SliceKeyType, combined_metrics: Dict[Text, Any],
       output_metrics: Dict[Text, metrics_pb2.MetricValue]) -> None:
       # Parses the metric and converts it into required metric format.
       metric_result = combined_metrics[self._metric_key]
       output_metrics[self._metric_key].double_value.value = metric_result

Uso

# Custom metric callback
custom_metric_callback = my_metric(
    labels_key='label',
    target_prediction_keys=['prediction'])

fairness_indicators_callback =
   post_export_metrics.fairness_indicators(
        thresholds=[0.1, 0.3, 0.5, 0.7, 0.9], labels_key=label)

add_metrics_callbacks = [custom_metric_callback,
   fairness_indicators_callback]

eval_shared_model = tfma.default_eval_shared_model(
    eval_saved_model_path=eval_saved_model_path,
    add_metrics_callbacks=add_metrics_callbacks)

eval_config = tfma.EvalConfig(...)

# Run evaluation
tfma.run_model_analysis(
    eval_config=eval_config, eval_shared_model=eval_shared_model)