...

/

Masked Language Modeling

Masked Language Modeling

Learn about using masked language modeling and whole word masking techniques to pre-train the BERT model.

BERT is an auto-encoding language model, meaning that it reads the sentence in both directions to make a prediction. In a masked language modeling task, in a given input sentence, we randomly mask 15% of the words and train the network to predict the masked words. To predict the masked words, our model reads the sentence in both directions and tries to predict the masked words.

Training the BERT model for MLM task

Let's understand how masked language modeling works with an example. Let's take the same sentences we saw earlier: 'Paris is a beautiful city', and 'I love Paris'.

Tokenize the sentence

First, we tokenize the sentences and get the tokens, as shown here:

Adding [CLS] and [SEP] tokens

Now, we add the [CLS] token at the beginning of the first sentence and the [SEP] token at the end of every sentence, as shown here:

Masking the tokens

Next, we randomly mask 15% of the tokens (words) in our preceding tokens list. Say we mask the word 'city', then we replace the word 'city' with a [MASK] token, as shown here:

As we can observe from the preceding tokens list, we have replaced the word 'city' with a [MASK] token. Now we train our BERT model to predict the masked token.

There is a small catch here. Masking tokens in this way will create a discrepancy between pre-training and fine-tuning. That is, we learned that we train BERT by predicting the [MASK] token. After training, we can fine-tune the pre-trained BERT model for downstream tasks, such as sentiment analysis. But during fine-tuning, we will not have any [MASK] tokens in the input. So it will cause a mismatch between the way in which BERT is pre-trained and how it is used for fine-tuning.

The 80-10-10% rule

To overcome this issue, we apply the 80-10-10% rule. We learned that we randomly mask 15% of the tokens in the sentence. Now, for these 15% of tokens, we do the following:

  • For 80% of the time, we replace the token (actual word) with the [MASK] token. So, for 80% of the time, the input to the model will be as follows:

...