Voir sur TensorFlow.org | Exécuter dans Google Colab | Voir la source sur GitHub | Télécharger le cahier |
L'enregistrement continu du "meilleur" modèle ou des poids/paramètres du modèle présente de nombreux avantages. Ceux-ci incluent la possibilité de suivre la progression de la formation et de charger des modèles enregistrés à partir de différents états enregistrés.
Dans TensorFlow 1, pour configurer l'enregistrement des points de contrôle pendant l'entraînement/la validation avec les API tf.estimator.Estimator
, vous spécifiez une planification dans tf.estimator.RunConfig
ou utilisez tf.estimator.CheckpointSaverHook
. Ce guide explique comment migrer de ce flux de travail vers les API TensorFlow 2 Keras.
Dans TensorFlow 2, vous pouvez configurer tf.keras.callbacks.ModelCheckpoint
de plusieurs manières :
- Enregistrez la "meilleure" version selon une métrique surveillée à l'aide du paramètre
save_best_only=True
, oùmonitor
peut être, par exemple,'loss'
,'val_loss'
,'accuracy', or
'val_accuracy'`. - Enregistrez continuellement à une certaine fréquence (en utilisant l'argument
save_freq
). - Enregistrez les poids/paramètres uniquement au lieu du modèle entier en définissant
save_weights_only
surTrue
.
Pour plus de détails, reportez-vous à la documentation de l'API tf.keras.callbacks.ModelCheckpoint
et à la section Enregistrer les points de contrôle pendant l'entraînement dans le didacticiel Enregistrer et charger des modèles . En savoir plus sur le format Checkpoint dans la section Format TF Checkpoint du guide Enregistrer et charger les modèles Keras . De plus, pour ajouter une tolérance aux pannes, vous pouvez utiliser tf.keras.callbacks.BackupAndRestore
ou tf.train.Checkpoint
pour un point de contrôle manuel. Pour en savoir plus, consultez le guide de migration de la tolérance aux pannes .
Les rappels Keras sont des objets qui sont appelés à différents moments pendant la formation/l'évaluation/la prédiction dans les API Keras Model.fit
/ Model.evaluate
/ Model.predict Model.predict
. Pour en savoir plus, consultez la section Étapes suivantes à la fin du guide.
Installer
Commencez par des importations et un jeu de données simple à des fins de démonstration :
import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11493376/11490434 [==============================] - 0s 0us/step 11501568/11490434 [==============================] - 0s 0us/step
TensorFlow 1 : Enregistrer les points de contrôle avec les API tf.estimator
Cet exemple TensorFlow 1 montre comment configurer tf.estimator.RunConfig
pour enregistrer des points de contrôle à chaque étape de l'entraînement/de l'évaluation avec les API tf.estimator.Estimator
:
feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]
config = tf1.estimator.RunConfig(save_summary_steps=1,
save_checkpoints_steps=1)
path = tempfile.mkdtemp()
classifier = tf1.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[256, 32],
optimizer=tf1.train.AdamOptimizer(0.001),
n_classes=10,
dropout=0.2,
model_dir=path,
config = config
)
train_input_fn = tf1.estimator.inputs.numpy_input_fn(
x={"x": x_train},
y=y_train.astype(np.int32),
num_epochs=10,
batch_size=50,
shuffle=True,
)
test_input_fn = tf1.estimator.inputs.numpy_input_fn(
x={"x": x_test},
y=y_test.astype(np.int32),
num_epochs=10,
shuffle=False
)
train_spec = tf1.estimator.TrainSpec(input_fn=train_input_fn, max_steps=10)
eval_spec = tf1.estimator.EvalSpec(input_fn=test_input_fn,
steps=10,
throttle_secs=0)
tf1.estimator.train_and_evaluate(estimator=classifier,
train_spec=train_spec,
eval_spec=eval_spec)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmplrkjo9in', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} WARNING:tensorflow:From /tmp/ipykernel_20296/3980459272.py:18: The name tf.estimator.inputs is deprecated. Please use tf.compat.v1.estimator.inputs instead. WARNING:tensorflow:From /tmp/ipykernel_20296/3980459272.py:18: The name tf.estimator.inputs.numpy_input_fn is deprecated. Please use tf.compat.v1.estimator.inputs.numpy_input_fn instead. INFO:tensorflow:Not using Distribute Coordinator. INFO:tensorflow:Running training and evaluation locally (non-distributed). INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps 1 or save_checkpoints_secs None. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:397: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:65: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py:491: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py:914: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmplrkjo9in/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1... INFO:tensorflow:Saving checkpoints for 1 into /tmp/tmplrkjo9in/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:47 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-1 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.26374s INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:47 INFO:tensorflow:Saving dict for global step 1: accuracy = 0.1765625, average_loss = 2.2546134, global_step = 1, loss = 288.5905 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1: /tmp/tmplrkjo9in/model.ckpt-1 INFO:tensorflow:loss = 118.3231, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2... INFO:tensorflow:Saving checkpoints for 2 into /tmp/tmplrkjo9in/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:48 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-2 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.36662s INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:48 INFO:tensorflow:Saving dict for global step 2: accuracy = 0.2859375, average_loss = 2.1868849, global_step = 2, loss = 279.92126 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 2: /tmp/tmplrkjo9in/model.ckpt-2 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3... INFO:tensorflow:Saving checkpoints for 3 into /tmp/tmplrkjo9in/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:48 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-3 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.22792s INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:48 INFO:tensorflow:Saving dict for global step 3: accuracy = 0.35078126, average_loss = 2.1220195, global_step = 3, loss = 271.6185 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 3: /tmp/tmplrkjo9in/model.ckpt-3 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4... INFO:tensorflow:Saving checkpoints for 4 into /tmp/tmplrkjo9in/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:49 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-4 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.22387s INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:49 INFO:tensorflow:Saving dict for global step 4: accuracy = 0.40234375, average_loss = 2.0655982, global_step = 4, loss = 264.39658 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 4: /tmp/tmplrkjo9in/model.ckpt-4 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5... INFO:tensorflow:Saving checkpoints for 5 into /tmp/tmplrkjo9in/model.ckpt. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1054: remove_checkpoint (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version. Instructions for updating: Use standard file APIs to delete files with this prefix. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:49 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-5 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.22548s INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:49 INFO:tensorflow:Saving dict for global step 5: accuracy = 0.42421874, average_loss = 2.0072064, global_step = 5, loss = 256.92242 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5: /tmp/tmplrkjo9in/model.ckpt-5 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6... INFO:tensorflow:Saving checkpoints for 6 into /tmp/tmplrkjo9in/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:50 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-6 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.22806s INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:50 INFO:tensorflow:Saving dict for global step 6: accuracy = 0.43984374, average_loss = 1.9473753, global_step = 6, loss = 249.26404 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 6: /tmp/tmplrkjo9in/model.ckpt-6 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7... INFO:tensorflow:Saving checkpoints for 7 into /tmp/tmplrkjo9in/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:50 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-7 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.23091s INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:50 INFO:tensorflow:Saving dict for global step 7: accuracy = 0.44296876, average_loss = 1.8903366, global_step = 7, loss = 241.96309 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 7: /tmp/tmplrkjo9in/model.ckpt-7 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8... INFO:tensorflow:Saving checkpoints for 8 into /tmp/tmplrkjo9in/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:51 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-8 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.22453s INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:51 INFO:tensorflow:Saving dict for global step 8: accuracy = 0.44453126, average_loss = 1.8294731, global_step = 8, loss = 234.17256 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 8: /tmp/tmplrkjo9in/model.ckpt-8 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9... INFO:tensorflow:Saving checkpoints for 9 into /tmp/tmplrkjo9in/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:51 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-9 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.22271s INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:51 INFO:tensorflow:Saving dict for global step 9: accuracy = 0.47734374, average_loss = 1.7674354, global_step = 9, loss = 226.23174 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 9: /tmp/tmplrkjo9in/model.ckpt-9 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmplrkjo9in/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:52 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-10 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.38483s INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:52 INFO:tensorflow:Saving dict for global step 10: accuracy = 0.5140625, average_loss = 1.7108486, global_step = 10, loss = 218.98862 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmplrkjo9in/model.ckpt-10 INFO:tensorflow:Loss for final step: 96.2236. ({'accuracy': 0.5140625, 'average_loss': 1.7108486, 'loss': 218.98862, 'global_step': 10}, [])
%ls {classifier.model_dir}
checkpoint eval/ events.out.tfevents.1642127326.kokoro-gcp-ubuntu-prod-837339153 graph.pbtxt model.ckpt-10.data-00000-of-00001 model.ckpt-10.index model.ckpt-10.meta model.ckpt-6.data-00000-of-00001 model.ckpt-6.index model.ckpt-6.meta model.ckpt-7.data-00000-of-00001 model.ckpt-7.index model.ckpt-7.meta model.ckpt-8.data-00000-of-00001 model.ckpt-8.index model.ckpt-8.meta model.ckpt-9.data-00000-of-00001 model.ckpt-9.index model.ckpt-9.meta
TensorFlow 2 : Enregistrer les points de contrôle avec un rappel Keras pour Model.fit
Dans TensorFlow 2, lorsque vous utilisez le Keras Model.fit
(ou Model.evaluate
) intégré pour l'entraînement/l'évaluation, vous pouvez configurer tf.keras.callbacks.ModelCheckpoint
, puis le transmettre au paramètre callbacks
de Model.fit
(ou Model.evaluate
). (Pour en savoir plus, consultez la documentation sur l'API et la section Utilisation des rappels dans le guide Formation et évaluation avec les méthodes intégrées .)
Dans l'exemple ci-dessous, vous utiliserez un rappel tf.keras.callbacks.ModelCheckpoint
pour stocker les points de contrôle dans un répertoire temporaire :
def create_model():
return tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model = create_model()
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'],
steps_per_execution=10)
log_dir = tempfile.mkdtemp()
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=log_dir)
model.fit(x=x_train,
y=y_train,
epochs=10,
validation_data=(x_test, y_test),
callbacks=[model_checkpoint_callback])
Epoch 1/10 1840/1875 [============================>.] - ETA: 0s - loss: 0.2224 - accuracy: 0.9348 2022-01-14 02:28:56.714889: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets 1875/1875 [==============================] - 4s 2ms/step - loss: 0.2208 - accuracy: 0.9354 - val_loss: 0.1132 - val_accuracy: 0.9669 Epoch 2/10 1870/1875 [============================>.] - ETA: 0s - loss: 0.0961 - accuracy: 0.9706INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets 1875/1875 [==============================] - 3s 1ms/step - loss: 0.0962 - accuracy: 0.9706 - val_loss: 0.0784 - val_accuracy: 0.9753 Epoch 3/10 1860/1875 [============================>.] - ETA: 0s - loss: 0.0696 - accuracy: 0.9781INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets 1875/1875 [==============================] - 3s 2ms/step - loss: 0.0695 - accuracy: 0.9782 - val_loss: 0.0684 - val_accuracy: 0.9788 Epoch 4/10 1860/1875 [============================>.] - ETA: 0s - loss: 0.0529 - accuracy: 0.9826INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets 1875/1875 [==============================] - 3s 1ms/step - loss: 0.0531 - accuracy: 0.9826 - val_loss: 0.0671 - val_accuracy: 0.9791 Epoch 5/10 1860/1875 [============================>.] - ETA: 0s - loss: 0.0423 - accuracy: 0.9860INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets 1875/1875 [==============================] - 3s 1ms/step - loss: 0.0424 - accuracy: 0.9860 - val_loss: 0.0772 - val_accuracy: 0.9757 Epoch 6/10 1860/1875 [============================>.] - ETA: 0s - loss: 0.0345 - accuracy: 0.9888INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets 1875/1875 [==============================] - 3s 1ms/step - loss: 0.0345 - accuracy: 0.9888 - val_loss: 0.0669 - val_accuracy: 0.9811 Epoch 7/10 1860/1875 [============================>.] - ETA: 0s - loss: 0.0314 - accuracy: 0.9895INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets 1875/1875 [==============================] - 3s 1ms/step - loss: 0.0313 - accuracy: 0.9895 - val_loss: 0.0718 - val_accuracy: 0.9800 Epoch 8/10 1870/1875 [============================>.] - ETA: 0s - loss: 0.0298 - accuracy: 0.9899INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets 1875/1875 [==============================] - 3s 1ms/step - loss: 0.0298 - accuracy: 0.9899 - val_loss: 0.0632 - val_accuracy: 0.9825 Epoch 9/10 1860/1875 [============================>.] - ETA: 0s - loss: 0.0230 - accuracy: 0.9925INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets 1875/1875 [==============================] - 3s 1ms/step - loss: 0.0231 - accuracy: 0.9924 - val_loss: 0.0748 - val_accuracy: 0.9800 Epoch 10/10 1860/1875 [============================>.] - ETA: 0s - loss: 0.0220 - accuracy: 0.9920INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets 1875/1875 [==============================] - 3s 1ms/step - loss: 0.0222 - accuracy: 0.9920 - val_loss: 0.0703 - val_accuracy: 0.9825 <keras.callbacks.History at 0x7f638c204410>
%ls {model_checkpoint_callback.filepath}
assets/ keras_metadata.pb saved_model.pb variables/
Prochaines étapes
En savoir plus sur les points de contrôle dans :
- Documentation API :
tf.keras.callbacks.ModelCheckpoint
- Tutoriel : Enregistrer et charger des modèles (la section Enregistrer les points de contrôle pendant l'entraînement )
- Guide : Enregistrer et charger des modèles Keras (la section du format TF Checkpoint )
En savoir plus sur les rappels dans :
- Documentation API :
tf.keras.callbacks.Callback
- Guide : Écrire vos propres rappels
- Guide : Entraînement et évaluation avec les méthodes intégrées (section Utilisation des rappels )
Les ressources liées à la migration suivantes peuvent également vous être utiles :
- Le guide de migration de la tolérance aux pannes :
tf.keras.callbacks.BackupAndRestore
pourModel.fit
, ou les APItf.train.Checkpoint
ettf.train.CheckpointManager
pour une boucle d'entraînement personnalisée - Le guide de migration d'arrêt anticipé :
tf.keras.callbacks.EarlyStopping
est un rappel d'arrêt anticipé intégré - Le guide de migration TensorBoard : TensorBoard permet de suivre et d'afficher des métriques
- Guide de migration des rappels LoggingTensorHook et StopAtStepHook vers Keras
- Le guide des rappels SessionRunHook to Keras