Skip to content

Commit 1242196

Browse files
authored
[Model] Support post-norm architecture for EAGLE-3 supeculators (#42764)
Signed-off-by: Doğaç Eldenk <dogacel@gmail.com>
1 parent a65093c commit 1242196

3 files changed

Lines changed: 80 additions & 18 deletions

File tree

vllm/model_executor/models/deepseek_eagle3.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -199,11 +199,18 @@ def __init__(
199199
]
200200
)
201201

202-
# fc layer for combining auxiliary hidden states (3x hidden size input)
203-
if hasattr(self.config, "target_hidden_size"):
204-
fc_input_size = self.config.target_hidden_size * 3
205-
else:
206-
fc_input_size = self.config.hidden_size * 3
202+
# fc layer for combining auxiliary hidden states
203+
num_aux_hidden_states = getattr(self.config, "num_aux_hidden_states", None)
204+
if num_aux_hidden_states is None:
205+
eagle_config = getattr(self.config, "eagle_config", None) or {}
206+
layer_ids = eagle_config.get("eagle_aux_hidden_state_layer_ids")
207+
num_aux_hidden_states = len(layer_ids) if layer_ids else 3
208+
self.num_aux_hidden_states = num_aux_hidden_states
209+
210+
target_hidden_size = getattr(
211+
self.config, "target_hidden_size", self.config.hidden_size
212+
)
213+
fc_input_size = target_hidden_size * num_aux_hidden_states
207214

208215
self.fc = ReplicatedLinear(
209216
input_size=fc_input_size,
@@ -215,6 +222,18 @@ def __init__(
215222
return_bias=False,
216223
)
217224

225+
use_fc_norm = getattr(self.config, "fc_norm", False)
226+
if use_fc_norm:
227+
self.fc_norm = nn.ModuleList(
228+
[
229+
RMSNorm(target_hidden_size, eps=self.config.rms_norm_eps)
230+
for _ in range(self.num_aux_hidden_states)
231+
]
232+
)
233+
else:
234+
self.fc_norm = None
235+
236+
self.norm_output = getattr(self.config, "norm_output", False)
218237
self.norm = RMSNorm(
219238
self.config.hidden_size,
220239
eps=self.config.rms_norm_eps,
@@ -242,8 +261,13 @@ def forward(
242261
hidden_states=hidden_states,
243262
residual=residual,
244263
)
264+
245265
hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
246-
return hidden_states, hidden_prenorm
266+
267+
# norm_output variant uses the post-norm hidden states.
268+
aux_output = hidden_states if self.norm_output else hidden_prenorm
269+
270+
return hidden_states, aux_output
247271

248272
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
249273
stacked_params_mapping = [

vllm/model_executor/models/llama_eagle3.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -172,26 +172,49 @@ def __init__(
172172
]
173173
)
174174
if self.use_aux_hidden_state:
175-
if hasattr(self.config, "target_hidden_size"):
176-
fc_input_size = self.config.target_hidden_size * 3
177-
else:
178-
fc_input_size = self.config.hidden_size * 3
175+
self.num_aux_hidden_states = getattr(
176+
self.config, "num_aux_hidden_states", None
177+
)
178+
if self.num_aux_hidden_states is None:
179+
eagle_config = getattr(self.config, "eagle_config", None) or {}
180+
layer_ids = eagle_config.get("eagle_aux_hidden_state_layer_ids")
181+
self.num_aux_hidden_states = len(layer_ids) if layer_ids else 3
182+
183+
target_hidden_size = getattr(
184+
self.config, "target_hidden_size", self.config.hidden_size
185+
)
186+
self.fc_input_size = target_hidden_size * self.num_aux_hidden_states
187+
179188
if self.norm_before_fc:
180189
self.input_norm = RMSNorm(
181-
fc_input_size,
190+
self.fc_input_size,
182191
eps=self.config.rms_norm_eps,
183192
)
184193
else:
185194
self.input_norm = None
195+
196+
use_fc_norm = getattr(self.config, "fc_norm", False)
197+
if use_fc_norm:
198+
self.fc_norm = nn.ModuleList(
199+
[
200+
RMSNorm(target_hidden_size, eps=self.config.rms_norm_eps)
201+
for _ in range(self.num_aux_hidden_states)
202+
]
203+
)
204+
else:
205+
self.fc_norm = None
206+
186207
self.fc = ReplicatedLinear(
187-
input_size=fc_input_size,
208+
input_size=self.fc_input_size,
188209
output_size=self.config.hidden_size,
189210
bias=False,
190211
params_dtype=vllm_config.model_config.dtype,
191212
quant_config=self.quant_config,
192213
prefix=maybe_prefix(prefix, "fc"),
193214
return_bias=False,
194215
)
216+
217+
self.norm_output = getattr(self.config, "norm_output", False)
195218
self.norm = RMSNorm(
196219
self.config.hidden_size,
197220
eps=self.config.rms_norm_eps,
@@ -220,7 +243,11 @@ def forward(
220243
residual=residual,
221244
)
222245
hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
223-
return hidden_states, hidden_prenorm
246+
247+
# norm_output variant uses the post-norm hidden states.
248+
aux_output = hidden_states if self.norm_output else hidden_prenorm
249+
250+
return hidden_states, aux_output
224251

225252
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
226253
stacked_params_mapping = [
@@ -312,11 +339,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
312339
if self.use_parallel_drafting:
313340
self.register_buffer(
314341
"mask_hidden",
315-
torch.zeros(
316-
1,
317-
(3 if self.model.use_aux_hidden_state else 1)
318-
* self.config.hidden_size,
319-
),
342+
torch.zeros(1, self.model.fc_input_size),
320343
persistent=False,
321344
)
322345

@@ -371,6 +394,16 @@ def combine_hidden_states(
371394

372395
if self.model.norm_before_fc:
373396
hidden_states = self.model.input_norm(hidden_states)
397+
398+
# `norm_before_fc` adds a single RMSNorm before the FC layer, whereas `fc_norm`
399+
# applies separate RMSNorms to each chunk of the hidden states.
400+
if self.model.fc_norm is not None:
401+
chunks = hidden_states.chunk(self.model.num_aux_hidden_states, dim=-1)
402+
hidden_states = torch.cat(
403+
[norm(chunk) for norm, chunk in zip(self.model.fc_norm, chunks)],
404+
dim=-1,
405+
)
406+
374407
return self.model.fc(hidden_states)
375408

376409
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

vllm/v1/worker/gpu_model_runner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5210,9 +5210,14 @@ def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None:
52105210
layer_ids = getattr(hf_config, "eagle_aux_hidden_state_layer_ids", None)
52115211
if not layer_ids:
52125212
dflash_config = getattr(hf_config, "dflash_config", None)
5213+
eagle_config = getattr(hf_config, "eagle_config", None)
5214+
52135215
if dflash_config and isinstance(dflash_config, dict):
52145216
layer_ids = dflash_config.get("target_layer_ids")
52155217

5218+
if eagle_config and isinstance(eagle_config, dict):
5219+
layer_ids = eagle_config.get("eagle_aux_hidden_state_layer_ids")
5220+
52165221
if layer_ids and isinstance(layer_ids, (list, tuple)):
52175222
return tuple(layer_ids)
52185223

0 commit comments

Comments
 (0)