早期停止を移行する

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

このノートブックは、最初に TensorFlow 1 で tf.estimator.Estimator と早期停止フックを使用してから、次に TensorFlow 2 で Keras API またはカスタムトレーニングループを使用して、早期停止を使用してモデルトレーニングをセットアップする方法を示します。早期停止は、たとえば検証損失が特定のしきい値に達した場合にトレーニングを停止する正則化手法です。

TensorFlow 2 では、早期停止を実装する 3 つの方法があります。

セットアップ

import time
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_datasets as tfds
2024-01-11 18:27:43.622883: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-11 18:27:43.622930: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-11 18:27:43.624662: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

TensorFlow 1: 早期停止フックと tf.estimator による早期停止

MNIST データセットの読み込みと前処理、および tf.estimator.Estimator で使用されるモデル定義の関数を定義することから始めます。

def normalize_img(image, label):
  return tf.cast(image, tf.float32) / 255., label

def _input_fn():
  ds_train = tfds.load(
    name='mnist',
    split='train',
    shuffle_files=True,
    as_supervised=True)

  ds_train = ds_train.map(
      normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
  ds_train = ds_train.batch(128)
  ds_train = ds_train.repeat(100)
  return ds_train

def _eval_input_fn():
  ds_test = tfds.load(
    name='mnist',
    split='test',
    shuffle_files=True,
    as_supervised=True)
  ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
  ds_test = ds_test.batch(128)
  return ds_test

def _model_fn(features, labels, mode):
  flatten = tf1.layers.Flatten()(features)
  features = tf1.layers.Dense(128, 'relu')(flatten)
  logits = tf1.layers.Dense(10)(features)

  loss = tf1.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
  optimizer = tf1.train.AdagradOptimizer(0.005)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())

  return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

TensorFlow 1 では、早期停止は tf.estimator.experimental.make_early_stopping_hook で早期停止フックを設定することで機能します。引数なしで関数を受け入れることができる should_stop_fn のパラメータとして、フックを make_early_stopping_hook メソッドに渡します。 should_stop_fnTrue を返すと、トレーニングは停止します。

次の例は、トレーニング時間を最大 20 秒に制限する早期停止手法を実装する方法を示しています。

estimator = tf1.estimator.Estimator(model_fn=_model_fn)

start_time = time.time()
max_train_seconds = 20

def should_stop_fn():
  return time.time() - start_time > max_train_seconds

early_stopping_hook = tf1.estimator.experimental.make_early_stopping_hook(
    estimator=estimator,
    should_stop_fn=should_stop_fn,
    run_every_secs=1,
    run_every_steps=None)

train_spec = tf1.estimator.TrainSpec(
    input_fn=_input_fn,
    hooks=[early_stopping_hook])

eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)

tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_68749/1011025907.py:1: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1844: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpiy6wag0h
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpiy6wag0h', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_68749/1011025907.py:9: make_early_stopping_hook (from tensorflow_estimator.python.estimator.early_stopping) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/early_stopping.py:474: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_68749/1011025907.py:15: TrainSpec.__new__ (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_68749/1011025907.py:19: EvalSpec.__new__ (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_68749/1011025907.py:21: train_and_evaluate (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_68749/1468818800.py:37: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_68749/1468818800.py:37: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpiy6wag0h/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpiy6wag0h/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:loss = 2.4071355, step = 0
INFO:tensorflow:loss = 2.4071355, step = 0
INFO:tensorflow:global_step/sec: 383.516
INFO:tensorflow:global_step/sec: 383.516
INFO:tensorflow:loss = 1.2792293, step = 100 (0.263 sec)
INFO:tensorflow:loss = 1.2792293, step = 100 (0.263 sec)
INFO:tensorflow:global_step/sec: 450.873
INFO:tensorflow:global_step/sec: 450.873
INFO:tensorflow:loss = 0.77129304, step = 200 (0.221 sec)
INFO:tensorflow:loss = 0.77129304, step = 200 (0.221 sec)
INFO:tensorflow:global_step/sec: 466.825
INFO:tensorflow:global_step/sec: 466.825
INFO:tensorflow:loss = 0.6746831, step = 300 (0.214 sec)
INFO:tensorflow:loss = 0.6746831, step = 300 (0.214 sec)
INFO:tensorflow:global_step/sec: 455.492
INFO:tensorflow:global_step/sec: 455.492
INFO:tensorflow:loss = 0.5994307, step = 400 (0.220 sec)
INFO:tensorflow:loss = 0.5994307, step = 400 (0.220 sec)
INFO:tensorflow:global_step/sec: 462.334
INFO:tensorflow:global_step/sec: 462.334
INFO:tensorflow:loss = 0.47565445, step = 500 (0.216 sec)
INFO:tensorflow:loss = 0.47565445, step = 500 (0.216 sec)
INFO:tensorflow:global_step/sec: 636.411
INFO:tensorflow:global_step/sec: 636.411
INFO:tensorflow:loss = 0.43863237, step = 600 (0.157 sec)
INFO:tensorflow:loss = 0.43863237, step = 600 (0.157 sec)
INFO:tensorflow:global_step/sec: 637.375
INFO:tensorflow:global_step/sec: 637.375
INFO:tensorflow:loss = 0.35803005, step = 700 (0.157 sec)
INFO:tensorflow:loss = 0.35803005, step = 700 (0.157 sec)
INFO:tensorflow:global_step/sec: 641.452
INFO:tensorflow:global_step/sec: 641.452
INFO:tensorflow:loss = 0.50397503, step = 800 (0.155 sec)
INFO:tensorflow:loss = 0.50397503, step = 800 (0.155 sec)
INFO:tensorflow:global_step/sec: 617.79
INFO:tensorflow:global_step/sec: 617.79
INFO:tensorflow:loss = 0.38481098, step = 900 (0.162 sec)
INFO:tensorflow:loss = 0.38481098, step = 900 (0.162 sec)
INFO:tensorflow:global_step/sec: 548.636
INFO:tensorflow:global_step/sec: 548.636
INFO:tensorflow:loss = 0.43239072, step = 1000 (0.183 sec)
INFO:tensorflow:loss = 0.43239072, step = 1000 (0.183 sec)
INFO:tensorflow:global_step/sec: 578.089
INFO:tensorflow:global_step/sec: 578.089
INFO:tensorflow:loss = 0.43478006, step = 1100 (0.173 sec)
INFO:tensorflow:loss = 0.43478006, step = 1100 (0.173 sec)
INFO:tensorflow:global_step/sec: 637.798
INFO:tensorflow:global_step/sec: 637.798
INFO:tensorflow:loss = 0.38385656, step = 1200 (0.157 sec)
INFO:tensorflow:loss = 0.38385656, step = 1200 (0.157 sec)
INFO:tensorflow:global_step/sec: 640.452
INFO:tensorflow:global_step/sec: 640.452
INFO:tensorflow:loss = 0.45495924, step = 1300 (0.156 sec)
INFO:tensorflow:loss = 0.45495924, step = 1300 (0.156 sec)
INFO:tensorflow:global_step/sec: 630.436
INFO:tensorflow:global_step/sec: 630.436
INFO:tensorflow:loss = 0.29222503, step = 1400 (0.160 sec)
INFO:tensorflow:loss = 0.29222503, step = 1400 (0.160 sec)
INFO:tensorflow:global_step/sec: 513.24
INFO:tensorflow:global_step/sec: 513.24
INFO:tensorflow:loss = 0.27475923, step = 1500 (0.194 sec)
INFO:tensorflow:loss = 0.27475923, step = 1500 (0.194 sec)
INFO:tensorflow:global_step/sec: 618.569
INFO:tensorflow:global_step/sec: 618.569
INFO:tensorflow:loss = 0.3822915, step = 1600 (0.162 sec)
INFO:tensorflow:loss = 0.3822915, step = 1600 (0.162 sec)
INFO:tensorflow:global_step/sec: 619.712
INFO:tensorflow:global_step/sec: 619.712
INFO:tensorflow:loss = 0.34821516, step = 1700 (0.161 sec)
INFO:tensorflow:loss = 0.34821516, step = 1700 (0.161 sec)
INFO:tensorflow:global_step/sec: 603.56
INFO:tensorflow:global_step/sec: 603.56
INFO:tensorflow:loss = 0.31085092, step = 1800 (0.166 sec)
INFO:tensorflow:loss = 0.31085092, step = 1800 (0.166 sec)
INFO:tensorflow:global_step/sec: 545.691
INFO:tensorflow:global_step/sec: 545.691
INFO:tensorflow:loss = 0.5044548, step = 1900 (0.183 sec)
INFO:tensorflow:loss = 0.5044548, step = 1900 (0.183 sec)
INFO:tensorflow:global_step/sec: 578.335
INFO:tensorflow:global_step/sec: 578.335
INFO:tensorflow:loss = 0.20964296, step = 2000 (0.173 sec)
INFO:tensorflow:loss = 0.20964296, step = 2000 (0.173 sec)
INFO:tensorflow:global_step/sec: 632.198
INFO:tensorflow:global_step/sec: 632.198
INFO:tensorflow:loss = 0.25611854, step = 2100 (0.158 sec)
INFO:tensorflow:loss = 0.25611854, step = 2100 (0.158 sec)
INFO:tensorflow:global_step/sec: 561.831
INFO:tensorflow:global_step/sec: 561.831
INFO:tensorflow:loss = 0.29500312, step = 2200 (0.178 sec)
INFO:tensorflow:loss = 0.29500312, step = 2200 (0.178 sec)
INFO:tensorflow:global_step/sec: 636.661
INFO:tensorflow:global_step/sec: 636.661
INFO:tensorflow:loss = 0.3458867, step = 2300 (0.157 sec)
INFO:tensorflow:loss = 0.3458867, step = 2300 (0.157 sec)
INFO:tensorflow:global_step/sec: 578.296
INFO:tensorflow:global_step/sec: 578.296
INFO:tensorflow:loss = 0.2508843, step = 2400 (0.173 sec)
INFO:tensorflow:loss = 0.2508843, step = 2400 (0.173 sec)
INFO:tensorflow:global_step/sec: 580.748
INFO:tensorflow:global_step/sec: 580.748
INFO:tensorflow:loss = 0.22855805, step = 2500 (0.172 sec)
INFO:tensorflow:loss = 0.22855805, step = 2500 (0.172 sec)
INFO:tensorflow:global_step/sec: 602.423
INFO:tensorflow:global_step/sec: 602.423
INFO:tensorflow:loss = 0.1585938, step = 2600 (0.166 sec)
INFO:tensorflow:loss = 0.1585938, step = 2600 (0.166 sec)
INFO:tensorflow:global_step/sec: 634.883
INFO:tensorflow:global_step/sec: 634.883
INFO:tensorflow:loss = 0.30701032, step = 2700 (0.158 sec)
INFO:tensorflow:loss = 0.30701032, step = 2700 (0.158 sec)
INFO:tensorflow:global_step/sec: 635.129
INFO:tensorflow:global_step/sec: 635.129
INFO:tensorflow:loss = 0.46658728, step = 2800 (0.157 sec)
INFO:tensorflow:loss = 0.46658728, step = 2800 (0.157 sec)
INFO:tensorflow:global_step/sec: 522.673
INFO:tensorflow:global_step/sec: 522.673
INFO:tensorflow:loss = 0.22727002, step = 2900 (0.191 sec)
INFO:tensorflow:loss = 0.22727002, step = 2900 (0.191 sec)
INFO:tensorflow:global_step/sec: 631.452
INFO:tensorflow:global_step/sec: 631.452
INFO:tensorflow:loss = 0.32864082, step = 3000 (0.158 sec)
INFO:tensorflow:loss = 0.32864082, step = 3000 (0.158 sec)
INFO:tensorflow:global_step/sec: 628.747
INFO:tensorflow:global_step/sec: 628.747
INFO:tensorflow:loss = 0.20764841, step = 3100 (0.159 sec)
INFO:tensorflow:loss = 0.20764841, step = 3100 (0.159 sec)
INFO:tensorflow:global_step/sec: 627.033
INFO:tensorflow:global_step/sec: 627.033
INFO:tensorflow:loss = 0.42431408, step = 3200 (0.159 sec)
INFO:tensorflow:loss = 0.42431408, step = 3200 (0.159 sec)
INFO:tensorflow:global_step/sec: 543.587
INFO:tensorflow:global_step/sec: 543.587
INFO:tensorflow:loss = 0.324377, step = 3300 (0.185 sec)
INFO:tensorflow:loss = 0.324377, step = 3300 (0.185 sec)
INFO:tensorflow:global_step/sec: 617.368
INFO:tensorflow:global_step/sec: 617.368
INFO:tensorflow:loss = 0.2624207, step = 3400 (0.162 sec)
INFO:tensorflow:loss = 0.2624207, step = 3400 (0.162 sec)
INFO:tensorflow:global_step/sec: 621.767
INFO:tensorflow:global_step/sec: 621.767
INFO:tensorflow:loss = 0.19428356, step = 3500 (0.160 sec)
INFO:tensorflow:loss = 0.19428356, step = 3500 (0.160 sec)
INFO:tensorflow:global_step/sec: 632.717
INFO:tensorflow:global_step/sec: 632.717
INFO:tensorflow:loss = 0.23018332, step = 3600 (0.158 sec)
INFO:tensorflow:loss = 0.23018332, step = 3600 (0.158 sec)
INFO:tensorflow:global_step/sec: 624.224
INFO:tensorflow:global_step/sec: 624.224
INFO:tensorflow:loss = 0.2818337, step = 3700 (0.160 sec)
INFO:tensorflow:loss = 0.2818337, step = 3700 (0.160 sec)
INFO:tensorflow:global_step/sec: 485.89
INFO:tensorflow:global_step/sec: 485.89
INFO:tensorflow:loss = 0.35542685, step = 3800 (0.206 sec)
INFO:tensorflow:loss = 0.35542685, step = 3800 (0.206 sec)
INFO:tensorflow:global_step/sec: 578.011
INFO:tensorflow:global_step/sec: 578.011
INFO:tensorflow:loss = 0.22460853, step = 3900 (0.174 sec)
INFO:tensorflow:loss = 0.22460853, step = 3900 (0.174 sec)
INFO:tensorflow:global_step/sec: 588.815
INFO:tensorflow:global_step/sec: 588.815
INFO:tensorflow:loss = 0.28191128, step = 4000 (0.170 sec)
INFO:tensorflow:loss = 0.28191128, step = 4000 (0.170 sec)
INFO:tensorflow:global_step/sec: 583.65
INFO:tensorflow:global_step/sec: 583.65
INFO:tensorflow:loss = 0.19746388, step = 4100 (0.172 sec)
INFO:tensorflow:loss = 0.19746388, step = 4100 (0.172 sec)
INFO:tensorflow:global_step/sec: 629.072
INFO:tensorflow:global_step/sec: 629.072
INFO:tensorflow:loss = 0.25752583, step = 4200 (0.159 sec)
INFO:tensorflow:loss = 0.25752583, step = 4200 (0.159 sec)
INFO:tensorflow:global_step/sec: 548.53
INFO:tensorflow:global_step/sec: 548.53
INFO:tensorflow:loss = 0.29156587, step = 4300 (0.182 sec)
INFO:tensorflow:loss = 0.29156587, step = 4300 (0.182 sec)
INFO:tensorflow:global_step/sec: 584.561
INFO:tensorflow:global_step/sec: 584.561
INFO:tensorflow:loss = 0.3017162, step = 4400 (0.171 sec)
INFO:tensorflow:loss = 0.3017162, step = 4400 (0.171 sec)
INFO:tensorflow:global_step/sec: 596.634
INFO:tensorflow:global_step/sec: 596.634
INFO:tensorflow:loss = 0.26009542, step = 4500 (0.167 sec)
INFO:tensorflow:loss = 0.26009542, step = 4500 (0.167 sec)
INFO:tensorflow:global_step/sec: 629.979
INFO:tensorflow:global_step/sec: 629.979
INFO:tensorflow:loss = 0.32197523, step = 4600 (0.159 sec)
INFO:tensorflow:loss = 0.32197523, step = 4600 (0.159 sec)
INFO:tensorflow:global_step/sec: 580.103
INFO:tensorflow:global_step/sec: 580.103
INFO:tensorflow:loss = 0.15014084, step = 4700 (0.173 sec)
INFO:tensorflow:loss = 0.15014084, step = 4700 (0.173 sec)
INFO:tensorflow:global_step/sec: 625.18
INFO:tensorflow:global_step/sec: 625.18
INFO:tensorflow:loss = 0.28748113, step = 4800 (0.159 sec)
INFO:tensorflow:loss = 0.28748113, step = 4800 (0.159 sec)
INFO:tensorflow:global_step/sec: 622.796
INFO:tensorflow:global_step/sec: 622.796
INFO:tensorflow:loss = 0.36667243, step = 4900 (0.161 sec)
INFO:tensorflow:loss = 0.36667243, step = 4900 (0.161 sec)
INFO:tensorflow:global_step/sec: 588.541
INFO:tensorflow:global_step/sec: 588.541
INFO:tensorflow:loss = 0.2555472, step = 5000 (0.170 sec)
INFO:tensorflow:loss = 0.2555472, step = 5000 (0.170 sec)
INFO:tensorflow:global_step/sec: 541.24
INFO:tensorflow:global_step/sec: 541.24
INFO:tensorflow:loss = 0.321966, step = 5100 (0.185 sec)
INFO:tensorflow:loss = 0.321966, step = 5100 (0.185 sec)
INFO:tensorflow:global_step/sec: 559.688
INFO:tensorflow:global_step/sec: 559.688
INFO:tensorflow:loss = 0.21373868, step = 5200 (0.179 sec)
INFO:tensorflow:loss = 0.21373868, step = 5200 (0.179 sec)
INFO:tensorflow:global_step/sec: 614.89
INFO:tensorflow:global_step/sec: 614.89
INFO:tensorflow:loss = 0.2476358, step = 5300 (0.163 sec)
INFO:tensorflow:loss = 0.2476358, step = 5300 (0.163 sec)
INFO:tensorflow:global_step/sec: 618.975
INFO:tensorflow:global_step/sec: 618.975
INFO:tensorflow:loss = 0.16120149, step = 5400 (0.162 sec)
INFO:tensorflow:loss = 0.16120149, step = 5400 (0.162 sec)
INFO:tensorflow:global_step/sec: 613.135
INFO:tensorflow:global_step/sec: 613.135
INFO:tensorflow:loss = 0.20572953, step = 5500 (0.163 sec)
INFO:tensorflow:loss = 0.20572953, step = 5500 (0.163 sec)
INFO:tensorflow:global_step/sec: 624.712
INFO:tensorflow:global_step/sec: 624.712
INFO:tensorflow:loss = 0.1752571, step = 5600 (0.160 sec)
INFO:tensorflow:loss = 0.1752571, step = 5600 (0.160 sec)
INFO:tensorflow:global_step/sec: 544.799
INFO:tensorflow:global_step/sec: 544.799
INFO:tensorflow:loss = 0.16339287, step = 5700 (0.184 sec)
INFO:tensorflow:loss = 0.16339287, step = 5700 (0.184 sec)
INFO:tensorflow:global_step/sec: 528.028
INFO:tensorflow:global_step/sec: 528.028
INFO:tensorflow:loss = 0.28886348, step = 5800 (0.190 sec)
INFO:tensorflow:loss = 0.28886348, step = 5800 (0.190 sec)
INFO:tensorflow:global_step/sec: 600.042
INFO:tensorflow:global_step/sec: 600.042
INFO:tensorflow:loss = 0.21474686, step = 5900 (0.167 sec)
INFO:tensorflow:loss = 0.21474686, step = 5900 (0.167 sec)
INFO:tensorflow:global_step/sec: 584.179
INFO:tensorflow:global_step/sec: 584.179
INFO:tensorflow:loss = 0.2417813, step = 6000 (0.171 sec)
INFO:tensorflow:loss = 0.2417813, step = 6000 (0.171 sec)
INFO:tensorflow:global_step/sec: 564.652
INFO:tensorflow:global_step/sec: 564.652
INFO:tensorflow:loss = 0.17924747, step = 6100 (0.177 sec)
INFO:tensorflow:loss = 0.17924747, step = 6100 (0.177 sec)
INFO:tensorflow:global_step/sec: 625.477
INFO:tensorflow:global_step/sec: 625.477
INFO:tensorflow:loss = 0.25654313, step = 6200 (0.160 sec)
INFO:tensorflow:loss = 0.25654313, step = 6200 (0.160 sec)
INFO:tensorflow:global_step/sec: 650.049
INFO:tensorflow:global_step/sec: 650.049
INFO:tensorflow:loss = 0.24964629, step = 6300 (0.154 sec)
INFO:tensorflow:loss = 0.24964629, step = 6300 (0.154 sec)
INFO:tensorflow:global_step/sec: 656.719
INFO:tensorflow:global_step/sec: 656.719
INFO:tensorflow:loss = 0.2832839, step = 6400 (0.152 sec)
INFO:tensorflow:loss = 0.2832839, step = 6400 (0.152 sec)
INFO:tensorflow:global_step/sec: 635.097
INFO:tensorflow:global_step/sec: 635.097
INFO:tensorflow:loss = 0.2779467, step = 6500 (0.157 sec)
INFO:tensorflow:loss = 0.2779467, step = 6500 (0.157 sec)
INFO:tensorflow:global_step/sec: 518.757
INFO:tensorflow:global_step/sec: 518.757
INFO:tensorflow:loss = 0.18514237, step = 6600 (0.193 sec)
INFO:tensorflow:loss = 0.18514237, step = 6600 (0.193 sec)
INFO:tensorflow:global_step/sec: 602.511
INFO:tensorflow:global_step/sec: 602.511
INFO:tensorflow:loss = 0.2525888, step = 6700 (0.166 sec)
INFO:tensorflow:loss = 0.2525888, step = 6700 (0.166 sec)
INFO:tensorflow:global_step/sec: 643.017
INFO:tensorflow:global_step/sec: 643.017
INFO:tensorflow:loss = 0.3977906, step = 6800 (0.156 sec)
INFO:tensorflow:loss = 0.3977906, step = 6800 (0.156 sec)
INFO:tensorflow:global_step/sec: 627.06
INFO:tensorflow:global_step/sec: 627.06
INFO:tensorflow:loss = 0.14362156, step = 6900 (0.159 sec)
INFO:tensorflow:loss = 0.14362156, step = 6900 (0.159 sec)
INFO:tensorflow:global_step/sec: 629.254
INFO:tensorflow:global_step/sec: 629.254
INFO:tensorflow:loss = 0.3407214, step = 7000 (0.159 sec)
INFO:tensorflow:loss = 0.3407214, step = 7000 (0.159 sec)
INFO:tensorflow:global_step/sec: 584.666
INFO:tensorflow:global_step/sec: 584.666
INFO:tensorflow:loss = 0.1927482, step = 7100 (0.172 sec)
INFO:tensorflow:loss = 0.1927482, step = 7100 (0.172 sec)
INFO:tensorflow:global_step/sec: 627.371
INFO:tensorflow:global_step/sec: 627.371
INFO:tensorflow:loss = 0.24224454, step = 7200 (0.159 sec)
INFO:tensorflow:loss = 0.24224454, step = 7200 (0.159 sec)
INFO:tensorflow:global_step/sec: 638.217
INFO:tensorflow:global_step/sec: 638.217
INFO:tensorflow:loss = 0.1819179, step = 7300 (0.156 sec)
INFO:tensorflow:loss = 0.1819179, step = 7300 (0.156 sec)
INFO:tensorflow:global_step/sec: 638.029
INFO:tensorflow:global_step/sec: 638.029
INFO:tensorflow:loss = 0.26996285, step = 7400 (0.157 sec)
INFO:tensorflow:loss = 0.26996285, step = 7400 (0.157 sec)
INFO:tensorflow:global_step/sec: 525.172
INFO:tensorflow:global_step/sec: 525.172
INFO:tensorflow:loss = 0.13532382, step = 7500 (0.191 sec)
INFO:tensorflow:loss = 0.13532382, step = 7500 (0.191 sec)
INFO:tensorflow:global_step/sec: 582.699
INFO:tensorflow:global_step/sec: 582.699
INFO:tensorflow:loss = 0.23422551, step = 7600 (0.171 sec)
INFO:tensorflow:loss = 0.23422551, step = 7600 (0.171 sec)
INFO:tensorflow:global_step/sec: 494.859
INFO:tensorflow:global_step/sec: 494.859
INFO:tensorflow:loss = 0.26256862, step = 7700 (0.203 sec)
INFO:tensorflow:loss = 0.26256862, step = 7700 (0.203 sec)
INFO:tensorflow:global_step/sec: 512.49
INFO:tensorflow:global_step/sec: 512.49
INFO:tensorflow:loss = 0.16503167, step = 7800 (0.195 sec)
INFO:tensorflow:loss = 0.16503167, step = 7800 (0.195 sec)
INFO:tensorflow:global_step/sec: 622.401
INFO:tensorflow:global_step/sec: 622.401
INFO:tensorflow:loss = 0.16881448, step = 7900 (0.161 sec)
INFO:tensorflow:loss = 0.16881448, step = 7900 (0.161 sec)
INFO:tensorflow:global_step/sec: 549.244
INFO:tensorflow:global_step/sec: 549.244
INFO:tensorflow:loss = 0.17681262, step = 8000 (0.183 sec)
INFO:tensorflow:loss = 0.17681262, step = 8000 (0.183 sec)
INFO:tensorflow:global_step/sec: 621.287
INFO:tensorflow:global_step/sec: 621.287
INFO:tensorflow:loss = 0.18271996, step = 8100 (0.160 sec)
INFO:tensorflow:loss = 0.18271996, step = 8100 (0.160 sec)
INFO:tensorflow:global_step/sec: 546.022
INFO:tensorflow:global_step/sec: 546.022
INFO:tensorflow:loss = 0.26113954, step = 8200 (0.183 sec)
INFO:tensorflow:loss = 0.26113954, step = 8200 (0.183 sec)
INFO:tensorflow:global_step/sec: 571.733
INFO:tensorflow:global_step/sec: 571.733
INFO:tensorflow:loss = 0.13036585, step = 8300 (0.175 sec)
INFO:tensorflow:loss = 0.13036585, step = 8300 (0.175 sec)
INFO:tensorflow:global_step/sec: 632.036
INFO:tensorflow:global_step/sec: 632.036
INFO:tensorflow:loss = 0.20619017, step = 8400 (0.158 sec)
INFO:tensorflow:loss = 0.20619017, step = 8400 (0.158 sec)
INFO:tensorflow:global_step/sec: 494.028
INFO:tensorflow:global_step/sec: 494.028
INFO:tensorflow:loss = 0.17398489, step = 8500 (0.203 sec)
INFO:tensorflow:loss = 0.17398489, step = 8500 (0.203 sec)
INFO:tensorflow:global_step/sec: 629.83
INFO:tensorflow:global_step/sec: 629.83
INFO:tensorflow:loss = 0.16091517, step = 8600 (0.158 sec)
INFO:tensorflow:loss = 0.16091517, step = 8600 (0.158 sec)
INFO:tensorflow:global_step/sec: 587.259
INFO:tensorflow:global_step/sec: 587.259
INFO:tensorflow:loss = 0.11213518, step = 8700 (0.171 sec)
INFO:tensorflow:loss = 0.11213518, step = 8700 (0.171 sec)
INFO:tensorflow:global_step/sec: 579.446
INFO:tensorflow:global_step/sec: 579.446
INFO:tensorflow:loss = 0.24871698, step = 8800 (0.173 sec)
INFO:tensorflow:loss = 0.24871698, step = 8800 (0.173 sec)
INFO:tensorflow:global_step/sec: 624.93
INFO:tensorflow:global_step/sec: 624.93
INFO:tensorflow:loss = 0.39154494, step = 8900 (0.160 sec)
INFO:tensorflow:loss = 0.39154494, step = 8900 (0.160 sec)
INFO:tensorflow:global_step/sec: 548.934
INFO:tensorflow:global_step/sec: 548.934
INFO:tensorflow:loss = 0.13556643, step = 9000 (0.182 sec)
INFO:tensorflow:loss = 0.13556643, step = 9000 (0.182 sec)
INFO:tensorflow:global_step/sec: 596.807
INFO:tensorflow:global_step/sec: 596.807
INFO:tensorflow:loss = 0.20212807, step = 9100 (0.168 sec)
INFO:tensorflow:loss = 0.20212807, step = 9100 (0.168 sec)
INFO:tensorflow:global_step/sec: 521.752
INFO:tensorflow:global_step/sec: 521.752
INFO:tensorflow:loss = 0.23187144, step = 9200 (0.191 sec)
INFO:tensorflow:loss = 0.23187144, step = 9200 (0.191 sec)
INFO:tensorflow:global_step/sec: 640.669
INFO:tensorflow:global_step/sec: 640.669
INFO:tensorflow:loss = 0.15639503, step = 9300 (0.156 sec)
INFO:tensorflow:loss = 0.15639503, step = 9300 (0.156 sec)
INFO:tensorflow:global_step/sec: 557.668
INFO:tensorflow:global_step/sec: 557.668
INFO:tensorflow:loss = 0.19676518, step = 9400 (0.179 sec)
INFO:tensorflow:loss = 0.19676518, step = 9400 (0.179 sec)
INFO:tensorflow:global_step/sec: 602.812
INFO:tensorflow:global_step/sec: 602.812
INFO:tensorflow:loss = 0.15681529, step = 9500 (0.166 sec)
INFO:tensorflow:loss = 0.15681529, step = 9500 (0.166 sec)
INFO:tensorflow:global_step/sec: 461.091
INFO:tensorflow:global_step/sec: 461.091
INFO:tensorflow:loss = 0.2656338, step = 9600 (0.217 sec)
INFO:tensorflow:loss = 0.2656338, step = 9600 (0.217 sec)
INFO:tensorflow:global_step/sec: 523.942
INFO:tensorflow:global_step/sec: 523.942
INFO:tensorflow:loss = 0.18550211, step = 9700 (0.191 sec)
INFO:tensorflow:loss = 0.18550211, step = 9700 (0.191 sec)
INFO:tensorflow:global_step/sec: 631.217
INFO:tensorflow:global_step/sec: 631.217
INFO:tensorflow:loss = 0.13052791, step = 9800 (0.158 sec)
INFO:tensorflow:loss = 0.13052791, step = 9800 (0.158 sec)
INFO:tensorflow:Requesting early stopping at global step 9853
INFO:tensorflow:Requesting early stopping at global step 9853
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9854...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9854...
INFO:tensorflow:Saving checkpoints for 9854 into /tmpfs/tmp/tmpiy6wag0h/model.ckpt.
INFO:tensorflow:Saving checkpoints for 9854 into /tmpfs/tmp/tmpiy6wag0h/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9854...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9854...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2024-01-11T18:28:07
INFO:tensorflow:Starting evaluation at 2024-01-11T18:28:07
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpiy6wag0h/model.ckpt-9854
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpiy6wag0h/model.ckpt-9854
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Inference Time : 1.06288s
INFO:tensorflow:Inference Time : 1.06288s
INFO:tensorflow:Finished evaluation at 2024-01-11-18:28:08
INFO:tensorflow:Finished evaluation at 2024-01-11-18:28:08
INFO:tensorflow:Saving dict for global step 9854: global_step = 9854, loss = 0.19562133
INFO:tensorflow:Saving dict for global step 9854: global_step = 9854, loss = 0.19562133
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 9854: /tmpfs/tmp/tmpiy6wag0h/model.ckpt-9854
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 9854: /tmpfs/tmp/tmpiy6wag0h/model.ckpt-9854
INFO:tensorflow:Loss for final step: 0.18979602.
INFO:tensorflow:Loss for final step: 0.18979602.
({'loss': 0.19562133, 'global_step': 9854}, [])

TensorFlow 2: 組み込みコールバックと Model.fit による早期停止

MNIST データセットと単純な Keras モデルを準備します。

(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(128)

ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.005),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

TensorFlow 2 では、組み込みの Keras Model.fit(または Model.evaluate)を使用する場合、組み込みのコールバック(tf.keras.callbacks.EarlyStopping)を Model.fitcallbacks パラメータに渡すことで、早期停止を構成できます。

EarlyStopping コールバックは、ユーザー指定の指標を監視し、改善が止まるとトレーニングを終了します。(詳細については、組み込みメソッドによるトレーニングおよび評価または API ドキュメントを確認してください。)

以下は、損失を監視し、改善を示さないエポック数が 3patience)に設定された後にトレーニングを停止する早期停止コールバックの例です。

callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)

# Only around 25 epochs are run during training, instead of 100.
history = model.fit(
    ds_train,
    epochs=100,
    validation_data=ds_test,
    callbacks=[callback]
)

len(history.history['loss'])
Epoch 1/100
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1704997690.618989   68919 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
469/469 [==============================] - 3s 4ms/step - loss: 0.2279 - sparse_categorical_accuracy: 0.9316 - val_loss: 0.1229 - val_sparse_categorical_accuracy: 0.9642
Epoch 2/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0947 - sparse_categorical_accuracy: 0.9714 - val_loss: 0.0974 - val_sparse_categorical_accuracy: 0.9702
Epoch 3/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0667 - sparse_categorical_accuracy: 0.9791 - val_loss: 0.1035 - val_sparse_categorical_accuracy: 0.9698
Epoch 4/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0493 - sparse_categorical_accuracy: 0.9848 - val_loss: 0.1030 - val_sparse_categorical_accuracy: 0.9715
Epoch 5/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0385 - sparse_categorical_accuracy: 0.9866 - val_loss: 0.1438 - val_sparse_categorical_accuracy: 0.9636
Epoch 6/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0350 - sparse_categorical_accuracy: 0.9882 - val_loss: 0.1081 - val_sparse_categorical_accuracy: 0.9725
Epoch 7/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0310 - sparse_categorical_accuracy: 0.9892 - val_loss: 0.1100 - val_sparse_categorical_accuracy: 0.9735
Epoch 8/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0285 - sparse_categorical_accuracy: 0.9905 - val_loss: 0.1180 - val_sparse_categorical_accuracy: 0.9744
Epoch 9/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0266 - sparse_categorical_accuracy: 0.9908 - val_loss: 0.1124 - val_sparse_categorical_accuracy: 0.9761
Epoch 10/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0219 - sparse_categorical_accuracy: 0.9925 - val_loss: 0.1240 - val_sparse_categorical_accuracy: 0.9748
Epoch 11/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0200 - sparse_categorical_accuracy: 0.9932 - val_loss: 0.1391 - val_sparse_categorical_accuracy: 0.9754
Epoch 12/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0206 - sparse_categorical_accuracy: 0.9928 - val_loss: 0.1585 - val_sparse_categorical_accuracy: 0.9713
Epoch 13/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0236 - sparse_categorical_accuracy: 0.9924 - val_loss: 0.1529 - val_sparse_categorical_accuracy: 0.9748
Epoch 14/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0200 - sparse_categorical_accuracy: 0.9932 - val_loss: 0.1491 - val_sparse_categorical_accuracy: 0.9759
14

