مکانیسم تحمل خطا را مهاجرت کنید

مشاهده در TensorFlow.org در Google Colab اجرا شود مشاهده منبع در GitHub دانلود دفترچه یادداشت

تحمل خطا به مکانیزمی اشاره دارد که به صورت دوره ای حالت های اشیاء قابل ردیابی، مانند پارامترها و مدل ها را ذخیره می کند. این به شما امکان می دهد در صورت خرابی برنامه/ماشین در طول تمرین، آنها را بازیابی کنید.

این راهنما ابتدا نحوه افزودن تحمل خطا را به آموزش با tf.estimator.Estimator در TensorFlow 1 با مشخص کردن ذخیره متریک با tf.estimator.RunConfig نشان می دهد. سپس، نحوه اجرای تحمل خطا را برای آموزش در Tensorflow 2 به دو روش یاد خواهید گرفت:

هر دوی این روش ها از حالت های آموزشی در فایل های چک پوینت نسخه پشتیبان تهیه و بازیابی می کنند.

برپایی

import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
import time
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

TensorFlow 1: نقاط بازرسی را با tf.estimator.RunConfig ذخیره کنید

در TensorFlow 1، می‌توانید یک tf.estimator را پیکربندی کنید تا با پیکربندی tf.estimator.RunConfig ، در هر مرحله نقاط بازرسی ذخیره شود.

در این مثال، با نوشتن قلابی شروع کنید که به طور مصنوعی در خلال ایست بازرسی پنجم خطا ایجاد می کند:

class InterruptHook(tf1.train.SessionRunHook):
  # A hook for artificially interrupting training.
  def begin(self):
    self._step = -1

  def before_run(self, run_context):
    self._step += 1

  def after_run(self, run_context, run_values):
    if self._step == 5:
      raise RuntimeError('Interruption')

سپس، tf.estimator.Estimator را برای ذخیره هر چک و استفاده از مجموعه داده MNIST پیکربندی کنید:

feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]
config = tf1.estimator.RunConfig(save_summary_steps=1,
                                 save_checkpoints_steps=1)

path = tempfile.mkdtemp()

classifier = tf1.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[256, 32],
    optimizer=tf1.train.AdamOptimizer(0.001),
    n_classes=10,
    dropout=0.2,
    model_dir=path,
    config = config
)

train_input_fn = tf1.estimator.inputs.numpy_input_fn(
    x={"x": x_train},
    y=y_train.astype(np.int32),
    num_epochs=10,
    batch_size=50,
    shuffle=True,
)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpv15yxr9g', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmp/ipykernel_20837/314197976.py:17: The name tf.estimator.inputs is deprecated. Please use tf.compat.v1.estimator.inputs instead.

WARNING:tensorflow:From /tmp/ipykernel_20837/314197976.py:17: The name tf.estimator.inputs.numpy_input_fn is deprecated. Please use tf.compat.v1.estimator.inputs.numpy_input_fn instead.

آموزش مدل را شروع کنید. یک استثنا مصنوعی توسط قلابی که قبلاً تعریف کردید ایجاد می شود.

try:
  classifier.train(input_fn=train_input_fn,
                   hooks=[InterruptHook()],
                   max_steps=10)
