Sentence-BERT with a Siamese Network
Learn how Sentence-BERT uses the Siamese network architecture for fine-tuning pre-trained BERT for sentence pair classification and sentence pair regression tasks.
We'll cover the following...
Sentence-BERT uses the Siamese network architecture for fine-tuning the pre-trained BERT model for sentence pair tasks. In this lesson, let's understand how the Siamese network architecture is useful and how we fine-tune the pre-trained BERT for sentence pair tasks. First, we will see how Sentence-BERT works for a sentence pair classification task, then we will learn how Sentence-BERT works for a sentence pair regression task.
Sentence pair classification task
Suppose we have a dataset containing sentence pairs and a binary label indicating whether the sentence pairs are similar (1) or dissimilar (0), as shown in the following figure:
Sample dataset
Sentence 1 | Sentence 2 | Label |
I completed my assignment | I completed my homework | 1 |
The game was boring | This is a great place | 0 |
The food is delicious | The food is tasty | 1 |
: | : | : |
Now, let's see how to fine-tune the pre-trained BERT model with the preceding dataset using the Siamese architecture for the sentence pair classification task. Let's take the first sentence pair from our dataset:
We need to classify whether the given sentence pair is similar (1) or dissimilar (0). First, we tokenize the sentence and add [CLS] and [SEP] tokens at the beginning and end of the sentence, respectively, as shown:
Now, we feed the tokens to the pre-trained BERT model and obtain the representation of each of the tokens. We learned that Sentence-BERT uses a Siamese network. We know that the Siamese network consists of two identical networks that share the same weights. So, here we use two identical pre-trained BERT models. We feed the tokens from sentence 1 to one BERT and the tokens from sentence 2 to another BERT and compute the representation of the given two sentences.
To compute the representation of a sentence, we apply mean or max pooling. By default, in Sentence-BERT, we use mean pooling. After applying the pooling operation, we will have a sentence representation for the given sentence pair, as shown in the following figure:
In the preceding figure