RNN & Attention
Recurrent neural networks and attention mechanisms for sequence modeling.
Recurrent Neural Networks
Vanilla RNN
Schematically: ht = tanh(Whh ht−1 + Wxh xt + b). The output at each step can feed a loss (many-to-many), or only the final hidden state can classify the whole sequence (many-to-one). Bidirectional RNNs run one RNN forward and one backward, concatenating states so each position sees past and future context—common in tagging, not usable for causal autoregressive generation without masking tricks.
Training truncates BPTT to a fixed window to limit memory; very long dependencies still challenge plain RNNs.
LSTM and GRU
LSTM adds a cell state ct and gates: forget, input, output. The cell updates additively, giving gradients a “highway†that reduces vanishing compared to repeated tanh squashing alone. GRU merges ideas into fewer gates (reset/update)—often similar quality with fewer parameters. Both are drop-in replacements for nn.RNN in PyTorch.
PyTorch: nn.LSTM
import torch
import torch.nn as nn
# x: (batch, seq_len, input_size)
lstm = nn.LSTM(input_size=128, hidden_size=256, num_layers=2, batch_first=True)
x = torch.randn(32, 50, 128)
out, (h_n, c_n) = lstm(x)
# out[:, -1, :] — last timestep; or use out for per-step heads
Summary
- RNNs map sequences via a recurrent hidden state and shared weights across time.
- BPTT causes gradient issues; LSTM/GRU gates mitigate vanishing over longer spans.
- Bidirectional RNNs use future context; unidirectional suits online/decoding settings.
- Next: Attention—soft, content-based aggregation that powers Transformers.
Attention Mechanism
Scaled Dot-Product Attention
For queries Q, keys K, values V (as matrices of row-vectors), Attention(Q, K, V) = softmax(QKT / √dk) V. The dot product QKT scores how much each query aligns with each key; dividing by √dk (dimension of key vectors) keeps softmax from saturating when dk is large. The result is a mixture of value rows—each query’s output is a convex combination of values.
Multi-head attention runs several attention operations in parallel with different learned linear projections of Q, K, V, then concatenates and projects again—different heads can specialize in syntax, long-range, or local patterns.
Masking
For language modeling, positions must not attend to future tokens. A causal mask sets logits to −∞ above the diagonal before softmax so those weights are zero. Padding masks zero out attention to pad tokens in batched sequences. Vision Transformers apply attention over image patches with similar machinery.
PyTorch: MultiheadAttention
import torch.nn as nn
# embed_dim must be divisible by num_heads
mha = nn.MultiheadAttention(embed_dim=256, num_heads=8, batch_first=True)
# x: (batch, seq_len, embed_dim)
x = torch.randn(4, 100, 256)
out, attn_weights = mha(x, x, x)
Full Transformers stack MHA with feed-forward nets, residuals, and layer norm—see dedicated transformer tutorials for the complete block.
Summary
- Attention = softmax-normalized key–query similarity applied to values.
- Scaling by √dk stabilizes gradients; multi-head increases representational flexibility.
- Encoder–decoder vs self-attention differ in where Q, K, V are drawn from.
- Masks enforce causality and ignore padding; complexity scales quadratically with length.