except Exception as e:
  print(f'{type(e).__name__}:{e}')
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:397: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:65: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py:491: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py:914: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1...
INFO:tensorflow:Saving checkpoints for 1 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1...
INFO:tensorflow:loss = 118.92192, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2...
INFO:tensorflow:Saving checkpoints for 2 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3...
INFO:tensorflow:Saving checkpoints for 3 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4...
INFO:tensorflow:Saving checkpoints for 4 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5...
INFO:tensorflow:Saving checkpoints for 5 into /tmp/tmpv15yxr9g/model.ckpt.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1054: remove_checkpoint (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to delete files with this prefix.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6...
INFO:tensorflow:Saving checkpoints for 6 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6...
RuntimeError:Interruption

با استفاده از آخرین بازرسی ذخیره شده tf.estimator.Estimator را بازسازی کنید و به آموزش ادامه دهید:

classifier = tf1.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[256, 32],
    optimizer=tf1.train.AdamOptimizer(0.001),
    n_classes=10,
    dropout=0.2,
    model_dir=path,
    config = config
)
classifier.train(input_fn=train_input_fn,
                   max_steps = 10)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpv15yxr9g', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpv15yxr9g/model.ckpt-6
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1161: get_checkpoint_mtimes (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file utilities to get mtimes.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6...
INFO:tensorflow:Saving checkpoints for 6 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7...
INFO:tensorflow:Saving checkpoints for 7 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7...
INFO:tensorflow:loss = 105.44863, step = 6
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8...
INFO:tensorflow:Saving checkpoints for 8 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9...
INFO:tensorflow:Saving checkpoints for 9 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 100.47882.
<tensorflow_estimator.python.estimator.canned.dnn.DNNClassifier at 0x7fcfe8165150>

TensorFlow 2: با callback و Model.fit پشتیبان بگیرید و بازیابی کنید

در TensorFlow 2، اگر از Keras Model.fit API برای آموزش استفاده می کنید، می توانید پاسخ تماس tf.keras.callbacks.BackupAndRestore را برای افزودن قابلیت تحمل خطا ارائه دهید.

برای کمک به نشان دادن این موضوع، اجازه دهید ابتدا با تعریف یک کلاس برگشتی شروع کنیم که به طور مصنوعی در حین بازرسی پنجم خطا ایجاد می کند:

class InterruptingCallback(tf.keras.callbacks.Callback):
  # A callback for artificially interrupting training.
  def on_epoch_end(self, epoch, log=None):
    if epoch == 4:
      raise RuntimeError('Interruption')

سپس، یک مدل Keras ساده را تعریف و نمونه‌سازی کنید، تابع ضرر را تعریف کنید، Model.compile را فراخوانی کنید، و یک پاسخ تماس tf.keras.callbacks.BackupAndRestore کنید که نقاط بازرسی را در یک فهرست موقت ذخیره می‌کند:

def create_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10)
  ])

loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

model = create_model()
model.compile(optimizer='adam',
              loss=loss,
              metrics=['accuracy'],
              steps_per_execution=10)

log_dir = tempfile.mkdtemp()

backup_restore_callback = tf.keras.callbacks.BackupAndRestore(
    backup_dir = log_dir
)

اکنون آموزش مدل را با Model.fit کنید. در طول آموزش، نقاط بازرسی به لطف backup_restore_callback تعریف شده در بالا ذخیره می شوند، در حالی که InterruptingCallback یک استثنا مصنوعی برای شبیه سازی یک شکست ایجاد می کند.

try:
  model.fit(x=x_train,
            y=y_train,
            epochs=10,
            validation_data=(x_test, y_test),
            callbacks=[backup_restore_callback, InterruptingCallback()])
except Exception as e:
  print(f'{type(e).__name__}:{e}')
Epoch 1/10
1875/1875 [==============================] - 3s 2ms/step - loss: 0.2186 - accuracy: 0.9352 - val_loss: 0.1267 - val_accuracy: 0.9615
Epoch 2/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0967 - accuracy: 0.9700 - val_loss: 0.0910 - val_accuracy: 0.9718
Epoch 3/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0687 - accuracy: 0.9784 - val_loss: 0.0679 - val_accuracy: 0.9797
Epoch 4/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0527 - accuracy: 0.9829 - val_loss: 0.0623 - val_accuracy: 0.9814
Epoch 5/10
1860/1875 [============================>.] - ETA: 0s - loss: 0.0434 - accuracy: 0.9857RuntimeError:Interruption

سپس، مدل Keras را نمونه‌سازی کنید، Model.compile و آموزش مدل را با Model.fit از یک چکپوینت ذخیره‌شده قبلی ادامه دهید:

model = create_model()
model.compile(optimizer='adam',
              loss=loss,
              metrics=['accuracy'],
              steps_per_execution=10)
model.fit(x=x_train,
            y=y_train,
            epochs=10,
            validation_data=(x_test, y_test),
            callbacks=[backup_restore_callback])
Epoch 6/10
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0370 - accuracy: 0.9879 - val_loss: 0.0732 - val_accuracy: 0.9791
Epoch 7/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0306 - accuracy: 0.9898 - val_loss: 0.0601 - val_accuracy: 0.9827
Epoch 8/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0259 - accuracy: 0.9913 - val_loss: 0.0655 - val_accuracy: 0.9819
Epoch 9/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0244 - accuracy: 0.9918 - val_loss: 0.0746 - val_accuracy: 0.9812
Epoch 10/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0221 - accuracy: 0.9923 - val_loss: 0.0818 - val_accuracy: 0.9813
<keras.callbacks.History at 0x7fcfe0647350>