TensorFlow 2: カスタムコールバックと Model.fit による早期停止

Model.fit(または Model.evaluate)の callbacks パラメータに渡すこともできるカスタムの早期停止コールバックを実装することもできます。

この例では、self.model.stop_trainingTrue に設定されると、トレーニングプロセスが停止されます。

class LimitTrainingTime(tf.keras.callbacks.Callback):
  def __init__(self, max_time_s):
    super().__init__()
    self.max_time_s = max_time_s
    self.start_time = None

  def on_train_begin(self, logs):
    self.start_time = time.time()

  def on_train_batch_end(self, batch, logs):
    now = time.time()
    if now - self.start_time >  self.max_time_s:
      self.model.stop_training = True
# Limit the training time to 30 seconds.
callback = LimitTrainingTime(30)
history = model.fit(
    ds_train,
    epochs=100,
    validation_data=ds_test,
    callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0194 - sparse_categorical_accuracy: 0.9939 - val_loss: 0.1414 - val_sparse_categorical_accuracy: 0.9793
Epoch 2/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0175 - sparse_categorical_accuracy: 0.9941 - val_loss: 0.1568 - val_sparse_categorical_accuracy: 0.9779
Epoch 3/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0144 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.1515 - val_sparse_categorical_accuracy: 0.9771
Epoch 4/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0180 - sparse_categorical_accuracy: 0.9947 - val_loss: 0.1559 - val_sparse_categorical_accuracy: 0.9774
Epoch 5/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0212 - sparse_categorical_accuracy: 0.9938 - val_loss: 0.1807 - val_sparse_categorical_accuracy: 0.9735
Epoch 6/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0154 - sparse_categorical_accuracy: 0.9953 - val_loss: 0.1803 - val_sparse_categorical_accuracy: 0.9761
Epoch 7/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0150 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1661 - val_sparse_categorical_accuracy: 0.9778
Epoch 8/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0150 - sparse_categorical_accuracy: 0.9954 - val_loss: 0.1636 - val_sparse_categorical_accuracy: 0.9762
Epoch 9/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0143 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.1799 - val_sparse_categorical_accuracy: 0.9761
Epoch 10/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0143 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.1802 - val_sparse_categorical_accuracy: 0.9792
Epoch 11/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0139 - sparse_categorical_accuracy: 0.9956 - val_loss: 0.1959 - val_sparse_categorical_accuracy: 0.9747
Epoch 12/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0194 - sparse_categorical_accuracy: 0.9947 - val_loss: 0.2243 - val_sparse_categorical_accuracy: 0.9748
Epoch 13/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0120 - sparse_categorical_accuracy: 0.9969 - val_loss: 0.1788 - val_sparse_categorical_accuracy: 0.9787
Epoch 14/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0118 - sparse_categorical_accuracy: 0.9964 - val_loss: 0.1857 - val_sparse_categorical_accuracy: 0.9786
Epoch 15/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0113 - sparse_categorical_accuracy: 0.9969 - val_loss: 0.2032 - val_sparse_categorical_accuracy: 0.9774
Epoch 16/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0126 - sparse_categorical_accuracy: 0.9966 - val_loss: 0.2107 - val_sparse_categorical_accuracy: 0.9773
Epoch 17/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0141 - sparse_categorical_accuracy: 0.9961 - val_loss: 0.2157 - val_sparse_categorical_accuracy: 0.9759
Epoch 18/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0116 - sparse_categorical_accuracy: 0.9966 - val_loss: 0.2189 - val_sparse_categorical_accuracy: 0.9772
Epoch 19/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0140 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2316 - val_sparse_categorical_accuracy: 0.9772
Epoch 20/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0116 - sparse_categorical_accuracy: 0.9968 - val_loss: 0.2400 - val_sparse_categorical_accuracy: 0.9775
Epoch 21/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0108 - sparse_categorical_accuracy: 0.9972 - val_loss: 0.2283 - val_sparse_categorical_accuracy: 0.9770
Epoch 22/100
469/469 [==============================] - 1s 1ms/step - loss: 0.0093 - sparse_categorical_accuracy: 0.9970 - val_loss: 0.2263 - val_sparse_categorical_accuracy: 0.9784
22

TensorFlow 2: カスタムトレーニングループによる早期停止

TensorFlow 2 では、組み込みの Keras メソッドを使用してトレーニングと評価を行っていない場合、カスタムトレーニングループで早期停止を実装できます。

Keras API を使用して、別の単純なモデル、オプティマイザ、損失関数、および指標を定義することから始めます。

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])

