Ver en TensorFlow.org | Ejecutar en Google Colab | Ver fuente en GitHub | Descargar libreta |
Este documento presenta tf.estimator
, una API de TensorFlow de alto nivel. Los estimadores encapsulan las siguientes acciones:
- Capacitación
- Evaluación
- Predicción
- Exportar para servir
TensorFlow implementa varios estimadores prefabricados. Los estimadores personalizados aún se admiten, pero principalmente como una medida de compatibilidad con versiones anteriores. Los estimadores personalizados no deben usarse para el código nuevo . Todos los Estimadores (prefabricados o personalizados) son clases basadas en la clase tf.estimator.Estimator
.
Para ver un ejemplo rápido, pruebe los tutoriales de Estimator . Para obtener una descripción general del diseño de la API, consulte el documento técnico .
Configuración
pip install -U tensorflow_datasets
import tempfile
import os
import tensorflow as tf
import tensorflow_datasets as tfds
Ventajas
Similar a tf.keras.Model
, un estimator
es una abstracción a nivel de modelo. El tf.estimator
proporciona algunas capacidades que aún están en desarrollo para tf.keras
. Estos son:
- Entrenamiento basado en servidor de parámetros
- Integración completa de TFX
Capacidades de los estimadores
Los estimadores proporcionan los siguientes beneficios:
- Puede ejecutar modelos basados en Estimator en un host local o en un entorno distribuido de varios servidores sin cambiar su modelo. Además, puede ejecutar modelos basados en Estimator en CPU, GPU o TPU sin volver a codificar su modelo.
- Los estimadores proporcionan un ciclo de entrenamiento distribuido seguro que controla cómo y cuándo:
- Cargar datos
- Manejar excepciones
- Cree archivos de puntos de control y recupérese de fallas
- Guardar resúmenes para TensorBoard
Al escribir una aplicación con estimadores, debe separar la canalización de entrada de datos del modelo. Esta separación simplifica los experimentos con diferentes conjuntos de datos.
Uso de estimadores prefabricados
Los estimadores prefabricados le permiten trabajar a un nivel conceptual mucho más alto que las API básicas de TensorFlow. Ya no tiene que preocuparse por crear el gráfico computacional o las sesiones, ya que Estimators maneja toda la "plomería" por usted. Además, los Estimadores prefabricados le permiten experimentar con diferentes arquitecturas de modelos realizando solo cambios mínimos en el código. tf.estimator.DNNClassifier
, por ejemplo, es una clase de Estimator prefabricada que entrena modelos de clasificación basados en redes neuronales densas y realimentadas.
Un programa de TensorFlow que se basa en un Estimator prefabricado generalmente consta de los siguientes cuatro pasos:
1. Escribe una función de entrada
Por ejemplo, puede crear una función para importar el conjunto de entrenamiento y otra función para importar el conjunto de prueba. Los estimadores esperan que sus entradas tengan el formato de un par de objetos:
- Un diccionario en el que las claves son nombres de funciones y los valores son tensores (o SparseTensors) que contienen los datos de funciones correspondientes
- Un tensor que contiene una o más etiquetas
El input_fn
debe devolver un tf.data.Dataset
que produce pares en ese formato.
Por ejemplo, el siguiente código crea un tf.data.Dataset
a partir del archivo train.csv
del conjunto de datos del Titanic:
def train_input_fn():
titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic = tf.data.experimental.make_csv_dataset(
titanic_file, batch_size=32,
label_name="survived")
titanic_batches = (
titanic.cache().repeat().shuffle(500)
.prefetch(tf.data.AUTOTUNE))
return titanic_batches
El input_fn
se ejecuta en un tf.Graph
y también puede devolver directamente un par (features_dics, labels)
que contiene tensores gráficos, pero esto es propenso a errores fuera de casos simples como devolver constantes.
2. Defina las columnas de características.
Cada tf.feature_column
identifica un nombre de característica, su tipo y cualquier preprocesamiento de entrada.
Por ejemplo, el siguiente fragmento crea tres columnas de características.
- El primero usa la función de
age
directamente como una entrada de punto flotante. - El segundo usa la función de
class
como una entrada categórica. - El tercero utiliza la
embark_town
como una entrada categórica, pero utiliza elhashing trick
para evitar la necesidad de enumerar las opciones y establecer el número de opciones.
Para obtener más información, consulte el tutorial de columnas de características .
age = tf.feature_column.numeric_column('age')
cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third'])
embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)
3. Cree una instancia del Estimador prefabricado relevante.
Por ejemplo, aquí hay una instancia de muestra de un Estimador prefabricado llamado LinearClassifier
:
model_dir = tempfile.mkdtemp()
model = tf.estimator.LinearClassifier(
model_dir=model_dir,
feature_columns=[embark, cls, age],
n_classes=2
)
INFO:tensorflow:Using default config. INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpl24pp3cp', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
Para más información, puedes ir al tutorial del clasificador lineal .
4. Llame a un método de entrenamiento, evaluación o inferencia.
Todos los estimadores proporcionan métodos de train
, evaluate
y predict
.
model = model.train(input_fn=train_input_fn, steps=100)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/base_layer_v1.py:1684: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead. warnings.warn('`layer.add_variable` is deprecated and ' WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/ftrl.py:147: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpl24pp3cp/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.6931472, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 100... INFO:tensorflow:Saving checkpoints for 100 into /tmp/tmpl24pp3cp/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 100... INFO:tensorflow:Loss for final step: 0.6319582. 2021-09-22 20:49:10.453286: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
result = model.evaluate(train_input_fn, steps=10)
for key, value in result.items():
print(key, ":", value)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2021-09-22T20:49:11 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpl24pp3cp/model.ckpt-100 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.74609s INFO:tensorflow:Finished evaluation at 2021-09-22-20:49:12 INFO:tensorflow:Saving dict for global step 100: accuracy = 0.734375, accuracy_baseline = 0.640625, auc = 0.7373913, auc_precision_recall = 0.64306235, average_loss = 0.563341, global_step = 100, label/mean = 0.359375, loss = 0.563341, precision = 0.734375, prediction/mean = 0.3463129, recall = 0.40869564 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 100: /tmp/tmpl24pp3cp/model.ckpt-100 accuracy : 0.734375 accuracy_baseline : 0.640625 auc : 0.7373913 auc_precision_recall : 0.64306235 average_loss : 0.563341 label/mean : 0.359375 loss : 0.563341 precision : 0.734375 prediction/mean : 0.3463129 recall : 0.40869564 global_step : 100 2021-09-22 20:49:12.168629: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
for pred in model.predict(train_input_fn):
for key, value in pred.items():
print(key, ":", value)
break
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpl24pp3cp/model.ckpt-100 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. logits : [-1.5173098] logistic : [0.17985801] probabilities : [0.820142 0.17985801] class_ids : [0] classes : [b'0'] all_class_ids : [0 1] all_classes : [b'0' b'1'] 2021-09-22 20:49:13.076528: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Beneficios de los estimadores prefabricados
Los estimadores prefabricados codifican las mejores prácticas y brindan los siguientes beneficios:
- Las mejores prácticas para determinar dónde deben ejecutarse las diferentes partes del gráfico computacional, implementando estrategias en una sola máquina o en un clúster.
- Las mejores prácticas para la redacción de eventos (resumen) y resúmenes universalmente útiles.
Si no usa Estimadores prefabricados, debe implementar las funciones anteriores usted mismo.
Estimadores personalizados
El corazón de cada Estimator, ya sea prefabricado o personalizado, es su función de modelo , model_fn
, que es un método que crea gráficos para entrenamiento, evaluación y predicción. Cuando está utilizando un Estimador prefabricado, alguien más ya ha implementado la función de modelo. Cuando confíe en un Estimador personalizado, debe escribir la función del modelo usted mismo.
Crear un Estimador a partir de un modelo de Keras
Puede convertir modelos Keras existentes en Estimadores con tf.keras.estimator.model_to_estimator
. Esto es útil si desea modernizar el código de su modelo, pero su proceso de capacitación aún requiere Estimadores.
Crea una instancia de un modelo Keras MobileNet V2 y compila el modelo con el optimizador, la pérdida y las métricas para entrenar con:
keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(
input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = False
estimator_model = tf.keras.Sequential([
keras_mobilenet_v2,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(1)
])
# Compile the model
estimator_model.compile(
optimizer='adam',
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5 9412608/9406464 [==============================] - 0s 0us/step 9420800/9406464 [==============================] - 0s 0us/step
Cree un Estimator
a partir del modelo de Keras compilado. El estado del modelo inicial del modelo Keras se conserva en el Estimator
creado:
est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)
INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpmosnmied INFO:tensorflow:Using the Keras model provided. /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/backend.py:401: UserWarning: `tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model. warnings.warn('`tf.keras.backend.set_learning_phase` is deprecated and ' /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/utils/generic_utils.py:497: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument. category=CustomMaskWarning) INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpmosnmied', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
Trate el Estimator
derivado como lo haría con cualquier otro Estimator
.
IMG_SIZE = 160 # All images will be resized to 160x160
def preprocess(image, label):
image = tf.cast(image, tf.float32)
image = (image/127.5) - 1
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
return image, label
def train_input_fn(batch_size):
data = tfds.load('cats_vs_dogs', as_supervised=True)
train_data = data['train']
train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)
return train_data
Para entrenar, llama a la función de tren de Estimator:
est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=50)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpmosnmied/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={}) INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpmosnmied/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={}) INFO:tensorflow:Warm-starting from: /tmp/tmpmosnmied/keras/keras_model.ckpt INFO:tensorflow:Warm-starting from: /tmp/tmpmosnmied/keras/keras_model.ckpt INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES. INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES. INFO:tensorflow:Warm-started 158 variables. INFO:tensorflow:Warm-started 158 variables. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpmosnmied/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpmosnmied/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.6994096, step = 0 INFO:tensorflow:loss = 0.6994096, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50... INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpmosnmied/model.ckpt. INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpmosnmied/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50... INFO:tensorflow:Loss for final step: 0.68789804. INFO:tensorflow:Loss for final step: 0.68789804. <tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f4b1c1e9890>
De manera similar, para evaluar, llame a la función de evaluación del Estimador:
est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/training.py:2470: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically. warnings.warn('`Model.state_updates` will be removed in a future version. ' INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2021-09-22T20:49:36 INFO:tensorflow:Starting evaluation at 2021-09-22T20:49:36 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpmosnmied/model.ckpt-50 INFO:tensorflow:Restoring parameters from /tmp/tmpmosnmied/model.ckpt-50 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 3.89658s INFO:tensorflow:Inference Time : 3.89658s INFO:tensorflow:Finished evaluation at 2021-09-22-20:49:39 INFO:tensorflow:Finished evaluation at 2021-09-22-20:49:39 INFO:tensorflow:Saving dict for global step 50: accuracy = 0.525, global_step = 50, loss = 0.6723582 INFO:tensorflow:Saving dict for global step 50: accuracy = 0.525, global_step = 50, loss = 0.6723582 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpmosnmied/model.ckpt-50 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpmosnmied/model.ckpt-50 {'accuracy': 0.525, 'loss': 0.6723582, 'global_step': 50}
Para obtener más detalles, consulte la documentación de tf.keras.estimator.model_to_estimator
.
Guardar puntos de control basados en objetos con Estimator
Los estimadores guardan de forma predeterminada los puntos de control con nombres de variables en lugar del gráfico de objetos descrito en la guía de puntos de control . tf.train.Checkpoint
leerá los puntos de control basados en nombres, pero los nombres de las variables pueden cambiar al mover partes de un modelo fuera del model_fn
del Estimator. Para la compatibilidad futura, guardar los puntos de control basados en objetos hace que sea más fácil entrenar un modelo dentro de un Estimator y luego usarlo fuera de uno.
import tensorflow.compat.v1 as tf_compat
def toy_dataset():
inputs = tf.range(10.)[:, None]
labels = inputs * 5. + tf.range(5.)[None, :]
return tf.data.Dataset.from_tensor_slices(
dict(x=inputs, y=labels)).repeat().batch(2)
class Net(tf.keras.Model):
"""A simple linear model."""
def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)
def call(self, x):
return self.l1(x)
def model_fn(features, labels, mode):
net = Net()
opt = tf.keras.optimizers.Adam(0.1)
ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
optimizer=opt, net=net)
with tf.GradientTape() as tape:
output = net(features['x'])
loss = tf.reduce_mean(tf.abs(output - features['y']))
variables = net.trainable_variables
gradients = tape.gradient(loss, variables)
return tf.estimator.EstimatorSpec(
mode,
loss=loss,
train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
ckpt.step.assign_add(1)),
# Tell the Estimator to save "ckpt" in an object-based format.
scaffold=tf_compat.train.Scaffold(saver=ckpt))
tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config. INFO:tensorflow:Using default config. INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 4.659403, step = 0 INFO:tensorflow:loss = 4.659403, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt. INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Loss for final step: 39.58891. INFO:tensorflow:Loss for final step: 39.58891. <tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f4b7c451fd0>
tf.train.Checkpoint
puede cargar los puntos de control del Estimator desde su model_dir
.
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy() # From est.train(..., steps=10)
10
Modelos guardados de Estimadores
Los estimadores exportan modelos guardados a través tf.Estimator.export_saved_model
.
input_column = tf.feature_column.numeric_column("x")
estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])
def input_fn():
return tf.data.Dataset.from_tensor_slices(
({"x": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)
estimator.train(input_fn)
INFO:tensorflow:Using default config. INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp30_d7xz6 WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp30_d7xz6 INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp30_d7xz6', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp30_d7xz6', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp30_d7xz6/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp30_d7xz6/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.6931472, step = 0 INFO:tensorflow:loss = 0.6931472, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50... INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmp30_d7xz6/model.ckpt. INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmp30_d7xz6/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50... INFO:tensorflow:Loss for final step: 0.4022895. INFO:tensorflow:Loss for final step: 0.4022895. <tensorflow_estimator.python.estimator.canned.linear.LinearClassifierV2 at 0x7f4b1c10fd10>
Para guardar un Estimator
, debe crear un serving_input_receiver
. Esta función crea una parte de un tf.Graph
que analiza los datos sin procesar recibidos por el modelo guardado.
El módulo tf.estimator.export
contiene funciones para ayudar a construir estos receivers
.
El siguiente código crea un receptor, basado en feature_columns
, que acepta búferes de protocolo tf.Example
serializados, que a menudo se usan con tf-serving .
tmpdir = tempfile.mkdtemp()
serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
tf.feature_column.make_parse_example_spec([input_column]))
estimator_base_path = os.path.join(tmpdir, 'from_estimator')
estimator_path = estimator.export_saved_model(estimator_base_path, serving_input_fn)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info. INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification'] INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification'] INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression'] INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression'] INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict'] INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict'] INFO:tensorflow:Signatures INCLUDED in export for Train: None INFO:tensorflow:Signatures INCLUDED in export for Train: None INFO:tensorflow:Signatures INCLUDED in export for Eval: None INFO:tensorflow:Signatures INCLUDED in export for Eval: None INFO:tensorflow:Restoring parameters from /tmp/tmp30_d7xz6/model.ckpt-50 INFO:tensorflow:Restoring parameters from /tmp/tmp30_d7xz6/model.ckpt-50 INFO:tensorflow:Assets added to graph. INFO:tensorflow:Assets added to graph. INFO:tensorflow:No assets to write. INFO:tensorflow:No assets to write. INFO:tensorflow:SavedModel written to: /tmp/tmpi_szzuj1/from_estimator/temp-1632343781/saved_model.pb INFO:tensorflow:SavedModel written to: /tmp/tmpi_szzuj1/from_estimator/temp-1632343781/saved_model.pb
También puede cargar y ejecutar ese modelo, desde python:
imported = tf.saved_model.load(estimator_path)
def predict(x):
example = tf.train.Example()
example.features.feature["x"].float_list.value.extend([x])
return imported.signatures["predict"](
examples=tf.constant([example.SerializeToString()]))
print(predict(1.5))
print(predict(3.5))
{'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[1]])>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'1']], dtype=object)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.2974025]], dtype=float32)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.5738074]], dtype=float32)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.42619258, 0.5738074 ]], dtype=float32)>} {'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[0]])>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'0']], dtype=object)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-1.1919093]], dtype=float32)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.23291764]], dtype=float32)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.7670824 , 0.23291762]], dtype=float32)>}
tf.estimator.export.build_raw_serving_input_receiver_fn
le permite crear funciones de entrada que toman tensores sin procesar en lugar de tf.train.Example
s.
Uso de tf.distribute.Strategy
con Estimator (soporte limitado)
tf.estimator
es una API de TensorFlow de capacitación distribuida que originalmente admitía el enfoque del servidor de parámetros asincrónicos. tf.estimator
ahora es compatible con tf.distribute.Strategy
. Si usa tf.estimator
, puede cambiar al entrenamiento distribuido con muy pocos cambios en su código. Con esto, los usuarios de Estimator ahora pueden realizar capacitación distribuida sincrónica en múltiples GPU y múltiples trabajadores, así como usar TPU. Este soporte en Estimator es, sin embargo, limitado. Consulte la sección Qué es compatible ahora a continuación para obtener más detalles.
Usar tf.distribute.Strategy
con Estimator es ligeramente diferente que en el caso de Keras. En lugar de usar strategy.scope
, ahora pasa el objeto de estrategia a RunConfig
para el Estimador.
Puede consultar la guía de capacitación distribuida para obtener más información.
Aquí hay un fragmento de código que muestra esto con un Estimator LinearRegressor
y MirroredStrategy
prefabricados:
mirrored_strategy = tf.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(
train_distribute=mirrored_strategy, eval_distribute=mirrored_strategy)
regressor = tf.estimator.LinearRegressor(
feature_columns=[tf.feature_column.numeric_column('feats')],
optimizer='SGD',
config=config)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Initializing RunConfig with distribution strategies. INFO:tensorflow:Initializing RunConfig with distribution strategies. INFO:tensorflow:Not using Distribute Coordinator. INFO:tensorflow:Not using Distribute Coordinator. WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpftw63jyd WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpftw63jyd INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpftw63jyd', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f4b0c04c050>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f4b0c04c050>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None} INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpftw63jyd', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f4b0c04c050>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f4b0c04c050>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}
Aquí, usa un estimador prefabricado, pero el mismo código también funciona con un estimador personalizado. train_distribute
determina cómo se distribuirá la capacitación y eval_distribute
determina cómo se distribuirá la evaluación. Esta es otra diferencia con Keras, donde usa la misma estrategia tanto para el entrenamiento como para la evaluación.
Ahora puede entrenar y evaluar este Estimador con una función de entrada:
def input_fn():
dataset = tf.data.Dataset.from_tensors(({"feats":[1.]}, [1.]))
return dataset.repeat(1000).batch(10)
regressor.train(input_fn=input_fn, steps=10)
regressor.evaluate(input_fn=input_fn, steps=10)
INFO:tensorflow:Calling model_fn. /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py:374: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options. warnings.warn("To make it possible to preserve tf.data options across " INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/util.py:95: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version. Instructions for updating: Use the iterator's `initializer` property instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/util.py:95: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version. Instructions for updating: Use the iterator's `initializer` property instead. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpftw63jyd/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpftw63jyd/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... 2021-09-22 20:49:45.706166: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} } . Registered: device='CPU' 2021-09-22 20:49:45.707521: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} } . Registered: device='CPU' INFO:tensorflow:loss = 1.0, step = 0 INFO:tensorflow:loss = 1.0, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmpftw63jyd/model.ckpt. INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmpftw63jyd/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Loss for final step: 2.877698e-13. INFO:tensorflow:Loss for final step: 2.877698e-13. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2021-09-22T20:49:46 INFO:tensorflow:Starting evaluation at 2021-09-22T20:49:46 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpftw63jyd/model.ckpt-10 INFO:tensorflow:Restoring parameters from /tmp/tmpftw63jyd/model.ckpt-10 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. 2021-09-22 20:49:46.680821: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} } . Registered: device='CPU' 2021-09-22 20:49:46.682161: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} } . Registered: device='CPU' INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.26514s INFO:tensorflow:Inference Time : 0.26514s INFO:tensorflow:Finished evaluation at 2021-09-22-20:49:46 INFO:tensorflow:Finished evaluation at 2021-09-22-20:49:46 INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994 INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmpftw63jyd/model.ckpt-10 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmpftw63jyd/model.ckpt-10 {'average_loss': 1.4210855e-14, 'label/mean': 1.0, 'loss': 1.4210855e-14, 'prediction/mean': 0.99999994, 'global_step': 10}
Otra diferencia a destacar aquí entre Estimator y Keras es el manejo de entrada. En Keras, cada lote del conjunto de datos se divide automáticamente en múltiples réplicas. En Estimator, sin embargo, no realiza la división automática de lotes ni fragmenta automáticamente los datos entre diferentes trabajadores. Tiene control total sobre cómo desea que se distribuyan sus datos entre trabajadores y dispositivos, y debe proporcionar un input_fn
para especificar cómo distribuir sus datos.
Su input_fn
se llama una vez por trabajador, lo que proporciona un conjunto de datos por trabajador. Luego, un lote de ese conjunto de datos se alimenta a una réplica en ese trabajador, por lo que consume N lotes para N réplicas en 1 trabajador. En otras palabras, el conjunto de datos devuelto por input_fn
debe proporcionar lotes de tamaño PER_REPLICA_BATCH_SIZE
. Y el tamaño de lote global para un paso se puede obtener como PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync
.
Al realizar la capacitación de varios trabajadores, debe dividir los datos entre los trabajadores o mezclar con una semilla aleatoria en cada uno. Puedes consultar un ejemplo de cómo hacerlo en el tutorial Formación multitrabajador con Estimator .
Y de manera similar, también puede usar estrategias de servidores de parámetros y trabajadores múltiples. El código sigue siendo el mismo, pero debe usar tf.estimator.train_and_evaluate
y establecer variables de entorno TF_CONFIG
para cada binario que se ejecuta en su clúster.
¿Qué es compatible ahora?
Hay soporte limitado para el entrenamiento con Estimator usando todas las estrategias excepto TPUStrategy
. La capacitación básica y la evaluación deberían funcionar, pero una serie de características avanzadas como v1.train.Scaffold
no lo hacen. También puede haber una serie de errores en esta integración y no hay planes para mejorar activamente este soporte (el enfoque está en Keras y el soporte de bucle de entrenamiento personalizado). Si es posible, debería preferir usar tf.distribute
con esas API en su lugar.
API de entrenamiento | Estrategia reflejada | TPUEstrategia | MultiWorkerMirroredStrategy | Estrategia de almacenamiento central | ParámetroServidorEstrategia |
---|---|---|---|---|---|
API del estimador | Soporte limitado | No soportado | Soporte limitado | Soporte limitado | Soporte limitado |
Ejemplos y tutoriales
Aquí hay algunos ejemplos completos que muestran cómo usar varias estrategias con Estimator:
- El tutorial Capacitación de varios trabajadores con Estimator muestra cómo puede capacitarse con varios trabajadores mediante
MultiWorkerMirroredStrategy
en el conjunto de datos MNIST. - Un ejemplo integral de ejecución de capacitación para varios trabajadores con estrategias de distribución en
tensorflow/ecosystem
utilizando plantillas de Kubernetes. Comienza con un modelo Keras y lo convierte en un Estimator utilizando la APItf.keras.estimator.model_to_estimator
. - El modelo oficial de ResNet50 , que se puede entrenar con
MirroredStrategy
oMultiWorkerMirroredStrategy
.