Migracja wcześnie zatrzymana

Zobacz na TensorFlow.org Uruchom w Google Colab Wyświetl źródło na GitHub Pobierz notatnik

Ten notatnik pokazuje, jak skonfigurować uczenie modeli z wczesnym zatrzymywaniem, najpierw w TensorFlow 1 z tf.estimator.Estimator i zaczepem wczesnego zatrzymywania, a następnie w TensorFlow 2 z interfejsami API Keras lub niestandardową pętlą uczenia. Wczesne zatrzymanie to technika regularyzacji, która zatrzymuje trening, jeśli na przykład utrata walidacji osiągnie pewien próg.

W TensorFlow 2 istnieją trzy sposoby wdrożenia wczesnego zatrzymania:

Ustawiać

import time
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_datasets as tfds

TensorFlow 1: Wczesne zatrzymanie za pomocą haka wczesnego zatrzymania i tf.estimator

Zacznij od zdefiniowania funkcji ładowania i wstępnego przetwarzania zestawu danych MNIST oraz definicji modelu, który ma być używany z tf.estimator.Estimator :

def normalize_img(image, label):
  return tf.cast(image, tf.float32) / 255., label

def _input_fn():
  ds_train = tfds.load(
    name='mnist',
    split='train',
    shuffle_files=True,
    as_supervised=True)

  ds_train = ds_train.map(
      normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
  ds_train = ds_train.batch(128)
  ds_train = ds_train.repeat(100)
  return ds_train

def _eval_input_fn():
  ds_test = tfds.load(
    name='mnist',
    split='test',
    shuffle_files=True,
    as_supervised=True)
  ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
  ds_test = ds_test.batch(128)
  return ds_test

def _model_fn(features, labels, mode):
  flatten = tf1.layers.Flatten()(features)
  features = tf1.layers.Dense(128, 'relu')(flatten)
  logits = tf1.layers.Dense(10)(features)

  loss = tf1.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
  optimizer = tf1.train.AdagradOptimizer(0.005)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())

  return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

W TensorFlow 1 wczesne zatrzymywanie działa poprzez skonfigurowanie haka wczesnego zatrzymywania za pomocą tf.estimator.experimental.make_early_stopping_hook . Przechwytywanie przekazujesz do metody make_early_stopping_hook jako parametr dla should_stop_fn , która może akceptować funkcję bez żadnych argumentów. Szkolenie zatrzymuje się, gdy should_stop_fn zwraca True .

Poniższy przykład pokazuje, jak wdrożyć technikę wczesnego zatrzymania, która ogranicza czas treningu do maksymalnie 20 sekund:

estimator = tf1.estimator.Estimator(model_fn=_model_fn)

start_time = time.time()
max_train_seconds = 20

def should_stop_fn():
  return time.time() - start_time > max_train_seconds

early_stopping_hook = tf1.estimator.experimental.make_early_stopping_hook(
    estimator=estimator,
    should_stop_fn=should_stop_fn,
    run_every_secs=1,
    run_every_steps=None)

train_spec = tf1.estimator.TrainSpec(
    input_fn=_input_fn,
    hooks=[early_stopping_hook])

eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)

tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpocmc6_bo
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpocmc6_bo', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/adagrad.py:77: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/adagrad.py:77: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpocmc6_bo/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpocmc6_bo/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 2.3545606, step = 0
INFO:tensorflow:loss = 2.3545606, step = 0
INFO:tensorflow:global_step/sec: 94.5711
INFO:tensorflow:global_step/sec: 94.5711
INFO:tensorflow:loss = 1.3383636, step = 100 (1.060 sec)
INFO:tensorflow:loss = 1.3383636, step = 100 (1.060 sec)
INFO:tensorflow:global_step/sec: 158.428
INFO:tensorflow:global_step/sec: 158.428
INFO:tensorflow:loss = 0.7937969, step = 200 (0.631 sec)
INFO:tensorflow:loss = 0.7937969, step = 200 (0.631 sec)
INFO:tensorflow:global_step/sec: 287.334
INFO:tensorflow:global_step/sec: 287.334
INFO:tensorflow:loss = 0.69060934, step = 300 (0.349 sec)
INFO:tensorflow:loss = 0.69060934, step = 300 (0.349 sec)
INFO:tensorflow:global_step/sec: 286.658
INFO:tensorflow:global_step/sec: 286.658
INFO:tensorflow:loss = 0.59314424, step = 400 (0.349 sec)
INFO:tensorflow:loss = 0.59314424, step = 400 (0.349 sec)
INFO:tensorflow:global_step/sec: 311.591
INFO:tensorflow:global_step/sec: 311.591
INFO:tensorflow:loss = 0.50495726, step = 500 (0.320 sec)
INFO:tensorflow:loss = 0.50495726, step = 500 (0.320 sec)
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 536 vs previous value: 536. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 536 vs previous value: 536. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
INFO:tensorflow:global_step/sec: 538.395
INFO:tensorflow:global_step/sec: 538.395
INFO:tensorflow:loss = 0.43083754, step = 600 (0.186 sec)
INFO:tensorflow:loss = 0.43083754, step = 600 (0.186 sec)
INFO:tensorflow:global_step/sec: 503.72
INFO:tensorflow:global_step/sec: 503.72
INFO:tensorflow:loss = 0.381118, step = 700 (0.198 sec)
INFO:tensorflow:loss = 0.381118, step = 700 (0.198 sec)
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 715 vs previous value: 715. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 715 vs previous value: 715. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
INFO:tensorflow:global_step/sec: 482.019
INFO:tensorflow:global_step/sec: 482.019
INFO:tensorflow:loss = 0.49349022, step = 800 (0.207 sec)
INFO:tensorflow:loss = 0.49349022, step = 800 (0.207 sec)
INFO:tensorflow:global_step/sec: 508.316
INFO:tensorflow:global_step/sec: 508.316
INFO:tensorflow:loss = 0.38730466, step = 900 (0.199 sec)
INFO:tensorflow:loss = 0.38730466, step = 900 (0.199 sec)
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 987 vs previous value: 987. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 987 vs previous value: 987. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
INFO:tensorflow:global_step/sec: 452.89
INFO:tensorflow:global_step/sec: 452.89
INFO:tensorflow:loss = 0.44916487, step = 1000 (0.219 sec)
INFO:tensorflow:loss = 0.44916487, step = 1000 (0.219 sec)
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 1042 vs previous value: 1042. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 1042 vs previous value: 1042. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
INFO:tensorflow:global_step/sec: 519.401
INFO:tensorflow:global_step/sec: 519.401
INFO:tensorflow:loss = 0.44320562, step = 1100 (0.192 sec)
INFO:tensorflow:loss = 0.44320562, step = 1100 (0.192 sec)
INFO:tensorflow:global_step/sec: 510.25
INFO:tensorflow:global_step/sec: 510.25
INFO:tensorflow:loss = 0.3758085, step = 1200 (0.196 sec)
INFO:tensorflow:loss = 0.3758085, step = 1200 (0.196 sec)
INFO:tensorflow:global_step/sec: 518.649
INFO:tensorflow:global_step/sec: 518.649
INFO:tensorflow:loss = 0.46760654, step = 1300 (0.193 sec)
INFO:tensorflow:loss = 0.46760654, step = 1300 (0.193 sec)
INFO:tensorflow:global_step/sec: 474.056
INFO:tensorflow:global_step/sec: 474.056
INFO:tensorflow:loss = 0.29544568, step = 1400 (0.211 sec)
INFO:tensorflow:loss = 0.29544568, step = 1400 (0.211 sec)
INFO:tensorflow:global_step/sec: 461.406
INFO:tensorflow:global_step/sec: 461.406
INFO:tensorflow:loss = 0.28616875, step = 1500 (0.217 sec)
INFO:tensorflow:loss = 0.28616875, step = 1500 (0.217 sec)
INFO:tensorflow:global_step/sec: 486.2
INFO:tensorflow:global_step/sec: 486.2
INFO:tensorflow:loss = 0.4114887, step = 1600 (0.206 sec)
INFO:tensorflow:loss = 0.4114887, step = 1600 (0.206 sec)
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 1678 vs previous value: 1678. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 1678 vs previous value: 1678. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
INFO:tensorflow:global_step/sec: 507.701
INFO:tensorflow:global_step/sec: 507.701
INFO:tensorflow:loss = 0.35298553, step = 1700 (0.197 sec)
INFO:tensorflow:loss = 0.35298553, step = 1700 (0.197 sec)
INFO:tensorflow:global_step/sec: 490.541
INFO:tensorflow:global_step/sec: 490.541
INFO:tensorflow:loss = 0.3363277, step = 1800 (0.204 sec)
INFO:tensorflow:loss = 0.3363277, step = 1800 (0.204 sec)
INFO:tensorflow:global_step/sec: 460.083
INFO:tensorflow:global_step/sec: 460.083
INFO:tensorflow:loss = 0.50634325, step = 1900 (0.217 sec)
INFO:tensorflow:loss = 0.50634325, step = 1900 (0.217 sec)
INFO:tensorflow:global_step/sec: 436.782
INFO:tensorflow:global_step/sec: 436.782
INFO:tensorflow:loss = 0.2063987, step = 2000 (0.229 sec)
INFO:tensorflow:loss = 0.2063987, step = 2000 (0.229 sec)
INFO:tensorflow:global_step/sec: 475.841
INFO:tensorflow:global_step/sec: 475.841
INFO:tensorflow:loss = 0.27246287, step = 2100 (0.210 sec)
INFO:tensorflow:loss = 0.27246287, step = 2100 (0.210 sec)
INFO:tensorflow:global_step/sec: 483.322
INFO:tensorflow:global_step/sec: 483.322
INFO:tensorflow:loss = 0.31674564, step = 2200 (0.207 sec)
INFO:tensorflow:loss = 0.31674564, step = 2200 (0.207 sec)
INFO:tensorflow:global_step/sec: 442.257
INFO:tensorflow:global_step/sec: 442.257
INFO:tensorflow:loss = 0.3334998, step = 2300 (0.226 sec)
INFO:tensorflow:loss = 0.3334998, step = 2300 (0.226 sec)
INFO:tensorflow:global_step/sec: 476.38
INFO:tensorflow:global_step/sec: 476.38
INFO:tensorflow:loss = 0.2549953, step = 2400 (0.210 sec)
INFO:tensorflow:loss = 0.2549953, step = 2400 (0.210 sec)
INFO:tensorflow:global_step/sec: 467.543
INFO:tensorflow:global_step/sec: 467.543
INFO:tensorflow:loss = 0.21111101, step = 2500 (0.214 sec)
INFO:tensorflow:loss = 0.21111101, step = 2500 (0.214 sec)
INFO:tensorflow:global_step/sec: 497.051
INFO:tensorflow:global_step/sec: 497.051
INFO:tensorflow:loss = 0.15878338, step = 2600 (0.201 sec)
INFO:tensorflow:loss = 0.15878338, step = 2600 (0.201 sec)
INFO:tensorflow:global_step/sec: 461.785
INFO:tensorflow:global_step/sec: 461.785
INFO:tensorflow:loss = 0.31587577, step = 2700 (0.219 sec)
INFO:tensorflow:loss = 0.31587577, step = 2700 (0.219 sec)
INFO:tensorflow:global_step/sec: 493.743
INFO:tensorflow:global_step/sec: 493.743
INFO:tensorflow:loss = 0.47478187, step = 2800 (0.200 sec)
INFO:tensorflow:loss = 0.47478187, step = 2800 (0.200 sec)
INFO:tensorflow:global_step/sec: 463.477
INFO:tensorflow:global_step/sec: 463.477
INFO:tensorflow:loss = 0.2499526, step = 2900 (0.216 sec)
INFO:tensorflow:loss = 0.2499526, step = 2900 (0.216 sec)
INFO:tensorflow:global_step/sec: 538.27
INFO:tensorflow:global_step/sec: 538.27
INFO:tensorflow:loss = 0.34210858, step = 3000 (0.186 sec)
INFO:tensorflow:loss = 0.34210858, step = 3000 (0.186 sec)
INFO:tensorflow:global_step/sec: 508.741
INFO:tensorflow:global_step/sec: 508.741
INFO:tensorflow:loss = 0.2128592, step = 3100 (0.197 sec)
INFO:tensorflow:loss = 0.2128592, step = 3100 (0.197 sec)
INFO:tensorflow:global_step/sec: 519.319
INFO:tensorflow:global_step/sec: 519.319
INFO:tensorflow:loss = 0.40954083, step = 3200 (0.192 sec)
INFO:tensorflow:loss = 0.40954083, step = 3200 (0.192 sec)
INFO:tensorflow:global_step/sec: 468.989
INFO:tensorflow:global_step/sec: 468.989
INFO:tensorflow:loss = 0.34270883, step = 3300 (0.213 sec)
INFO:tensorflow:loss = 0.34270883, step = 3300 (0.213 sec)
INFO:tensorflow:global_step/sec: 479.856
INFO:tensorflow:global_step/sec: 479.856
INFO:tensorflow:loss = 0.26599607, step = 3400 (0.209 sec)
INFO:tensorflow:loss = 0.26599607, step = 3400 (0.209 sec)
INFO:tensorflow:global_step/sec: 495.76
INFO:tensorflow:global_step/sec: 495.76
INFO:tensorflow:loss = 0.21713805, step = 3500 (0.201 sec)
INFO:tensorflow:loss = 0.21713805, step = 3500 (0.201 sec)
INFO:tensorflow:global_step/sec: 440.282
INFO:tensorflow:global_step/sec: 440.282
INFO:tensorflow:loss = 0.22268976, step = 3600 (0.228 sec)
INFO:tensorflow:loss = 0.22268976, step = 3600 (0.228 sec)
INFO:tensorflow:global_step/sec: 495.629
INFO:tensorflow:global_step/sec: 495.629
INFO:tensorflow:loss = 0.28974164, step = 3700 (0.201 sec)
INFO:tensorflow:loss = 0.28974164, step = 3700 (0.201 sec)
INFO:tensorflow:global_step/sec: 468.695
INFO:tensorflow:global_step/sec: 468.695
INFO:tensorflow:loss = 0.37919793, step = 3800 (0.214 sec)
INFO:tensorflow:loss = 0.37919793, step = 3800 (0.214 sec)
INFO:tensorflow:global_step/sec: 529.005
INFO:tensorflow:global_step/sec: 529.005
INFO:tensorflow:loss = 0.23738712, step = 3900 (0.189 sec)
INFO:tensorflow:loss = 0.23738712, step = 3900 (0.189 sec)
INFO:tensorflow:global_step/sec: 494.809
INFO:tensorflow:global_step/sec: 494.809
INFO:tensorflow:loss = 0.29650036, step = 4000 (0.204 sec)
INFO:tensorflow:loss = 0.29650036, step = 4000 (0.204 sec)
INFO:tensorflow:global_step/sec: 525.629
INFO:tensorflow:global_step/sec: 525.629
INFO:tensorflow:loss = 0.20826155, step = 4100 (0.188 sec)
INFO:tensorflow:loss = 0.20826155, step = 4100 (0.188 sec)
INFO:tensorflow:global_step/sec: 509.573
INFO:tensorflow:global_step/sec: 509.573
INFO:tensorflow:loss = 0.26417816, step = 4200 (0.196 sec)
INFO:tensorflow:loss = 0.26417816, step = 4200 (0.196 sec)
INFO:tensorflow:global_step/sec: 472.845
INFO:tensorflow:global_step/sec: 472.845
INFO:tensorflow:loss = 0.31241363, step = 4300 (0.212 sec)
INFO:tensorflow:loss = 0.31241363, step = 4300 (0.212 sec)
INFO:tensorflow:global_step/sec: 510.868
INFO:tensorflow:global_step/sec: 510.868
INFO:tensorflow:loss = 0.32773697, step = 4400 (0.195 sec)
INFO:tensorflow:loss = 0.32773697, step = 4400 (0.195 sec)
INFO:tensorflow:global_step/sec: 492.967
INFO:tensorflow:global_step/sec: 492.967
INFO:tensorflow:loss = 0.28609803, step = 4500 (0.203 sec)
INFO:tensorflow:loss = 0.28609803, step = 4500 (0.203 sec)
INFO:tensorflow:global_step/sec: 507.394
INFO:tensorflow:global_step/sec: 507.394
INFO:tensorflow:loss = 0.32142323, step = 4600 (0.197 sec)
INFO:tensorflow:loss = 0.32142323, step = 4600 (0.197 sec)
INFO:tensorflow:global_step/sec: 475.176
INFO:tensorflow:global_step/sec: 475.176
INFO:tensorflow:loss = 0.14882785, step = 4700 (0.211 sec)
INFO:tensorflow:loss = 0.14882785, step = 4700 (0.211 sec)
INFO:tensorflow:global_step/sec: 503.718
INFO:tensorflow:global_step/sec: 503.718
INFO:tensorflow:loss = 0.312344, step = 4800 (0.198 sec)
INFO:tensorflow:loss = 0.312344, step = 4800 (0.198 sec)
INFO:tensorflow:global_step/sec: 497.659
INFO:tensorflow:global_step/sec: 497.659
INFO:tensorflow:loss = 0.37370217, step = 4900 (0.201 sec)
INFO:tensorflow:loss = 0.37370217, step = 4900 (0.201 sec)
INFO:tensorflow:global_step/sec: 477.736
INFO:tensorflow:global_step/sec: 477.736
INFO:tensorflow:loss = 0.2663591, step = 5000 (0.209 sec)
INFO:tensorflow:loss = 0.2663591, step = 5000 (0.209 sec)
INFO:tensorflow:global_step/sec: 496.559
INFO:tensorflow:global_step/sec: 496.559
INFO:tensorflow:loss = 0.34745598, step = 5100 (0.202 sec)
INFO:tensorflow:loss = 0.34745598, step = 5100 (0.202 sec)
INFO:tensorflow:global_step/sec: 475.989
INFO:tensorflow:global_step/sec: 475.989
INFO:tensorflow:loss = 0.21809828, step = 5200 (0.210 sec)
INFO:tensorflow:loss = 0.21809828, step = 5200 (0.210 sec)
INFO:tensorflow:global_step/sec: 474.464
INFO:tensorflow:global_step/sec: 474.464
INFO:tensorflow:loss = 0.2474105, step = 5300 (0.211 sec)
INFO:tensorflow:loss = 0.2474105, step = 5300 (0.211 sec)
INFO:tensorflow:global_step/sec: 488.774
INFO:tensorflow:global_step/sec: 488.774
INFO:tensorflow:loss = 0.1611641, step = 5400 (0.204 sec)
INFO:tensorflow:loss = 0.1611641, step = 5400 (0.204 sec)
INFO:tensorflow:global_step/sec: 504.942
INFO:tensorflow:global_step/sec: 504.942
INFO:tensorflow:loss = 0.2306528, step = 5500 (0.198 sec)
INFO:tensorflow:loss = 0.2306528, step = 5500 (0.198 sec)
INFO:tensorflow:global_step/sec: 514.058
INFO:tensorflow:global_step/sec: 514.058
INFO:tensorflow:loss = 0.20716992, step = 5600 (0.195 sec)
INFO:tensorflow:loss = 0.20716992, step = 5600 (0.195 sec)
INFO:tensorflow:global_step/sec: 458.899
INFO:tensorflow:global_step/sec: 458.899
INFO:tensorflow:loss = 0.16730343, step = 5700 (0.217 sec)
INFO:tensorflow:loss = 0.16730343, step = 5700 (0.217 sec)
INFO:tensorflow:global_step/sec: 495.197
INFO:tensorflow:global_step/sec: 495.197
INFO:tensorflow:loss = 0.2906361, step = 5800 (0.202 sec)
INFO:tensorflow:loss = 0.2906361, step = 5800 (0.202 sec)
INFO:tensorflow:global_step/sec: 482.244
INFO:tensorflow:global_step/sec: 482.244
INFO:tensorflow:loss = 0.24669808, step = 5900 (0.207 sec)
INFO:tensorflow:loss = 0.24669808, step = 5900 (0.207 sec)
INFO:tensorflow:global_step/sec: 484.946
INFO:tensorflow:global_step/sec: 484.946
INFO:tensorflow:loss = 0.26403594, step = 6000 (0.207 sec)
INFO:tensorflow:loss = 0.26403594, step = 6000 (0.207 sec)
INFO:tensorflow:global_step/sec: 486.74
INFO:tensorflow:global_step/sec: 486.74
INFO:tensorflow:loss = 0.19804293, step = 6100 (0.206 sec)
INFO:tensorflow:loss = 0.19804293, step = 6100 (0.206 sec)
INFO:tensorflow:global_step/sec: 436.727
INFO:tensorflow:global_step/sec: 436.727
INFO:tensorflow:loss = 0.25344175, step = 6200 (0.229 sec)
INFO:tensorflow:loss = 0.25344175, step = 6200 (0.229 sec)
INFO:tensorflow:global_step/sec: 428.73
INFO:tensorflow:global_step/sec: 428.73
INFO:tensorflow:loss = 0.2430937, step = 6300 (0.232 sec)
INFO:tensorflow:loss = 0.2430937, step = 6300 (0.232 sec)
INFO:tensorflow:global_step/sec: 449.706
INFO:tensorflow:global_step/sec: 449.706
INFO:tensorflow:loss = 0.2842306, step = 6400 (0.222 sec)
INFO:tensorflow:loss = 0.2842306, step = 6400 (0.222 sec)
INFO:tensorflow:global_step/sec: 440.873
INFO:tensorflow:global_step/sec: 440.873
INFO:tensorflow:loss = 0.2641199, step = 6500 (0.227 sec)
INFO:tensorflow:loss = 0.2641199, step = 6500 (0.227 sec)
INFO:tensorflow:global_step/sec: 424.092
INFO:tensorflow:global_step/sec: 424.092
INFO:tensorflow:loss = 0.19028814, step = 6600 (0.237 sec)
INFO:tensorflow:loss = 0.19028814, step = 6600 (0.237 sec)
INFO:tensorflow:global_step/sec: 450.352
INFO:tensorflow:global_step/sec: 450.352
INFO:tensorflow:loss = 0.24667627, step = 6700 (0.221 sec)
INFO:tensorflow:loss = 0.24667627, step = 6700 (0.221 sec)
INFO:tensorflow:global_step/sec: 462.774
INFO:tensorflow:global_step/sec: 462.774
INFO:tensorflow:loss = 0.40046322, step = 6800 (0.216 sec)
INFO:tensorflow:loss = 0.40046322, step = 6800 (0.216 sec)
INFO:tensorflow:global_step/sec: 460.854
INFO:tensorflow:global_step/sec: 460.854
INFO:tensorflow:loss = 0.14105138, step = 6900 (0.217 sec)
INFO:tensorflow:loss = 0.14105138, step = 6900 (0.217 sec)
INFO:tensorflow:Requesting early stopping at global step 6916
INFO:tensorflow:Requesting early stopping at global step 6916
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6917...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6917...
INFO:tensorflow:Saving checkpoints for 6917 into /tmp/tmpocmc6_bo/model.ckpt.
INFO:tensorflow:Saving checkpoints for 6917 into /tmp/tmpocmc6_bo/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6917...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6917...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-09-22T20:07:35
INFO:tensorflow:Starting evaluation at 2021-09-22T20:07:35
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpocmc6_bo/model.ckpt-6917
INFO:tensorflow:Restoring parameters from /tmp/tmpocmc6_bo/model.ckpt-6917
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Inference Time : 0.79520s
INFO:tensorflow:Inference Time : 0.79520s
INFO:tensorflow:Finished evaluation at 2021-09-22-20:07:36
INFO:tensorflow:Finished evaluation at 2021-09-22-20:07:36
INFO:tensorflow:Saving dict for global step 6917: global_step = 6917, loss = 0.227278
INFO:tensorflow:Saving dict for global step 6917: global_step = 6917, loss = 0.227278
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 6917: /tmp/tmpocmc6_bo/model.ckpt-6917
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 6917: /tmp/tmpocmc6_bo/model.ckpt-6917
INFO:tensorflow:Loss for final step: 0.13882703.
INFO:tensorflow:Loss for final step: 0.13882703.
({'loss': 0.227278, 'global_step': 6917}, [])

