View source on GitHub |
Interface to build a tf.keras.Model
for ranking.
The AbstractModelBuilder
serves as the interface between model building and
training. The training pipeline just calls the build()
method to get the
model constructed in the strategy scope used in the training pipeline, so for
all variables in the model, optimizers, and metrics. See ModelFitPipeline
in
pipeline.py
for example.
The build()
method is to be implemented in a subclass. The simplest example
is just to define everything inside the build function when you define a
tf.keras.Model.
class MyModelBuilder(AbstractModelBuilder):
def build(self) -> tf.keras.Model:
inputs = ...
outputs = ...
return tf.keras.Model(inputs=inputs, outputs=outputs)
The MyModelBuilder
should work with ModelFitPipeline
. To make the model
building more structured for ranking problems, we also define subclasses like
ModelBuilderWithMask
in the following.
Methods
build
@abc.abstractmethod
build() -> tf.keras.Model
The build method to be implemented by a subclass.