@@ -172,26 +172,49 @@ def __init__(
172172 ]
173173 )
174174 if self .use_aux_hidden_state :
175- if hasattr (self .config , "target_hidden_size" ):
176- fc_input_size = self .config .target_hidden_size * 3
177- else :
178- fc_input_size = self .config .hidden_size * 3
175+ self .num_aux_hidden_states = getattr (
176+ self .config , "num_aux_hidden_states" , None
177+ )
178+ if self .num_aux_hidden_states is None :
179+ eagle_config = getattr (self .config , "eagle_config" , None ) or {}
180+ layer_ids = eagle_config .get ("eagle_aux_hidden_state_layer_ids" )
181+ self .num_aux_hidden_states = len (layer_ids ) if layer_ids else 3
182+
183+ target_hidden_size = getattr (
184+ self .config , "target_hidden_size" , self .config .hidden_size
185+ )
186+ self .fc_input_size = target_hidden_size * self .num_aux_hidden_states
187+
179188 if self .norm_before_fc :
180189 self .input_norm = RMSNorm (
181- fc_input_size ,
190+ self . fc_input_size ,
182191 eps = self .config .rms_norm_eps ,
183192 )
184193 else :
185194 self .input_norm = None
195+
196+ use_fc_norm = getattr (self .config , "fc_norm" , False )
197+ if use_fc_norm :
198+ self .fc_norm = nn .ModuleList (
199+ [
200+ RMSNorm (target_hidden_size , eps = self .config .rms_norm_eps )
201+ for _ in range (self .num_aux_hidden_states )
202+ ]
203+ )
204+ else :
205+ self .fc_norm = None
206+
186207 self .fc = ReplicatedLinear (
187- input_size = fc_input_size ,
208+ input_size = self . fc_input_size ,
188209 output_size = self .config .hidden_size ,
189210 bias = False ,
190211 params_dtype = vllm_config .model_config .dtype ,
191212 quant_config = self .quant_config ,
192213 prefix = maybe_prefix (prefix , "fc" ),
193214 return_bias = False ,
194215 )
216+
217+ self .norm_output = getattr (self .config , "norm_output" , False )
195218 self .norm = RMSNorm (
196219 self .config .hidden_size ,
197220 eps = self .config .rms_norm_eps ,
@@ -220,7 +243,11 @@ def forward(
220243 residual = residual ,
221244 )
222245 hidden_states , hidden_prenorm = self .norm (hidden_states , residual )
223- return hidden_states , hidden_prenorm
246+
247+ # norm_output variant uses the post-norm hidden states.
248+ aux_output = hidden_states if self .norm_output else hidden_prenorm
249+
250+ return hidden_states , aux_output
224251
225252 def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]) -> set [str ]:
226253 stacked_params_mapping = [
@@ -312,11 +339,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
312339 if self .use_parallel_drafting :
313340 self .register_buffer (
314341 "mask_hidden" ,
315- torch .zeros (
316- 1 ,
317- (3 if self .model .use_aux_hidden_state else 1 )
318- * self .config .hidden_size ,
319- ),
342+ torch .zeros (1 , self .model .fc_input_size ),
320343 persistent = False ,
321344 )
322345
@@ -371,6 +394,16 @@ def combine_hidden_states(
371394
372395 if self .model .norm_before_fc :
373396 hidden_states = self .model .input_norm (hidden_states )
397+
398+ # `norm_before_fc` adds a single RMSNorm before the FC layer, whereas `fc_norm`
399+ # applies separate RMSNorms to each chunk of the hidden states.
400+ if self .model .fc_norm is not None :
401+ chunks = hidden_states .chunk (self .model .num_aux_hidden_states , dim = - 1 )
402+ hidden_states = torch .cat (
403+ [norm (chunk ) for norm , chunk in zip (self .model .fc_norm , chunks )],
404+ dim = - 1 ,
405+ )
406+
374407 return self .model .fc (hidden_states )
375408
376409 def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
0 commit comments