optimizer = tf.keras.optimizers.Adam(0.005)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
train_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
val_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()

tf.GradientTape@tf.function デコレータを使用してパラメータ更新関数を定義し、スピードアップします

@tf.function
def train_step(x, y):
  with tf.GradientTape() as tape:
      logits = model(x, training=True)
      loss_value = loss_fn(y, logits)
  grads = tape.gradient(loss_value, model.trainable_weights)
  optimizer.apply_gradients(zip(grads, model.trainable_weights))
  train_acc_metric.update_state(y, logits)
  train_loss_metric.update_state(y, logits)
  return loss_value

@tf.function
def test_step(x, y):
  logits = model(x, training=False)
  val_acc_metric.update_state(y, logits)
  val_loss_metric.update_state(y, logits)

次に、早期停止ルールを手動で実装できるカスタムトレーニングループを記述します。

以下の例は、検証損失が特定のエポック数にわたって改善されない場合にトレーニングを停止する方法を示しています。

epochs = 100
patience = 5
wait = 0
best = float('inf')

for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    for step, (x_batch_train, y_batch_train) in enumerate(ds_train):
      loss_value = train_step(x_batch_train, y_batch_train)
      if step % 200 == 0:
        print("Training loss at step %d: %.4f" % (step, loss_value.numpy()))
        print("Seen so far: %s samples" % ((step + 1) * 128))        
    train_acc = train_acc_metric.result()
    train_loss = train_loss_metric.result()
    train_acc_metric.reset_states()
    train_loss_metric.reset_states()
    print("Training acc over epoch: %.4f" % (train_acc.numpy()))

    for x_batch_val, y_batch_val in ds_test:
      test_step(x_batch_val, y_batch_val)
    val_acc = val_acc_metric.result()
    val_loss = val_loss_metric.result()
    val_acc_metric.reset_states()
    val_loss_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))

    # The early stopping strategy: stop the training if `val_loss` does not
    # decrease over a certain number of epochs.
    wait += 1
    if val_loss < best:
      best = val_loss
      wait = 0
    if wait >= patience:
      break
