Metriche post-esportazione
Come suggerisce il nome, si tratta di una metrica che viene aggiunta dopo l'esportazione, prima della valutazione.
TFMA è dotato di diverse metriche di valutazione predefinite, come example_count, auc, confusion_matrix_at_thresholds, Precision_recall_at_k, mse, mae, solo per citarne alcuni. (L'elenco completo qui .)
Se non trovi una metrica esistente pertinente al tuo caso d'uso o desideri personalizzare una metrica, puoi definire la tua metrica personalizzata. Continua a leggere per i dettagli!
Aggiunta di metriche personalizzate in TFMA
Definizione delle metriche personalizzate in TFMA 1.x
Estendi la classe base astratta
Per aggiungere una metrica personalizzata, crea una nuova classe che estende la classe astratta _PostExportMetric e definisci il suo costruttore e implementa metodi astratti/non implementati.
Definire Costruttore
Nel costruttore, prendi come parametri tutte le informazioni rilevanti come label_key, Recommendation_key, example_weight_key, metric_tag, ecc. richieste per la metrica personalizzata.
Implementare metodi astratti/non implementati
Implementare questo metodo per verificare la compatibilità della metrica con il modello da valutare, ovvero verificare se tutte le funzionalità richieste, l'etichetta prevista e la chiave di previsione sono presenti nel modello nel tipo di dati appropriato. Sono necessari tre argomenti:
- features_dict
- previsioni_dict
- etichette_dict
Questi dizionari contengono riferimenti ai tensori per il modello.
Implementare questo metodo per fornire operazioni metriche (operazioni valore e aggiornamento) per calcolare la metrica. Similmente al metodo check_compatibility, richiede anche tre argomenti:
- features_dict
- previsioni_dict
- etichette_dict
Definisci la logica di calcolo della metrica utilizzando questi riferimenti ai tensori per il modello.
populate_stats_and_pop e populate_plots_and_pop
Implementa questa metrica per convertire i risultati della metrica grezza nel formato proto MetricValue e PlotData . Ciò richiede tre argomenti:
- slice_key: nome della metrica della sezione a cui appartiene.
- combined_metrics: dizionario contenente risultati grezzi.
- output_metrics: dizionario di output contenente la metrica nel formato proto desiderato.
@_export('my_metric')
class _MyMetric(_PostExportMetric):
def __init__(self,
target_prediction_keys: Optional[List[Text]] = None,
labels_key: Optional[Text] = None,
metric_tag: Optional[Text] = None):
self._target_prediction_keys = target_prediction_keys
self._label_keys = label_keys
self._metric_tag = metric_tag
self._metric_key = 'my_metric_key'
def check_compatibility(self, features_dict:types.TensorTypeMaybeDict,
predictions_dict: types.TensorTypeMaybeDict,
labels_dict: types.TensorTypeMaybeDict) -> None:
# Add compatibility check needed for the metric here.
def get_metric_ops(self, features_dict: types.TensorTypeMaybeDict,
predictions_dict: types.TensorTypeMaybeDict,
labels_dict: types.TensorTypeMaybeDict
) -> Dict[bytes, Tuple[types.TensorType,
types.TensorType]]:
# Metric computation logic here.
# Define value and update ops.
value_op = compute_metric_value(...)
update_op = create_update_op(... )
return {self._metric_key: (value_op, update_op)}
def populate_stats_and_pop(
self, slice_key: slicer.SliceKeyType, combined_metrics: Dict[Text, Any],
output_metrics: Dict[Text, metrics_pb2.MetricValue]) -> None:
# Parses the metric and converts it into required metric format.
metric_result = combined_metrics[self._metric_key]
output_metrics[self._metric_key].double_value.value = metric_result
Utilizzo
# Custom metric callback
custom_metric_callback = my_metric(
labels_key='label',
target_prediction_keys=['prediction'])
fairness_indicators_callback =
post_export_metrics.fairness_indicators(
thresholds=[0.1, 0.3, 0.5, 0.7, 0.9], labels_key=label)
add_metrics_callbacks = [custom_metric_callback,
fairness_indicators_callback]
eval_shared_model = tfma.default_eval_shared_model(
eval_saved_model_path=eval_saved_model_path,
add_metrics_callbacks=add_metrics_callbacks)
eval_config = tfma.EvalConfig(...)
# Run evaluation
tfma.run_model_analysis(
eval_config=eval_config, eval_shared_model=eval_shared_model)