TensorFlow 2: Wczesne zatrzymanie dzięki wbudowanemu wywołaniu zwrotnemu i Model.fit

Przygotuj zbiór danych MNIST i prosty model Keras:

(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(128)

ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.005),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

W TensorFlow 2, gdy używasz wbudowanego Keras Model.fit (lub Model.evaluate ), możesz skonfigurować wczesne zatrzymywanie, przekazując wbudowane wywołanie zwrotne — tf.keras.callbacks.EarlyStopping — do parametru callbacks w Model.fit .

Wywołanie zwrotne EarlyStopping monitoruje metrykę określoną przez użytkownika i kończy trenowanie, gdy przestaje się poprawiać. (Sprawdź Szkolenie i ocenę z wbudowanymi metodami lub dokumentację interfejsu API , aby uzyskać więcej informacji).

Poniżej znajduje się przykład wczesnego zatrzymania wywołania zwrotnego, które monitoruje utratę i zatrzymuje trenowanie po tym, jak liczba epok, w których nie widać poprawy, jest ustawiona na 3 ( patience ):

callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)

# Only around 25 epochs are run during training, instead of 100.
history = model.fit(
    ds_train,
    epochs=100,
    validation_data=ds_test,
    callbacks=[callback]
)

len(history.history['loss'])
Epoch 1/100
469/469 [==============================] - 5s 8ms/step - loss: 0.2371 - sparse_categorical_accuracy: 0.9293 - val_loss: 0.1334 - val_sparse_categorical_accuracy: 0.9611
Epoch 2/100
469/469 [==============================] - 1s 3ms/step - loss: 0.1028 - sparse_categorical_accuracy: 0.9686 - val_loss: 0.1062 - val_sparse_categorical_accuracy: 0.9667
Epoch 3/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0703 - sparse_categorical_accuracy: 0.9783 - val_loss: 0.0993 - val_sparse_categorical_accuracy: 0.9707
Epoch 4/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0552 - sparse_categorical_accuracy: 0.9822 - val_loss: 0.1040 - val_sparse_categorical_accuracy: 0.9680
Epoch 5/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0420 - sparse_categorical_accuracy: 0.9865 - val_loss: 0.1033 - val_sparse_categorical_accuracy: 0.9716
Epoch 6/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0387 - sparse_categorical_accuracy: 0.9871 - val_loss: 0.1167 - val_sparse_categorical_accuracy: 0.9691
Epoch 7/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0321 - sparse_categorical_accuracy: 0.9893 - val_loss: 0.1396 - val_sparse_categorical_accuracy: 0.9672
Epoch 8/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0285 - sparse_categorical_accuracy: 0.9902 - val_loss: 0.1397 - val_sparse_categorical_accuracy: 0.9671
Epoch 9/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0263 - sparse_categorical_accuracy: 0.9915 - val_loss: 0.1296 - val_sparse_categorical_accuracy: 0.9715
Epoch 10/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0250 - sparse_categorical_accuracy: 0.9915 - val_loss: 0.1440 - val_sparse_categorical_accuracy: 0.9715
Epoch 11/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0274 - sparse_categorical_accuracy: 0.9910 - val_loss: 0.1439 - val_sparse_categorical_accuracy: 0.9710
Epoch 12/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0241 - sparse_categorical_accuracy: 0.9923 - val_loss: 0.1429 - val_sparse_categorical_accuracy: 0.9718
Epoch 13/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0205 - sparse_categorical_accuracy: 0.9929 - val_loss: 0.1451 - val_sparse_categorical_accuracy: 0.9753
Epoch 14/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0196 - sparse_categorical_accuracy: 0.9936 - val_loss: 0.1562 - val_sparse_categorical_accuracy: 0.9750
Epoch 15/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0214 - sparse_categorical_accuracy: 0.9930 - val_loss: 0.1531 - val_sparse_categorical_accuracy: 0.9748
Epoch 16/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0178 - sparse_categorical_accuracy: 0.9941 - val_loss: 0.1712 - val_sparse_categorical_accuracy: 0.9731
Epoch 17/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0177 - sparse_categorical_accuracy: 0.9947 - val_loss: 0.1715 - val_sparse_categorical_accuracy: 0.9755
Epoch 18/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0141 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1826 - val_sparse_categorical_accuracy: 0.9730
Epoch 19/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0188 - sparse_categorical_accuracy: 0.9942 - val_loss: 0.1919 - val_sparse_categorical_accuracy: 0.9732
Epoch 20/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0190 - sparse_categorical_accuracy: 0.9944 - val_loss: 0.1703 - val_sparse_categorical_accuracy: 0.9777
Epoch 21/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0153 - sparse_categorical_accuracy: 0.9951 - val_loss: 0.1725 - val_sparse_categorical_accuracy: 0.9764
21

