El componente de canalización Trainer TFX entrena un modelo de TensorFlow.
Entrenador y TensorFlow
Trainer hace un uso extensivo de la API Python TensorFlow para entrenar modelos.
Componente
El entrenador toma:
- tf.Ejemplos utilizados para entrenamiento y evaluación.
- Un archivo de módulo proporcionado por el usuario que define la lógica del entrenador.
- Definición de Protobuf de argumentos de tren y argumentos de evaluación.
- (Opcional) Un esquema de datos creado por un componente de canalización de SchemaGen y, opcionalmente, modificado por el desarrollador.
- (Opcional) gráfico de transformación producido por un componente Transform ascendente.
- (Opcional) Modelos previamente entrenados utilizados para escenarios como el inicio en caliente.
- Hiperparámetros (opcionales), que se pasarán a la función del módulo de usuario. Los detalles de la integración con Tuner se pueden encontrar aquí .
El entrenador emite: al menos un modelo para inferencia/publicación (normalmente en SavedModelFormat) y, opcionalmente, otro modelo para evaluación (normalmente un EvalSavedModel).
Brindamos soporte para formatos de modelos alternativos como TFLite a través de la Biblioteca de reescritura de modelos . Consulte el enlace a la Biblioteca de reescritura de modelos para ver ejemplos de cómo convertir los modelos Estimator y Keras.
Entrenador genérico
El entrenador genérico permite a los desarrolladores utilizar cualquier API modelo de TensorFlow con el componente Trainer. Además de los estimadores de TensorFlow, los desarrolladores pueden utilizar modelos Keras o bucles de entrenamiento personalizados. Para obtener más información, consulte el RFC para el entrenador genérico .
Configuración del componente de entrenador
El código DSL de canalización típico para el Entrenador genérico se vería así:
from tfx.components import Trainer
...
trainer = Trainer(
module_file=module_file,
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
Trainer invoca un módulo de capacitación, que se especifica en el parámetro module_file
. En lugar de trainer_fn
, se requiere run_fn
en el archivo del módulo si GenericExecutor
se especifica en custom_executor_spec
. trainer_fn
fue responsable de crear el modelo. Además de eso, run_fn
también necesita manejar la parte de entrenamiento y enviar el modelo entrenado a la ubicación deseada proporcionada por FnArgs :
from tfx.components.trainer.fn_args_utils import FnArgs
def run_fn(fn_args: FnArgs) -> None:
"""Build the TF model and train it."""
model = _build_keras_model()
model.fit(...)
# Save model to fn_args.serving_model_dir.
model.save(fn_args.serving_model_dir, ...)
Aquí hay un archivo de módulo de ejemplo con run_fn
.
Tenga en cuenta que si el componente Transformar no se utiliza en el proceso, el formador tomará los ejemplos de EjemploGen directamente:
trainer = Trainer(
module_file=module_file,
examples=example_gen.outputs['examples'],
schema=infer_schema.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
Más detalles están disponibles en la referencia de Trainer API .