جزء خط لوله Trainer TFX یک مدل TensorFlow را آموزش می دهد.
ترینر و تنسورفلو
ترینر به طور گسترده از Python TensorFlow API برای مدل های آموزشی استفاده می کند.
جزء
مربی می گیرد:
- tf. مثال هایی که برای آموزش و ارزیابی استفاده می شوند.
- یک فایل ماژول ارائه شده توسط کاربر که منطق مربی را تعریف می کند.
- تعریف Protobuf از قطار ارگ و ارگ eval.
- (اختیاری) یک طرح داده ایجاد شده توسط یک جزء خط لوله SchemaGen و به صورت اختیاری توسط توسعه دهنده تغییر می کند.
- (اختیاری) تبدیل گراف تولید شده توسط یک جزء Transform بالادست.
- (اختیاری) مدل های از پیش آموزش دیده مورد استفاده برای سناریوهایی مانند شروع گرما.
- هایپرپارامترهای (اختیاری)، که به تابع ماژول کاربر ارسال می شود. جزئیات ادغام با Tuner را می توانید در اینجا پیدا کنید.
Trainer منتشر می کند: حداقل یک مدل برای استنتاج/خدمت (معمولا در SavedModelFormat) و به صورت اختیاری مدل دیگری برای eval (معمولا EvalSavedModel).
ما از قالبهای مدل جایگزین مانند TFLite از طریق کتابخانه بازنویسی مدل پشتیبانی میکنیم. برای نمونههایی از نحوه تبدیل هر دو مدل تخمینگر و کراس، پیوند کتابخانه بازنویسی مدل را ببینید.
مربی عمومی
Generic trainer توسعه دهندگان را قادر می سازد تا از هر API مدل TensorFlow با جزء Trainer استفاده کنند. علاوه بر برآوردگرهای TensorFlow، توسعهدهندگان میتوانند از مدلهای Keras یا حلقههای آموزشی سفارشی استفاده کنند. برای جزئیات، لطفاً RFC برای مربی عمومی را ببینید.
پیکربندی کامپوننت Trainer
کد DSL خط لوله معمولی برای Trainer عمومی به شکل زیر است:
from tfx.components import Trainer
...
trainer = Trainer(
module_file=module_file,
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
Trainer یک ماژول آموزشی را فراخوانی می کند که در پارامتر module_file
مشخص شده است. اگر GenericExecutor
در custom_executor_spec
مشخص شده باشد، به جای trainer_fn
، یک run_fn
در فایل ماژول مورد نیاز است. trainer_fn
مسئول ایجاد مدل بود. علاوه بر آن، run_fn
همچنین باید بخش آموزشی را مدیریت کند و مدل آموزشدیده شده را در محل مورد نظر ارائه شده توسط FnArgs خروجی دهد:
from tfx.components.trainer.fn_args_utils import FnArgs
def run_fn(fn_args: FnArgs) -> None:
"""Build the TF model and train it."""
model = _build_keras_model()
model.fit(...)
# Save model to fn_args.serving_model_dir.
model.save(fn_args.serving_model_dir, ...)
در اینجا یک نمونه فایل ماژول با run_fn
آمده است.
توجه داشته باشید که اگر کامپوننت Transform در خط لوله استفاده نشود، Trainer مثالها را مستقیماً از ExampleGen میگیرد:
trainer = Trainer(
module_file=module_file,
examples=example_gen.outputs['examples'],
schema=infer_schema.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
جزئیات بیشتر در مرجع Trainer API موجود است.