TensorFlow.org에서 보기 | Google Colab에서 실행하기 | GitHub에서소스 보기 | 노트북 다운로드하기 |
TF1에서 tf.metrics
는 모든 메트릭 함수에 대한 API 네임스페이스입니다. 각 메트릭은 label
과 prediction
을 입력 매개변수로 사용하고 해당 메트릭 텐서를 결과로 반환하는 함수입니다. TF2에서 tf.keras.metrics
는 모든 메트릭 함수와 객체를 포함합니다. Metric
객체는 tf.keras.Model
과 tf.keras.layers.layer
와 함께 사용하여 메트릭 값을 계산할 수 있습니다.
설치하기
몇 가지 필요한 TensorFlow 가져오기로 시작합니다.
import tensorflow as tf
import tensorflow.compat.v1 as tf1
2022-12-14 20:49:39.466833: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 20:49:39.466926: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 20:49:39.466935: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
그리고 데모용으로 몇 가지 간단한 데이터를 준비합니다.
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [0, 0, 1]
eval_features = [[4., 4.5], [5., 5.5], [6., 6.5]]
eval_labels = [0, 1, 1]
TF1: Estimator를 사용하는 tf.compat.v1.metrics
TF1에서 메트릭은 eval_metric_ops
로 EstimatorSpec
에 추가될 수 있으며 연산은 tf.metrics
에 정의된 모든 메트릭 함수를 통해 생성됩니다. 예제에 따라 tf.metrics.accuracy
를 사용하는 방법을 확인할 수 있습니다.
def _input_fn():
return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)
def _eval_input_fn():
return tf1.data.Dataset.from_tensor_slices(
(eval_features, eval_labels)).batch(1)
def _model_fn(features, labels, mode):
logits = tf1.layers.Dense(2)(features)
predictions = tf.math.argmax(input=logits, axis=1)
loss = tf1.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)
optimizer = tf1.train.AdagradOptimizer(0.05)
train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
accuracy = tf1.metrics.accuracy(labels=labels, predictions=predictions)
return tf1.estimator.EstimatorSpec(mode,
predictions=predictions,
loss=loss,
train_op=train_op,
eval_metric_ops={'accuracy': accuracy})
estimator = tf1.estimator.Estimator(model_fn=_model_fn)
estimator.train(_input_fn)
INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpppdg7p6b INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpppdg7p6b', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpppdg7p6b/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.057665426, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3... INFO:tensorflow:Saving checkpoints for 3 into /tmpfs/tmp/tmpppdg7p6b/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3... INFO:tensorflow:Loss for final step: 7.887593. <tensorflow_estimator.python.estimator.estimator.Estimator at 0x7f94fbb4be50>
estimator.evaluate(_eval_input_fn)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-12-14T20:49:44 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpppdg7p6b/model.ckpt-3 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Inference Time : 0.25716s INFO:tensorflow:Finished evaluation at 2022-12-14-20:49:45 INFO:tensorflow:Saving dict for global step 3: accuracy = 0.33333334, global_step = 3, loss = 8.581979 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 3: /tmpfs/tmp/tmpppdg7p6b/model.ckpt-3 {'accuracy': 0.33333334, 'loss': 8.581979, 'global_step': 3}
또한, tf.estimator.add_metrics()
를 통해 메트릭을 Estimator에 직접 추가할 수 있습니다.
def mean_squared_error(labels, predictions):
labels = tf.cast(labels, predictions.dtype)
return {"mean_squared_error":
tf1.metrics.mean_squared_error(labels=labels, predictions=predictions)}
estimator = tf1.estimator.add_metrics(estimator, mean_squared_error)
estimator.evaluate(_eval_input_fn)
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpppdg7p6b', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-12-14T20:49:45 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpppdg7p6b/model.ckpt-3 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Inference Time : 0.15928s INFO:tensorflow:Finished evaluation at 2022-12-14-20:49:45 INFO:tensorflow:Saving dict for global step 3: accuracy = 0.33333334, global_step = 3, loss = 8.581979, mean_squared_error = 0.6666667 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 3: /tmpfs/tmp/tmpppdg7p6b/model.ckpt-3 {'accuracy': 0.33333334, 'loss': 8.581979, 'mean_squared_error': 0.6666667, 'global_step': 3}
TF2: tf.keras.Model을 사용하는 Keras 메트릭 API
TF2에서 tf.keras.metrics
는 모든 메트릭 클래스와 함수를 포함합니다. 이는 OOP 스타일로 설계되어 있으며 다른 tf.keras
API와 긴밀하게 통합합니다. 모든 메트릭은 tf.keras.metrics
네임스페이스에서 찾을 수 있으며 일반적으로 tf.compat.v1.metrics
와 tf.keras.metrics
사이에서 직접 매핑합니다.
다음 예제에서 메트릭이 model.compile()
메서드에 추가됩니다. 사용자는 레이블 및 예측 텐서를 지정하지 않고 메트릭 인스턴스만 생성하면 됩니다. Keras 모델은 모델 출력과 레이블을 메트릭 객체로 라우팅합니다.
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
eval_dataset = tf.data.Dataset.from_tensor_slices(
(eval_features, eval_labels)).batch(1)
inputs = tf.keras.Input((2,))
logits = tf.keras.layers.Dense(2)(inputs)
predictions = tf.math.argmax(input=logits, axis=1)
model = tf.keras.models.Model(inputs, predictions)
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)
model.compile(optimizer, loss='mse', metrics=[tf.keras.metrics.Accuracy()])
model.evaluate(eval_dataset, return_dict=True)
3/3 [==============================] - 0s 4ms/step - loss: 0.6667 - accuracy: 0.3333 {'loss': 0.6666666865348816, 'accuracy': 0.3333333432674408}
Eager 실행을 사용하도록 설정하면 tf.keras.metrics.Metric
인스턴스를 직접 사용하여 numpy 데이터 또는 Eager 텐서를 평가할 수 있습니다. tf.keras.metrics.Metric
객체는 상태 저장 컨테이너입니다. 메트릭 값은 metric.update_state(y_true, y_pred)
를 통해 업데이트할 수 있으며 결과는 metrics.result()
로 검색할 수 있습니다.
accuracy = tf.keras.metrics.Accuracy()
accuracy.update_state(y_true=[0, 0, 1, 1], y_pred=[0, 0, 0, 1])
accuracy.result().numpy()
0.75
accuracy.update_state(y_true=[0, 0, 1, 1], y_pred=[0, 0, 0, 0])
accuracy.update_state(y_true=[0, 0, 1, 1], y_pred=[1, 1, 0, 0])
accuracy.result().numpy()
0.41666666
tf.keras.metrics.Metric
에 대한 자세한 내용은 tf.keras.metrics.Metric
의 API 문서와 마이그레이션 가이드를 참조하세요.
TF1.x 옵티마이저를 Keras 옵티마이저로 마이그레이션하기
Adam 옵티마이저와 경사 하강 옵티마이저와 같은 tf.compat.v1.train
의 옵티마이저는 tf.keras.optimizers
에서 동일한 기능을 합니다.
아래 표에는 이러한 레거시 옵티마이저를 Keras에 맞게 변환하는 방법이 요약되어 있습니다. 기본 학습률 업데이트하기와 같은 추가 단계가 필요하지 않는 한 TF1.x 버전을 TF2 버전으로 직접 교체할 수 있습니다.
옵티마이저를 변환하면 이전 체크포인트가 호환되지 않을 수 있습니다.
TF1.x | TF2 | 추가 단계 |
---|---|---|
`tf.v1.train.GradientDescentOptimizer` | tf.keras.optimizers.SGD |
없음 |
`tf.v1.train.MomentumOptimizer` | tf.keras.optimizers.SGD |
`momentum` 인수 포함 |
`tf.v1.train.AdamOptimizer` | tf.keras.optimizers.Adam |
`beta1`과 `beta2` 인수를 `beta_1`과 `beta_2`로 이름 변경 |
`tf.v1.train.RMSPropOptimizer` | tf.keras.optimizers.RMSprop |
`decay` 인수를 `rho`로 이름 변경 |
`tf.v1.train.AdadeltaOptimizer` | tf.keras.optimizers.Adadelta |
없음 |
`tf.v1.train.AdagradOptimizer` | tf.keras.optimizers.Adagrad |
없음 |
`tf.v1.train.FtrlOptimizer` | tf.keras.optimizers.Ftrl |
`accum_name`와 `linear_name` 인수 제거 |
`tf.contrib.AdamaxOptimizer` | tf.keras.optimizers.Adamax |
`beta1`과 `beta2` 인수를 `beta_1`과 `beta_2`로 이름 변경 |
`tf.contrib.Nadam` | tf.keras.optimizers.Nadam |
`beta1`과 `beta2` 인수를 `beta_1`과 `beta_2`로 이름 변경 |
참고: TF2에서 모든 엡실론(수치 안정 상수)은 이제 1e-8
대신 1e-7
로 기본 설정됩니다. 이 차이는 대부분의 사용 사례에서 무시할 수 있습니다.