Multi-Head Attention

Capturing Multiple Types of Relationships

Difficulty
Intermediate
Duration
10-12 min
Prerequisites
Softmax & Attention
Step
1/ 7

Why One Attention Head Isn't Enough

A single attention head computes one set of attention weights — one way of deciding which tokens are relevant to which. But language has many simultaneous types of relationships:

Consider the word "sat" in "The cat sat down":

  • Syntactic: "sat" needs to find its subject ("cat") and modifier ("down")
  • Semantic: "sat" relates to the concept of physical position
  • Positional: "sat" is near "cat" (adjacent) and "down" (adjacent)

A single attention head must compress all these relationship types into one set of weights. It might learn to focus on syntactic relationships but miss positional ones, or vice versa.

Multi-head attention solves this by running multiple attention heads in parallel, each with its own Q, K, V projections. Each head can specialize in a different relationship type:

  • Head 0 might learn syntactic dependencies (subject-verb)
  • Head 1 might learn positional/local patterns (adjacent words)
  • Head 2 might learn semantic similarity
  • Head 3 might learn coreference (pronoun resolution)

In practice, researchers find that different heads do specialize, though the patterns are often more nuanced than these clean categories.

One Head Cannot Capture All Relationship Types

Relationship TypeExample in "The cat sat down"Single Head Can Capture?
Subject-verb"sat" ← "cat" (who sat?)Maybe, if this is what it learns
Verb-modifier"sat" ← "down" (how?)Conflicts with subject-verb focus
Determiner-noun"The" → "cat" (which cat?)May be ignored if head focuses on verbs
Local contextAdjacent word patternsMay be missed for long-range focus

Scaling Attention with Multiple Heads

ApproachHeadsWhat Each Head SeesCapacity
Single-head1One attention pattern for everythingLimited — must compromise
Multi-head (h=2)2Each head has own Q, K, V weights2 independent patterns
Multi-head (h=8)8Each head learns different relationships8 independent patterns
GPT-3 (h=96)96Rich, diverse attention patternsMassive capacity

Multi-Head Attention — Lesson Content

See how splitting attention into multiple heads lets the model capture syntactic, semantic, and positional patterns simultaneously.

A single attention head must compress all types of token relationships into one set of weights. Multi-head attention runs multiple heads in parallel, each with its own Q, K, V subspace, letting the model attend to different relationship types simultaneously. Using "The cat sat down" with 2 attention heads, you'll see how Q, K, V are split, compare the distinct attention patterns each head learns, and understand how concatenation and output projection combine everything back together.

Learning Objectives

  • Explain why one attention head is insufficient
  • Describe how Q, K, V are split across heads
  • Compare attention patterns from different heads
  • Understand concatenation and output projection
  • List the benefits of multi-head attention

Step 1: Why One Attention Head Isn't Enough

A single attention head computes one set of attention weights — one way of deciding which tokens are relevant to which. But language has **many simultaneous types of relationships**: Consider the word "sat" in "The cat sat down": - **Syntactic:** "sat" needs to find its subject ("cat") and modifier ("down") - **Semantic:** "sat" relates to the concept of physical position - **Positional:** "sat" is near "cat" (adjacent) and "down" (adjacent) A single attention head must compress all these relationship types into one set of weights. It might learn to focus on syntactic relationships but miss positional ones, or vice versa. **Multi-head attention** solves this by running multiple attention heads in parallel, each with its own Q, K, V projections. Each head can specialize in a different relationship type: - Head 0 might learn syntactic dependencies (subject-verb) - Head 1 might learn positional/local patterns (adjacent words) - Head 2 might learn semantic similarity - Head 3 might learn coreference (pronoun resolution) In practice, researchers find that different heads do specialize, though the patterns are often more nuanced than these clean categories.

Step 2: Splitting Q, K, V into Heads

In multi-head attention, we don't just run the same attention multiple times — we **split** the Q, K, V vectors into smaller pieces, one per head. For our example: - **d_model** = 8 (total embedding dimension) - **n_heads** = 2 (number of attention heads) - **d_head** = 8 / 2 = 4 (dimension per head) Each token's 8-dimensional Q, K, V vectors are split into two 4-dimensional pieces: - **Head 0** gets dimensions [0, 1, 2, 3] - **Head 1** gets dimensions [4, 5, 6, 7] Each head then runs the full attention computation (Q·K^T / √d_head, softmax, × V) independently on its 4-dimensional slice. Since each head has different portions of the Q, K, V projections, they naturally attend to different features. This is computationally efficient: splitting doesn't increase the total computation. Two heads of dimension 4 cost the same as one head of dimension 8 (the matrix multiplications are the same total size). We get multiple attention patterns for free!
MultiHead(Q, K, V) = Concat(head_0, head_1, ...) × W_O

