TensorFlow.org で表示 | Google Colab で実行 | GitHub でソースを表示 | ノートブックをダウンロード |
このノートブックは、最初に TensorFlow 1 で tf.estimator.Estimator
と早期停止フックを使用してから、次に TensorFlow 2 で Keras API またはカスタムトレーニングループを使用して、早期停止を使用してモデルトレーニングをセットアップする方法を示します。早期停止は、たとえば検証損失が特定のしきい値に達した場合にトレーニングを停止する正則化手法です。
TensorFlow 2 では、早期停止を実装する 3 つの方法があります。
- 組み込みの Keras コールバック(
tf.keras.callbacks.EarlyStopping
)を使用して、Model.fit
に渡します。 - カスタムコールバックを定義し、Keras
Model.fit
に渡します。 - カスタムトレーニングループでカスタム早期停止ルールを記述します(
tf.GradientTape
を使用)。
セットアップ
import time
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_datasets as tfds
2024-01-11 18:27:43.622883: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-01-11 18:27:43.622930: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-01-11 18:27:43.624662: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
TensorFlow 1: 早期停止フックと tf.estimator による早期停止
MNIST データセットの読み込みと前処理、および tf.estimator.Estimator
で使用されるモデル定義の関数を定義することから始めます。
def normalize_img(image, label):
return tf.cast(image, tf.float32) / 255., label
def _input_fn():
ds_train = tfds.load(
name='mnist',
split='train',
shuffle_files=True,
as_supervised=True)
ds_train = ds_train.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(128)
ds_train = ds_train.repeat(100)
return ds_train
def _eval_input_fn():
ds_test = tfds.load(
name='mnist',
split='test',
shuffle_files=True,
as_supervised=True)
ds_test = ds_test.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
return ds_test
def _model_fn(features, labels, mode):
flatten = tf1.layers.Flatten()(features)
features = tf1.layers.Dense(128, 'relu')(flatten)
logits = tf1.layers.Dense(10)(features)
loss = tf1.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
optimizer = tf1.train.AdagradOptimizer(0.005)
train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
TensorFlow 1 では、早期停止は tf.estimator.experimental.make_early_stopping_hook
で早期停止フックを設定することで機能します。引数なしで関数を受け入れることができる should_stop_fn
のパラメータとして、フックを make_early_stopping_hook
メソッドに渡します。 should_stop_fn
が True
を返すと、トレーニングは停止します。
次の例は、トレーニング時間を最大 20 秒に制限する早期停止手法を実装する方法を示しています。
estimator = tf1.estimator.Estimator(model_fn=_model_fn)
start_time = time.time()
max_train_seconds = 20
def should_stop_fn():
return time.time() - start_time > max_train_seconds
early_stopping_hook = tf1.estimator.experimental.make_early_stopping_hook(
estimator=estimator,
should_stop_fn=should_stop_fn,
run_every_secs=1,
run_every_steps=None)
train_spec = tf1.estimator.TrainSpec(
input_fn=_input_fn,
hooks=[early_stopping_hook])
eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)
tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_68749/1011025907.py:1: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1844: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpiy6wag0h INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpiy6wag0h', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} WARNING:tensorflow:From /tmpfs/tmp/ipykernel_68749/1011025907.py:9: make_early_stopping_hook (from tensorflow_estimator.python.estimator.early_stopping) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/early_stopping.py:474: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_68749/1011025907.py:15: TrainSpec.__new__ (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_68749/1011025907.py:19: EvalSpec.__new__ (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_68749/1011025907.py:21: train_and_evaluate (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Not using Distribute Coordinator. INFO:tensorflow:Running training and evaluation locally (non-distributed). INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor WARNING:tensorflow:From /tmpfs/tmp/ipykernel_68749/1468818800.py:37: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_68749/1468818800.py:37: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpiy6wag0h/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpiy6wag0h/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:loss = 2.4071355, step = 0 INFO:tensorflow:loss = 2.4071355, step = 0 INFO:tensorflow:global_step/sec: 383.516 INFO:tensorflow:global_step/sec: 383.516 INFO:tensorflow:loss = 1.2792293, step = 100 (0.263 sec) INFO:tensorflow:loss = 1.2792293, step = 100 (0.263 sec) INFO:tensorflow:global_step/sec: 450.873 INFO:tensorflow:global_step/sec: 450.873 INFO:tensorflow:loss = 0.77129304, step = 200 (0.221 sec) INFO:tensorflow:loss = 0.77129304, step = 200 (0.221 sec) INFO:tensorflow:global_step/sec: 466.825 INFO:tensorflow:global_step/sec: 466.825 INFO:tensorflow:loss = 0.6746831, step = 300 (0.214 sec) INFO:tensorflow:loss = 0.6746831, step = 300 (0.214 sec) INFO:tensorflow:global_step/sec: 455.492 INFO:tensorflow:global_step/sec: 455.492 INFO:tensorflow:loss = 0.5994307, step = 400 (0.220 sec) INFO:tensorflow:loss = 0.5994307, step = 400 (0.220 sec) INFO:tensorflow:global_step/sec: 462.334 INFO:tensorflow:global_step/sec: 462.334 INFO:tensorflow:loss = 0.47565445, step = 500 (0.216 sec) INFO:tensorflow:loss = 0.47565445, step = 500 (0.216 sec) INFO:tensorflow:global_step/sec: 636.411 INFO:tensorflow:global_step/sec: 636.411 INFO:tensorflow:loss = 0.43863237, step = 600 (0.157 sec) INFO:tensorflow:loss = 0.43863237, step = 600 (0.157 sec) INFO:tensorflow:global_step/sec: 637.375 INFO:tensorflow:global_step/sec: 637.375 INFO:tensorflow:loss = 0.35803005, step = 700 (0.157 sec) INFO:tensorflow:loss = 0.35803005, step = 700 (0.157 sec) INFO:tensorflow:global_step/sec: 641.452 INFO:tensorflow:global_step/sec: 641.452 INFO:tensorflow:loss = 0.50397503, step = 800 (0.155 sec) INFO:tensorflow:loss = 0.50397503, step = 800 (0.155 sec) INFO:tensorflow:global_step/sec: 617.79 INFO:tensorflow:global_step/sec: 617.79 INFO:tensorflow:loss = 0.38481098, step = 900 (0.162 sec) INFO:tensorflow:loss = 0.38481098, step = 900 (0.162 sec) INFO:tensorflow:global_step/sec: 548.636 INFO:tensorflow:global_step/sec: 548.636 INFO:tensorflow:loss = 0.43239072, step = 1000 (0.183 sec) INFO:tensorflow:loss = 0.43239072, step = 1000 (0.183 sec) INFO:tensorflow:global_step/sec: 578.089 INFO:tensorflow:global_step/sec: 578.089 INFO:tensorflow:loss = 0.43478006, step = 1100 (0.173 sec) INFO:tensorflow:loss = 0.43478006, step = 1100 (0.173 sec) INFO:tensorflow:global_step/sec: 637.798 INFO:tensorflow:global_step/sec: 637.798 INFO:tensorflow:loss = 0.38385656, step = 1200 (0.157 sec) INFO:tensorflow:loss = 0.38385656, step = 1200 (0.157 sec) INFO:tensorflow:global_step/sec: 640.452 INFO:tensorflow:global_step/sec: 640.452 INFO:tensorflow:loss = 0.45495924, step = 1300 (0.156 sec) INFO:tensorflow:loss = 0.45495924, step = 1300 (0.156 sec) INFO:tensorflow:global_step/sec: 630.436 INFO:tensorflow:global_step/sec: 630.436 INFO:tensorflow:loss = 0.29222503, step = 1400 (0.160 sec) INFO:tensorflow:loss = 0.29222503, step = 1400 (0.160 sec) INFO:tensorflow:global_step/sec: 513.24 INFO:tensorflow:global_step/sec: 513.24 INFO:tensorflow:loss = 0.27475923, step = 1500 (0.194 sec) INFO:tensorflow:loss = 0.27475923, step = 1500 (0.194 sec) INFO:tensorflow:global_step/sec: 618.569 INFO:tensorflow:global_step/sec: 618.569 INFO:tensorflow:loss = 0.3822915, step = 1600 (0.162 sec) INFO:tensorflow:loss = 0.3822915, step = 1600 (0.162 sec) INFO:tensorflow:global_step/sec: 619.712 INFO:tensorflow:global_step/sec: 619.712 INFO:tensorflow:loss = 0.34821516, step = 1700 (0.161 sec) INFO:tensorflow:loss = 0.34821516, step = 1700 (0.161 sec) INFO:tensorflow:global_step/sec: 603.56 INFO:tensorflow:global_step/sec: 603.56 INFO:tensorflow:loss = 0.31085092, step = 1800 (0.166 sec) INFO:tensorflow:loss = 0.31085092, step = 1800 (0.166 sec) INFO:tensorflow:global_step/sec: 545.691 INFO:tensorflow:global_step/sec: 545.691 INFO:tensorflow:loss = 0.5044548, step = 1900 (0.183 sec) INFO:tensorflow:loss = 0.5044548, step = 1900 (0.183 sec) INFO:tensorflow:global_step/sec: 578.335 INFO:tensorflow:global_step/sec: 578.335 INFO:tensorflow:loss = 0.20964296, step = 2000 (0.173 sec) INFO:tensorflow:loss = 0.20964296, step = 2000 (0.173 sec) INFO:tensorflow:global_step/sec: 632.198 INFO:tensorflow:global_step/sec: 632.198 INFO:tensorflow:loss = 0.25611854, step = 2100 (0.158 sec) INFO:tensorflow:loss = 0.25611854, step = 2100 (0.158 sec) INFO:tensorflow:global_step/sec: 561.831 INFO:tensorflow:global_step/sec: 561.831 INFO:tensorflow:loss = 0.29500312, step = 2200 (0.178 sec) INFO:tensorflow:loss = 0.29500312, step = 2200 (0.178 sec) INFO:tensorflow:global_step/sec: 636.661 INFO:tensorflow:global_step/sec: 636.661 INFO:tensorflow:loss = 0.3458867, step = 2300 (0.157 sec) INFO:tensorflow:loss = 0.3458867, step = 2300 (0.157 sec) INFO:tensorflow:global_step/sec: 578.296 INFO:tensorflow:global_step/sec: 578.296 INFO:tensorflow:loss = 0.2508843, step = 2400 (0.173 sec) INFO:tensorflow:loss = 0.2508843, step = 2400 (0.173 sec) INFO:tensorflow:global_step/sec: 580.748 INFO:tensorflow:global_step/sec: 580.748 INFO:tensorflow:loss = 0.22855805, step = 2500 (0.172 sec) INFO:tensorflow:loss = 0.22855805, step = 2500 (0.172 sec) INFO:tensorflow:global_step/sec: 602.423 INFO:tensorflow:global_step/sec: 602.423 INFO:tensorflow:loss = 0.1585938, step = 2600 (0.166 sec) INFO:tensorflow:loss = 0.1585938, step = 2600 (0.166 sec) INFO:tensorflow:global_step/sec: 634.883 INFO:tensorflow:global_step/sec: 634.883 INFO:tensorflow:loss = 0.30701032, step = 2700 (0.158 sec) INFO:tensorflow:loss = 0.30701032, step = 2700 (0.158 sec) INFO:tensorflow:global_step/sec: 635.129 INFO:tensorflow:global_step/sec: 635.129 INFO:tensorflow:loss = 0.46658728, step = 2800 (0.157 sec) INFO:tensorflow:loss = 0.46658728, step = 2800 (0.157 sec) INFO:tensorflow:global_step/sec: 522.673 INFO:tensorflow:global_step/sec: 522.673 INFO:tensorflow:loss = 0.22727002, step = 2900 (0.191 sec) INFO:tensorflow:loss = 0.22727002, step = 2900 (0.191 sec) INFO:tensorflow:global_step/sec: 631.452 INFO:tensorflow:global_step/sec: 631.452 INFO:tensorflow:loss = 0.32864082, step = 3000 (0.158 sec) INFO:tensorflow:loss = 0.32864082, step = 3000 (0.158 sec) INFO:tensorflow:global_step/sec: 628.747 INFO:tensorflow:global_step/sec: 628.747 INFO:tensorflow:loss = 0.20764841, step = 3100 (0.159 sec) INFO:tensorflow:loss = 0.20764841, step = 3100 (0.159 sec) INFO:tensorflow:global_step/sec: 627.033 INFO:tensorflow:global_step/sec: 627.033 INFO:tensorflow:loss = 0.42431408, step = 3200 (0.159 sec) INFO:tensorflow:loss = 0.42431408, step = 3200 (0.159 sec) INFO:tensorflow:global_step/sec: 543.587 INFO:tensorflow:global_step/sec: 543.587 INFO:tensorflow:loss = 0.324377, step = 3300 (0.185 sec) INFO:tensorflow:loss = 0.324377, step = 3300 (0.185 sec) INFO:tensorflow:global_step/sec: 617.368 INFO:tensorflow:global_step/sec: 617.368 INFO:tensorflow:loss = 0.2624207, step = 3400 (0.162 sec) INFO:tensorflow:loss = 0.2624207, step = 3400 (0.162 sec) INFO:tensorflow:global_step/sec: 621.767 INFO:tensorflow:global_step/sec: 621.767 INFO:tensorflow:loss = 0.19428356, step = 3500 (0.160 sec) INFO:tensorflow:loss = 0.19428356, step = 3500 (0.160 sec) INFO:tensorflow:global_step/sec: 632.717 INFO:tensorflow:global_step/sec: 632.717 INFO:tensorflow:loss = 0.23018332, step = 3600 (0.158 sec) INFO:tensorflow:loss = 0.23018332, step = 3600 (0.158 sec) INFO:tensorflow:global_step/sec: 624.224 INFO:tensorflow:global_step/sec: 624.224 INFO:tensorflow:loss = 0.2818337, step = 3700 (0.160 sec) INFO:tensorflow:loss = 0.2818337, step = 3700 (0.160 sec) INFO:tensorflow:global_step/sec: 485.89 INFO:tensorflow:global_step/sec: 485.89 INFO:tensorflow:loss = 0.35542685, step = 3800 (0.206 sec) INFO:tensorflow:loss = 0.35542685, step = 3800 (0.206 sec) INFO:tensorflow:global_step/sec: 578.011 INFO:tensorflow:global_step/sec: 578.011 INFO:tensorflow:loss = 0.22460853, step = 3900 (0.174 sec) INFO:tensorflow:loss = 0.22460853, step = 3900 (0.174 sec) INFO:tensorflow:global_step/sec: 588.815 INFO:tensorflow:global_step/sec: 588.815 INFO:tensorflow:loss = 0.28191128, step = 4000 (0.170 sec) INFO:tensorflow:loss = 0.28191128, step = 4000 (0.170 sec) INFO:tensorflow:global_step/sec: 583.65 INFO:tensorflow:global_step/sec: 583.65 INFO:tensorflow:loss = 0.19746388, step = 4100 (0.172 sec) INFO:tensorflow:loss = 0.19746388, step = 4100 (0.172 sec) INFO:tensorflow:global_step/sec: 629.072 INFO:tensorflow:global_step/sec: 629.072 INFO:tensorflow:loss = 0.25752583, step = 4200 (0.159 sec) INFO:tensorflow:loss = 0.25752583, step = 4200 (0.159 sec) INFO:tensorflow:global_step/sec: 548.53 INFO:tensorflow:global_step/sec: 548.53 INFO:tensorflow:loss = 0.29156587, step = 4300 (0.182 sec) INFO:tensorflow:loss = 0.29156587, step = 4300 (0.182 sec) INFO:tensorflow:global_step/sec: 584.561 INFO:tensorflow:global_step/sec: 584.561 INFO:tensorflow:loss = 0.3017162, step = 4400 (0.171 sec) INFO:tensorflow:loss = 0.3017162, step = 4400 (0.171 sec) INFO:tensorflow:global_step/sec: 596.634 INFO:tensorflow:global_step/sec: 596.634 INFO:tensorflow:loss = 0.26009542, step = 4500 (0.167 sec) INFO:tensorflow:loss = 0.26009542, step = 4500 (0.167 sec) INFO:tensorflow:global_step/sec: 629.979 INFO:tensorflow:global_step/sec: 629.979 INFO:tensorflow:loss = 0.32197523, step = 4600 (0.159 sec) INFO:tensorflow:loss = 0.32197523, step = 4600 (0.159 sec) INFO:tensorflow:global_step/sec: 580.103 INFO:tensorflow:global_step/sec: 580.103 INFO:tensorflow:loss = 0.15014084, step = 4700 (0.173 sec) INFO:tensorflow:loss = 0.15014084, step = 4700 (0.173 sec) INFO:tensorflow:global_step/sec: 625.18 INFO:tensorflow:global_step/sec: 625.18 INFO:tensorflow:loss = 0.28748113, step = 4800 (0.159 sec) INFO:tensorflow:loss = 0.28748113, step = 4800 (0.159 sec) INFO:tensorflow:global_step/sec: 622.796 INFO:tensorflow:global_step/sec: 622.796 INFO:tensorflow:loss = 0.36667243, step = 4900 (0.161 sec) INFO:tensorflow:loss = 0.36667243, step = 4900 (0.161 sec) INFO:tensorflow:global_step/sec: 588.541 INFO:tensorflow:global_step/sec: 588.541 INFO:tensorflow:loss = 0.2555472, step = 5000 (0.170 sec) INFO:tensorflow:loss = 0.2555472, step = 5000 (0.170 sec) INFO:tensorflow:global_step/sec: 541.24 INFO:tensorflow:global_step/sec: 541.24 INFO:tensorflow:loss = 0.321966, step = 5100 (0.185 sec) INFO:tensorflow:loss = 0.321966, step = 5100 (0.185 sec) INFO:tensorflow:global_step/sec: 559.688 INFO:tensorflow:global_step/sec: 559.688 INFO:tensorflow:loss = 0.21373868, step = 5200 (0.179 sec) INFO:tensorflow:loss = 0.21373868, step = 5200 (0.179 sec) INFO:tensorflow:global_step/sec: 614.89 INFO:tensorflow:global_step/sec: 614.89 INFO:tensorflow:loss = 0.2476358, step = 5300 (0.163 sec) INFO:tensorflow:loss = 0.2476358, step = 5300 (0.163 sec) INFO:tensorflow:global_step/sec: 618.975 INFO:tensorflow:global_step/sec: 618.975 INFO:tensorflow:loss = 0.16120149, step = 5400 (0.162 sec) INFO:tensorflow:loss = 0.16120149, step = 5400 (0.162 sec) INFO:tensorflow:global_step/sec: 613.135 INFO:tensorflow:global_step/sec: 613.135 INFO:tensorflow:loss = 0.20572953, step = 5500 (0.163 sec) INFO:tensorflow:loss = 0.20572953, step = 5500 (0.163 sec) INFO:tensorflow:global_step/sec: 624.712 INFO:tensorflow:global_step/sec: 624.712 INFO:tensorflow:loss = 0.1752571, step = 5600 (0.160 sec) INFO:tensorflow:loss = 0.1752571, step = 5600 (0.160 sec) INFO:tensorflow:global_step/sec: 544.799 INFO:tensorflow:global_step/sec: 544.799 INFO:tensorflow:loss = 0.16339287, step = 5700 (0.184 sec) INFO:tensorflow:loss = 0.16339287, step = 5700 (0.184 sec) INFO:tensorflow:global_step/sec: 528.028 INFO:tensorflow:global_step/sec: 528.028 INFO:tensorflow:loss = 0.28886348, step = 5800 (0.190 sec) INFO:tensorflow:loss = 0.28886348, step = 5800 (0.190 sec) INFO:tensorflow:global_step/sec: 600.042 INFO:tensorflow:global_step/sec: 600.042 INFO:tensorflow:loss = 0.21474686, step = 5900 (0.167 sec) INFO:tensorflow:loss = 0.21474686, step = 5900 (0.167 sec) INFO:tensorflow:global_step/sec: 584.179 INFO:tensorflow:global_step/sec: 584.179 INFO:tensorflow:loss = 0.2417813, step = 6000 (0.171 sec) INFO:tensorflow:loss = 0.2417813, step = 6000 (0.171 sec) INFO:tensorflow:global_step/sec: 564.652 INFO:tensorflow:global_step/sec: 564.652 INFO:tensorflow:loss = 0.17924747, step = 6100 (0.177 sec) INFO:tensorflow:loss = 0.17924747, step = 6100 (0.177 sec) INFO:tensorflow:global_step/sec: 625.477 INFO:tensorflow:global_step/sec: 625.477 INFO:tensorflow:loss = 0.25654313, step = 6200 (0.160 sec) INFO:tensorflow:loss = 0.25654313, step = 6200 (0.160 sec) INFO:tensorflow:global_step/sec: 650.049 INFO:tensorflow:global_step/sec: 650.049 INFO:tensorflow:loss = 0.24964629, step = 6300 (0.154 sec) INFO:tensorflow:loss = 0.24964629, step = 6300 (0.154 sec) INFO:tensorflow:global_step/sec: 656.719 INFO:tensorflow:global_step/sec: 656.719 INFO:tensorflow:loss = 0.2832839, step = 6400 (0.152 sec) INFO:tensorflow:loss = 0.2832839, step = 6400 (0.152 sec) INFO:tensorflow:global_step/sec: 635.097 INFO:tensorflow:global_step/sec: 635.097 INFO:tensorflow:loss = 0.2779467, step = 6500 (0.157 sec) INFO:tensorflow:loss = 0.2779467, step = 6500 (0.157 sec) INFO:tensorflow:global_step/sec: 518.757 INFO:tensorflow:global_step/sec: 518.757 INFO:tensorflow:loss = 0.18514237, step = 6600 (0.193 sec) INFO:tensorflow:loss = 0.18514237, step = 6600 (0.193 sec) INFO:tensorflow:global_step/sec: 602.511 INFO:tensorflow:global_step/sec: 602.511 INFO:tensorflow:loss = 0.2525888, step = 6700 (0.166 sec) INFO:tensorflow:loss = 0.2525888, step = 6700 (0.166 sec) INFO:tensorflow:global_step/sec: 643.017 INFO:tensorflow:global_step/sec: 643.017 INFO:tensorflow:loss = 0.3977906, step = 6800 (0.156 sec) INFO:tensorflow:loss = 0.3977906, step = 6800 (0.156 sec) INFO:tensorflow:global_step/sec: 627.06 INFO:tensorflow:global_step/sec: 627.06 INFO:tensorflow:loss = 0.14362156, step = 6900 (0.159 sec) INFO:tensorflow:loss = 0.14362156, step = 6900 (0.159 sec) INFO:tensorflow:global_step/sec: 629.254 INFO:tensorflow:global_step/sec: 629.254 INFO:tensorflow:loss = 0.3407214, step = 7000 (0.159 sec) INFO:tensorflow:loss = 0.3407214, step = 7000 (0.159 sec) INFO:tensorflow:global_step/sec: 584.666 INFO:tensorflow:global_step/sec: 584.666 INFO:tensorflow:loss = 0.1927482, step = 7100 (0.172 sec) INFO:tensorflow:loss = 0.1927482, step = 7100 (0.172 sec) INFO:tensorflow:global_step/sec: 627.371 INFO:tensorflow:global_step/sec: 627.371 INFO:tensorflow:loss = 0.24224454, step = 7200 (0.159 sec) INFO:tensorflow:loss = 0.24224454, step = 7200 (0.159 sec) INFO:tensorflow:global_step/sec: 638.217 INFO:tensorflow:global_step/sec: 638.217 INFO:tensorflow:loss = 0.1819179, step = 7300 (0.156 sec) INFO:tensorflow:loss = 0.1819179, step = 7300 (0.156 sec) INFO:tensorflow:global_step/sec: 638.029 INFO:tensorflow:global_step/sec: 638.029 INFO:tensorflow:loss = 0.26996285, step = 7400 (0.157 sec) INFO:tensorflow:loss = 0.26996285, step = 7400 (0.157 sec) INFO:tensorflow:global_step/sec: 525.172 INFO:tensorflow:global_step/sec: 525.172 INFO:tensorflow:loss = 0.13532382, step = 7500 (0.191 sec) INFO:tensorflow:loss = 0.13532382, step = 7500 (0.191 sec) INFO:tensorflow:global_step/sec: 582.699 INFO:tensorflow:global_step/sec: 582.699 INFO:tensorflow:loss = 0.23422551, step = 7600 (0.171 sec) INFO:tensorflow:loss = 0.23422551, step = 7600 (0.171 sec) INFO:tensorflow:global_step/sec: 494.859 INFO:tensorflow:global_step/sec: 494.859 INFO:tensorflow:loss = 0.26256862, step = 7700 (0.203 sec) INFO:tensorflow:loss = 0.26256862, step = 7700 (0.203 sec) INFO:tensorflow:global_step/sec: 512.49 INFO:tensorflow:global_step/sec: 512.49 INFO:tensorflow:loss = 0.16503167, step = 7800 (0.195 sec) INFO:tensorflow:loss = 0.16503167, step = 7800 (0.195 sec) INFO:tensorflow:global_step/sec: 622.401 INFO:tensorflow:global_step/sec: 622.401 INFO:tensorflow:loss = 0.16881448, step = 7900 (0.161 sec) INFO:tensorflow:loss = 0.16881448, step = 7900 (0.161 sec) INFO:tensorflow:global_step/sec: 549.244 INFO:tensorflow:global_step/sec: 549.244 INFO:tensorflow:loss = 0.17681262, step = 8000 (0.183 sec) INFO:tensorflow:loss = 0.17681262, step = 8000 (0.183 sec) INFO:tensorflow:global_step/sec: 621.287 INFO:tensorflow:global_step/sec: 621.287 INFO:tensorflow:loss = 0.18271996, step = 8100 (0.160 sec) INFO:tensorflow:loss = 0.18271996, step = 8100 (0.160 sec) INFO:tensorflow:global_step/sec: 546.022 INFO:tensorflow:global_step/sec: 546.022 INFO:tensorflow:loss = 0.26113954, step = 8200 (0.183 sec) INFO:tensorflow:loss = 0.26113954, step = 8200 (0.183 sec) INFO:tensorflow:global_step/sec: 571.733 INFO:tensorflow:global_step/sec: 571.733 INFO:tensorflow:loss = 0.13036585, step = 8300 (0.175 sec) INFO:tensorflow:loss = 0.13036585, step = 8300 (0.175 sec) INFO:tensorflow:global_step/sec: 632.036 INFO:tensorflow:global_step/sec: 632.036 INFO:tensorflow:loss = 0.20619017, step = 8400 (0.158 sec) INFO:tensorflow:loss = 0.20619017, step = 8400 (0.158 sec) INFO:tensorflow:global_step/sec: 494.028 INFO:tensorflow:global_step/sec: 494.028 INFO:tensorflow:loss = 0.17398489, step = 8500 (0.203 sec) INFO:tensorflow:loss = 0.17398489, step = 8500 (0.203 sec) INFO:tensorflow:global_step/sec: 629.83 INFO:tensorflow:global_step/sec: 629.83 INFO:tensorflow:loss = 0.16091517, step = 8600 (0.158 sec) INFO:tensorflow:loss = 0.16091517, step = 8600 (0.158 sec) INFO:tensorflow:global_step/sec: 587.259 INFO:tensorflow:global_step/sec: 587.259 INFO:tensorflow:loss = 0.11213518, step = 8700 (0.171 sec) INFO:tensorflow:loss = 0.11213518, step = 8700 (0.171 sec) INFO:tensorflow:global_step/sec: 579.446 INFO:tensorflow:global_step/sec: 579.446 INFO:tensorflow:loss = 0.24871698, step = 8800 (0.173 sec) INFO:tensorflow:loss = 0.24871698, step = 8800 (0.173 sec) INFO:tensorflow:global_step/sec: 624.93 INFO:tensorflow:global_step/sec: 624.93 INFO:tensorflow:loss = 0.39154494, step = 8900 (0.160 sec) INFO:tensorflow:loss = 0.39154494, step = 8900 (0.160 sec) INFO:tensorflow:global_step/sec: 548.934 INFO:tensorflow:global_step/sec: 548.934 INFO:tensorflow:loss = 0.13556643, step = 9000 (0.182 sec) INFO:tensorflow:loss = 0.13556643, step = 9000 (0.182 sec) INFO:tensorflow:global_step/sec: 596.807 INFO:tensorflow:global_step/sec: 596.807 INFO:tensorflow:loss = 0.20212807, step = 9100 (0.168 sec) INFO:tensorflow:loss = 0.20212807, step = 9100 (0.168 sec) INFO:tensorflow:global_step/sec: 521.752 INFO:tensorflow:global_step/sec: 521.752 INFO:tensorflow:loss = 0.23187144, step = 9200 (0.191 sec) INFO:tensorflow:loss = 0.23187144, step = 9200 (0.191 sec) INFO:tensorflow:global_step/sec: 640.669 INFO:tensorflow:global_step/sec: 640.669 INFO:tensorflow:loss = 0.15639503, step = 9300 (0.156 sec) INFO:tensorflow:loss = 0.15639503, step = 9300 (0.156 sec) INFO:tensorflow:global_step/sec: 557.668 INFO:tensorflow:global_step/sec: 557.668 INFO:tensorflow:loss = 0.19676518, step = 9400 (0.179 sec) INFO:tensorflow:loss = 0.19676518, step = 9400 (0.179 sec) INFO:tensorflow:global_step/sec: 602.812 INFO:tensorflow:global_step/sec: 602.812 INFO:tensorflow:loss = 0.15681529, step = 9500 (0.166 sec) INFO:tensorflow:loss = 0.15681529, step = 9500 (0.166 sec) INFO:tensorflow:global_step/sec: 461.091 INFO:tensorflow:global_step/sec: 461.091 INFO:tensorflow:loss = 0.2656338, step = 9600 (0.217 sec) INFO:tensorflow:loss = 0.2656338, step = 9600 (0.217 sec) INFO:tensorflow:global_step/sec: 523.942 INFO:tensorflow:global_step/sec: 523.942 INFO:tensorflow:loss = 0.18550211, step = 9700 (0.191 sec) INFO:tensorflow:loss = 0.18550211, step = 9700 (0.191 sec) INFO:tensorflow:global_step/sec: 631.217 INFO:tensorflow:global_step/sec: 631.217 INFO:tensorflow:loss = 0.13052791, step = 9800 (0.158 sec) INFO:tensorflow:loss = 0.13052791, step = 9800 (0.158 sec) INFO:tensorflow:Requesting early stopping at global step 9853 INFO:tensorflow:Requesting early stopping at global step 9853 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9854... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9854... INFO:tensorflow:Saving checkpoints for 9854 into /tmpfs/tmp/tmpiy6wag0h/model.ckpt. INFO:tensorflow:Saving checkpoints for 9854 into /tmpfs/tmp/tmpiy6wag0h/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9854... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9854... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2024-01-11T18:28:07 INFO:tensorflow:Starting evaluation at 2024-01-11T18:28:07 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpiy6wag0h/model.ckpt-9854 INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpiy6wag0h/model.ckpt-9854 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [10/100] INFO:tensorflow:Evaluation [10/100] INFO:tensorflow:Evaluation [20/100] INFO:tensorflow:Evaluation [20/100] INFO:tensorflow:Evaluation [30/100] INFO:tensorflow:Evaluation [30/100] INFO:tensorflow:Evaluation [40/100] INFO:tensorflow:Evaluation [40/100] INFO:tensorflow:Evaluation [50/100] INFO:tensorflow:Evaluation [50/100] INFO:tensorflow:Evaluation [60/100] INFO:tensorflow:Evaluation [60/100] INFO:tensorflow:Evaluation [70/100] INFO:tensorflow:Evaluation [70/100] INFO:tensorflow:Inference Time : 1.06288s INFO:tensorflow:Inference Time : 1.06288s INFO:tensorflow:Finished evaluation at 2024-01-11-18:28:08 INFO:tensorflow:Finished evaluation at 2024-01-11-18:28:08 INFO:tensorflow:Saving dict for global step 9854: global_step = 9854, loss = 0.19562133 INFO:tensorflow:Saving dict for global step 9854: global_step = 9854, loss = 0.19562133 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 9854: /tmpfs/tmp/tmpiy6wag0h/model.ckpt-9854 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 9854: /tmpfs/tmp/tmpiy6wag0h/model.ckpt-9854 INFO:tensorflow:Loss for final step: 0.18979602. INFO:tensorflow:Loss for final step: 0.18979602. ({'loss': 0.19562133, 'global_step': 9854}, [])
TensorFlow 2: 組み込みコールバックと Model.fit による早期停止
MNIST データセットと単純な Keras モデルを準備します。
(ds_train, ds_test), ds_info = tfds.load(
'mnist',
split=['train', 'test'],
shuffle_files=True,
as_supervised=True,
with_info=True,
)
ds_train = ds_train.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(128)
ds_test = ds_test.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(
optimizer=tf.keras.optimizers.Adam(0.005),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
TensorFlow 2 では、組み込みの Keras Model.fit
(または Model.evaluate
)を使用する場合、組み込みのコールバック(tf.keras.callbacks.EarlyStopping
)を Model.fit
の callbacks
パラメータに渡すことで、早期停止を構成できます。
EarlyStopping
コールバックは、ユーザー指定の指標を監視し、改善が止まるとトレーニングを終了します。(詳細については、組み込みメソッドによるトレーニングおよび評価または API ドキュメントを確認してください。)
以下は、損失を監視し、改善を示さないエポック数が 3
(patience
)に設定された後にトレーニングを停止する早期停止コールバックの例です。
callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
# Only around 25 epochs are run during training, instead of 100.
history = model.fit(
ds_train,
epochs=100,
validation_data=ds_test,
callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100 WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1704997690.618989 68919 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. 469/469 [==============================] - 3s 4ms/step - loss: 0.2279 - sparse_categorical_accuracy: 0.9316 - val_loss: 0.1229 - val_sparse_categorical_accuracy: 0.9642 Epoch 2/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0947 - sparse_categorical_accuracy: 0.9714 - val_loss: 0.0974 - val_sparse_categorical_accuracy: 0.9702 Epoch 3/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0667 - sparse_categorical_accuracy: 0.9791 - val_loss: 0.1035 - val_sparse_categorical_accuracy: 0.9698 Epoch 4/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0493 - sparse_categorical_accuracy: 0.9848 - val_loss: 0.1030 - val_sparse_categorical_accuracy: 0.9715 Epoch 5/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0385 - sparse_categorical_accuracy: 0.9866 - val_loss: 0.1438 - val_sparse_categorical_accuracy: 0.9636 Epoch 6/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0350 - sparse_categorical_accuracy: 0.9882 - val_loss: 0.1081 - val_sparse_categorical_accuracy: 0.9725 Epoch 7/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0310 - sparse_categorical_accuracy: 0.9892 - val_loss: 0.1100 - val_sparse_categorical_accuracy: 0.9735 Epoch 8/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0285 - sparse_categorical_accuracy: 0.9905 - val_loss: 0.1180 - val_sparse_categorical_accuracy: 0.9744 Epoch 9/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0266 - sparse_categorical_accuracy: 0.9908 - val_loss: 0.1124 - val_sparse_categorical_accuracy: 0.9761 Epoch 10/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0219 - sparse_categorical_accuracy: 0.9925 - val_loss: 0.1240 - val_sparse_categorical_accuracy: 0.9748 Epoch 11/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0200 - sparse_categorical_accuracy: 0.9932 - val_loss: 0.1391 - val_sparse_categorical_accuracy: 0.9754 Epoch 12/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0206 - sparse_categorical_accuracy: 0.9928 - val_loss: 0.1585 - val_sparse_categorical_accuracy: 0.9713 Epoch 13/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0236 - sparse_categorical_accuracy: 0.9924 - val_loss: 0.1529 - val_sparse_categorical_accuracy: 0.9748 Epoch 14/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0200 - sparse_categorical_accuracy: 0.9932 - val_loss: 0.1491 - val_sparse_categorical_accuracy: 0.9759 14
TensorFlow 2: カスタムコールバックと Model.fit による早期停止
Model.fit
(または Model.evaluate
)の callbacks
パラメータに渡すこともできるカスタムの早期停止コールバックを実装することもできます。
この例では、self.model.stop_training
が True
に設定されると、トレーニングプロセスが停止されます。
class LimitTrainingTime(tf.keras.callbacks.Callback):
def __init__(self, max_time_s):
super().__init__()
self.max_time_s = max_time_s
self.start_time = None
def on_train_begin(self, logs):
self.start_time = time.time()
def on_train_batch_end(self, batch, logs):
now = time.time()
if now - self.start_time > self.max_time_s:
self.model.stop_training = True
# Limit the training time to 30 seconds.
callback = LimitTrainingTime(30)
history = model.fit(
ds_train,
epochs=100,
validation_data=ds_test,
callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0194 - sparse_categorical_accuracy: 0.9939 - val_loss: 0.1414 - val_sparse_categorical_accuracy: 0.9793 Epoch 2/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0175 - sparse_categorical_accuracy: 0.9941 - val_loss: 0.1568 - val_sparse_categorical_accuracy: 0.9779 Epoch 3/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0144 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.1515 - val_sparse_categorical_accuracy: 0.9771 Epoch 4/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0180 - sparse_categorical_accuracy: 0.9947 - val_loss: 0.1559 - val_sparse_categorical_accuracy: 0.9774 Epoch 5/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0212 - sparse_categorical_accuracy: 0.9938 - val_loss: 0.1807 - val_sparse_categorical_accuracy: 0.9735 Epoch 6/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0154 - sparse_categorical_accuracy: 0.9953 - val_loss: 0.1803 - val_sparse_categorical_accuracy: 0.9761 Epoch 7/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0150 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1661 - val_sparse_categorical_accuracy: 0.9778 Epoch 8/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0150 - sparse_categorical_accuracy: 0.9954 - val_loss: 0.1636 - val_sparse_categorical_accuracy: 0.9762 Epoch 9/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0143 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.1799 - val_sparse_categorical_accuracy: 0.9761 Epoch 10/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0143 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.1802 - val_sparse_categorical_accuracy: 0.9792 Epoch 11/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0139 - sparse_categorical_accuracy: 0.9956 - val_loss: 0.1959 - val_sparse_categorical_accuracy: 0.9747 Epoch 12/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0194 - sparse_categorical_accuracy: 0.9947 - val_loss: 0.2243 - val_sparse_categorical_accuracy: 0.9748 Epoch 13/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0120 - sparse_categorical_accuracy: 0.9969 - val_loss: 0.1788 - val_sparse_categorical_accuracy: 0.9787 Epoch 14/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0118 - sparse_categorical_accuracy: 0.9964 - val_loss: 0.1857 - val_sparse_categorical_accuracy: 0.9786 Epoch 15/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0113 - sparse_categorical_accuracy: 0.9969 - val_loss: 0.2032 - val_sparse_categorical_accuracy: 0.9774 Epoch 16/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0126 - sparse_categorical_accuracy: 0.9966 - val_loss: 0.2107 - val_sparse_categorical_accuracy: 0.9773 Epoch 17/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0141 - sparse_categorical_accuracy: 0.9961 - val_loss: 0.2157 - val_sparse_categorical_accuracy: 0.9759 Epoch 18/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0116 - sparse_categorical_accuracy: 0.9966 - val_loss: 0.2189 - val_sparse_categorical_accuracy: 0.9772 Epoch 19/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0140 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2316 - val_sparse_categorical_accuracy: 0.9772 Epoch 20/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0116 - sparse_categorical_accuracy: 0.9968 - val_loss: 0.2400 - val_sparse_categorical_accuracy: 0.9775 Epoch 21/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0108 - sparse_categorical_accuracy: 0.9972 - val_loss: 0.2283 - val_sparse_categorical_accuracy: 0.9770 Epoch 22/100 469/469 [==============================] - 1s 1ms/step - loss: 0.0093 - sparse_categorical_accuracy: 0.9970 - val_loss: 0.2263 - val_sparse_categorical_accuracy: 0.9784 22
TensorFlow 2: カスタムトレーニングループによる早期停止
TensorFlow 2 では、組み込みの Keras メソッドを使用してトレーニングと評価を行っていない場合、カスタムトレーニングループで早期停止を実装できます。
Keras API を使用して、別の単純なモデル、オプティマイザ、損失関数、および指標を定義することから始めます。
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
optimizer = tf.keras.optimizers.Adam(0.005)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
train_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
val_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
tf.GradientTape と @tf.function
デコレータを使用してパラメータ更新関数を定義し、スピードアップします。
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
logits = model(x, training=True)
loss_value = loss_fn(y, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
train_acc_metric.update_state(y, logits)
train_loss_metric.update_state(y, logits)
return loss_value
@tf.function
def test_step(x, y):
logits = model(x, training=False)
val_acc_metric.update_state(y, logits)
val_loss_metric.update_state(y, logits)
次に、早期停止ルールを手動で実装できるカスタムトレーニングループを記述します。
以下の例は、検証損失が特定のエポック数にわたって改善されない場合にトレーニングを停止する方法を示しています。
epochs = 100
patience = 5
wait = 0
best = float('inf')
for epoch in range(epochs):
print("\nStart of epoch %d" % (epoch,))
start_time = time.time()
for step, (x_batch_train, y_batch_train) in enumerate(ds_train):
loss_value = train_step(x_batch_train, y_batch_train)
if step % 200 == 0:
print("Training loss at step %d: %.4f" % (step, loss_value.numpy()))
print("Seen so far: %s samples" % ((step + 1) * 128))
train_acc = train_acc_metric.result()
train_loss = train_loss_metric.result()
train_acc_metric.reset_states()
train_loss_metric.reset_states()
print("Training acc over epoch: %.4f" % (train_acc.numpy()))
for x_batch_val, y_batch_val in ds_test:
test_step(x_batch_val, y_batch_val)
val_acc = val_acc_metric.result()
val_loss = val_loss_metric.result()
val_acc_metric.reset_states()
val_loss_metric.reset_states()
print("Validation acc: %.4f" % (float(val_acc),))
print("Time taken: %.2fs" % (time.time() - start_time))
# The early stopping strategy: stop the training if `val_loss` does not
# decrease over a certain number of epochs.
wait += 1
if val_loss < best:
best = val_loss
wait = 0
if wait >= patience:
break
Start of epoch 0 Training loss at step 0: 2.3667 Seen so far: 128 samples Training loss at step 200: 0.2732 Seen so far: 25728 samples Training loss at step 400: 0.2398 Seen so far: 51328 samples Training acc over epoch: 0.9300 Validation acc: 0.9622 Time taken: 2.85s Start of epoch 1 Training loss at step 0: 0.0916 Seen so far: 128 samples Training loss at step 200: 0.1402 Seen so far: 25728 samples Training loss at step 400: 0.1649 Seen so far: 51328 samples Training acc over epoch: 0.9703 Validation acc: 0.9690 Time taken: 1.11s Start of epoch 2 Training loss at step 0: 0.0642 Seen so far: 128 samples Training loss at step 200: 0.0878 Seen so far: 25728 samples Training loss at step 400: 0.1178 Seen so far: 51328 samples Training acc over epoch: 0.9791 Validation acc: 0.9725 Time taken: 1.09s Start of epoch 3 Training loss at step 0: 0.0279 Seen so far: 128 samples Training loss at step 200: 0.0560 Seen so far: 25728 samples Training loss at step 400: 0.0390 Seen so far: 51328 samples Training acc over epoch: 0.9834 Validation acc: 0.9717 Time taken: 1.19s Start of epoch 4 Training loss at step 0: 0.0122 Seen so far: 128 samples Training loss at step 200: 0.0434 Seen so far: 25728 samples Training loss at step 400: 0.0586 Seen so far: 51328 samples Training acc over epoch: 0.9868 Validation acc: 0.9709 Time taken: 1.14s Start of epoch 5 Training loss at step 0: 0.0321 Seen so far: 128 samples Training loss at step 200: 0.0287 Seen so far: 25728 samples Training loss at step 400: 0.0125 Seen so far: 51328 samples Training acc over epoch: 0.9879 Validation acc: 0.9750 Time taken: 1.14s Start of epoch 6 Training loss at step 0: 0.0093 Seen so far: 128 samples Training loss at step 200: 0.0358 Seen so far: 25728 samples Training loss at step 400: 0.0420 Seen so far: 51328 samples Training acc over epoch: 0.9886 Validation acc: 0.9719 Time taken: 1.13s Start of epoch 7 Training loss at step 0: 0.0047 Seen so far: 128 samples Training loss at step 200: 0.0080 Seen so far: 25728 samples Training loss at step 400: 0.0271 Seen so far: 51328 samples Training acc over epoch: 0.9909 Validation acc: 0.9701 Time taken: 1.20s Start of epoch 8 Training loss at step 0: 0.0180 Seen so far: 128 samples Training loss at step 200: 0.0406 Seen so far: 25728 samples Training loss at step 400: 0.0498 Seen so far: 51328 samples Training acc over epoch: 0.9913 Validation acc: 0.9708 Time taken: 1.18s Start of epoch 9 Training loss at step 0: 0.0173 Seen so far: 128 samples Training loss at step 200: 0.0444 Seen so far: 25728 samples Training loss at step 400: 0.0116 Seen so far: 51328 samples Training acc over epoch: 0.9916 Validation acc: 0.9738 Time taken: 1.18s Start of epoch 10 Training loss at step 0: 0.0438 Seen so far: 128 samples Training loss at step 200: 0.0663 Seen so far: 25728 samples Training loss at step 400: 0.0328 Seen so far: 51328 samples Training acc over epoch: 0.9925 Validation acc: 0.9725 Time taken: 1.17s Start of epoch 11 Training loss at step 0: 0.0006 Seen so far: 128 samples Training loss at step 200: 0.0269 Seen so far: 25728 samples Training loss at step 400: 0.0058 Seen so far: 51328 samples Training acc over epoch: 0.9919 Validation acc: 0.9719 Time taken: 1.12s Start of epoch 12 Training loss at step 0: 0.0594 Seen so far: 128 samples Training loss at step 200: 0.0026 Seen so far: 25728 samples Training loss at step 400: 0.0661 Seen so far: 51328 samples Training acc over epoch: 0.9926 Validation acc: 0.9765 Time taken: 1.16s Start of epoch 13 Training loss at step 0: 0.0003 Seen so far: 128 samples Training loss at step 200: 0.0214 Seen so far: 25728 samples Training loss at step 400: 0.0048 Seen so far: 51328 samples Training acc over epoch: 0.9936 Validation acc: 0.9763 Time taken: 1.24s Start of epoch 14 Training loss at step 0: 0.0008 Seen so far: 128 samples Training loss at step 200: 0.0078 Seen so far: 25728 samples Training loss at step 400: 0.0009 Seen so far: 51328 samples Training acc over epoch: 0.9941 Validation acc: 0.9746 Time taken: 1.08s Start of epoch 15 Training loss at step 0: 0.0020 Seen so far: 128 samples Training loss at step 200: 0.0047 Seen so far: 25728 samples Training loss at step 400: 0.0289 Seen so far: 51328 samples Training acc over epoch: 0.9937 Validation acc: 0.9753 Time taken: 1.10s Start of epoch 16 Training loss at step 0: 0.0063 Seen so far: 128 samples Training loss at step 200: 0.0698 Seen so far: 25728 samples Training loss at step 400: 0.0442 Seen so far: 51328 samples Training acc over epoch: 0.9941 Validation acc: 0.9737 Time taken: 1.08s Start of epoch 17 Training loss at step 0: 0.0019 Seen so far: 128 samples Training loss at step 200: 0.0066 Seen so far: 25728 samples Training loss at step 400: 0.0168 Seen so far: 51328 samples Training acc over epoch: 0.9945 Validation acc: 0.9745 Time taken: 1.17s Start of epoch 18 Training loss at step 0: 0.0025 Seen so far: 128 samples Training loss at step 200: 0.0123 Seen so far: 25728 samples Training loss at step 400: 0.0914 Seen so far: 51328 samples Training acc over epoch: 0.9929 Validation acc: 0.9700 Time taken: 1.07s Start of epoch 19 Training loss at step 0: 0.0029 Seen so far: 128 samples Training loss at step 200: 0.0456 Seen so far: 25728 samples Training loss at step 400: 0.0002 Seen so far: 51328 samples Training acc over epoch: 0.9947 Validation acc: 0.9758 Time taken: 1.15s Start of epoch 20 Training loss at step 0: 0.0002 Seen so far: 128 samples Training loss at step 200: 0.0008 Seen so far: 25728 samples Training loss at step 400: 0.0010 Seen so far: 51328 samples Training acc over epoch: 0.9966 Validation acc: 0.9771 Time taken: 1.11s Start of epoch 21 Training loss at step 0: 0.0100 Seen so far: 128 samples Training loss at step 200: 0.0139 Seen so far: 25728 samples Training loss at step 400: 0.0398 Seen so far: 51328 samples Training acc over epoch: 0.9961 Validation acc: 0.9744 Time taken: 1.09s Start of epoch 22 Training loss at step 0: 0.0000 Seen so far: 128 samples Training loss at step 200: 0.0008 Seen so far: 25728 samples Training loss at step 400: 0.0008 Seen so far: 51328 samples Training acc over epoch: 0.9958 Validation acc: 0.9784 Time taken: 1.05s Start of epoch 23 Training loss at step 0: 0.0009 Seen so far: 128 samples Training loss at step 200: 0.0419 Seen so far: 25728 samples Training loss at step 400: 0.0020 Seen so far: 51328 samples Training acc over epoch: 0.9959 Validation acc: 0.9764 Time taken: 1.11s Start of epoch 24 Training loss at step 0: 0.0002 Seen so far: 128 samples Training loss at step 200: 0.0192 Seen so far: 25728 samples Training loss at step 400: 0.0186 Seen so far: 51328 samples Training acc over epoch: 0.9955 Validation acc: 0.9754 Time taken: 1.08s Start of epoch 25 Training loss at step 0: 0.0009 Seen so far: 128 samples Training loss at step 200: 0.0232 Seen so far: 25728 samples Training loss at step 400: 0.0007 Seen so far: 51328 samples Training acc over epoch: 0.9956 Validation acc: 0.9772 Time taken: 1.09s Start of epoch 26 Training loss at step 0: 0.0001 Seen so far: 128 samples Training loss at step 200: 0.0427 Seen so far: 25728 samples Training loss at step 400: 0.0034 Seen so far: 51328 samples Training acc over epoch: 0.9946 Validation acc: 0.9785 Time taken: 1.12s Start of epoch 27 Training loss at step 0: 0.0013 Seen so far: 128 samples Training loss at step 200: 0.0092 Seen so far: 25728 samples Training loss at step 400: 0.0955 Seen so far: 51328 samples Training acc over epoch: 0.9962 Validation acc: 0.9755 Time taken: 1.12s
Next steps
- API ドキュメントで、Keras の組み込み早期停止コールバック API の詳細をご覧ください。
- 最小損失での早期停止を含む、カスタム Keras コールバックの書き方を学びます。
- Keras 組み込みメソッドを使用したトレーニングと評価について学びます。
EarlyStopping
コールバックを使用する 過学習および未学習 のチュートリアルで、一般的な正則化手法を調べます。