View source on GitHub |
Task object for tf-ranking BERT.
Inherits From: RankingTask
tfr.extension.premade.TFRBertTask(
params,
label_spec: Tuple[str, tf.io.FixedLenFeature] = None,
logging_dir: Optional[str] = None,
name: Optional[str] = None,
**kwargs
)
Attributes | |
---|---|
logging_dir
|
|
name
|
Returns the name of this module as passed or determined in the ctor. |
name_scope
|
Returns a tf.name_scope instance for this class.
|
non_trainable_variables
|
Sequence of non-trainable variables owned by this module and its submodules. |
submodules
|
Sequence of all sub-modules.
Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).
|
task_config
|
|
trainable_variables
|
Sequence of trainable variables owned by this module and its submodules. |
variables
|
Sequence of variables owned by this module and its submodules. |
Methods
aggregate_logs
aggregate_logs(
state=None, step_outputs=None
)
Aggregates over logs. This runs on CPU in eager mode.
build_inputs
build_inputs(
params, input_context=None
)
Returns tf.data.Dataset for tf-ranking BERT task.
build_losses
build_losses(
labels, model_outputs, aux_losses=None
) -> tf.Tensor
Standard interface to compute losses.
Args | |
---|---|
labels
|
optional label tensors. |
model_outputs
|
a nested structure of output tensors. |
aux_losses
|
auxiliary loss tensors, i.e. losses in keras.Model.
|
Returns | |
---|---|
The total loss tensor. |
build_metrics
build_metrics(
training=None
)
Gets streaming metrics for training/validation.
build_model
build_model()
[Optional] Creates model architecture.
Returns | |
---|---|
A model instance. |
create_optimizer
@classmethod
create_optimizer( optimizer_config: OptimizationConfig, runtime_config: Optional[RuntimeConfig] = None, dp_config: Optional[DifferentialPrivacyConfig] = None )
Creates an TF optimizer from configurations.
Args | |
---|---|
optimizer_config
|
the parameters of the Optimization settings. |
runtime_config
|
the parameters of the runtime. |
dp_config
|
the parameter of differential privacy. |
Returns | |
---|---|
A tf.optimizers.Optimizer object. |
inference_step
inference_step(
inputs, model: tf.keras.Model
)
Performs the forward step.
With distribution strategies, this method runs on devices.
Args | |
---|---|
inputs
|
a dictionary of input tensors. |
model
|
the keras.Model. |
Returns | |
---|---|
Model outputs. |
initialize
initialize(
model
)
Load a pretrained checkpoint (if exists) and then train from iter 0.
process_compiled_metrics
process_compiled_metrics(
compiled_metrics, labels, model_outputs
)
Process and update compiled_metrics.
call when using compile/fit API.
Args | |
---|---|
compiled_metrics
|
the compiled metrics (model.compiled_metrics). |
labels
|
a tensor or a nested structure of tensors. |
model_outputs
|
a tensor or a nested structure of tensors. For example, output of the keras model built by self.build_model. |
process_metrics
process_metrics(
metrics, labels, model_outputs
)
Process and update metrics.
Called when using custom training loop API.
Args | |
---|---|
metrics
|
a nested structure of metrics objects. The return of function self.build_metrics. |
labels
|
a tensor or a nested structure of tensors. |
model_outputs
|
a tensor or a nested structure of tensors. For example, output of the keras model built by self.build_model. |
**kwargs
|
other args. |
reduce_aggregated_logs
reduce_aggregated_logs(
aggregated_logs, global_step=None
)
Calculates aggregated metrics and writes predictions to csv.
train_step
train_step(
inputs,
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics
)
Does forward and backward.
With distribution strategies, this method runs on devices.
Args | |
---|---|
inputs
|
a dictionary of input tensors. |
model
|
the model, forward pass definition. |
optimizer
|
the optimizer for this training step. |
metrics
|
a nested structure of metrics objects. |
Returns | |
---|---|
A dictionary of logs. |
validation_step
validation_step(
inputs, model: tf.keras.Model, metrics=None
)
Validation step.
With distribution strategies, this method runs on devices.
Args | |
---|---|
inputs
|
a dictionary of input tensors. |
model
|
the keras.Model. |
metrics
|
a nested structure of metrics objects. |
Returns | |
---|---|
A dictionary of logs. |
with_name_scope
@classmethod
with_name_scope( method )
Decorator to automatically enter the module name scope.
class MyModule(tf.Module):
@tf.Module.with_name_scope
def __call__(self, x):
if not hasattr(self, 'w'):
self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
return tf.matmul(x, self.w)
Using the above module would produce tf.Variable
s and tf.Tensor
s whose
names included the module name:
mod = MyModule()
mod(tf.ones([1, 2]))
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
mod.w
<tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
numpy=..., dtype=float32)>
Args | |
---|---|
method
|
The method to wrap. |
Returns | |
---|---|
The original method wrapped such that it enters the module's name scope. |
Class Variables | |
---|---|
loss |
'loss'
|