View source on GitHub |
A task for semantic segmentation.
Inherits From: Task
tfm.vision.SemanticSegmentationTask(
params, logging_dir: Optional[str] = None, name: Optional[str] = None
)
Attributes | |
---|---|
logging_dir
|
|
task_config
|
Methods
aggregate_logs
aggregate_logs(
state=None, step_outputs=None
)
Optional aggregation over logs returned from a validation step.
Given step_logs from a validation step, this function aggregates the logs after each eval_step() (see eval_reduce() function in official/core/base_trainer.py). It runs on CPU and can be used to aggregate metrics during validation, when there are too many metrics that cannot fit into TPU memory. Note that this may increase latency due to data transfer between TPU and CPU. Also, the step output from a validation step may be a tuple with elements from replicas, and a concatenation of the elements is needed in such case.
Args | |
---|---|
state
|
The current state of training, for example, it can be a sequence of metrics. |
step_logs
|
Logs from a validation step. Can be a dictionary. |
build_inputs
build_inputs(
params: tfm.vision.configs.semantic_segmentation.DataConfig
,
input_context: Optional[tf.distribute.InputContext] = None
)
Builds classification input.
build_losses
build_losses(
labels: Mapping[str, tf.Tensor],
model_outputs: Union[Mapping[str, tf.Tensor], tf.Tensor],
aux_losses: Optional[Any] = None
)
Segmentation loss.
Args | |
---|---|
labels
|
labels. |
model_outputs
|
Output logits of the classifier. |
aux_losses
|
auxiliarly loss tensors, i.e. losses in keras.Model.
|
Returns | |
---|---|
The total loss tensor. |
build_metrics
build_metrics(
training: bool = True
)
Gets streaming metrics for training/validation.
build_model
build_model()
Builds segmentation model.
create_optimizer
@classmethod
create_optimizer( optimizer_config:
tfm.optimization.OptimizationConfig
, runtime_config: Optional[tfm.core.base_task.RuntimeConfig
] = None, dp_config: Optional[tfm.core.base_task.DifferentialPrivacyConfig
] = None )
Creates an TF optimizer from configurations.
Args | |
---|---|
optimizer_config
|
the parameters of the Optimization settings. |
runtime_config
|
the parameters of the runtime. |
dp_config
|
the parameter of differential privacy. |
Returns | |
---|---|
A tf.optimizers.Optimizer object. |
inference_step
inference_step(
inputs: tf.Tensor, model: tf.keras.Model
)
Performs the forward step.
initialize
initialize(
model: tf.keras.Model
)
Loads pretrained checkpoint.
process_compiled_metrics
process_compiled_metrics(
compiled_metrics, labels, model_outputs
)
Process and update compiled_metrics.
call when using compile/fit API.
Args | |
---|---|
compiled_metrics
|
the compiled metrics (model.compiled_metrics). |
labels
|
a tensor or a nested structure of tensors. |
model_outputs
|
a tensor or a nested structure of tensors. For example, output of the keras model built by self.build_model. |
process_metrics
process_metrics(
metrics, labels, model_outputs, **kwargs
)
Process and update metrics.
Called when using custom training loop API.
Args | |
---|---|
metrics
|
a nested structure of metrics objects. The return of function self.build_metrics. |
labels
|
a tensor or a nested structure of tensors. |
model_outputs
|
a tensor or a nested structure of tensors. For example, output of the keras model built by self.build_model. |
**kwargs
|
other args. |
reduce_aggregated_logs
reduce_aggregated_logs(
aggregated_logs, global_step=None
)
Optional reduce of aggregated logs over validation steps.
This function reduces aggregated logs at the end of validation, and can be used to compute the final metrics. It runs on CPU and in each eval_end() in base trainer (see eval_end() function in official/core/base_trainer.py).
Args | |
---|---|
aggregated_logs
|
Aggregated logs over multiple validation steps. |
global_step
|
An optional variable of global step. |
Returns | |
---|---|
A dictionary of reduced results. |
train_step
train_step(
inputs: Tuple[Any, Any],
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics: Optional[List[Any]] = None
)
Does forward and backward.
Args | |
---|---|
inputs
|
a dictionary of input tensors. |
model
|
the model, forward pass definition. |
optimizer
|
the optimizer for this training step. |
metrics
|
a nested structure of metrics objects. |
Returns | |
---|---|
A dictionary of logs. |
validation_step
validation_step(
inputs: Tuple[Any, Any],
model: tf.keras.Model,
metrics: Optional[List[Any]] = None
)
Validatation step.
Args | |
---|---|
inputs
|
a dictionary of input tensors. |
model
|
the keras.Model. |
metrics
|
a nested structure of metrics objects. |
Returns | |
---|---|
A dictionary of logs. |
Class Variables | |
---|---|
loss |
'loss'
|