View source on GitHub |
A single-replica view of training procedure.
Inherits From: Task
tfm.vision.MaskRCNNTask(
params, logging_dir: Optional[str] = None, name: Optional[str] = None
)
Mask R-CNN task provides artifacts for training/evalution procedures, including loading/iterating over Datasets, initializing the model, calculating the loss, post-processing, and customized metrics with reduction.
Attributes | |
---|---|
logging_dir
|
|
task_config
|
Methods
aggregate_logs
aggregate_logs(
state: Optional[Any] = None, step_outputs: Optional[Dict[str, Any]] = None
) -> Optional[Any]
Optional aggregation over logs returned from a validation step.
build_inputs
build_inputs(
params: tfm.vision.configs.maskrcnn.DataConfig
,
input_context: Optional[tf.distribute.InputContext] = None,
dataset_fn: Optional[dataset_fn_lib.PossibleDatasetType] = None
) -> tf.data.Dataset
Builds input dataset.
build_losses
build_losses(
outputs: Mapping[str, Any],
labels: Mapping[str, Any],
aux_losses: Optional[Any] = None
) -> Dict[str, tf.Tensor]
Builds Mask R-CNN losses.
build_metrics
build_metrics(
training: bool = True
)
Builds detection metrics.
build_model
build_model()
Builds Mask R-CNN model.
create_optimizer
@classmethod
create_optimizer( optimizer_config:
tfm.optimization.OptimizationConfig
, runtime_config: Optional[tfm.core.base_task.RuntimeConfig
] = None, dp_config: Optional[tfm.core.base_task.DifferentialPrivacyConfig
] = None )
Creates an TF optimizer from configurations.
Args | |
---|---|
optimizer_config
|
the parameters of the Optimization settings. |
runtime_config
|
the parameters of the runtime. |
dp_config
|
the parameter of differential privacy. |
Returns | |
---|---|
A tf.optimizers.Optimizer object. |
inference_step
inference_step(
inputs, model: tf.keras.Model
)
Performs the forward step.
With distribution strategies, this method runs on devices.
Args | |
---|---|
inputs
|
a dictionary of input tensors. |
model
|
the keras.Model. |
Returns | |
---|---|
Model outputs. |
initialize
initialize(
model: tf.keras.Model
)
Loads pretrained checkpoint.
process_compiled_metrics
process_compiled_metrics(
compiled_metrics, labels, model_outputs
)
Process and update compiled_metrics.
call when using compile/fit API.
Args | |
---|---|
compiled_metrics
|
the compiled metrics (model.compiled_metrics). |
labels
|
a tensor or a nested structure of tensors. |
model_outputs
|
a tensor or a nested structure of tensors. For example, output of the keras model built by self.build_model. |
process_metrics
process_metrics(
metrics, labels, model_outputs, **kwargs
)
Process and update metrics.
Called when using custom training loop API.
Args | |
---|---|
metrics
|
a nested structure of metrics objects. The return of function self.build_metrics. |
labels
|
a tensor or a nested structure of tensors. |
model_outputs
|
a tensor or a nested structure of tensors. For example, output of the keras model built by self.build_model. |
**kwargs
|
other args. |
reduce_aggregated_logs
reduce_aggregated_logs(
aggregated_logs: Dict[str, Any], global_step: Optional[tf.Tensor] = None
) -> Dict[str, tf.Tensor]
Optional reduce of aggregated logs over validation steps.
train_step
train_step(
inputs: Tuple[Any, Any],
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics: Optional[List[Any]] = None
)
Does forward and backward.
Args | |
---|---|
inputs
|
a dictionary of input tensors. |
model
|
the model, forward pass definition. |
optimizer
|
the optimizer for this training step. |
metrics
|
a nested structure of metrics objects. |
Returns | |
---|---|
A dictionary of logs. |
validation_step
validation_step(
inputs: Tuple[Any, Any],
model: tf.keras.Model,
metrics: Optional[List[Any]] = None
)
Validatation step.
Args | |
---|---|
inputs
|
a dictionary of input tensors. |
model
|
the keras.Model. |
metrics
|
a nested structure of metrics objects. |
Returns | |
---|---|
A dictionary of logs. |
Class Variables | |
---|---|
loss |
'loss'
|