Компонент конвейера Trainer TFX

Компонент конвейера Trainer TFX обучает модель TensorFlow.

Тренер и TensorFlow

Trainer широко использует API Python TensorFlow для обучения моделей.

Компонент

Тренер берет:

  • tf.Примеры, используемые для обучения и оценки.
  • Предоставленный пользователем файл модуля, определяющий логику тренера.
  • Protobuf определение аргументов поезда и аргументов оценки.
  • (Необязательно) Схема данных, созданная компонентом конвейера SchemaGen и при необходимости измененная разработчиком.
  • (Необязательно) граф преобразования, созданный вышестоящим компонентом Transform.
  • (Необязательно) предварительно обученные модели, используемые для таких сценариев, как теплый запуск.
  • (Необязательно) гиперпараметры, которые будут переданы в функцию пользовательского модуля. Подробности интеграции с Tuner можно найти здесь .

Тренер выдает: как минимум одну модель для вывода/обслуживания (обычно в SavedModelFormat) и, возможно, еще одну модель для оценки (обычно EvalSavedModel).

Мы предоставляем поддержку альтернативных форматов моделей, таких как TFLite, через Библиотеку перезаписи моделей . См. ссылку на Библиотеку перезаписи моделей, где приведены примеры преобразования моделей Estimator и Keras.

Общий тренер

Универсальный тренажер позволяет разработчикам использовать любой API модели TensorFlow с компонентом Trainer. В дополнение к оценщикам TensorFlow разработчики могут использовать модели Keras или собственные циклы обучения. Подробную информацию см. в RFC для универсального тренажера .

Настройка компонента тренера

Типичный код конвейера DSL для универсального Trainer будет выглядеть так:

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Trainer вызывает обучающий модуль, указанный в параметре module_file . Вместо trainer_fn в файле модуля требуется run_fn , если GenericExecutor указан в custom_executor_spec . trainer_fn отвечал за создание модели. В дополнение к этому, run_fn также должен обрабатывать обучающую часть и выводить обученную модель в желаемое место, заданное FnArgs :

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

Вот пример файла модуля с run_fn .

Обратите внимание: если компонент Transform не используется в конвейере, то Trainer будет напрямую брать примеры из SampleGen:

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Более подробная информация доступна в справочнике Trainer API .