An ItemSelector
implementation that randomly selects items in a batch.
text.RandomItemSelector(
max_selections_per_batch,
selection_rate,
unselectable_ids=None,
shuffle_fn=None
)
Used in the notebooks
RandomItemSelector
randomly selects items in a batch subject to
restrictions given (max_selections_per_batch, selection_rate and
unselectable_ids).
Example:
vocab = ["[UNK]", "[MASK]", "[RANDOM]", "[CLS]", "[SEP]",
"abc", "def", "ghi"]
# Note that commonly in masked language model work, there are
# special tokens we don't want to mask, like CLS, SEP, and probably
# any OOV (out-of-vocab) tokens here called UNK.
# Note that if e.g. there are bucketed OOV tokens in the code,
# that might be a use case for overriding `get_selectable()` to
# exclude a range of IDs rather than enumerating them.
tf.random.set_seed(1234)
selector = tf_text.RandomItemSelector(
max_selections_per_batch=2,
selection_rate=0.2,
unselectable_ids=[0, 3, 4]) # indices of UNK, CLS, SEP
selection = selector.get_selection_mask(
tf.ragged.constant([[3, 5, 7, 7], [4, 6, 7, 5]]), axis=1)
print(selection)
<tf.RaggedTensor [[False, False, False, True], [False, False, True, False]]>
The selection has skipped the first elements (the CLS and SEP token codings)
and picked random elements from the other elements of the segments -- if
run with a different random seed the selections might be different.
Args |
max_selections_per_batch
|
An int of the max number of items to mask out.
|
selection_rate
|
The rate at which items are randomly selected.
|
unselectable_ids
|
(optional) A list of python ints or 1D Tensor of ints
which are ids that will be not be masked.
|
shuffle_fn
|
(optional) A function that shuffles a 1D Tensor . Default
uses tf.random.shuffle .
|
Attributes |
max_selections_per_batch
|
|
selection_rate
|
|
shuffle_fn
|
|
unselectable_ids
|
|
Methods
get_selectable
View source
get_selectable(
input_ids, axis
)
Return a boolean mask of items that can be chosen for selection.
The default implementation marks all items whose IDs are not in the
unselectable_ids
list. This can be overridden if there is a need for
a more complex or algorithmic approach for selectability.
Args |
input_ids
|
a RaggedTensor .
|
axis
|
axis to apply selection on.
|
Returns |
a RaggedTensor with dtype of bool and with shape
input_ids.shape[:axis] . Its values are True if the
corresponding item (or broadcasted subitems) should be considered for
masking. In the default implementation, all input_ids items that are not
listed in unselectable_ids (from the class arg) are considered
selectable.
|
get_selection_mask
View source
get_selection_mask(
input_ids, axis
)
Returns a mask of items that have been selected.
The default implementation simply returns all items not excluded by
get_selectable
.
Args |
input_ids
|
A RaggedTensor .
|
axis
|
(optional) An int detailing the dimension to apply selection on.
Default is the 1st dimension.
|
Returns |
a RaggedTensor with shape input_ids.shape[:axis] . Its values are True
if the corresponding item (or broadcasted subitems) should be selected.
|