Self-Attention
Understand in depth how self-attention works.
What is self-attention?
“Self-attention" is an attention mechanism relating different positions of a single sequence in order to compute a representation of the sequence.” ~ Ashish Vaswani et al. from Google Brain
Self-attention enables us to find correlations between different words (tokens) of the input indicating the syntactic and contextual structure of the sentence.
Let’s take the input sequence “Hello I love you” as an example.
A trained self-attention layer will associate the word “love” with the words “I” and “you” with a higher weight than the word “Hello”. From linguistics, we know that these words share a subject-verb-object relationship and that is an intuitive way to understand what self-attention will capture.
In practice, the original transformer model uses three different representations of the embedding matrix: the Queries, Keys, and Values.
This can easily be implemented by multiplying our input with three different weight matrices , and . In essence, it is just a matrix multiplication in the original word embedding.
You can think of it as a linear projection. Here is a visual to help you out:
Having the Query, Value, and Key matrices, we can now apply the self-attention layer as:
In the original paper, the scaled dot-product attention was chosen as a scoring function to represent the correlation between two words (the attention weight).
Note that we can also utilize another similarity function. The is here simply as a scaling factor to make sure that the vectors won’t explode.
Following the video database-query paradigm that we introduced before, this term simply finds the similarity of the searching query with an entry in a database.
Finally, we apply a softmax function to get the final attention weights as a probability distribution.
Remember that we have distinguished the Keys () from the Values () as distinct representations. Thus, the final representation is the self-attention matrix multiplied with the Value () matrix.
Personally, the attention matrix can be thought of as where to look and the Value matrix as what is actually wanted. .
Notice any differences between vector similarities?
First, we have matrices instead of vectors, and, as a result, matrix multiplications. Second, we don’t scale down by the vector magnitude but by the matrix size (dk), which is the number of words in a sentence! The sentence size varies.
Matrix multiplication can be thought of as a parallel vector-matrix multiplication of multiple vectors. The vectors are simply the queries.
We “query” all the projected words together by stacking embedding vectors in a matrix and projecting them linearly to .
Isn’t that awesome?
To make things clearer, you can find a Pytorch implementation below:
class ScaledDotProductAttention(nn.Module):
def __init__(self):
super(ScaledDotProductAttention, self).__init__()
def forward(self,
query: torch.FloatTensor,
key: torch.FloatTensor,
value: torch.FloatTensor,
mask: Optional[torch.ByteTensor] = None,
dropout: Optional[nn.Dropout] = None
) -> Tuple[torch.Tensor, Any]:
"""
Args:
`query`: shape (batch_size, n_heads, max_len, d_q)
`key`: shape (batch_size, n_heads, max_len, d_k)
`value`: shape (batch_size, n_heads, max_len, d_v)
`mask`: shape (batch_size, 1, 1, max_len)
`dropout`: nn.Dropout
Returns:
`weighted value`: shape (batch_size, n_heads, max_len, d_v)
`weight matrix`: shape (batch_size, n_heads, max_len, max_len)
"""
d_k = query.size(-1) # d_k = d_model / n_heads
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask.eq(0), -1e9)
p_attn = F.softmax(scores, dim=-1)
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value), p_attn
Before we build fancy transformer blocks, we need to delve into one more critical concept: multi-head self-attention.
Get hands-on with 1300+ tech skills courses.