View source on GitHub |
Builds datasets for multi-task training.
Inherits From: BaseDatasetBuilder
, AbstractDatasetBuilder
tfr.keras.pipeline.MultiLabelDatasetBuilder(
context_feature_spec: Dict[str, Union[tf.io.FixedLenFeature, tf.io.VarLenFeature, tf.io.
RaggedFeature]],
example_feature_spec: Dict[str, Union[tf.io.FixedLenFeature, tf.io.VarLenFeature, tf.io.
RaggedFeature]],
mask_feature_name: str,
label_spec: Dict[str, Tuple[str, tf.io.FixedLenFeature]],
hparams: tfr.keras.pipeline.DatasetHparams
,
sample_weight_spec: Optional[Tuple[str, tf.io.FixedLenFeature]] = None
)
This supports a single data sets with multiple labels formed in a dict. The case where we have multiple datasets is not handled in the current code yet. We can consider to extend the dataset builder when the use case comes out.
Example usage:
context_feature_spec = {}
example_feature_spec = {
"example_feature_1": tf.io.FixedLenFeature(
shape=(1,), dtype=tf.float32, default_value=0.0)
}
mask_feature_name = "list_mask"
label_spec_tuple = ("utility",
tf.io.FixedLenFeature(
shape=(1,),
dtype=tf.float32,
default_value=_PADDING_LABEL))
label_spec = {"task1": label_spec_tuple, "task2": label_spec_tuple}
weight_spec = ("weight",
tf.io.FixedLenFeature(
shape=(1,), dtype=tf.float32, default_value=1.))
dataset_hparams = DatasetHparams(
train_input_pattern="train.dat",
valid_input_pattern="valid.dat",
train_batch_size=128,
valid_batch_size=128)
dataset_builder = MultiLabelDatasetBuilder(
context_feature_spec,
example_feature_spec,
mask_feature_name,
label_spec,
dataset_hparams,
sample_weight_spec=weight_spec)
Args | |
---|---|
context_feature_spec
|
Maps context (aka, query) names to feature specs. |
example_feature_spec
|
Maps example (aka, document) names to feature specs. |
mask_feature_name
|
If set, populates the feature dictionary with this name
and the coresponding value is a tf.bool Tensor of shape [batch_size,
list_size] indicating the actual example is padded or not.
|
label_spec
|
A dict that maps task names to label specs. Each of the latter have a label name and a tf.io.FixedLenFeature spec. |
hparams
|
A dict containing model hyperparameters. |
sample_weight_spec
|
Feature spec for per-example weight. |
Methods
build_signatures
build_signatures(
model: tf.keras.Model
) -> Any
See AbstractDatasetBuilder
.
build_train_dataset
build_train_dataset() -> tf.data.Dataset
See AbstractDatasetBuilder
.
build_valid_dataset
build_valid_dataset() -> tf.data.Dataset
See AbstractDatasetBuilder
.