A single attention head can only learn one pattern of relationships between positions in a sequence. But language is rich with multiple simultaneous relationships: syntactic structure, semantic similarity, coreference chains, and positional patterns all coexist. Multi-head attention solves this by running several attention computations in parallel, each learning to focus on different types of relationships. It is one of the key innovations that makes transformers so effective.

The Limitation of Single-Head Attention

Imagine you are reading a complex sentence. A single attention head must decide on one set of attention weights for each token. If it focuses on syntactic dependencies, it might miss semantic relationships. If it captures nearby word associations, it might overlook long-range coreferences.

Consider this sentence: "The lawyer who argued the case in court yesterday won the appeal." A single attention operation has to decide: should "won" attend strongly to "lawyer" (the subject), to "case" (the object of the relative clause), to "yesterday" (temporal context), or to "appeal" (its direct object)? Ideally, we want to capture all these relationships simultaneously.

Single-head attention is like having one pair of eyes. Multi-head attention gives the model multiple pairs, each looking at the input from a different perspective.

How Multi-Head Attention Works

The multi-head attention mechanism works by splitting the attention computation into multiple parallel "heads," each operating on a lower-dimensional subspace of the input.

Here is the step-by-step process:

  1. Project inputs: The input embeddings (of dimension d_model) are projected into h separate sets of queries, keys, and values using different learned weight matrices. Each head has its own W_Q, W_K, and W_V, projecting into dimension d_k = d_model / h.
  2. Compute attention per head: Each head independently computes scaled dot-product attention on its projected Q, K, and V matrices.
  3. Concatenate outputs: The outputs from all h heads are concatenated back into a single vector of dimension d_model.
  4. Final linear projection: A final weight matrix W_O projects the concatenated output, mixing information from all heads.
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) * W_O
where head_i = Attention(Q * W_Q_i, K * W_K_i, V * W_V_i)

Key Takeaway

Multi-head attention runs h parallel attention operations on lower-dimensional projections, then combines the results. This costs roughly the same as single-head attention on the full dimension, but captures richer patterns.

What Different Heads Learn

Research has revealed fascinating patterns in what individual attention heads learn to specialize in. Analysis of trained transformer models shows that different heads consistently focus on different linguistic phenomena:

Syntactic Heads

Some heads learn to track grammatical relationships. In BERT, researchers found heads that reliably attend from a verb to its subject, from a noun to its determiner, or from a preposition to its object. These heads essentially learn a parse tree without ever being explicitly trained on syntactic annotations.

Positional Heads

Certain heads develop consistent positional attention patterns. Some always attend to the immediately preceding token, others to the token two positions back. These heads help the model track local context and sequential patterns.

Semantic Heads

Other heads learn to connect semantically related tokens regardless of their position. These might link "doctor" and "patient" or "rain" and "umbrella," capturing thematic relationships in the text.

Rare Pattern Heads

Some heads appear to specialize in unusual or rare patterns -- attending to punctuation, special tokens like [CLS] or [SEP], or specific positional patterns. While individually these may seem less important, they contribute to the model's overall ability to handle edge cases.

The Numbers: Head Counts in Practice

Different transformer architectures use different numbers of attention heads:

  • BERT-base: 12 heads with d_model = 768, so each head operates on d_k = 64 dimensions
  • BERT-large: 16 heads with d_model = 1024, so d_k = 64
  • GPT-2: 12 heads (small) to 25 heads (XL) depending on model size
  • GPT-3 (175B): 96 heads with d_model = 12288, so d_k = 128
  • LLaMA-70B: 64 heads with d_model = 8192, d_k = 128

Notice that d_k is typically kept at 64 or 128 dimensions per head, regardless of model size. Larger models simply use more heads rather than larger heads.

Head Pruning: Not All Heads Are Equal

An important finding from research on multi-head attention is that not all heads are equally important. Studies have shown that many heads can be removed (pruned) after training with minimal impact on performance. In some cases, removing up to 40% of attention heads from a trained model causes less than 1% degradation in task performance.

This suggests several things:

  • The model is somewhat over-parameterized in its attention mechanism
  • Redundancy between heads provides robustness during training
  • Head pruning is a viable compression technique for deployment

However, the heads that do matter are critical. Removing certain key heads can cause dramatic performance drops, confirming that the specialization of heads is real and functionally important.

Multi-head attention is not just parallelism for speed. Each head learns a different aspect of language, and together they create a richer, more nuanced understanding than any single attention mechanism could achieve alone.

Grouped Query Attention: A Modern Optimization

As models have grown larger, researchers have found ways to make multi-head attention more efficient. One important innovation is Grouped Query Attention (GQA), used in models like LLaMA 2.

In standard multi-head attention, each head has its own separate Q, K, and V projections. In GQA, multiple query heads share the same key and value projections. For example, if a model has 32 query heads and 8 key-value groups, every 4 query heads share one set of keys and values.

This dramatically reduces the memory required for the KV cache during inference, which is the primary bottleneck for serving large language models. GQA provides most of the representation power of full multi-head attention at a fraction of the memory cost.

An even more aggressive variant is Multi-Query Attention (MQA), where all query heads share a single set of keys and values. While MQA offers the largest memory savings, GQA strikes a better balance between efficiency and quality.

Key Takeaway

Multi-head attention enables transformers to simultaneously capture syntactic, semantic, and positional patterns. Modern variants like Grouped Query Attention maintain this expressiveness while dramatically reducing the memory cost for inference.