Métriques post-exportation
Comme son nom l'indique, il s'agit d'une métrique ajoutée après l'exportation, avant l'évaluation.
TFMA est fourni avec plusieurs métriques d'évaluation prédéfinies, comme example_count, auc, confusion_matrix_at_thresholds, précision_recall_at_k, mse, mae, pour n'en nommer que quelques-unes. (Liste complète ici .)
Si vous ne trouvez pas de métrique existante pertinente pour votre cas d'utilisation ou si vous souhaitez personnaliser une métrique, vous pouvez définir votre propre métrique personnalisée. Lisez la suite pour les détails !
Ajout de métriques personnalisées dans TFMA
Définition de métriques personnalisées dans TFMA 1.x
Étendre la classe de base abstraite
Pour ajouter une métrique personnalisée, créez une nouvelle classe étendant la classe abstraite _PostExportMetric , définissez son constructeur et implémentez des méthodes abstraites/non implémentées.
Définir le constructeur
Dans le constructeur, prenez comme paramètres toutes les informations pertinentes telles que label_key, prédiction_key, example_weight_key, metric_tag, etc. requises pour la métrique personnalisée.
Implémenter des méthodes abstraites/non implémentées
Implémentez cette méthode pour vérifier la compatibilité de la métrique avec le modèle en cours d'évaluation, c'est-à-dire vérifier si toutes les fonctionnalités requises, l'étiquette attendue et la clé de prédiction sont présentes dans le modèle dans le type de données approprié. Il faut trois arguments :
- fonctionnalités_dict
- prédictions_dict
- labels_dict
Ces dictionnaires contiennent des références aux Tensors pour le modèle.
Implémentez cette méthode pour fournir des opérations de métrique (opérations de valeur et de mise à jour) pour calculer la métrique. Semblable à la méthode check_compatibility, elle prend également trois arguments :
- fonctionnalités_dict
- prédictions_dict
- étiquettes_dict
Définissez votre logique de calcul de métriques à l'aide de ces références aux tenseurs pour le modèle.
populate_stats_and_pop et populate_plots_and_pop
Implémentez cette métrique pour convertir les résultats bruts des métriques au format proto MetricValue et PlotData . Cela nécessite trois arguments :
- slice_key : nom de la métrique de tranche à laquelle appartient.
- Combined_metrics : dictionnaire contenant les résultats bruts.
- output_metrics : dictionnaire de sortie contenant la métrique au format de prototype souhaité.
@_export('my_metric')
class _MyMetric(_PostExportMetric):
def __init__(self,
target_prediction_keys: Optional[List[Text]] = None,
labels_key: Optional[Text] = None,
metric_tag: Optional[Text] = None):
self._target_prediction_keys = target_prediction_keys
self._label_keys = label_keys
self._metric_tag = metric_tag
self._metric_key = 'my_metric_key'
def check_compatibility(self, features_dict:types.TensorTypeMaybeDict,
predictions_dict: types.TensorTypeMaybeDict,
labels_dict: types.TensorTypeMaybeDict) -> None:
# Add compatibility check needed for the metric here.
def get_metric_ops(self, features_dict: types.TensorTypeMaybeDict,
predictions_dict: types.TensorTypeMaybeDict,
labels_dict: types.TensorTypeMaybeDict
) -> Dict[bytes, Tuple[types.TensorType,
types.TensorType]]:
# Metric computation logic here.
# Define value and update ops.
value_op = compute_metric_value(...)
update_op = create_update_op(... )
return {self._metric_key: (value_op, update_op)}
def populate_stats_and_pop(
self, slice_key: slicer.SliceKeyType, combined_metrics: Dict[Text, Any],
output_metrics: Dict[Text, metrics_pb2.MetricValue]) -> None:
# Parses the metric and converts it into required metric format.
metric_result = combined_metrics[self._metric_key]
output_metrics[self._metric_key].double_value.value = metric_result
Usage
# Custom metric callback
custom_metric_callback = my_metric(
labels_key='label',
target_prediction_keys=['prediction'])
fairness_indicators_callback =
post_export_metrics.fairness_indicators(
thresholds=[0.1, 0.3, 0.5, 0.7, 0.9], labels_key=label)
add_metrics_callbacks = [custom_metric_callback,
fairness_indicators_callback]
eval_shared_model = tfma.default_eval_shared_model(
eval_saved_model_path=eval_saved_model_path,
add_metrics_callbacks=add_metrics_callbacks)
eval_config = tfma.EvalConfig(...)
# Run evaluation
tfma.run_model_analysis(
eval_config=eval_config, eval_shared_model=eval_shared_model)