Transformers

Neil D. Lawrence

From Deep Networks to Transformers

  • Review: Deep networks, chain rule, overparameterization
  • Today: How attention mechanisms extend these concepts
  • Focus: Multi-path chain rule in transformer architectures
  • Connection: How transformers relate to generalization theory

The Attention Mechanism

Chain Rule for Transformer Attention

Transformer Attention Structure

  • Input: \(\mathbf{X}\) appears in three roles
  • Query: \(\queryMatrix = \mathbf{X}\queryWeightMatrix\)
  • Key: \(\keyMatrix = \mathbf{X}\keyWeightMatrix\)
  • Value: \(\valueMatrix = \mathbf{X}\valueWeightMatrix\)

Attention Computation

\[\attentionMatrix = \softmax\left(\frac{\queryMatrix \keyMatrix^\top}{\sqrt{d_k}}\right)\]

\[\outputMatrix = \attentionMatrix \valueMatrix\]

The Chain Rule Challenge

  • Standard NN: \(\frac{\partial L}{\partial \mathbf{X}} = \frac{\partial L}{\partial \outputMatrix} \frac{\partial \outputMatrix}{\partial \mathbf{X}}\)
  • Transformer: \(\frac{\partial L}{\partial \mathbf{X}} = \frac{\partial L}{\partial \queryMatrix} \frac{\partial \queryMatrix}{\partial \mathbf{X}} + \frac{\partial L}{\partial \keyMatrix} \frac{\partial \keyMatrix}{\partial \mathbf{X}} + \frac{\partial L}{\partial \valueMatrix} \frac{\partial \valueMatrix}{\partial \mathbf{X}}\)

Gradient Flow Through Attention

  • Step 1: \(\frac{\partial L}{\partial \outputMatrix}\) (from next layer)
  • Step 2: \(\frac{\partial L}{\partial \valueMatrix} = \frac{\partial L}{\partial \outputMatrix} \attentionMatrix^\top\)
  • Step 3: \(\frac{\partial L}{\partial \attentionMatrix} = \frac{\partial L}{\partial \outputMatrix} \valueMatrix^\top\)

Attention Weights Gradient

\[\frac{\partial L}{\partial \attentionMatrix} = \frac{\partial L}{\partial \outputMatrix} \valueMatrix^\top\]

Query-Key Interaction Gradient

  • Attention logits: \(\logitsMatrix = \frac{\queryMatrix \keyMatrix^\top}{\sqrt{d_k}}\)
  • Gradient: \(\frac{\partial L}{\partial \logitsMatrix} = \attentionMatrix \odot \left(\frac{\partial L}{\partial \attentionMatrix} - \sum_{j} \frac{\partial L}{\partial \attentionMatrix_{:,j}} \odot \attentionMatrix_{:,j}\right)\)

Query and Key Gradients

\[\frac{\partial L}{\partial \queryMatrix} = \frac{\partial L}{\partial \logitsMatrix} \keyMatrix\]

\[\frac{\partial L}{\partial \keyMatrix} = \frac{\partial L}{\partial \logitsMatrix}^\top \queryMatrix\]

Input Matrix Gradient

\[\frac{\partial L}{\partial \mathbf{X}} = \frac{\partial L}{\partial \queryMatrix} \queryWeightMatrix^\top + \frac{\partial L}{\partial \keyMatrix} \keyWeightMatrix^\top + \frac{\partial L}{\partial \valueMatrix} \valueWeightMatrix^\top\]

Weight Matrix Gradients

  • Query weights: \(\frac{\partial L}{\partial \queryWeightMatrix} = \mathbf{X}^\top \frac{\partial L}{\partial \queryMatrix}\)
  • Key weights: \(\frac{\partial L}{\partial \keyWeightMatrix} = \mathbf{X}^\top \frac{\partial L}{\partial \keyMatrix}\)
  • Value weights: \(\frac{\partial L}{\partial \valueWeightMatrix} = \mathbf{X}^\top \frac{\partial L}{\partial \valueMatrix}\)

