Tutorial 1: The Logit Lens¶
Paper: interpreting GPT: the logit lens by nostalgebraist (2020)
Difficulty: Beginner | Time: 1-2 hours
Model Choice
The original paper analyzed GPT-2. This tutorial uses Llama-3.2-1B-Instruct to demonstrate the technique on a modern model. The core findings (predictions crystallizing in later layers, early layers predicting related tokens) generalize across transformer architectures, though specific layer numbers may differ. For exact paper reproduction, use a GPT-2 model via InterpretableModel("gpt2").
Learning Objectives¶
By the end of this tutorial, you will:
- Understand the residual stream concept in transformers
- Know how to project intermediate hidden states to token predictions
- Observe how predictions evolve across layers
- Use mlxterp's built-in
logit_lens()method - Interpret logit lens visualizations
Introduction¶
What is the Logit Lens?¶
The logit lens is a technique for peeking inside transformer models to see what they're "thinking" at each layer. The key insight is that hidden states at intermediate layers can be projected through the output embedding matrix (the "unembedding" matrix) to produce token logits (unnormalized scores that can be converted to probabilities via softmax).
Logits vs Probabilities
Throughout this tutorial, "scores" refer to logits (raw model outputs before softmax), not probabilities. Higher logits indicate stronger predictions. To get probabilities, apply mx.softmax(logits). The relative ranking of tokens is preserved either way.
Why does this work?¶
Transformers use a residual stream architecture. Each layer reads from and writes to a running sum:
x_0 = embedding(input)
x_1 = x_0 + attention_1(x_0) + mlp_1(x_0)
x_2 = x_1 + attention_2(x_1) + mlp_2(x_1)
...
x_n = x_{n-1} + attention_n(x_{n-1}) + mlp_n(x_{n-1})
output = unembed(layer_norm(x_n))
The logit lens asks: What if we applied the unembedding at intermediate layers?
This reveals how the model iteratively refines its predictions from input to output.
Approximation Caveat
The logit lens is an approximation. The final layer norm and unembedding matrix are trained only on the final layer's output, not intermediate layers. This means early layer predictions may be systematically biased. The Tuned Lens addresses this by learning layer-specific corrections.
Key Finding from the Paper¶
The original paper found that:
"The model doesn't seem to be encoding the output distribution in some clever, inscrutable way that gets decoded at the very end. Instead, the residual stream appears to contain something like a 'running probability distribution' over tokens, which each layer refines."
Prerequisites¶
# Install mlxterp if you haven't already
# pip install mlxterp
from mlxterp import InterpretableModel
import mlx.core as mx
Part 1: Manual Implementation¶
Let's first understand the logit lens by implementing it manually.
Step 1: Load a Model¶
from mlxterp import InterpretableModel
# Load a model (any MLX-compatible model works)
model = InterpretableModel("mlx-community/Llama-3.2-1B-Instruct-4bit")
print(f"Model loaded with {len(model.layers)} layers")
Step 2: Get Hidden States at Each Layer¶
# Our test prompt
prompt = "The capital of France is"
# Trace through the model to capture all layer outputs
with model.trace(prompt) as trace:
pass
# Show what activations we captured
print(f"Captured {len(trace.activations)} activation tensors")
print("\nLayer activation keys:")
for key in sorted(trace.activations.keys()):
if 'layers' in key and key.endswith(('.self_attn', '.mlp', '.input_layernorm')):
continue # Skip component-level activations for now
shape = trace.activations[key].shape if hasattr(trace.activations[key], 'shape') else 'N/A'
print(f" {key}: {shape}")
Step 3: Project to Token Predictions¶
The key operation is projecting hidden states through the unembedding:
def manual_logit_lens(model, hidden_state):
"""
Manually apply the logit lens to a hidden state.
Args:
model: InterpretableModel instance
hidden_state: Hidden state tensor of shape (hidden_dim,) or (seq_len, hidden_dim)
Returns:
Top-k token predictions with probabilities
"""
# Step 1: Apply final layer normalization
# This is critical - the model expects normalized inputs to the output projection
final_norm = model._module_resolver.get_final_norm()
normalized = final_norm(hidden_state)
# Step 2: Get predictions through the output projection (unembedding)
predictions = model.get_token_predictions(normalized, top_k=5, return_scores=True)
# Step 3: Convert to readable format
results = []
for token_id, score in predictions:
token_str = model.token_to_str(token_id)
results.append((token_id, score, token_str))
return results
Step 4: Apply at Each Layer¶
# Get tokens
tokens = model.encode(prompt)
print(f"Input tokens: {[model.token_to_str(t) for t in tokens]}")
# Analyze the last token position (where the model predicts the next token)
print("\n" + "="*60)
print("Logit Lens: Predictions at each layer")
print("="*60)
# Find layer activations
from mlxterp.core.module_resolver import find_layer_key_pattern
for layer_idx in range(0, len(model.layers), 4): # Every 4th layer
layer_key = find_layer_key_pattern(trace.activations, layer_idx)
if layer_key is None:
continue
layer_output = trace.activations[layer_key] # Shape: (1, seq_len, hidden_dim)
last_token_hidden = layer_output[0, -1, :] # Shape: (hidden_dim,)
# Apply logit lens
predictions = manual_logit_lens(model, last_token_hidden)
# Show top prediction
top_token = predictions[0][2]
top_score = predictions[0][1]
print(f"Layer {layer_idx:2d}: '{top_token}' (score: {top_score:.3f})")
Part 2: Using the Built-in Method¶
mlxterp provides a built-in logit_lens() method that handles all of this automatically:
# Simple usage
results = model.logit_lens(
"The capital of France is",
layers=[0, 4, 8, 12, 15], # Analyze specific layers
top_k=3 # Get top 3 predictions
)
# Print results
print("\nLogit Lens Results (built-in method):")
print("-" * 50)
for layer_idx in sorted(results.keys()):
layer_preds = results[layer_idx]
last_pos_preds = layer_preds[-1] # Last token position
print(f"\nLayer {layer_idx}:")
for i, (token_id, score, token_str) in enumerate(last_pos_preds[:3]):
print(f" {i+1}. '{token_str}' (score: {score:.4f})")
Visualization¶
The logit_lens() method can generate visualizations:
# Generate a heatmap visualization
results = model.logit_lens(
"The capital of France is",
plot=True, # Enable visualization
max_display_tokens=10, # Limit tokens shown
figsize=(14, 8)
)
This produces a heatmap showing: - X-axis: Token positions in the input - Y-axis: Layer index - Color: Confidence of the top prediction - Cell text: Top predicted token
Part 3: Reproducing Paper Findings¶
Finding 1: Predictions Crystallize in Later Layers¶
The paper observed that correct predictions often appear suddenly in middle-to-late layers:
# Test with a factual prompt
prompt = "The Eiffel Tower is located in"
results = model.logit_lens(prompt, top_k=1)
print("When does the correct answer ('Paris') appear?")
print("-" * 50)
for layer_idx in sorted(results.keys()):
top_pred = results[layer_idx][-1][0][2] # Top prediction at last position
marker = " <-- CORRECT" if 'Paris' in top_pred else ""
print(f"Layer {layer_idx:2d}: {top_pred:15s}{marker}")
Finding 2: Early Layers Often Predict Related Tokens¶
# Early layers often predict semantically related but incorrect tokens
prompt = "Barack Obama was the 44th president of the United"
results = model.logit_lens(prompt, top_k=5)
print("\nEarly vs Late Layer Predictions:")
print("=" * 60)
early_layer = min(results.keys())
late_layer = max(results.keys())
print(f"\nEarly Layer ({early_layer}):")
for token_id, score, token_str in results[early_layer][-1][:5]:
print(f" '{token_str}': {score:.4f}")
print(f"\nLate Layer ({late_layer}):")
for token_id, score, token_str in results[late_layer][-1][:5]:
print(f" '{token_str}': {score:.4f}")
Finding 3: Position-Dependent Evolution¶
Different positions evolve differently:
prompt = "The quick brown fox jumps over the lazy"
results = model.logit_lens(prompt, layers=list(range(0, len(model.layers), 2)))
tokens = model.encode(prompt)
print("\nEvolution at different positions:")
print("=" * 60)
# Compare early vs late position
positions_to_check = [2, len(tokens)-1] # "brown" and final position
for pos in positions_to_check:
token_str = model.token_to_str(tokens[pos]) if pos < len(tokens) else "?"
print(f"\nPosition {pos} ('{token_str}'):")
for layer_idx in sorted(results.keys())[:4]:
if pos < len(results[layer_idx]):
top_pred = results[layer_idx][pos][0][2]
print(f" Layer {layer_idx:2d}: '{top_pred}'")
Part 4: Exercises¶
Exercise 1: Compare Different Prompt Types¶
Try the logit lens on different types of prompts:
prompts = [
# Factual
"The largest planet in our solar system is",
# Linguistic
"She went to the store and bought some",
# Arithmetic
"Two plus two equals",
# Code
"def hello_world():\n print(",
]
for prompt in prompts:
print(f"\nPrompt: '{prompt}'")
results = model.logit_lens(prompt, layers=[0, 8, 15])
final_pred = results[max(results.keys())][-1][0][2]
print(f"Final prediction: '{final_pred}'")
Exercise 2: Find the "Crystallization Point"¶
Write a function that finds the layer where the correct answer first appears:
def find_crystallization_layer(model, prompt, expected_token):
"""Find the first layer where the expected token becomes the top prediction."""
results = model.logit_lens(prompt)
for layer_idx in sorted(results.keys()):
top_pred = results[layer_idx][-1][0][2]
if expected_token.lower() in top_pred.lower():
return layer_idx
return None
# Test it
layer = find_crystallization_layer(
model,
"The capital of Japan is",
"Tokyo"
)
print(f"'Tokyo' crystallizes at layer: {layer}")
Exercise 3: Confidence Evolution¶
Plot how the confidence of the correct prediction changes across layers:
import matplotlib.pyplot as plt
def plot_confidence_evolution(model, prompt, target_token):
"""Plot confidence in a specific token across layers."""
results = model.logit_lens(prompt, top_k=100) # Get more predictions
layers = sorted(results.keys())
confidences = []
for layer_idx in layers:
preds = results[layer_idx][-1]
# Find confidence of target token
conf = 0.0
for token_id, score, token_str in preds:
if target_token.lower() in token_str.lower():
conf = score
break
confidences.append(conf)
plt.figure(figsize=(10, 5))
plt.plot(layers, confidences, 'b-o')
plt.xlabel('Layer')
plt.ylabel(f'Confidence in "{target_token}"')
plt.title(f'Confidence Evolution: "{prompt}"')
plt.grid(True, alpha=0.3)
plt.show()
# Try it
# plot_confidence_evolution(model, "The capital of France is", "Paris")
Summary¶
In this tutorial, you learned:
- The residual stream concept: Transformers maintain a running representation that each layer refines
- How the logit lens works: Project intermediate hidden states through the unembedding matrix
- Key finding: Models iteratively refine predictions, with correct answers often "crystallizing" in middle-to-late layers
- How to use mlxterp: Both manual implementation and the built-in
logit_lens()method
Next Steps¶
- Tutorial 2: Tuned Lens - Learn how to improve logit lens predictions with learned probes
- Tutorial 3: Causal Tracing - Use activation patching to localize factual knowledge
References¶
-
nostalgebraist (2020). interpreting GPT: the logit lens. LessWrong.
-
Elhage et al. (2021). A Mathematical Framework for Transformer Circuits. Anthropic.