View source on GitHub |
Interface for ranking pipeline to train a tf.keras.Model
.
The AbstractPipeline
class is an abstract class to train and validate a
ranking model in tfr.keras.
To be implemented by subclasses:
build_loss()
: Contains the logic to build atf.keras.losses.Loss
or a dict or list oftf.keras.losses.Loss
s to be optimized in training.build_metrics()
: Contains the logic to build a list or dict oftf.keras.metrics.Metric
s to monitor and evaluate the training.build_weighted_metrics()
: Contains the logic to build a list or dict oftf.keras.metrics.Metric
s which will take the weights.train_and_validate()
: Contrains the main training pipeline for training and validation.
Example subclass implementation:
class BasicPipeline(AbstractPipeline):
def __init__(self, model, train_data, valid_data, name=None):
self._model = model
self._train_data = train_data
self._valid_data = valid_data
self._name = name
def build_loss(self):
return tfr.keras.losses.get('softmax_loss')
def build_metrics(self):
return [
tfr.keras.metrics.get(
'ndcg', topn=topn, name='ndcg_{}'.format(topn)
) for topn in [1, 5, 10]
]
def build_weighted_metrics(self):
return [
tfr.keras.metrics.get(
'ndcg', topn=topn, name='weighted_ndcg_{}'.format(topn)
) for topn in [1, 5, 10]
]
def train_and_validate(self, *arg, **kwargs):
self._model.compile(
optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
loss=self.build_loss(),
metrics=self.build_metrics(),
weighted_metrics=self.build_weighted_metrics())
self._model.fit(
x=self._train_data,
epochs=100,
validation_data=self._valid_data)
Methods
build_loss
@abc.abstractmethod
build_loss() -> Any
Returns the loss for model.compile.
Example usage:
pipeline = BasicPipeline(model, train_data, valid_data)
loss = pipeline.build_loss()
Returns | |
---|---|
A tf.keras.losses.Loss or a dict or list of tf.keras.losses.Loss .
|
build_metrics
@abc.abstractmethod
build_metrics() -> Any
Returns a list of ranking metrics for model.compile()
.
Example usage:
pipeline = BasicPipeline(model, train_data, valid_data)
metrics = pipeline.build_metrics()
Returns | |
---|---|
A list or a dict of tf.keras.metrics.Metric s.
|
build_weighted_metrics
@abc.abstractmethod
build_weighted_metrics() -> Any
Returns a list of weighted ranking metrics for model.compile.
Example usage:
pipeline = BasicPipeline(model, train_data, valid_data)
weighted_metrics = pipeline.build_weighted_metrics()
Returns | |
---|---|
A list or a dict of tf.keras.metrics.Metric s.
|
train_and_validate
@abc.abstractmethod
train_and_validate( *arg, **kwargs ) -> Any
Constructs and runs the training pipeline.
Example usage:
pipeline = BasicPipeline(model, train_data, valid_data)
pipeline.train_and_validate()
Args | |
---|---|
*arg
|
arguments that might be used in the training pipeline. |
**kwargs
|
keyword arguments that might be used in the training pipeline. |
Returns | |
---|---|
None or a trained tf.keras.Model or a path to a saved tf.keras.Model .
|