tfr.keras.pipeline.AbstractPipeline

Interface for ranking pipeline to train a tf.keras.Model.

The AbstractPipeline class is an abstract class to train and validate a ranking model in tfr.keras.

To be implemented by subclasses:

  • build_loss(): Contains the logic to build a tf.keras.losses.Loss or a dict or list of tf.keras.losses.Losss to be optimized in training.
  • build_metrics(): Contains the logic to build a list or dict of tf.keras.metrics.Metrics to monitor and evaluate the training.
  • build_weighted_metrics(): Contains the logic to build a list or dict of tf.keras.metrics.Metrics which will take the weights.
  • train_and_validate(): Contrains the main training pipeline for training and validation.

Example subclass implementation:

class BasicPipeline(AbstractPipeline):

  def __init__(self, model, train_data, valid_data, name=None):
    self._model = model
    self._train_data = train_data
    self._valid_data = valid_data
    self._name = name

  def build_loss(self):
    return tfr.keras.losses.get('softmax_loss')

  def build_metrics(self):
    return [
        tfr.keras.metrics.get(
            'ndcg', topn=topn, name='ndcg_{}'.format(topn)
        ) for topn in [1, 5, 10]
    ]

  def build_weighted_metrics(self):
    return [
        tfr.keras.metrics.get(
            'ndcg', topn=topn, name='weighted_ndcg_{}'.format(topn)
        ) for topn in [1, 5, 10]
    ]

  def train_and_validate(self, *arg, **kwargs):
    self._model.compile(
        optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
        loss=self.build_loss(),
        metrics=self.build_metrics(),
        weighted_metrics=self.build_weighted_metrics())
    self._model.fit(
        x=self._train_data,
        epochs=100,
        validation_data=self._valid_data)

Methods

build_loss

View source

Returns the loss for model.compile.

Example usage:

pipeline = BasicPipeline(model, train_data, valid_data)
loss = pipeline.build_loss()

Returns
A tf.keras.losses.Loss or a dict or list of tf.keras.losses.Loss.

build_metrics

View source

Returns a list of ranking metrics for model.compile().

Example usage:

pipeline = BasicPipeline(model, train_data, valid_data)
metrics = pipeline.build_metrics()

Returns
A list or a dict of tf.keras.metrics.Metrics.

build_weighted_metrics

View source

Returns a list of weighted ranking metrics for model.compile.

Example usage:

pipeline = BasicPipeline(model, train_data, valid_data)
weighted_metrics = pipeline.build_weighted_metrics()

Returns
A list or a dict of tf.keras.metrics.Metrics.

train_and_validate

View source

Constructs and runs the training pipeline.

Example usage:

pipeline = BasicPipeline(model, train_data, valid_data)
pipeline.train_and_validate()

Args
*arg arguments that might be used in the training pipeline.
**kwargs keyword arguments that might be used in the training pipeline.

Returns
None or a trained tf.keras.Model or a path to a saved tf.keras.Model.