Start of epoch 0
Training loss at step 0: 2.3667
Seen so far: 128 samples
Training loss at step 200: 0.2732
Seen so far: 25728 samples
Training loss at step 400: 0.2398
Seen so far: 51328 samples
Training acc over epoch: 0.9300
Validation acc: 0.9622
Time taken: 2.85s

Start of epoch 1
Training loss at step 0: 0.0916
Seen so far: 128 samples
Training loss at step 200: 0.1402
Seen so far: 25728 samples
Training loss at step 400: 0.1649
Seen so far: 51328 samples
Training acc over epoch: 0.9703
Validation acc: 0.9690
Time taken: 1.11s

Start of epoch 2
Training loss at step 0: 0.0642
Seen so far: 128 samples
Training loss at step 200: 0.0878
Seen so far: 25728 samples
Training loss at step 400: 0.1178
Seen so far: 51328 samples
Training acc over epoch: 0.9791
Validation acc: 0.9725
Time taken: 1.09s

Start of epoch 3
Training loss at step 0: 0.0279
Seen so far: 128 samples
Training loss at step 200: 0.0560
Seen so far: 25728 samples
Training loss at step 400: 0.0390
Seen so far: 51328 samples
Training acc over epoch: 0.9834
Validation acc: 0.9717
Time taken: 1.19s

