Week 5: Transformers

[jupyter][google colab][reveal]

Neil D. Lawrence

Abstract:

This lecture builds on deep neural networks to explore transformer architectures, focusing on how attention mechanisms require sophisticated chain rule applications and how they connect to the overparameterization and generalization themes.

ML Foundations Course Notebook Setup

[edit]

We install some bespoke codes for creating and saving plots as well as loading data sets.

%%capture
%pip install notutils
%pip install git+https://github.com/lawrennd/ods.git
%pip install git+https://github.com/lawrennd/mlai.git
import notutils
import pods
import mlai
import mlai.plot as plot

From Deep Networks to Transformers

[edit]

In our previous lectures, we explored how composing layers of basis functions creates deep neural networks, and we examined the chain rule and automatic differentiation that makes training these networks possible. We’ve seen how we can consider structured data through convolutional neural networks, graph neural networks, recurrent networks. Today we’ll see how these foundations extend to one of the most important architectural innovations in deep learning: the transformer.

Transformers represent a fundamental shift from the sequential processing of RNNs to parallel attention mechanisms. This creates new challenges for automatic differentiation, as we’ll see.

The Attention Mechanism

The key insight of transformers is the attention mechanism, which allows the model to focus on different parts of the input sequence simultaneously. This creates a more complex gradient flow than standard neural networks.

Chain Rule for Transformer Attention

[edit]

The transformer attention mechanism is more complex than standard neural networks because the same input matrix \(\mathbf{X}\) appears in three different linear transformations. This creates a more intricate chain rule when computing gradients.

This multiple appearance is what allows the transformer to include variable length sequences. But it make sthe chain rule computation a little more complex than for a standard neural network.

The attention mechanism computes a weighted combination of values, where the weights are determined by the similarity between queries and keys. The softmax ensures the weights sum to one.

In a standard neural network, we have a single path from input to output. In transformer attention, we have three parallel paths through the same input, making the chain rule more complex.

The gradient flow through attention involves computing how the loss changes with respect to the attention weights and the value matrix.

The gradient with respect to the attention matrix comes from the product with the value matrix. This tells us how much each attention weight should change.

The gradient through the softmax requires the standard softmax gradient formula, accounting for the fact that attention weights sum to one.

The gradients for queries and keys come from their interaction in the attention logits. Each query interacts with all keys, and each key interacts with all queries.

Finally, we combine all three gradient paths to get the gradient with respect to the input matrix. This is the key insight: the same input appears in three different transformations.

The gradients for the weight matrices follow the standard pattern: input matrix transposed times the gradient of the output.

Multi-head attention adds another layer of complexity. Each head computes its own attention, and the gradients must be computed for each head separately before being combined.

Implementing transformer gradients efficiently requires careful attention to memory usage and numerical stability. The softmax operation can be numerically unstable for large attention scores.

The transformer attention mechanism requires a more sophisticated understanding of the chain rule because the same input participates in multiple parallel computations.

Transformer Architecture

Now we’ll see how to build a complete transformer model, integrating all the components we’ve discussed.

Simple Transformer Implementation

[edit]
import numpy as np

# Create data
X_seq, y_seq = create_synthetic_sequence_data(200, 8, 30)

print(f"Sequence data: {X_seq.shape} -> {y_seq.shape}")
print(f"Sample sequence: {X_seq[0]}")
print(f"Target sequence: {y_seq[0]}")

Create and Test Basic Attention

from mlai import Attention
d_model = 64
n_heads = 4
seq_length = 8
vocab_size = 30

# Create basic attention mechanism
attention = Attention(d_model)

# Test forward pass
X_test = np.random.randn(2, seq_length, d_model)
attn_output, attn_weights = attention.forward(X_test, X_test, X_test)

