Creates a baseline task for next-word prediction on Stack Overflow.
tff.simulation.baselines.stackoverflow.create_word_prediction_task(
train_client_spec: tff.simulation.baselines.ClientSpec
,
eval_client_spec: Optional[tff.simulation.baselines.ClientSpec
] = None,
sequence_length: int = constants.DEFAULT_SEQUENCE_LENGTH,
vocab_size: int = constants.DEFAULT_WORD_VOCAB_SIZE,
num_out_of_vocab_buckets: int = 1,
cache_dir: Optional[str] = None,
use_synthetic_data: bool = False
) -> tff.simulation.baselines.BaselineTask
The goal of the task is to take sequence_length
words from a post and
predict the next word. Here, all posts are drawn from the Stack Overflow
forum, and a client corresponds to a user.
Args |
train_client_spec
|
A tff.simulation.baselines.ClientSpec specifying how to
preprocess train client data.
|
eval_client_spec
|
An optional tff.simulation.baselines.ClientSpec
specifying how to preprocess evaluation client data. If set to None , the
evaluation datasets will use a batch size of 64 with no extra
preprocessing.
|
sequence_length
|
A positive integer dictating the length of each word
sequence in a client's dataset. By default, this is set to
tff.simulation.baselines.stackoverflow.DEFAULT_SEQUENCE_LENGTH .
|
vocab_size
|
Integer dictating the number of most frequent words in the
entire corpus to use for the task's vocabulary. By default, this is set to
tff.simulation.baselines.stackoverflow.DEFAULT_WORD_VOCAB_SIZE .
|
num_out_of_vocab_buckets
|
The number of out-of-vocabulary buckets to use.
|
cache_dir
|
An optional directory to cache the downloadeded datasets. If
None , they will be cached to ~/.tff/ .
|
use_synthetic_data
|
A boolean indicating whether to use synthetic Stack
Overflow data. This option should only be used for testing purposes, in
order to avoid downloading the entire Stack Overflow dataset. A synthetic
vocabulary will also be used (not necessarily of the size vocab_size ).
|