text.RandomItemSelector

An ItemSelector implementation that randomly selects items in a batch.

Used in the notebooks

Used in the guide

RandomItemSelector randomly selects items in a batch subject to restrictions given (max_selections_per_batch, selection_rate and unselectable_ids).

Example:

vocab = ["[UNK]", "[MASK]", "[RANDOM]", "[CLS]", "[SEP]",
         "abc", "def", "ghi"]
# Note that commonly in masked language model work, there are
# special tokens we don't want to mask, like CLS, SEP, and probably
# any OOV (out-of-vocab) tokens here called UNK.
# Note that if e.g. there are bucketed OOV tokens in the code,
# that might be a use case for overriding `get_selectable()` to
# exclude a range of IDs rather than enumerating them.
tf.random.set_seed(1234)
selector = tf_text.RandomItemSelector(
    max_selections_per_batch=2,
    selection_rate=0.2,
    unselectable_ids=[0, 3, 4])  # indices of UNK, CLS, SEP
selection = selector.get_selection_mask(
    tf.ragged.constant([[3, 5, 7, 7], [4, 6, 7, 5]]), axis=1)
print(selection)
<tf.RaggedTensor [[False, False, False, True], [False, False, True, False]]>

The selection has skipped the first elements (the CLS and SEP token codings) and picked random elements from the other elements of the segments -- if run with a different random seed the selections might be different.

max_selections_per_batch An int of the max number of items to mask out.
selection_rate The rate at which items are randomly selected.
unselectable_ids (optional) A list of python ints or 1D Tensor of ints which are ids that will be not be masked.
shuffle_fn (optional) A function that shuffles a 1D Tensor. Default uses tf.random.shuffle.

max_selections_per_batch

selection_rate

shuffle_fn

unselectable_ids

Methods

get_selectable

View source

Return a boolean mask of items that can be chosen for selection.

The default implementation marks all items whose IDs are not in the unselectable_ids list. This can be overridden if there is a need for a more complex or algorithmic approach for selectability.

Args
input_ids a RaggedTensor.
axis axis to apply selection on.

Returns
a RaggedTensor with dtype of bool and with shape input_ids.shape[:axis]. Its values are True if the corresponding item (or broadcasted subitems) should be considered for masking. In the default implementation, all input_ids items that are not listed in unselectable_ids (from the class arg) are considered selectable.

get_selection_mask

View source

Returns a mask of items that have been selected.

The default implementation simply returns all items not excluded by get_selectable.

Args
input_ids A RaggedTensor.
axis (optional) An int detailing the dimension to apply selection on. Default is the 1st dimension.

Returns
a RaggedTensor with shape input_ids.shape[:axis]. Its values are True if the corresponding item (or broadcasted subitems) should be selected.