tfr.keras.pipeline.MultiTaskPipeline

Pipeline for multi-task training.

Inherits From: ModelFitPipeline, AbstractPipeline

This handles a set of losses and labels. It is intended to mainly work with MultiLabelDatasetBuilder.

Use subclassing to customize the losses and metrics.

Example usage:

context_feature_spec = {}
example_feature_spec = {
    "example_feature_1": tf.io.FixedLenFeature(
        shape=(1,), dtype=tf.float32, default_value=0.0)
}
mask_feature_name = "list_mask"
label_spec_tuple = ("utility",
                    tf.io.FixedLenFeature(
                        shape=(1,),
                        dtype=tf.float32,
                        default_value=_PADDING_LABEL))
label_spec = {"task1": label_spec_tuple, "task2": label_spec_tuple}
weight_spec = ("weight",
               tf.io.FixedLenFeature(
                   shape=(1,), dtype=tf.float32, default_value=1.))
dataset_hparams = DatasetHparams(
    train_input_pattern="train.dat",
    valid_input_pattern="valid.dat",
    train_batch_size=128,
    valid_batch_size=128)
pipeline_hparams = PipelineHparams(
    model_dir="model/",
    num_epochs=2,
    steps_per_epoch=5,
    validation_steps=2,
    learning_rate=0.01,
    loss={
        "task1": "softmax_loss",
        "task2": "pairwise_logistic_loss"
    },
    loss_weights={
        "task1": 1.0,
        "task2": 2.0
    },
    export_best_model=True)
model_builder = MultiTaskModelBuilder(...)
dataset_builder = MultiLabelDatasetBuilder(
    context_feature_spec,
    example_feature_spec,
    mask_feature_name,
    label_spec,
    dataset_hparams,
    sample_weight_spec=weight_spec)
pipeline = MultiTaskPipeline(model_builder, dataset_builder, pipeline_hparams)
pipeline.train_and_validate(verbose=1)

model_builder A ModelBuilder instance for model fit.
dataset_builder An AbstractDatasetBuilder instance to load train and validate datasets and create signatures for SavedModel.
hparams A dict containing model hyperparameters.

Methods

build_callbacks

View source

Sets up Callbacks.

Example usage:

model_builder = ModelBuilder(...)
dataset_builder = DatasetBuilder(...)
hparams = PipelineHparams(...)
pipeline = BasicModelFitPipeline(model_builder, dataset_builder, hparams)
callbacks = pipeline.build_callbacks()

Returns
A list of tf.keras.callbacks.Callback or a tf.keras.callbacks.CallbackList for tensorboard and checkpoint.

build_loss

View source

See AbstractPipeline.

build_metrics

View source

See AbstractPipeline.

build_weighted_metrics

View source

See AbstractPipeline.

export_saved_model

View source

Exports the trained model with signatures.

Example usage:

model_builder = ModelBuilder(...)
dataset_builder = DatasetBuilder(...)
hparams = PipelineHparams(...)
pipeline = BasicModelFitPipeline(model_builder, dataset_builder, hparams)
pipeline.export_saved_model(model_builder.build(), 'saved_model/')

Args
model Model to be saved.
export_to Specifies the directory the model is be exported to.
checkpoint If given, export the model with weights from this checkpoint.

train_and_validate

View source

Main function to train the model with TPU strategy.

Example usage:

context_feature_spec = {}
example_feature_spec = {
    "example_feature_1": tf.io.FixedLenFeature(
        shape=(1,), dtype=tf.float32, default_value=0.0)
}
mask_feature_name = "list_mask"
label_spec = {
    "utility": tf.io.FixedLenFeature(
        shape=(1,), dtype=tf.float32, default_value=0.0)
}
dataset_hparams = DatasetHparams(
    train_input_pattern="train.dat",
    valid_input_pattern="valid.dat",
    train_batch_size=128,
    valid_batch_size=128)
pipeline_hparams = pipeline.PipelineHparams(
    model_dir="model/",
    num_epochs=2,
    steps_per_epoch=5,
    validation_steps=2,
    learning_rate=0.01,
    loss="softmax_loss")
model_builder = SimpleModelBuilder(
    context_feature_spec, example_feature_spec, mask_feature_name)
dataset_builder = SimpleDatasetBuilder(
    context_feature_spec,
    example_feature_spec,
    mask_feature_name,
    label_spec,
    dataset_hparams)
pipeline = BasicModelFitPipeline(
    model_builder, dataset_builder, pipeline_hparams)
pipeline.train_and_validate(verbose=1)

Args
verbose An int for the verbosity level.