Multihead Attention

Dive into the new concept of multihead attention, which allows transformers to capture diverse features and enhance interpretability.

Now, let's discuss a variant of self-attention or the attention mechanism in general, known as multihead attention. This concept is vital for encoding multiple features using the transformer model.

We've seen how queries (QQQ)Q (Query): This represents the information the model is looking for within the input sequence., keys (KKK)K (Key): Keys help establish relationships with other elements in the sequence, providing context and connections., and values (VVV)V (Value): Values hold the actual information content related to a particular element in the sequence. are projections of the same thing, essentially representing different views or features of the input, especially in the context of natural language processing (NLP). We'll explore this further.

Understanding multihead attention

For example, in NLP, we can examine the part-of-speech tag of a word and query its relationship with other part-of-speech tags in the same sentence. This is useful for understanding connections between named entities or resolving references.

Consider the example sentence, "The student didn't attempt the quiz because it was too difficult." In this sentence, a single word, for example, "it", can refer to different words based on the sentence's structure. Each type of attribute has its own unique projection.

The need for multiple features

What if we want to include more than one feature at once? We're not limited to a single projection. We aim to include various perspectives, similar to the concept of channels in convolutional neural networks. This is where multihead attention comes in.

Press + to interact
Multihead attention
Multihead attention

In mathematical terms, multihead attention is quite straightforward. It involves concatenating multiple attention outputs, each having its own set of projection matrices for QQ, KK, and VV. This approach allows us to capture various features, not just one.

Understanding multihead attention

This idea of multihead attention aligns with the need for multiple feature detectors. In other areas of deep learning, like fully connected layers, we have multiple output neurons, each dedicated to detecting different features. Similarly, in convolutional neural networks (CNN)Convolutional Neural Network (CNN): Deep learning model using convolution layers, excelling in image recognition and feature extraction., we employ different kernels to detect various features, resulting in distinct feature maps that act as channels. These channels collectively encode different feature projections.

Press + to interact
Recognizing a cat using ConvNet
Recognizing a cat using ConvNet

In multihead attention, the process is analogous. The key distinction, as discussed earlier, lies in the inductive bias of CNN, which assumes spatial relationships in images. In contrast, attention doesn’t make such assumptions; it calculates global attention scores, applies softmax, and multiplies these scores by input values. This approach is particularly valuable for several reasons. Unlike CNNs, where spatial relationships are explicitly assumed and learned, attention allows the model to capture dependencies across the entire input sequence. By calculating global attention scores, the model can assign varying degrees of importance to different parts of the input sequence, focusing on the most relevant information.

To grasp the significance of this process, consider it as a mechanism that allows the model to dynamically attend to different parts of the input sequence, adaptively adjusting its focus based on the context and content of the data. Multiple such operations can be performed and concatenated, mirroring the idea of channels in CNN, and providing the model with a powerful tool to capture intricate relationships and dependencies within the input data.

Extending to RNNs

This concept is natural and extends to recurrent neural networks (RNNs) Recurrent Neural Network (RNN): Sequential data processing deep learning model with recurrent connections for temporal understanding. as well. In matrix equations, we use multiple projection matrices (e.g., W0qW_0^q ...

Access this course and 1400+ top-rated courses and projects.