mlxterp API Documentation¶
Complete API reference for mlxterp.
Table of Contents¶
- InterpretableModel
- Tracing
- Interventions
- Utilities
- Core Components
- Metrics
- Analysis Results
- Causal Interpretability
- Text Generation
- Conversation Analysis
- Visualization
- Research Workflows
- AutoInterp Ratchet Loop
- Automated Interpretability
- Report Generation
InterpretableModel¶
Class: InterpretableModel¶
Main entry point for wrapping MLX models to add interpretability features.
Constructor¶
InterpretableModel(
model: Union[nn.Module, str],
tokenizer: Optional[Any] = None,
layer_attr: str = "layers",
embedding_path: Optional[str] = None,
norm_path: Optional[str] = None,
lm_head_path: Optional[str] = None
)
Parameters:
- model (
nn.Moduleorstr): Either: - An MLX
nn.Moduleinstance to wrap -
A model name/path string (attempts to load via
mlx_lm.load()) -
tokenizer (
Optional[Any]): Tokenizer for processing text inputs. IfNoneand model is loaded from string, attempts to load tokenizer automatically. -
layer_attr (
str, default:"layers"): Name of the attribute containing the model's transformer layers. Common values: "layers"(Llama, Mistral)"h"(GPT-2)-
"transformer.h"(some GPT variants) -
embedding_path (
Optional[str], default:None): Override path for the token embedding layer. Used for weight-tied output projection inget_token_predictions. Auto-detected if not specified. Tried paths:model.embed_tokens,model.model.embed_tokens,embed_tokens,tok_embeddings,wte. -
norm_path (
Optional[str], default:None): Override path for the final layer normalization. Used inlogit_lensfor projecting intermediate activations. Auto-detected if not specified. Tried paths:model.norm,model.model.norm,norm,ln_f,model.ln_f. -
lm_head_path (
Optional[str], default:None): Override path for the output projection layer. If not found, falls back to weight-tied embedding. Tried paths:lm_head,model.lm_head,model.model.lm_head,output,head.
Returns: InterpretableModel instance
Example:
from mlxterp import InterpretableModel
# Load from model name (auto-detection works)
model = InterpretableModel("mlx-community/Llama-3.2-1B-Instruct")
# Wrap existing model
import mlx.nn as nn
base_model = nn.Module() # Your model
model = InterpretableModel(base_model, tokenizer=my_tokenizer)
# Custom layer attribute
model = InterpretableModel(gpt2_model, layer_attr="h")
# Custom model with non-standard attribute names
model = InterpretableModel(
custom_model,
tokenizer=my_tokenizer,
embedding_path="my_custom_embeddings",
norm_path="my_final_norm",
lm_head_path="my_output_projection"
)
Method: trace¶
Create a tracing context for capturing activations and applying interventions.
model.trace(
inputs: Union[str, List[str], mx.array, List[int]],
interventions: Optional[Dict[str, Callable]] = None
) -> Trace
Parameters:
- inputs: Input data in various formats:
str: Single text prompt (requires tokenizer)List[str]: Batch of text prompts (requires tokenizer)mx.array: Token array, shape(batch, seq_len)-
List[int]: Single sequence of token IDs -
interventions (
Optional[Dict[str, Callable]]): Dictionary mapping module names to intervention functions. Module names use dot notation (e.g.,"layers.3.self_attn"for mlx-lm models).
Returns: Trace context manager
Example:
# Basic tracing
with model.trace("Hello world"):
output = model.output.save()
# With interventions
from mlxterp import interventions as iv
with model.trace("Test", interventions={"layers.3": iv.scale(0.5)}):
output = model.output.save()
# Batch inputs
with model.trace(["Hello", "World"]):
acts = model.layers[3].output.save()
Method: encode¶
Encode text to token IDs using the model's tokenizer.
Parameters:
- text (
str): Text string to encode
Returns: List of token IDs
Raises: ValueError if no tokenizer is available
Example:
Method: decode¶
Decode token IDs to text using the model's tokenizer.
Parameters:
- tokens (
List[int]ormx.array): Token IDs to decode
Returns: Decoded text string
Raises: ValueError if no tokenizer is available
Example:
text = model.decode([128000, 9906, 1917])
print(text) # "<|begin_of_text|>Hello world"
# Also works with mx.array
import mlx.core as mx
tokens_array = mx.array([128000, 9906, 1917])
text = model.decode(tokens_array)
Method: encode_batch¶
Encode multiple texts to token IDs.
Parameters:
- texts (
List[str]): List of text strings to encode
Returns: List of token ID lists
Raises: ValueError if no tokenizer is available
Example:
token_lists = model.encode_batch(["Hello", "World", "Test"])
print(token_lists)
# [[128000, 9906], [128000, 10343], [128000, 2323]]
Method: token_to_str¶
Convert a single token ID to its string representation.
Parameters:
- token_id (
int): Token ID to decode
Returns: String representation of the token
Raises: ValueError if no tokenizer is available
Example:
# Decode individual tokens
tokens = model.encode("Hello world")
for i, token_id in enumerate(tokens):
token_str = model.token_to_str(token_id)
print(f"Token {i}: {token_id} -> '{token_str}'")
# Output:
# Token 0: 128000 -> '<|begin_of_text|>'
# Token 1: 9906 -> 'Hello'
# Token 2: 1917 -> ' world'
Property: vocab_size¶
Get the vocabulary size of the tokenizer.
Returns: Vocabulary size, or None if no tokenizer is available
Example:
Attribute: tokenizer¶
Direct access to the underlying tokenizer for advanced operations.
Type: Tokenizer object (varies by model)
Example:
# Access tokenizer directly for advanced features
tokenizer = model.tokenizer
# Use tokenizer-specific methods
if hasattr(tokenizer, 'special_tokens'):
print(tokenizer.special_tokens)
Method: get_token_predictions¶
Decode hidden states to token predictions using the model's output projection.
model.get_token_predictions(
hidden_state: mx.array,
top_k: int = 10,
return_scores: bool = False,
embedding_layer: Optional[Any] = None,
lm_head: Optional[Any] = None
) -> Union[List[int], List[tuple]]
Parameters:
- hidden_state (
mx.array): Hidden state tensor, shape(hidden_dim,)or(batch, hidden_dim) - top_k (
int, default:10): Number of top predictions to return - return_scores (
bool, default:False): If True, return(token_id, score)tuples - embedding_layer (
Optional[Any], default:None): Override embedding layer for weight-tied projection. If provided, uses this layer's weights for output projection. - lm_head (
Optional[Any], default:None): Override lm_head layer. If provided, uses this layer directly. Takes precedence overembedding_layer.
Returns: List of token IDs or (token_id, score) tuples
Example:
# Get predictions from a specific layer
with model.trace("The capital of France is") as trace:
layer_6 = trace.activations["model.model.layers.6"]
# Get last token's hidden state
last_token_hidden = layer_6[0, -1, :]
# Get top predictions
predictions = model.get_token_predictions(last_token_hidden, top_k=5)
# Decode to words
for token_id in predictions:
print(model.token_to_str(token_id))
# With scores
predictions_with_scores = model.get_token_predictions(
last_token_hidden,
top_k=5,
return_scores=True
)
for token_id, score in predictions_with_scores:
token_str = model.token_to_str(token_id)
print(f"{token_str}: {score:.2f}")
# Custom model with override at call time
predictions = model.get_token_predictions(
hidden,
top_k=5,
lm_head=custom_model.my_lm_head
)
Notes: - Automatically handles weight-tied models (uses embedding weights transposed) - Works with quantized embeddings (dequantizes automatically) - Model-agnostic: auto-detects embedding/lm_head paths for various architectures - Useful for analyzing what any layer "thinks" at a specific position
Method: logit_lens¶
Apply logit lens technique to see what each layer predicts at each token position.
The logit lens projects each layer's hidden states through the final layer norm and embedding matrix to see what tokens each layer predicts at each position in the input sequence.
model.logit_lens(
text: str,
top_k: int = 1,
layers: Optional[List[int]] = None,
position: Optional[int] = None,
plot: bool = False,
max_display_tokens: int = 15,
figsize: tuple = (16, 10),
cmap: str = 'viridis',
font_family: Optional[str] = None,
final_norm: Optional[Any] = None,
skip_norm: bool = False
) -> Dict[int, List[List[tuple]]]
Parameters:
- text (
str): Input text to analyze - top_k (
int, default:1): Number of top predictions to return per position - layers (
Optional[List[int]], default:None): Specific layers to analyze (None = all) - position (
Optional[int], default:None): Specific position to analyze (None = all). Supports negative indexing (-1 = last position). - plot (
bool, default:False): If True, display a heatmap visualization showing predictions - max_display_tokens (
int, default:15): Maximum number of tokens to show in visualization (from the end) - figsize (
tuple, default:(16, 10)): Figure size for plot (width, height) - cmap (
str, default:'viridis'): Colormap for heatmap - font_family (
Optional[str], default:None): Font for plot (use 'Arial Unicode MS' for CJK support) - final_norm (
Optional[Any], default:None): Override final layer normalization module. If provided, uses this module instead of auto-detected norm layer. - skip_norm (
bool, default:False): If True, skip final layer normalization entirely. Useful for models without a final norm layer.
Returns: Dict mapping layer_idx -> list of positions -> list of (token_id, score, token_str) tuples
Structure: {layer_idx: [[pos_0_predictions], [pos_1_predictions], ...]}
Model Compatibility: This method automatically detects model structure and works with:
- mlx-lm models (model.model.norm, model.model.embed_tokens)
- GPT-2 style models (ln_f, wte)
- Custom models (use norm_path constructor arg or final_norm/skip_norm parameters)
Example:
# Get predictions at all positions for all layers
results = model.logit_lens("The capital of France is")
# Access predictions for layer 10 at position 3
layer_10_predictions = results[10]
pos_3_top_pred = layer_10_predictions[3][0] # (token_id, score, token_str)
print(f"Layer 10, Position 3: {pos_3_top_pred[2]}")
# Show what each layer predicts at the LAST position
text = "The capital of France is"
results = model.logit_lens(text, layers=[0, 5, 10, 15])
for layer_idx in [0, 5, 10, 15]:
# Get prediction at last position
last_pos_pred = results[layer_idx][-1][0][2]
print(f"Layer {layer_idx}: '{last_pos_pred}'")
# Output:
# Layer 0: ' the'
# Layer 5: ' a'
# Layer 10: ' Paris'
# Layer 15: ' Paris'
# Show predictions at each position for a specific layer
text = "The capital of France"
results = model.logit_lens(text, layers=[10])
tokens = model.encode(text)
for pos_idx, predictions in enumerate(results[10]):
input_token = model.token_to_str(tokens[pos_idx])
pred_token = predictions[0][2] # Top prediction
print(f"Position {pos_idx} ('{input_token}') -> '{pred_token}'")
# Visualize with heatmap
results = model.logit_lens(
"The Eiffel Tower is located in the city of",
plot=True,
max_display_tokens=15,
figsize=(16, 10)
)
# Displays a heatmap with:
# - X-axis: Input token positions
# - Y-axis: Model layers
# - Cell values: Top predicted token at each (layer, position)
# - Colors: Different predictions shown with different colors
# Model without final normalization
results = model.logit_lens("Hello world", skip_norm=True)
# Custom final norm at call time
results = model.logit_lens(
"Hello world",
final_norm=custom_model.my_final_norm
)
Note: Plotting requires matplotlib: pip install matplotlib
Use Cases: - Understand how model predictions evolve through layers - Debug model behavior at intermediate layers - Visualize progressive refinement of predictions - Identify where in the model certain facts are computed
Method: tuned_lens¶
Apply tuned lens for improved layer-wise predictions.
The tuned lens technique (Belrose et al., 2023) uses learned affine transformations for each layer to correct for coordinate system mismatches between layers, producing more accurate intermediate predictions than the standard logit lens.
model.tuned_lens(
text: str,
tuned_lens: TunedLens,
top_k: int = 1,
layers: Optional[List[int]] = None,
position: Optional[int] = None,
plot: bool = False,
max_display_tokens: int = 15,
figsize: tuple = (16, 10),
cmap: str = 'viridis',
font_family: Optional[str] = None,
final_norm: Any = None,
skip_norm: bool = False
) -> Dict[int, List[List[tuple]]]
Parameters:
- text (
str): Input text to analyze - tuned_lens (
TunedLens): Trained TunedLens instance with layer translators - top_k (
int, default:1): Number of top predictions to return per position - layers (
Optional[List[int]], default:None): Specific layers to analyze (None = all) - position (
Optional[int], default:None): Specific position to analyze (None = all). Supports negative indexing. - plot (
bool, default:False): If True, display a heatmap visualization - max_display_tokens (
int, default:15): Maximum number of tokens to show in visualization - figsize (
tuple, default:(16, 10)): Figure size for plot - cmap (
str, default:'viridis'): Colormap for heatmap - font_family (
Optional[str]): Font for plot (auto-detected if None) - final_norm (
Any, default:None): Override for final layer norm. Pass a callable to use a custom norm. - skip_norm (
bool, default:False): If True, skip final layer normalization (for models without it)
Returns: Dict mapping layer_idx -> list of positions -> list of (token_id, score, token_str) tuples
Example:
from mlxterp import InterpretableModel, TunedLens
model = InterpretableModel("mlx-community/Llama-3.2-1B-Instruct")
# Option 1: Train new tuned lens
tuned_lens = model.train_tuned_lens(
dataset=["Sample text 1", "Sample text 2", ...],
num_steps=250,
save_path="tuned_lens_llama.npz"
)
# Option 2: Load pre-trained tuned lens
tuned_lens = model.load_tuned_lens("tuned_lens_llama.npz")
# Apply tuned lens
results = model.tuned_lens(
"The capital of France is",
tuned_lens,
layers=[0, 5, 10, 15],
plot=True
)
# Compare with regular logit lens
regular_results = model.logit_lens("The capital of France is", layers=[0, 5, 10, 15])
Reference: Belrose et al., "Eliciting Latent Predictions from Transformers with the Tuned Lens" (https://arxiv.org/abs/2303.08112)
Method: train_tuned_lens¶
Train a tuned lens for this model.
The tuned lens technique trains small affine transformations for each layer to correct for coordinate system mismatches, producing more accurate intermediate predictions.
model.train_tuned_lens(
dataset: List[str],
num_steps: int = 250,
learning_rate: float = 1.0,
momentum: float = 0.9,
max_seq_len: int = 2048,
gradient_clip: float = 1.0,
save_path: Optional[str] = None,
verbose: bool = True,
callback: Optional[Callable[[int, float], None]] = None
) -> TunedLens
Parameters:
- dataset (
List[str]): List of text strings for training - num_steps (
int, default:250): Number of training steps - learning_rate (
float, default:1.0): Initial learning rate (uses linear decay) - momentum (
float, default:0.9): Nesterov momentum coefficient - max_seq_len (
int, default:2048): Maximum sequence length for training chunks - gradient_clip (
float, default:1.0): Gradient clipping norm - save_path (
Optional[str]): Path to save trained weights - verbose (
bool, default:True): Print training progress - callback (
Optional[Callable]): Callback function called with(step, loss)after each step
Returns: Trained TunedLens instance
Raises:
ValueError: If dataset is empty or contains only whitespaceValueError: If dataset has fewer tokens thanmax_seq_lenValueError: Ifmax_seq_lenis less than 10 tokensValueError: If model hidden dimension cannot be determined
Training Details (from paper): - Optimizer: SGD with Nesterov momentum (0.9) - Learning rate: 1.0 with linear decay over training steps - Gradient clipping: norm 1.0 - Loss: KL divergence between translator prediction and final output
Example:
# Load sample texts for training
texts = [
"The capital of France is Paris.",
"Machine learning is a subset of artificial intelligence.",
# ... more training texts
]
# Train tuned lens
tuned_lens = model.train_tuned_lens(
dataset=texts,
num_steps=250,
save_path="my_tuned_lens.npz",
verbose=True
)
Method: load_tuned_lens¶
Load a pre-trained tuned lens from a file.
Parameters:
- path (
str): Path to the saved tuned lens weights (expects.npzand.jsonfiles)
Returns: Loaded TunedLens instance
Example:
tuned_lens = model.load_tuned_lens("tuned_lens_llama.npz")
results = model.tuned_lens("Hello world", tuned_lens)
Method: activation_patching¶
Automated activation patching to identify important layers for a task.
This helper method performs activation patching across all (or specified) layers to determine which components are critical for a specific task. It automates the boilerplate of running clean/corrupted inputs, patching activations, and measuring recovery.
model.activation_patching(
clean_text: str,
corrupted_text: str,
component: str = "mlp",
layers: Optional[List[int]] = None,
metric: str = "l2",
plot: bool = False,
figsize: tuple = (12, 8),
cmap: str = "RdBu_r"
) -> Dict[int, float]
Parameters:
- clean_text (
str): Clean/correct input text - corrupted_text (
str): Corrupted input text (differs in the aspect you're studying) - component (
str, default:"mlp"): Component to patch. Options: "mlp"- Full MLP block"self_attn"- Full attention block"mlp.gate_proj"- MLP gate projection"mlp.up_proj"- MLP up projection"mlp.down_proj"- MLP down projection"self_attn.q_proj"- Query projection"self_attn.k_proj"- Key projection"self_attn.v_proj"- Value projection"self_attn.o_proj"- Output projection- layers (
Optional[List[int]], default:None): Specific layers to test (None = all layers) - metric (
str, default:"l2"): Distance metric. Options: "l2": Euclidean distance (default, with overflow protection)"cosine": Cosine distance (recommended for large vocabularies)"mse": Mean squared error (most stable for huge models > 100k vocab)
Recommendation:
- Vocab < 50k: use "l2"
- Vocab 50k-100k: use "l2" or "cosine"
- Vocab > 100k: use "mse" or "cosine"
- plot (bool, default: False): If True, display a bar chart of recovery percentages
- figsize (tuple, default: (12, 8)): Figure size for plot
- cmap (str, default: "RdBu_r"): Colormap for plot (blue = positive, red = negative)
Returns: Dict mapping layer_idx -> recovery percentage
Recovery Interpretation: - Positive % (e.g., +40%): Layer is important for the task - Negative % (e.g., -20%): Layer encodes the corruption - ~0%: Layer is not relevant to this task
Example:
# Find which MLP layers are important for factual recall
results = model.activation_patching(
clean_text="Paris is the capital of France",
corrupted_text="London is the capital of France",
component="mlp",
plot=True
)
# Analyze results
sorted_results = sorted(results.items(), key=lambda x: x[1], reverse=True)
print("\nMost important layers:")
for layer_idx, recovery in sorted_results[:3]:
print(f" Layer {layer_idx}: {recovery:.1f}% recovery")
# Output:
# Layer 0: +43.1% recovery ← Very important!
# Layer 15: +24.2% recovery ← Important
# Layer 6: +17.6% recovery ← Somewhat important
# Test specific layers only
results = model.activation_patching(
clean_text="Paris is the capital of France",
corrupted_text="London is the capital of France",
component="self_attn",
layers=[0, 5, 10, 15],
plot=False
)
# Test MLP sub-components
results = model.activation_patching(
clean_text="The cat sits on the mat",
corrupted_text="The cat sit on the mat",
component="mlp.gate_proj",
layers=[3, 8, 12]
)
# For large vocabulary models (> 100k tokens), use MSE metric
results = model.activation_patching(
clean_text="Paris is the capital of France",
corrupted_text="London is the capital of France",
component="mlp",
metric="mse", # Most stable for Qwen, GPT-4 scale models
plot=True
)
Note: Plotting requires matplotlib: pip install matplotlib
Distance Metrics:
The method supports three distance metrics for measuring output differences:
-
L2 (Euclidean) - Default, best for models with < 50k vocab
-
Cosine - Best for direction-based similarity, good for 50k-150k vocab
-
MSE (Mean Squared Error) - Most stable for very large models (> 100k vocab)
Example with large vocabulary model (Qwen: 151k tokens):
# Without correct metric - gets NaN
results = model.activation_patching(..., metric="l2") # ❌ May overflow
# With correct metric - works perfectly
results = model.activation_patching(..., metric="mse") # ✅ Stable
See Also: Activation Patching Guide for comprehensive coverage including metric selection and numerical details
Method: named_modules¶
Iterator over all modules in the wrapped model with their names.
Returns: Iterator yielding (name, module) tuples
Example:
Attribute: layers¶
Indexed access to model layers via LayerListProxy.
Type: LayerListProxy
Example:
# Access specific layer
layer_3 = model.layers[3]
# Iterate over layers
for layer in model.layers:
print(layer)
# Get number of layers
num_layers = len(model.layers)
Attribute: output¶
Access to model output within a trace context. Returns an OutputProxy that can be saved.
Type: OutputProxy
Example:
Tracing¶
Class: Trace¶
Context manager for tracing model execution. Created by InterpretableModel.trace().
Attributes¶
- output (
mx.array): Model output after trace completes - saved_values (
Dict[str, mx.array]): Values saved with.save() - activations (
Dict[str, mx.array]): All captured activations by module name
Methods¶
get(name: str) -> Optional[mx.array]¶
Get a saved value by name.
with model.trace(input) as trace:
model.layers[3].output.save()
activation = trace.get("layers.3.output")
get_activation(name: str) -> Optional[mx.array]¶
Get an activation by module name.
with model.trace(input) as trace:
pass # Activations captured automatically
# Note: get_activation requires the full key (not normalized)
attn_act = trace.get_activation("model.model.layers.3.self_attn")
Class: OutputProxy¶
Wraps module outputs to provide .save() functionality.
Method: save() -> Any¶
Save the wrapped value to the current trace context and return the unwrapped value.
Returns: The underlying value (usually mx.array)
Example:
with model.trace(input):
# Save returns the actual array (use self_attn for mlx-lm models)
attn = model.layers[3].self_attn.output.save()
print(attn.shape) # Can use immediately
Interventions¶
Intervention functions modify activations during forward passes.
Module: mlxterp.interventions¶
Namespace containing pre-built intervention functions.
from mlxterp import interventions as iv
# Use intervention functions
with model.trace(input, interventions={"layers.3": iv.scale(0.5)}):
output = model.output.save()
Function: zero_out¶
Set activations to zero.
Example:
Function: scale¶
Multiply activations by a constant factor.
Parameters:
- factor (float): Scaling factor
Example:
# Reduce by 50%
with model.trace(input, interventions={"layers.3": iv.scale(0.5)}):
output = model.output.save()
# Amplify
with model.trace(input, interventions={"layers.3": iv.scale(2.0)}):
output = model.output.save()
Function: add_vector¶
Add a steering vector to activations.
Parameters:
- vector (mx.array): Vector to add (must be broadcastable to activation shape)
Example:
import mlx.core as mx
# Create steering vector
steering = mx.random.normal((hidden_dim,))
with model.trace(input, interventions={"layers.5": iv.add_vector(steering)}):
steered_output = model.output.save()
Function: replace_with¶
Replace activations with a fixed value.
Parameters: - value: Replacement value (array or scalar)
Example:
# Replace with zeros
with model.trace(input, interventions={"layers.3": iv.replace_with(0.0)}):
output = model.output.save()
# Replace with custom array
custom = mx.ones((batch, seq_len, hidden_dim))
with model.trace(input, interventions={"layers.3": iv.replace_with(custom)}):
output = model.output.save()
Function: clamp¶
Clamp activation values to a range.
Parameters:
- min_val (Optional[float]): Minimum value
- max_val (Optional[float]): Maximum value
Example:
# Clamp to [-1, 1]
with model.trace(input, interventions={"layers.3": iv.clamp(-1.0, 1.0)}):
output = model.output.save()
# Only maximum
with model.trace(input, interventions={"layers.3": iv.clamp(max_val=10.0)}):
output = model.output.save()
Function: noise¶
Add Gaussian noise to activations.
Parameters:
- std (float, default: 0.1): Standard deviation of noise
Example:
with model.trace(input, interventions={"layers.3": iv.noise(std=0.2)}):
output = model.output.save()
Class: InterventionComposer¶
Compose multiple interventions into a single function.
Method: add¶
Add an intervention to the composition.
Returns: self for chaining
Method: build¶
Build the composed intervention function.
Returns: Composed intervention function
Example:
from mlxterp import interventions as iv
# Compose multiple interventions
combined = iv.compose() \
.add(iv.scale(0.8)) \
.add(iv.noise(std=0.1)) \
.add(iv.clamp(-5.0, 5.0)) \
.build()
with model.trace(input, interventions={"layers.3": combined}):
output = model.output.save()
Custom Interventions¶
Create your own intervention functions:
import mlx.core as mx
def my_intervention(activation: mx.array) -> mx.array:
"""Custom activation modification"""
# Your logic here
return mx.tanh(activation)
with model.trace(input, interventions={"layers.3": my_intervention}):
output = model.output.save()
Requirements:
- Function signature: (mx.array) -> mx.array
- Must return array with same shape as input
- Can use any MLX operations
Utilities¶
Function: get_activations¶
Collect activations for specified layers and token positions.
get_activations(
model: InterpretableModel,
prompts: Union[str, List[str]],
layers: Optional[List[int]] = None,
positions: Union[int, List[int]] = -1
) -> Dict[str, mx.array]
Parameters:
- model: InterpretableModel instance
- prompts: Single prompt or list of prompts
- layers: Layer indices to collect (None = all layers)
- positions: Token position(s) to extract
-1: Last token0: First token[0, -1]: First and last tokens
Returns: Dict mapping "layer_{i}" to activation arrays
Shapes:
- Single position: (batch_size, hidden_dim)
- Multiple positions: (batch_size, num_positions, hidden_dim)
Example:
from mlxterp import get_activations
# Single prompt, multiple layers
acts = get_activations(model, "Hello world", layers=[3, 8, 12])
print(acts["layer_3"].shape) # (1, hidden_dim)
# Batch prompts
acts = get_activations(
model,
["Hello", "World", "Test"],
layers=[5],
positions=-1
)
print(acts["layer_5"].shape) # (3, hidden_dim)
# Multiple positions
acts = get_activations(
model,
"Test prompt",
layers=[3],
positions=[0, -1] # First and last token
)
print(acts["layer_3"].shape) # (1, 2, hidden_dim)
Function: batch_get_activations¶
Memory-efficient batch processing for large datasets.
batch_get_activations(
model: InterpretableModel,
prompts: List[str],
layers: Optional[List[int]] = None,
positions: Union[int, List[int]] = -1,
batch_size: int = 8
) -> Dict[str, mx.array]
Parameters:
- model: InterpretableModel instance
- prompts: List of prompts
- layers: Layer indices to collect
- positions: Token position(s) to extract
- batch_size: Number of prompts per batch
Returns: Dict mapping "layer_{i}" to concatenated activation arrays
Example:
from mlxterp import batch_get_activations
# Process 1000 prompts efficiently
large_dataset = [f"Prompt {i}" for i in range(1000)]
acts = batch_get_activations(
model,
prompts=large_dataset,
layers=[3, 8, 12],
batch_size=32
)
print(acts["layer_3"].shape) # (1000, hidden_dim)
Function: collect_activations¶
Direct activation collection with caching.
collect_activations(
model: InterpretableModel,
inputs: Any,
layers: Optional[List[str]] = None
) -> ActivationCache
Parameters:
- model: InterpretableModel instance
- inputs: Input data
- layers: List of layer names to cache (None = all)
Returns: ActivationCache object
Example:
from mlxterp import collect_activations
cache = collect_activations(
model,
"Test input",
layers=["layers.3", "layers.8"]
)
# Access cached activations
act_3 = cache.get("layers.3")
print(f"Cached {len(cache)} activations")
print(f"Available keys: {cache.keys()}")
Core Components¶
Advanced usage: Direct access to core components.
Class: ModuleProxy¶
Wraps nn.Module to intercept forward passes. Created automatically by InterpretableModel.
Attributes:
- output: OutputProxy for the module's output
Example:
# Access through InterpretableModel.layers
proxy = model.layers[3] # Returns ModuleProxy
print(type(proxy)) # ModuleProxy
with model.trace(input):
act = proxy.output.save()
Class: LayerListProxy¶
Provides indexed access to model layers.
Methods:
- __getitem__(idx): Get layer at index
- __len__(): Number of layers
- __iter__(): Iterate over layers
Example:
# Created automatically
layers = model.layers
# Access
layer_3 = layers[3]
# Length
print(len(layers)) # 12
# Iterate
for i, layer in enumerate(layers):
print(f"Layer {i}: {layer}")
Class: ActivationCache¶
Storage for cached activations.
Attributes:
- activations (Dict[str, mx.array]): Activation storage
- metadata (Optional[Dict]): Additional information
Methods:
- get(name): Get activation by name
- keys(): List all cached names
- __contains__(name): Check if activation exists
- __len__(): Number of cached activations
Example:
from mlxterp import collect_activations
cache = collect_activations(model, input)
# Access
act = cache.get("layers.3")
# Check existence
if "layers.3" in cache:
print("Found!")
# List all
for name in cache.keys():
print(name)
Class: ModuleResolver¶
Generic module resolution for different MLX model architectures. Automatically finds embedding, final norm, and lm_head modules using fallback chains.
Constructor:
ModuleResolver(
model: nn.Module,
embedding_path: Optional[str] = None,
norm_path: Optional[str] = None,
lm_head_path: Optional[str] = None
)
Methods:
get_embedding_layer(): Get token embedding layerget_final_norm(): Get final layer normalizationget_lm_head(): Get output projection layerget_output_projection(): Get output projection with weight-tied detection. Returns(module, path, is_weight_tied)clear_cache(): Clear resolved module cache (call after modifying model structure)
Fallback Chains:
| Component | Resolution Order |
|---|---|
| Embedding | model.embed_tokens, model.model.embed_tokens, embed_tokens, tok_embeddings, wte |
| Final Norm | model.norm, model.model.norm, norm, ln_f, model.ln_f |
| LM Head | lm_head, model.lm_head, model.model.lm_head, output, head (falls back to embedding if not found) |
Example:
from mlxterp.core import ModuleResolver
# Create resolver for a model
resolver = ModuleResolver(model)
# Get components
embedding = resolver.get_embedding_layer()
norm = resolver.get_final_norm()
proj, path, is_tied = resolver.get_output_projection()
if is_tied:
print("Model uses weight-tied embedding for output")
# With custom paths
resolver = ModuleResolver(
model,
embedding_path="my_embed",
norm_path="my_norm"
)
# Cache invalidation (after modifying model structure)
resolver.clear_cache() # Force re-resolution on next access
Class: TunedLens¶
Learned affine translators for each layer, implementing the Tuned Lens technique from Belrose et al. (2023).
The tuned lens uses layer-specific affine transformations (Wx + b) to map hidden states from each layer into a space where the final output projection can make accurate predictions. This corrects for coordinate system mismatches between layers.
Parameters:
- num_layers (
int): Number of transformer layers in the model - hidden_dim (
int): Dimension of hidden states
Attributes:
num_layers: Number of layershidden_dim: Hidden dimensiontranslators: List of linear layers, one per transformer layer
Methods:
| Method | Description |
|---|---|
__call__(h, layer_idx) |
Apply translator for a specific layer |
save(path) |
Save weights and config to files (.npz and .json) |
load(path) |
Load tuned lens from saved files (classmethod) |
Example:
from mlxterp import TunedLens
# Create tuned lens
tuned_lens = TunedLens(num_layers=32, hidden_dim=4096)
# Apply to hidden state from layer 10
translated = tuned_lens(hidden_state, layer_idx=10)
# Save and load
tuned_lens.save("my_tuned_lens")
loaded = TunedLens.load("my_tuned_lens")
Reference: Belrose et al., "Eliciting Latent Predictions from Transformers with the Tuned Lens" (https://arxiv.org/abs/2303.08112)
Function: normalize_layer_key¶
Normalize activation keys by removing model prefixes.
Example:
from mlxterp.core import normalize_layer_key
normalize_layer_key("model.model.layers.0") # "layers.0"
normalize_layer_key("model.layers.5.self_attn") # "layers.5.self_attn"
Function: find_layer_key_pattern¶
Find the correct activation key pattern for a layer index.
find_layer_key_pattern(
activations: dict,
layer_idx: int,
component: Optional[str] = None
) -> Optional[str]
Example:
from mlxterp.core import find_layer_key_pattern
# Find layer 5's key in activations dict
key = find_layer_key_pattern(trace.activations, 5)
# Returns "model.model.layers.5" or "layers.5" etc.
# Find specific component
key = find_layer_key_pattern(trace.activations, 5, "self_attn")
Type Annotations¶
Common types used in mlxterp:
from typing import Union, List, Dict, Callable, Optional, Any
import mlx.core as mx
import mlx.nn as nn
# Input types
InputType = Union[str, List[str], mx.array, List[int]]
# Intervention function type
InterventionFn = Callable[[mx.array], mx.array]
# Interventions dict
InterventionsDict = Dict[str, InterventionFn]
# Model type
ModelType = Union[nn.Module, str]
Error Handling¶
Common Exceptions¶
ValueError¶
Raised when: - String input provided without tokenizer - Invalid input format - Model cannot be loaded from string
# Will raise ValueError
model = InterpretableModel(base_model) # No tokenizer
with model.trace("text input"): # Needs tokenizer!
pass
Solution: Provide tokenizer
AttributeError¶
Raised when: - Accessing non-existent module attribute - Layer attribute doesn't exist
# Will raise AttributeError
model = InterpretableModel(custom_model, layer_attr="transformer")
# If custom_model doesn't have 'transformer' attribute
Solution: Specify correct layer attribute
Best Practices¶
1. Always Use Context Managers¶
# ✅ Good
with model.trace(input):
act = model.layers[3].output.save()
# ❌ Avoid
trace = model.trace(input)
# Missing context manager!
2. Save Early, Access Later¶
# ✅ Good
with model.trace(input) as t:
model.layers[3].output.save()
# Access after trace completes
act = t.get("layers.3.output")
# ❌ Avoid trying to access during trace
with model.trace(input) as t:
act = t.get("layers.3.output") # Not saved yet!
3. Use Utility Functions for Common Tasks¶
# ✅ Good - Use get_activations
from mlxterp import get_activations
acts = get_activations(model, prompts, layers=[3, 8])
# ❌ Avoid manual loops
acts = {}
for layer_idx in [3, 8]:
with model.trace(prompts):
acts[layer_idx] = model.layers[layer_idx].output.save()
4. Batch Large Datasets¶
# ✅ Good - Use batching
from mlxterp import batch_get_activations
acts = batch_get_activations(model, large_list, batch_size=32)
# ❌ Avoid loading everything at once
with model.trace(large_list): # May run out of memory
acts = model.layers[3].output.save()
Visualization Module¶
The visualization module provides tools for attention pattern analysis and visualization.
from mlxterp.visualization import (
# Attention extraction and visualization
get_attention_patterns,
attention_heatmap,
attention_from_trace,
AttentionVisualizationConfig,
# Pattern detection
AttentionPatternDetector,
detect_head_types,
detect_induction_heads,
induction_score,
previous_token_score,
first_token_score,
copying_score,
)
Function: get_attention_patterns¶
Extract attention weight patterns from a trace.
Parameters:
- trace (
Trace): Completed trace context with captured attention weights - layers (
Optional[List[int]], default:None): Specific layers to extract (None = all)
Returns: Dict mapping layer_idx -> attention array of shape (batch, heads, seq_len, seq_len)
Example:
from mlxterp import InterpretableModel
from mlxterp.visualization import get_attention_patterns
model = InterpretableModel("mlx-community/Llama-3.2-1B-Instruct-4bit")
with model.trace("The capital of France is") as trace:
pass
# Get all attention patterns
patterns = get_attention_patterns(trace)
print(f"Found {len(patterns)} layers")
print(f"Shape: {patterns[0].shape}") # (1, 32, 6, 6)
# Get specific layers
patterns = get_attention_patterns(trace, layers=[0, 5, 10])
Function: attention_heatmap¶
Create a heatmap visualization of attention patterns.
attention_heatmap(
attention: np.ndarray,
tokens: List[str],
head_idx: int = 0,
title: Optional[str] = None,
colorscale: str = "Blues",
backend: str = "matplotlib",
mask_upper_tri: bool = True,
figsize: tuple = (8, 6)
) -> Any
Parameters:
- attention (
np.ndarray): Attention weights, shape(batch, heads, seq_q, seq_k) - tokens (
List[str]): Token strings for axis labels - head_idx (
int, default:0): Which attention head to visualize - title (
Optional[str]): Plot title - colorscale (
str, default:"Blues"): Colormap name - backend (
str, default:"matplotlib"): Backend ("matplotlib","plotly","circuitsviz") - mask_upper_tri (
bool, default:True): Mask future positions (for causal attention) - figsize (
tuple, default:(8, 6)): Figure size for matplotlib
Returns: Figure object (type depends on backend)
Example:
from mlxterp.visualization import get_attention_patterns, attention_heatmap
with model.trace("Hello world") as trace:
pass
patterns = get_attention_patterns(trace, layers=[5])
tokens = model.to_str_tokens("Hello world")
# Create heatmap for head 0
fig = attention_heatmap(
patterns[5],
tokens,
head_idx=0,
title="Layer 5, Head 0",
backend="matplotlib"
)
Function: attention_from_trace¶
High-level function to visualize attention patterns from a trace.
attention_from_trace(
trace: Trace,
tokens: List[str],
layers: Optional[List[int]] = None,
heads: Optional[List[int]] = None,
mode: str = "single",
head_notation: str = "LH",
config: Optional[AttentionVisualizationConfig] = None
) -> Any
Parameters:
- trace (
Trace): Completed trace context - tokens (
List[str]): Token strings for labels - layers (
Optional[List[int]]): Layers to visualize (default: first layer) - heads (
Optional[List[int]]): Heads to visualize (default: first head) - mode (
str, default:"single"): Visualization mode: "single": One heatmap"grid": Grid of multiple heatmaps- head_notation (
str, default:"LH"): Notation for titles ("LH"= L5H3,"dot"= 5.3) - config (
Optional[AttentionVisualizationConfig]): Custom configuration
Returns: Figure object
Example:
from mlxterp.visualization import attention_from_trace, AttentionVisualizationConfig
with model.trace("The quick brown fox") as trace:
pass
tokens = model.to_str_tokens("The quick brown fox")
# Single heatmap
fig = attention_from_trace(trace, tokens, layers=[5], heads=[0], mode="single")
# Grid of multiple heads
config = AttentionVisualizationConfig(backend="matplotlib", colorscale="Blues")
fig = attention_from_trace(
trace, tokens,
layers=[0, 5, 10],
heads=[0, 1, 2, 3],
mode="grid",
config=config
)
Class: AttentionVisualizationConfig¶
Configuration for attention visualization.
AttentionVisualizationConfig(
colorscale: str = "Blues",
mask_upper_tri: bool = True,
backend: str = "auto",
figsize: tuple = (10, 8),
show_colorbar: bool = True,
font_size: int = 10
)
Parameters:
- colorscale (
str): Colormap name (default:"Blues") - mask_upper_tri (
bool): Mask future positions (default:True) - backend (
str): Visualization backend (default:"auto"- tries circuitsviz, then plotly, then matplotlib) - figsize (
tuple): Figure size (default:(10, 8)) - show_colorbar (
bool): Show colorbar (default:True) - font_size (
int): Font size for labels (default:10)
Function: induction_score¶
Compute induction head score for an attention pattern.
Induction heads implement the pattern [A][B]...[A] -> predict [B] by attending from position i to position i - seq_len + 1 in repeated sequences.
Parameters:
- attention_pattern (
np.ndarray): Attention weights, shape(seq_q, seq_k) - seq_len (
int): Length of the repeated subsequence
Returns: Induction score (0-1, higher = more induction-like)
Example:
import numpy as np
from mlxterp.visualization import induction_score
# Create pattern from repeated sequence "ABC ABC"
# In a true induction head, position 4 (second A) attends to position 1 (first B)
seq_len = 3
total_len = 6
# Synthetic perfect induction pattern
pattern = np.zeros((total_len, total_len))
for i in range(seq_len, total_len):
pattern[i, i - seq_len + 1] = 1.0
score = induction_score(pattern, seq_len)
print(f"Induction score: {score:.3f}") # ~1.0
Function: previous_token_score¶
Compute previous token head score.
Previous token heads attend strongly to position i-1 (the immediately preceding token).
Parameters:
- attention_pattern (
np.ndarray): Attention weights, shape(seq_q, seq_k)
Returns: Previous token score (0-1, higher = attends more to previous position)
Example:
from mlxterp.visualization import previous_token_score
# Perfect previous token pattern
pattern = np.zeros((5, 5))
for i in range(1, 5):
pattern[i, i-1] = 1.0
score = previous_token_score(pattern)
print(f"Previous token score: {score:.3f}") # ~1.0
Function: first_token_score¶
Compute first token (BOS) head score.
First token heads attend strongly to position 0 (typically BOS or first token).
Parameters:
- attention_pattern (
np.ndarray): Attention weights, shape(seq_q, seq_k)
Returns: First token score (0-1, higher = attends more to position 0)
Example:
from mlxterp.visualization import first_token_score
# Perfect first token pattern
pattern = np.zeros((5, 5))
pattern[:, 0] = 1.0 # All positions attend to first
score = first_token_score(pattern)
print(f"First token score: {score:.3f}") # ~1.0
Function: copying_score¶
Compute copying head score from OV circuit.
Copying heads increase the logit of the attended-to token.
Parameters:
- ov_circuit (
np.ndarray): OV circuit matrixW_V @ W_O, shape(d_model, d_model) - unembedding (
Optional[np.ndarray]): Optional unembedding matrix for full analysis
Returns: Copying score (higher = more copying behavior)
Class: AttentionPatternDetector¶
Detector for classifying attention head types.
AttentionPatternDetector(
induction_threshold: float = 0.4,
previous_token_threshold: float = 0.5,
first_token_threshold: float = 0.3,
current_token_threshold: float = 0.3
)
Parameters:
- induction_threshold (
float): Threshold for induction head classification - previous_token_threshold (
float): Threshold for previous token head - first_token_threshold (
float): Threshold for first token head - current_token_threshold (
float): Threshold for current token (self-attention) head
Methods:
analyze_head¶
Compute all pattern scores for a single attention head.
analyze_head(
attention_pattern: np.ndarray,
seq_len_for_induction: Optional[int] = None
) -> Dict[str, float]
Returns: Dict of pattern type -> score
classify_head¶
Classify an attention head into one or more types.
classify_head(
attention_pattern: np.ndarray,
seq_len_for_induction: Optional[int] = None
) -> List[str]
Returns: List of head type labels that exceed thresholds
Example:
from mlxterp.visualization import AttentionPatternDetector
detector = AttentionPatternDetector(
previous_token_threshold=0.5,
first_token_threshold=0.3
)
# Analyze a head
scores = detector.analyze_head(attention_pattern)
print(f"Scores: {scores}")
# {'previous_token': 0.85, 'first_token': 0.1, 'current_token': 0.05}
# Classify
types = detector.classify_head(attention_pattern)
print(f"Classification: {types}") # ['previous_token']
Function: detect_head_types¶
Detect attention head types across a model.
detect_head_types(
model: InterpretableModel,
text: str,
threshold: float = 0.4,
layers: Optional[List[int]] = None
) -> Dict[str, List[Tuple[int, int]]]
Parameters:
- model (
InterpretableModel): Model to analyze - text (
str): Input text for analysis - threshold (
float, default:0.4): Score threshold for classification - layers (
Optional[List[int]]): Specific layers to analyze (None = all)
Returns: Dict mapping head type to list of (layer, head) tuples
Example:
from mlxterp.visualization import detect_head_types
head_types = detect_head_types(
model,
"The quick brown fox jumps over the lazy dog",
threshold=0.3,
layers=[0, 5, 10, 15]
)
print(f"Previous token heads: {len(head_types['previous_token'])}")
print(f"First token heads: {len(head_types['first_token'])}")
# Print specific heads
for layer, head in head_types['previous_token'][:5]:
print(f" L{layer}H{head}")
Function: detect_induction_heads¶
Detect induction heads using repeated random token sequences.
detect_induction_heads(
model: InterpretableModel,
n_random_tokens: int = 50,
n_repeats: int = 2,
threshold: float = 0.4,
layers: Optional[List[int]] = None,
seed: int = 42
) -> List[HeadScore]
Parameters:
- model (
InterpretableModel): Model to analyze - n_random_tokens (
int, default:50): Number of random tokens in subsequence - n_repeats (
int, default:2): Number of times to repeat the subsequence - threshold (
float, default:0.4): Score threshold for detection - layers (
Optional[List[int]]): Specific layers to analyze (None = all) - seed (
int, default:42): Random seed for reproducibility
Returns: List of HeadScore objects for heads above threshold, sorted by score descending
Example:
from mlxterp.visualization import detect_induction_heads
# Find induction heads
induction_heads = detect_induction_heads(
model,
n_random_tokens=50,
threshold=0.3,
layers=[0, 5, 10, 15]
)
print(f"Found {len(induction_heads)} induction heads")
for head in induction_heads[:10]:
print(f" L{head.layer}H{head.head}: {head.score:.3f}")
Class: HeadScore¶
Score for an attention head (dataclass).
Method: InterpretableModel.to_str_tokens¶
Convert text or token IDs to a list of token strings.
model.to_str_tokens(
input: Union[str, List[int], mx.array],
prepend_bos: bool = False
) -> List[str]
Parameters:
- input: Text string, list of token IDs, or mx.array of tokens
- prepend_bos (
bool, default:False): Whether to prepend BOS token (only for string input)
Returns: List of token strings
Example:
# From text
tokens = model.to_str_tokens("Hello world")
print(tokens) # ['<|begin_of_text|>', 'Hello', ' world']
# From token IDs
token_ids = model.encode("Hello world")
tokens = model.to_str_tokens(token_ids)
print(tokens) # ['<|begin_of_text|>', 'Hello', ' world']
Metrics¶
Module: mlxterp.metrics
Standard metrics for causal interpretability experiments. All metrics follow a uniform signature so they can be passed to any patching or attribution function.
Metric Signature¶
Every metric function accepts:
def metric(
patched_logits: mx.array, # Logits after intervention
clean_logits: mx.array, # Logits from clean/correct input
corrupted_logits: mx.array, # Logits from corrupted input
**kwargs, # Metric-specific parameters
) -> float:
Returns a scalar where higher = more recovery of clean behavior.
Function: logit_diff¶
Measures recovery of the logit difference between correct and incorrect tokens. The standard metric for Indirect Object Identification (IOI) and factual recall tasks.
from mlxterp.metrics import logit_diff
effect = logit_diff(
patched_logits, clean_logits, corrupted_logits,
correct_token=tokenizer.encode(" Paris")[0],
incorrect_token=tokenizer.encode(" Rome")[0],
)
# Returns: 1.0 = fully recovered, 0.0 = no change, <0 = worse
Required kwargs:
- correct_token (
int): Token ID for the correct completion - incorrect_token (
int): Token ID for the incorrect completion
Function: kl_divergence¶
KL divergence between patched and clean output distributions. Returns negative KL so higher = better (consistent with other metrics).
from mlxterp.metrics import kl_divergence
effect = kl_divergence(patched_logits, clean_logits, corrupted_logits)
# Returns: 0.0 when distributions match, negative when they diverge
Function: cross_entropy_diff¶
Difference in cross-entropy loss: CE(corrupted) - CE(patched). Positive means patching reduced the loss.
from mlxterp.metrics import cross_entropy_diff
effect = cross_entropy_diff(
patched_logits, clean_logits, corrupted_logits,
target_token=42, # Optional: auto-detected from clean argmax if omitted
)
Function: l2_distance¶
Normalized L2 recovery: what fraction of the clean-corrupted distance was recovered.
from mlxterp.metrics import l2_distance
effect = l2_distance(patched_logits, clean_logits, corrupted_logits)
# Returns: 1.0 = patched == clean, 0.0 = patched == corrupted
Function: cosine_distance¶
Cosine similarity recovery between patched and clean outputs.
from mlxterp.metrics import cosine_distance
effect = cosine_distance(patched_logits, clean_logits, corrupted_logits)
Function: get_metric¶
Look up a metric by name string, or pass through a callable.
from mlxterp.metrics import get_metric
fn = get_metric("logit_diff") # Returns the logit_diff function
fn = get_metric("l2") # Alias for l2_distance
fn = get_metric(my_custom_fn) # Returns my_custom_fn unchanged
Available names: "logit_diff", "kl_divergence" / "kl", "cross_entropy_diff" / "ce_diff", "l2_distance" / "l2", "cosine_distance" / "cosine"
Analysis Results¶
Module: mlxterp.results
Every analysis function returns a structured result object with .data, .summary(), .to_json(), .to_markdown(), and .plot() methods. This enables both human consumption and agent/programmatic processing.
Class: AnalysisResult¶
Base class for all results.
from mlxterp.results import AnalysisResult
result.data # Dict of structured results
result.metadata # Dict of parameters and timing
result.result_type # String tag (e.g., "patching", "dla")
result.summary() # One-line human-readable summary
result.to_json() # JSON serialization (handles mx.array)
result.to_markdown() # Markdown report
result.plot() # Visualization (subclass-specific)
Class: PatchingResult¶
Returned by activation_patching() and path_patching().
from mlxterp.causal import activation_patching
result = activation_patching(model, clean, corrupted, component="mlp")
result.effect_matrix # mx.array: (n_layers,) or (n_layers, n_heads)
result.layers # List of layer indices tested
result.component # "mlp", "attn", "resid_post", "attn_head"
result.metric_name # Name of the metric used
# Get top components by effect size
top = result.top_components(k=5)
# Returns: [(layer_idx, effect_score), ...]
# Visualize
result.plot() # Heatmap or bar chart depending on dimensions
Class: AttributionResult¶
Returned by attribution_patching().
result.attribution_scores # mx.array of scores
result.layers # Layers analyzed
result.component # Component analyzed
result.method # "finite_diff" or "gradient"
Class: DLAResult¶
Returned by direct_logit_attribution().
result.head_contributions # mx.array: per-attention-layer contributions
result.mlp_contributions # mx.array: per-MLP contributions
result.target_token # Token ID being attributed
result.target_token_str # String representation
Class: GenerationResult¶
Returned by model.generate().
result.text # Generated text string
result.tokens # List of token IDs
result.token_logits # Per-token logit distributions (optional)
result.prompt # Original prompt
Class: ConversationResult¶
Returned by ConversationTrace.to_result().
result.turns # List of turn metadata dicts
result.cross_turn_attention # Turn x turn attention matrix
Class: CircuitResult¶
Returned by acdc() and feature_circuit().
result.nodes # List of component names in the circuit
result.edges # List of (sender, receiver, weight) tuples
result.threshold # Pruning threshold used
# Export as graph
graph = result.to_graph()
# Returns: {"nodes": [{"id": ...}], "edges": [{"source": ..., "target": ..., "weight": ...}]}
Causal Interpretability¶
Module: mlxterp.causal
The causal package provides tools for understanding which model components cause specific behaviors, from activation patching to automated circuit discovery.
Function: activation_patching¶
The core causal analysis tool. Patches clean activations into corrupted forward passes to identify important components.
from mlxterp.causal import activation_patching
# Layer-level patching: which layers matter?
result = activation_patching(
model,
clean="The Eiffel Tower is in",
corrupted="The Colosseum is in",
component="mlp", # "resid_post", "attn", "mlp", "attn_head"
metric="l2", # or "logit_diff", "kl", "cosine", "ce_diff"
layers=range(16), # None = all layers
)
print(result.summary())
result.plot() # Heatmap of effects
# Position-level: which token positions matter?
result = activation_patching(
model, clean, corrupted,
component="mlp",
positions=[3, 4, 5], # Only patch these positions
)
# Head-level: which attention heads matter?
result = activation_patching(
model, clean, corrupted,
component="attn_head", # Returns (n_layers, n_heads) matrix
metric="logit_diff",
metric_kwargs={"correct_token": 123, "incorrect_token": 456},
)
top = result.top_components(k=10)
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
model |
InterpretableModel |
required | Model to analyze |
clean |
str or mx.array |
required | Clean/correct input |
corrupted |
str or mx.array |
required | Corrupted/counterfactual input |
layers |
List[int] |
None (all) |
Layer indices to test |
component |
str |
"resid_post" |
Component to patch |
metric |
str or Callable |
"l2" |
Metric function |
positions |
List[int] |
None (all) |
Token positions to patch |
metric_kwargs |
Dict |
None |
Extra metric arguments |
verbose |
bool |
False |
Print progress |
Class: CausalTrace¶
Declarative context manager for paired clean/corrupted analysis. Reduces boilerplate for multi-component patching experiments.
from mlxterp.causal import CausalTrace
# Using model.causal_trace() shortcut:
with model.causal_trace(clean_text, corrupted_text) as ct:
# Register patches declaratively
ct.patch("layers.5.mlp")
ct.patch("layers.7.self_attn", positions=[3, 4])
# Compute metric with all patches applied at once
effect = ct.metric("l2")
print(f"Combined patching effect: {effect:.4f}")
# Access clean activations
with model.causal_trace(clean, corrupted) as ct:
act = ct.get_clean_activation("mlp", layer=5)
print(f"Clean MLP output shape: {act.shape}")
# Use canonical names with layer parameter
with model.causal_trace(clean, corrupted) as ct:
ct.patch("mlp", layer=0)
ct.patch("attn", layer=3)
effect = ct.metric("logit_diff", correct_token=100, incorrect_token=200)
Methods:
ct.patch(component, positions=None, layer=None)- Register a component to be patched from clean into corruptedct.metric(metric_fn, **kwargs)- Apply all registered patches and compute metricct.get_clean_activation(component, layer=None)- Retrieve a specific clean activation
Class: ResidualStreamAccessor¶
Structured access to the residual stream at each layer, decomposed into attention and MLP contributions.
from mlxterp.causal import ResidualStreamAccessor
with model.trace("The capital of France is") as trace:
pass
rs = ResidualStreamAccessor(trace.activations)
# Access residual stream at each layer
pre = rs.resid_pre(5) # Input to layer 5
post = rs.resid_post(5) # Output of layer 5
mid = rs.resid_mid(5) # Between attn and MLP (= pre + attn)
# Component contributions (raw, before residual add)
attn_out = rs.attn_contribution(5) # What attention added
mlp_out = rs.mlp_contribution(5) # What MLP added
total = rs.layer_contribution(5) # Total: post - pre
# Discover available layers
layers = rs.available_layers() # [0, 1, 2, ..., 15]
Key invariants:
resid_pre[i] == resid_post[i-1](layer i's input is layer i-1's output)resid_post[i] == resid_pre[i] + attn_contribution[i] + mlp_contribution[i](approximately, for standard transformer blocks)resid_mid[i] == resid_pre[i] + attn_contribution[i]
Function: direct_logit_attribution¶
Decomposes the final logit prediction into per-component contributions by projecting each component's output through the unembedding matrix.
from mlxterp.causal import direct_logit_attribution
result = direct_logit_attribution(
model,
text="The capital of France is",
target_token=None, # Auto-detect from argmax, or specify token ID
position=-1, # Token position to analyze (-1 = last)
layers=None, # None = all layers
)
print(f"Target: '{result.target_token_str}' (id={result.target_token})")
print(f"Attention contributions: {result.head_contributions}")
print(f"MLP contributions: {result.mlp_contributions}")
print(result.summary())
Function: attribution_patching¶
Fast approximation of activation patching using gradient-based finite differences. ~100x faster than brute-force patching.
from mlxterp.causal import attribution_patching
result = attribution_patching(
model,
clean="The Eiffel Tower is in",
corrupted="The Colosseum is in",
component="resid_post", # "resid_post", "attn", "mlp"
metric="l2",
n_projections=1, # More projections = more accurate, slower
eps=1e-3, # Perturbation size
)
print(result.summary())
# Scores approximate brute-force patching effects
print(f"Attribution scores: {result.attribution_scores}")
Function: path_patching¶
Measures the causal effect along a specific sender-receiver path by freezing all other components.
from mlxterp.causal import path_patching
result = path_patching(
model,
clean="The Eiffel Tower is in",
corrupted="The Colosseum is in",
sender="layers.7.self_attn", # Source component
receiver="layers.9.self_attn", # Target component
metric="l2",
)
print(f"Path effect: {result.data['effect']:.4f}")
print(f"Components frozen: {result.data['n_frozen']}")
Function: acdc¶
Automated Circuit Discovery. Iteratively tests each component's importance and prunes those below a threshold.
from mlxterp.causal import acdc
circuit = acdc(
model,
clean="The Eiffel Tower is in",
corrupted="The Colosseum is in",
threshold=0.01, # Minimum effect to keep a node
components=["attn", "mlp"], # Component types to test
layers=range(16),
)
print(circuit.summary())
# "Circuit: 8 nodes, 12 edges (threshold=0.01)"
print("Important components:", circuit.nodes)
graph = circuit.to_graph() # For visualization
Feature-Level Analysis¶
Function: feature_patching¶
Compute causal effect of individual SAE features via ablation.
from mlxterp.causal.feature_circuits import feature_patching
effects = feature_patching(
model, sae,
text="The capital of France is",
layer=10,
component="mlp",
top_k=20, # Test top-20 most active features
)
# Returns: {feature_id: effect_score, ...}
# Or test specific features
effects = feature_patching(
model, sae, text,
layer=10,
feature_ids=[42, 123, 456],
)
Function: feature_circuit¶
Discover which SAE features are causally important.
from mlxterp.causal.feature_circuits import feature_circuit
circuit = feature_circuit(
model, sae, text,
layer=10,
component="mlp",
threshold=0.01,
top_k=50,
)
print(circuit.nodes) # ["L10.mlp.f42", "L10.mlp.f123", ...]
Text Generation¶
Module: mlxterp.generation
Generate text token-by-token with optional interventions applied at each step.
Function: generate / model.generate()¶
# Basic generation
result = model.generate(
"The capital of France is",
max_tokens=20,
temperature=0.0, # 0.0 = greedy (deterministic)
)
print(result.text) # " Paris. Paris is..."
print(result.tokens) # [3681, 13, 3681, ...]
# Generation with interventions (steering)
from mlxterp.core.intervention import add_vector, scale
result = model.generate(
"I think the movie was",
max_tokens=30,
temperature=0.7,
interventions={"layers.5": add_vector(positive_sentiment_vec)},
)
# Top-k and top-p sampling
result = model.generate(
prompt,
max_tokens=50,
temperature=0.8,
top_k=40,
top_p=0.95,
)
# Callback for custom stopping or logging
def my_callback(step, token_id, logits):
print(f"Step {step}: token {token_id}")
return step >= 20 # Return True to stop
result = model.generate(
prompt,
max_tokens=100,
callback=my_callback,
stop_tokens=[tokenizer.eos_token_id],
)
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
prompt |
str, mx.array, List[int] |
required | Input prompt |
max_tokens |
int |
50 |
Maximum tokens to generate |
temperature |
float |
0.0 |
Sampling temperature (0 = greedy) |
top_k |
int |
0 |
Top-k filtering (0 = disabled) |
top_p |
float |
1.0 |
Nucleus sampling (1.0 = disabled) |
interventions |
Dict[str, Callable] |
None |
Interventions applied at each step |
stop_tokens |
List[int] |
None |
Token IDs that stop generation |
callback |
Callable |
None |
callback(step, token_id, logits) -> bool |
Returns: GenerationResult
Conversation Analysis¶
Module: mlxterp.conversation
Tools for analyzing multi-turn conversations: turn detection, per-turn activation slicing, and cross-turn attention analysis.
Class: Turn¶
Represents a single turn in a conversation with token position tracking.
from mlxterp.conversation import Turn
turn = Turn(
index=0,
role="user",
full_start=0, # Including template tokens
full_end=15,
content_start=3, # Content only (no role markers)
content_end=12,
)
turn.content_positions # slice(3, 12)
turn.full_positions # slice(0, 15)
turn.n_content_tokens # 9
turn.n_total_tokens # 15
Class: TurnList¶
Container with indexing, slicing, and role-based filtering.
from mlxterp.conversation import TurnList
turns[0] # First turn
turns[0:2] # First two turns (returns TurnList)
turns.by_role("user") # All user turns
turns.by_role("assistant") # All assistant turns
turns.roles # ["user", "assistant", "user"]
turns.content_positions() # All content positions across turns
Function: detect_turns¶
Detect turn boundaries in a tokenized conversation using the chat template.
from mlxterp.conversation import detect_turns
messages = [
{"role": "user", "content": "My name is Alice."},
{"role": "assistant", "content": "Nice to meet you!"},
{"role": "user", "content": "What's my name?"},
]
turns = detect_turns(tokenizer, messages)
print(turns)
# [Turn(0, role='user', content=5:12, full=1:13),
# Turn(1, role='assistant', content=16:22, full=13:23),
# Turn(2, role='user', content=26:31, full=23:32)]
Supports Llama 3, ChatML (Qwen), Gemma, and custom templates.
Class: ConversationTrace¶
Context manager for multi-turn conversation analysis. Uses model.conversation_trace().
conversation = [
{"role": "user", "content": "The capital of France is Paris."},
{"role": "assistant", "content": "That's correct!"},
{"role": "user", "content": "What's the capital of France?"},
]
with model.conversation_trace(conversation) as ct:
# Access turns
print(ct.turns) # TurnList with 3 turns
# Get activations for a specific turn
turn0_act = ct.get_turn_activation(0, "layers.5")
turn2_act = ct.get_turn_activation(2, "layers.5")
# Cross-turn attention: how much does each turn attend to others?
cross_attn = ct.cross_turn_attention(layer=5, head=0)
# Returns: (3, 3) matrix — entry [i,j] = avg attention from turn i to turn j
# Convert to structured result
result = ct.to_result()
print(result.summary())
Methods:
| Method | Description |
|---|---|
ct.get_turn_activation(turn_idx, component, content_only=True) |
Get activations sliced to a turn |
ct.cross_turn_attention(layer, head) |
Turn x turn attention aggregation |
ct.to_result() |
Convert to ConversationResult |
Visualization (Patching & Dashboards)¶
Patching Visualization¶
Module: mlxterp.visualization.patching
from mlxterp.visualization.patching import plot_patching_result, plot_patching_comparison
# Plot a single result (auto-detects bar chart vs heatmap)
result = activation_patching(model, clean, corrupted, component="mlp")
fig = result.plot() # Calls plot_patching_result internally
# Or call directly with options
fig = plot_patching_result(result, figsize=(14, 6), cmap="viridis", title="MLP Patching")
# Compare multiple results side by side
results = [
activation_patching(model, clean, corrupted, component="mlp"),
activation_patching(model, clean, corrupted, component="attn"),
]
fig = plot_patching_comparison(results, title="MLP vs Attention")
Feature Dashboards¶
Module: mlxterp.visualization.dashboards
Generate standalone HTML dashboards for SAE feature analysis.
from mlxterp.visualization.dashboards import (
max_activating_examples,
feature_activation_histogram,
generate_feature_dashboard_html,
)
# Find texts that maximally activate a feature
examples = max_activating_examples(
model, sae,
feature_id=42,
texts=dataset_texts,
layer=10,
component="mlp",
top_k=10,
)
# Returns: [{"text": "...", "activation_value": 0.95, "token_position": 7}, ...]
# Compute activation distribution
histogram = feature_activation_histogram(
model, sae,
feature_id=42,
texts=dataset_texts,
layer=10,
)
# Returns: {"bin_edges": [...], "counts": [...], "mean": 0.12, "std": 0.3, "sparsity": 0.85}
# Generate HTML dashboard
html = generate_feature_dashboard_html(42, examples, histogram)
with open("feature_42_dashboard.html", "w") as f:
f.write(html)
Intervention: replace_at_positions¶
Position-level activation patching (added to mlxterp.core.intervention).
from mlxterp.core.intervention import replace_at_positions
# Only replace positions 3 and 4, leaving all others unchanged
with model.trace(input, interventions={
"layers.5.mlp": replace_at_positions(clean_activation, [3, 4])
}):
output = model.output.save()
Function: resolve_component¶
Cross-architecture component name resolution (in mlxterp.core.module_resolver).
from mlxterp.core.module_resolver import resolve_component
# Find the actual activation key for "mlp" at layer 5
key = resolve_component("mlp", 5, trace.activations)
# Returns: "model.model.layers.5.mlp" (Llama) or "model.h.5.mlp" (GPT-2)
# Canonical names: "resid_post", "resid_pre", "attn", "mlp", "attn_head"
Research Workflows¶
Module: mlxterp.workflows
Pre-built multi-step pipelines that chain analysis tools and return WorkflowResult.
Function: behavior_localization¶
Pipeline: DLA → MLP patching → attention patching → head-level patching.
from mlxterp.workflows import behavior_localization
result = behavior_localization(
model, clean, corrupted,
metric="l2",
steps=["dla", "patch_mlp", "patch_attn", "patch_heads"],
top_k=5,
verbose=True,
)
print(result.narrative)
mlp_result = result.get_step("patch_mlp")
Function: circuit_discovery¶
Pipeline: attribution patching → activation patching → ACDC.
from mlxterp.workflows import circuit_discovery
result = circuit_discovery(model, clean, corrupted, threshold=0.01)
circuit = result.get_step("acdc")
Function: feature_investigation¶
Pipeline: find active SAE features → ablation → max-activating examples.
from mlxterp.workflows import feature_investigation
result = feature_investigation(model, sae, text, layer=10, dataset=texts)
Class: WorkflowResult¶
Returned by all workflow functions. Extends AnalysisResult.
result.steps # List of (step_name, step_result) tuples
result.narrative # Human-readable investigation narrative
result.get_step(name) # Get a specific step's result
result.to_markdown() # Full markdown report with all steps
AutoInterp Ratchet Loop¶
Module: mlxterp.autointerpret
Karpathy AutoResearch-style three-file contract for overnight interpretability experiments.
Function: init_autointerpret¶
Generate a project scaffold:
from mlxterp.autointerpret import init_autointerpret
path = init_autointerpret(
output_dir="my_investigation",
model_name="mlx-community/Llama-3.2-1B-Instruct-4bit",
research_question="How does this model recall factual associations?",
max_experiments=100,
)
# Creates: setup.py, experiment.py, program.md, CLAUDE.md, results.jsonl, findings/
Class: AutoInterpret¶
Programmatic runner for automated experiment loops:
from mlxterp.autointerpret import AutoInterpret
runner = AutoInterpret(model=model, output_dir="exp", max_experiments=50)
entry = runner.run_experiment(
name="mlp_scan",
fn=lambda: activation_patching(model, clean, corrupted, component="mlp"),
hypothesis="MLPs are important",
)
print(runner.summary())
print(f"Done: {runner.is_done}")
Class: MetricRegistry¶
Registry of metrics for experiment use:
from mlxterp.autointerpret import MetricRegistry
metrics = MetricRegistry()
metrics.register("logit_diff", fn, description="Logit difference recovery")
print(metrics.names) # ["logit_diff"]
fn = metrics.get("logit_diff")
Class: ExperimentLog¶
Append-only JSONL experiment logging:
from mlxterp.autointerpret import ExperimentLog, ExperimentEntry
log = ExperimentLog("results.jsonl")
log.append(ExperimentEntry(hypothesis="H1", method="patching", conclusion="Found it"))
print(log.summary())
Automated Interpretability¶
Module: mlxterp.auto_interp
LLM-generated SAE feature descriptions using Claude API.
Function: auto_label_feature¶
Label a single feature using max-activating examples + LLM:
from mlxterp.auto_interp import auto_label_feature
label = auto_label_feature(model, sae, feature_id=42, texts=dataset, layer=10)
print(f"{label.label}: {label.description} ({label.confidence:.0%})")
Function: auto_label_features¶
Batch label multiple features:
from mlxterp.auto_interp import auto_label_features
labels = auto_label_features(model, sae, texts=dataset, layer=10, top_k_features=20)
Function: sensitivity_test¶
Validate a label via sensitivity testing:
from mlxterp.auto_interp import sensitivity_test
label = sensitivity_test(model, sae, label, test_texts=validation, layer=10)
print(f"Passed: {label.sensitivity_passed}")
Class: FeatureLabel¶
Dataclass for labeled features:
label.feature_id # int
label.label # "short_label"
label.description # "Longer description"
label.confidence # 0.0-1.0
label.evidence # List of max-activating examples
label.sensitivity_passed # True/False/None
Requires: pip install anthropic and ANTHROPIC_API_KEY env var.
Report Generation¶
Module: mlxterp.reports
Generate shareable reports from any AnalysisResult.
Function: generate_report¶
from mlxterp.reports import generate_report
# Markdown
md = generate_report(result, title="My Analysis", format="markdown")
# HTML with embedded plots
html = generate_report(results, title="Report", format="html", include_plots=True)
# Include raw JSON appendix
md = generate_report(result, include_json=True)
Function: save_report¶
from mlxterp.reports import save_report
save_report(result, "report.md", title="Analysis Report")
save_report(result, "report.html", title="HTML Report") # Auto-detects format
Version History¶
0.1.0 (Current)¶
- Initial release
- Core tracing functionality
- Intervention system
- Basic utility functions
- Support for any MLX model
- New: Attention visualization module
- Attention weight capture during tracing
- Pattern extraction and visualization
- Pattern detection (induction, previous token, first token heads)
- Multiple backends (matplotlib, plotly, circuitsviz)