ส่วนประกอบไปป์ไลน์ Trainer TFX จะฝึกโมเดล TensorFlow
เทรนเนอร์และ TensorFlow
Trainer ใช้ Python TensorFlow API อย่างกว้างขวางสำหรับโมเดลการฝึก
ส่วนประกอบ
ผู้ฝึกสอนใช้เวลา:
- tf.ตัวอย่างที่ใช้สำหรับการฝึกอบรมและประเมินผล
- ไฟล์โมดูลที่ผู้ใช้จัดเตรียมไว้ซึ่งกำหนดตรรกะของผู้ฝึกสอน
- คำจำกัดความ Protobuf ของ train args และ eval args
- (ไม่บังคับ) สคีมาข้อมูลที่สร้างโดยส่วนประกอบไปป์ไลน์ SchemaGen และอาจแก้ไขโดยนักพัฒนา
- (ทางเลือก) กราฟการแปลงที่สร้างโดยส่วนประกอบการแปลงต้นน้ำ
- (ไม่บังคับ) โมเดลที่ได้รับการฝึกอบรมล่วงหน้าที่ใช้สำหรับสถานการณ์ เช่น วอร์มสตาร์ท
- (ไม่บังคับ) ไฮเปอร์พารามิเตอร์ ซึ่งจะถูกส่งไปยังฟังก์ชันโมดูลผู้ใช้ ดูรายละเอียดของการทำงานร่วมกับ Tuner ได้ ที่นี่
ผู้ฝึกสอนส่งเสียง: อย่างน้อยหนึ่งรุ่นสำหรับการอนุมาน/การให้บริการ (โดยทั่วไปจะอยู่ใน SavedModelFormat) และอีกรุ่นหนึ่งสำหรับ eval (โดยทั่วไปคือ EvalSavedModel)
เราให้การสนับสนุนรูปแบบโมเดลทางเลือก เช่น TFLite ผ่านทาง Model Rewriting Library ดูลิงก์ไปยัง Model Rewriting Library สำหรับตัวอย่างวิธีแปลงทั้งโมเดล Estimator และ Keras
เทรนเนอร์ทั่วไป
โปรแกรมฝึกสอนทั่วไปช่วยให้นักพัฒนาสามารถใช้ API ของโมเดล TensorFlow กับส่วนประกอบของโปรแกรมฝึกสอนได้ นอกจาก TensorFlow Estimators แล้ว นักพัฒนายังสามารถใช้โมเดล Keras หรือลูปการฝึกแบบกำหนดเองได้ สำหรับรายละเอียด โปรดดู RFC สำหรับผู้ฝึกสอนทั่วไป
การกำหนดค่าส่วนประกอบเทรนเนอร์
รหัส DSL ไปป์ไลน์ทั่วไปสำหรับ Trainer ทั่วไปจะมีลักษณะดังนี้:
from tfx.components import Trainer
...
trainer = Trainer(
module_file=module_file,
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
ผู้ฝึกสอนเรียกใช้โมดูลการฝึกอบรม ซึ่งระบุไว้ในพารามิเตอร์ module_file
แทนที่จะเป็น trainer_fn
จำเป็นต้องใช้ run_fn
ในไฟล์โมดูลหากระบุ GenericExecutor
ใน custom_executor_spec
trainer_fn
มีหน้าที่รับผิดชอบในการสร้างแบบจำลอง นอกจากนั้น run_fn
ยังต้องจัดการส่วนการฝึกและส่งออกโมเดลที่ได้รับการฝึกไปยังตำแหน่งที่ต้องการที่กำหนดโดย FnArgs :
from tfx.components.trainer.fn_args_utils import FnArgs
def run_fn(fn_args: FnArgs) -> None:
"""Build the TF model and train it."""
model = _build_keras_model()
model.fit(...)
# Save model to fn_args.serving_model_dir.
model.save(fn_args.serving_model_dir, ...)
นี่คือ ตัวอย่างไฟล์โมดูล ที่มี run_fn
โปรดทราบว่าหากไม่ได้ใช้ส่วนประกอบ Transform ในไปป์ไลน์ Trainer จะนำตัวอย่างจาก ExampleGen โดยตรง:
trainer = Trainer(
module_file=module_file,
examples=example_gen.outputs['examples'],
schema=infer_schema.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
มีรายละเอียดเพิ่มเติมใน ข้อมูลอ้างอิง Trainer API