Skip to content

Conversation-Level Analysis Guide

Overview

Most interpretability research analyzes single prompts. But LLMs are used in multi-turn conversations where meaning builds across turns, context accumulates, and interventions on early turns cascade through later ones.

mlxterp provides conversation-level analysis tools that let you:

  • Trace full conversations as a single forward pass
  • Detect turn boundaries automatically from chat templates
  • Slice activations by turn to study per-turn representations
  • Measure cross-turn attention to understand information flow between turns

Key Concepts

Turns and Chat Templates

Chat models insert special tokens between turns. For example, Llama 3:

<|begin_of_text|><|start_header_id|>user<|end_header_id|>
My name is Alice.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
Nice to meet you!<|eot_id|>

mlxterp automatically detects these boundaries and separates content tokens (the actual message) from template tokens (role markers, end-of-turn markers).

The Turn Dataclass

Each turn tracks its position in the full token sequence:

from mlxterp.conversation import Turn

# Turn.content_start / content_end — just the message text
# Turn.full_start / full_end — including template overhead
# Turn.role — "user", "assistant", "system"

Quick Start

from mlxterp import InterpretableModel

model = InterpretableModel("mlx-community/Llama-3.2-1B-Instruct-4bit")

conversation = [
    {"role": "user", "content": "My name is Alice."},
    {"role": "assistant", "content": "Nice to meet you, Alice!"},
    {"role": "user", "content": "What's my name?"},
]

with model.conversation_trace(conversation) as ct:
    # See detected turns
    print(ct.turns)

    # Get activations for the question turn
    question_act = ct.get_turn_activation(2, "layers.10")
    print(f"Question turn activations shape: {question_act.shape}")

    # Cross-turn attention: does the model look back at turn 0?
    cross_attn = ct.cross_turn_attention(layer=10, head=0)
    print(f"Turn 2 -> Turn 0 attention: {float(cross_attn[2, 0]):.4f}")

Turn Detection

Automatic Detection

detect_turns() uses the tokenizer's chat template to find boundaries:

from mlxterp.conversation import detect_turns

messages = [
    {"role": "user", "content": "Hello"},
    {"role": "assistant", "content": "Hi there!"},
    {"role": "user", "content": "How are you?"},
]

turns = detect_turns(model.tokenizer, messages)

for turn in turns:
    print(f"Turn {turn.index} ({turn.role}): "
          f"content tokens {turn.content_start}:{turn.content_end}, "
          f"full tokens {turn.full_start}:{turn.full_end}")

Supported Templates

Model Family Template Support
Llama 3 <\|start_header_id\|>...<\|eot_id\|> Auto-detected
ChatML (Qwen, etc.) <\|im_start\|>...<\|im_end\|> Auto-detected
Gemma <start_of_turn>...<end_of_turn> Auto-detected
Custom Any with apply_chat_template Supported

Filtering Turns

turns = detect_turns(tokenizer, messages)

# Get all user turns
user_turns = turns.by_role("user")
print(f"User turns: {len(user_turns)}")

# Get assistant turns
assistant_turns = turns.by_role("assistant")

# Slice turns
first_two = turns[0:2]

# List all roles
print(turns.roles)  # ["user", "assistant", "user"]

# Get all content token positions
positions = turns.content_positions()

Conversation Tracing

Per-Turn Activations

Extract activations for a specific turn only:

with model.conversation_trace(conversation) as ct:
    # Content only (default) — excludes template tokens
    turn0_act = ct.get_turn_activation(0, "layers.5")

    # Full turn — includes role markers and end tokens
    turn0_full = ct.get_turn_activation(0, "layers.5", content_only=False)

    print(f"Content tokens: {turn0_act.shape}")
    print(f"Full turn tokens: {turn0_full.shape}")

Cross-Turn Attention

Measure how much each turn attends to other turns:

with model.conversation_trace(conversation) as ct:
    # (n_turns, n_turns) matrix
    cross_attn = ct.cross_turn_attention(layer=10, head=0)

    # Does turn 2 (the question) attend to turn 0 (the name)?
    print(f"Question -> Name turn: {float(cross_attn[2, 0]):.4f}")
    print(f"Question -> Response turn: {float(cross_attn[2, 1]):.4f}")

Converting to Result

with model.conversation_trace(conversation) as ct:
    result = ct.to_result()

print(result.summary())
print(result.to_json())

Example: Name Recall Analysis

Investigate whether the model remembers a name from an earlier turn:

model = InterpretableModel("mlx-community/Llama-3.2-1B-Instruct-4bit")

conversation = [
    {"role": "user", "content": "My name is Alice."},
    {"role": "assistant", "content": "Nice to meet you, Alice!"},
    {"role": "user", "content": "What did I say my name was?"},
]

with model.conversation_trace(conversation) as ct:
    print(f"Detected {len(ct.turns)} turns")

    # Which layers attend back to the name?
    for layer in range(16):
        cross = ct.cross_turn_attention(layer=layer, head=0)
        if cross is not None:
            score = float(cross[2, 0])  # Question -> Name turn
            if score > 0.05:
                print(f"Layer {layer}: turn 2 attends to turn 0 with score {score:.4f}")

    # Get the activation at the question turn
    question_act = ct.get_turn_activation(2, "layers.10")
    print(f"Question turn activation shape: {question_act.shape}")

Example: Information Flow Across Turns

conversation = [
    {"role": "user", "content": "The capital of France is Paris."},
    {"role": "assistant", "content": "That's correct!"},
    {"role": "user", "content": "What's the capital of France?"},
]

with model.conversation_trace(conversation) as ct:
    # Check attention flow at every layer
    for layer in [5, 10, 15]:
        for head in range(4):  # Check first 4 heads
            cross = ct.cross_turn_attention(layer=layer, head=head)
            if cross is not None:
                # How much does the question attend to the fact?
                fact_attn = float(cross[2, 0])
                if fact_attn > 0.1:
                    print(f"L{layer}H{head}: question->fact = {fact_attn:.3f}")

API Reference Summary

Class/Function Purpose
Turn Dataclass for turn boundaries (positions, role)
TurnList Container with indexing, slicing, by_role()
detect_turns(tokenizer, messages) Auto-detect turn boundaries
model.conversation_trace(messages) Context manager for multi-turn tracing
ct.get_turn_activation(turn_idx, component) Slice activations to a turn
ct.cross_turn_attention(layer, head) Turn x turn attention matrix
ct.to_result() Convert to ConversationResult