Компонент конвейера Trainer TFX обучает модель TensorFlow.
Тренер и TensorFlow
Trainer широко использует API Python TensorFlow для обучения моделей.
Компонент
Тренер берет:
- tf.Примеры, используемые для обучения и оценки.
- Предоставленный пользователем файл модуля, определяющий логику тренера.
- Protobuf определение аргументов поезда и аргументов оценки.
- (Необязательно) Схема данных, созданная компонентом конвейера SchemaGen и при необходимости измененная разработчиком.
- (Необязательно) граф преобразования, созданный вышестоящим компонентом Transform.
- (Необязательно) предварительно обученные модели, используемые для таких сценариев, как теплый запуск.
- (Необязательно) гиперпараметры, которые будут переданы в функцию пользовательского модуля. Подробности интеграции с Tuner можно найти здесь .
Тренер выдает: как минимум одну модель для вывода/обслуживания (обычно в SavedModelFormat) и, возможно, еще одну модель для оценки (обычно EvalSavedModel).
Мы предоставляем поддержку альтернативных форматов моделей, таких как TFLite, через Библиотеку перезаписи моделей . См. ссылку на Библиотеку перезаписи моделей, где приведены примеры преобразования моделей Estimator и Keras.
Общий тренер
Универсальный тренажер позволяет разработчикам использовать любой API модели TensorFlow с компонентом Trainer. В дополнение к оценщикам TensorFlow разработчики могут использовать модели Keras или собственные циклы обучения. Подробную информацию см. в RFC для универсального тренажера .
Настройка компонента тренера
Типичный код конвейера DSL для универсального Trainer будет выглядеть так:
from tfx.components import Trainer
...
trainer = Trainer(
module_file=module_file,
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
Trainer вызывает обучающий модуль, указанный в параметре module_file
. Вместо trainer_fn
в файле модуля требуется run_fn
, если GenericExecutor
указан в custom_executor_spec
. trainer_fn
отвечал за создание модели. В дополнение к этому, run_fn
также должен обрабатывать обучающую часть и выводить обученную модель в желаемое место, заданное FnArgs :
from tfx.components.trainer.fn_args_utils import FnArgs
def run_fn(fn_args: FnArgs) -> None:
"""Build the TF model and train it."""
model = _build_keras_model()
model.fit(...)
# Save model to fn_args.serving_model_dir.
model.save(fn_args.serving_model_dir, ...)
Вот пример файла модуля с run_fn
.
Обратите внимание: если компонент Transform не используется в конвейере, то Trainer будет напрямую брать примеры из SampleGen:
trainer = Trainer(
module_file=module_file,
examples=example_gen.outputs['examples'],
schema=infer_schema.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
Более подробная информация доступна в справочнике Trainer API .