Architecture¶
This document explains the design and implementation of mlxterp.
Design Philosophy¶
mlxterp follows these core principles:
- Simplicity First: Make mechanistic interpretability accessible with minimal code
- Clean API: Context managers and direct attribute access (inspired by nnterp)
- Model Agnostic: Generic wrapping works with any MLX model (inspired by nnsight)
- MLX Native: Leverage MLX's unique features (lazy evaluation, unified memory)
- Minimal Abstractions: Avoid unnecessary complexity and boilerplate
System Architecture¶
┌─────────────────────────────────────────┐
│ InterpretableModel │
│ (User-facing API) │
└──────────────┬──────────────────────────┘
│
┌───────┴───────┐
│ │
┌──────▼─────┐ ┌─────▼──────┐
│ Trace │ │LayerListProxy│
│(Context Mgr)│ │(Layer Access)│
└──────┬─────┘ └─────┬───────┘
│ │
└───────┬───────┘
│
┌──────▼───────┐
│ TraceContext │
│(Global State)│
└──────┬───────┘
│
┌──────────┼──────────┐
│ │ │
┌───▼───┐ ┌──▼──┐ ┌───▼────┐
│Proxies│ │Cache│ │Intervene│
└───────┘ └─────┘ └────────┘
Core Components¶
1. InterpretableModel¶
File: mlxterp/model.py
The main entry point for users. Provides:
- Model loading (from string or nn.Module)
- Layer discovery and wrapping
- Trace context creation
- Delegation to underlying model
Key Methods:
- trace(): Create tracing context
- named_modules(): Iterate over modules
- parameters(): Access parameters
Design Decision: Uses composition over inheritance - wraps models rather than subclassing.
2. Proxy System¶
File: mlxterp/core/proxy.py
Implements transparent wrapping for clean attribute access.
ModuleProxy¶
Wraps nn.Module instances to intercept forward passes:
class ModuleProxy:
def __call__(self, *args, **kwargs):
# Call original module
result = self._module(*args, **kwargs)
# Capture in context
ctx = TraceContext.current()
if ctx:
ctx.activations[self._name] = result
return result
Key Features: - Transparent wrapping (user doesn't see proxy) - Automatic activation capture - Intervention support - Lazy submodule wrapping
OutputProxy¶
Provides .save() functionality:
class OutputProxy:
def save(self):
ctx = TraceContext.current()
ctx.save(self._name, self._value)
return self._value
LayerListProxy¶
Provides indexed access to layers:
model.layers[3] # Returns ModuleProxy
model.layers[3].self_attn.output # Returns OutputProxy (use self_attn for mlx-lm models)
3. Trace Context¶
File: mlxterp/core/trace.py
Context manager for tracing execution.
Lifecycle¶
__enter__: Setup context, push to global stack, execute forward pass- User code: Access activations, call
.save()on outputs __exit__: Copy saved values and activations, restore layers, pop context
Design Decision: Forward pass executes in __enter__ so activations are immediately available inside the context. Saved values and activations are copied in __exit__ so they remain available after the context.
Global State Management¶
Uses a stack to support nested traces (though typically only one active):
class TraceContext:
_stack = []
@classmethod
def current(cls):
return cls._stack[-1] if cls._stack else None
4. Intervention System¶
File: mlxterp/core/intervention.py
Interventions are simply functions: (mx.array) -> mx.array
Design Decision: Functional approach over class hierarchies for simplicity.
Pre-built Interventions¶
zero_out: Set to zerosscale(factor): Multiply by constantadd_vector(vec): Add steering vectorreplace_with(value): Replace with fixed valueclamp(min, max): Clamp to rangenoise(std): Add Gaussian noise
Composition¶
5. Caching System¶
File: mlxterp/core/cache.py
Simple dictionary-based caching with utility functions:
ActivationCache: Storage containercollect_activations(): Helper for common patterns
Design Decision: Simple dict storage rather than complex cache invalidation - rely on MLX's lazy evaluation.
MLX Integration¶
Leveraging MLX Features¶
1. Lazy Evaluation¶
MLX arrays are lazy - computation deferred until mx.eval():
# These don't execute immediately
a = mx.random.normal((100, 100))
b = mx.matmul(a, a)
c = mx.sum(b)
# Execute now
mx.eval(c)
How mlxterp uses it:
- Activations captured as lazy arrays
- Only materialized when accessed via .save()
- Efficient memory usage
2. Unified Memory¶
MLX arrays live in shared memory - no CPU/GPU transfers:
# Automatically available on both CPU and GPU
arr = mx.array([1, 2, 3])
# No need for .to('cuda') or .cpu()
How mlxterp uses it: - Zero-copy activation caching - Efficient intervention application - No device management needed
3. Dynamic Graphs¶
Computation graphs built dynamically (like PyTorch eager mode):
How mlxterp uses it: - Easy to intercept operations - No need for graph tracing hooks - Simple proxy-based wrapping works
4. Module System¶
MLX modules are dict-based:
How mlxterp uses it:
- Easy parameter inspection via tree_flatten()
- Natural module traversal via named_modules()
- Simple attribute access patterns
Implementation Patterns¶
Pattern 1: Proxy-Based Interception¶
Instead of monkey-patching or hooks:
# MLX module
original_module = nn.Linear(128, 128)
# Wrap with proxy
proxy = ModuleProxy(original_module, "layer_name")
# Calls are intercepted transparently
output = proxy(input) # Captured in trace context
Benefits: - No modification of original module - Type-safe (proxy delegates to original) - Easy to add/remove
Pattern 2: Context Manager Pattern¶
Clean setup/teardown:
class Trace:
def __enter__(self):
# Setup: Create and push context
self.context = TraceContext()
TraceContext.push(self.context)
# Execute: Run model forward pass immediately
# This allows users to access activations inside the with block
self.output = self.model_forward(self.inputs)
return self
def __exit__(self, *args):
# Copy saved_values from context (for values saved inside the block)
self.saved_values = self.context.saved_values.copy()
# Cleanup: Pop context
TraceContext.pop()
Benefits: - Guaranteed cleanup - Clear lifecycle - Pythonic API
Pattern 3: Functional Interventions¶
Simple functions over classes:
# Intervention is just a function
def scale(factor: float):
def _scale(x: mx.array) -> mx.array:
return x * factor
return _scale
# Use it
interventions = {"layers.3": scale(0.5)}
Benefits: - Easy to compose - No inheritance complexity - Clear semantics
Pattern 4: Lazy Proxy Creation¶
Proxies created on-demand:
class ModuleProxy:
def __getattr__(self, name):
if name not in self._subproxies:
# Create proxy only when accessed
attr = getattr(self._module, name)
if isinstance(attr, nn.Module):
self._subproxies[name] = ModuleProxy(attr, f"{self._name}.{name}")
return self._subproxies[name]
Benefits: - Efficient memory usage - Automatic submodule discovery - No upfront cost
Comparison with Other Libraries¶
vs TransformerLens¶
| Aspect | mlxterp | TransformerLens |
|---|---|---|
| Approach | Generic wrapping | Model-specific classes |
| Models | Any MLX model | Specific architectures |
| API | Context managers | Direct methods |
| Framework | MLX | PyTorch |
| Lines of Code | ~600 | ~10,000+ |
Trade-offs: - mlxterp: More flexible, less feature-complete - TransformerLens: More features, less flexible
vs nnsight¶
| Aspect | mlxterp | nnsight |
|---|---|---|
| Approach | Similar (generic) | Generic wrapping |
| API Style | Context managers | Context managers |
| Framework | MLX | PyTorch |
| Focus | Simplicity | Completeness |
Inspiration: mlxterp closely follows nnsight's design philosophy.
vs nnterp¶
| Aspect | mlxterp | nnterp |
|---|---|---|
| Approach | Similar | Standardized naming |
| API | Very similar | Clean context managers |
| Framework | MLX | PyTorch/nnsight |
| Focus | Apple Silicon | Model unification |
Inspiration: mlxterp adopts nnterp's clean API design.
Performance Considerations¶
Memory Efficiency¶
- Lazy Evaluation: Activations not computed until needed
- Selective Saving: Only save what you explicitly
.save() - Unified Memory: No device transfers
- Batch Processing: Utilities for large datasets
Computational Efficiency¶
- Minimal Overhead: Proxy calls are lightweight
- Metal Acceleration: Inherits MLX's Metal optimizations
- Zero-Copy Operations: Interventions don't copy data unnecessarily
Best Practices¶
# ✅ Good: Only save needed activations
with model.trace(input):
important = [3, 8, 12]
acts = {i: model.layers[i].output.save() for i in important}
# ❌ Avoid: Saving everything
with model.trace(input):
all_acts = [model.layers[i].output.save() for i in range(100)]
Extension Points¶
Custom Interventions¶
Easy to add:
def my_intervention(x: mx.array) -> mx.array:
# Your logic
return transformed_x
# Use immediately
with model.trace(input, interventions={"layers.3": my_intervention}):
output = model.output.save()
Custom Utilities¶
Built on core components:
from mlxterp.core import collect_activations
def my_analysis_function(model, prompts):
cache = collect_activations(model, prompts)
# Your analysis
return results
Integration with Other Tools¶
Compatible with MLX ecosystem:
# Use with mlx_lm
from mlx_lm import load
model_base, tokenizer = load("model-name")
model = InterpretableModel(model_base, tokenizer)
# Use with mlx training
# Access underlying model
model.model.train()
optimizer = optim.Adam(model.parameters())
Future Enhancements¶
Potential additions (maintaining simplicity):
- Visualization Tools: Attention pattern plotting
- Circuit Discovery: Automated path finding
- Sparse Autoencoders: Feature extraction support
- Logit Lens: Intermediate decoding utilities
- Benchmarks: Standard interpretability tasks
Summary¶
mlxterp achieves its goals through:
- Clean abstraction layers: Each component has clear responsibility
- MLX integration: Leverages framework strengths
- Simple patterns: Proxies, context managers, functions
- Minimal code: ~600 lines of core functionality
- Extensible design: Easy to add custom behavior
The architecture prioritizes simplicity and usability over feature completeness, making mechanistic interpretability accessible to researchers on Apple Silicon.