Skip to content

Commit e09b5f3

Browse files
authored
Release: 2026-01-14 (#76)
Squashed changes from error-vector-fix branch.
1 parent 9317b2a commit e09b5f3

9 files changed

+140
-49
lines changed

circuit_tracer/replacement_model/replacement_model_nnsight.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def fetch_activations(
303303
gemma_3_it = "gemma-3" in self.cfg.model_name and self.cfg.model_name.endswith("-it")
304304
overlap = 0
305305
if gemma_3_it:
306-
input_ids = self.input.squeeze(0)
306+
input_ids = self.input
307307
ignore_prefix = torch.tensor(
308308
[2, 105, 2364, 107], dtype=input_ids.dtype, device=input_ids.device
309309
)
@@ -541,7 +541,7 @@ def setup_attribution(self, inputs: str | torch.Tensor):
541541
# Compute error vectors
542542
error_vectors = mlp_out_cache - attribution_data["reconstruction"]
543543

544-
error_vectors[:, 0] = 0
544+
error_vectors[:, zero_positions] = 0
545545
token_vectors = self.embed_weight[ # type: ignore
546546
tokens
547547
].detach() # (n_pos, d_model) # type: ignore

circuit_tracer/transcoder/cross_layer_transcoder.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,9 @@ def select_decoder_vectors(self, features):
281281

282282
return pos_ids, layer_ids, feat_ids, decoder_vectors, encoder_mapping
283283

284-
def compute_reconstruction(self, pos_ids, layer_ids, decoder_vectors):
284+
def compute_reconstruction(
285+
self, pos_ids, layer_ids, decoder_vectors, input_acts: torch.Tensor | None = None
286+
):
285287
n_pos = pos_ids.max() + 1
286288
flat_idx = layer_ids * n_pos + pos_ids
287289
recon = torch.zeros(
@@ -290,11 +292,17 @@ def compute_reconstruction(self, pos_ids, layer_ids, decoder_vectors):
290292
device=decoder_vectors.device,
291293
dtype=decoder_vectors.dtype,
292294
).index_add_(0, flat_idx, decoder_vectors)
293-
return recon.reshape(self.n_layers, n_pos, self.d_model) + self.b_dec[:, None]
295+
recon = recon.reshape(self.n_layers, n_pos, self.d_model) + self.b_dec[:, None]
296+
if self.W_skip is not None:
297+
assert input_acts is not None, (
298+
"Transcoder has skip connection but no input_acts were provided"
299+
)
300+
recon = recon + input_acts @ self.W_skip
301+
return recon
294302

295-
def decode(self, features):
303+
def decode(self, features, input_acts: torch.Tensor | None = None):
296304
pos_ids, layer_ids, feat_ids, decoder_vectors, _ = self.select_decoder_vectors(features)
297-
return self.compute_reconstruction(pos_ids, layer_ids, decoder_vectors)
305+
return self.compute_reconstruction(pos_ids, layer_ids, decoder_vectors, input_acts)
298306

299307
def compute_skip(self, layer_id: int, inputs):
300308
if self.W_skip is not None:
@@ -330,7 +338,7 @@ def compute_attribution_components(self, inputs, zero_positions: slice = slice(0
330338
pos_ids, layer_ids, feat_ids, decoder_vectors, encoder_to_decoder_map = (
331339
self.select_decoder_vectors(features)
332340
)
333-
reconstruction = self.compute_reconstruction(pos_ids, layer_ids, decoder_vectors)
341+
reconstruction = self.compute_reconstruction(pos_ids, layer_ids, decoder_vectors, inputs)
334342

335343
return {
336344
"activation_matrix": features,

circuit_tracer/transcoder/single_layer_transcoder.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,15 @@ def encode(self, input_acts, apply_activation_function: bool = True):
124124
return pre_acts
125125
return self.activation_function(pre_acts)
126126

127-
def decode(self, acts):
127+
def decode(self, acts, input_acts: torch.Tensor | None = None):
128128
W_dec = self.W_dec
129-
return acts @ W_dec + self.b_dec
129+
reconstruction = acts @ W_dec + self.b_dec
130+
if self.W_skip is not None:
131+
assert input_acts is not None, (
132+
"Transcoder has skip connection but no input_acts were provided"
133+
)
134+
reconstruction = reconstruction + self.compute_skip(input_acts)
135+
return reconstruction
130136

131137
def compute_skip(self, input_acts):
132138
if self.W_skip is not None:
@@ -136,13 +142,9 @@ def compute_skip(self, input_acts):
136142

137143
def forward(self, input_acts):
138144
transcoder_acts = self.encode(input_acts)
139-
decoded = self.decode(transcoder_acts)
140-
decoded = decoded.detach()
141-
decoded.requires_grad = True
142-
143-
if self.W_skip is not None:
144-
skip = self.compute_skip(input_acts)
145-
decoded = decoded + skip
145+
decoded = self.decode(transcoder_acts, input_acts)
146+
# decoded = decoded.detach()
147+
# decoded.requires_grad = True
146148

147149
return decoded
148150

@@ -169,7 +171,7 @@ def encode_sparse(self, input_acts, zero_positions: slice = slice(0, 1)):
169171

170172
return sparse_acts, active_encoders
171173

172-
def decode_sparse(self, sparse_acts):
174+
def decode_sparse(self, sparse_acts, input_acts: torch.Tensor | None = None):
173175
"""Decode sparse activations and return reconstruction with scaled decoder vectors.
174176
175177
Returns:
@@ -189,6 +191,11 @@ def decode_sparse(self, sparse_acts):
189191
n_pos, self.d_model, device=sparse_acts.device, dtype=sparse_acts.dtype
190192
)
191193
reconstruction = reconstruction.index_add_(0, pos_idx, scaled_decoders)
194+
if self.W_skip is not None:
195+
assert input_acts is not None, (
196+
"Transcoder has skip connection but no input_acts were provided"
197+
)
198+
reconstruction = reconstruction + self.compute_skip(input_acts)
192199
reconstruction = reconstruction + self.b_dec
193200

194201
return reconstruction, scaled_decoders
@@ -319,9 +326,12 @@ def select_decoder_vectors(self, features):
319326
encoder_mapping,
320327
)
321328

322-
def decode(self, acts):
329+
def decode(self, acts, input_acts: torch.Tensor | None):
323330
return torch.stack(
324-
[transcoder.decode(acts[i]) for i, transcoder in enumerate(self.transcoders)], # type: ignore
331+
[
332+
transcoder.decode(acts[i], None if input_acts is None else input_acts[i])
333+
for i, transcoder in enumerate[SingleLayerTranscoder](self.transcoders) # type: ignore
334+
],
325335
dim=0,
326336
)
327337

@@ -349,11 +359,13 @@ def compute_attribution_components(
349359
decoder_vectors = []
350360
sparse_acts_list = []
351361

352-
for layer, transcoder in enumerate(self.transcoders):
353-
sparse_acts, active_encoders = transcoder.encode_sparse( # type: ignore
362+
for layer, transcoder in enumerate[SingleLayerTranscoder](self.transcoders): # type: ignore
363+
sparse_acts, active_encoders = transcoder.encode_sparse(
354364
mlp_inputs[layer], zero_positions=zero_positions
355365
)
356-
reconstruction[layer], active_decoders = transcoder.decode_sparse(sparse_acts) # type: ignore
366+
reconstruction[layer], active_decoders = transcoder.decode_sparse(
367+
sparse_acts, mlp_inputs[layer]
368+
)
357369
encoder_vectors.append(active_encoders)
358370
decoder_vectors.append(active_decoders)
359371
sparse_acts_list.append(sparse_acts)

tests/test_attributions_gemma3_nnsight.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,13 @@ def verify_token_and_error_edges(
149149
act_rtol=1e-3,
150150
logit_atol=1e-5,
151151
logit_rtol=1e-3,
152+
pos_start=1,
152153
):
153154
s = graph.input_tokens
154155
adjacency_matrix = graph.adjacency_matrix.to(device=model.device, dtype=model.dtype)
155156
active_features = graph.active_features.to(device=model.device)
156157
logit_tokens = graph.logit_tokens.to(device=model.device)
157158
total_active_features = active_features.size(0)
158-
pos_start = 1 # skip first position (BOS token)
159159

160160
ctx = model.setup_attribution(s)
161161

@@ -525,6 +525,24 @@ def test_gemma_3_1b():
525525
verify_feature_edges(model, graph)
526526

527527

528+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
529+
def test_gemma_3_1b_it():
530+
s = "<bos><start_of_turn>user\nThe National Digital Analytics Group (ND"
531+
model = ReplacementModel.from_pretrained(
532+
"google/gemma-3-1b-it",
533+
"mwhanna/gemma-scope-2-1b-it/transcoder_all/width_16k_l0_small_affine",
534+
dtype=torch.float32,
535+
backend="nnsight",
536+
)
537+
graph = attribute(s, model)
538+
assert isinstance(model, NNSightReplacementModel)
539+
540+
print("Changing logit softcap to 0, as the logits will otherwise be off.")
541+
with model.zero_softcap():
542+
verify_token_and_error_edges(model, graph, pos_start=4)
543+
verify_feature_edges(model, graph)
544+
545+
528546
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
529547
def test_gemma_3_1b_clt():
530548
s = "The National Digital Analytics Group (ND"
@@ -569,5 +587,6 @@ def test_gemma_3_4b():
569587
test_gemma3_with_dummy_transcoders()
570588
test_gemma3_with_dummy_clt()
571589
test_gemma_3_1b()
590+
test_gemma_3_1b_it()
572591
test_gemma_3_1b_clt()
573592
test_gemma_3_4b()

tests/test_attributions_llama.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,22 @@ def test_llama_3_2_1b():
225225
verify_feature_edges(model, graph)
226226

227227

228+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
229+
def test_llama_3_2_1b_clt():
230+
s = "The National Digital Analytics Group (ND"
231+
model = ReplacementModel.from_pretrained(
232+
"meta-llama/Llama-3.2-1B", "mntss/clt-llama-3.2-1b-524k"
233+
)
234+
assert isinstance(model, TransformerLensReplacementModel)
235+
graph = attribute(s, model, batch_size=128)
236+
237+
verify_token_and_error_edges(model, graph)
238+
verify_feature_edges(model, graph)
239+
240+
228241
if __name__ == "__main__":
229242
torch.manual_seed(42)
230243
test_small_llama_model()
231244
test_large_llama_model()
232245
test_llama_3_2_1b()
246+
test_llama_3_2_1b_clt()

tests/test_cross_layer_transcoder.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def cleanup_cuda():
2222
def create_test_clt_files():
2323
"""Create temporary CLT safetensors files for testing."""
2424

25-
def _create_files(n_layers=4, d_model=128, d_transcoder=512):
25+
def _create_files(n_layers=4, d_model=128, d_transcoder=512, skip_connection=False):
2626
tmpdir = tempfile.mkdtemp()
2727

2828
# Create encoder and decoder files for each layer
@@ -41,6 +41,11 @@ def _create_files(n_layers=4, d_model=128, d_transcoder=512):
4141
dec_path = os.path.join(tmpdir, f"W_dec_{i}.safetensors")
4242
save_file(dec_dict, dec_path)
4343

44+
if skip_connection:
45+
skip_dict = {"W_skip": torch.randn(d_model, d_model)}
46+
skip_path = os.path.join(tmpdir, "W_skip.safetensors")
47+
save_file(skip_dict, skip_path)
48+
4449
return tmpdir
4550

4651
return _create_files
@@ -49,9 +54,10 @@ def _create_files(n_layers=4, d_model=128, d_transcoder=512):
4954
# === Attribution Tests ===
5055

5156

52-
def test_compute_attribution_components(create_test_clt_files):
57+
@pytest.mark.parametrize("skip_connection", [False, True])
58+
def test_compute_attribution_components(create_test_clt_files, skip_connection):
5359
"""Test the main attribution functionality of CLT."""
54-
clt_path = create_test_clt_files()
60+
clt_path = create_test_clt_files(skip_connection=skip_connection)
5561
clt = load_clt(
5662
clt_path,
5763
device=torch.device("cpu"),
@@ -64,7 +70,7 @@ def test_compute_attribution_components(create_test_clt_files):
6470
inputs = torch.randn(clt.n_layers, n_pos, clt.d_model, dtype=clt.b_enc.dtype)
6571

6672
# Compute attribution components
67-
components = clt.compute_attribution_components(inputs)
73+
components = clt.compute_attribution_components(inputs, zero_positions=slice(0, 1))
6874

6975
# Verify all required components are present
7076
assert "activation_matrix" in components
@@ -79,9 +85,10 @@ def test_compute_attribution_components(create_test_clt_files):
7985
assert act_matrix.is_sparse
8086
assert act_matrix.shape == (clt.n_layers, n_pos, clt.d_transcoder)
8187

82-
# Check reconstruction
88+
# Check reconstruction (only positions 1 and beyond)
8389
reconstruction = components["reconstruction"]
8490
assert reconstruction.shape == (clt.n_layers, n_pos, clt.d_model)
91+
assert torch.allclose(reconstruction[:, 1:], clt(inputs)[:, 1:])
8592

8693
# Check encoder/decoder vectors have consistent counts
8794
n_active_encoders = act_matrix._nnz()
@@ -93,7 +100,7 @@ def test_compute_attribution_components(create_test_clt_files):
93100

94101
# Check decoder locations
95102
decoder_locs = components["decoder_locations"]
96-
assert decoder_locs.shape[0] == 2 # layer and position indices
103+
assert decoder_locs.shape[0] == 2
97104

98105

99106
def test_encode_sparse_with_lazy_encoder(create_test_clt_files):

tests/test_single_layer_transcoder.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def cleanup_cuda():
2323
def create_test_transcoder_file():
2424
"""Create a temporary transcoder safetensors file for testing."""
2525

26-
def _create_file(d_model=128, d_sae=512):
26+
def _create_file(d_model=128, d_sae=512, skip_connection=False):
2727
with tempfile.NamedTemporaryFile(suffix=".safetensors", delete=False) as f:
2828
W_enc = torch.randn(d_sae, d_model)
2929
W_dec = torch.randn(d_sae, d_model)
@@ -37,6 +37,9 @@ def _create_file(d_model=128, d_sae=512):
3737
"b_dec": b_dec,
3838
}
3939

40+
if skip_connection:
41+
state_dict["W_skip"] = torch.randn(d_model, d_model)
42+
4043
save_file(state_dict, f.name)
4144
return f.name, state_dict
4245

@@ -58,13 +61,16 @@ def _create_and_track(*args, **kwargs):
5861
# === Attribution Tests ===
5962

6063

61-
def test_transcoder_set_attribution_components(create_test_transcoder_file):
64+
@pytest.mark.parametrize("skip_connection", [False, True])
65+
def test_transcoder_set_attribution_components(create_test_transcoder_file, skip_connection):
6266
"""Test compute_attribution_components functionality."""
6367
# Create test files for multiple layers
6468
n_layers = 3
6569
paths = {}
6670
for layer in range(n_layers):
67-
path, _ = create_test_transcoder_file(d_model=128, d_sae=512)
71+
path, _ = create_test_transcoder_file(
72+
d_model=128, d_sae=512, skip_connection=skip_connection
73+
)
6874
paths[layer] = path
6975

7076
transcoder_set = load_transcoder_set(
@@ -74,7 +80,7 @@ def test_transcoder_set_attribution_components(create_test_transcoder_file):
7480
feature_output_hook="hook_mlp_out",
7581
device=torch.device("cpu"),
7682
lazy_encoder=False,
77-
lazy_decoder=True, # Test with lazy decoder
83+
lazy_decoder=True,
7884
)
7985

8086
# Create test MLP inputs
@@ -83,7 +89,9 @@ def test_transcoder_set_attribution_components(create_test_transcoder_file):
8389
mlp_inputs = torch.randn(n_layers, n_pos, d_model)
8490

8591
# Compute attribution components
86-
components = transcoder_set.compute_attribution_components(mlp_inputs)
92+
components = transcoder_set.compute_attribution_components(
93+
mlp_inputs, zero_positions=slice(0, 1)
94+
)
8795

8896
# Verify all required components are present
8997
assert "activation_matrix" in components
@@ -98,9 +106,13 @@ def test_transcoder_set_attribution_components(create_test_transcoder_file):
98106
assert act_matrix.is_sparse
99107
assert act_matrix.shape == (n_layers, n_pos, 512)
100108

101-
# Check reconstruction
109+
# Check reconstruction (only positions 1 and beyond)
102110
reconstruction = components["reconstruction"]
103111
assert reconstruction.shape == (n_layers, n_pos, d_model)
112+
for layer, transcoder in enumerate(transcoder_set.transcoders):
113+
assert torch.allclose(
114+
reconstruction[layer, 1:], transcoder(mlp_inputs[layer])[1:], rtol=1e-4, atol=1e-4
115+
)
104116

105117
# Check encoder/decoder vectors have matching counts
106118
n_active = act_matrix._nnz()
@@ -110,7 +122,7 @@ def test_transcoder_set_attribution_components(create_test_transcoder_file):
110122

111123
# Check decoder locations
112124
decoder_locs = components["decoder_locations"]
113-
assert decoder_locs.shape == (2, n_active) # layer and position indices
125+
assert decoder_locs.shape == (2, n_active)
114126

115127

116128
def test_sparse_encode_decode(create_test_transcoder_file):

0 commit comments

Comments
 (0)