エクスポート後のメトリクス
名前が示すように、これはエクスポート後の評価前に追加されるメトリックです。
TFMA には、example_count、auc、construction_matrix_at_thresholds、precision_recall_at_k、mse、mae など、いくつかの事前定義された評価指標がパッケージ化されています。 (完全なリストはここにあります。)
ユースケースに関連する既存のメトリクスが見つからない場合、またはメトリクスをカスタマイズしたい場合は、独自のカスタム メトリクスを定義できます。詳細については続きをお読みください。
TFMA でのカスタム メトリクスの追加
TFMA 1.x でのカスタム メトリクスの定義
抽象基本クラスの拡張
カスタム メトリックを追加するには、 _PostExportMetric抽象クラスを拡張する新しいクラスを作成し、そのコンストラクターを定義して、抽象/未実装メソッドを実装します。
コンストラクターの定義
コンストラクターでは、カスタム メトリックに必要な label_key、prediction_key、example_weight_key、metric_tag などのすべての関連情報をパラメーターとして受け取ります。
抽象/未実装メソッドの実装
このメソッドを実装して、評価対象のモデルとメトリクスの互換性をチェックします。つまり、すべての必要な特徴、予期されるラベル、予測キーが適切なデータ型でモデルに存在するかどうかをチェックします。 3 つの引数を取ります。
- features_dict
- 予測_dict
- ラベル_dict
これらの辞書には、モデルの Tensor への参照が含まれています。
このメソッドを実装して、メトリックを計算するためのメトリック操作 (値操作と更新操作) を提供します。 check_compatibility メソッドと同様に、このメソッドも 3 つの引数を取ります。
- features_dict
- 予測_dict
- ラベル_dict
モデルの Tensor へのこれらの参照を使用して、メトリクス計算ロジックを定義します。
Populate_stats_and_popとPopulate_plots_and_pop
このメトリックを実装して、生のメトリック結果をMetricValueおよびPlotDataプロト形式に変換します。これには 3 つの引数が必要です。
- lice_key: スライス メトリックが属する名前。
- combined_metrics: 生の結果を含むディクショナリ。
- Output_metrics: 必要なプロト形式のメトリックを含む出力ディクショナリ。
@_export('my_metric')
class _MyMetric(_PostExportMetric):
def __init__(self,
target_prediction_keys: Optional[List[Text]] = None,
labels_key: Optional[Text] = None,
metric_tag: Optional[Text] = None):
self._target_prediction_keys = target_prediction_keys
self._label_keys = label_keys
self._metric_tag = metric_tag
self._metric_key = 'my_metric_key'
def check_compatibility(self, features_dict:types.TensorTypeMaybeDict,
predictions_dict: types.TensorTypeMaybeDict,
labels_dict: types.TensorTypeMaybeDict) -> None:
# Add compatibility check needed for the metric here.
def get_metric_ops(self, features_dict: types.TensorTypeMaybeDict,
predictions_dict: types.TensorTypeMaybeDict,
labels_dict: types.TensorTypeMaybeDict
) -> Dict[bytes, Tuple[types.TensorType,
types.TensorType]]:
# Metric computation logic here.
# Define value and update ops.
value_op = compute_metric_value(...)
update_op = create_update_op(... )
return {self._metric_key: (value_op, update_op)}
def populate_stats_and_pop(
self, slice_key: slicer.SliceKeyType, combined_metrics: Dict[Text, Any],
output_metrics: Dict[Text, metrics_pb2.MetricValue]) -> None:
# Parses the metric and converts it into required metric format.
metric_result = combined_metrics[self._metric_key]
output_metrics[self._metric_key].double_value.value = metric_result
使用法
# Custom metric callback
custom_metric_callback = my_metric(
labels_key='label',
target_prediction_keys=['prediction'])
fairness_indicators_callback =
post_export_metrics.fairness_indicators(
thresholds=[0.1, 0.3, 0.5, 0.7, 0.9], labels_key=label)
add_metrics_callbacks = [custom_metric_callback,
fairness_indicators_callback]
eval_shared_model = tfma.default_eval_shared_model(
eval_saved_model_path=eval_saved_model_path,
add_metrics_callbacks=add_metrics_callbacks)
eval_config = tfma.EvalConfig(...)
# Run evaluation
tfma.run_model_analysis(
eval_config=eval_config, eval_shared_model=eval_shared_model)