View source on GitHub |
Optimizer factory class.
tfm.optimization.OptimizerFactory(
config: tfm.optimization.OptimizationConfig
)
This class builds learning rate and optimizer based on an optimization config. To use this class, you need to do the following: (1) Define optimization config, this includes optimizer, and learning rate schedule. (2) Initialize the class using the optimization config. (3) Build learning rate. (4) Build optimizer.
This is a typical example for using this class:
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]}
},
'warmup': {
'type': 'linear',
'linear': {'warmup_steps': 500, 'warmup_learning_rate': 0.01}
}
}
opt_config = OptimizationConfig(params)
opt_factory = OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
optimizer = opt_factory.build_optimizer(lr)
Args | |
---|---|
config
|
OptimizationConfig instance contain optimization config. |
Methods
build_learning_rate
build_learning_rate()
Build learning rate.
Builds learning rate from config. Learning rate schedule is built according to the learning rate config. If learning rate type is consant, lr_config.learning_rate is returned.
Returns | |
---|---|
tf.keras.optimizers.schedules.LearningRateSchedule instance. If learning rate type is consant, lr_config.learning_rate is returned. |
build_optimizer
build_optimizer(
lr: Union[tf.keras.optimizers.schedules.LearningRateSchedule, float],
gradient_aggregator: Optional[Callable[[List[Tuple[tf.Tensor, tf.Tensor]]], List[Tuple[tf.Tensor,
tf.Tensor]]]] = None,
gradient_transformers: Optional[List[Callable[[List[Tuple[tf.Tensor, tf.Tensor]]], List[Tuple[tf.
Tensor, tf.Tensor]]]]] = None,
postprocessor: Optional[Callable[[tf.keras.optimizers.Optimizer], tf.keras.optimizers.
Optimizer]] = None,
use_legacy_optimizer: bool = True
)
Build optimizer.
Builds optimizer from config. It takes learning rate as input, and builds the optimizer according to the optimizer config. Typically, the learning rate built using self.build_lr() is passed as an argument to this method.
Args | |
---|---|
lr
|
A floating point value, or a tf.keras.optimizers.schedules.LearningRateSchedule instance. |
gradient_aggregator
|
Optional function to overwrite gradient aggregation. |
gradient_transformers
|
Optional list of functions to use to transform gradients before applying updates to Variables. The functions are applied after gradient_aggregator. The functions should accept and return a list of (gradient, variable) tuples. clipvalue, clipnorm, global_clipnorm should not be set when gradient_transformers is passed. |
postprocessor
|
An optional function for postprocessing the optimizer. It takes an optimizer and returns an optimizer. |
use_legacy_optimizer
|
A boolean that indicates if using legacy optimizers. |
Returns | |
---|---|
tf.keras.optimizers.legacy.Optimizer or
tf.keras.optimizers.experimental.Optimizer instance.
|