Start of epoch 4
Training loss at step 0: 0.0122
Seen so far: 128 samples
Training loss at step 200: 0.0434
Seen so far: 25728 samples
Training loss at step 400: 0.0586
Seen so far: 51328 samples
Training acc over epoch: 0.9868
Validation acc: 0.9709
Time taken: 1.14s

Start of epoch 5
Training loss at step 0: 0.0321
Seen so far: 128 samples
Training loss at step 200: 0.0287
Seen so far: 25728 samples
Training loss at step 400: 0.0125
Seen so far: 51328 samples
Training acc over epoch: 0.9879
Validation acc: 0.9750
Time taken: 1.14s

Start of epoch 6
Training loss at step 0: 0.0093
Seen so far: 128 samples
Training loss at step 200: 0.0358
Seen so far: 25728 samples
Training loss at step 400: 0.0420
Seen so far: 51328 samples
Training acc over epoch: 0.9886
Validation acc: 0.9719
Time taken: 1.13s

Start of epoch 7
Training loss at step 0: 0.0047
Seen so far: 128 samples
Training loss at step 200: 0.0080
Seen so far: 25728 samples
Training loss at step 400: 0.0271
Seen so far: 51328 samples
Training acc over epoch: 0.9909
Validation acc: 0.9701
Time taken: 1.20s

