View source on GitHub |
Base model for TFRS models.
tfrs.models.Model(
*args, **kwargs
)
Used in the notebooks
Used in the tutorials |
---|
Many recommender models are relatively complex, and do not neatly fit into supervised or unsupervised paradigms. This base class makes it easy to define custom training and test losses for such complex models.
This is done by asking the user to implement the following methods:
__init__
to set up your model. Variable, task, loss, and metric initialization should go here.compute_loss
to define the training loss. The method takes as input the raw features passed into the model, and returns a loss tensor for training. As part of doing so, it should also update the model's metrics.- [Optional]
call
to define how the model computes its predictions. This is not always necessary: for example, two-tower retrieval models have two well-defined submodels whosecall
methods are normally used directly.
Note that this base class is a thin conveniece wrapper for tf.keras.Model, and
equivalent functionality can easily be achieved by overriding the train_step
and test_step
methods of a plain Keras model. Doing so also makes it easy
to build even more complex training mechanisms, such as the use of
different optimizers for different variables, or manipulating gradients.
Keras has an excellent tutorial on how to do this here.
Methods
call
call(
inputs, training=None, mask=None
)
Calls the model on new inputs and returns the outputs as tensors.
In this case call()
just reapplies
all ops in the graph to the new inputs
(e.g. build a new computational graph from the provided inputs).
Args | |
---|---|
inputs
|
Input tensor, or dict/list/tuple of input tensors. |
training
|
Boolean or boolean scalar tensor, indicating whether to
run the Network in training mode or inference mode.
|
mask
|
A mask or list of masks. A mask can be either a boolean tensor or None (no mask). For more details, check the guide here. |
Returns | |
---|---|
A tensor if there is a single output, or a list of tensors if there are more than one outputs. |