Multi-Head Attention
Processing information in multiple subspaces simultaneously.
Multi-Head Attention
In practice, a single self-attention mechanism is not enough. Multi-Head Attention allows the model to run multiple attention processes in parallel, each focusing on different types of relationships.
Level 1 — Why Multiple Heads?
One head might focus on grammar (subject-verb agreement), while another focuses on semantics (word meaning), and a third focuses on references (pronouns).
Level 2 — Concatenation and Projection
The results from all "heads" are concatenated into one long vector and then passed through a final linear layer to bring it back to the original dimension.
Multi-Head vs Single-Head
Single-head attention averages out all relationships. Multi-head attention allows the model to maintain multiple distinct "interpretations" of the sentence simultaneously.
Level 3 — Parameter Efficiency
Despite having multiple heads, we don't increase the total number of parameters significantly because we split the original dimension between the heads (e.g., 512 total dim / 8 heads = 64 dim per head).
from torch import nn
# Example in PyTorch
mha = nn.MultiheadAttention(embed_dim=512, num_heads=8)
# query, key, value embeddings
q = torch.randn(10, 1, 512)
k = torch.randn(10, 1, 512)
v = torch.randn(10, 1, 512)
attn_output, _ = mha(q, k, v)