View source on GitHub |
Class to set up the input, train and eval processes for a TF Ranking model.
tfr.extension.pipeline.RankingPipeline(
context_feature_columns,
example_feature_columns,
hparams,
estimator,
label_feature_name='relevance',
label_feature_type=tf.int64,
dataset_reader=tfr.keras.pipeline.DatasetHparams.dataset_reader
,
best_exporter_metric=None,
best_exporter_metric_higher_better=True,
size_feature_name=None
)
An example use case is provided below:
import tensorflow as tf
import tensorflow_ranking as tfr
context_feature_columns = {
"c1": tf.feature_column.numeric_column("c1", shape=(1,))
}
example_feature_columns = {
"e1": tf.feature_column.numeric_column("e1", shape=(1,))
}
hparams = dict(
train_input_pattern="/path/to/train/files",
eval_input_pattern="/path/to/eval/files",
train_batch_size=8,
eval_batch_size=8,
checkpoint_secs=120,
num_checkpoints=1000,
num_train_steps=10000,
num_eval_steps=100,
loss="softmax_loss",
list_size=10,
listwise_inference=False,
convert_labels_to_binary=False,
model_dir="/path/to/your/model/directory")
# See `tensorflow_ranking.estimator` for details about creating an estimator.
estimator = <create your own estimator>
ranking_pipeline = tfr.ext.pipeline.RankingPipeline(
context_feature_columns,
example_feature_columns,
hparams,
estimator=estimator,
label_feature_name="relevance",
label_feature_type=tf.int64)
ranking_pipeline.train_and_eval()
Note that you may | |
---|---|
|
If you want to further customize certain RankingPipeline
behaviors, please
create a subclass of RankingPipeline
, and overwrite related functions. We
recommend only overwriting the following functions:
_make_dataset
which builds the tf.dataset for a tf-ranking model._make_serving_input_fn
that defines the input function for serving._export_strategies
if you have more advanced needs for model exporting.
For example, if you want to remove the best exporters, you may overwrite:
class NoBestExporterRankingPipeline(tfr.ext.pipeline.RankingPipeline):
def _export_strategies(self, event_file_pattern):
del event_file_pattern
latest_exporter = tf.estimator.LatestExporter(
"latest_model",
serving_input_receiver_fn=self._make_serving_input_fn())
return [latest_exporter]
ranking_pipeline = NoBestExporterRankingPipeline(
context_feature_columns,
example_feature_columns,
hparams
estimator=estimator)
ranking_pipeline.train_and_eval()
if you want to customize your dataset reading behaviors, you may overwrite:
class CustomizedDatasetRankingPipeline(tfr.ext.pipeline.RankingPipeline):
def _make_dataset(self,
batch_size,
list_size,
input_pattern,
randomize_input=True,
num_epochs=None):
# Creates your own dataset, plese follow `tfr.data.build_ranking_dataset`.
dataset = build_my_own_ranking_dataset(...)
...
return dataset.map(self._features_and_labels)
ranking_pipeline = CustomizedDatasetRankingPipeline(
context_feature_columns,
example_feature_columns,
hparams
estimator=estimator)
ranking_pipeline.train_and_eval()
Args | |
---|---|
context_feature_columns
|
(dict) Context (aka, query) feature columns. |
example_feature_columns
|
(dict) Example (aka, document) feature columns. |
hparams
|
(dict) A dict containing model hyperparameters. |
estimator
|
(Estimator ) An Estimator instance for model train and eval.
|
label_feature_name
|
(str) The name of the label feature. |
label_feature_type
|
(tf.dtype ) The value type of the label feature.
|
dataset_reader
|
(tf.Dataset ) The dataset format for the input files.
|
best_exporter_metric
|
(str) Metric key for exporting the best model. If None, exports the model with the minimal loss value. |
best_exporter_metric_higher_better
|
(bool) If a higher metric is better.
This is only used if best_exporter_metric is not None.
|
size_feature_name
|
(str) If set, populates the feature dictionary with
this name and the coresponding value is a tf.int32 Tensor of shape
[batch_size] indicating the actual sizes of the example lists before
padding and truncation. If None, which is default, this feature is not
generated.
|
Methods
train_and_eval
train_and_eval(
local_training=True
)
Launches train and evaluation jobs locally.