Causal Patching Guide¶
Overview¶
Causal patching (also called activation patching or interchange intervention) is the most widely used technique in mechanistic interpretability. It answers: which model components are causally responsible for a specific behavior?
The idea: run the model on a clean input and a corrupted input. Then re-run the corrupted input, but patch in the clean activation at a specific component. If the model recovers the clean behavior, that component was causally important.
Quick Start¶
from mlxterp import InterpretableModel
from mlxterp.causal import activation_patching
model = InterpretableModel("mlx-community/Llama-3.2-1B-Instruct-4bit")
# Which layers' MLPs are important for factual recall?
result = activation_patching(
model,
clean="The Eiffel Tower is in Paris",
corrupted="The Colosseum is in Paris",
component="mlp",
metric="l2",
)
result.plot() # Bar chart of per-layer effects
print(result.top_components(k=5))
Components You Can Patch¶
| Component | Description | Result Shape |
|---|---|---|
"resid_post" |
Full layer output (residual stream) | (n_layers,) |
"attn" |
Attention module output | (n_layers,) |
"mlp" |
MLP/feed-forward output | (n_layers,) |
"attn_head" |
Individual attention heads | (n_layers, n_heads) |
Choosing a Metric¶
| Metric | Best For | Notes |
|---|---|---|
"l2" |
General purpose | Normalized recovery, 0-1 scale |
"logit_diff" |
IOI, factual recall | Requires correct_token and incorrect_token |
"kl" |
Distribution comparison | Negative KL (higher = better) |
"cosine" |
Direction-sensitive tasks | Good for large vocabularies |
"ce_diff" |
Loss-based evaluation | Positive = patching reduced loss |
Complete Example: Factual Recall¶
from mlxterp import InterpretableModel
from mlxterp.causal import activation_patching
from mlxterp.metrics import logit_diff
model = InterpretableModel("mlx-community/Llama-3.2-1B-Instruct-4bit")
clean = "The Eiffel Tower is in"
corrupted = "The Colosseum is in"
# Step 1: Find important layers (MLP)
mlp_result = activation_patching(
model, clean, corrupted,
component="mlp",
metric="l2",
)
# Step 2: Find important layers (attention)
attn_result = activation_patching(
model, clean, corrupted,
component="attn",
metric="l2",
)
# Step 3: For the most important layer, find which heads matter
top_layer = mlp_result.top_components(k=1)[0][0]
print(f"Most important MLP layer: {top_layer}")
head_result = activation_patching(
model, clean, corrupted,
component="attn_head",
metric="l2",
layers=[top_layer - 1, top_layer, top_layer + 1],
)
head_result.plot() # Heatmap: layer x head
Position-Level Patching¶
Identify which token positions carry the critical information:
result = activation_patching(
model, clean, corrupted,
component="mlp",
positions=[3, 4, 5], # Only patch these positions
)
Using CausalTrace for Multi-Patch Experiments¶
When you want to patch multiple components simultaneously:
with model.causal_trace(clean, corrupted) as ct:
# Patch MLP at layers 5-7 and attention at layer 9
ct.patch("layers.5.mlp")
ct.patch("layers.6.mlp")
ct.patch("layers.7.mlp")
ct.patch("layers.9.self_attn")
# All patches applied at once
effect = ct.metric("l2")
print(f"Combined effect: {effect:.4f}")
Using logit_diff Metric¶
For tasks with a clear correct/incorrect answer:
correct_token = model.tokenizer.encode(" Paris")[-1]
incorrect_token = model.tokenizer.encode(" Rome")[-1]
result = activation_patching(
model, clean, corrupted,
component="mlp",
metric="logit_diff",
metric_kwargs={
"correct_token": correct_token,
"incorrect_token": incorrect_token,
},
)
Working with Results¶
Every patching function returns a PatchingResult:
# Summary
print(result.summary())
# Top components
for layer, effect in result.top_components(k=5):
print(f" Layer {layer}: {effect:.4f}")
# JSON export (for programmatic use)
json_data = result.to_json()
# Markdown report
print(result.to_markdown())
# Raw effect matrix
print(result.effect_matrix) # mx.array
print(result.effect_matrix.tolist()) # Python list
Interpreting Results¶
- High positive effect: This component is important. Patching the clean activation into the corrupted run recovered the clean behavior.
- Near-zero effect: This component doesn't contribute to the difference between clean and corrupted.
- Negative effect: Patching this component made things worse (rare, but indicates interference).
Rule of Thumb
Start with component="mlp" and component="attn" at all layers to get a broad picture. Then zoom into specific layers with component="attn_head" for head-level analysis.