...

/

Masked Multi-Head Attention

Masked Multi-Head Attention

Learn about the masked-multi head attention mechanism and how it works.

In our English-to-French translation task, say our training dataset looks like the one shown here:

A sample training set

Source sentence

Target sentence

I am good

Je vais bien

Good morning

Bonjour

Thank you very much

Merci beaucoup

By looking at the preceding dataset, we can understand that we have source and target sentences. We saw how the decoder predicts the target sentence word by word in each time step and that happens only during testing.

During training, since we have the right target sentence, we can just feed the whole target sentence as input to the decoder but with a small modification. We learned that the decoder takes the input <sos><sos> as the first token, and combines the next predicted word to the input on every time step for predicting the target sentence until the <eos><eos> token is reached. So, we can just add the <sos><sos> token to the beginning of our target sentence and send that as an input to the decoder.

Say we are converting the English sentence 'I am good' to the French sentence 'Je vais bien'. We can just add the <sos><sos> token to the beginning of the target sentence and send <sos>Je vais bien<sos> \text{Je vais bien} as an input to the decoder, and then the decoder predicts the output as Je vais bien<eos>\text{Je vais bien} <eos>, as shown in the following figure:

Press + to interact
Encoder and decoder of the transformer
Encoder and decoder of the transformer

But how does this work? Isn't this kind of ambiguous? Why do we need to feed the entire target sentence and let the decoder predict the shifted target sentence as output? Let's explore this in more detail.

We learned that instead of feeding the input directly to the decoder, we convert it into an embedding (output embedding matrix) and add positional encoding, and then feed it to the decoder. Let's suppose the following matrix, XX, is obtained as a result of adding the output embedding matrix and positional encoding:

Press + to interact
Input matrix
Input matrix

Now, we feed the preceding matrix XX, to the decoder. The first layer in the decoder is the masked multi-head attention. This works similarly to the multi-head attention mechanism we learned about with the encoder, but with a small difference.

Computing query, key, and value matrices

To perform self-attention, we create three new matrices, called query QQ, key KK ...