TensorFlow.org で表示 | Google Colab で実行 | GitHub でソースを表示 | ノートブックをダウンロード |
フォールトトレランスとは、パラメータやモデルなどの追跡可能なオブジェクトの状態を定期的に保存するメカニズムを指します。これにより、トレーニング中にプログラム/マシンに障害が発生した場合に回復が可能になります。
このガイドでは、まず tf.estimator.RunConfig
で指標の保存を指定することにより、TensorFlow 1 で tf.estimator.Estimator
を使用してトレーニングにフォールトトレランスを追加する方法を示します。次に、Tensorflow 2 でのトレーニングにフォールトトレランスを実装する 2 つの方法を学習します。
- Keras
Model.fit
API を使用する場合、tf.keras.callbacks.BackupAndRestore
コールバックを渡すことができます。 - カスタムトレーニングループ(
tf.GradientTape
を使用)を使用する場合、tf.train.Checkpoint
およびtf.train.CheckpointManager
API を使用してチェックポイントを任意に保存できます。
これらの方法は両方とも、チェックポイントファイルのトレーニング状態をバックアップおよび復元します。
セットアップ
tf.keras.callbacks.BackupAndRestore
の save_freq
引数を使用した特定のステップでのチェックポイント保存の頻度が TensorFlow 2.10 から導入されたため、tf-nightly
をインストールします。
pip install tf-nightly
import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
import time
2022-12-14 22:42:52.938218: E tensorflow/tsl/lib/monitoring/collection_registry.cc:81] Cannot register 2 metrics with the same name: /tensorflow/core/bfc_allocator_delay
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
TensorFlow 1: tf.estimator.RunConfig
でチェックポイントを保存する
TensorFlow 1 では、tf.estimator.RunConfig
を構成して、各ステップでチェックポイントを保存するために tf.estimator
を構成できます。
この例では、5 番目のチェックポイントで人為的にエラーをスローするフックを作成することから始めます。
class InterruptHook(tf1.train.SessionRunHook):
# A hook for artificially interrupting training.
def begin(self):
self._step = -1
def before_run(self, run_context):
self._step += 1
def after_run(self, run_context, run_values):
if self._step == 5:
raise RuntimeError('Interruption')
次に、すべてのチェックポイントを保存し、MNIST データセットを使用するように tf.estimator.Estimator
を構成します。
feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]
config = tf1.estimator.RunConfig(save_summary_steps=1,
save_checkpoints_steps=1)
path = tempfile.mkdtemp()
classifier = tf1.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[256, 32],
optimizer=tf1.train.AdamOptimizer(0.001),
n_classes=10,
dropout=0.2,
model_dir=path,
config = config
)
train_input_fn = tf1.estimator.inputs.numpy_input_fn(
x={"x": x_train},
y=y_train.astype(np.int32),
num_epochs=10,
batch_size=50,
shuffle=True,
)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_199879/314197976.py:1: numeric_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version. Instructions for updating: Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_199879/314197976.py:2: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_199879/314197976.py:7: DNNClassifier.__init__ (from tensorflow_estimator.python.estimator.canned.dnn) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/dnn.py:807: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpo_j1tnnb', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} WARNING:tensorflow:From /tmpfs/tmp/ipykernel_199879/314197976.py:17: The name tf.estimator.inputs is deprecated. Please use tf.compat.v1.estimator.inputs instead. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_199879/314197976.py:17: numpy_input_fn (from tensorflow_estimator.python.estimator.inputs.numpy_io) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead.
モデルのトレーニングを開始します。前に定義したフックによって人為的な例外が発生します。
try:
classifier.train(input_fn=train_input_fn,
hooks=[InterruptHook()],
max_steps=10)
except Exception as e:
print(f'{type(e).__name__}:{e}')
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_199879/2587623597.py:3: object.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:385: StopAtStepHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:60: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py:491: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/dnn.py:446: dnn_logit_fn_builder (from tensorflow_estimator.python.estimator.canned.dnn) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/model_fn.py:250: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Done calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1414: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1417: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1454: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Create CheckpointSaverHook. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:910: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. 2022-12-14 22:42:59.835374: W tensorflow/core/common_runtime/type_inference.cc:339] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1: type_id: TFT_OPTIONAL args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_INT64 } } } is neither a subtype nor a supertype of the combined inputs preceding it: type_id: TFT_OPTIONAL args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_INT32 } } } while inferring type of node 'dnn/zero_fraction/cond/output/_18' INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpo_j1tnnb/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1... INFO:tensorflow:Saving checkpoints for 1 into /tmpfs/tmp/tmpo_j1tnnb/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1... INFO:tensorflow:loss = 117.2641, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2... INFO:tensorflow:Saving checkpoints for 2 into /tmpfs/tmp/tmpo_j1tnnb/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3... INFO:tensorflow:Saving checkpoints for 3 into /tmpfs/tmp/tmpo_j1tnnb/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4... INFO:tensorflow:Saving checkpoints for 4 into /tmpfs/tmp/tmpo_j1tnnb/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5... INFO:tensorflow:Saving checkpoints for 5 into /tmpfs/tmp/tmpo_j1tnnb/model.ckpt. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/saver.py:1067: remove_checkpoint (from tensorflow.python.checkpoint.checkpoint_management) is deprecated and will be removed in a future version. Instructions for updating: Use standard file APIs to delete files with this prefix. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6... INFO:tensorflow:Saving checkpoints for 6 into /tmpfs/tmp/tmpo_j1tnnb/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6... RuntimeError:Interruption
最後に保存されたチェックポイントを使用して tf.estimator.Estimator
を再構築し、トレーニングを続行します。
classifier = tf1.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[256, 32],
optimizer=tf1.train.AdamOptimizer(0.001),
n_classes=10,
dropout=0.2,
model_dir=path,
config = config
)
classifier.train(input_fn=train_input_fn,
max_steps = 10)
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpo_j1tnnb', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpo_j1tnnb/model.ckpt-6 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/saver.py:1176: get_checkpoint_mtimes (from tensorflow.python.checkpoint.checkpoint_management) is deprecated and will be removed in a future version. Instructions for updating: Use standard file utilities to get mtimes. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6... INFO:tensorflow:Saving checkpoints for 6 into /tmpfs/tmp/tmpo_j1tnnb/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7... INFO:tensorflow:Saving checkpoints for 7 into /tmpfs/tmp/tmpo_j1tnnb/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7... INFO:tensorflow:loss = 100.51071, step = 6 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8... INFO:tensorflow:Saving checkpoints for 8 into /tmpfs/tmp/tmpo_j1tnnb/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9... INFO:tensorflow:Saving checkpoints for 9 into /tmpfs/tmp/tmpo_j1tnnb/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Saving checkpoints for 10 into /tmpfs/tmp/tmpo_j1tnnb/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Loss for final step: 96.29817. <tensorflow_estimator.python.estimator.canned.dnn.DNNClassifier at 0x7fa50ef7c910>
TensorFlow 2: コールバックと Model.fit
を使用したバックアップと復元
TensorFlow 2 では、トレーニングに Keras Model.fit
API を使用する場合、tf.keras.callbacks.BackupAndRestore
コールバックを提供してフォールトトレランス機能を追加できます。
これを実証するために、最初に 4 番目のエポックチェックポイントで人為的にエラーをスローする Keras Callback
クラスを定義することから始めます。
class InterruptAtEpoch(tf.keras.callbacks.Callback):
# A callback for artificially interrupting training.
def __init__(self, interrupting_epoch=3):
self.interrupting_epoch = interrupting_epoch
def on_epoch_end(self, epoch, log=None):
if epoch == self.interrupting_epoch:
raise RuntimeError('Interruption')
次に、単純な Keras モデルを定義してインスタンス化し、損失関数を定義して Model.compile
を呼び出し、エポックの境界で一時ディレクトリにチェックポイントを保存する tf.keras.callbacks.BackupAndRestore
コールバックを設定します。
def create_model():
return tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10)
])
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model = create_model()
model.compile(optimizer='adam',
loss=loss,
metrics=['accuracy'])
log_dir = tempfile.mkdtemp()
backup_restore_callback = tf.keras.callbacks.BackupAndRestore(
backup_dir = log_dir)
Model.fit
でモデルのトレーニングを開始します。トレーニング中、上記でインスタンス化された tf.keras.callbacks.BackupAndRestore
のおかげでチェックポイントが保存されますが、InterruptAtEpoch
クラスは 4 番目のエポックの後に失敗をシミュレートするために人為的な例外を発生させます。
try:
model.fit(x=x_train,
y=y_train,
epochs=10,
steps_per_epoch=100,
validation_data=(x_test, y_test),
callbacks=[backup_restore_callback, InterruptAtEpoch()])
except Exception as e:
print(f'{type(e).__name__}:{e}')
Epoch 1/10 100/100 [==============================] - 2s 11ms/step - loss: 0.4694 - accuracy: 0.8690 - val_loss: 0.2198 - val_accuracy: 0.9398 Epoch 2/10 100/100 [==============================] - 1s 8ms/step - loss: 0.1980 - accuracy: 0.9438 - val_loss: 0.1538 - val_accuracy: 0.9559 Epoch 3/10 100/100 [==============================] - 1s 8ms/step - loss: 0.1452 - accuracy: 0.9598 - val_loss: 0.1219 - val_accuracy: 0.9643 Epoch 4/10 94/100 [===========================>..] - ETA: 0s - loss: 0.1167 - accuracy: 0.9671RuntimeError:Interruption
次に、Keras モデルをインスタンス化し、Model.compile
を呼び出し、以前に保存したチェックポイントから Model.fit
を使用してモデルのトレーニングを続けます。
model = create_model()
model.compile(optimizer='adam',
loss=loss,
metrics=['accuracy'],
steps_per_execution=10)
model.fit(x=x_train,
y=y_train,
epochs=10,
steps_per_epoch=100,
validation_data=(x_test, y_test),
callbacks=[backup_restore_callback])
Epoch 5/10 100/100 [==============================] - 2s 19ms/step - loss: 0.0943 - accuracy: 0.9733 - val_loss: 0.0899 - val_accuracy: 0.9730 Epoch 6/10 100/100 [==============================] - 0s 5ms/step - loss: 0.0778 - accuracy: 0.9778 - val_loss: 0.0825 - val_accuracy: 0.9752 Epoch 7/10 100/100 [==============================] - 0s 5ms/step - loss: 0.0679 - accuracy: 0.9801 - val_loss: 0.0770 - val_accuracy: 0.9756 Epoch 8/10 100/100 [==============================] - 0s 5ms/step - loss: 0.0577 - accuracy: 0.9834 - val_loss: 0.0715 - val_accuracy: 0.9775 Epoch 9/10 100/100 [==============================] - 0s 5ms/step - loss: 0.0523 - accuracy: 0.9848 - val_loss: 0.0703 - val_accuracy: 0.9791 Epoch 10/10 100/100 [==============================] - 0s 5ms/step - loss: 0.0451 - accuracy: 0.9873 - val_loss: 0.0644 - val_accuracy: 0.9800 <keras.callbacks.History at 0x7fa41c5dce20>
140 番目のステップで人為的にエラーをスローする別の Callback
クラスを定義します。
class InterruptAtStep(tf.keras.callbacks.Callback):
# A callback for artificially interrupting training.
def __init__(self, interrupting_step=140):
self.total_step_count = 0
self.interrupting_step = interrupting_step
def on_batch_begin(self, batch, logs=None):
self.total_step_count += 1
def on_batch_end(self, batch, logs=None):
if self.total_step_count == self.interrupting_step:
print("\nInterrupting at step count", self.total_step_count)
raise RuntimeError('Interruption')
注意: このセクションでは、Tensorflow 2.10 がリリースされるまで tf-nightly
でのみ利用可能な機能を使用します。
チェックポイントが 30 ステップごとに保存されるようにするには、BackupAndRestore
コールバックの save_freq
を 30
に設定します。 InterruptAtStep
は、エポック 1 およびステップ 40(合計ステップ数 140)での失敗をシミュレートするために人為的な例外を発生させます。チェックポイントは、エポック 1 とステップ 20 で最後に保存されます。
log_dir_2 = tempfile.mkdtemp()
backup_restore_callback = tf.keras.callbacks.BackupAndRestore(
backup_dir = log_dir_2, save_freq=30
)
model = create_model()
model.compile(optimizer='adam',
loss=loss,
metrics=['accuracy'])
try:
model.fit(x=x_train,
y=y_train,
epochs=10,
steps_per_epoch=100,
validation_data=(x_test, y_test),
callbacks=[backup_restore_callback, InterruptAtStep()])
except Exception as e:
print(f'{type(e).__name__}:{e}')
Epoch 1/10 100/100 [==============================] - 2s 12ms/step - loss: 0.4631 - accuracy: 0.8691 - val_loss: 0.2199 - val_accuracy: 0.9390 Epoch 2/10 38/100 [==========>...................] - ETA: 0s - loss: 0.2255 - accuracy: 0.9354 Interrupting at step count 140 RuntimeError:Interruption
次に、Keras モデルをインスタンス化し、Model.compile
を呼び出し、以前に保存したチェックポイントから Model.fit
を使用してモデルのトレーニングを続けます。エポック 2 とステップ 21 からトレーニングが開始されることに注意してください。
model = create_model()
model.compile(optimizer='adam',
loss=loss,
metrics=['accuracy'],
steps_per_execution=10)
model.fit(x=x_train,
y=y_train,
epochs=10,
steps_per_epoch=100,
validation_data=(x_test, y_test),
callbacks=[backup_restore_callback])
Epoch 2/10 100/100 [==============================] - 2s 17ms/step - loss: 0.1932 - accuracy: 0.9463 - val_loss: 0.1557 - val_accuracy: 0.9563 Epoch 3/10 100/100 [==============================] - 1s 5ms/step - loss: 0.1427 - accuracy: 0.9592 - val_loss: 0.1207 - val_accuracy: 0.9647 Epoch 4/10 100/100 [==============================] - 0s 5ms/step - loss: 0.1155 - accuracy: 0.9666 - val_loss: 0.1018 - val_accuracy: 0.9695 Epoch 5/10 100/100 [==============================] - 0s 5ms/step - loss: 0.0950 - accuracy: 0.9729 - val_loss: 0.0904 - val_accuracy: 0.9737 Epoch 6/10 100/100 [==============================] - 0s 5ms/step - loss: 0.0797 - accuracy: 0.9775 - val_loss: 0.0815 - val_accuracy: 0.9760 Epoch 7/10 100/100 [==============================] - 0s 4ms/step - loss: 0.0666 - accuracy: 0.9809 - val_loss: 0.0736 - val_accuracy: 0.9780 Epoch 8/10 100/100 [==============================] - 0s 4ms/step - loss: 0.0578 - accuracy: 0.9838 - val_loss: 0.0701 - val_accuracy: 0.9790 Epoch 9/10 100/100 [==============================] - 0s 5ms/step - loss: 0.0521 - accuracy: 0.9853 - val_loss: 0.0662 - val_accuracy: 0.9791 Epoch 10/10 100/100 [==============================] - 0s 5ms/step - loss: 0.0442 - accuracy: 0.9873 - val_loss: 0.0652 - val_accuracy: 0.9796 <keras.callbacks.History at 0x7fa4545119a0>
TensorFlow 2: カスタムトレーニングループを使用して手動チェックポイントを作成する
TensorFlow 2 でカスタムトレーニングループを使用する場合、tf.train.Checkpoint
および tf.train.CheckpointManager
API を使用してフォールトトレランスメカニズムを実装できます。
この例は、次の方法を示しています。
tf.train.Checkpoint
オブジェクトを使用してチェックポイントを手動で作成します。保存する追跡可能なオブジェクトが属性として設定されます。- 複数のチェックポイントを管理するには、
tf.train.CheckpointManager
を使用します。
Keras モデル、オプティマイザ、および損失関数を定義してインスタンス化することから始めます。次に、追跡可能な状態を持つ 2 つのオブジェクト(モデルとオプティマイザ)を管理する Checkpoint
と、いくつかのチェックポイントを一時ディレクトリに記録して保持するための CheckpointManager
を作成します。
model = create_model()
optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
log_dir = tempfile.mkdtemp()
epochs = 5
steps_per_epoch = 5
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, log_dir, max_to_keep=2)
ここで、新しいエポックが開始されるたびに最初のエポックの後に最後のチェックポイントが読み込まれるカスタムトレーニングループを実装します。
for epoch in range(epochs):
if epoch > 0:
tf.train.load_checkpoint(save_path)
print(f"\nStart of epoch {epoch}")
for step in range(steps_per_epoch):
with tf.GradientTape() as tape:
logits = model(x_train, training=True)
loss_value = loss_fn(y_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
save_path = checkpoint_manager.save()
print(f"Checkpoint saved to {save_path}")
print(f"Training loss at step {step}: {loss_value}")
Start of epoch 0 Checkpoint saved to /tmpfs/tmp/tmpf3an9neo/ckpt-1 Training loss at step 0: 2.3486015796661377 Checkpoint saved to /tmpfs/tmp/tmpf3an9neo/ckpt-2 Training loss at step 1: 2.3473939895629883 Checkpoint saved to /tmpfs/tmp/tmpf3an9neo/ckpt-3 Training loss at step 2: 2.3460042476654053 Checkpoint saved to /tmpfs/tmp/tmpf3an9neo/ckpt-4 Training loss at step 3: 2.3438942432403564 Checkpoint saved to /tmpfs/tmp/tmpf3an9neo/ckpt-5 Training loss at step 4: 2.343919277191162 Start of epoch 1 Checkpoint saved to /tmpfs/tmp/tmpf3an9neo/ckpt-6 Training loss at step 0: 2.341978073120117 Checkpoint saved to /tmpfs/tmp/tmpf3an9neo/ckpt-7 Training loss at step 1: 2.341369152069092 Checkpoint saved to /tmpfs/tmp/tmpf3an9neo/ckpt-8 Training loss at step 2: 2.3405568599700928 Checkpoint saved to /tmpfs/tmp/tmpf3an9neo/ckpt-9 Training loss at step 3: 2.339402198791504 Checkpoint saved to /tmpfs/tmp/tmpf3an9neo/ckpt-10 Training loss at step 4: 2.3377020359039307 Start of epoch 2 Checkpoint saved to /tmpfs/tmp/tmpf3an9neo/ckpt-11 Training loss at step 0: 2.335585355758667 Checkpoint saved to /tmpfs/tmp/tmpf3an9neo/ckpt-12 Training loss at step 1: 2.334397315979004 Checkpoint saved to /tmpfs/tmp/tmpf3an9neo/ckpt-13 Training loss at step 2: 2.332718849182129 Checkpoint saved to /tmpfs/tmp/tmpf3an9neo/ckpt-14 Training loss at step 3: 2.3335583209991455 Checkpoint saved to /tmpfs/tmp/tmpf3an9neo/ckpt-15 Training loss at step 4: 2.3313987255096436 Start of epoch 3 Checkpoint saved to /tmpfs/tmp/tmpf3an9neo/ckpt-16 Training loss at step 0: 2.3284120559692383 Checkpoint saved to /tmpfs/tmp/tmpf3an9neo/ckpt-17 Training loss at step 1: 2.3293144702911377 Checkpoint saved to /tmpfs/tmp/tmpf3an9neo/ckpt-18 Training loss at step 2: 2.3262722492218018 Checkpoint saved to /tmpfs/tmp/tmpf3an9neo/ckpt-19 Training loss at step 3: 2.3256146907806396 Checkpoint saved to /tmpfs/tmp/tmpf3an9neo/ckpt-20 Training loss at step 4: 2.3249266147613525 Start of epoch 4 Checkpoint saved to /tmpfs/tmp/tmpf3an9neo/ckpt-21 Training loss at step 0: 2.3231191635131836 Checkpoint saved to /tmpfs/tmp/tmpf3an9neo/ckpt-22 Training loss at step 1: 2.3204562664031982 Checkpoint saved to /tmpfs/tmp/tmpf3an9neo/ckpt-23 Training loss at step 2: 2.320456027984619 Checkpoint saved to /tmpfs/tmp/tmpf3an9neo/ckpt-24 Training loss at step 3: 2.319236993789673 Checkpoint saved to /tmpfs/tmp/tmpf3an9neo/ckpt-25 Training loss at step 4: 2.3184564113616943
Next steps
TensorFlow 2 のフォールトトレランスとチェックポイントの詳細については、次のドキュメントを参照してください。
tf.keras.callbacks.BackupAndRestore
コールバック API ドキュメント。tf.train.Checkpoint
およびtf.train.CheckpointManager
API ドキュメント。- 書き込みチェックポイントセクションを含むトレーニングチェックポイントガイド。
分散トレーニングに関連する次の資料も役立つ場合があります。
- Keras を使用したマルチワーカートレーニングチュートリアルのフォールトトレランスセクション。
- パラメータサーバーのトレーニングチュートリアルのタスクの失敗の処理セクション。