Fine-Tuning BERT for Extractive Summarization
Learn how to fine-tune the BERT model for extractive summarization.
Let's learn how to fine-tune the BERT model to perform text summarization. First, we will understand how to fine-tune BERT for extractive summarization, and then we will see how to fine-tune BERT for abstractive summarization.
Extractive summarization using BERT
To fine-tune the pre-trained BERT for the extractive summarization task, we slightly modify the input data format of the BERT model. Before looking into the modified input data format, let's first recall how we feed the input data to the BERT model.
Say we have two sentences: 'Paris is a beautiful city. I love Paris'. First, we tokenize the sentences, and we add a [CLS] token only at the beginning of the first sentence, and we add a [SEP] token at the end of every sentence. Before feeding the tokens to the BERT, we convert them into embedding using three embedding layers known as token embedding, segment embedding, and position embedding. We sum up all the embeddings together element-wise, and then we feed them as input to the BERT. The input data format of BERT is shown in the following figure:
The BERT model takes this input and returns the representation of every token as output, as shown in the following figure:
Now the question is, how can we use the BERT for the text summarization task? We know that the BERT model gives a representation of every token. But we don't need a representation of every token. Instead, we need a representation of every sentence.
Need for a representation of every sentence
We learned that in extractive summarization, we create a summary by just selecting only the important sentences. We know that a representation of a sentence will hold the meaning of the sentence. If we get a representation of every sentence, then based on the representation, we can decide whether the sentence is important or not. If it is important, then we will add it to the summary, else we will discard it. Thus, if we obtain the representation of every sentence using BERT, then we can feed the representation to the classifier, and the classifier will tell us whether the sentence is important or not.
Getting a representation of the sentence
Okay, how can we get the representation of a sentence? Can we use the representation of the [CLS] token as the representation of the sentence? Yes! But there is a small catch here. We learned that we add the [CLS] token only at the beginning of the first sentence, but in the text summarization task, we feed multiple sentences to the BERT model, and we need the representation of all the sentences.
So, in this case, we modify our input data format to the BERT model. We add the [CLS] token at the beginning of every sentence so that we can use the representation of the [CLS] token added at the beginning of every sentence as the representation.
Say we have three sentences: 'sent one', 'sent two', and 'sent three'. First, we tokenize the sentences and add the [CLS] token at the beginning of every sentence, and we also separate each sentence with the [SEP] token. The input tokens are shown in the following:
As we can observe, we have added the [CLS] token at the beginning of every sentence and added the [SEP] token at the end of every sentence.
Converting tokens to embeddings
Next, we feed the input tokens to the token, segment, and position embedding layers and convert the input tokens into embeddings.
Token embedding layer
The token embedding layer is shown in the following figure:
Segment embedding layer
The next layer is the segment embedding layer. We know that segment embedding is used to distinguish between the two given sentences. The segment embedding layer returns one of two embeddings,