Activation Patching Guide¶
Activation patching is a fundamental technique in mechanistic interpretability for identifying which components of a neural network are important for specific tasks.
Overview¶
Goal: Determine which layers or components are critical for a task by:
- Running a "clean" input through the model
- Running a "corrupted" input through the model
- Patching clean activations into the corrupted run at different locations
- Measuring how much this recovers the clean output
Key insight: If patching a component significantly recovers the clean output, that component is important for the task.
Quick Start: Using the Helper Function¶
The easiest way to perform activation patching is with the built-in activation_patching() method:
from mlxterp import InterpretableModel
from mlx_lm import load
# Load model
base_model, tokenizer = load('mlx-community/Llama-3.2-1B-Instruct-4bit')
model = InterpretableModel(base_model, tokenizer=tokenizer)
# Find important layers - that's it!
results = model.activation_patching(
clean_text="Paris is the capital of France",
corrupted_text="London is the capital of France",
component="mlp",
plot=True
)
# Analyze
sorted_results = sorted(results.items(), key=lambda x: x[1], reverse=True)
print("Top 3 most important layers:")
for layer_idx, recovery in sorted_results[:3]:
print(f" Layer {layer_idx}: {recovery:.1f}% recovery")
The helper function handles all the boilerplate: - Running clean and corrupted inputs - Patching each layer automatically - Measuring recovery with L2 distance - Optional visualization
Continue reading for details on interpretation and manual implementation.
Quick Example¶
import mlx.core as mx
from mlxterp import InterpretableModel, interventions as iv
from mlx_lm import load
# Load model
base_model, tokenizer = load('mlx-community/Llama-3.2-1B-Instruct-4bit')
model = InterpretableModel(base_model, tokenizer=tokenizer)
# Define clean vs corrupted inputs
clean_text = "Paris is the capital of France"
corrupted_text = "London is the capital of France"
# Get baseline outputs
with model.trace(clean_text):
clean_output = model.output.save()
with model.trace(corrupted_text):
corrupted_output = model.output.save()
mx.eval(clean_output, corrupted_output)
# Helper function to measure distance
def l2_distance(a, b):
return float(mx.sqrt(mx.sum((a - b) ** 2)))
baseline = l2_distance(corrupted_output[0, -1], clean_output[0, -1])
print(f"Baseline L2 distance: {baseline:.2f}")
# Patch MLP at layer 10
with model.trace(clean_text) as trace:
clean_mlp = trace.activations["model.model.layers.10.mlp"]
mx.eval(clean_mlp)
with model.trace(corrupted_text,
interventions={"layers.10.mlp": iv.replace_with(clean_mlp)}):
patched_output = model.output.save()
mx.eval(patched_output)
dist = l2_distance(patched_output[0, -1], clean_output[0, -1])
recovery = (baseline - dist) / baseline * 100
print(f"Layer 10 MLP: {recovery:.1f}% recovery")
Complete Procedure¶
Step 1: Define Clean and Corrupted Inputs¶
Choose inputs that differ in exactly the aspect you want to study:
# Factual knowledge task
clean_text = "Paris is the capital of France"
corrupted_text = "London is the capital of France"
# Sentiment task
clean_text = "This movie was amazing"
corrupted_text = "This movie was terrible"
# Grammatical task
clean_text = "The cat sits on the mat"
corrupted_text = "The cat sit on the mat"
Step 2: Get Baseline Measurements¶
# Get clean output
with model.trace(clean_text):
clean_output = model.output.save()
# Get corrupted output
with model.trace(corrupted_text):
corrupted_output = model.output.save()
mx.eval(clean_output, corrupted_output)
# Measure baseline distance
def l2_distance(a, b):
"""L2 (Euclidean) distance between output logits"""
return float(mx.sqrt(mx.sum((a - b) ** 2)))
baseline = l2_distance(corrupted_output[0, -1], clean_output[0, -1])
Step 3: Patch Each Component¶
results = {}
for layer_idx in range(len(model.layers)):
# Get clean activation for this component
with model.trace(clean_text) as trace:
clean_mlp = trace.activations[f"model.model.layers.{layer_idx}.mlp"]
mx.eval(clean_mlp)
# Patch into corrupted run
with model.trace(corrupted_text,
interventions={f"layers.{layer_idx}.mlp": iv.replace_with(clean_mlp)}):
patched_output = model.output.save()
mx.eval(patched_output)
# Measure recovery
dist = l2_distance(patched_output[0, -1], clean_output[0, -1])
recovery = (baseline - dist) / baseline * 100
results[layer_idx] = recovery
print(f"Layer {layer_idx:2d}: {recovery:6.1f}% recovery")
Step 4: Analyze Results¶
# Sort by importance
sorted_results = sorted(results.items(), key=lambda x: x[1], reverse=True)
print("\nMost important layers:")
for layer_idx, recovery in sorted_results[:5]:
print(f" Layer {layer_idx:2d}: {recovery:5.1f}% recovery")
Understanding Results¶
Positive Recovery¶
High positive recovery (e.g., +40%) = Important layer
- Patching this component significantly recovers clean output
- This layer is critical for the task
- Often seen in early layers (feature extraction) and final layers (output formation)
Example:
Negative Recovery¶
Negative recovery (e.g., -20%) = Encodes corruption
- Patching makes output WORSE than corrupted baseline
- This layer actively encodes the corrupted information
- This is expected and informative!
Example:
Near-Zero Recovery¶
~0% recovery = Not relevant
- Patching has minimal effect
- Layer doesn't significantly contribute to this specific task
- Might be important for other tasks
Components to Patch¶
You can patch different granularities:
Full MLP¶
MLP Sub-components¶
# Gate projection
interventions={"layers.{i}.mlp.gate_proj": iv.replace_with(clean_gate)}
# Up projection
interventions={"layers.{i}.mlp.up_proj": iv.replace_with(clean_up)}
# Down projection
interventions={"layers.{i}.mlp.down_proj": iv.replace_with(clean_down)}
Attention Components¶
# Full attention
interventions={"layers.{i}.self_attn": iv.replace_with(clean_attn)}
# Query projection
interventions={"layers.{i}.self_attn.q_proj": iv.replace_with(clean_q)}
# Key projection
interventions={"layers.{i}.self_attn.k_proj": iv.replace_with(clean_k)}
# Value projection
interventions={"layers.{i}.self_attn.v_proj": iv.replace_with(clean_v)}
# Output projection
interventions={"layers.{i}.self_attn.o_proj": iv.replace_with(clean_o)}
Common Pitfalls¶
❌ DON'T: Patch Entire Layer Output¶
# WRONG - patches entire residual stream
with model.trace(corrupted_text,
interventions={"layers.10": iv.replace_with(clean_act)}):
pass
Why wrong: This replaces the entire residual stream, affecting ALL downstream layers. You'll get perfect recovery for every layer.
❌ DON'T: Use Lambda Closures¶
# WRONG - lambda closure bug
for i in range(num_layers):
with model.trace(clean_text):
clean_act = model.layers[i].output.save()
# Bug: all interventions use the SAME activation
with model.trace(corrupted_text,
interventions={f"layers.{i}": lambda x: clean_act}):
pass
Why wrong: Lambda captures clean_act by reference, so all interventions end up using the last layer's activation.
✅ DO: Use iv.replace_with()¶
# CORRECT
with model.trace(corrupted_text,
interventions={"layers.10.mlp": iv.replace_with(clean_mlp)}):
pass
✅ DO: Choose the Right Distance Metric¶
The activation_patching() helper supports three distance metrics. Choose based on your model's vocabulary size:
# For small/medium models (< 50k vocab)
results = model.activation_patching(
clean_text="...",
corrupted_text="...",
metric="l2" # Default - Euclidean distance
)
# For large vocabulary models (> 100k vocab)
results = model.activation_patching(
clean_text="...",
corrupted_text="...",
metric="mse" # Most stable for huge models
)
Why: KL divergence can give NaN, and L2 can overflow on large vocabularies. See Distance Metrics section below.
Advanced: Position-Specific Patching¶
Patch activations only at specific token positions:
# Patch only the last token's MLP activation
with model.trace(clean_text) as trace:
clean_mlp = trace.activations["model.model.layers.10.mlp"]
# Create patched activation: clean for last token, corrupted for others
def selective_patch(corrupted_activation):
patched = corrupted_activation.copy()
patched[0, -1, :] = clean_mlp[0, -1, :] # Patch last token only
return patched
with model.trace(corrupted_text,
interventions={"layers.10.mlp": selective_patch}):
patched_output = model.output.save()
Distance Metrics¶
The activation_patching() helper uses distance metrics to measure how different the outputs are. Choosing the right metric is crucial, especially for large models.
Available Metrics¶
1. L2 Distance (Euclidean) - Default¶
Formula:
When to use: Small to medium models (vocabulary < 50k tokens)
Implementation:
def l2_distance(a, b):
diff = a - b
# Use float32 for accumulation to prevent overflow
diff_f32 = diff.astype(mx.float32)
squared_sum = mx.sum(diff_f32 * diff_f32)
# Check for overflow
if mx.isinf(squared_sum):
# Fallback to MSE-based calculation
mse = mx.mean(diff_f32 * diff_f32)
return float(mx.sqrt(mse) * mx.sqrt(float(diff.size)))
return float(mx.sqrt(squared_sum))
Why it can fail: With large vocabularies (e.g., 150k tokens), summing 150k squared differences can overflow to inf, especially in float16.
Example:
results = model.activation_patching(
clean_text="Paris is the capital of France",
corrupted_text="London is the capital of France",
metric="l2" # Default
)
2. Cosine Distance¶
Formula:
d(a, b) = 1 - (a · b) / (||a|| × ||b||)
where:
a · b = Σ(aᵢ × bᵢ) # Dot product
||a|| = √(Σ aᵢ²) # L2 norm
When to use: Medium to large models (50k - 150k tokens), or when you want direction-based similarity
Implementation:
def cosine_distance(a, b):
a_f32 = a.astype(mx.float32)
b_f32 = b.astype(mx.float32)
a_norm = mx.sqrt(mx.sum(a_f32 * a_f32))
b_norm = mx.sqrt(mx.sum(b_f32 * b_f32))
if mx.isinf(a_norm) or mx.isinf(b_norm):
# Fallback: normalize by mean instead of sum
a_normalized = a_f32 / mx.sqrt(mx.mean(a_f32 * a_f32))
b_normalized = b_f32 / mx.sqrt(mx.mean(b_f32 * b_f32))
return float(1.0 - mx.mean(a_normalized * b_normalized))
a_normalized = a_f32 / a_norm
b_normalized = b_f32 / b_norm
return float(1.0 - mx.sum(a_normalized * b_normalized))
Why it's better for large models: Normalization prevents overflow by dividing before accumulation.
Example:
results = model.activation_patching(
clean_text="Paris is the capital of France",
corrupted_text="London is the capital of France",
metric="cosine"
)
3. Mean Squared Error (MSE) - Most Stable¶
Formula:
When to use: Very large models (vocabulary > 100k tokens), or when numerical stability is critical
Implementation:
def mse_distance(a, b):
diff = a.astype(mx.float32) - b.astype(mx.float32)
return float(mx.mean(diff * diff))
Why it's most stable: Averages over all elements instead of summing, preventing overflow even with millions of dimensions.
Example:
# Recommended for Qwen (151k vocab), GPT-4 scale models
results = model.activation_patching(
clean_text="Paris is the capital of France",
corrupted_text="London is the capital of France",
metric="mse" # Most stable
)
Metric Selection Guide¶
| Model Characteristics | Recommended Metric | Reason |
|---|---|---|
| Vocab < 50k tokens (e.g., Llama-3.2-1B) |
"l2" (default) |
Fast, accurate, no overflow risk |
| Vocab 50k - 100k (e.g., Llama-3-8B) |
"l2" or "cosine" |
L2 with overflow protection works well |
| Vocab > 100k tokens (e.g., Qwen-30B: 151k) |
"mse" or "cosine" |
Most numerically stable |
| Direction matters (studying vector directions) |
"cosine" |
Measures angle, not magnitude |
| Magnitude matters (studying activation sizes) |
"l2" or "mse" |
Measures absolute difference |
Real-World Example: Qwen Model¶
The Qwen3-30B model has 151,936 tokens. Here's what happens with each metric:
from mlxterp import InterpretableModel
from mlx_lm import load
base_model, tokenizer = load('mlx-community/Qwen3-30B-A3B-Thinking-2507-4bit')
model = InterpretableModel(base_model, tokenizer=tokenizer)
With L2 (without protection):
Output shape: (1, 6, 151936) # 151k logits!
Baseline: inf # Overflow!
Recovery: nan, nan, nan... # All NaN
With MSE ✅:
Baseline: 0.6480
Layer 10: 17.9% recovery ← Works perfectly!
Layer 30: 7.5% recovery
Layer 0: -298.6% recovery
With Cosine ✅:
Baseline: 0.0079
Layer 10: 11.9% recovery ← Also works!
Layer 40: 9.1% recovery
Layer 0: -45.0% recovery
Recovery Calculation¶
Recovery percentage is computed as:
baseline_dist = distance(corrupted_output, clean_output)
patched_dist = distance(patched_output, clean_output)
recovery = (baseline_dist - patched_dist) / baseline_dist * 100
Interpretation: - High positive %: Patching reduced distance significantly → layer is important - Negative %: Patching increased distance → layer encodes the corruption - ~0%: Patching had no effect → layer is not relevant
Why Not KL Divergence?¶
KL divergence is commonly used in research papers, but it has numerical issues:
# KL Divergence (NOT recommended)
def kl_divergence(p, q):
p = mx.softmax(p, axis=-1)
q = mx.softmax(q, axis=-1)
return mx.sum(p * (mx.log(p) - mx.log(q))) # NaN from log(0)!
Problems:
1. log(0) produces -inf
2. Very small probabilities (< 1e-7) cause numerical instability
3. Requires adding epsilon: log(p + ε) - but what epsilon?
4. With 150k vocab, many probabilities are ~0
Better alternatives: L2, MSE, or cosine distance are more robust.
Interpreting Example Results¶
Layer 0 MLP: +43.1% recovery
Layer 2 MLP: +16.7% recovery
Layer 6 MLP: +17.6% recovery
Layer 7 MLP: -18.4% recovery
Layer 10 MLP: -23.5% recovery
Layer 15 MLP: +24.2% recovery
Interpretation:
- Layer 0 (43% recovery): Critical for early feature extraction
- Layers 2, 6 (16-17% recovery): Contribute to task but not critical
- Layers 7, 10 (negative): Encode the corruption ("London")
- Layer 15 (24% recovery): Important for final output formation
Insight: The model processes the factual knowledge primarily in early (Layer 0) and late (Layer 15) layers, while middle layers (7-10) encode the specific entity mentioned ("London").
Complete Working Example¶
See examples/activation_patching_example.py for a complete, tested implementation.
References¶
- Classic paper: Causal Tracing for GPT-2
- TransformerLens: Similar techniques in PyTorch
- nnsight: Generic activation patching framework