Voir sur TensorFlow.org | Exécuter dans Google Colab | Voir la source sur GitHub | Télécharger le cahier |
Dans TensorFlow 1, vous utilisez tf.estimator.LoggingTensorHook
pour surveiller et consigner les tenseurs, tandis que tf.estimator.StopAtStepHook
permet d'arrêter l'entraînement à une étape spécifiée lors de l'entraînement avec tf.estimator.Estimator
. Ce notebook montre comment migrer de ces API vers leurs équivalents dans TensorFlow 2 à l'aide de rappels Keras personnalisés ( tf.keras.callbacks.Callback
) avec Model.fit
.
Les rappels Keras sont des objets qui sont appelés à différents moments pendant la formation/l'évaluation/la prédiction dans les API Keras Model.fit
/ Model.evaluate
/ Model.predict Model.predict
. Vous pouvez en savoir plus sur les rappels dans la documentation de l'API tf.keras.callbacks.Callback
, ainsi que dans les guides Rédaction de vos propres rappels et Formation et évaluation avec les méthodes intégrées (section Utilisation des rappels ). Pour migrer de SessionRunHook
dans TensorFlow 1 vers les rappels Keras dans TensorFlow 2, consultez le guide de formation Migrer avec logique assistée .
Installer
Commencez par des importations et un jeu de données simple à des fins de démonstration :
import tensorflow as tf
import tensorflow.compat.v1 as tf1
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]
# Define an input function.
def _input_fn():
return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)
TensorFlow 1 : Enregistrer les Tensors et arrêter l'entraînement avec les API tf.estimator
Dans TensorFlow 1, vous définissez différents crochets pour contrôler le comportement d'entraînement. Ensuite, vous transmettez ces crochets à tf.estimator.EstimatorSpec
.
Dans l'exemple ci-dessous :
- Pour surveiller/enregistrer des tenseurs (par exemple, des poids ou des pertes de modèle), vous utilisez
tf.estimator.LoggingTensorHook
(tf.train.LoggingTensorHook
est son alias). - Pour arrêter l'entraînement à une étape spécifique, vous utilisez
tf.estimator.StopAtStepHook
(tf.train.StopAtStepHook
est son alias).
def _model_fn(features, labels, mode):
dense = tf1.layers.Dense(1)
logits = dense(features)
loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
optimizer = tf1.train.AdagradOptimizer(0.05)
train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
# Define the stop hook.
stop_hook = tf1.train.StopAtStepHook(num_steps=2)
# Access tensors to be logged by names.
kernel_name = tf.identity(dense.weights[0])
bias_name = tf.identity(dense.weights[1])
logging_weight_hook = tf1.train.LoggingTensorHook(
tensors=[kernel_name, bias_name],
every_n_iter=1)
# Log the training loss by the tensor object.
logging_loss_hook = tf1.train.LoggingTensorHook(
{'loss from LoggingTensorHook': loss},
every_n_secs=3)
# Pass all hooks to `EstimatorSpec`.
return tf1.estimator.EstimatorSpec(mode,
loss=loss,
train_op=train_op,
training_hooks=[stop_hook,
logging_weight_hook,
logging_loss_hook])
estimator = tf1.estimator.Estimator(model_fn=_model_fn)
# Begin training.
# The training will stop after 2 steps, and the weights/loss will also be logged.
estimator.train(_input_fn)
INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp3q__3yt7 INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp3q__3yt7', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/adagrad.py:77: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp3q__3yt7/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.025395721, step = 0 INFO:tensorflow:Tensor("Identity:0", shape=(2, 1), dtype=float32) = [[-1.0769143] [ 1.0241832]], Tensor("Identity_1:0", shape=(1,), dtype=float32) = [0.] INFO:tensorflow:loss from LoggingTensorHook = 0.025395721 INFO:tensorflow:Tensor("Identity:0", shape=(2, 1), dtype=float32) = [[-1.1124082] [ 0.9824805]], Tensor("Identity_1:0", shape=(1,), dtype=float32) = [-0.03549388] (0.026 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2... INFO:tensorflow:Saving checkpoints for 2 into /tmp/tmp3q__3yt7/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2... INFO:tensorflow:Loss for final step: 0.09248222. <tensorflow_estimator.python.estimator.estimator.Estimator at 0x7f05ec414d10>
TensorFlow 2 : enregistrez les Tensors et arrêtez l'entraînement avec des rappels personnalisés et Model.fit
Dans TensorFlow 2, lorsque vous utilisez le Keras Model.fit
(ou Model.evaluate
) intégré pour l'entraînement/l'évaluation, vous pouvez configurer la surveillance du tenseur et l'arrêt de l'entraînement en définissant des Keras personnalisés tf.keras.callbacks.Callback
s. Ensuite, vous les transmettez au paramètre callbacks
de Model.fit
(ou Model.evaluate
). (Pour en savoir plus, consultez le guide Écrire vos propres rappels .)
Dans l'exemple ci-dessous :
- Pour recréer les fonctionnalités de
StopAtStepHook
, définissez un rappel personnalisé (nomméStopAtStepCallback
ci-dessous) où vous remplacez la méthodeon_batch_end
pour arrêter la formation après un certain nombre d'étapes. - Pour recréer le comportement
LoggingTensorHook
, définissez un rappel personnalisé (LoggingTensorCallback
) dans lequel vous enregistrez et produisez manuellement les tenseurs enregistrés, car l'accès aux tenseurs par noms n'est pas pris en charge. Vous pouvez également implémenter la fréquence de journalisation dans le rappel personnalisé. L'exemple ci-dessous imprimera les poids toutes les deux étapes. D'autres stratégies comme la journalisation toutes les N secondes sont également possibles.
class StopAtStepCallback(tf.keras.callbacks.Callback):
def __init__(self, stop_step=None):
super().__init__()
self._stop_step = stop_step
def on_batch_end(self, batch, logs=None):
if self.model.optimizer.iterations >= self._stop_step:
self.model.stop_training = True
print('\nstop training now')
class LoggingTensorCallback(tf.keras.callbacks.Callback):
def __init__(self, every_n_iter):
super().__init__()
self._every_n_iter = every_n_iter
self._log_count = every_n_iter
def on_batch_end(self, batch, logs=None):
if self._log_count > 0:
self._log_count -= 1
print("Logging Tensor Callback: dense/kernel:",
model.layers[0].weights[0])
print("Logging Tensor Callback: dense/bias:",
model.layers[0].weights[1])
print("Logging Tensor Callback loss:", logs["loss"])
else:
self._log_count -= self._every_n_iter
Lorsque vous avez terminé, transmettez les nouveaux StopAtStepCallback
et LoggingTensorCallback
—au paramètre callbacks
de Model.fit
:
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)
model.compile(optimizer, "mse")
# Begin training.
# The training will stop after 2 steps, and the weights/loss will also be logged.
model.fit(dataset, callbacks=[StopAtStepCallback(stop_step=2),
LoggingTensorCallback(every_n_iter=2)])
1/3 [=========>....................] - ETA: 0s - loss: 3.2473Logging Tensor Callback: dense/kernel: <tf.Variable 'dense/kernel:0' shape=(2, 1) dtype=float32, numpy= array([[-0.27049014], [-0.73790836]], dtype=float32)> Logging Tensor Callback: dense/bias: <tf.Variable 'dense/bias:0' shape=(1,) dtype=float32, numpy=array([0.04980864], dtype=float32)> Logging Tensor Callback loss: 3.2473244667053223 stop training now Logging Tensor Callback: dense/kernel: <tf.Variable 'dense/kernel:0' shape=(2, 1) dtype=float32, numpy= array([[-0.22285421], [-0.6911988 ]], dtype=float32)> Logging Tensor Callback: dense/bias: <tf.Variable 'dense/bias:0' shape=(1,) dtype=float32, numpy=array([0.09196297], dtype=float32)> Logging Tensor Callback loss: 5.644947052001953 3/3 [==============================] - 0s 4ms/step - loss: 5.6449 <keras.callbacks.History at 0x7f053022be90>
Prochaines étapes
En savoir plus sur les rappels dans :
- Documentation API :
tf.keras.callbacks.Callback
- Guide : Écrire vos propres rappels
- Guide : Entraînement et évaluation avec les méthodes intégrées (section Utilisation des rappels )
Les ressources liées à la migration suivantes peuvent également vous être utiles :
- Le guide de migration d'arrêt anticipé :
tf.keras.callbacks.EarlyStopping
est un rappel d'arrêt anticipé intégré - Le guide de migration TensorBoard : TensorBoard permet de suivre et d'afficher des métriques
- Le guide Formation avec migration logique assistée : De
SessionRunHook
dans TensorFlow 1 aux rappels Keras dans TensorFlow 2