Voir sur TensorFlow.org | Exécuter dans Google Colab | Voir la source sur GitHub | Télécharger le cahier |
Ce document présente tf.estimator
, une API TensorFlow de haut niveau. Les estimateurs encapsulent les actions suivantes :
- Entraînement
- Évaluation
- Prédiction
- Exporter pour servir
TensorFlow implémente plusieurs estimateurs prédéfinis. Les estimateurs personnalisés sont toujours pris en charge, mais principalement en tant que mesure de rétrocompatibilité. Les estimateurs personnalisés ne doivent pas être utilisés pour le nouveau code . Tous les estimateurs, prédéfinis ou personnalisés, sont des classes basées sur la classe tf.estimator.Estimator
.
Pour un exemple rapide, essayez les didacticiels Estimator . Pour un aperçu de la conception de l'API, consultez le livre blanc .
Installer
pip install -U tensorflow_datasets
import tempfile
import os
import tensorflow as tf
import tensorflow_datasets as tfds
Avantages
Semblable à un tf.keras.Model
, un estimator
est une abstraction au niveau du modèle. Le tf.estimator
fournit certaines fonctionnalités actuellement encore en développement pour tf.keras
. Ceux-ci sont:
- Formation basée sur le serveur de paramètres
- Intégration complète de TFX
Capacités des estimateurs
Les estimateurs offrent les avantages suivants :
- Vous pouvez exécuter des modèles basés sur Estimator sur un hôte local ou sur un environnement multiserveur distribué sans modifier votre modèle. De plus, vous pouvez exécuter des modèles basés sur Estimator sur des CPU, des GPU ou des TPU sans recoder votre modèle.
- Les estimateurs fournissent une boucle de formation distribuée sécurisée qui contrôle comment et quand :
- Charger les données
- Gérer les exceptions
- Créer des fichiers de point de contrôle et récupérer des échecs
- Enregistrer les résumés pour TensorBoard
Lors de l'écriture d'une application avec des estimateurs, vous devez séparer le pipeline d'entrée de données du modèle. Cette séparation simplifie les expériences avec différents ensembles de données.
Utilisation d'estimateurs prédéfinis
Les estimateurs prédéfinis vous permettent de travailler à un niveau conceptuel beaucoup plus élevé que les API TensorFlow de base. Vous n'avez plus à vous soucier de la création du graphique ou des sessions de calcul puisque les estimateurs gèrent toute la "plomberie" pour vous. De plus, les estimateurs prédéfinis vous permettent d'expérimenter différentes architectures de modèles en n'apportant que des modifications minimes au code. tf.estimator.DNNClassifier
, par exemple, est une classe Estimator prédéfinie qui forme des modèles de classification basés sur des réseaux de neurones denses et à anticipation.
Un programme TensorFlow s'appuyant sur un estimateur prédéfini comprend généralement les quatre étapes suivantes :
1. Ecrire une fonction d'entrée
Par exemple, vous pouvez créer une fonction pour importer l'ensemble d'apprentissage et une autre fonction pour importer l'ensemble de test. Les estimateurs s'attendent à ce que leurs entrées soient formatées comme une paire d'objets :
- Un dictionnaire dans lequel les clés sont des noms de caractéristiques et les valeurs sont des Tensors (ou SparseTensors) contenant les données de caractéristiques correspondantes
- Un Tensor contenant un ou plusieurs libellés
L' input_fn
doit renvoyer un tf.data.Dataset
qui produit des paires dans ce format.
Par exemple, le code suivant crée un tf.data.Dataset
à partir du fichier train.csv
du jeu de données Titanic :
def train_input_fn():
titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic = tf.data.experimental.make_csv_dataset(
titanic_file, batch_size=32,
label_name="survived")
titanic_batches = (
titanic.cache().repeat().shuffle(500)
.prefetch(tf.data.AUTOTUNE))
return titanic_batches
Le input_fn
est exécuté dans un tf.Graph
et peut également renvoyer directement une paire (features_dics, labels)
contenant des tenseurs de graphe, mais cela est sujet aux erreurs en dehors des cas simples comme le retour de constantes.
2. Définissez les colonnes de fonction.
Chaque tf.feature_column
identifie un nom de fonctionnalité, son type et tout prétraitement d'entrée.
Par exemple, l'extrait de code suivant crée trois colonnes de caractéristiques.
- La première utilise la fonction d'
age
directement comme entrée à virgule flottante. - La seconde utilise la fonction de
class
comme entrée catégorique. - La troisième utilise l'
embark_town
comme entrée catégorique, mais utilise l'hashing trick
pour éviter d'avoir à énumérer les options et à définir le nombre d'options.
Pour plus d'informations, consultez le didacticiel sur les colonnes de fonctions .
age = tf.feature_column.numeric_column('age')
cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third'])
embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)
3. Instanciez l'estimateur prédéfini pertinent.
Par exemple, voici un exemple d'instanciation d'un estimateur préfabriqué nommé LinearClassifier
:
model_dir = tempfile.mkdtemp()
model = tf.estimator.LinearClassifier(
model_dir=model_dir,
feature_columns=[embark, cls, age],
n_classes=2
)
INFO:tensorflow:Using default config. INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpl24pp3cp', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
Pour plus d'informations, vous pouvez consulter le didacticiel sur le classificateur linéaire .
4. Appelez une méthode de formation, d'évaluation ou d'inférence.
Tous les estimateurs fournissent des méthodes d' train
, d' evaluate
et de predict
.
model = model.train(input_fn=train_input_fn, steps=100)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/base_layer_v1.py:1684: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead. warnings.warn('`layer.add_variable` is deprecated and ' WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/ftrl.py:147: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpl24pp3cp/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.6931472, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 100... INFO:tensorflow:Saving checkpoints for 100 into /tmp/tmpl24pp3cp/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 100... INFO:tensorflow:Loss for final step: 0.6319582. 2021-09-22 20:49:10.453286: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
result = model.evaluate(train_input_fn, steps=10)
for key, value in result.items():
print(key, ":", value)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2021-09-22T20:49:11 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpl24pp3cp/model.ckpt-100 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.74609s INFO:tensorflow:Finished evaluation at 2021-09-22-20:49:12 INFO:tensorflow:Saving dict for global step 100: accuracy = 0.734375, accuracy_baseline = 0.640625, auc = 0.7373913, auc_precision_recall = 0.64306235, average_loss = 0.563341, global_step = 100, label/mean = 0.359375, loss = 0.563341, precision = 0.734375, prediction/mean = 0.3463129, recall = 0.40869564 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 100: /tmp/tmpl24pp3cp/model.ckpt-100 accuracy : 0.734375 accuracy_baseline : 0.640625 auc : 0.7373913 auc_precision_recall : 0.64306235 average_loss : 0.563341 label/mean : 0.359375 loss : 0.563341 precision : 0.734375 prediction/mean : 0.3463129 recall : 0.40869564 global_step : 100 2021-09-22 20:49:12.168629: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
for pred in model.predict(train_input_fn):
for key, value in pred.items():
print(key, ":", value)
break
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpl24pp3cp/model.ckpt-100 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. logits : [-1.5173098] logistic : [0.17985801] probabilities : [0.820142 0.17985801] class_ids : [0] classes : [b'0'] all_class_ids : [0 1] all_classes : [b'0' b'1'] 2021-09-22 20:49:13.076528: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Avantages des estimateurs préfabriqués
Les estimateurs prédéfinis encodent les meilleures pratiques, offrant les avantages suivants :
- Meilleures pratiques pour déterminer où les différentes parties du graphe de calcul doivent s'exécuter, en mettant en œuvre des stratégies sur une seule machine ou sur un cluster.
- Meilleures pratiques pour la rédaction d'événements (résumés) et résumés universellement utiles.
Si vous n'utilisez pas d'estimateurs prédéfinis, vous devez implémenter vous-même les fonctionnalités précédentes.
Estimateurs personnalisés
Le cœur de chaque estimateur, qu'il soit prédéfini ou personnalisé, est sa fonction de modèle , model_fn
, qui est une méthode qui crée des graphiques pour la formation, l'évaluation et la prédiction. Lorsque vous utilisez un estimateur prédéfini, quelqu'un d'autre a déjà implémenté la fonction de modèle. Lorsque vous vous appuyez sur un estimateur personnalisé, vous devez écrire vous-même la fonction de modèle.
Créer un estimateur à partir d'un modèle Keras
Vous pouvez convertir des modèles Keras existants en estimateurs avec tf.keras.estimator.model_to_estimator
. Ceci est utile si vous souhaitez moderniser le code de votre modèle, mais que votre pipeline de formation nécessite toujours des estimateurs.
Instanciez un modèle Keras MobileNet V2 et compilez le modèle avec l'optimiseur, la perte et les métriques avec lesquels vous entraîner :
keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(
input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = False
estimator_model = tf.keras.Sequential([
keras_mobilenet_v2,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(1)
])
# Compile the model
estimator_model.compile(
optimizer='adam',
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5 9412608/9406464 [==============================] - 0s 0us/step 9420800/9406464 [==============================] - 0s 0us/step
Créez un Estimator
à partir du modèle Keras compilé. L'état initial du modèle Keras est conservé dans l' Estimator
créé :
est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)
INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpmosnmied INFO:tensorflow:Using the Keras model provided. /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/backend.py:401: UserWarning: `tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model. warnings.warn('`tf.keras.backend.set_learning_phase` is deprecated and ' /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/utils/generic_utils.py:497: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument. category=CustomMaskWarning) INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpmosnmied', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
Traitez l' Estimator
dérivé comme vous le feriez avec n'importe quel autre Estimator
.
IMG_SIZE = 160 # All images will be resized to 160x160
def preprocess(image, label):
image = tf.cast(image, tf.float32)
image = (image/127.5) - 1
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
return image, label
def train_input_fn(batch_size):
data = tfds.load('cats_vs_dogs', as_supervised=True)
train_data = data['train']
train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)
return train_data
Pour entraîner, appelez la fonction d'entraînement d'Estimator :
est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=50)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpmosnmied/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={}) INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpmosnmied/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={}) INFO:tensorflow:Warm-starting from: /tmp/tmpmosnmied/keras/keras_model.ckpt INFO:tensorflow:Warm-starting from: /tmp/tmpmosnmied/keras/keras_model.ckpt INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES. INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES. INFO:tensorflow:Warm-started 158 variables. INFO:tensorflow:Warm-started 158 variables. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpmosnmied/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpmosnmied/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.6994096, step = 0 INFO:tensorflow:loss = 0.6994096, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50... INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpmosnmied/model.ckpt. INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpmosnmied/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50... INFO:tensorflow:Loss for final step: 0.68789804. INFO:tensorflow:Loss for final step: 0.68789804. <tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f4b1c1e9890>
De même, pour évaluer, appelez la fonction d'évaluation de l'estimateur :
est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/training.py:2470: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically. warnings.warn('`Model.state_updates` will be removed in a future version. ' INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2021-09-22T20:49:36 INFO:tensorflow:Starting evaluation at 2021-09-22T20:49:36 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpmosnmied/model.ckpt-50 INFO:tensorflow:Restoring parameters from /tmp/tmpmosnmied/model.ckpt-50 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 3.89658s INFO:tensorflow:Inference Time : 3.89658s INFO:tensorflow:Finished evaluation at 2021-09-22-20:49:39 INFO:tensorflow:Finished evaluation at 2021-09-22-20:49:39 INFO:tensorflow:Saving dict for global step 50: accuracy = 0.525, global_step = 50, loss = 0.6723582 INFO:tensorflow:Saving dict for global step 50: accuracy = 0.525, global_step = 50, loss = 0.6723582 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpmosnmied/model.ckpt-50 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpmosnmied/model.ckpt-50 {'accuracy': 0.525, 'loss': 0.6723582, 'global_step': 50}
Pour plus de détails, veuillez vous référer à la documentation de tf.keras.estimator.model_to_estimator
.
Enregistrement de points de contrôle basés sur des objets avec Estimator
Les estimateurs enregistrent par défaut les points de contrôle avec des noms de variables plutôt que le graphique d'objets décrit dans le guide des points de contrôle. tf.train.Checkpoint
lira les points de contrôle basés sur le nom, mais les noms de variables peuvent changer lors du déplacement de parties d'un modèle en dehors du model_fn
de l'Estimator. Pour une compatibilité ascendante, l'enregistrement de points de contrôle basés sur des objets facilite la formation d'un modèle à l'intérieur d'un estimateur, puis son utilisation en dehors de celui-ci.
import tensorflow.compat.v1 as tf_compat
def toy_dataset():
inputs = tf.range(10.)[:, None]
labels = inputs * 5. + tf.range(5.)[None, :]
return tf.data.Dataset.from_tensor_slices(
dict(x=inputs, y=labels)).repeat().batch(2)
class Net(tf.keras.Model):
"""A simple linear model."""
def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)
def call(self, x):
return self.l1(x)
def model_fn(features, labels, mode):
net = Net()
opt = tf.keras.optimizers.Adam(0.1)
ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
optimizer=opt, net=net)
with tf.GradientTape() as tape:
output = net(features['x'])
loss = tf.reduce_mean(tf.abs(output - features['y']))
variables = net.trainable_variables
gradients = tape.gradient(loss, variables)
return tf.estimator.EstimatorSpec(
mode,
loss=loss,
train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
ckpt.step.assign_add(1)),
# Tell the Estimator to save "ckpt" in an object-based format.
scaffold=tf_compat.train.Scaffold(saver=ckpt))
tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config. INFO:tensorflow:Using default config. INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 4.659403, step = 0 INFO:tensorflow:loss = 4.659403, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt. INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Loss for final step: 39.58891. INFO:tensorflow:Loss for final step: 39.58891. <tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f4b7c451fd0>
tf.train.Checkpoint
peut alors charger les points de contrôle de l'estimateur à partir de son model_dir
.
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy() # From est.train(..., steps=10)
10
Modèles enregistrés à partir d'estimateurs
Les estimateurs exportent SavedModels via tf.Estimator.export_saved_model
.
input_column = tf.feature_column.numeric_column("x")
estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])
def input_fn():
return tf.data.Dataset.from_tensor_slices(
({"x": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)
estimator.train(input_fn)
INFO:tensorflow:Using default config. INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp30_d7xz6 WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp30_d7xz6 INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp30_d7xz6', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp30_d7xz6', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp30_d7xz6/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp30_d7xz6/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.6931472, step = 0 INFO:tensorflow:loss = 0.6931472, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50... INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmp30_d7xz6/model.ckpt. INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmp30_d7xz6/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50... INFO:tensorflow:Loss for final step: 0.4022895. INFO:tensorflow:Loss for final step: 0.4022895. <tensorflow_estimator.python.estimator.canned.linear.LinearClassifierV2 at 0x7f4b1c10fd10>
Pour enregistrer un Estimator
, vous devez créer serving_input_receiver
. Cette fonction construit une partie d'un tf.Graph
qui analyse les données brutes reçues par le SavedModel.
Le module tf.estimator.export
contient des fonctions pour aider à construire ces receivers
.
Le code suivant construit un récepteur, basé sur feature_columns
, qui accepte les tampons de protocole sérialisés tf.Example
, qui sont souvent utilisés avec tf-serving .
tmpdir = tempfile.mkdtemp()
serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
tf.feature_column.make_parse_example_spec([input_column]))
estimator_base_path = os.path.join(tmpdir, 'from_estimator')
estimator_path = estimator.export_saved_model(estimator_base_path, serving_input_fn)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info. INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification'] INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification'] INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression'] INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression'] INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict'] INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict'] INFO:tensorflow:Signatures INCLUDED in export for Train: None INFO:tensorflow:Signatures INCLUDED in export for Train: None INFO:tensorflow:Signatures INCLUDED in export for Eval: None INFO:tensorflow:Signatures INCLUDED in export for Eval: None INFO:tensorflow:Restoring parameters from /tmp/tmp30_d7xz6/model.ckpt-50 INFO:tensorflow:Restoring parameters from /tmp/tmp30_d7xz6/model.ckpt-50 INFO:tensorflow:Assets added to graph. INFO:tensorflow:Assets added to graph. INFO:tensorflow:No assets to write. INFO:tensorflow:No assets to write. INFO:tensorflow:SavedModel written to: /tmp/tmpi_szzuj1/from_estimator/temp-1632343781/saved_model.pb INFO:tensorflow:SavedModel written to: /tmp/tmpi_szzuj1/from_estimator/temp-1632343781/saved_model.pb
Vous pouvez également charger et exécuter ce modèle, à partir de python :
imported = tf.saved_model.load(estimator_path)
def predict(x):
example = tf.train.Example()
example.features.feature["x"].float_list.value.extend([x])
return imported.signatures["predict"](
examples=tf.constant([example.SerializeToString()]))
print(predict(1.5))
print(predict(3.5))
{'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[1]])>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'1']], dtype=object)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.2974025]], dtype=float32)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.5738074]], dtype=float32)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.42619258, 0.5738074 ]], dtype=float32)>} {'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[0]])>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'0']], dtype=object)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-1.1919093]], dtype=float32)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.23291764]], dtype=float32)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.7670824 , 0.23291762]], dtype=float32)>}
tf.estimator.export.build_raw_serving_input_receiver_fn
vous permet de créer des fonctions d'entrée qui prennent des tenseurs bruts plutôt que tf.train.Example
s.
Utilisation de tf.distribute.Strategy
avec Estimator (support limité)
tf.estimator
est une API TensorFlow de formation distribuée qui prenait en charge à l'origine l'approche du serveur de paramètres asynchrone. tf.estimator
prend désormais en charge tf.distribute.Strategy
. Si vous utilisez tf.estimator
, vous pouvez passer à la formation distribuée avec très peu de modifications de votre code. Grâce à cela, les utilisateurs d'Estimator peuvent désormais effectuer une formation distribuée synchrone sur plusieurs GPU et plusieurs travailleurs, ainsi qu'utiliser des TPU. Cette prise en charge dans Estimator est cependant limitée. Consultez la section Ce qui est maintenant pris en charge ci-dessous pour plus de détails.
L'utilisation de tf.distribute.Strategy
avec Estimator est légèrement différente de celle de Keras. Au lieu d'utiliser strategy.scope
, vous transmettez maintenant l'objet de stratégie dans le RunConfig
pour l'Estimator.
Vous pouvez vous référer au guide de formation distribué pour plus d'informations.
Voici un extrait de code qui montre cela avec un Estimator LinearRegressor
et MirroredStrategy
prédéfinis :
mirrored_strategy = tf.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(
train_distribute=mirrored_strategy, eval_distribute=mirrored_strategy)
regressor = tf.estimator.LinearRegressor(
feature_columns=[tf.feature_column.numeric_column('feats')],
optimizer='SGD',
config=config)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Initializing RunConfig with distribution strategies. INFO:tensorflow:Initializing RunConfig with distribution strategies. INFO:tensorflow:Not using Distribute Coordinator. INFO:tensorflow:Not using Distribute Coordinator. WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpftw63jyd WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpftw63jyd INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpftw63jyd', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f4b0c04c050>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f4b0c04c050>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None} INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpftw63jyd', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f4b0c04c050>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f4b0c04c050>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}
Ici, vous utilisez un estimateur prédéfini, mais le même code fonctionne également avec un estimateur personnalisé. train_distribute
détermine comment la formation sera distribuée et eval_distribute
détermine comment l'évaluation sera distribuée. C'est une autre différence avec Keras où vous utilisez la même stratégie pour la formation et l'évaluation.
Vous pouvez maintenant entraîner et évaluer cet estimateur avec une fonction d'entrée :
def input_fn():
dataset = tf.data.Dataset.from_tensors(({"feats":[1.]}, [1.]))
return dataset.repeat(1000).batch(10)
regressor.train(input_fn=input_fn, steps=10)
regressor.evaluate(input_fn=input_fn, steps=10)
INFO:tensorflow:Calling model_fn. /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py:374: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options. warnings.warn("To make it possible to preserve tf.data options across " INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/util.py:95: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version. Instructions for updating: Use the iterator's `initializer` property instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/util.py:95: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version. Instructions for updating: Use the iterator's `initializer` property instead. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpftw63jyd/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpftw63jyd/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... 2021-09-22 20:49:45.706166: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} } . Registered: device='CPU' 2021-09-22 20:49:45.707521: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} } . Registered: device='CPU' INFO:tensorflow:loss = 1.0, step = 0 INFO:tensorflow:loss = 1.0, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmpftw63jyd/model.ckpt. INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmpftw63jyd/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Loss for final step: 2.877698e-13. INFO:tensorflow:Loss for final step: 2.877698e-13. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2021-09-22T20:49:46 INFO:tensorflow:Starting evaluation at 2021-09-22T20:49:46 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpftw63jyd/model.ckpt-10 INFO:tensorflow:Restoring parameters from /tmp/tmpftw63jyd/model.ckpt-10 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. 2021-09-22 20:49:46.680821: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} } . Registered: device='CPU' 2021-09-22 20:49:46.682161: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} } . Registered: device='CPU' INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.26514s INFO:tensorflow:Inference Time : 0.26514s INFO:tensorflow:Finished evaluation at 2021-09-22-20:49:46 INFO:tensorflow:Finished evaluation at 2021-09-22-20:49:46 INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994 INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmpftw63jyd/model.ckpt-10 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmpftw63jyd/model.ckpt-10 {'average_loss': 1.4210855e-14, 'label/mean': 1.0, 'loss': 1.4210855e-14, 'prediction/mean': 0.99999994, 'global_step': 10}
Une autre différence à souligner ici entre Estimator et Keras est la gestion des entrées. Dans Keras, chaque lot de l'ensemble de données est automatiquement réparti entre les multiples répliques. Dans Estimator, cependant, vous n'effectuez pas de fractionnement automatique des lots, ni ne partagez automatiquement les données entre différents travailleurs. Vous avez un contrôle total sur la façon dont vous souhaitez que vos données soient distribuées entre les travailleurs et les appareils, et vous devez fournir un input_fn
pour spécifier comment distribuer vos données.
Votre input_fn
est appelé une fois par worker, donnant ainsi un jeu de données par worker. Ensuite, un lot de cet ensemble de données est transmis à une réplique sur ce travailleur, consommant ainsi N lots pour N répliques sur 1 travailleur. En d'autres termes, l'ensemble de données renvoyé par input_fn
doit fournir des lots de taille PER_REPLICA_BATCH_SIZE
. Et la taille de lot globale pour une étape peut être obtenue sous PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync
.
Lorsque vous effectuez une formation multi-travailleurs, vous devez soit diviser vos données entre les travailleurs, soit mélanger avec une graine aléatoire sur chacun. Vous pouvez consulter un exemple de la façon de procéder dans le didacticiel Formation multi-travailleurs avec Estimator .
Et de même, vous pouvez également utiliser des stratégies multi-travailleurs et de serveurs de paramètres. Le code reste le même, mais vous devez utiliser tf.estimator.train_and_evaluate
et définir les variables d'environnement TF_CONFIG
pour chaque binaire exécuté dans votre cluster.
Qu'est-ce qui est pris en charge maintenant ?
La prise en charge de l'entraînement avec Estimator à l'aide de toutes les stratégies à l'exception de TPUStrategy
est limitée. La formation et l'évaluation de base devraient fonctionner, mais un certain nombre de fonctionnalités avancées telles que v1.train.Scaffold
ne fonctionnent pas. Il peut également y avoir un certain nombre de bogues dans cette intégration et il n'est pas prévu d'améliorer activement cette prise en charge (l'accent est mis sur Keras et la prise en charge de la boucle d'entraînement personnalisée). Si possible, vous devriez préférer utiliser tf.distribute
avec ces API à la place.
API de formation | Stratégie en miroir | TPUStratégie | MultiWorkerMirroredStrategy | Stratégie de stockage central | ParameterServerStrategy |
---|---|---|---|---|---|
API d'estimation | Assistance limitée | Non supporté | Assistance limitée | Assistance limitée | Assistance limitée |
Exemples et tutoriels
Voici quelques exemples de bout en bout qui montrent comment utiliser diverses stratégies avec Estimator :
- Le didacticiel Formation multi-travailleurs avec Estimator montre comment vous pouvez vous entraîner avec plusieurs travailleurs à l'aide de
MultiWorkerMirroredStrategy
sur le jeu de données MNIST. - Un exemple de bout en bout d' exécution d'une formation multi-travailleurs avec des stratégies de distribution dans
tensorflow/ecosystem
à l'aide de modèles Kubernetes. Il commence par un modèle Keras et le convertit en estimateur à l'aide de l'APItf.keras.estimator.model_to_estimator
. - Le modèle ResNet50 officiel, qui peut être entraîné à l'aide de
MirroredStrategy
ouMultiWorkerMirroredStrategy
.