Метрики пост-экспорта

Как следует из названия, это показатель, который добавляется после экспорта, перед оценкой.

TFMA поставляется с несколькими предопределенными оценочными метриками, такими как example_count, auc, смущение_matrix_at_thresholds, Precision_recall_at_k, mse, mae и многие другие. (Полный список здесь .)

Если вы не нашли существующие метрики, соответствующие вашему варианту использования, или хотите настроить метрику, вы можете определить свою собственную метрику. Читайте подробности!

Добавление пользовательских метрик в TFMA

Определение пользовательских метрик в TFMA 1.x

Расширение абстрактного базового класса

Чтобы добавить пользовательскую метрику, создайте новый класс, расширяющий абстрактный класс _PostExportMetric , определите его конструктор и реализуйте абстрактные/нереализованные методы.

Определить конструктор

В конструкторе возьмите в качестве параметров всю соответствующую информацию, такую ​​как label_key, Prediction_key, example_weight_key, metric_tag и т. д., необходимую для пользовательской метрики.

Реализация абстрактных/нереализованных методов
  • проверка_совместимости

    Внедрите этот метод, чтобы проверить совместимость метрики с оцениваемой моделью, т. е. проверить, присутствуют ли в модели все необходимые функции, ожидаемая метка и ключ прогнозирования в соответствующем типе данных. Требуется три аргумента:

    • Features_dict
    • предсказания_дикт
    • labels_dict

    Эти словари содержат ссылки на тензоры для модели.

  • get_metric_ops

    Реализуйте этот метод, чтобы предоставить операции по метрике (операции по значению и обновлению) для вычисления метрики. Подобно методу check_compatibility, он также принимает три аргумента:

    • Features_dict
    • предсказания_дикт
    • labels_dict

    Определите логику вычисления показателей, используя эти ссылки на тензоры для модели.

  • populate_stats_and_pop и populate_plots_and_pop

    Реализуйте эту метрику для преобразования необработанных результатов метрики в формат MetricValue и PlotData . Для этого требуется три аргумента:

    • срез_ключ: имя метрики среза, которому принадлежит.
    • Комбинированные_метрики: словарь, содержащий необработанные результаты.
    • output_metrics: выходной словарь, содержащий метрику в желаемом прото-формате.
@_export('my_metric')
class _MyMetric(_PostExportMetric):
   def __init__(self,
                target_prediction_keys: Optional[List[Text]] = None,
                labels_key: Optional[Text] = None,
                metric_tag: Optional[Text] = None):
      self._target_prediction_keys = target_prediction_keys
      self._label_keys = label_keys
      self._metric_tag = metric_tag
      self._metric_key = 'my_metric_key'

   def check_compatibility(self, features_dict:types.TensorTypeMaybeDict,
                           predictions_dict: types.TensorTypeMaybeDict,
                           labels_dict: types.TensorTypeMaybeDict) -> None:
       # Add compatibility check needed for the metric here.

   def get_metric_ops(self, features_dict: types.TensorTypeMaybeDict,
                      predictions_dict: types.TensorTypeMaybeDict,
                      labels_dict: types.TensorTypeMaybeDict
                     ) -> Dict[bytes, Tuple[types.TensorType,
                     types.TensorType]]:
        # Metric computation logic here.
        # Define value and update ops.
        value_op = compute_metric_value(...)
        update_op = create_update_op(... )
        return {self._metric_key: (value_op, update_op)}

   def populate_stats_and_pop(
       self, slice_key: slicer.SliceKeyType, combined_metrics: Dict[Text, Any],
       output_metrics: Dict[Text, metrics_pb2.MetricValue]) -> None:
       # Parses the metric and converts it into required metric format.
       metric_result = combined_metrics[self._metric_key]
       output_metrics[self._metric_key].double_value.value = metric_result

Использование

# Custom metric callback
custom_metric_callback = my_metric(
    labels_key='label',
    target_prediction_keys=['prediction'])

fairness_indicators_callback =
   post_export_metrics.fairness_indicators(
        thresholds=[0.1, 0.3, 0.5, 0.7, 0.9], labels_key=label)

add_metrics_callbacks = [custom_metric_callback,
   fairness_indicators_callback]

eval_shared_model = tfma.default_eval_shared_model(
    eval_saved_model_path=eval_saved_model_path,
    add_metrics_callbacks=add_metrics_callbacks)

eval_config = tfma.EvalConfig(...)

# Run evaluation
tfma.run_model_analysis(
    eval_config=eval_config, eval_shared_model=eval_shared_model)