Start of epoch 8
Training loss at step 0: 0.0180
Seen so far: 128 samples
Training loss at step 200: 0.0406
Seen so far: 25728 samples
Training loss at step 400: 0.0498
Seen so far: 51328 samples
Training acc over epoch: 0.9913
Validation acc: 0.9708
Time taken: 1.18s

Start of epoch 9
Training loss at step 0: 0.0173
Seen so far: 128 samples
Training loss at step 200: 0.0444
Seen so far: 25728 samples
Training loss at step 400: 0.0116
Seen so far: 51328 samples
Training acc over epoch: 0.9916
Validation acc: 0.9738
Time taken: 1.18s

Start of epoch 10
Training loss at step 0: 0.0438
Seen so far: 128 samples
Training loss at step 200: 0.0663
Seen so far: 25728 samples
Training loss at step 400: 0.0328
Seen so far: 51328 samples
Training acc over epoch: 0.9925
Validation acc: 0.9725
Time taken: 1.17s

Start of epoch 11
Training loss at step 0: 0.0006
Seen so far: 128 samples
Training loss at step 200: 0.0269
Seen so far: 25728 samples
Training loss at step 400: 0.0058
Seen so far: 51328 samples
Training acc over epoch: 0.9919
Validation acc: 0.9719
Time taken: 1.12s

Start of epoch 12
Training loss at step 0: 0.0594
Seen so far: 128 samples
Training loss at step 200: 0.0026
Seen so far: 25728 samples
Training loss at step 400: 0.0661
Seen so far: 51328 samples
Training acc over epoch: 0.9926
Validation acc: 0.9765
Time taken: 1.16s

Start of epoch 13
Training loss at step 0: 0.0003
Seen so far: 128 samples
Training loss at step 200: 0.0214
Seen so far: 25728 samples
Training loss at step 400: 0.0048
Seen so far: 51328 samples
Training acc over epoch: 0.9936
Validation acc: 0.9763
Time taken: 1.24s

Start of epoch 14
Training loss at step 0: 0.0008
Seen so far: 128 samples
Training loss at step 200: 0.0078
Seen so far: 25728 samples
Training loss at step 400: 0.0009
Seen so far: 51328 samples
Training acc over epoch: 0.9941
Validation acc: 0.9746
Time taken: 1.08s

Start of epoch 15
Training loss at step 0: 0.0020
Seen so far: 128 samples
Training loss at step 200: 0.0047
Seen so far: 25728 samples
Training loss at step 400: 0.0289
Seen so far: 51328 samples
Training acc over epoch: 0.9937
Validation acc: 0.9753
Time taken: 1.10s

Start of epoch 16
Training loss at step 0: 0.0063
Seen so far: 128 samples
Training loss at step 200: 0.0698
Seen so far: 25728 samples
Training loss at step 400: 0.0442
Seen so far: 51328 samples
Training acc over epoch: 0.9941
Validation acc: 0.9737
Time taken: 1.08s

Start of epoch 17
Training loss at step 0: 0.0019
Seen so far: 128 samples
Training loss at step 200: 0.0066
Seen so far: 25728 samples
Training loss at step 400: 0.0168
Seen so far: 51328 samples
Training acc over epoch: 0.9945
Validation acc: 0.9745
Time taken: 1.17s

Start of epoch 18
Training loss at step 0: 0.0025
Seen so far: 128 samples
Training loss at step 200: 0.0123
Seen so far: 25728 samples
Training loss at step 400: 0.0914
Seen so far: 51328 samples
Training acc over epoch: 0.9929
Validation acc: 0.9700
Time taken: 1.07s

Start of epoch 19
Training loss at step 0: 0.0029
Seen so far: 128 samples
Training loss at step 200: 0.0456
Seen so far: 25728 samples
Training loss at step 400: 0.0002
Seen so far: 51328 samples
Training acc over epoch: 0.9947
Validation acc: 0.9758
Time taken: 1.15s

Start of epoch 20
Training loss at step 0: 0.0002
Seen so far: 128 samples
Training loss at step 200: 0.0008
Seen so far: 25728 samples
Training loss at step 400: 0.0010
Seen so far: 51328 samples
Training acc over epoch: 0.9966
Validation acc: 0.9771
Time taken: 1.11s

Start of epoch 21
Training loss at step 0: 0.0100
Seen so far: 128 samples
Training loss at step 200: 0.0139
Seen so far: 25728 samples
Training loss at step 400: 0.0398
Seen so far: 51328 samples
Training acc over epoch: 0.9961
Validation acc: 0.9744
Time taken: 1.09s

Start of epoch 22
Training loss at step 0: 0.0000
Seen so far: 128 samples
Training loss at step 200: 0.0008
Seen so far: 25728 samples
Training loss at step 400: 0.0008
Seen so far: 51328 samples
Training acc over epoch: 0.9958
Validation acc: 0.9784
Time taken: 1.05s

Start of epoch 23
Training loss at step 0: 0.0009
Seen so far: 128 samples
Training loss at step 200: 0.0419
Seen so far: 25728 samples
Training loss at step 400: 0.0020
Seen so far: 51328 samples
Training acc over epoch: 0.9959
Validation acc: 0.9764
Time taken: 1.11s

Start of epoch 24
Training loss at step 0: 0.0002
Seen so far: 128 samples
Training loss at step 200: 0.0192
Seen so far: 25728 samples
Training loss at step 400: 0.0186
Seen so far: 51328 samples
Training acc over epoch: 0.9955
Validation acc: 0.9754
Time taken: 1.08s

Start of epoch 25
Training loss at step 0: 0.0009
Seen so far: 128 samples
Training loss at step 200: 0.0232
Seen so far: 25728 samples
Training loss at step 400: 0.0007
Seen so far: 51328 samples
Training acc over epoch: 0.9956
Validation acc: 0.9772
Time taken: 1.09s

Start of epoch 26
Training loss at step 0: 0.0001
Seen so far: 128 samples
Training loss at step 200: 0.0427
Seen so far: 25728 samples
Training loss at step 400: 0.0034
Seen so far: 51328 samples
Training acc over epoch: 0.9946
Validation acc: 0.9785
Time taken: 1.12s

Start of epoch 27
Training loss at step 0: 0.0013
Seen so far: 128 samples
Training loss at step 200: 0.0092
Seen so far: 25728 samples
Training loss at step 400: 0.0955
Seen so far: 51328 samples
Training acc over epoch: 0.9962
Validation acc: 0.9755
Time taken: 1.12s

Next steps