Il componente della pipeline TFX Trainer

Il componente pipeline Trainer TFX addestra un modello TensorFlow.

Formatore e TensorFlow

Trainer fa ampio uso dell'API Python TensorFlow per l'addestramento dei modelli.

Componente

Il formatore prende:

  • tf.Esempi utilizzati per la formazione e la valutazione.
  • Un file del modulo fornito dall'utente che definisce la logica del trainer.
  • Definizione del protobuf degli argomenti train e degli argomenti eval.
  • (Facoltativo) Uno schema di dati creato da un componente della pipeline SchemaGen e facoltativamente modificato dallo sviluppatore.
  • (Facoltativo) grafico di trasformazione prodotto da un componente di trasformazione upstream.
  • (Facoltativo) modelli preaddestrati utilizzati per scenari come l'avvio a caldo.
  • (Facoltativo) iperparametri, che verranno passati alla funzione del modulo utente. I dettagli dell'integrazione con Tuner possono essere trovati qui .

Il trainer emette: almeno un modello per l'inferenza/elaborazione (tipicamente in SavedModelFormat) e facoltativamente un altro modello per eval (tipicamente un EvalSavedModel).

Forniamo supporto per formati di modello alternativi come TFLite tramite la libreria di riscrittura dei modelli . Consulta il collegamento alla libreria di riscrittura dei modelli per esempi di come convertire i modelli Estimator e Keras.

Allenatore generico

Il trainer generico consente agli sviluppatori di utilizzare qualsiasi API del modello TensorFlow con il componente Trainer. Oltre agli stimatori TensorFlow, gli sviluppatori possono utilizzare modelli Keras o cicli di formazione personalizzati. Per i dettagli, consultare la RFC per il trainer generico .

Configurazione del componente Trainer

Il tipico codice DSL della pipeline per il Trainer generico sarebbe simile al seguente:

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Il trainer richiama un modulo di formazione, specificato nel parametro module_file . Invece di trainer_fn , è richiesto un run_fn nel file del modulo se GenericExecutor è specificato in custom_executor_spec . Il trainer_fn era responsabile della creazione del modello. In aggiunta a ciò, run_fn deve anche gestire la parte di training e inviare il modello addestrato nella posizione desiderata fornita da FnArgs :

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

Ecco un file di modulo di esempio con run_fn .

Tieni presente che se il componente Transform non viene utilizzato nella pipeline, il Trainer prenderà direttamente gli esempi da EsempioGen:

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Maggiori dettagli sono disponibili nel riferimento API Trainer .