text.mask_language_model

Applies dynamic language model masking.

Used in the notebooks

Used in the guide

mask_language_model implements the Masked LM and Masking Procedure described in BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding (https://arxiv.org/pdf/1810.04805.pdf). mask_language_model uses an ItemSelector to select the items for masking, and a MaskValuesChooser to assign the values to the selected items. The purpose of this is to bias the representation towards the actual observed item.

Masking is performed on items in an axis. A decision is taken independently at random to mask with [MASK], mask with random tokens from the full vocab, or not mask at all. Note that the masking decision is broadcasted to the sub-dimensions.

For example, in a RaggedTensor of shape [batch, (wordpieces)] and if axis=1, each wordpiece independently gets masked (or not).

With the following input:

[[b"Sp", b"##onge", b"bob", b"Sq", b"##uare", b"##pants" ],
[b"Bar", b"##ack", b"Ob", b"##ama"],
[b"Mar", b"##vel", b"A", b"##ven", b"##gers"]],

mask_language_model could end up masking individual wordpieces:

[[b"[MASK]", b"##onge", b"bob", b"Sq", b"[MASK]", b"##pants" ],
[b"Bar", b"##ack", b"[MASK]", b"##ama"],
[b"[MASK]", b"##vel", b"A", b"##ven", b"##gers"]]

Or with random token inserted:

[[b"[MASK]", b"##onge", b"bob", b"Sq", b"[MASK]", b"##pants" ],
[b"Bar", b"##ack", b"Sq", b"##ama"],   # random token inserted for 'Ob'
[b"Bar", b"##vel", b"A", b"##ven", b"##gers"]]  # random token inserted for
                                                # 'Mar'

In a RaggedTensor of shape [batch, (words), (wordpieces)], whole words get masked (or not). If a word gets masked, all its tokens are independently either replaced by [MASK], by random tokens, or no substitution occurs. Note that any arbitrary spans that can be constructed by a RaggedTensor can be masked in the same way.

For example, if we have an RaggedTensor with shape [batch, (token), (wordpieces)]:

[[[b"Sp", "##onge"], [b"bob"], [b"Sq", b"##uare", b"##pants"]],
 [[b"Bar", "##ack"], [b"Ob", b"##ama"]],
 [[b"Mar", "##vel"], [b"A", b"##ven", b"##gers"]]]

mask_language_model could mask whole spans (items grouped together by the same 1st dimension):

[[[b"[MASK]", "[MASK]"], [b"bob"], [b"Sq", b"##uare", b"##pants"]],
 [[b"Bar", "##ack"], [b"[MASK]", b"[MASK]"]],
 [[b"[MASK]", "[MASK]"], [b"A", b"##ven", b"##gers"]]]

or insert random items in spans:

 [[[b"Mar", "##ama"], [b"bob"], [b"Sq", b"##uare", b"##pants"]],
  [[b"Bar", "##ack"], [b"##onge", b"##gers"]],
  [[b"Ob", "Sp"], [b"A", b"##ven", b"##gers"]]]

input_ids A RaggedTensor of n dimensions (where n >= 2) on which masking will be applied to items up to dimension 1.
item_selector An instance of ItemSelector that is used for selecting items to be masked.
mask_values_chooser An instance of MaskValuesChooser which determines the values assigned to the ids chosen for masking.
axis the axis where items will be treated atomically for masking.

A tuple of (masked_input_ids, masked_positions, masked_ids) where:
masked_input_ids A RaggedTensor in the same shape and dtype as input_ids, but with items in masked_positions possibly replaced with mask_token, random id, or no change.
masked_positions A RaggedTensor of ints with shape [batch, (num_masked)] containing the positions of items selected for masking.
masked_ids A RaggedTensor with shape [batch, (num_masked)] and same type as input_ids containing the original values before masking and thus used as labels for the task.