print("Basic Attention Test:")
print(f"Input shape: {X_test.shape}")
print(f"Output shape: {attn_output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print(f"Attention weights sum (should be 1): {attn_weights.sum(axis=-1)[0, 0]}")
print(f"Model parameters: {attention.W_q.size + attention.W_k.size + attention.W_v.size + attention.W_o.size}")

Test Multi-Head Attention

from mlai import MultiHeadAttention
# Test multi-head attention (built from basic attention)
multi_head_attention = MultiHeadAttention(d_model, n_heads)

# Forward pass
X_test = np.random.randn(2, seq_length, d_model)
attn_output, attn_weights = multi_head_attention.forward(X_test, X_test, X_test)

print("Multi-Head Attention Test:")
print(f"Input shape: {X_test.shape}")
print(f"Output shape: {attn_output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print(f"Attention weights sum (should be 1): {attn_weights.sum(axis=-1)[0, 0]}")
print(f"Number of heads: {n_heads}")

Test Chain Rule in Attention

# Test gradient flow through attention (demonstrating chain rule)
from mlai import MeanSquaredError

X_test = np.random.randn(2, seq_length, d_model)

# Forward pass through attention
output, attn_weights = attention.forward(X_test, X_test, X_test)

# Create dummy loss using proper loss function (consistent with neural network)
target = np.random.randn(2, seq_length, d_model)
loss_fn = MeanSquaredError()
loss_value = loss_fn.forward(output, target)

# Backward pass (demonstrates three-path chain rule)
loss_gradient = loss_fn.gradient(output, target)
gradients = attention.backward(loss_gradient, X_test, X_test, X_test, attn_weights)

print("Chain Rule Demonstration:")
print(f"Loss value: {loss_value:.4f}")
print(f"Input gradient shape: {gradients['grad_input'].shape}")
print(f"Input gradient norm: {np.linalg.norm(gradients['grad_input']):.4f}")
print("This shows how gradients flow through Q, K, V transformations")
print(f"Three-path chain rule: grad_query + grad_key + grad_value = grad_input")
print(f"Gradient verification: {np.allclose(gradients['grad_input'], gradients['grad_query'] + gradients['grad_key'] + gradients['grad_value'])}")

Train Simple Attention Model

X_seq, y_seq = create_synthetic_sequence_data(200, 8, 30)
model, losses = train_attention_model(X_seq, y_seq)

Figure: Transformer Training Progress for sequence modeling

Visualise Attention Weights

Figure: Attention weights visualisation from the first head showing which positions the model attends to

Test Different Numbers of Heads

# Test different numbers of heads (showing composition)
d_model = 64
seq_length = 8

for n_heads in [1, 2, 4, 8]:
    multi_head_attention = MultiHeadAttention(d_model, n_heads)
    X_test = np.random.randn(2, seq_length, d_model)
    
    output, attn_weights = multi_head_attention.forward(X_test, X_test, X_test)
    
    print(f"n_heads={n_heads}: output shape={output.shape}, attn shape={attn_weights.shape}")
    print(f"  Each head processes {d_model//n_heads} dimensions")

Test Different Activation Functions

from mlai import SoftmaxActivation
from mlai import SigmoidAttentionActivation
from mlai import IdentityMinusSoftmaxActivation
# Compare different activation functions for attention
from mlai import SoftmaxActivation, SigmoidAttentionActivation, IdentityMinusSoftmaxActivation

d_model = 32
seq_length = 4
X_test = np.random.randn(1, seq_length, d_model)

print("Comparing Attention Activation Functions:")
print("=" * 50)

# Standard softmax attention
softmax_attention = Attention(d_model, activation=SoftmaxActivation())
output_softmax, weights_softmax = softmax_attention.forward(X_test, X_test, X_test)

print("1. SoftmaxActivation (Standard):")
print(f"   Weights sum: {weights_softmax.sum(axis=-1)[0, 0]:.6f}")
print(f"   Weights range: [{weights_softmax.min():.6f}, {weights_softmax.max():.6f}]")
print(f"   Attention matrix:\n{weights_softmax[0, :, :]}")

# Sigmoid with normalization
sigmoid_attention = Attention(d_model, activation=SigmoidAttentionActivation())
output_sigmoid, weights_sigmoid = sigmoid_attention.forward(X_test, X_test, X_test)

print("\n2. SigmoidAttentionActivation:")
print(f"   Weights sum: {weights_sigmoid.sum(axis=-1)[0, 0]:.6f}")
print(f"   Weights range: [{weights_sigmoid.min():.6f}, {weights_sigmoid.max():.6f}]")
print(f"   Attention matrix:\n{weights_sigmoid[0, :, :]}")

# Identity minus softmax (interesting alternative)
identity_attention = Attention(d_model, activation=IdentityMinusSoftmaxActivation())
output_identity, weights_identity = identity_attention.forward(X_test, X_test, X_test)

print("\n3. IdentityMinusSoftmaxActivation:")
print(f"   Weights sum: {weights_identity.sum(axis=-1)[0, 0]:.6f}")
print(f"   Weights range: [{weights_identity.min():.6f}, {weights_identity.max():.6f}]")
print(f"   Diagonal entries (1-softmax): {np.diag(weights_identity[0, :, :])}")
print(f"   Off-diagonal entries (-softmax): {weights_identity[0, 0, 1]:.6f}, {weights_identity[0, 1, 0]:.6f}")
print(f"   Attention matrix:\n{weights_identity[0, :, :]}")

print("\nKey Differences:")
print("- Softmax: Standard attention, weights sum to 1, all positive")
print("- Sigmoid: Alternative activation, weights sum to 1, all positive") 
print("- Identity-Minus-Softmax: Diagonal positive (1-softmax), off-diagonal negative (-softmax), sum to 0")
print("  This creates a 'contrast' attention pattern that emphasizes self-connections while")
print("  de-emphasizing connections to other positions.")

Visualise Different Attention Patterns

Figure: Comparison of different attention activation functions showing how they create different attention patterns

# Show the different attention patterns
print("Attention Pattern Analysis:")
print("Softmax: Standard attention, weights sum to 1")
print("Sigmoid: Alternative activation, weights sum to 1") 
print("Identity-Softmax: Contrast attention, weights sum to 0")
print("This demonstrates different attention behaviors!")

The different attention activation functions create fundamentally different behaviors:

Softmax Attention (Standard): - All weights are positive (0 to 1) - Each row sums to 1 (probability distribution) - Represents ‘how much to attend to each position’ - Higher values = more attention

Sigmoid + Normalization: - Similar to softmax but uses sigmoid activation - All weights positive, rows sum to 1 - Alternative way to create attention weights

Identity Minus Softmax: - Diagonal entries: positive (1 - softmax) - Off-diagonal entries: negative (-softmax) - Each row sums to 0 (not 1!) - Creates ‘contrast’ attention: * Positive diagonal = ‘attend to self’ * Negative off-diagonal = ‘de-emphasize others’ - Could be useful for tasks requiring: * Self-focus (diagonal emphasis) * Contrast learning (positive vs negative weights) * Sparse attention patterns

This demonstrates how different activation functions can create fundamentally different attention behaviors, even with the same underlying Q, K, V computation!

Positional Encoding Test

from mlai import PositionalEncoding
# Test positional encoding
pe = PositionalEncoding(d_model, max_length=100)
X_test = np.random.randn(2, seq_length, d_model)

X_with_pe = pe.forward(X_test)

print("Positional Encoding Test:")
print(f"Input shape: {X_test.shape}")
print(f"Output shape: {X_with_pe.shape}")
print(f"PE added: {np.allclose(X_test + pe.pe[:seq_length], X_with_pe)}")

Simple Transformer Model

from mlai import Transformer
# Create transformer model using the proper Model class
vocab_size = 30
d_model = 64
n_heads = 4

# Use the Transformer model class (inherits from Model)
transformer = Transformer(d_model=d_model, n_heads=n_heads, vocab_size=vocab_size)

print("Transformer Model Test:")
print(f"Model dimension: {transformer.d_model}")
print(f"Number of heads: {transformer.n_heads}")
print(f"Vocabulary size: {transformer.vocab_size}")
print(f"Is a Model: {isinstance(transformer, Model)}")
print(f"Has predict method: {hasattr(transformer, 'predict')}")
print(f"Has objective method: {hasattr(transformer, 'objective')}")
print(f"Has fit method: {hasattr(transformer, 'fit')}")
# Test forward pass with proper Model interface
X_test = np.random.randint(0, vocab_size, (2, 8))
output = transformer.predict(X_test)

print("Transformer Model Test:")
print(f"Input shape: {X_test.shape}")
print(f"Output shape: {output.shape}")
print(f"Model parameters: {transformer.embedding.size + transformer.output_projection.size + sum(head.W_q.size + head.W_k.size + head.W_v.size + head.W_o.size for head in transformer.attention.attention_heads)}")
print("This model follows the same pattern as NeuralNetwork - it's a proper Model class")

Training Simple Model

X_seq, y_seq = create_synthetic_sequence_data(200, 8, 30)
transformer_model, transformer_losses = train_transformer_model(X_seq, y_seq, vocab_size=30)

Figure: Simple Transformer Training Progress

Training and Generalization

How do transformers relate to the overparameterization and generalization themes we discussed in the previous lecture?

Attention as Implicit Regularization

The attention mechanism provides a form of implicit regularization. Unlike the explicit regularization we discussed for standard neural networks, attention creates sparse, interpretable patterns that emerge during training.

Overparameterization in Transformers

Transformers generalize well precisely because they are highly overparameterized. This extends our previous discussion of how overparameterization enables generalization through the optimization process, with the attention mechanism providing additional structural constraints.

Summary and Future Directions

Transformers represent a significant evolution in deep learning architectures, but they also raise new questions about optimization, generalization, and the fundamental principles of learning. The attention mechanism provides a new form of inductive bias that we’re still learning to understand theoretically.

Further Reading

Thanks!

For more information on these subjects and more you might want to check the following resources.

References