Métricas posteriores a la exportación
Como sugiere el nombre, esta es una métrica que se agrega después de la exportación, antes de la evaluación.
TFMA incluye varias métricas de evaluación predefinidas, como example_count, auc, confusion_matrix_at_thresholds, precision_recall_at_k, mse, mae, por nombrar algunas. (Lista completa aquí ).
Si no encuentra una métrica existente relevante para su caso de uso o desea personalizar una métrica, puede definir su propia métrica personalizada. ¡Sigue leyendo para conocer los detalles!
Agregar métricas personalizadas en TFMA
Definición de métricas personalizadas en TFMA 1.x
Ampliar la clase base abstracta
Para agregar una métrica personalizada, cree una nueva clase que extienda la clase abstracta _PostExportMetric y defina su constructor e implemente métodos abstractos/no implementados.
Definir constructor
En el constructor, tome como parámetros toda la información relevante como label_key, predict_key, example_weight_key, metric_tag, etc. requerida para la métrica personalizada.
Implementar métodos abstractos/no implementados
Implemente este método para verificar la compatibilidad de la métrica con el modelo que se está evaluando, es decir, verificar si todas las características requeridas, la etiqueta esperada y la clave de predicción están presentes en el modelo en el tipo de datos apropiado. Se necesitan tres argumentos:
- características_dict
- predicciones_dict
- etiquetas_dict
Estos diccionarios contienen referencias a tensores para el modelo.
Implemente este método para proporcionar operaciones de métricas (operaciones de valor y actualización) para calcular la métrica. Similar al método check_compatibility, también requiere tres argumentos:
- características_dict
- predicciones_dict
- etiquetas_dict
Defina su lógica de cálculo métrico utilizando estas referencias a tensores para el modelo.
populate_stats_and_pop y populate_plots_and_pop
Implemente esta métrica para convertir los resultados de métricas sin procesar al formato de protocolo MetricValue y PlotData . Esto requiere tres argumentos:
- slice_key: nombre de la métrica de sector al que pertenece.
- combinado_metrics: Diccionario que contiene resultados sin procesar.
- output_metrics: diccionario de salida que contiene métricas en el formato de protocolo deseado.
@_export('my_metric')
class _MyMetric(_PostExportMetric):
def __init__(self,
target_prediction_keys: Optional[List[Text]] = None,
labels_key: Optional[Text] = None,
metric_tag: Optional[Text] = None):
self._target_prediction_keys = target_prediction_keys
self._label_keys = label_keys
self._metric_tag = metric_tag
self._metric_key = 'my_metric_key'
def check_compatibility(self, features_dict:types.TensorTypeMaybeDict,
predictions_dict: types.TensorTypeMaybeDict,
labels_dict: types.TensorTypeMaybeDict) -> None:
# Add compatibility check needed for the metric here.
def get_metric_ops(self, features_dict: types.TensorTypeMaybeDict,
predictions_dict: types.TensorTypeMaybeDict,
labels_dict: types.TensorTypeMaybeDict
) -> Dict[bytes, Tuple[types.TensorType,
types.TensorType]]:
# Metric computation logic here.
# Define value and update ops.
value_op = compute_metric_value(...)
update_op = create_update_op(... )
return {self._metric_key: (value_op, update_op)}
def populate_stats_and_pop(
self, slice_key: slicer.SliceKeyType, combined_metrics: Dict[Text, Any],
output_metrics: Dict[Text, metrics_pb2.MetricValue]) -> None:
# Parses the metric and converts it into required metric format.
metric_result = combined_metrics[self._metric_key]
output_metrics[self._metric_key].double_value.value = metric_result
Uso
# Custom metric callback
custom_metric_callback = my_metric(
labels_key='label',
target_prediction_keys=['prediction'])
fairness_indicators_callback =
post_export_metrics.fairness_indicators(
thresholds=[0.1, 0.3, 0.5, 0.7, 0.9], labels_key=label)
add_metrics_callbacks = [custom_metric_callback,
fairness_indicators_callback]
eval_shared_model = tfma.default_eval_shared_model(
eval_saved_model_path=eval_saved_model_path,
add_metrics_callbacks=add_metrics_callbacks)
eval_config = tfma.EvalConfig(...)
# Run evaluation
tfma.run_model_analysis(
eval_config=eval_config, eval_shared_model=eval_shared_model)