概要
TFMA は、次のメトリクスとプロットをサポートしています。
- 標準の keras メトリクス (
tf.keras.metrics.*
)- keras メトリクスを使用するために keras モデルは必要ないことに注意してください。メトリクスは、メトリクス クラスを直接使用して、ビーム内のグラフの外側で計算されます。
標準の TFMA メトリクスとプロット (
tfma.metrics.*
)カスタム keras メトリクス (
tf.keras.metrics.Metric
から派生したメトリクス)カスタム TFMA メトリクス (カスタム ビーム コンバイナーを使用した
tfma.metrics.Metric
から派生したメトリクス、または他のメトリクスから派生したメトリクス)。
TFMA は、マルチクラス/マルチラベルの問題で使用するバイナリ分類メトリックを変換するための組み込みサポートも提供します。
- クラスIDや上位Kなどに基づく二値化
- マイクロ平均、マクロ平均などに基づいた集計メトリクス。
TFMA は、パイプライン内でサンプルがクエリ キーによって自動的にグループ化される、クエリ/ランキング ベースのメトリクスの組み込みサポートも提供します。
組み合わせると、回帰、バイナリ分類、マルチクラス/マルチラベル分類、ランキングなどのさまざまな問題に利用できる 50 以上の標準メトリクスとプロットが利用可能です。
構成
TFMA でメトリクスを構成するには 2 つの方法があります: (1) tfma.MetricsSpec
を使用する方法、または (2) Python でtf.keras.metrics.*
および/またはtfma.metrics.*
クラスのインスタンスを作成してtfma.metrics.specs_from_metrics
使用して、 tfma.MetricsSpec
のリストに変換します。
次のセクションでは、さまざまな種類の機械学習の問題に対する構成例について説明します。
回帰指標
以下は、回帰問題の構成セットアップの例です。サポートされる追加メトリクスについては、 tf.keras.metrics.*
モジュールとtfma.metrics.*
モジュールを参照してください。
from google.protobuf import text_format
metrics_specs = text_format.Parse("""
metrics_specs {
metrics { class_name: "ExampleCount" }
metrics { class_name: "MeanSquaredError" }
metrics { class_name: "Accuracy" }
metrics { class_name: "MeanLabel" }
metrics { class_name: "MeanPrediction" }
metrics { class_name: "Calibration" }
metrics {
class_name: "CalibrationPlot"
config: '"min_value": 0, "max_value": 10'
}
}
""", tfma.EvalConfig()).metrics_specs
これと同じセットアップは、次の Python コードを使用して作成できます。
metrics = [
tfma.metrics.ExampleCount(name='example_count'),
tf.keras.metrics.MeanSquaredError(name='mse'),
tf.keras.metrics.Accuracy(name='accuracy'),
tfma.metrics.MeanLabel(name='mean_label'),
tfma.metrics.MeanPrediction(name='mean_prediction'),
tfma.metrics.Calibration(name='calibration'),
tfma.metrics.CalibrationPlot(
name='calibration', min_value=0, max_value=10)
]
metrics_specs = tfma.metrics.specs_from_metrics(metrics)
この設定は、 tfma.metrics.default_regression_specs
を呼び出しても利用できることに注意してください。
バイナリ分類メトリクス
以下は、バイナリ分類問題の設定セットアップの例です。サポートされる追加メトリクスについては、 tf.keras.metrics.*
モジュールとtfma.metrics.*
モジュールを参照してください。
from google.protobuf import text_format
metrics_specs = text_format.Parse("""
metrics_specs {
metrics { class_name: "ExampleCount" }
metrics { class_name: "BinaryCrossentropy" }
metrics { class_name: "BinaryAccuracy" }
metrics { class_name: "AUC" }
metrics { class_name: "AUCPrecisionRecall" }
metrics { class_name: "MeanLabel" }
metrics { class_name: "MeanPrediction" }
metrics { class_name: "Calibration" }
metrics { class_name: "ConfusionMatrixPlot" }
metrics { class_name: "CalibrationPlot" }
}
""", tfma.EvalConfig()).metrics_specs
これと同じセットアップは、次の Python コードを使用して作成できます。
metrics = [
tfma.metrics.ExampleCount(name='example_count'),
tf.keras.metrics.BinaryCrossentropy(name='binary_crossentropy'),
tf.keras.metrics.BinaryAccuracy(name='accuracy'),
tf.keras.metrics.AUC(name='auc', num_thresholds=10000),
tf.keras.metrics.AUC(
name='auc_precision_recall', curve='PR', num_thresholds=10000),
tf.keras.metrics.Precision(name='precision'),
tf.keras.metrics.Recall(name='recall'),
tfma.metrics.MeanLabel(name='mean_label'),
tfma.metrics.MeanPrediction(name='mean_prediction'),
tfma.metrics.Calibration(name='calibration'),
tfma.metrics.ConfusionMatrixPlot(name='confusion_matrix_plot'),
tfma.metrics.CalibrationPlot(name='calibration_plot')
]
metrics_specs = tfma.metrics.specs_from_metrics(metrics)
この設定は、 tfma.metrics.default_binary_classification_specs
を呼び出しても利用できることに注意してください。
マルチクラス/マルチラベルの分類メトリック
以下は、マルチクラス分類問題の構成セットアップの例です。サポートされる追加メトリクスについては、 tf.keras.metrics.*
モジュールとtfma.metrics.*
モジュールを参照してください。
from google.protobuf import text_format
metrics_specs = text_format.Parse("""
metrics_specs {
metrics { class_name: "ExampleCount" }
metrics { class_name: "SparseCategoricalCrossentropy" }
metrics { class_name: "SparseCategoricalAccuracy" }
metrics { class_name: "Precision" config: '"top_k": 1' }
metrics { class_name: "Precision" config: '"top_k": 3' }
metrics { class_name: "Recall" config: '"top_k": 1' }
metrics { class_name: "Recall" config: '"top_k": 3' }
metrics { class_name: "MultiClassConfusionMatrixPlot" }
}
""", tfma.EvalConfig()).metrics_specs
これと同じセットアップは、次の Python コードを使用して作成できます。
metrics = [
tfma.metrics.ExampleCount(name='example_count'),
tf.keras.metrics.SparseCategoricalCrossentropy(
name='sparse_categorical_crossentropy'),
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
tf.keras.metrics.Precision(name='precision', top_k=1),
tf.keras.metrics.Precision(name='precision', top_k=3),
tf.keras.metrics.Recall(name='recall', top_k=1),
tf.keras.metrics.Recall(name='recall', top_k=3),
tfma.metrics.MultiClassConfusionMatrixPlot(
name='multi_class_confusion_matrix_plot'),
]
metrics_specs = tfma.metrics.specs_from_metrics(metrics)
この設定は、 tfma.metrics.default_multi_class_classification_specs
を呼び出しても利用できることに注意してください。
マルチクラス/マルチラベルの二値化メトリクス
tfma.BinarizationOptions
を使用して、複数クラス/複数ラベルのメトリクスをバイナリ化し、クラスごと、top_k ごとなどのメトリクスを生成できます。例えば:
from google.protobuf import text_format
metrics_specs = text_format.Parse("""
metrics_specs {
binarize: { class_ids: { values: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] } }
// Metrics to binarize
metrics { class_name: "AUC" }
...
}
""", tfma.EvalConfig()).metrics_specs
これと同じセットアップは、次の Python コードを使用して作成できます。
metrics = [
// Metrics to binarize
tf.keras.metrics.AUC(name='auc', num_thresholds=10000),
...
]
metrics_specs = tfma.metrics.specs_from_metrics(
metrics, binarize=tfma.BinarizationOptions(
class_ids={'values': [0,1,2,3,4,5,6,7,8,9]}))
マルチクラス/マルチラベルの集計メトリクス
tfma.AggregationOptions
を使用すると、複数クラス/複数ラベルのメトリックを集約して、バイナリ分類メトリックの単一の集約値を生成できます。
集約設定は二値化設定とは独立しているため、 tfma.AggregationOptions
とtfma.BinarizationOptions
の両方を同時に使用できることに注意してください。
ミクロ平均
マイクロ平均化は、 tfma.AggregationOptions
内のmicro_average
オプションを使用して実行できます。例えば:
from google.protobuf import text_format
metrics_specs = text_format.Parse("""
metrics_specs {
aggregate: { micro_average: true }
// Metrics to aggregate
metrics { class_name: "AUC" }
...
}
""", tfma.EvalConfig()).metrics_specs
これと同じセットアップは、次の Python コードを使用して作成できます。
metrics = [
// Metrics to aggregate
tf.keras.metrics.AUC(name='auc', num_thresholds=10000),
...
]
metrics_specs = tfma.metrics.specs_from_metrics(
metrics, aggregate=tfma.AggregationOptions(micro_average=True))
マイクロ平均化では、計算に上位 k の値のみが使用されるtop_k
設定もサポートされています。例えば:
from google.protobuf import text_format
metrics_specs = text_format.Parse("""
metrics_specs {
aggregate: {
micro_average: true
top_k_list: { values: [1, 3] }
}
// Metrics to aggregate
metrics { class_name: "AUC" }
...
}
""", tfma.EvalConfig()).metrics_specs
これと同じセットアップは、次の Python コードを使用して作成できます。
metrics = [
// Metrics to aggregate
tf.keras.metrics.AUC(name='auc', num_thresholds=10000),
...
]
metrics_specs = tfma.metrics.specs_from_metrics(
metrics,
aggregate=tfma.AggregationOptions(micro_average=True,
top_k_list={'values': [1, 3]}))
マクロ / 加重マクロ平均
マクロの平均化は、 tfma.AggregationOptions
内のmacro_average
またはweighted_macro_average
オプションを使用して実行できます。 top_k
設定が使用されない限り、マクロは平均を計算するクラスを知るためにclass_weights
を設定する必要があります。 class_weight
が指定されていない場合は、0.0 が想定されます。例えば:
from google.protobuf import text_format
metrics_specs = text_format.Parse("""
metrics_specs {
aggregate: {
macro_average: true
class_weights: { key: 0 value: 1.0 }
class_weights: { key: 1 value: 1.0 }
class_weights: { key: 2 value: 1.0 }
class_weights: { key: 3 value: 1.0 }
class_weights: { key: 4 value: 1.0 }
class_weights: { key: 5 value: 1.0 }
class_weights: { key: 6 value: 1.0 }
class_weights: { key: 7 value: 1.0 }
class_weights: { key: 8 value: 1.0 }
class_weights: { key: 9 value: 1.0 }
}
// Metrics to aggregate
metrics { class_name: "AUC" }
...
}
""", tfma.EvalConfig()).metrics_specs
これと同じセットアップは、次の Python コードを使用して作成できます。
metrics = [
// Metrics to aggregate
tf.keras.metrics.AUC(name='auc', num_thresholds=10000),
...
]
metrics_specs = tfma.metrics.specs_from_metrics(
metrics,
aggregate=tfma.AggregationOptions(
macro_average=True, class_weights={i: 1.0 for i in range(10)}))
マイクロ平均化と同様に、マクロ平均化でも、計算で上位 k の値のみが使用されるtop_k
設定がサポートされます。例えば:
from google.protobuf import text_format
metrics_specs = text_format.Parse("""
metrics_specs {
aggregate: {
macro_average: true
top_k_list: { values: [1, 3] }
}
// Metrics to aggregate
metrics { class_name: "AUC" }
...
}
""", tfma.EvalConfig()).metrics_specs
これと同じセットアップは、次の Python コードを使用して作成できます。
metrics = [
// Metrics to aggregate
tf.keras.metrics.AUC(name='auc', num_thresholds=10000),
...
]
metrics_specs = tfma.metrics.specs_from_metrics(
metrics,
aggregate=tfma.AggregationOptions(macro_average=True,
top_k_list={'values': [1, 3]}))
クエリ/ランキングベースのメトリクス
クエリ/ランキング ベースのメトリクスは、メトリクス仕様でquery_key
オプションを指定することで有効になります。例えば:
from google.protobuf import text_format
metrics_specs = text_format.Parse("""
metrics_specs {
query_key: "doc_id"
metrics {
class_name: "NDCG"
config: '"gain_key": "gain", "top_k_list": [1, 2]'
}
metrics { class_name: "MinLabelPosition" }
}
""", tfma.EvalConfig()).metrics_specs
これと同じセットアップは、次の Python コードを使用して作成できます。
metrics = [
tfma.metrics.NDCG(name='ndcg', gain_key='gain', top_k_list=[1, 2]),
tfma.metrics.MinLabelPosition(name='min_label_position')
]
metrics_specs = tfma.metrics.specs_from_metrics(metrics, query_key='doc_id')
マルチモデルの評価指標
TFMA は、複数のモデルの同時評価をサポートします。マルチモデル評価を実行すると、モデルごとにメトリクスが計算されます。例えば:
from google.protobuf import text_format
metrics_specs = text_format.Parse("""
metrics_specs {
# no model_names means all models
...
}
""", tfma.EvalConfig()).metrics_specs
モデルのサブセットに対してメトリクスを計算する必要がある場合は、 metric_specs
でmodel_names
を設定します。例えば:
from google.protobuf import text_format
metrics_specs = text_format.Parse("""
metrics_specs {
model_names: ["my-model1"]
...
}
""", tfma.EvalConfig()).metrics_specs
specs_from_metrics
API は、モデル名の受け渡しもサポートしています。
metrics = [
...
]
metrics_specs = tfma.metrics.specs_from_metrics(
metrics, model_names=['my-model1'])
モデル比較メトリック
TFMA は、ベースライン モデルに対する候補モデルの比較メトリックの評価をサポートします。候補モデルとベースライン モデルのペアをセットアップする簡単な方法は、適切なモデル名 (tfma.BASELINE_KEY および tfma.CANDIDATE_KEY) を持つ eval_shared_model を渡すことです。
eval_config = text_format.Parse("""
model_specs {
# ... model_spec without names ...
}
metrics_spec {
# ... metrics ...
}
""", tfma.EvalConfig())
eval_shared_models = [
tfma.default_eval_shared_model(
model_name=tfma.CANDIDATE_KEY,
eval_saved_model_path='/path/to/saved/candidate/model',
eval_config=eval_config),
tfma.default_eval_shared_model(
model_name=tfma.BASELINE_KEY,
eval_saved_model_path='/path/to/saved/baseline/model',
eval_config=eval_config),
]
eval_result = tfma.run_model_analysis(
eval_shared_models,
eval_config=eval_config,
# This assumes your data is a TFRecords file containing records in the
# tf.train.Example format.
data_location="/path/to/file/containing/tfrecords",
output_path="/path/for/output")
比較メトリックは、すべての差分可能なメトリック (現在は精度や AUC などのスカラー値メトリックのみ) に対して自動的に計算されます。
複数出力モデルのメトリクス
TFMA は、異なる出力を持つモデルのメトリクスの評価をサポートします。複数出力モデルは、出力名をキーとした辞書の形式で出力予測を保存します。複数出力モデルを使用する場合、メトリックのセットに関連付けられた出力の名前を MetricsSpec のoutput_names
セクションで指定する必要があります。例えば:
from google.protobuf import text_format
metrics_specs = text_format.Parse("""
metrics_specs {
output_names: ["my-output"]
...
}
""", tfma.EvalConfig()).metrics_specs
specs_from_metrics
API は、出力名の受け渡しもサポートしています。
metrics = [
...
]
metrics_specs = tfma.metrics.specs_from_metrics(
metrics, output_names=['my-output'])
メトリック設定のカスタマイズ
TFMA を使用すると、さまざまなメトリックで使用される設定をカスタマイズできます。たとえば、名前を変更したり、しきい値を設定したりする場合があります。これを行うには、メトリック構成にconfig
セクションを追加します。構成は、メトリクス__init__
メソッドに渡されるパラメーターの JSON 文字列バージョンを使用して指定されます (使いやすくするために、先頭と末尾の '{' および '}' 括弧は省略できます)。例えば:
from google.protobuf import text_format
metrics_specs = text_format.Parse("""
metrics_specs {
metrics {
class_name: "ConfusionMatrixAtThresholds"
config: '"thresholds": [0.3, 0.5, 0.8]'
}
}
""", tfma.MetricsSpec()).metrics_specs
もちろん、このカスタマイズも直接サポートされています。
metrics = [
tfma.metrics.ConfusionMatrixAtThresholds(thresholds=[0.3, 0.5, 0.8]),
]
metrics_specs = tfma.metrics.specs_from_metrics(metrics)
出力
メトリック評価の出力は、使用される構成に基づいた一連のメトリック キー/値、および/またはプロット キー/値です。
メトリックキー
MetricKey は、構造化キー タイプを使用して定義されます。このキーは、メトリックの次の各側面を一意に識別します。
- メトリクス名 (
auc
、mean_label
など) - モデル名 (複数モデル評価の場合のみ使用)
- 出力名 (複数出力モデルが評価される場合にのみ使用されます)
- サブキー(例:マルチクラスモデルが二値化されている場合のクラスID)
メトリック値
MetricValuesは、さまざまなメトリック ( double
、 ConfusionMatrixAtThresholds
など) によってサポートされるさまざまな値のタイプをカプセル化するプロトを使用して定義されます。
サポートされているメトリック値のタイプは次のとおりです。
-
double_value
- double 型のラッパー。 -
bytes_value
- バイト値。 -
bounded_value
- 点ごとの推定値となる実際の値を表します。オプションで、何らかの近似境界を使用できます。value
、lower_bound
、およびupper_bound
プロパティがあります。 -
value_at_cutoffs
- カットオフ時の値 (precision@K、recall@K など)。プロパティvalues
があり、それぞれにプロパティcutoff
およびvalue
があります。 -
confusion_matrix_at_thresholds
- しきい値での混同行列。プロパティmatrices
があり、それぞれに、threshold
、precision
、recall
、およびfalse_negatives
などの混同行列値のプロパティがあります。 -
array_value
- 値の配列を返すメトリクスの場合。
プロットキー
PlotKey はメトリック キーに似ていますが、歴史的な理由により、すべてのプロット値が単一のプロトに格納されるため、プロット キーには名前がありません。
値のプロット
サポートされているすべてのプロットは、 PlotDataと呼ばれる単一のプロトに保存されます。
評価結果
評価実行からの戻り値はtfma.EvalResult
です。このレコードには、メトリック キーをマルチレベルの辞書としてエンコードするslicing_metrics
が含まれており、レベルはそれぞれ出力名、クラス ID、メトリック名、メトリック値に対応します。これは、Jupiter ノートブックの UI 表示に使用することを目的としています。基礎となるデータへのアクセスが必要な場合は、代わりにmetrics
結果ファイルを使用する必要があります ( metrics_for_slice.protoを参照)。
カスタマイズ
保存された keras (または従来の EvalSavedModel) の一部として追加されるカスタム メトリクスに加えて。 TFMA ポスト保存でメトリクスをカスタマイズするには、(1) カスタム keras メトリクス クラスを定義する方法と、(2) ビーム コンバイナを利用したカスタム TFMA メトリクス クラスを定義する方法の 2 つがあります。
どちらの場合も、メトリックは、メトリック クラスの名前と関連モジュールを指定することによって構成されます。例えば:
from google.protobuf import text_format
metrics_specs = text_format.Parse("""
metrics_specs {
metrics { class_name: "MyMetric" module: "my.module"}
}
""", tfma.EvalConfig()).metrics_specs
カスタム Keras メトリクス
カスタム keras メトリクスを作成するには、ユーザーは実装でtf.keras.metrics.Metric
を拡張し、評価時にメトリクスのモジュールが利用可能であることを確認する必要があります。
モデルの保存後に追加されたメトリクスの場合、TFMA はラベル (y_true)、予測 (y_pred)、およびサンプルの重み (sample_weight) をupdate_state
メソッドのパラメータとして受け取るメトリクスのみをサポートすることに注意してください。
Keras メトリクスの例
以下はカスタム keras メトリクスの例です。
class MyMetric(tf.keras.metrics.Mean):
def __init__(self, name='my_metric', dtype=None):
super(MyMetric, self).__init__(name=name, dtype=dtype)
def update_state(self, y_true, y_pred, sample_weight=None):
return super(MyMetric, self).update_state(
y_pred, sample_weight=sample_weight)
カスタム TFMA メトリクス
カスタム TFMA メトリクスを作成するには、ユーザーは実装でtfma.metrics.Metric
を拡張し、評価時にメトリクスのモジュールが利用可能であることを確認する必要があります。
メトリック
tfma.metrics.Metric
実装は、メトリック値の計算に必要な計算 (おそらく複数) を作成する関数とともに、メトリック構成を定義する一連の kwargs で構成されます。使用できる主な計算タイプにはtfma.metrics.MetricComputation
とtfma.metrics.DerivedMetricComputation
の 2 つがあります。これらについては、以下のセクションで説明します。これらの計算を作成する関数には、次のパラメーターが入力として渡されます。
-
eval_config: tfam.EvalConfig
- エバリュエーターに渡される評価構成 (使用する予測キーなどのモデル仕様設定を調べるのに役立ちます)。
-
model_names: List[Text]
- メトリクスを計算するモデル名のリスト (単一モデルの場合はなし)
-
output_names: List[Text]
。- メトリクスを計算する出力名のリスト (単一モデルの場合はなし)
-
sub_keys: List[tfma.SubKey]
。- メトリクスを計算するためのサブキー (クラス ID、上位 K など) のリスト (またはなし)
-
aggregation_type: tfma.AggregationType
- 集計メトリックを計算する場合の集計のタイプ。
-
class_weights: Dict[int, float]
。- 集計メトリックを計算する場合に使用するクラスの重み。
-
query_key: Text
- クエリ/ランキングベースのメトリクスを計算する場合に使用されるクエリキー。
メトリックがこれらの設定の 1 つ以上に関連付けられていない場合、それらのパラメーターがそのシグネチャ定義から除外される可能性があります。
メトリックが各モデル、出力、およびサブキーに対して同じ方法で計算される場合、ユーティリティtfma.metrics.merge_per_key_computations
を使用して、これらの入力のそれぞれに対して同じ計算を個別に実行できます。
メトリック計算
MetricComputation
、 preprocessors
とcombiner
の組み合わせで構成されます。 preprocessors
はpreprocessor
のリストであり、入力として抽出を受け取り、結合器によって使用される初期状態を出力するbeam.DoFn
です (抽出の詳細については、アーキテクチャを参照してください)。すべてのプリプロセッサはリストの順序で順次実行されます。 preprocessors
が空の場合、コンバイナーにはStandardMetricInputsが渡されます (標準メトリック入力にはラベル、予測、および example_weights が含まれます)。 combiner
、(スライス キー、プリプロセッサ出力) のタプルを入力として受け取り、(slice_key、メトリック結果 dict) のタプルを結果として出力するbeam.CombineFn
です。
スライスはpreprocessors
とcombiner
の間で発生することに注意してください。
メトリクスの計算で両方の標準メトリクス入力を利用したいが、 features
抽出からのいくつかの特徴で拡張したい場合は、複数の結合器からの要求された特徴を 1 つの結合器にマージする特別なFeaturePreprocessorを使用できることに注意してください。すべてのコンバイナーに渡される共有 StandardMetricsInputs 値 (コンバイナーは、関心のある機能を読み取り、残りを無視する責任があります)。
例
以下は、ExampleCount を計算するための TFMA メトリック定義の非常に簡単な例です。
class ExampleCount(tfma.metrics.Metric):
def __init__(self, name: Text = 'example_count'):
super(ExampleCount, self).__init__(_example_count, name=name)
def _example_count(
name: Text = 'example_count') -> tfma.metrics.MetricComputations:
key = tfma.metrics.MetricKey(name=name)
return [
tfma.metrics.MetricComputation(
keys=[key],
preprocessors=[_ExampleCountPreprocessor()],
combiner=_ExampleCountCombiner(key))
]
class ExampleCountTest(tfma.test.testutil.TensorflowModelAnalysisTest):
def testExampleCount(self):
metric = ExampleCount()
computations = metric.computations(example_weighted=False)
computation = computations[0]
with beam.Pipeline() as pipeline:
result = (
pipeline
| 'Create' >> beam.Create([...]) # Add inputs
| 'PreProcess' >> beam.ParDo(computation.preprocessors[0])
| 'Process' >> beam.Map(tfma.metrics.to_standard_metric_inputs)
| 'AddSlice' >> beam.Map(lambda x: ((), x))
| 'ComputeMetric' >> beam.CombinePerKey(computation.combiner)
)
def check_result(got):
try:
self.assertLen(got, 1)
got_slice_key, got_metrics = got[0]
self.assertEqual(got_slice_key, ())
key = computation.keys[0]
self.assertIn(key, got_metrics)
self.assertAlmostEqual(got_metrics[key], expected_value, places=5)
except AssertionError as err:
raise util.BeamAssertException(err)
util.assert_that(result, check_result, label='result')
class _ExampleCountPreprocessor(beam.DoFn):
def process(self, extracts: tfma.Extracts) -> Iterable[int]:
yield 1
class _ExampleCountPreprocessorTest(unittest.TestCase):
def testExampleCountPreprocessor(self):
... # Init the test case here
with beam.Pipeline() as pipeline:
updated_pcoll = (
pipeline
| 'Create' >> beam.Create([...]) # Add inputs
| 'Preprocess'
>> beam.ParDo(
_ExampleCountPreprocessor()
)
)
beam_testing_util.assert_that(
updated_pcoll,
lambda result: ..., # Assert the test case
)
class _ExampleCountCombiner(beam.CombineFn):
def __init__(self, metric_key: tfma.metrics.MetricKey):
self._metric_key = metric_key
def create_accumulator(self) -> int:
return 0
def add_input(self, accumulator: int, state: int) -> int:
return accumulator + state
def merge_accumulators(self, accumulators: Iterable[int]) -> int:
accumulators = iter(accumulators)
result = next(accumulator)
for accumulator in accumulators:
result += accumulator
return result
def extract_output(self,
accumulator: int) -> Dict[tfma.metrics.MetricKey, int]:
return {self._metric_key: accumulator}
派生メトリック計算
DerivedMetricComputation
は、他のメトリック計算の出力に基づいてメトリック値を計算するために使用される結果関数で構成されます。 result 関数は、計算された値の辞書を入力として受け取り、追加のメトリック結果の辞書を出力します。
メトリックによって作成される計算のリストに、派生計算が依存する計算を含めることが許容される (推奨される) ことに注意してください。これにより、複数のメトリクス間で共有される計算を事前に作成して渡す必要がなくなります。エバリュエーターは、同じ定義を持つ計算の重複を自動的に除外するため、実際に実行される計算は 1 つだけです。
例
TJUR メトリクスは、派生メトリクスの良い例を提供します。