mlxterp Quick Start Guide¶
Get started with mlxterp in 5 minutes!
Installation¶
# Clone the repository
git clone https://github.com/coairesearch/mlxterp
cd mlxterp
# Install in development mode
pip install -e .
# Or install with optional dependencies
pip install -e ".[viz]"
Prerequisites¶
- macOS with Apple Silicon (M1/M2/M3)
- Python 3.9+
- MLX framework (
pip install mlx)
Your First Script¶
Create a file test_mlxterp.py:
from mlxterp import InterpretableModel
import mlx.core as mx
import mlx.nn as nn
# 1. Create a simple model
class SimpleTransformer(nn.Module):
def __init__(self, hidden_dim=64):
super().__init__()
self.layers = [
nn.Linear(hidden_dim, hidden_dim)
for _ in range(4)
]
def __call__(self, x):
for layer in self.layers:
x = layer(x)
x = nn.relu(x)
return x
# 2. Wrap with InterpretableModel
base_model = SimpleTransformer()
model = InterpretableModel(base_model)
# 3. Create input
input_data = mx.random.normal((1, 10, 64)) # (batch, seq, hidden)
# 4. Trace execution
with model.trace(input_data):
layer_0 = model.layers[0].output.save()
layer_2 = model.layers[2].output.save()
print(f"Layer 0 shape: {layer_0.shape}")
print(f"Layer 2 shape: {layer_2.shape}")
Run it:
Common Use Cases¶
1. Load a Real Model¶
from mlxterp import InterpretableModel
# Automatically loads model and tokenizer
# Note: MLX models require a quantization suffix (e.g., -4bit, -8bit, -bf16)
model = InterpretableModel("mlx-community/Llama-3.2-1B-Instruct-4bit")
# Use with text (use self_attn for mlx-lm models)
with model.trace("Hello world"):
attn = model.layers[5].self_attn.output.save()
2. Work with Tokenizers¶
# Encode text to tokens
tokens = model.encode("Hello world")
print(tokens) # [128000, 9906, 1917]
# Decode tokens to text
text = model.decode(tokens)
print(text) # "Hello world"
# Analyze individual tokens
for i, token_id in enumerate(tokens):
token_str = model.token_to_str(token_id)
print(f"Token {i}: '{token_str}'")
# Get vocabulary size
print(f"Vocab size: {model.vocab_size}")
3. Collect Activations¶
from mlxterp import get_activations
acts = get_activations(
model,
prompts=["Hello", "World"],
layers=[3, 8, 12],
positions=-1 # last token
)
print(acts["layer_3"].shape) # (2, hidden_dim)
4. Apply Interventions¶
from mlxterp import interventions as iv
# Scale down layer 4
with model.trace("Test", interventions={"layers.4": iv.scale(0.5)}):
output = model.output.save()
# Add steering vector
steering = mx.random.normal((hidden_dim,))
with model.trace("Test", interventions={"layers.5": iv.add_vector(steering)}):
output = model.output.save()
5. Logit Lens¶
See what each layer predicts at each token position:
# Analyze predictions across layers
text = "The capital of France is"
results = model.logit_lens(text, layers=[0, 5, 10, 15])
# See how prediction at LAST position evolves through layers
for layer_idx in [0, 5, 10, 15]:
last_pos_pred = results[layer_idx][-1][0][2]
print(f"Layer {layer_idx}: '{last_pos_pred}'")
# Output:
# Layer 0: ' the'
# Layer 5: ' a'
# Layer 10: ' Paris'
# Layer 15: ' Paris'
# Visualize with heatmap showing ALL positions
results = model.logit_lens(
"The Eiffel Tower is located in the city of",
plot=True,
max_display_tokens=15
)
# Shows: Input tokens (x-axis) × Layers (y-axis) with predicted tokens
6. Activation Patching¶
# Get clean activation
with model.trace("The capital of France is"):
clean_act = model.layers[8].output.save()
# Patch into different input
with model.trace("The capital of Spain is",
interventions={"layers.8": lambda x: clean_act}):
patched = model.output.save()
What's Next?¶
- Read the full README for detailed examples
- Check the API Reference for complete API documentation
- Explore the examples for more use cases
- Read the Architecture Guide to understand the design
Troubleshooting¶
"Could not load model" or "Repository Not Found"¶
Problem: Model loading fails with 404 error.
Solution: MLX Community models require a quantization suffix:
# WRONG - base model name doesn't exist
model = InterpretableModel("mlx-community/Llama-3.2-1B-Instruct")
# CORRECT - use quantization suffix
model = InterpretableModel("mlx-community/Llama-3.2-1B-Instruct-4bit") # 4-bit quantized
model = InterpretableModel("mlx-community/Llama-3.2-1B-Instruct-8bit") # 8-bit quantized
model = InterpretableModel("mlx-community/Llama-3.2-1B-Instruct-bf16") # bfloat16
Available quantization options vary by model. Check MLX Community for available variants.
"String input provided but no tokenizer available"¶
Problem: You're passing text but no tokenizer was loaded.
Solution: Pass a tokenizer or use token arrays:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("model-name")
model = InterpretableModel(base_model, tokenizer=tokenizer)
"Model does not have 'layers' attribute"¶
Problem: Your model uses a different attribute name.
Solution: Specify the correct attribute:
# For GPT-2
model = InterpretableModel(gpt2_model, layer_attr="h")
# For custom models
model = InterpretableModel(custom_model, layer_attr="transformer.blocks")
Import Error¶
Problem: ModuleNotFoundError: No module named 'mlxterp'
Solution: Install the package:
Community & Support¶
- Issues: GitHub Issues
- Discussions: GitHub Discussions
- Examples: See
examples/directory
Next Steps¶
Try these examples in order:
- Basic tracing (you just did this!)
- examples/basic_usage.py - More comprehensive examples
- Explore your own models
- Experiment with interventions
- Analyze activation patterns