NNsight Best Practices#
This guide covers essential NNsight patterns and best practices for efficient and correct usage with nnterp. Understanding these patterns will help you avoid common pitfalls and write more performant code.
Execution Order Requirements#
Critical Rule: You must access model internals in the same order as the model executes them.
In NNsight, interventions must be written in forward pass order. This means you cannot access layer 2’s output before layer 1’s output.
from nnterp import StandardizedTransformer
model = StandardizedTransformer("gpt2")
# ✅ CORRECT: Access layers in forward order
with model.trace("My tailor is rich"):
l1 = model.layers_output[1] # Access layer 1 first
l2 = model.layers_output[2] # Then layer 2
logits = model.logits # Finally the output
# ❌ INCORRECT: This will fail!
try:
with model.trace("My tailor is rich"):
l2 = model.layers_output[2] # Access layer 2 first
l1 = model.layers_output[1] # Then layer 1 - ERROR!
except Exception as e:
print(f"Error: {e}")
This applies to all model components:
with model.trace("Hello"):
# ✅ CORRECT: Forward pass order
attn_0 = model.attentions_output[0] # Layer 0 attention
mlp_0 = model.mlps_output[0] # Layer 0 MLP
layer_0 = model.layers_output[0] # Layer 0 output
attn_1 = model.attentions_output[1] # Layer 1 attention
# ... and so on
Gradient Computation#
To compute gradients, you must use the .backward()
context and save gradients inside it.
Basic Gradient Computation#
import torch
with model.trace("My tailor is rich"):
# Save the activation you want gradients for
l1_out = model.layers_output[1]
# Access model output (must be after l1_out)
logits = model.output.logits
# Compute gradients inside backward context
with logits.sum().backward():
l1_grad = l1_out.grad.save()
Multiple Backward Passes#
For multiple gradient computations, use retain_graph=True
:
with model.trace("My tailor is rich"):
l1_out = model.layers_output[1]
logits = model.output.logits
# First backward pass
with logits.sum().backward(retain_graph=True):
l1_grad_1 = l1_out.grad.save()
# Second backward pass with different objective
with (logits.sum() ** 2).backward():
l1_grad_2 = l1_out.grad.save()
# Gradients will be different
assert not torch.allclose(l1_grad_1, l1_grad_2)
Don’t forget the execution order!#
# ❌ INCORRECT: Accessing layers_output[1] after logits
with model.trace("My tailor is rich"):
logits = model.logits
with logits.sum().backward():
# This would fail - can't access layers_output[1] after output
l1_grad = model.layers_output[1].grad.save()
# ✅ CORRECT: Save activation first
with model.trace("My tailor is rich"):
l1_out = model.layers_output[1] # Save first
logits = model.logits
with logits.sum().backward():
l1_grad = l1_out.grad.save() # Use saved activation
Performance Optimization#
Use tracer.stop()
to Skip Unnecessary Computations#
When you only need intermediate activations, use tracer.stop()
to prevent the model from computing subsequent layers:
import time
# Without tracer.stop() - computes all layers
start = time.time()
for _ in range(10):
with model.trace("Hello world"):
layer_5_out = model.layers_output[5].save()
time_without_stop = time.time() - start
# With tracer.stop() - only computes up to layer 5
start = time.time()
for _ in range(10):
with model.trace("Hello world") as tracer:
layer_5_out = model.layers_output[5].save()
tracer.stop() # Stop here - don't compute remaining layers
time_with_stop = time.time() - start
print(f"Speedup: {time_without_stop / time_with_stop:.2f}x")
This can provide significant speedups (often 2-5x) when working with large models and only analyzing intermediate layers.
Caching Activations#
NNsight 0.5+ includes built-in caching for collecting multiple activations efficiently:
# Cache activations from multiple layers
with model.trace("Hello world") as tracer:
# Cache every other layer
cache = tracer.cache(
modules=[model.layers[i] for i in range(0, model.num_layers, 2)]
).save()
# Don't call tracer.stop() before cache is accessed!
# Access cached activations
print(cache.keys()) # Shows module names
print(cache["model.layers.0"].output.shape) # Layer 0 output
print(cache["model.layers.2"].output.shape) # Layer 2 output
Important: As of 0.5.dev8, the cache uses original module names, not nnterp’s renamed names.