tfrs.tasks.Ranking
Stay organized with collections
Save and categorize content based on your preferences.
A ranking task.
Inherits From: Task
tfrs.tasks.Ranking(
loss: Optional[tf.keras.losses.Loss] = None,
metrics: Optional[List[tf.keras.metrics.Metric]] = None,
prediction_metrics: Optional[List[tf.keras.metrics.Metric]] = None,
label_metrics: Optional[List[tf.keras.metrics.Metric]] = None,
loss_metrics: Optional[List[tf.keras.metrics.Metric]] = None,
name: Optional[Text] = None
) -> None
Used in the notebooks
Recommender systems are often composed of two components:
- a retrieval model, retrieving O(thousands) candidates from a corpus of
O(millions) candidates.
- a ranker model, scoring the candidates retrieved by the retrieval model to
return a ranked shortlist of a few dozen candidates.
This task helps with building ranker models. Usually, these will involve
predicting signals such as clicks, cart additions, likes, ratings, and
purchases.
Args |
loss
|
Loss function. Defaults to BinaryCrossentropy.
|
metrics
|
List of Keras metrics to be evaluated.
|
prediction_metrics
|
List of Keras metrics used to summarize the
predictions.
|
label_metrics
|
List of Keras metrics used to summarize the labels.
|
loss_metrics
|
List of Keras metrics used to summarize the loss.
|
name
|
Optional task name.
|
Methods
call
View source
call(
labels: tf.Tensor,
predictions: tf.Tensor,
sample_weight: Optional[tf.Tensor] = None,
training: bool = False,
compute_metrics: bool = True
) -> tf.Tensor
Computes the task loss and metrics.
Args |
labels
|
Tensor of labels.
|
predictions
|
Tensor of predictions.
|
sample_weight
|
Tensor of sample weights.
|
training
|
Indicator whether training or test loss is being computed.
|
compute_metrics
|
Whether to compute metrics. Set this to False
during training for faster training.
|
Returns |
loss
|
Tensor of loss values.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-04-26 UTC.
[{
"type": "thumb-down",
"id": "missingTheInformationINeed",
"label":"Missing the information I need"
},{
"type": "thumb-down",
"id": "tooComplicatedTooManySteps",
"label":"Too complicated / too many steps"
},{
"type": "thumb-down",
"id": "outOfDate",
"label":"Out of date"
},{
"type": "thumb-down",
"id": "samplesCodeIssue",
"label":"Samples / code issue"
},{
"type": "thumb-down",
"id": "otherDown",
"label":"Other"
}]
[{
"type": "thumb-up",
"id": "easyToUnderstand",
"label":"Easy to understand"
},{
"type": "thumb-up",
"id": "solvedMyProblem",
"label":"Solved my problem"
},{
"type": "thumb-up",
"id": "otherUp",
"label":"Other"
}]
{"lastModified": "Last updated 2024-04-26 UTC."}
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-04-26 UTC."]]