Il componente pipeline Trainer TFX addestra un modello TensorFlow.
Formatore e TensorFlow
Trainer fa ampio uso dell'API Python TensorFlow per l'addestramento dei modelli.
Componente
Il formatore prende:
- tf.Esempi utilizzati per la formazione e la valutazione.
- Un file del modulo fornito dall'utente che definisce la logica del trainer.
- Definizione del protobuf degli argomenti train e degli argomenti eval.
- (Facoltativo) Uno schema di dati creato da un componente della pipeline SchemaGen e facoltativamente modificato dallo sviluppatore.
- (Facoltativo) grafico di trasformazione prodotto da un componente di trasformazione upstream.
- (Facoltativo) modelli preaddestrati utilizzati per scenari come l'avvio a caldo.
- (Facoltativo) iperparametri, che verranno passati alla funzione del modulo utente. I dettagli dell'integrazione con Tuner possono essere trovati qui .
Il trainer emette: almeno un modello per l'inferenza/elaborazione (tipicamente in SavedModelFormat) e facoltativamente un altro modello per eval (tipicamente un EvalSavedModel).
Forniamo supporto per formati di modello alternativi come TFLite tramite la libreria di riscrittura dei modelli . Consulta il collegamento alla libreria di riscrittura dei modelli per esempi di come convertire i modelli Estimator e Keras.
Allenatore generico
Il trainer generico consente agli sviluppatori di utilizzare qualsiasi API del modello TensorFlow con il componente Trainer. Oltre agli stimatori TensorFlow, gli sviluppatori possono utilizzare modelli Keras o cicli di formazione personalizzati. Per i dettagli, consultare la RFC per il trainer generico .
Configurazione del componente Trainer
Il tipico codice DSL della pipeline per il Trainer generico sarebbe simile al seguente:
from tfx.components import Trainer
...
trainer = Trainer(
module_file=module_file,
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
Il trainer richiama un modulo di formazione, specificato nel parametro module_file
. Invece di trainer_fn
, è richiesto un run_fn
nel file del modulo se GenericExecutor
è specificato in custom_executor_spec
. Il trainer_fn
era responsabile della creazione del modello. In aggiunta a ciò, run_fn
deve anche gestire la parte di training e inviare il modello addestrato nella posizione desiderata fornita da FnArgs :
from tfx.components.trainer.fn_args_utils import FnArgs
def run_fn(fn_args: FnArgs) -> None:
"""Build the TF model and train it."""
model = _build_keras_model()
model.fit(...)
# Save model to fn_args.serving_model_dir.
model.save(fn_args.serving_model_dir, ...)
Ecco un file di modulo di esempio con run_fn
.
Tieni presente che se il componente Transform non viene utilizzato nella pipeline, il Trainer prenderà direttamente gli esempi da EsempioGen:
trainer = Trainer(
module_file=module_file,
examples=example_gen.outputs['examples'],
schema=infer_schema.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
Maggiori dettagli sono disponibili nel riferimento API Trainer .