TensorFlow 2: Wczesne zatrzymanie dzięki niestandardowemu wywołaniu zwrotnemu i Model.fit

Możesz także zaimplementować niestandardowe wywołanie zwrotne wczesnego zatrzymywania , które można również przekazać do parametru callbacks w Model.fit (lub Model.evaluate ).

W tym przykładzie proces uczenia zostaje zatrzymany po self.model.stop_training na True :

class LimitTrainingTime(tf.keras.callbacks.Callback):
  def __init__(self, max_time_s):
    super().__init__()
    self.max_time_s = max_time_s
    self.start_time = None

  def on_train_begin(self, logs):
    self.start_time = time.time()

  def on_train_batch_end(self, batch, logs):
    now = time.time()
    if now - self.start_time >  self.max_time_s:
      self.model.stop_training = True
# Limit the training time to 30 seconds.
callback = LimitTrainingTime(30)
history = model.fit(
    ds_train,
    epochs=100,
    validation_data=ds_test,
    callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0131 - sparse_categorical_accuracy: 0.9961 - val_loss: 0.1911 - val_sparse_categorical_accuracy: 0.9749
Epoch 2/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0133 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.1999 - val_sparse_categorical_accuracy: 0.9755
Epoch 3/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0153 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1927 - val_sparse_categorical_accuracy: 0.9770
Epoch 4/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0145 - sparse_categorical_accuracy: 0.9957 - val_loss: 0.2279 - val_sparse_categorical_accuracy: 0.9753
Epoch 5/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0141 - sparse_categorical_accuracy: 0.9959 - val_loss: 0.2272 - val_sparse_categorical_accuracy: 0.9755
Epoch 6/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0132 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.2352 - val_sparse_categorical_accuracy: 0.9747
Epoch 7/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0144 - sparse_categorical_accuracy: 0.9960 - val_loss: 0.2421 - val_sparse_categorical_accuracy: 0.9734
Epoch 8/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0128 - sparse_categorical_accuracy: 0.9964 - val_loss: 0.2260 - val_sparse_categorical_accuracy: 0.9785
Epoch 9/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0129 - sparse_categorical_accuracy: 0.9965 - val_loss: 0.2472 - val_sparse_categorical_accuracy: 0.9752
Epoch 10/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0143 - sparse_categorical_accuracy: 0.9961 - val_loss: 0.2166 - val_sparse_categorical_accuracy: 0.9768
Epoch 11/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0145 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2289 - val_sparse_categorical_accuracy: 0.9781
Epoch 12/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0119 - sparse_categorical_accuracy: 0.9968 - val_loss: 0.2310 - val_sparse_categorical_accuracy: 0.9777
Epoch 13/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0144 - sparse_categorical_accuracy: 0.9966 - val_loss: 0.2617 - val_sparse_categorical_accuracy: 0.9781
Epoch 14/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0119 - sparse_categorical_accuracy: 0.9972 - val_loss: 0.3007 - val_sparse_categorical_accuracy: 0.9754
Epoch 15/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0150 - sparse_categorical_accuracy: 0.9966 - val_loss: 0.3014 - val_sparse_categorical_accuracy: 0.9767
Epoch 16/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0143 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2815 - val_sparse_categorical_accuracy: 0.9750
Epoch 17/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0129 - sparse_categorical_accuracy: 0.9967 - val_loss: 0.2606 - val_sparse_categorical_accuracy: 0.9765
Epoch 18/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0103 - sparse_categorical_accuracy: 0.9975 - val_loss: 0.2602 - val_sparse_categorical_accuracy: 0.9777
Epoch 19/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0098 - sparse_categorical_accuracy: 0.9979 - val_loss: 0.2594 - val_sparse_categorical_accuracy: 0.9780
Epoch 20/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0156 - sparse_categorical_accuracy: 0.9965 - val_loss: 0.3008 - val_sparse_categorical_accuracy: 0.9755
Epoch 21/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0110 - sparse_categorical_accuracy: 0.9974 - val_loss: 0.2662 - val_sparse_categorical_accuracy: 0.9765
Epoch 22/100
469/469 [==============================] - 1s 1ms/step - loss: 0.0083 - sparse_categorical_accuracy: 0.9978 - val_loss: 0.2587 - val_sparse_categorical_accuracy: 0.9797
22

TensorFlow 2: Wczesne zatrzymanie dzięki niestandardowej pętli treningowej

W TensorFlow 2 możesz zaimplementować wczesne zatrzymywanie w niestandardowej pętli treningowej, jeśli nie trenujesz i nie oceniasz za pomocą wbudowanych metod Keras .

Zacznij od użycia interfejsów API Keras do zdefiniowania innego prostego modelu, optymalizatora, funkcji straty i metryk:

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])

