Runs train/eval configured by the experiment params.
tfm.core.train_lib.run_experiment(
distribution_strategy: tf.distribute.Strategy,
task: tfm.core.base_task.Task
,
mode: str,
params: tfm.core.base_trainer.ExperimentConfig
,
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True,
train_actions: Optional[List[orbit.Action]] = None,
eval_actions: Optional[List[orbit.Action]] = None,
trainer: Optional[tfm.core.base_trainer.Trainer
] = None,
controller_cls=orbit.Controller,
summary_manager: Optional[orbit.utils.SummaryManager] = None,
eval_summary_manager: Optional[orbit.utils.SummaryManager] = None,
enable_async_checkpointing: bool = False
) -> Tuple[tf.keras.Model, Mapping[str, Any]]
Args |
distribution_strategy
|
A distribution distribution_strategy.
|
task
|
A Task instance.
|
mode
|
A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
|
params
|
ExperimentConfig instance.
|
model_dir
|
A 'str', a path to store model checkpoints and summaries.
|
run_post_eval
|
Whether to run post eval once after training, metrics logs
are returned.
|
save_summary
|
Whether to save train and validation summary.
|
train_actions
|
Optional list of Orbit train actions.
|
eval_actions
|
Optional list of Orbit eval actions.
|
trainer
|
the base_trainer.Trainer instance. It should be created within the
strategy.scope().
|
controller_cls
|
The controller class to manage the train and eval process.
Must be a orbit.Controller subclass.
|
summary_manager
|
Instance of the summary manager to override default summary
manager.
|
eval_summary_manager
|
Instance of the eval summary manager to override
default eval summary manager.
|
enable_async_checkpointing
|
Optional boolean indicating whether to enable
async checkpoint saving.
|
Returns |
A 2-tuple of (model, eval_logs).
model: tf.keras.Model instance.
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
|