Advanced Features#
Attention Probabilities#
For supported models, access attention probabilities:
with model.trace("The cat sat on the mat"):
# Access attention probabilities for layer 5
attn_probs = model.attention_probabilities[5].save()
# Shape: (batch, heads, seq_len, seq_len)
# Modify attention patterns
attn_probs[:, :, :, 0] = 0 # Remove attention to first token
attn_probs /= attn_probs.sum(dim=-1, keepdim=True) # Renormalize
modified_logits = model.logits.save()
Check what’s happening:
model.attention_probabilities.print_source()
Prompt Utilities#
Track probabilities of specific tokens:
from nnterp.prompt_utils import Prompt, run_prompts
# Create prompt with target tokens
prompt = Prompt.from_strings(
"The capital of France (not England or Spain) is",
{
"target": "Paris",
"traps": ["London", "Madrid"],
"longstring": "the country of France",
},
model.tokenizer,
)
# Check what tokens are tracked
for name, tokens in prompt.target_tokens.items():
print(f"{name}: {model.tokenizer.convert_ids_to_tokens(tokens)}")
# Get probabilities
results = run_prompts(model, [prompt])
for target, probs in results.items():
print(f"{target}: {probs.shape}") # Shape: (batch_size,)
Combined with Interventions#
from nnterp.interventions import logit_lens
# Use interventions with target tracking
results = run_prompts(model, prompts, get_probs_func=logit_lens)
# Returns probabilities for each target category across all layers
Visualization#
Plot top tokens at each layer:
from nnterp.display import plot_topk_tokens, prompts_to_df
probs = logit_lens(model, "The capital of France is")
# Interactive plot
plot_topk_tokens(
probs[0], # First prompt
model.tokenizer,
k=5,
title="Top 5 tokens at each layer"
)
Prompt Analysis#
# Convert prompts to DataFrame
df = prompts_to_df(prompts, model.tokenizer)
display(df)
Plot Target Evolution#
import plotly.graph_objects as go
# From logit lens results with target tracking
results = run_prompts(model, prompts, get_probs_func=logit_lens)
fig = go.Figure()
for category, probs in results.items():
fig.add_trace(go.Scatter(
x=list(range(len(probs[0]))),
y=probs[0].tolist(),
mode="lines+markers",
name=category
))
fig.update_layout(
title="Target Token Probabilities Across Layers",
xaxis_title="Layer",
yaxis_title="Probability"
)
fig.show()