@@ -124,9 +124,15 @@ def encode(self, input_acts, apply_activation_function: bool = True):
124124 return pre_acts
125125 return self .activation_function (pre_acts )
126126
127- def decode (self , acts ):
127+ def decode (self , acts , input_acts : torch . Tensor | None = None ):
128128 W_dec = self .W_dec
129- return acts @ W_dec + self .b_dec
129+ reconstruction = acts @ W_dec + self .b_dec
130+ if self .W_skip is not None :
131+ assert input_acts is not None , (
132+ "Transcoder has skip connection but no input_acts were provided"
133+ )
134+ reconstruction = reconstruction + self .compute_skip (input_acts )
135+ return reconstruction
130136
131137 def compute_skip (self , input_acts ):
132138 if self .W_skip is not None :
@@ -136,13 +142,9 @@ def compute_skip(self, input_acts):
136142
137143 def forward (self , input_acts ):
138144 transcoder_acts = self .encode (input_acts )
139- decoded = self .decode (transcoder_acts )
140- decoded = decoded .detach ()
141- decoded .requires_grad = True
142-
143- if self .W_skip is not None :
144- skip = self .compute_skip (input_acts )
145- decoded = decoded + skip
145+ decoded = self .decode (transcoder_acts , input_acts )
146+ # decoded = decoded.detach()
147+ # decoded.requires_grad = True
146148
147149 return decoded
148150
@@ -169,7 +171,7 @@ def encode_sparse(self, input_acts, zero_positions: slice = slice(0, 1)):
169171
170172 return sparse_acts , active_encoders
171173
172- def decode_sparse (self , sparse_acts ):
174+ def decode_sparse (self , sparse_acts , input_acts : torch . Tensor | None = None ):
173175 """Decode sparse activations and return reconstruction with scaled decoder vectors.
174176
175177 Returns:
@@ -189,6 +191,11 @@ def decode_sparse(self, sparse_acts):
189191 n_pos , self .d_model , device = sparse_acts .device , dtype = sparse_acts .dtype
190192 )
191193 reconstruction = reconstruction .index_add_ (0 , pos_idx , scaled_decoders )
194+ if self .W_skip is not None :
195+ assert input_acts is not None , (
196+ "Transcoder has skip connection but no input_acts were provided"
197+ )
198+ reconstruction = reconstruction + self .compute_skip (input_acts )
192199 reconstruction = reconstruction + self .b_dec
193200
194201 return reconstruction , scaled_decoders
@@ -319,9 +326,12 @@ def select_decoder_vectors(self, features):
319326 encoder_mapping ,
320327 )
321328
322- def decode (self , acts ):
329+ def decode (self , acts , input_acts : torch . Tensor | None ):
323330 return torch .stack (
324- [transcoder .decode (acts [i ]) for i , transcoder in enumerate (self .transcoders )], # type: ignore
331+ [
332+ transcoder .decode (acts [i ], None if input_acts is None else input_acts [i ])
333+ for i , transcoder in enumerate [SingleLayerTranscoder ](self .transcoders ) # type: ignore
334+ ],
325335 dim = 0 ,
326336 )
327337
@@ -349,11 +359,13 @@ def compute_attribution_components(
349359 decoder_vectors = []
350360 sparse_acts_list = []
351361
352- for layer , transcoder in enumerate (self .transcoders ):
353- sparse_acts , active_encoders = transcoder .encode_sparse ( # type: ignore
362+ for layer , transcoder in enumerate [ SingleLayerTranscoder ] (self .transcoders ): # type: ignore
363+ sparse_acts , active_encoders = transcoder .encode_sparse (
354364 mlp_inputs [layer ], zero_positions = zero_positions
355365 )
356- reconstruction [layer ], active_decoders = transcoder .decode_sparse (sparse_acts ) # type: ignore
366+ reconstruction [layer ], active_decoders = transcoder .decode_sparse (
367+ sparse_acts , mlp_inputs [layer ]
368+ )
357369 encoder_vectors .append (active_encoders )
358370 decoder_vectors .append (active_decoders )
359371 sparse_acts_list .append (sparse_acts )
0 commit comments