optimizer = tf.keras.optimizers.Adam(0.005)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
train_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
val_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()

Zdefiniuj funkcje aktualizacji parametrów za pomocą tf.GradientTape i dekoratora @tf.function w celu przyspieszenia :

@tf.function
def train_step(x, y):
  with tf.GradientTape() as tape:
      logits = model(x, training=True)
      loss_value = loss_fn(y, logits)
  grads = tape.gradient(loss_value, model.trainable_weights)
  optimizer.apply_gradients(zip(grads, model.trainable_weights))
  train_acc_metric.update_state(y, logits)
  train_loss_metric.update_state(y, logits)
  return loss_value

@tf.function
def test_step(x, y):
  logits = model(x, training=False)
  val_acc_metric.update_state(y, logits)
  val_loss_metric.update_state(y, logits)

Następnie napisz niestandardową pętlę treningową, w której możesz ręcznie zaimplementować regułę wczesnego zatrzymywania.

Poniższy przykład pokazuje, jak przestać trenować, gdy utrata walidacji nie poprawia się przez określoną liczbę epok:

epochs = 100
patience = 5
wait = 0
best = 0

for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    for step, (x_batch_train, y_batch_train) in enumerate(ds_train):
      loss_value = train_step(x_batch_train, y_batch_train)
      if step % 200 == 0:
        print("Training loss at step %d: %.4f" % (step, loss_value.numpy()))
        print("Seen so far: %s samples" % ((step + 1) * 128))        
    train_acc = train_acc_metric.result()
    train_loss = train_loss_metric.result()
    train_acc_metric.reset_states()
    train_loss_metric.reset_states()
    print("Training acc over epoch: %.4f" % (train_acc.numpy()))

    for x_batch_val, y_batch_val in ds_test:
      test_step(x_batch_val, y_batch_val)
    val_acc = val_acc_metric.result()
    val_loss = val_loss_metric.result()
    val_acc_metric.reset_states()
    val_loss_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))

    # The early stopping strategy: stop the training if `val_loss` does not
    # decrease over a certain number of epochs.
    wait += 1
    if val_loss > best:
      best = val_loss
      wait = 0
    if wait >= patience:
      break
