View source on GitHub |
KerasTuner interface implementation backed by Vizier Service.
tfc.CloudTuner(
hypermodel: Union[hypermodel_module.HyperModel, Callable[[hp_module.HyperParameters],
tf.keras.Model]],
project_id: Text,
region: Text,
objective: Union[Text, oracle_module.Objective] = None,
hyperparameters: hp_module.HyperParameters = None,
study_config: Optional[Dict[Text, Any]] = None,
max_trials: int = None,
study_id: Optional[Text] = None,
**kwargs
)
CloudTuner is a implementation of KerasTuner that uses Google Cloud Vizier Service as its Oracle. To learn more about KerasTuner and Oracles please refer to:
- https://keras-team.github.io/keras-tuner/
- https://keras-team.github.io/keras-tuner/documentation/oracles/
Example:
tuner = CloudTuner(
build_model,
project_id="MY_PROJECT_ID",
region='us-central1',
objective='accuracy',
hyperparameters=HPS,
max_trials=5,
directory='tmp/MY_JOB')
Args | |
---|---|
hypermodel
|
Instance of HyperModel class (or callable that takes hyperparameters and returns a Model instance). |
project_id
|
A GCP project id. |
region
|
A GCP region. e.g. 'us-central1'. |
objective
|
Name of model metric to minimize or maximize, e.g. "val_accuracy". |
hyperparameters
|
Can be used to override (or register in advance) hyperparameters in the search space. |
study_config
|
Study configuration for Vizier service. |
max_trials
|
Total number of trials (model configurations) to test at
most. Note that the oracle may interrupt the search before
max_trials models have been tested if the search space has
been exhausted.
|
study_id
|
An identifier of the study. The full study name will be projects/{project_id}/locations/{region}/studies/{study_id}. |
**kwargs
|
Keyword arguments relevant to all Tuner subclasses.
Please see the docstring for Tuner .
|
Attributes | |
---|---|
project_dir
|
|
remaining_trials
|
Returns the number of trials remaining.
Will return |
Methods
get_best_hyperparameters
get_best_hyperparameters(
num_trials=1
)
Returns the best hyperparameters, as determined by the objective.
This method can be used to reinstantiate the (untrained) best model found during the search process.
Example:
best_hp = tuner.get_best_hyperparameters()[0]
model = tuner.hypermodel.build(best_hp)
Arguments | |
---|---|
num_trials
|
(int, optional). Number of HyperParameters objects to
return. HyperParameters will be returned in sorted order based on
trial performance.
|
Returns | |
---|---|
List of HyperParameter objects.
|
get_best_models
get_best_models(
num_models=1
)
Returns the best model(s), as determined by the tuner's objective.
The models are loaded with the weights corresponding to their best checkpoint (at the end of the best epoch of best trial).
This method is only a convenience shortcut. For best performance, It is
recommended to retrain your Model on the full dataset using the best
hyperparameters found during search
.
Args | |
---|---|
num_models (int, optional): Number of best models to return. Models will be returned in sorted order. Defaults to 1. |
Returns | |
---|---|
List of trained model instances. |
get_state
get_state()
Returns the current state of this object.
This method is called during save
.
get_trial_dir
get_trial_dir(
trial_id
)
load_model
load_model(
trial
)
Loads a Model from a given trial.
Arguments | |
---|---|
trial
|
A Trial instance. For models that report intermediate
results to the Oracle , generally load_model should load the
best reported step by relying of trial.best_step
|
on_batch_begin
on_batch_begin(
trial, model, batch, logs
)
A hook called at the start of every batch.
Arguments | |
---|---|
trial
|
A Trial instance.
|
model
|
A Keras Model .
|
batch
|
The current batch number within the curent epoch. |
logs
|
Additional metrics. |
on_batch_end
on_batch_end(
trial, model, batch, logs=None
)
A hook called at the end of every batch.
Arguments | |
---|---|
trial
|
A Trial instance.
|
model
|
A Keras Model .
|
batch
|
The current batch number within the curent epoch. |
logs
|
Additional metrics. |
on_epoch_begin
on_epoch_begin(
trial, model, epoch, logs=None
)
A hook called at the start of every epoch.
Arguments | |
---|---|
trial
|
A Trial instance.
|
model
|
A Keras Model .
|
epoch
|
The current epoch number. |
logs
|
Additional metrics. |
on_epoch_end
on_epoch_end(
trial, model, epoch, logs=None
)
A hook called at the end of every epoch.
Arguments | |
---|---|
trial
|
A Trial instance.
|
model
|
A Keras Model .
|
epoch
|
The current epoch number. |
logs
|
Dict. Metrics for this epoch. This should include the value of the objective for this epoch. |
on_search_begin
on_search_begin()
A hook called at the beginning of search
.
on_search_end
on_search_end()
A hook called at the end of search
.
on_trial_begin
on_trial_begin(
trial
)
A hook called before starting each trial.
Arguments | |
---|---|
trial
|
A Trial instance.
|
on_trial_end
on_trial_end(
trial
)
A hook called after each trial is run.
Arguments | |
---|---|
trial
|
A Trial instance.
|
reload
reload()
Reloads this object from its project directory.
results_summary
results_summary(
num_trials=10
)
Display tuning results summary.
Args | |
---|---|
num_trials (int, optional): Number of trials to display. Defaults to 10. |
run_trial
run_trial(
trial, *fit_args, **fit_kwargs
)
Evaluates a set of hyperparameter values.
This method is called during search
to evaluate a set of
hyperparameters.
Arguments | |
---|---|
trial
|
A Trial instance that contains the information
needed to run this trial. Hyperparameters can be accessed
via trial.hyperparameters .
|
*fit_args
|
Positional arguments passed by search .
|
*fit_kwargs
|
Keyword arguments passed by search .
|
save
save()
Saves this object to its project directory.
save_model
save_model(
trial_id, model, step=0
)
Saves a Model for a given trial.
Arguments | |
---|---|
trial_id
|
The ID of the Trial that corresponds to this Model.
|
model
|
The trained model. |
step
|
For models that report intermediate results to the Oracle ,
the step that this saved file should correspond to. For example,
for Keras models this is the number of epochs trained.
|
search
search(
*fit_args, **fit_kwargs
)
Performs a search for best hyperparameter configuations.
Arguments | |
---|---|
*fit_args
|
Positional arguments that should be passed to
run_trial , for example the training and validation data.
|
*fit_kwargs
|
Keyword arguments that should be passed to
run_trial , for example the training and validation data.
|
search_space_summary
search_space_summary(
extended=False
)
Print search space summary.
Args | |
---|---|
extended
|
Bool, optional. Display extended summary. Defaults to False. |
set_state
set_state(
state
)
Sets the current state of this object.
This method is called during reload
.
Arguments | |
---|---|
state
|
Dict. The state to restore for this object. |