Transformers
Neil D. Lawrence
Attention Computation
\[\attentionMatrix = \softmax\left(\frac{\queryMatrix \keyMatrix^\top}{\sqrt{d_k}}\right)\]
\[\outputMatrix = \attentionMatrix \valueMatrix\]
The Chain Rule Challenge
- Standard NN: \(\frac{\partial L}{\partial \mathbf{X}} = \frac{\partial L}{\partial \outputMatrix} \frac{\partial \outputMatrix}{\partial \mathbf{X}}\)
- Transformer: \(\frac{\partial L}{\partial \mathbf{X}} = \frac{\partial L}{\partial \queryMatrix} \frac{\partial \queryMatrix}{\partial \mathbf{X}} + \frac{\partial L}{\partial \keyMatrix} \frac{\partial \keyMatrix}{\partial \mathbf{X}} + \frac{\partial L}{\partial \valueMatrix} \frac{\partial \valueMatrix}{\partial \mathbf{X}}\)
Gradient Flow Through Attention
- Step 1: \(\frac{\partial L}{\partial \outputMatrix}\) (from next layer)
- Step 2: \(\frac{\partial L}{\partial \valueMatrix} = \frac{\partial L}{\partial \outputMatrix} \attentionMatrix^\top\)
- Step 3: \(\frac{\partial L}{\partial \attentionMatrix} = \frac{\partial L}{\partial \outputMatrix} \valueMatrix^\top\)
Attention Weights Gradient
\[\frac{\partial L}{\partial \attentionMatrix} = \frac{\partial L}{\partial \outputMatrix} \valueMatrix^\top\]
Query-Key Interaction Gradient
- Attention logits: \(\logitsMatrix = \frac{\queryMatrix \keyMatrix^\top}{\sqrt{d_k}}\)
- Gradient: \(\frac{\partial L}{\partial \logitsMatrix} = \attentionMatrix \odot \left(\frac{\partial L}{\partial \attentionMatrix} - \sum_{j} \frac{\partial L}{\partial \attentionMatrix_{:,j}} \odot \attentionMatrix_{:,j}\right)\)
Query and Key Gradients
\[\frac{\partial L}{\partial \queryMatrix} = \frac{\partial L}{\partial \logitsMatrix} \keyMatrix\]
\[\frac{\partial L}{\partial \keyMatrix} = \frac{\partial L}{\partial \logitsMatrix}^\top \queryMatrix\]
Weight Matrix Gradients
- Query weights: \(\frac{\partial L}{\partial \queryWeightMatrix} = \mathbf{X}^\top \frac{\partial L}{\partial \queryMatrix}\)
- Key weights: \(\frac{\partial L}{\partial \keyWeightMatrix} = \mathbf{X}^\top \frac{\partial L}{\partial \keyMatrix}\)
- Value weights: \(\frac{\partial L}{\partial \valueWeightMatrix} = \mathbf{X}^\top \frac{\partial L}{\partial \valueMatrix}\)
Multi-Head Attention
- Multiple heads: \(\attentionHead_i = \attentionFunction(\queryMatrix_i, \keyMatrix_i, \valueMatrix_i)\)
- Concatenation: \(\multiHeadOutput = (\attentionHead_1, \ldots, \attentionHead_h) \outputWeightMatrix\)
- Gradient complexity: Each head has its own \(Q\), \(K\), \(V\) gradients
Implementation Considerations
- Memory efficiency: Store intermediate computations
- Numerical stability: Scale attention weights appropriately
- Parallel computation: Each head can be computed independently
- Gradient accumulation: Sum gradients across heads
Summary
- Three-path chain rule: Input appears in \(Q\), \(K\), \(V\) transformations
- Softmax gradient: Standard formula with attention weight constraints
- Multi-head complexity: Each head has independent gradients
- Efficient implementation: Careful memory and numerical considerations
Create and Test Basic Attention
Test Multi-Head Attention
Test Chain Rule in Attention
Train Simple Attention Model
Visualise Attention Weights
Test Different Numbers of Heads
Test Different Activation Functions
Visualise Different Attention Patterns
Training Simple Model
Training and Generalization
Attention as Implicit Regularization
- Sparse attention: Many attention weights are near zero
- Implicit regularization: Attention patterns emerge during training
- Structural constraints: Attention provides architectural bias
- Empirical scaling: Performance generally improves with model size
Summary and Future Directions
- Key insights: Multi-path chain rule, attention as regularization
- Implementation: Practical considerations for transformer training
- Theory: Connections to overparameterization and generalization
- Future: What comes after transformers?
Further Reading
- Vaswani et al. (2017) “Attention is All You Need”
- Rogers et al. (2020) “A Primer on Neural Network Models for Natural Language Processing”
- Elhage et al. (2021) “A Mathematical Framework for Transformer Circuits”
- The Annotated Transformer
- Transformer Math 101