Komponent potoku Trainer TFX szkoli model TensorFlow.
Trener i TensorFlow
Trainer w szerokim zakresie wykorzystuje interfejs API Pythona TensorFlow do modeli szkoleniowych.
Część
Trener bierze:
- tf.Przykłady użyte do szkolenia i ewaluacji.
- Plik modułu dostarczony przez użytkownika, który definiuje logikę trenera.
- Definicja protobufa argumentów pociągowych i argumentów ewaluacyjnych.
- (Opcjonalnie) Schemat danych utworzony przez komponent potoku SchemaGen i opcjonalnie zmieniony przez programistę.
- (Opcjonalnie) wykres transformacji utworzony przez nadrzędny komponent Transform.
- (Opcjonalnie) wstępnie przeszkolone modele używane w scenariuszach, takich jak ciepły start.
- (Opcjonalnie) hiperparametry, które zostaną przekazane do funkcji modułu użytkownika. Szczegóły integracji z Tunerem znajdziesz tutaj .
Trener emituje: Co najmniej jeden model do wnioskowania/obsługiwania (zwykle w formacie SavedModelFormat) i opcjonalnie inny model do eval (zazwyczaj EvalSavedModel).
Zapewniamy obsługę alternatywnych formatów modeli, takich jak TFLite, za pośrednictwem biblioteki przepisywania modeli . Zobacz łącze do Biblioteki przepisywania modeli, aby zapoznać się z przykładami konwertowania modeli estymatora i modelu Keras.
Trener ogólny
Generic trainer umożliwia programistom korzystanie z dowolnego interfejsu API modelu TensorFlow z komponentem Trainer. Oprócz estymatorów TensorFlow programiści mogą korzystać z modeli Keras lub niestandardowych pętli szkoleniowych. Szczegółowe informacje można znaleźć w dokumencie RFC dotyczącym generycznego trenera .
Konfiguracja komponentu Trainer
Typowy kod DSL potoku dla ogólnego Trainera wyglądałby następująco:
from tfx.components import Trainer
...
trainer = Trainer(
module_file=module_file,
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
Trener wywołuje moduł szkoleniowy, który jest określony w parametrze module_file
. Zamiast trainer_fn
, w pliku modułu wymagany jest run_fn
, jeśli w pliku custom_executor_spec
określono GenericExecutor
. Za stworzenie modelu odpowiedzialny był trainer_fn
. Oprócz tego run_fn
musi również obsłużyć część szkoleniową i wyprowadzić przeszkolony model do żądanej lokalizacji określonej przez FnArgs :
from tfx.components.trainer.fn_args_utils import FnArgs
def run_fn(fn_args: FnArgs) -> None:
"""Build the TF model and train it."""
model = _build_keras_model()
model.fit(...)
# Save model to fn_args.serving_model_dir.
model.save(fn_args.serving_model_dir, ...)
Oto przykładowy plik modułu z run_fn
.
Należy pamiętać, że jeśli w potoku nie zostanie użyty komponent Transform, Trainer pobierze przykłady bezpośrednio z PrzykładGen:
trainer = Trainer(
module_file=module_file,
examples=example_gen.outputs['examples'],
schema=infer_schema.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
Więcej szczegółów można znaleźć w dokumentacji Trainer API .