View source on GitHub |
A single-replica view of training procedure.
tfm.core.base_task.Task(
params, logging_dir: Optional[str] = None, name: Optional[str] = None
)
Tasks provide artifacts for training/validation procedures, including loading/iterating over Datasets, training/validation steps, calculating the loss and customized metrics with reduction.
Attributes | |
---|---|
logging_dir
|
|
task_config
|
Methods
aggregate_logs
aggregate_logs(
state, step_logs
)
Optional aggregation over logs returned from a validation step.
Given step_logs from a validation step, this function aggregates the logs after each eval_step() (see eval_reduce() function in official/core/base_trainer.py). It runs on CPU and can be used to aggregate metrics during validation, when there are too many metrics that cannot fit into TPU memory. Note that this may increase latency due to data transfer between TPU and CPU. Also, the step output from a validation step may be a tuple with elements from replicas, and a concatenation of the elements is needed in such case.
Args | |
---|---|
state
|
The current state of training, for example, it can be a sequence of metrics. |
step_logs
|
Logs from a validation step. Can be a dictionary. |
build_inputs
@abc.abstractmethod
build_inputs( params, input_context: Optional[tf.distribute.InputContext] = None )
Returns a dataset or a nested structure of dataset functions.
Dataset functions define per-host datasets with the per-replica batch size. With distributed training, this method runs on remote hosts.
Args | |
---|---|
params
|
hyperparams to create input pipelines, which can be any of dataclass, ConfigDict, namedtuple, etc. |
input_context
|
optional distribution input pipeline context. |
Returns | |
---|---|
A nested structure of per-replica input functions. |
build_losses
build_losses(
labels, model_outputs, aux_losses=None
) -> tf.Tensor
Standard interface to compute losses.
Args | |
---|---|
labels
|
optional label tensors. |
model_outputs
|
a nested structure of output tensors. |
aux_losses
|
auxiliary loss tensors, i.e. losses in keras.Model.
|
Returns | |
---|---|
The total loss tensor. |
build_metrics
build_metrics(
training: bool = True
)
Gets streaming metrics for training/validation.
build_model
build_model() -> tf.keras.Model
[Optional] Creates model architecture.
Returns | |
---|---|
A model instance. |
create_optimizer
@classmethod
create_optimizer( optimizer_config:
tfm.optimization.OptimizationConfig
, runtime_config: Optional[tfm.core.base_task.RuntimeConfig
] = None, dp_config: Optional[tfm.core.base_task.DifferentialPrivacyConfig
] = None )
Creates an TF optimizer from configurations.
Args | |
---|---|
optimizer_config
|
the parameters of the Optimization settings. |
runtime_config
|
the parameters of the runtime. |
dp_config
|
the parameter of differential privacy. |
Returns | |
---|---|
A tf.optimizers.Optimizer object. |
inference_step
inference_step(
inputs, model: tf.keras.Model
)
Performs the forward step.
With distribution strategies, this method runs on devices.
Args | |
---|---|
inputs
|
a dictionary of input tensors. |
model
|
the keras.Model. |
Returns | |
---|---|
Model outputs. |
initialize
initialize(
model: tf.keras.Model
)
[Optional] A callback function used as CheckpointManager's init_fn.
This function will be called when no checkpoint is found for the model. If there is a checkpoint, the checkpoint will be loaded and this function will not be called. You can use this callback function to load a pretrained checkpoint, saved under a directory other than the model_dir.
Args | |
---|---|
model
|
The keras.Model built or used by this task. |
process_compiled_metrics
process_compiled_metrics(
compiled_metrics, labels, model_outputs
)
Process and update compiled_metrics.
call when using compile/fit API.
Args | |
---|---|
compiled_metrics
|
the compiled metrics (model.compiled_metrics). |
labels
|
a tensor or a nested structure of tensors. |
model_outputs
|
a tensor or a nested structure of tensors. For example, output of the keras model built by self.build_model. |
process_metrics
process_metrics(
metrics, labels, model_outputs, **kwargs
)
Process and update metrics.
Called when using custom training loop API.
Args | |
---|---|
metrics
|
a nested structure of metrics objects. The return of function self.build_metrics. |
labels
|
a tensor or a nested structure of tensors. |
model_outputs
|
a tensor or a nested structure of tensors. For example, output of the keras model built by self.build_model. |
**kwargs
|
other args. |
reduce_aggregated_logs
reduce_aggregated_logs(
aggregated_logs, global_step: Optional[tf.Tensor] = None
)
Optional reduce of aggregated logs over validation steps.
This function reduces aggregated logs at the end of validation, and can be used to compute the final metrics. It runs on CPU and in each eval_end() in base trainer (see eval_end() function in official/core/base_trainer.py).
Args | |
---|---|
aggregated_logs
|
Aggregated logs over multiple validation steps. |
global_step
|
An optional variable of global step. |
Returns | |
---|---|
A dictionary of reduced results. |
train_step
train_step(
inputs,
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics=None
)
Does forward and backward.
With distribution strategies, this method runs on devices.
Args | |
---|---|
inputs
|
a dictionary of input tensors. |
model
|
the model, forward pass definition. |
optimizer
|
the optimizer for this training step. |
metrics
|
a nested structure of metrics objects. |
Returns | |
---|---|
A dictionary of logs. |
validation_step
validation_step(
inputs, model: tf.keras.Model, metrics=None
)
Validation step.
With distribution strategies, this method runs on devices.
Args | |
---|---|
inputs
|
a dictionary of input tensors. |
model
|
the keras.Model. |
metrics
|
a nested structure of metrics objects. |
Returns | |
---|---|
A dictionary of logs. |
Class Variables | |
---|---|
loss |
'loss'
|