Multi-Head Attention

Learning multiple relationships

Why One Head Isn't Enough

Scaled dot-product attention is powerful, but it has a limitation: each position can only attend in one way. The attention weights form a single distribution — if "cat" attends 60% to "the" and 40% to "sat", there is no room to simultaneously capture that "cat" is the subject of "sat" and that "the" is its determiner.

Language is rich. Words relate to each other in multiple ways at once:

  • Syntactically: "The" modifies "cat", "cat" is the subject of "sat"
  • Semantically: "cat" and "sat" are conceptually linked (animals sit)
  • Positionally: nearby words often belong together

A single attention pattern cannot capture all these relationships. The solution? Run multiple attention operations in parallel, each with its own learned projections. Let different heads specialize in different types of relationships.

The Multi-Head Mechanism

The original paper puts it elegantly:

"Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions."

Attention Is All You Need, Vaswani et al.

Here is how it works. Instead of one large attention operation, we run hh smaller ones. For each head ii:

headi=Attention(XWiQ,XWiK,XWiV)\text{head}_i = \text{Attention}(XW_i^Q, XW_i^K, XW_i^V)

Each head has its own query, key, and value projection matrices. This means each head can learn to look for different things.

The key insight is dimension splitting. If our model dimension is dmodeld_{model} and we have hh heads, each head works in dk=dmodel/hd_k = d_{model} / h dimensions. The total computation is roughly the same as single-head attention with full dimensionality, but we get multiple perspectives.

Real-world scale: GPT-3 uses 96 attention heads, each with dimension 128, for a total model dimension of 12,288. That is 96 different ways to attend at every layer.

Interactive: Multiple Attention Heads

Positional Head

Attends primarily to nearby tokens

The
clever
fox
jumped
The
50%
35%
10%
5%
clever
30%
45%
20%
5%
fox
8%
32%
45%
15%
jumped
5%
10%
35%
50%

This head learned to attend to nearby positions. Notice the diagonal pattern — each token pays most attention to itself and its immediate neighbors. This captures local context.

Each head sees the same input but learns different attention patterns

Toggle between heads to see how each one attends differently. Then switch to grid view to see all patterns at once. Notice the variety — one head focuses locally, another tracks syntax, another captures meaning.

What Different Heads Learn

Researchers have studied what different heads actually learn, and the results are fascinating. Different heads consistently specialize:

Positional heads attend primarily to nearby tokens. They capture local context and word order. The pattern looks like a diagonal band — each position attends to itself and its neighbors.

Syntactic heads track grammatical structure. Verbs attend to their subjects. Pronouns attend to their antecedents. Adjectives attend to the nouns they modify. The network discovers grammar without explicit supervision.

Semantic heads connect related concepts. "Paris" and "France" attend to each other. "Hot" and "cold" show patterns of opposition. These heads encode world knowledge learned from training data.

Copy/identity heads simply pass information through. Each position attends mostly to itself. Sometimes the best move is to preserve the original representation unchanged.

Some heads remain mysterious. Researchers call them "weird heads" — their patterns don't match obvious linguistic categories, yet ablating them hurts performance. The network discovered something useful that we don't fully understand yet.

Interactive: Head Specialization Examples

Isawthemovie
I
saw
the
movie
I
80%
10%
5%
5%
saw
70%
20%
5%
5%
the
10%
75%
10%
5%
movie
5%
10%
75%
10%

This head consistently attends to the previous token. This pattern helps the model understand word order and build sequential context. It's one of the most common head types found in transformers.

Explore three different head types with real example sentences. The highlighted cells show the characteristic pattern each head has learned. Notice how different heads extract completely different information from the same input.

Concatenate and Project

After each head computes its attention, we need to combine the results. The approach is simple: concatenate all head outputs, then apply a linear projection.

MultiHead(X)=Concat(head1,...,headh)WO\text{MultiHead}(X) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W^O

The concatenation stacks the hh head outputs (each of dimension dkd_k) into a single vector of dimension dmodeld_{model}. The projection matrix WOW^O then transforms this back to the model dimension.

Why project after concatenating? Two reasons:

  1. Dimension matching: The output needs the same shape as the input to feed into subsequent layers.
  2. Cross-head mixing: The projection lets information from different heads combine. Head 3's syntactic insight can interact with Head 7's semantic knowledge.

Interactive: Concatenation and Projection

Individual Head Outputs
Concatenate Heads
Project to Model Dimension
Final Output

Each attention head produces an output of dimension dkd_k (here, 2). With 4 heads, we have 4 separate outputs.

Head 1
2
Head 2
2
Head 3
2
Head 4
2

4 heads × 2 dimensions each = 8 total dimensions

Watch the flow from individual heads through concatenation to the final projection. Each step preserves then combines information, giving the model multiple views synthesized into one output.

The Power of Parallel Processing

Multi-head attention is not just about expressivity — it is also efficient. All heads compute in parallel. On modern GPUs, running 96 small attention operations is barely slower than running one large one, but gives far richer representations.

This parallelism is why transformers scale so well. Unlike RNNs that must process tokens sequentially, transformers process all positions and all heads simultaneously. The attention mechanism is embarrassingly parallel.

The combination of expressivity (multiple relationship types) and efficiency (parallel computation) is what makes multi-head attention the backbone of modern language models. Every GPT, every BERT, every Claude and Gemini — they all rely on this mechanism.

Key Takeaways

  • Single-head attention can only capture one type of relationship at a time
  • Multi-head attention runs hh parallel attention operations, each with its own learned projections
  • Different heads specialize: positional, syntactic, semantic, identity, and "weird" heads
  • The paper's insight: attending to "different representation subspaces at different positions" captures richer structure
  • Outputs are concatenated then projected, mixing information across heads
  • GPT-3 uses 96 heads × 128 dimensions = 12,288 model dimension
  • Parallel computation makes multi-head attention both expressive and efficient