Transformers

Neil D. Lawrence

From Deep Networks to Transformers

  • Review: Deep networks, chain rule, overparameterization
  • Today: How attention mechanisms extend these concepts
  • Focus: Multi-path chain rule in transformer architectures
  • Connection: How transformers relate to generalization theory

The Attention Mechanism

Chain Rule for Transformer Attention

Transformer Attention Structure

  • Input: \(\mathbf{X}\) appears in three roles
  • Query: \(\queryMatrix = \mathbf{X}\queryWeightMatrix\)
  • Key: \(\keyMatrix = \mathbf{X}\keyWeightMatrix\)
  • Value: \(\valueMatrix = \mathbf{X}\valueWeightMatrix\)

Attention Computation

\[\attentionMatrix = \softmax\left(\frac{\queryMatrix \keyMatrix^\top}{\sqrt{d_k}}\right)\]

\[\outputMatrix = \attentionMatrix \valueMatrix\]

The Chain Rule Challenge

  • Standard NN: \(\frac{\partial L}{\partial \mathbf{X}} = \frac{\partial L}{\partial \outputMatrix} \frac{\partial \outputMatrix}{\partial \mathbf{X}}\)
  • Transformer: \(\frac{\partial L}{\partial \mathbf{X}} = \frac{\partial L}{\partial \queryMatrix} \frac{\partial \queryMatrix}{\partial \mathbf{X}} + \frac{\partial L}{\partial \keyMatrix} \frac{\partial \keyMatrix}{\partial \mathbf{X}} + \frac{\partial L}{\partial \valueMatrix} \frac{\partial \valueMatrix}{\partial \mathbf{X}}\)

Gradient Flow Through Attention

  • Step 1: \(\frac{\partial L}{\partial \outputMatrix}\) (from next layer)
  • Step 2: \(\frac{\partial L}{\partial \valueMatrix} = \frac{\partial L}{\partial \outputMatrix} \attentionMatrix^\top\)
  • Step 3: \(\frac{\partial L}{\partial \attentionMatrix} = \frac{\partial L}{\partial \outputMatrix} \valueMatrix^\top\)

Attention Weights Gradient

\[\frac{\partial L}{\partial \attentionMatrix} = \frac{\partial L}{\partial \outputMatrix} \valueMatrix^\top\]

Query-Key Interaction Gradient

  • Attention logits: \(\logitsMatrix = \frac{\queryMatrix \keyMatrix^\top}{\sqrt{d_k}}\)
  • Gradient: \(\frac{\partial L}{\partial \logitsMatrix} = \attentionMatrix \odot \left(\frac{\partial L}{\partial \attentionMatrix} - \sum_{j} \frac{\partial L}{\partial \attentionMatrix_{:,j}} \odot \attentionMatrix_{:,j}\right)\)

Query and Key Gradients

\[\frac{\partial L}{\partial \queryMatrix} = \frac{\partial L}{\partial \logitsMatrix} \keyMatrix\]

\[\frac{\partial L}{\partial \keyMatrix} = \frac{\partial L}{\partial \logitsMatrix}^\top \queryMatrix\]

Input Matrix Gradient

\[\frac{\partial L}{\partial \mathbf{X}} = \frac{\partial L}{\partial \queryMatrix} \queryWeightMatrix^\top + \frac{\partial L}{\partial \keyMatrix} \keyWeightMatrix^\top + \frac{\partial L}{\partial \valueMatrix} \valueWeightMatrix^\top\]

Weight Matrix Gradients

  • Query weights: \(\frac{\partial L}{\partial \queryWeightMatrix} = \mathbf{X}^\top \frac{\partial L}{\partial \queryMatrix}\)
  • Key weights: \(\frac{\partial L}{\partial \keyWeightMatrix} = \mathbf{X}^\top \frac{\partial L}{\partial \keyMatrix}\)
  • Value weights: \(\frac{\partial L}{\partial \valueWeightMatrix} = \mathbf{X}^\top \frac{\partial L}{\partial \valueMatrix}\)

Multi-Head Attention

  • Multiple heads: \(\attentionHead_i = \attentionFunction(\queryMatrix_i, \keyMatrix_i, \valueMatrix_i)\)
  • Concatenation: \(\multiHeadOutput = (\attentionHead_1, \ldots, \attentionHead_h) \outputWeightMatrix\)
  • Gradient complexity: Each head has its own \(Q\), \(K\), \(V\) gradients

Implementation Considerations

  • Memory efficiency: Store intermediate computations
  • Numerical stability: Scale attention weights appropriately
  • Parallel computation: Each head can be computed independently
  • Gradient accumulation: Sum gradients across heads

Summary

  • Three-path chain rule: Input appears in \(Q\), \(K\), \(V\) transformations
  • Softmax gradient: Standard formula with attention weight constraints
  • Multi-head complexity: Each head has independent gradients
  • Efficient implementation: Careful memory and numerical considerations

Transformer Architecture

Simple Transformer Implementation

Create and Test Basic Attention

Test Multi-Head Attention

Test Chain Rule in Attention

Train Simple Attention Model

Visualise Attention Weights

Test Different Numbers of Heads

Test Different Activation Functions

Visualise Different Attention Patterns

Positional Encoding Test

Simple Transformer Model

Training Simple Model

Training and Generalization

Attention as Implicit Regularization

  • Sparse attention: Many attention weights are near zero
  • Implicit regularization: Attention patterns emerge during training
  • Structural constraints: Attention provides architectural bias
  • Empirical scaling: Performance generally improves with model size

Overparameterization in Transformers

  • Large parameter counts: Modern transformers have millions to billions of parameters
  • Generalization through overparameterization: They work well BECAUSE they are highly overparameterized
  • Attention as structure: The attention mechanism provides architectural constraints
  • Connection to previous discussion: This extends our overparameterization analysis to transformers

Summary and Future Directions

  • Key insights: Multi-path chain rule, attention as regularization
  • Implementation: Practical considerations for transformer training
  • Theory: Connections to overparameterization and generalization
  • Future: What comes after transformers?

Further Reading

  • Vaswani et al. (2017) “Attention is All You Need”
  • Rogers et al. (2020) “A Primer on Neural Network Models for Natural Language Processing”
  • Elhage et al. (2021) “A Mathematical Framework for Transformer Circuits”
  • The Annotated Transformer
  • Transformer Math 101

Thanks!

References