Metryki po eksporcie

Jak sama nazwa wskazuje, jest to wskaźnik dodawany po eksporcie, przed oceną.

TFMA zawiera kilka predefiniowanych wskaźników oceny, takich jak liczba_przykładów, auc, zamieszanie_matrix_at_thresholds, precyzja_przypomnienia_at_k, mse, mae, żeby wymienić tylko kilka. (Pełna lista tutaj .)

Jeśli nie znajdziesz istniejących metryk odpowiednich dla Twojego przypadku użycia lub chcesz dostosować metrykę, możesz zdefiniować własną metrykę niestandardową. Czytaj dalej, aby poznać szczegóły!

Dodawanie niestandardowych metryk w TFMA

Definiowanie metryk niestandardowych w TFMA 1.x

Rozszerz abstrakcyjną klasę bazową

Aby dodać niestandardową metrykę, utwórz nową klasę rozszerzającą klasę abstrakcyjną _PostExportMetric , zdefiniuj jej konstruktor i zaimplementuj metody abstrakcyjne/niezaimplementowane.

Zdefiniuj konstruktor

W konstruktorze przyjmij jako parametry wszystkie istotne informacje, takie jak klucz_etykiety, klucz_przewidywania, klucz_przykładowej_wagi, tag_metryki itp. wymagane dla metryki niestandardowej.

Implementuj metody abstrakcyjne/niewdrożone
  • sprawdź_kompatybilność

    Zaimplementuj tę metodę, aby sprawdzić zgodność metryki z ocenianym modelem, czyli sprawdzić, czy wszystkie wymagane cechy, oczekiwana etykieta i klucz predykcji występują w modelu w odpowiednim typie danych. Wymaga trzech argumentów:

    • cechy_dykt
    • przewidywania_dykt
    • etykiety_dykt

    Słowniki te zawierają odniesienia do tensorów dla modelu.

  • get_metric_ops

    Zaimplementuj tę metodę, aby zapewnić operacje metryki (operacje wartości i aktualizacji) w celu obliczenia metryki. Podobnie jak metoda check_compatibility, również przyjmuje trzy argumenty:

    • cechy_dykt
    • przewidywania_dykt
    • etykiety_dykt

    Zdefiniuj logikę obliczeń metrycznych, korzystając z tych odniesień do tensorów modelu.

  • populate_stats_and_pop i populate_plots_and_pop

    Zaimplementuj tę metrykę, aby przekonwertować surowe wyniki metryki na format proto MetricValue i PlotData . Wymaga to trzech argumentów:

    • plasterek_key: Nazwa metryki plasterka, do której należy.
    • Combined_metrics: Słownik zawierający surowe wyniki.
    • Output_metrics: Słownik wyjściowy zawierający metrykę w żądanym formacie proto.
@_export('my_metric')
class _MyMetric(_PostExportMetric):
   def __init__(self,
                target_prediction_keys: Optional[List[Text]] = None,
                labels_key: Optional[Text] = None,
                metric_tag: Optional[Text] = None):
      self._target_prediction_keys = target_prediction_keys
      self._label_keys = label_keys
      self._metric_tag = metric_tag
      self._metric_key = 'my_metric_key'

   def check_compatibility(self, features_dict:types.TensorTypeMaybeDict,
                           predictions_dict: types.TensorTypeMaybeDict,
                           labels_dict: types.TensorTypeMaybeDict) -> None:
       # Add compatibility check needed for the metric here.

   def get_metric_ops(self, features_dict: types.TensorTypeMaybeDict,
                      predictions_dict: types.TensorTypeMaybeDict,
                      labels_dict: types.TensorTypeMaybeDict
                     ) -> Dict[bytes, Tuple[types.TensorType,
                     types.TensorType]]:
        # Metric computation logic here.
        # Define value and update ops.
        value_op = compute_metric_value(...)
        update_op = create_update_op(... )
        return {self._metric_key: (value_op, update_op)}

   def populate_stats_and_pop(
       self, slice_key: slicer.SliceKeyType, combined_metrics: Dict[Text, Any],
       output_metrics: Dict[Text, metrics_pb2.MetricValue]) -> None:
       # Parses the metric and converts it into required metric format.
       metric_result = combined_metrics[self._metric_key]
       output_metrics[self._metric_key].double_value.value = metric_result

Stosowanie

# Custom metric callback
custom_metric_callback = my_metric(
    labels_key='label',
    target_prediction_keys=['prediction'])

fairness_indicators_callback =
   post_export_metrics.fairness_indicators(
        thresholds=[0.1, 0.3, 0.5, 0.7, 0.9], labels_key=label)

add_metrics_callbacks = [custom_metric_callback,
   fairness_indicators_callback]

eval_shared_model = tfma.default_eval_shared_model(
    eval_saved_model_path=eval_saved_model_path,
    add_metrics_callbacks=add_metrics_callbacks)

eval_config = tfma.EvalConfig(...)

# Run evaluation
tfma.run_model_analysis(
    eval_config=eval_config, eval_shared_model=eval_shared_model)