Attention Mechanism Neural Alignment
Transformer Core Interpretability

Attention Mechanism: The Eyes of Neural Networks

Attention allows neural networks to focus on relevant parts of the input dynamically. From machine translation to LLMs and multimodal AI — complete mathematical and practical reference for additive, multiplicative, self, and multi-head attention.

Bahdanau

Additive (2014)

Luong

Multiplicative (2015)

Self-Attention

Q, K, V (2017)

Multi-Head

Parallel attention

What is Attention?

Attention is a neural component that dynamically computes a weighted sum of values, where weights depend on the similarity between a query and corresponding keys. It allows models to focus on specific parts of the input when producing each output element — mimicking visual attention.

Query (Q) Similarity Keys (K)

Attention Weights (softmax)

Weighted Sum → Context Vector × Values (V)

Core idea: Not all input elements are equally important. Learn to assign importance dynamically.

The Alignment Problem: Why Attention?

Seq2Seq without Attention

Encoder compresses entire source into one fixed-size vector → information bottleneck. Long sentences degrade rapidly.

"I love cats" → fixed vector (5-dim) → "Je ___ ?"

Seq2Seq with Attention

Decoder looks at all encoder states, weights them dynamically. Solves bottleneck, improves long-range translation.

Alignment: "cat" ↔ "chat" at step 3

Breakthrough (2014): Bahdanau et al. introduced attention to neural machine translation. BLEU scores jumped, and long sentences became tractable.

Bahdanau Attention (Additive)

Additive Attention Score

eᵢⱼ = vᵃ tanh(Wₐ [sᵢ₋₁; hⱼ])

or concat version: score(s, h) = vᵃ tanh(Wₐ[s; h])

Context vector cᵢ = Σⱼ αᵢⱼ hⱼ

Historical Significance

First attention mechanism for NLP. Used in RNN encoder-decoders. Computationally expensive (fully connected layer per alignment).

Bidirectional RNN Concatenation tanh

Bahdanau Attention (PyTorch)
class BahdanauAttention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.W_a = nn.Linear(hidden_dim * 2, hidden_dim)  # [s; h]
        self.v_a = nn.Linear(hidden_dim, 1, bias=False)
    
    def forward(self, query, encoder_outputs):
        # query: decoder hidden (batch, hidden)
        # encoder_outputs: (batch, seq_len, hidden)
        seq_len = encoder_outputs.size(1)
        query = query.unsqueeze(1).repeat(1, seq_len, 1)  # (batch, seq_len, hidden)
        
        # Combine query and encoder outputs
        energy = torch.tanh(self.W_a(torch.cat((query, encoder_outputs), dim=2)))  # (batch, seq_len, hidden)
        scores = self.v_a(energy).squeeze(2)  # (batch, seq_len)
        attn_weights = torch.softmax(scores, dim=1)
        
        context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1)
        return context, attn_weights

Luong Attention (Multiplicative)

Scoring Functions

Dot: score = sᵀ h

General: score = sᵀ W h

Concat: score = vᵀ tanh(W[s; h])

Key Differences

Luong computes attention after decoder output (vs before in Bahdanau). Simpler, faster. Uses top-layer state only.

Types: global (all source steps) vs local (window).

Luong Dot-Product Attention
def luong_dot_attention(query, encoder_outputs):
    # query: (batch, 1, hidden)
    # encoder_outputs: (batch, seq_len, hidden)
    scores = torch.bmm(query, encoder_outputs.transpose(1, 2))  # (batch, 1, seq_len)
    attn_weights = torch.softmax(scores, dim=2)
    context = torch.bmm(attn_weights, encoder_outputs)
    return context, attn_weights

Scaled Dot-Product Attention

The Transformer Formula

Attention(Q, K, V) = softmax(QKᵀ / √dₖ) V

Q, K, V: queries, keys, values matrices.
√dₖ: scaling factor prevents softmax saturation.

Why Scaling?

For large dₖ, dot products grow large in magnitude, pushing softmax into regions of vanishing gradients. Scaling fixes this.

Scaled Dot-Product Attention (NumPy/Torch)
def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    attention_weights = torch.softmax(scores, dim=-1)
    output = torch.matmul(attention_weights, V)
    return output, attention_weights
Attention Matrix: Rows = queries, Cols = keys. Each row sums to 1.

Multi-Head Attention

Instead of one attention function, project Q, K, V h times with different linear projections, perform attention in parallel, concatenate, and project.

Diverse Representations

Each head learns different relationships: syntactic, semantic, coreference, positional.

MultiHead(Q,K,V)

Concat(head₁,...,headₕ)Wᴼ

headᵢ = Attention(QWᵢ^Q, KWᵢ^K, VWᵢ^V)

Typical Values

h = 8, 12, 16, 32. dₖ = d_v = d_model / h.

Attention Variants: Self, Cross, Causal

Self-Attention

Q, K, V from same sequence. Each token attends to all tokens in the same sequence. Captures intra-sequence dependencies.

Encoders BERT

Cross-Attention

Q from decoder, K, V from encoder. Decoder attends to input sequence. Essential for seq2seq.

T5, BART

Causal (Masked) Attention

Prevents attending to future tokens. Upper triangular mask set to -∞. Used in autoregressive decoders.

GPT, Llama

Causal Attention Mask
def causal_mask(size):
    """Upper triangular matrix with zeros on diagonal and below, -inf above"""
    mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
    return mask  # True where future tokens (to be masked)

Visualizing Attention Weights

Alignment Matrix

Plot attention weights as heatmap. Rows = decoder steps, Cols = encoder steps. Reveals word alignment.

[0.9, 0.05, 0.05]
[0.1, 0.8, 0.1]
[0.1, 0.1, 0.8]
Probing Attention Heads

Certain heads specialize: positional heads attend to previous/next token, syntactic heads attend to dependent tokens, rare word heads.

Tools: BertViz, exBERT, AttentionViz for interactive exploration.

Attention Beyond NLP

Vision

Spatial attention: Attend to relevant image regions. ViT uses self-attention on patches. Cross-attention in image captioning.

Audio

Speech recognition: Attend to acoustic frames. Listen, Attend and Spell (LAS).

Video

Temporal attention: Focus on relevant frames. Video transformers.

Multimodal

CLIP, Flamingo, LLaVA: cross-attention between image and text.

Graphs

Graph Attention Networks (GAT): attend to neighbor nodes.

Reinforcement Learning

Attend to relevant observations in memory.

Attention Types – Cheatsheet

Bahdanau Additive, concat
Luong Dot, general
Scaled Dot QKᵀ/√d
Multi-Head Parallel
Self Intra-sequence
Cross Encoder-decoder
Causal Autoregressive
Spatial Vision

Attention Mechanism Comparison

Attention Type Score Function Complexity Typical Use
Bahdanau (Additive)vᵃ tanh(W[s; h])O(n·d²)RNN seq2seq
Luong (Dot)sᵀ hO(n·d)RNN, efficient
Scaled Dot-ProductQKᵀ/√dO(n²·d)Transformers
Multi-Headh × scaled dotO(n²·d·h)BERT, GPT
Graph AttentionLeakyReLU(aᵀ[Whᵢ; Whⱼ])O(E·d)Graph networks

Attention Pitfalls & Debugging

⚠️ Attention collapse: All weights equal. Causes: bad initialization, lack of training, model too small.
⚠️ Quadratic complexity: O(n²) for self-attention. Use sparse attention, Linformer, Longformer.
✅ Debug: Always visualize attention matrices. Entropy should be moderate (not 0, not uniform).
✅ Multi-head diversity: Check correlation between heads. Low correlation = diverse features.