Start of epoch 0
Training loss at step 0: 2.3073
Seen so far: 128 samples
Training loss at step 200: 0.2164
Seen so far: 25728 samples
Training loss at step 400: 0.2186
Seen so far: 51328 samples
Training acc over epoch: 0.9321
Validation acc: 0.9644
Time taken: 1.73s

Start of epoch 1
Training loss at step 0: 0.0733
Seen so far: 128 samples
Training loss at step 200: 0.1581
Seen so far: 25728 samples
Training loss at step 400: 0.1625
Seen so far: 51328 samples
Training acc over epoch: 0.9704
Validation acc: 0.9681
Time taken: 1.23s

Start of epoch 2
Training loss at step 0: 0.0501
Seen so far: 128 samples
Training loss at step 200: 0.1389
Seen so far: 25728 samples
Training loss at step 400: 0.1495
Seen so far: 51328 samples
Training acc over epoch: 0.9779
Validation acc: 0.9703
Time taken: 1.17s

Start of epoch 3
Training loss at step 0: 0.0513
Seen so far: 128 samples
Training loss at step 200: 0.0638
Seen so far: 25728 samples
Training loss at step 400: 0.0930
Seen so far: 51328 samples
Training acc over epoch: 0.9830
Validation acc: 0.9719
Time taken: 1.20s

Start of epoch 4
Training loss at step 0: 0.0251
Seen so far: 128 samples
Training loss at step 200: 0.0482
Seen so far: 25728 samples
Training loss at step 400: 0.0872
Seen so far: 51328 samples
Training acc over epoch: 0.9849
Validation acc: 0.9672
Time taken: 1.18s

Start of epoch 5
Training loss at step 0: 0.0417
Seen so far: 128 samples
Training loss at step 200: 0.0302
Seen so far: 25728 samples
Training loss at step 400: 0.0362
Seen so far: 51328 samples
Training acc over epoch: 0.9878
Validation acc: 0.9703
Time taken: 1.21s

Następne kroki