TensorFlow 2: نقاط بازرسی دستی را با یک حلقه آموزشی سفارشی بنویسید

اگر از یک حلقه آموزشی سفارشی در TensorFlow 2 استفاده می‌کنید، می‌توانید مکانیزم تحمل خطا را با API‌های tf.train.Checkpoint و tf.train.CheckpointManager کنید.

این مثال نشان می دهد که چگونه:

  • از یک شی tf.train.Checkpoint برای ایجاد دستی یک نقطه بازرسی استفاده کنید، جایی که اشیاء قابل ردیابی که می خواهید ذخیره کنید به عنوان ویژگی تنظیم می شوند.
  • از tf.train.CheckpointManager برای مدیریت چندین ایست بازرسی استفاده کنید.

با تعریف و نمونه سازی مدل Keras، بهینه ساز و تابع ضرر شروع کنید. سپس، یک Checkpoint ایجاد کنید که دو شیء را با حالت های قابل ردیابی (مدل و بهینه ساز)، و همچنین یک CheckpointManager برای ثبت و نگهداری چندین چک پوینت در یک فهرست موقت ایجاد کنید.

model = create_model()
optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
log_dir = tempfile.mkdtemp()
epochs = 5
steps_per_epoch = 5

checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
checkpoint_manager = tf.train.CheckpointManager(
            checkpoint, log_dir, max_to_keep=2)

اکنون، یک حلقه آموزشی سفارشی را پیاده سازی کنید که در آن پس از اولین دوره، هر بار که یک دوره جدید شروع می شود، آخرین بازرسی بارگذاری می شود:

for epoch in range(epochs):
  if epoch > 0:
      tf.train.load_checkpoint(save_path)
  print(f"\nStart of epoch {epoch}")

  for step in range(steps_per_epoch):
    with tf.GradientTape() as tape:

      logits = model(x_train, training=True)
      loss_value = loss_fn(y_train, logits)

      grads = tape.gradient(loss_value, model.trainable_weights)
      optimizer.apply_gradients(zip(grads, model.trainable_weights))

    save_path = checkpoint_manager.save()
    print(f"Checkpoint saved to {save_path}")
    print(f"Training loss at step {step}: {loss_value}")
Start of epoch 0
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-1
Training loss at step 0: 2.3636362552642822
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-2
Training loss at step 1: 2.3626415729522705
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-3
Training loss at step 2: 2.3613197803497314
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-4
Training loss at step 3: 2.360600233078003
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-5
Training loss at step 4: 2.3589422702789307

Start of epoch 1
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-6
Training loss at step 0: 2.3563339710235596
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-7
Training loss at step 1: 2.3568854331970215
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-8
Training loss at step 2: 2.354109287261963
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-9
Training loss at step 3: 2.3532731533050537
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-10
Training loss at step 4: 2.351112127304077

Start of epoch 2
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-11
Training loss at step 0: 2.348905563354492
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-12
Training loss at step 1: 2.349478006362915
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-13
Training loss at step 2: 2.3487260341644287
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-14
Training loss at step 3: 2.345991611480713
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-15
Training loss at step 4: 2.3451104164123535

Start of epoch 3
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-16
Training loss at step 0: 2.3441312313079834
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-17
Training loss at step 1: 2.341529130935669
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-18
Training loss at step 2: 2.342329263687134
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-19
Training loss at step 3: 2.340449571609497
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-20
Training loss at step 4: 2.3367927074432373

Start of epoch 4
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-21
Training loss at step 0: 2.3366076946258545
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-22
Training loss at step 1: 2.335028886795044
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-23
Training loss at step 2: 2.3338520526885986
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-24
Training loss at step 3: 2.3345272541046143
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-25
Training loss at step 4: 2.332385301589966

مراحل بعدی

برای کسب اطلاعات بیشتر در مورد تحمل خطا و چک پوینت در TensorFlow 2، مستندات زیر را در نظر بگیرید:

همچنین ممکن است مطالب زیر مربوط به آموزش توزیع شده مفید باشد: