Métricas posteriores a la exportación
Como sugiere el nombre, esta es una métrica que se agrega después de la exportación, antes de la evaluación.
TFMA incluye varias métricas de evaluación predefinidas, como example_count, auc, confusion_matrix_at_thresholds, precision_recall_at_k, mse, mae, por nombrar algunas. (Lista completa aquí .)
Si no encuentra una métrica existente relevante para su caso de uso, o desea personalizar una métrica, puede definir su propia métrica personalizada. ¡Siga leyendo para conocer los detalles!
Adición de métricas personalizadas en TFMA
Definición de métricas personalizadas en TFMA 1.x
Ampliar clase base abstracta
Para agregar una métrica personalizada, cree una nueva clase que extienda la clase abstracta _PostExportMetric y defina su constructor e implemente métodos abstractos/no implementados.
Definir constructor
En el constructor, tome como parámetros toda la información relevante como etiqueta_clave, predicción_clave, ejemplo_peso_clave, métrica_etiqueta, etc. requerida para la métrica personalizada.
Implementar métodos abstractos/no implementados
Implemente este método para verificar la compatibilidad de la métrica con el modelo que se está evaluando, es decir, verificar si todas las características requeridas, la etiqueta esperada y la clave de predicción están presentes en el modelo en el tipo de datos apropiado. Se necesitan tres argumentos:
- características_dict
- predicciones_dict
- etiquetas_dict
Estos diccionarios contienen referencias a Tensores para el modelo.
Implemente este método para proporcionar operaciones de métricas (operaciones de valor y actualización) para calcular la métrica. Similar al método check_compatibility, también toma tres argumentos:
- características_dict
- predicciones_dict
- etiquetas_dict
Defina la lógica de cálculo de su métrica utilizando estas referencias a Tensores para el modelo.
populate_stats_and_pop y populate_plots_and_pop
Implemente esta métrica para convertir los resultados de la métrica sin procesar al formato de prototipo MetricValue y PlotData . Esto toma tres argumentos:
- slice_key: nombre de la métrica de segmento a la que pertenece.
- métricas_combinadas: diccionario que contiene resultados sin procesar.
- output_metrics: diccionario de salida que contiene la métrica en el formato de prototipo deseado.
@_export('my_metric')
class _MyMetric(_PostExportMetric):
def __init__(self,
target_prediction_keys: Optional[List[Text]] = None,
labels_key: Optional[Text] = None,
metric_tag: Optional[Text] = None):
self._target_prediction_keys = target_prediction_keys
self._label_keys = label_keys
self._metric_tag = metric_tag
self._metric_key = 'my_metric_key'
def check_compatibility(self, features_dict:types.TensorTypeMaybeDict,
predictions_dict: types.TensorTypeMaybeDict,
labels_dict: types.TensorTypeMaybeDict) -> None:
# Add compatibility check needed for the metric here.
def get_metric_ops(self, features_dict: types.TensorTypeMaybeDict,
predictions_dict: types.TensorTypeMaybeDict,
labels_dict: types.TensorTypeMaybeDict
) -> Dict[bytes, Tuple[types.TensorType,
types.TensorType]]:
# Metric computation logic here.
# Define value and update ops.
value_op = compute_metric_value(...)
update_op = create_update_op(... )
return {self._metric_key: (value_op, update_op)}
def populate_stats_and_pop(
self, slice_key: slicer.SliceKeyType, combined_metrics: Dict[Text, Any],
output_metrics: Dict[Text, metrics_pb2.MetricValue]) -> None:
# Parses the metric and converts it into required metric format.
metric_result = combined_metrics[self._metric_key]
output_metrics[self._metric_key].double_value.value = metric_result
Uso
# Custom metric callback
custom_metric_callback = my_metric(
labels_key='label',
target_prediction_keys=['prediction'])
fairness_indicators_callback =
post_export_metrics.fairness_indicators(
thresholds=[0.1, 0.3, 0.5, 0.7, 0.9], labels_key=label)
add_metrics_callbacks = [custom_metric_callback,
fairness_indicators_callback]
eval_shared_model = tfma.default_eval_shared_model(
eval_saved_model_path=eval_saved_model_path,
add_metrics_callbacks=add_metrics_callbacks)
eval_config = tfma.EvalConfig(...)
# Run evaluation
tfma.run_model_analysis(
eval_config=eval_config, eval_shared_model=eval_shared_model)