mlxterp API Documentation¶
Complete API reference for mlxterp.
Table of Contents¶
InterpretableModel¶
Class: InterpretableModel¶
Main entry point for wrapping MLX models to add interpretability features.
Constructor¶
InterpretableModel(
model: Union[nn.Module, str],
tokenizer: Optional[Any] = None,
layer_attr: str = "layers",
embedding_path: Optional[str] = None,
norm_path: Optional[str] = None,
lm_head_path: Optional[str] = None
)
Parameters:
- model (
nn.Moduleorstr): Either: - An MLX
nn.Moduleinstance to wrap -
A model name/path string (attempts to load via
mlx_lm.load()) -
tokenizer (
Optional[Any]): Tokenizer for processing text inputs. IfNoneand model is loaded from string, attempts to load tokenizer automatically. -
layer_attr (
str, default:"layers"): Name of the attribute containing the model's transformer layers. Common values: "layers"(Llama, Mistral)"h"(GPT-2)-
"transformer.h"(some GPT variants) -
embedding_path (
Optional[str], default:None): Override path for the token embedding layer. Used for weight-tied output projection inget_token_predictions. Auto-detected if not specified. Tried paths:model.embed_tokens,model.model.embed_tokens,embed_tokens,tok_embeddings,wte. -
norm_path (
Optional[str], default:None): Override path for the final layer normalization. Used inlogit_lensfor projecting intermediate activations. Auto-detected if not specified. Tried paths:model.norm,model.model.norm,norm,ln_f,model.ln_f. -
lm_head_path (
Optional[str], default:None): Override path for the output projection layer. If not found, falls back to weight-tied embedding. Tried paths:lm_head,model.lm_head,model.model.lm_head,output,head.
Returns: InterpretableModel instance
Example:
from mlxterp import InterpretableModel
# Load from model name (auto-detection works)
model = InterpretableModel("mlx-community/Llama-3.2-1B-Instruct")
# Wrap existing model
import mlx.nn as nn
base_model = nn.Module() # Your model
model = InterpretableModel(base_model, tokenizer=my_tokenizer)
# Custom layer attribute
model = InterpretableModel(gpt2_model, layer_attr="h")
# Custom model with non-standard attribute names
model = InterpretableModel(
custom_model,
tokenizer=my_tokenizer,
embedding_path="my_custom_embeddings",
norm_path="my_final_norm",
lm_head_path="my_output_projection"
)
Method: trace¶
Create a tracing context for capturing activations and applying interventions.
model.trace(
inputs: Union[str, List[str], mx.array, List[int]],
interventions: Optional[Dict[str, Callable]] = None
) -> Trace
Parameters:
- inputs: Input data in various formats:
str: Single text prompt (requires tokenizer)List[str]: Batch of text prompts (requires tokenizer)mx.array: Token array, shape(batch, seq_len)-
List[int]: Single sequence of token IDs -
interventions (
Optional[Dict[str, Callable]]): Dictionary mapping module names to intervention functions. Module names use dot notation (e.g.,"layers.3.self_attn"for mlx-lm models).
Returns: Trace context manager
Example:
# Basic tracing
with model.trace("Hello world"):
output = model.output.save()
# With interventions
from mlxterp import interventions as iv
with model.trace("Test", interventions={"layers.3": iv.scale(0.5)}):
output = model.output.save()
# Batch inputs
with model.trace(["Hello", "World"]):
acts = model.layers[3].output.save()
Method: encode¶
Encode text to token IDs using the model's tokenizer.
Parameters:
- text (
str): Text string to encode
Returns: List of token IDs
Raises: ValueError if no tokenizer is available
Example:
Method: decode¶
Decode token IDs to text using the model's tokenizer.
Parameters:
- tokens (
List[int]ormx.array): Token IDs to decode
Returns: Decoded text string
Raises: ValueError if no tokenizer is available
Example:
text = model.decode([128000, 9906, 1917])
print(text) # "<|begin_of_text|>Hello world"
# Also works with mx.array
import mlx.core as mx
tokens_array = mx.array([128000, 9906, 1917])
text = model.decode(tokens_array)
Method: encode_batch¶
Encode multiple texts to token IDs.
Parameters:
- texts (
List[str]): List of text strings to encode
Returns: List of token ID lists
Raises: ValueError if no tokenizer is available
Example:
token_lists = model.encode_batch(["Hello", "World", "Test"])
print(token_lists)
# [[128000, 9906], [128000, 10343], [128000, 2323]]
Method: token_to_str¶
Convert a single token ID to its string representation.
Parameters:
- token_id (
int): Token ID to decode
Returns: String representation of the token
Raises: ValueError if no tokenizer is available
Example:
# Decode individual tokens
tokens = model.encode("Hello world")
for i, token_id in enumerate(tokens):
token_str = model.token_to_str(token_id)
print(f"Token {i}: {token_id} -> '{token_str}'")
# Output:
# Token 0: 128000 -> '<|begin_of_text|>'
# Token 1: 9906 -> 'Hello'
# Token 2: 1917 -> ' world'
Property: vocab_size¶
Get the vocabulary size of the tokenizer.
Returns: Vocabulary size, or None if no tokenizer is available
Example:
Attribute: tokenizer¶
Direct access to the underlying tokenizer for advanced operations.
Type: Tokenizer object (varies by model)
Example:
# Access tokenizer directly for advanced features
tokenizer = model.tokenizer
# Use tokenizer-specific methods
if hasattr(tokenizer, 'special_tokens'):
print(tokenizer.special_tokens)
Method: get_token_predictions¶
Decode hidden states to token predictions using the model's output projection.
model.get_token_predictions(
hidden_state: mx.array,
top_k: int = 10,
return_scores: bool = False,
embedding_layer: Optional[Any] = None,
lm_head: Optional[Any] = None
) -> Union[List[int], List[tuple]]
Parameters:
- hidden_state (
mx.array): Hidden state tensor, shape(hidden_dim,)or(batch, hidden_dim) - top_k (
int, default:10): Number of top predictions to return - return_scores (
bool, default:False): If True, return(token_id, score)tuples - embedding_layer (
Optional[Any], default:None): Override embedding layer for weight-tied projection. If provided, uses this layer's weights for output projection. - lm_head (
Optional[Any], default:None): Override lm_head layer. If provided, uses this layer directly. Takes precedence overembedding_layer.
Returns: List of token IDs or (token_id, score) tuples
Example:
# Get predictions from a specific layer
with model.trace("The capital of France is") as trace:
layer_6 = trace.activations["model.model.layers.6"]
# Get last token's hidden state
last_token_hidden = layer_6[0, -1, :]
# Get top predictions
predictions = model.get_token_predictions(last_token_hidden, top_k=5)
# Decode to words
for token_id in predictions:
print(model.token_to_str(token_id))
# With scores
predictions_with_scores = model.get_token_predictions(
last_token_hidden,
top_k=5,
return_scores=True
)
for token_id, score in predictions_with_scores:
token_str = model.token_to_str(token_id)
print(f"{token_str}: {score:.2f}")
# Custom model with override at call time
predictions = model.get_token_predictions(
hidden,
top_k=5,
lm_head=custom_model.my_lm_head
)
Notes: - Automatically handles weight-tied models (uses embedding weights transposed) - Works with quantized embeddings (dequantizes automatically) - Model-agnostic: auto-detects embedding/lm_head paths for various architectures - Useful for analyzing what any layer "thinks" at a specific position
Method: logit_lens¶
Apply logit lens technique to see what each layer predicts at each token position.
The logit lens projects each layer's hidden states through the final layer norm and embedding matrix to see what tokens each layer predicts at each position in the input sequence.
model.logit_lens(
text: str,
top_k: int = 1,
layers: Optional[List[int]] = None,
position: Optional[int] = None,
plot: bool = False,
max_display_tokens: int = 15,
figsize: tuple = (16, 10),
cmap: str = 'viridis',
font_family: Optional[str] = None,
final_norm: Optional[Any] = None,
skip_norm: bool = False
) -> Dict[int, List[List[tuple]]]
Parameters:
- text (
str): Input text to analyze - top_k (
int, default:1): Number of top predictions to return per position - layers (
Optional[List[int]], default:None): Specific layers to analyze (None = all) - position (
Optional[int], default:None): Specific position to analyze (None = all). Supports negative indexing (-1 = last position). - plot (
bool, default:False): If True, display a heatmap visualization showing predictions - max_display_tokens (
int, default:15): Maximum number of tokens to show in visualization (from the end) - figsize (
tuple, default:(16, 10)): Figure size for plot (width, height) - cmap (
str, default:'viridis'): Colormap for heatmap - font_family (
Optional[str], default:None): Font for plot (use 'Arial Unicode MS' for CJK support) - final_norm (
Optional[Any], default:None): Override final layer normalization module. If provided, uses this module instead of auto-detected norm layer. - skip_norm (
bool, default:False): If True, skip final layer normalization entirely. Useful for models without a final norm layer.
Returns: Dict mapping layer_idx -> list of positions -> list of (token_id, score, token_str) tuples
Structure: {layer_idx: [[pos_0_predictions], [pos_1_predictions], ...]}
Model Compatibility: This method automatically detects model structure and works with:
- mlx-lm models (model.model.norm, model.model.embed_tokens)
- GPT-2 style models (ln_f, wte)
- Custom models (use norm_path constructor arg or final_norm/skip_norm parameters)
Example:
# Get predictions at all positions for all layers
results = model.logit_lens("The capital of France is")
# Access predictions for layer 10 at position 3
layer_10_predictions = results[10]
pos_3_top_pred = layer_10_predictions[3][0] # (token_id, score, token_str)
print(f"Layer 10, Position 3: {pos_3_top_pred[2]}")
# Show what each layer predicts at the LAST position
text = "The capital of France is"
results = model.logit_lens(text, layers=[0, 5, 10, 15])
for layer_idx in [0, 5, 10, 15]:
# Get prediction at last position
last_pos_pred = results[layer_idx][-1][0][2]
print(f"Layer {layer_idx}: '{last_pos_pred}'")
# Output:
# Layer 0: ' the'
# Layer 5: ' a'
# Layer 10: ' Paris'
# Layer 15: ' Paris'
# Show predictions at each position for a specific layer
text = "The capital of France"
results = model.logit_lens(text, layers=[10])
tokens = model.encode(text)
for pos_idx, predictions in enumerate(results[10]):
input_token = model.token_to_str(tokens[pos_idx])
pred_token = predictions[0][2] # Top prediction
print(f"Position {pos_idx} ('{input_token}') -> '{pred_token}'")
# Visualize with heatmap
results = model.logit_lens(
"The Eiffel Tower is located in the city of",
plot=True,
max_display_tokens=15,
figsize=(16, 10)
)
# Displays a heatmap with:
# - X-axis: Input token positions
# - Y-axis: Model layers
# - Cell values: Top predicted token at each (layer, position)
# - Colors: Different predictions shown with different colors
# Model without final normalization
results = model.logit_lens("Hello world", skip_norm=True)
# Custom final norm at call time
results = model.logit_lens(
"Hello world",
final_norm=custom_model.my_final_norm
)
Note: Plotting requires matplotlib: pip install matplotlib
Use Cases: - Understand how model predictions evolve through layers - Debug model behavior at intermediate layers - Visualize progressive refinement of predictions - Identify where in the model certain facts are computed
Method: tuned_lens¶
Apply tuned lens for improved layer-wise predictions.
The tuned lens technique (Belrose et al., 2023) uses learned affine transformations for each layer to correct for coordinate system mismatches between layers, producing more accurate intermediate predictions than the standard logit lens.
model.tuned_lens(
text: str,
tuned_lens: TunedLens,
top_k: int = 1,
layers: Optional[List[int]] = None,
position: Optional[int] = None,
plot: bool = False,
max_display_tokens: int = 15,
figsize: tuple = (16, 10),
cmap: str = 'viridis',
font_family: Optional[str] = None,
final_norm: Any = None,
skip_norm: bool = False
) -> Dict[int, List[List[tuple]]]
Parameters:
- text (
str): Input text to analyze - tuned_lens (
TunedLens): Trained TunedLens instance with layer translators - top_k (
int, default:1): Number of top predictions to return per position - layers (
Optional[List[int]], default:None): Specific layers to analyze (None = all) - position (
Optional[int], default:None): Specific position to analyze (None = all). Supports negative indexing. - plot (
bool, default:False): If True, display a heatmap visualization - max_display_tokens (
int, default:15): Maximum number of tokens to show in visualization - figsize (
tuple, default:(16, 10)): Figure size for plot - cmap (
str, default:'viridis'): Colormap for heatmap - font_family (
Optional[str]): Font for plot (auto-detected if None) - final_norm (
Any, default:None): Override for final layer norm. Pass a callable to use a custom norm. - skip_norm (
bool, default:False): If True, skip final layer normalization (for models without it)
Returns: Dict mapping layer_idx -> list of positions -> list of (token_id, score, token_str) tuples
Example:
from mlxterp import InterpretableModel, TunedLens
model = InterpretableModel("mlx-community/Llama-3.2-1B-Instruct")
# Option 1: Train new tuned lens
tuned_lens = model.train_tuned_lens(
dataset=["Sample text 1", "Sample text 2", ...],
num_steps=250,
save_path="tuned_lens_llama.npz"
)
# Option 2: Load pre-trained tuned lens
tuned_lens = model.load_tuned_lens("tuned_lens_llama.npz")
# Apply tuned lens
results = model.tuned_lens(
"The capital of France is",
tuned_lens,
layers=[0, 5, 10, 15],
plot=True
)
# Compare with regular logit lens
regular_results = model.logit_lens("The capital of France is", layers=[0, 5, 10, 15])
Reference: Belrose et al., "Eliciting Latent Predictions from Transformers with the Tuned Lens" (https://arxiv.org/abs/2303.08112)
Method: train_tuned_lens¶
Train a tuned lens for this model.
The tuned lens technique trains small affine transformations for each layer to correct for coordinate system mismatches, producing more accurate intermediate predictions.
model.train_tuned_lens(
dataset: List[str],
num_steps: int = 250,
learning_rate: float = 1.0,
momentum: float = 0.9,
max_seq_len: int = 2048,
gradient_clip: float = 1.0,
save_path: Optional[str] = None,
verbose: bool = True,
callback: Optional[Callable[[int, float], None]] = None
) -> TunedLens
Parameters:
- dataset (
List[str]): List of text strings for training - num_steps (
int, default:250): Number of training steps - learning_rate (
float, default:1.0): Initial learning rate (uses linear decay) - momentum (
float, default:0.9): Nesterov momentum coefficient - max_seq_len (
int, default:2048): Maximum sequence length for training chunks - gradient_clip (
float, default:1.0): Gradient clipping norm - save_path (
Optional[str]): Path to save trained weights - verbose (
bool, default:True): Print training progress - callback (
Optional[Callable]): Callback function called with(step, loss)after each step
Returns: Trained TunedLens instance
Raises:
ValueError: If dataset is empty or contains only whitespaceValueError: If dataset has fewer tokens thanmax_seq_lenValueError: Ifmax_seq_lenis less than 10 tokensValueError: If model hidden dimension cannot be determined
Training Details (from paper): - Optimizer: SGD with Nesterov momentum (0.9) - Learning rate: 1.0 with linear decay over training steps - Gradient clipping: norm 1.0 - Loss: KL divergence between translator prediction and final output
Example:
# Load sample texts for training
texts = [
"The capital of France is Paris.",
"Machine learning is a subset of artificial intelligence.",
# ... more training texts
]
# Train tuned lens
tuned_lens = model.train_tuned_lens(
dataset=texts,
num_steps=250,
save_path="my_tuned_lens.npz",
verbose=True
)
Method: load_tuned_lens¶
Load a pre-trained tuned lens from a file.
Parameters:
- path (
str): Path to the saved tuned lens weights (expects.npzand.jsonfiles)
Returns: Loaded TunedLens instance
Example:
tuned_lens = model.load_tuned_lens("tuned_lens_llama.npz")
results = model.tuned_lens("Hello world", tuned_lens)
Method: activation_patching¶
Automated activation patching to identify important layers for a task.
This helper method performs activation patching across all (or specified) layers to determine which components are critical for a specific task. It automates the boilerplate of running clean/corrupted inputs, patching activations, and measuring recovery.
model.activation_patching(
clean_text: str,
corrupted_text: str,
component: str = "mlp",
layers: Optional[List[int]] = None,
metric: str = "l2",
plot: bool = False,
figsize: tuple = (12, 8),
cmap: str = "RdBu_r"
) -> Dict[int, float]
Parameters:
- clean_text (
str): Clean/correct input text - corrupted_text (
str): Corrupted input text (differs in the aspect you're studying) - component (
str, default:"mlp"): Component to patch. Options: "mlp"- Full MLP block"self_attn"- Full attention block"mlp.gate_proj"- MLP gate projection"mlp.up_proj"- MLP up projection"mlp.down_proj"- MLP down projection"self_attn.q_proj"- Query projection"self_attn.k_proj"- Key projection"self_attn.v_proj"- Value projection"self_attn.o_proj"- Output projection- layers (
Optional[List[int]], default:None): Specific layers to test (None = all layers) - metric (
str, default:"l2"): Distance metric. Options: "l2": Euclidean distance (default, with overflow protection)"cosine": Cosine distance (recommended for large vocabularies)"mse": Mean squared error (most stable for huge models > 100k vocab)
Recommendation:
- Vocab < 50k: use "l2"
- Vocab 50k-100k: use "l2" or "cosine"
- Vocab > 100k: use "mse" or "cosine"
- plot (bool, default: False): If True, display a bar chart of recovery percentages
- figsize (tuple, default: (12, 8)): Figure size for plot
- cmap (str, default: "RdBu_r"): Colormap for plot (blue = positive, red = negative)
Returns: Dict mapping layer_idx -> recovery percentage
Recovery Interpretation: - Positive % (e.g., +40%): Layer is important for the task - Negative % (e.g., -20%): Layer encodes the corruption - ~0%: Layer is not relevant to this task
Example:
# Find which MLP layers are important for factual recall
results = model.activation_patching(
clean_text="Paris is the capital of France",
corrupted_text="London is the capital of France",
component="mlp",
plot=True
)
# Analyze results
sorted_results = sorted(results.items(), key=lambda x: x[1], reverse=True)
print("\nMost important layers:")
for layer_idx, recovery in sorted_results[:3]:
print(f" Layer {layer_idx}: {recovery:.1f}% recovery")
# Output:
# Layer 0: +43.1% recovery ← Very important!
# Layer 15: +24.2% recovery ← Important
# Layer 6: +17.6% recovery ← Somewhat important
# Test specific layers only
results = model.activation_patching(
clean_text="Paris is the capital of France",
corrupted_text="London is the capital of France",
component="self_attn",
layers=[0, 5, 10, 15],
plot=False
)
# Test MLP sub-components
results = model.activation_patching(
clean_text="The cat sits on the mat",
corrupted_text="The cat sit on the mat",
component="mlp.gate_proj",
layers=[3, 8, 12]
)
# For large vocabulary models (> 100k tokens), use MSE metric
results = model.activation_patching(
clean_text="Paris is the capital of France",
corrupted_text="London is the capital of France",
component="mlp",
metric="mse", # Most stable for Qwen, GPT-4 scale models
plot=True
)
Note: Plotting requires matplotlib: pip install matplotlib
Distance Metrics:
The method supports three distance metrics for measuring output differences:
-
L2 (Euclidean) - Default, best for models with < 50k vocab
-
Cosine - Best for direction-based similarity, good for 50k-150k vocab
-
MSE (Mean Squared Error) - Most stable for very large models (> 100k vocab)
Example with large vocabulary model (Qwen: 151k tokens):
# Without correct metric - gets NaN
results = model.activation_patching(..., metric="l2") # ❌ May overflow
# With correct metric - works perfectly
results = model.activation_patching(..., metric="mse") # ✅ Stable
See Also: Activation Patching Guide for comprehensive coverage including metric selection and numerical details
Method: named_modules¶
Iterator over all modules in the wrapped model with their names.
Returns: Iterator yielding (name, module) tuples
Example:
Attribute: layers¶
Indexed access to model layers via LayerListProxy.
Type: LayerListProxy
Example:
# Access specific layer
layer_3 = model.layers[3]
# Iterate over layers
for layer in model.layers:
print(layer)
# Get number of layers
num_layers = len(model.layers)
Attribute: output¶
Access to model output within a trace context. Returns an OutputProxy that can be saved.
Type: OutputProxy
Example:
Tracing¶
Class: Trace¶
Context manager for tracing model execution. Created by InterpretableModel.trace().
Attributes¶
- output (
mx.array): Model output after trace completes - saved_values (
Dict[str, mx.array]): Values saved with.save() - activations (
Dict[str, mx.array]): All captured activations by module name
Methods¶
get(name: str) -> Optional[mx.array]¶
Get a saved value by name.
with model.trace(input) as trace:
model.layers[3].output.save()
activation = trace.get("layers.3.output")
get_activation(name: str) -> Optional[mx.array]¶
Get an activation by module name.
with model.trace(input) as trace:
pass # Activations captured automatically
# Note: get_activation requires the full key (not normalized)
attn_act = trace.get_activation("model.model.layers.3.self_attn")
Class: OutputProxy¶
Wraps module outputs to provide .save() functionality.
Method: save() -> Any¶
Save the wrapped value to the current trace context and return the unwrapped value.
Returns: The underlying value (usually mx.array)
Example:
with model.trace(input):
# Save returns the actual array (use self_attn for mlx-lm models)
attn = model.layers[3].self_attn.output.save()
print(attn.shape) # Can use immediately
Interventions¶
Intervention functions modify activations during forward passes.
Module: mlxterp.interventions¶
Namespace containing pre-built intervention functions.
from mlxterp import interventions as iv
# Use intervention functions
with model.trace(input, interventions={"layers.3": iv.scale(0.5)}):
output = model.output.save()
Function: zero_out¶
Set activations to zero.
Example:
Function: scale¶
Multiply activations by a constant factor.
Parameters:
- factor (float): Scaling factor
Example:
# Reduce by 50%
with model.trace(input, interventions={"layers.3": iv.scale(0.5)}):
output = model.output.save()
# Amplify
with model.trace(input, interventions={"layers.3": iv.scale(2.0)}):
output = model.output.save()
Function: add_vector¶
Add a steering vector to activations.
Parameters:
- vector (mx.array): Vector to add (must be broadcastable to activation shape)
Example:
import mlx.core as mx
# Create steering vector
steering = mx.random.normal((hidden_dim,))
with model.trace(input, interventions={"layers.5": iv.add_vector(steering)}):
steered_output = model.output.save()
Function: replace_with¶
Replace activations with a fixed value.
Parameters: - value: Replacement value (array or scalar)
Example:
# Replace with zeros
with model.trace(input, interventions={"layers.3": iv.replace_with(0.0)}):
output = model.output.save()
# Replace with custom array
custom = mx.ones((batch, seq_len, hidden_dim))
with model.trace(input, interventions={"layers.3": iv.replace_with(custom)}):
output = model.output.save()
Function: clamp¶
Clamp activation values to a range.
Parameters:
- min_val (Optional[float]): Minimum value
- max_val (Optional[float]): Maximum value
Example:
# Clamp to [-1, 1]
with model.trace(input, interventions={"layers.3": iv.clamp(-1.0, 1.0)}):
output = model.output.save()
# Only maximum
with model.trace(input, interventions={"layers.3": iv.clamp(max_val=10.0)}):
output = model.output.save()
Function: noise¶
Add Gaussian noise to activations.
Parameters:
- std (float, default: 0.1): Standard deviation of noise
Example:
with model.trace(input, interventions={"layers.3": iv.noise(std=0.2)}):
output = model.output.save()
Class: InterventionComposer¶
Compose multiple interventions into a single function.
Method: add¶
Add an intervention to the composition.
Returns: self for chaining
Method: build¶
Build the composed intervention function.
Returns: Composed intervention function
Example:
from mlxterp import interventions as iv
# Compose multiple interventions
combined = iv.compose() \
.add(iv.scale(0.8)) \
.add(iv.noise(std=0.1)) \
.add(iv.clamp(-5.0, 5.0)) \
.build()
with model.trace(input, interventions={"layers.3": combined}):
output = model.output.save()
Custom Interventions¶
Create your own intervention functions:
import mlx.core as mx
def my_intervention(activation: mx.array) -> mx.array:
"""Custom activation modification"""
# Your logic here
return mx.tanh(activation)
with model.trace(input, interventions={"layers.3": my_intervention}):
output = model.output.save()
Requirements:
- Function signature: (mx.array) -> mx.array
- Must return array with same shape as input
- Can use any MLX operations
Utilities¶
Function: get_activations¶
Collect activations for specified layers and token positions.
get_activations(
model: InterpretableModel,
prompts: Union[str, List[str]],
layers: Optional[List[int]] = None,
positions: Union[int, List[int]] = -1
) -> Dict[str, mx.array]
Parameters:
- model: InterpretableModel instance
- prompts: Single prompt or list of prompts
- layers: Layer indices to collect (None = all layers)
- positions: Token position(s) to extract
-1: Last token0: First token[0, -1]: First and last tokens
Returns: Dict mapping "layer_{i}" to activation arrays
Shapes:
- Single position: (batch_size, hidden_dim)
- Multiple positions: (batch_size, num_positions, hidden_dim)
Example:
from mlxterp import get_activations
# Single prompt, multiple layers
acts = get_activations(model, "Hello world", layers=[3, 8, 12])
print(acts["layer_3"].shape) # (1, hidden_dim)
# Batch prompts
acts = get_activations(
model,
["Hello", "World", "Test"],
layers=[5],
positions=-1
)
print(acts["layer_5"].shape) # (3, hidden_dim)
# Multiple positions
acts = get_activations(
model,
"Test prompt",
layers=[3],
positions=[0, -1] # First and last token
)
print(acts["layer_3"].shape) # (1, 2, hidden_dim)
Function: batch_get_activations¶
Memory-efficient batch processing for large datasets.
batch_get_activations(
model: InterpretableModel,
prompts: List[str],
layers: Optional[List[int]] = None,
positions: Union[int, List[int]] = -1,
batch_size: int = 8
) -> Dict[str, mx.array]
Parameters:
- model: InterpretableModel instance
- prompts: List of prompts
- layers: Layer indices to collect
- positions: Token position(s) to extract
- batch_size: Number of prompts per batch
Returns: Dict mapping "layer_{i}" to concatenated activation arrays
Example:
from mlxterp import batch_get_activations
# Process 1000 prompts efficiently
large_dataset = [f"Prompt {i}" for i in range(1000)]
acts = batch_get_activations(
model,
prompts=large_dataset,
layers=[3, 8, 12],
batch_size=32
)
print(acts["layer_3"].shape) # (1000, hidden_dim)
Function: collect_activations¶
Direct activation collection with caching.
collect_activations(
model: InterpretableModel,
inputs: Any,
layers: Optional[List[str]] = None
) -> ActivationCache
Parameters:
- model: InterpretableModel instance
- inputs: Input data
- layers: List of layer names to cache (None = all)
Returns: ActivationCache object
Example:
from mlxterp import collect_activations
cache = collect_activations(
model,
"Test input",
layers=["layers.3", "layers.8"]
)
# Access cached activations
act_3 = cache.get("layers.3")
print(f"Cached {len(cache)} activations")
print(f"Available keys: {cache.keys()}")
Core Components¶
Advanced usage: Direct access to core components.
Class: ModuleProxy¶
Wraps nn.Module to intercept forward passes. Created automatically by InterpretableModel.
Attributes:
- output: OutputProxy for the module's output
Example:
# Access through InterpretableModel.layers
proxy = model.layers[3] # Returns ModuleProxy
print(type(proxy)) # ModuleProxy
with model.trace(input):
act = proxy.output.save()
Class: LayerListProxy¶
Provides indexed access to model layers.
Methods:
- __getitem__(idx): Get layer at index
- __len__(): Number of layers
- __iter__(): Iterate over layers
Example:
# Created automatically
layers = model.layers
# Access
layer_3 = layers[3]
# Length
print(len(layers)) # 12
# Iterate
for i, layer in enumerate(layers):
print(f"Layer {i}: {layer}")
Class: ActivationCache¶
Storage for cached activations.
Attributes:
- activations (Dict[str, mx.array]): Activation storage
- metadata (Optional[Dict]): Additional information
Methods:
- get(name): Get activation by name
- keys(): List all cached names
- __contains__(name): Check if activation exists
- __len__(): Number of cached activations
Example:
from mlxterp import collect_activations
cache = collect_activations(model, input)
# Access
act = cache.get("layers.3")
# Check existence
if "layers.3" in cache:
print("Found!")
# List all
for name in cache.keys():
print(name)
Class: ModuleResolver¶
Generic module resolution for different MLX model architectures. Automatically finds embedding, final norm, and lm_head modules using fallback chains.
Constructor:
ModuleResolver(
model: nn.Module,
embedding_path: Optional[str] = None,
norm_path: Optional[str] = None,
lm_head_path: Optional[str] = None
)
Methods:
get_embedding_layer(): Get token embedding layerget_final_norm(): Get final layer normalizationget_lm_head(): Get output projection layerget_output_projection(): Get output projection with weight-tied detection. Returns(module, path, is_weight_tied)clear_cache(): Clear resolved module cache (call after modifying model structure)
Fallback Chains:
| Component | Resolution Order |
|---|---|
| Embedding | model.embed_tokens, model.model.embed_tokens, embed_tokens, tok_embeddings, wte |
| Final Norm | model.norm, model.model.norm, norm, ln_f, model.ln_f |
| LM Head | lm_head, model.lm_head, model.model.lm_head, output, head (falls back to embedding if not found) |
Example:
from mlxterp.core import ModuleResolver
# Create resolver for a model
resolver = ModuleResolver(model)
# Get components
embedding = resolver.get_embedding_layer()
norm = resolver.get_final_norm()
proj, path, is_tied = resolver.get_output_projection()
if is_tied:
print("Model uses weight-tied embedding for output")
# With custom paths
resolver = ModuleResolver(
model,
embedding_path="my_embed",
norm_path="my_norm"
)
# Cache invalidation (after modifying model structure)
resolver.clear_cache() # Force re-resolution on next access
Class: TunedLens¶
Learned affine translators for each layer, implementing the Tuned Lens technique from Belrose et al. (2023).
The tuned lens uses layer-specific affine transformations (Wx + b) to map hidden states from each layer into a space where the final output projection can make accurate predictions. This corrects for coordinate system mismatches between layers.
Parameters:
- num_layers (
int): Number of transformer layers in the model - hidden_dim (
int): Dimension of hidden states
Attributes:
num_layers: Number of layershidden_dim: Hidden dimensiontranslators: List of linear layers, one per transformer layer
Methods:
| Method | Description |
|---|---|
__call__(h, layer_idx) |
Apply translator for a specific layer |
save(path) |
Save weights and config to files (.npz and .json) |
load(path) |
Load tuned lens from saved files (classmethod) |
Example:
from mlxterp import TunedLens
# Create tuned lens
tuned_lens = TunedLens(num_layers=32, hidden_dim=4096)
# Apply to hidden state from layer 10
translated = tuned_lens(hidden_state, layer_idx=10)
# Save and load
tuned_lens.save("my_tuned_lens")
loaded = TunedLens.load("my_tuned_lens")
Reference: Belrose et al., "Eliciting Latent Predictions from Transformers with the Tuned Lens" (https://arxiv.org/abs/2303.08112)
Function: normalize_layer_key¶
Normalize activation keys by removing model prefixes.
Example:
from mlxterp.core import normalize_layer_key
normalize_layer_key("model.model.layers.0") # "layers.0"
normalize_layer_key("model.layers.5.self_attn") # "layers.5.self_attn"
Function: find_layer_key_pattern¶
Find the correct activation key pattern for a layer index.
find_layer_key_pattern(
activations: dict,
layer_idx: int,
component: Optional[str] = None
) -> Optional[str]
Example:
from mlxterp.core import find_layer_key_pattern
# Find layer 5's key in activations dict
key = find_layer_key_pattern(trace.activations, 5)
# Returns "model.model.layers.5" or "layers.5" etc.
# Find specific component
key = find_layer_key_pattern(trace.activations, 5, "self_attn")
Type Annotations¶
Common types used in mlxterp:
from typing import Union, List, Dict, Callable, Optional, Any
import mlx.core as mx
import mlx.nn as nn
# Input types
InputType = Union[str, List[str], mx.array, List[int]]
# Intervention function type
InterventionFn = Callable[[mx.array], mx.array]
# Interventions dict
InterventionsDict = Dict[str, InterventionFn]
# Model type
ModelType = Union[nn.Module, str]
Error Handling¶
Common Exceptions¶
ValueError¶
Raised when: - String input provided without tokenizer - Invalid input format - Model cannot be loaded from string
# Will raise ValueError
model = InterpretableModel(base_model) # No tokenizer
with model.trace("text input"): # Needs tokenizer!
pass
Solution: Provide tokenizer
AttributeError¶
Raised when: - Accessing non-existent module attribute - Layer attribute doesn't exist
# Will raise AttributeError
model = InterpretableModel(custom_model, layer_attr="transformer")
# If custom_model doesn't have 'transformer' attribute
Solution: Specify correct layer attribute
Best Practices¶
1. Always Use Context Managers¶
# ✅ Good
with model.trace(input):
act = model.layers[3].output.save()
# ❌ Avoid
trace = model.trace(input)
# Missing context manager!
2. Save Early, Access Later¶
# ✅ Good
with model.trace(input) as t:
model.layers[3].output.save()
# Access after trace completes
act = t.get("layers.3.output")
# ❌ Avoid trying to access during trace
with model.trace(input) as t:
act = t.get("layers.3.output") # Not saved yet!
3. Use Utility Functions for Common Tasks¶
# ✅ Good - Use get_activations
from mlxterp import get_activations
acts = get_activations(model, prompts, layers=[3, 8])
# ❌ Avoid manual loops
acts = {}
for layer_idx in [3, 8]:
with model.trace(prompts):
acts[layer_idx] = model.layers[layer_idx].output.save()
4. Batch Large Datasets¶
# ✅ Good - Use batching
from mlxterp import batch_get_activations
acts = batch_get_activations(model, large_list, batch_size=32)
# ❌ Avoid loading everything at once
with model.trace(large_list): # May run out of memory
acts = model.layers[3].output.save()
Visualization Module¶
The visualization module provides tools for attention pattern analysis and visualization.
from mlxterp.visualization import (
# Attention extraction and visualization
get_attention_patterns,
attention_heatmap,
attention_from_trace,
AttentionVisualizationConfig,
# Pattern detection
AttentionPatternDetector,
detect_head_types,
detect_induction_heads,
induction_score,
previous_token_score,
first_token_score,
copying_score,
)
Function: get_attention_patterns¶
Extract attention weight patterns from a trace.
Parameters:
- trace (
Trace): Completed trace context with captured attention weights - layers (
Optional[List[int]], default:None): Specific layers to extract (None = all)
Returns: Dict mapping layer_idx -> attention array of shape (batch, heads, seq_len, seq_len)
Example:
from mlxterp import InterpretableModel
from mlxterp.visualization import get_attention_patterns
model = InterpretableModel("mlx-community/Llama-3.2-1B-Instruct-4bit")
with model.trace("The capital of France is") as trace:
pass
# Get all attention patterns
patterns = get_attention_patterns(trace)
print(f"Found {len(patterns)} layers")
print(f"Shape: {patterns[0].shape}") # (1, 32, 6, 6)
# Get specific layers
patterns = get_attention_patterns(trace, layers=[0, 5, 10])
Function: attention_heatmap¶
Create a heatmap visualization of attention patterns.
attention_heatmap(
attention: np.ndarray,
tokens: List[str],
head_idx: int = 0,
title: Optional[str] = None,
colorscale: str = "Blues",
backend: str = "matplotlib",
mask_upper_tri: bool = True,
figsize: tuple = (8, 6)
) -> Any
Parameters:
- attention (
np.ndarray): Attention weights, shape(batch, heads, seq_q, seq_k) - tokens (
List[str]): Token strings for axis labels - head_idx (
int, default:0): Which attention head to visualize - title (
Optional[str]): Plot title - colorscale (
str, default:"Blues"): Colormap name - backend (
str, default:"matplotlib"): Backend ("matplotlib","plotly","circuitsviz") - mask_upper_tri (
bool, default:True): Mask future positions (for causal attention) - figsize (
tuple, default:(8, 6)): Figure size for matplotlib
Returns: Figure object (type depends on backend)
Example:
from mlxterp.visualization import get_attention_patterns, attention_heatmap
with model.trace("Hello world") as trace:
pass
patterns = get_attention_patterns(trace, layers=[5])
tokens = model.to_str_tokens("Hello world")
# Create heatmap for head 0
fig = attention_heatmap(
patterns[5],
tokens,
head_idx=0,
title="Layer 5, Head 0",
backend="matplotlib"
)
Function: attention_from_trace¶
High-level function to visualize attention patterns from a trace.
attention_from_trace(
trace: Trace,
tokens: List[str],
layers: Optional[List[int]] = None,
heads: Optional[List[int]] = None,
mode: str = "single",
head_notation: str = "LH",
config: Optional[AttentionVisualizationConfig] = None
) -> Any
Parameters:
- trace (
Trace): Completed trace context - tokens (
List[str]): Token strings for labels - layers (
Optional[List[int]]): Layers to visualize (default: first layer) - heads (
Optional[List[int]]): Heads to visualize (default: first head) - mode (
str, default:"single"): Visualization mode: "single": One heatmap"grid": Grid of multiple heatmaps- head_notation (
str, default:"LH"): Notation for titles ("LH"= L5H3,"dot"= 5.3) - config (
Optional[AttentionVisualizationConfig]): Custom configuration
Returns: Figure object
Example:
from mlxterp.visualization import attention_from_trace, AttentionVisualizationConfig
with model.trace("The quick brown fox") as trace:
pass
tokens = model.to_str_tokens("The quick brown fox")
# Single heatmap
fig = attention_from_trace(trace, tokens, layers=[5], heads=[0], mode="single")
# Grid of multiple heads
config = AttentionVisualizationConfig(backend="matplotlib", colorscale="Blues")
fig = attention_from_trace(
trace, tokens,
layers=[0, 5, 10],
heads=[0, 1, 2, 3],
mode="grid",
config=config
)
Class: AttentionVisualizationConfig¶
Configuration for attention visualization.
AttentionVisualizationConfig(
colorscale: str = "Blues",
mask_upper_tri: bool = True,
backend: str = "auto",
figsize: tuple = (10, 8),
show_colorbar: bool = True,
font_size: int = 10
)
Parameters:
- colorscale (
str): Colormap name (default:"Blues") - mask_upper_tri (
bool): Mask future positions (default:True) - backend (
str): Visualization backend (default:"auto"- tries circuitsviz, then plotly, then matplotlib) - figsize (
tuple): Figure size (default:(10, 8)) - show_colorbar (
bool): Show colorbar (default:True) - font_size (
int): Font size for labels (default:10)
Function: induction_score¶
Compute induction head score for an attention pattern.
Induction heads implement the pattern [A][B]...[A] -> predict [B] by attending from position i to position i - seq_len + 1 in repeated sequences.
Parameters:
- attention_pattern (
np.ndarray): Attention weights, shape(seq_q, seq_k) - seq_len (
int): Length of the repeated subsequence
Returns: Induction score (0-1, higher = more induction-like)
Example:
import numpy as np
from mlxterp.visualization import induction_score
# Create pattern from repeated sequence "ABC ABC"
# In a true induction head, position 4 (second A) attends to position 1 (first B)
seq_len = 3
total_len = 6
# Synthetic perfect induction pattern
pattern = np.zeros((total_len, total_len))
for i in range(seq_len, total_len):
pattern[i, i - seq_len + 1] = 1.0
score = induction_score(pattern, seq_len)
print(f"Induction score: {score:.3f}") # ~1.0
Function: previous_token_score¶
Compute previous token head score.
Previous token heads attend strongly to position i-1 (the immediately preceding token).
Parameters:
- attention_pattern (
np.ndarray): Attention weights, shape(seq_q, seq_k)
Returns: Previous token score (0-1, higher = attends more to previous position)
Example:
from mlxterp.visualization import previous_token_score
# Perfect previous token pattern
pattern = np.zeros((5, 5))
for i in range(1, 5):
pattern[i, i-1] = 1.0
score = previous_token_score(pattern)
print(f"Previous token score: {score:.3f}") # ~1.0
Function: first_token_score¶
Compute first token (BOS) head score.
First token heads attend strongly to position 0 (typically BOS or first token).
Parameters:
- attention_pattern (
np.ndarray): Attention weights, shape(seq_q, seq_k)
Returns: First token score (0-1, higher = attends more to position 0)
Example:
from mlxterp.visualization import first_token_score
# Perfect first token pattern
pattern = np.zeros((5, 5))
pattern[:, 0] = 1.0 # All positions attend to first
score = first_token_score(pattern)
print(f"First token score: {score:.3f}") # ~1.0
Function: copying_score¶
Compute copying head score from OV circuit.
Copying heads increase the logit of the attended-to token.
Parameters:
- ov_circuit (
np.ndarray): OV circuit matrixW_V @ W_O, shape(d_model, d_model) - unembedding (
Optional[np.ndarray]): Optional unembedding matrix for full analysis
Returns: Copying score (higher = more copying behavior)
Class: AttentionPatternDetector¶
Detector for classifying attention head types.
AttentionPatternDetector(
induction_threshold: float = 0.4,
previous_token_threshold: float = 0.5,
first_token_threshold: float = 0.3,
current_token_threshold: float = 0.3
)
Parameters:
- induction_threshold (
float): Threshold for induction head classification - previous_token_threshold (
float): Threshold for previous token head - first_token_threshold (
float): Threshold for first token head - current_token_threshold (
float): Threshold for current token (self-attention) head
Methods:
analyze_head¶
Compute all pattern scores for a single attention head.
analyze_head(
attention_pattern: np.ndarray,
seq_len_for_induction: Optional[int] = None
) -> Dict[str, float]
Returns: Dict of pattern type -> score
classify_head¶
Classify an attention head into one or more types.
classify_head(
attention_pattern: np.ndarray,
seq_len_for_induction: Optional[int] = None
) -> List[str]
Returns: List of head type labels that exceed thresholds
Example:
from mlxterp.visualization import AttentionPatternDetector
detector = AttentionPatternDetector(
previous_token_threshold=0.5,
first_token_threshold=0.3
)
# Analyze a head
scores = detector.analyze_head(attention_pattern)
print(f"Scores: {scores}")
# {'previous_token': 0.85, 'first_token': 0.1, 'current_token': 0.05}
# Classify
types = detector.classify_head(attention_pattern)
print(f"Classification: {types}") # ['previous_token']
Function: detect_head_types¶
Detect attention head types across a model.
detect_head_types(
model: InterpretableModel,
text: str,
threshold: float = 0.4,
layers: Optional[List[int]] = None
) -> Dict[str, List[Tuple[int, int]]]
Parameters:
- model (
InterpretableModel): Model to analyze - text (
str): Input text for analysis - threshold (
float, default:0.4): Score threshold for classification - layers (
Optional[List[int]]): Specific layers to analyze (None = all)
Returns: Dict mapping head type to list of (layer, head) tuples
Example:
from mlxterp.visualization import detect_head_types
head_types = detect_head_types(
model,
"The quick brown fox jumps over the lazy dog",
threshold=0.3,
layers=[0, 5, 10, 15]
)
print(f"Previous token heads: {len(head_types['previous_token'])}")
print(f"First token heads: {len(head_types['first_token'])}")
# Print specific heads
for layer, head in head_types['previous_token'][:5]:
print(f" L{layer}H{head}")
Function: detect_induction_heads¶
Detect induction heads using repeated random token sequences.
detect_induction_heads(
model: InterpretableModel,
n_random_tokens: int = 50,
n_repeats: int = 2,
threshold: float = 0.4,
layers: Optional[List[int]] = None,
seed: int = 42
) -> List[HeadScore]
Parameters:
- model (
InterpretableModel): Model to analyze - n_random_tokens (
int, default:50): Number of random tokens in subsequence - n_repeats (
int, default:2): Number of times to repeat the subsequence - threshold (
float, default:0.4): Score threshold for detection - layers (
Optional[List[int]]): Specific layers to analyze (None = all) - seed (
int, default:42): Random seed for reproducibility
Returns: List of HeadScore objects for heads above threshold, sorted by score descending
Example:
from mlxterp.visualization import detect_induction_heads
# Find induction heads
induction_heads = detect_induction_heads(
model,
n_random_tokens=50,
threshold=0.3,
layers=[0, 5, 10, 15]
)
print(f"Found {len(induction_heads)} induction heads")
for head in induction_heads[:10]:
print(f" L{head.layer}H{head.head}: {head.score:.3f}")
Class: HeadScore¶
Score for an attention head (dataclass).
Method: InterpretableModel.to_str_tokens¶
Convert text or token IDs to a list of token strings.
model.to_str_tokens(
input: Union[str, List[int], mx.array],
prepend_bos: bool = False
) -> List[str]
Parameters:
- input: Text string, list of token IDs, or mx.array of tokens
- prepend_bos (
bool, default:False): Whether to prepend BOS token (only for string input)
Returns: List of token strings
Example:
# From text
tokens = model.to_str_tokens("Hello world")
print(tokens) # ['<|begin_of_text|>', 'Hello', ' world']
# From token IDs
token_ids = model.encode("Hello world")
tokens = model.to_str_tokens(token_ids)
print(tokens) # ['<|begin_of_text|>', 'Hello', ' world']
Version History¶
0.1.0 (Current)¶
- Initial release
- Core tracing functionality
- Intervention system
- Basic utility functions
- Support for any MLX model
- New: Attention visualization module
- Attention weight capture during tracing
- Pattern extraction and visualization
- Pattern detection (induction, previous token, first token heads)
- Multiple backends (matplotlib, plotly, circuitsviz)