Multi-Head Attention with Layered Architecture

  • Multiple heads: \(\attentionHead_i = \attentionFunction(\queryMatrix_i, \keyMatrix_i, \valueMatrix_i)\)
  • Concatenation: \(\multiHeadOutput = (\attentionHead_1, \ldots, \attentionHead_h) \outputWeightMatrix\)
  • Layered implementation: Each head is an independent \(\attentionLayer\) instance
  • Gradient flow: Gradients computed independently for each head, then combined

Cross-Attention and Mixed Attention

  • Cross-attention: \(\queryMatrix = \queryInput \queryWeightMatrix\), \(\keyMatrix = \keyValueInput \keyWeightMatrix\), \(\valueMatrix = \keyValueInput \valueWeightMatrix\)
  • Mixed attention: \(\queryMatrix = \mathbf{X}\queryWeightMatrix\), \(\keyMatrix = \keyValueInput \keyWeightMatrix\), \(\valueMatrix = \keyValueInput \valueWeightMatrix\)
  • Gradient complexity: Different input sources create separate gradient paths

Implementation Considerations

  • Memory efficiency: Store intermediate computations
  • Numerical stability: Scale attention weights appropriately
  • Parallel computation: Each head can be computed independently
  • Gradient accumulation: Sum gradients across heads
  • Output projection: Additional \(W_o\) matrix requires gradient computation

Summary

  • Three-path chain rule: Input appears in \(Q\), \(K\), \(V\) transformations
  • Softmax gradient: Standard formula with attention weight constraints
  • Multi-head complexity: Each head has independent gradients
  • Cross/mixed attention: Different input sources create separate gradient paths
  • Layered architecture: Modular design with independent gradient computation
  • Output projection: Additional gradient path through \(W_o\) matrix

Verification with Our Implementation

  • Gradient testing: Use \(\finiteDifferenceGradient\) to verify analytical gradients
  • Three-path verification: Check that \(\gradInput = \gradQuery + \gradKey + \gradValue\)
  • Cross-attention testing: Verify separate gradient paths for different inputs
  • Multi-head testing: Each head tested independently with finite differences

Transformer Architecture

Simple Transformer Implementation

Explore Different Sequence Types

Create and Test Basic Attention Layer

Test Multi-Head Attention Layer

Test Chain Rule in Attention Layer

Train Simple Attention Model with Layered Architecture

Visualise Attention Weights

Test Different Numbers of Heads

Test Different Activation Functions

Visualise Different Attention Patterns

Positional Encoding Layer Test

Build Transformer with Layered Architecture

Training Layered Transformer Model

Compare Different Sequence Types

Benefits of the New Layered Architecture

Training and Generalization

Attention as Implicit Regularization

  • Sparse attention: Many attention weights are near zero
  • Implicit regularization: Attention patterns emerge during training
  • Structural constraints: Attention provides architectural bias
  • Empirical scaling: Performance generally improves with model size

Overparameterization in Transformers

  • Large parameter counts: Modern transformers have millions to billions of parameters
  • Generalization through overparameterization: They work well BECAUSE they are highly overparameterized
  • Attention as structure: The attention mechanism provides architectural constraints
  • Connection to previous discussion: This extends our overparameterization analysis to transformers

Summary and Future Directions

  • Key insights: Multi-path chain rule, attention as regularization
  • Implementation: Practical considerations for transformer training
  • Theory: Connections to overparameterization and generalization
  • Future: What comes after transformers?

Further Reading

  • Vaswani et al. (2017) “Attention is All You Need”
  • Rogers et al. (2020) “A Primer on Neural Network Models for Natural Language Processing”
  • Elhage et al. (2021) “A Mathematical Framework for Transformer Circuits”
  • The Annotated Transformer
  • Transformer Math 101

Thanks!

References