head_i = Attention(Q_i, K_i, V_i)

Where Q_i, K_i, V_i are the i-th slice of Q, K, V
d_head = d_model / n_heads = 8 / 2 = 4

Step 3: Head 0: Attention Pattern

Let's examine what Head 0 learns. Using the first 4 dimensions of Q and K, Head 0 computes its own attention weights independently. Below is the attention heatmap for Head 0. Each cell shows how much one token (row) attends to another (column) through this head's perspective. **What to look for:** - Does this head focus on specific syntactic relationships? - Does it show strong diagonal attention (self-attention)? - Does it capture subject-verb or modifier relationships? Different heads develop different specializations during training. In our small example, the patterns are driven by the pre-set weight matrices, but they illustrate how two heads can produce very different attention distributions from the same input. Compare this heatmap to Head 1 in the next step — the patterns will be noticeably different, showing that each head captures distinct information about the relationships between tokens.

Step 4: Head 1: A Different Perspective

Now let's look at Head 1, which uses dimensions 4-7 of Q and K. Compare this pattern to Head 0 — you should see meaningful differences. **Comparing the two heads:** Head 0 and Head 1 produce different attention distributions because they operate on different slices of the Q and K projections. Each slice captures different learned features, so the resulting attention patterns naturally differ. In large language models like GPT-4 with 96+ heads, researchers have observed heads that specialize in: - **Positional heads:** Attend primarily to the previous or next token - **Syntactic heads:** Track subject-verb agreement across long distances - **Rare token heads:** Attend to rare or surprising words in context - **Delimiter heads:** Focus on punctuation and sentence boundaries - **Induction heads:** Copy patterns from earlier in the context The diversity of attention patterns is a key reason multi-head attention works so well — the model gets many different "views" of the same sequence.

Step 5: Concatenation & Output Projection

After all heads compute their attention outputs independently, we need to combine them back into a single representation. This happens in two steps: **Step 1: Concatenate** Each head produces a 4-dimensional output per token. We concatenate all heads' outputs: output = Concat(head_0_output, head_1_output) → 8 dimensions This gives us back the original dimensionality (d_model = 8) while preserving information from both heads. **Step 2: Linear projection (W_O)** The concatenated output is multiplied by a learned output projection matrix W_O: final_output = Concat(head_0, head_1) × W_O W_O has shape (d_model × d_model) = (8 × 8). It serves two purposes: 1. It **mixes** information across heads — head 0's findings can influence the features in head 1's dimension range 2. It provides an additional learned transformation to refine the combined representation This output projection is crucial. Without it, the heads' outputs would simply be stacked side by side with no interaction. W_O allows the model to learn how to best combine the different types of information each head extracted.
MultiHead(Q, K, V) = Concat(head_0, ..., head_h) × W_O

head_i = softmax(Q_i K_i^T / √d_k) × V_i

Shapes:
  Each head_i: (4 tokens × 4 dims)
  Concat: (4 tokens × 8 dims)
  W_O: (8 × 8)
  Output: (4 tokens × 8 dims)

Step 6: Benefits of Multi-Head Attention

Multi-head attention provides several critical advantages that make transformers so powerful: **1. Diverse Attention Patterns** Each head can specialize in different types of relationships — syntactic, semantic, positional, or even more abstract patterns. A single head must compromise; multiple heads can each focus on what they do best. **2. Richer Representations** The concatenated output combines insights from all heads. Token "sat" might gather subject information from Head 0 and modifier information from Head 1, ending up with a richer understanding of its context than any single head could provide. **3. Robustness** If one head learns a noisy or unhelpful pattern, the other heads can compensate. This redundancy makes the model more robust to initialization luck and training noise. **4. No Extra Cost** Splitting dimensions across heads instead of adding new dimensions means multi-head attention has the same computational cost as single-head attention with the same d_model. You get multiple attention patterns for free. **5. Interpretability** Individual heads can be examined to understand what the model learned. Researchers have found specific heads responsible for specific linguistic phenomena, making the model somewhat interpretable.

Step 7: Test Your Understanding

You've learned how multi-head attention splits Q, K, V into parallel heads that each capture different patterns, then combines them. Let's test your understanding!

Prerequisites

  • Scaled dot-product attention
  • Query, Key, Value projections
  • Matrix multiplication

Key Concepts

  • Multi-Head Attention
  • Head Splitting
  • Attention Pattern Specialization
  • Concatenation
  • Output Projection