Skip to content

Commit 29cfba8

Browse files
authored
[Speculative Decoding] Add DFlash speculators config parsing (vllm-project#38300)
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
1 parent 0f14d83 commit 29cfba8

4 files changed

Lines changed: 223 additions & 1 deletion

File tree

.buildkite/test_areas/spec_decode.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,16 @@ steps:
4242
- tests/v1/e2e/spec_decode/
4343
commands:
4444
- pytest -v -s v1/e2e/spec_decode -k "draft_model or no_sync or batch_inference"
45+
46+
- label: DFlash Speculators Correctness
47+
timeout_in_minutes: 30
48+
device: h100
49+
optional: true
50+
num_devices: 1
51+
source_file_dependencies:
52+
- vllm/v1/spec_decode/
53+
- vllm/model_executor/models/qwen3_dflash.py
54+
- tests/v1/spec_decode/test_speculators_dflash.py
55+
commands:
56+
- export VLLM_ALLOW_INSECURE_SERIALIZATION=1
57+
- pytest -v -s v1/spec_decode/test_speculators_dflash.py -m slow_test
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
import torch
5+
6+
from tests.evals.gsm8k.gsm8k_eval import evaluate_gsm8k_offline
7+
from tests.utils import large_gpu_mark
8+
from vllm import LLM
9+
from vllm.config import SpeculativeConfig
10+
from vllm.distributed import cleanup_dist_env_and_memory
11+
12+
MODEL_PATH = "nm-testing/dflash-qwen3-8b-speculators"
13+
14+
EXPECTED_GSM8K_ACCURACY = 0.885
15+
ACCURACY_RTOL = 0.03
16+
EXPECTED_ACCEPTANCE_LEN = 3.45
17+
ACCEPTANCE_LEN_RTOL = 0.15
18+
19+
# Expected per-position acceptance rates (accepted_at_pos / num_drafts)
20+
# Based on GSM8K evaluation with Qwen3-8B dflash speculators.
21+
EXPECTED_PER_POS_ACCEPTANCE_RATES = [0.795, 0.611, 0.429, 0.282]
22+
PER_POS_RTOL = 0.15
23+
24+
25+
def compute_spec_decode_stats(
26+
metrics,
27+
) -> dict:
28+
"""Extract all spec-decode metrics and compute derived stats."""
29+
name2metric = {m.name: m for m in metrics}
30+
31+
n_drafts = name2metric["vllm:spec_decode_num_drafts"].value
32+
n_draft_tokens = name2metric["vllm:spec_decode_num_draft_tokens"].value
33+
n_accepted = name2metric["vllm:spec_decode_num_accepted_tokens"].value
34+
35+
per_pos_vec = name2metric["vllm:spec_decode_num_accepted_tokens_per_pos"].values
36+
37+
acceptance_len = 1 + (n_accepted / n_drafts) if n_drafts > 0 else 1.0
38+
draft_tokens_per_step = (n_draft_tokens / n_drafts) if n_drafts > 0 else 0
39+
overall_acceptance_rate = (n_accepted / n_draft_tokens) if n_draft_tokens > 0 else 0
40+
per_pos_rates = [v / n_drafts for v in per_pos_vec] if n_drafts > 0 else []
41+
42+
return {
43+
"num_drafts": n_drafts,
44+
"num_draft_tokens": n_draft_tokens,
45+
"num_accepted_tokens": n_accepted,
46+
"acceptance_len": acceptance_len,
47+
"draft_tokens_per_step": draft_tokens_per_step,
48+
"overall_acceptance_rate": overall_acceptance_rate,
49+
"per_pos_accepted": list(per_pos_vec),
50+
"per_pos_acceptance_rates": per_pos_rates,
51+
}
52+
53+
54+
def print_spec_decode_stats(stats: dict) -> None:
55+
"""Print all spec-decode metrics and derived values."""
56+
print("\n===== Spec Decode Metrics =====")
57+
print(f" num_drafts: {stats['num_drafts']}")
58+
print(f" num_draft_tokens: {stats['num_draft_tokens']}")
59+
print(f" num_accepted_tokens: {stats['num_accepted_tokens']}")
60+
print(f" draft_tokens_per_step: {stats['draft_tokens_per_step']:.2f}")
61+
print(f" overall_acceptance_rate: {stats['overall_acceptance_rate']:.4f}")
62+
print(f" acceptance_len (1+acc/drafts): {stats['acceptance_len']:.4f}")
63+
print(" per-position accepted tokens:", stats["per_pos_accepted"])
64+
print(" per-position acceptance rates:")
65+
for i, rate in enumerate(stats["per_pos_acceptance_rates"]):
66+
print(f" pos {i}: {rate:.4f}")
67+
print("===============================\n")
68+
69+
70+
def test_dflash_speculators_model(vllm_runner, example_prompts, monkeypatch):
71+
"""
72+
Test DFlash speculators model properly initializes speculative decoding.
73+
74+
Verifies:
75+
1. Speculative config is automatically initialized from speculators config
76+
2. Method is detected as 'dflash'
77+
3. The draft model path is correctly set
78+
4. Speculative tokens count is valid (num_speculative_tokens=8)
79+
5. Text generation works with speculative decoding enabled
80+
"""
81+
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
82+
83+
with vllm_runner(
84+
MODEL_PATH,
85+
dtype=torch.bfloat16,
86+
enforce_eager=True,
87+
quantization="fp8",
88+
) as vllm_model:
89+
vllm_config = vllm_model.llm.llm_engine.vllm_config
90+
91+
assert isinstance(vllm_config.speculative_config, SpeculativeConfig), (
92+
"Speculative config should be initialized for speculators model"
93+
)
94+
95+
spec_config = vllm_config.speculative_config
96+
assert spec_config.method == "dflash", (
97+
f"Expected method='dflash', got '{spec_config.method}'"
98+
)
99+
assert spec_config.num_speculative_tokens > 0, (
100+
f"Expected positive speculative tokens, "
101+
f"got {spec_config.num_speculative_tokens}"
102+
)
103+
assert spec_config.model == MODEL_PATH, (
104+
f"Draft model should be {MODEL_PATH}, got {spec_config.model}"
105+
)
106+
107+
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens=20)
108+
assert vllm_outputs, f"No outputs generated for speculators model {MODEL_PATH}"
109+
110+
111+
@pytest.mark.slow_test
112+
@large_gpu_mark(min_gb=40)
113+
def test_dflash_speculators_correctness(monkeypatch):
114+
"""
115+
E2E correctness test for DFlash via the speculators auto-detect path.
116+
117+
Evaluates GSM8k accuracy to ensure the speculators-format model produces
118+
correct outputs, and checks that acceptance length does not collapse under
119+
batched inference (lm-eval style).
120+
121+
Observed per-position acceptance rates on GSM8K (1319 prompts):
122+
pos 0: 0.795, pos 1: 0.611, pos 2: 0.429, pos 3: 0.282,
123+
pos 4: 0.169, pos 5: 0.093, pos 6: 0.048, pos 7: 0.023
124+
Observed mean AL: 3.45 (GSM8K dataset, max_num_seqs=128)
125+
"""
126+
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
127+
128+
spec_llm = LLM(
129+
model=MODEL_PATH,
130+
trust_remote_code=True,
131+
max_model_len=4096,
132+
max_num_seqs=128,
133+
gpu_memory_utilization=0.85,
134+
enforce_eager=False,
135+
disable_log_stats=False,
136+
)
137+
138+
results = evaluate_gsm8k_offline(spec_llm)
139+
accuracy = results["accuracy"]
140+
accuracy_threshold = EXPECTED_GSM8K_ACCURACY * (1 - ACCURACY_RTOL)
141+
assert accuracy >= accuracy_threshold, (
142+
f"Expected GSM8K accuracy >= {accuracy_threshold:.3f}, got {accuracy:.3f}"
143+
)
144+
145+
current_metrics = spec_llm.get_metrics()
146+
stats = compute_spec_decode_stats(current_metrics)
147+
print_spec_decode_stats(stats)
148+
149+
acceptance_len = stats["acceptance_len"]
150+
al_threshold = EXPECTED_ACCEPTANCE_LEN * (1 - ACCEPTANCE_LEN_RTOL)
151+
assert acceptance_len >= al_threshold, (
152+
f"DFlash speculators acceptance length too low: "
153+
f"{acceptance_len:.2f} < {al_threshold:.2f}"
154+
)
155+
156+
# Check per-position acceptance rates for the first few positions.
157+
per_pos_rates = stats["per_pos_acceptance_rates"]
158+
for i, expected_rate in enumerate(EXPECTED_PER_POS_ACCEPTANCE_RATES):
159+
assert i < len(per_pos_rates), (
160+
f"Missing per-position acceptance rate for position {i}"
161+
)
162+
threshold = expected_rate * (1 - PER_POS_RTOL)
163+
assert per_pos_rates[i] >= threshold, (
164+
f"Per-position acceptance rate at pos {i} too low: "
165+
f"{per_pos_rates[i]:.4f} < {threshold:.4f} "
166+
f"(expected ~{expected_rate:.4f})"
167+
)
168+
169+
del spec_llm
170+
torch.accelerator.empty_cache()
171+
cleanup_dist_env_and_memory()

vllm/model_executor/models/qwen3_dflash.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
523523
self.logits_processor = LogitsProcessor(
524524
self.config.draft_vocab_size, scale=logit_scale
525525
)
526-
self.draft_id_to_target_id = None
526+
target_vocab_size = vllm_config.model_config.get_vocab_size()
527+
if self.config.draft_vocab_size != target_vocab_size:
528+
self.draft_id_to_target_id = nn.Parameter(
529+
torch.zeros(self.config.draft_vocab_size, dtype=torch.long),
530+
requires_grad=False,
531+
)
532+
else:
533+
self.draft_id_to_target_id = None
527534

528535
def embed_input_ids(
529536
self,

vllm/transformers_utils/configs/speculators/algos.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,34 @@ def update_eagle3(config_dict: dict, pre_trained_config: dict) -> None:
4141
pre_trained_config["eagle_aux_hidden_state_layer_ids"] = config_dict[
4242
"eagle_aux_hidden_state_layer_ids"
4343
]
44+
45+
46+
@register_speculator("dflash")
47+
def update_dflash(config_dict: dict, pre_trained_config: dict) -> None:
48+
"""
49+
Apply DFlash specific configuration transformations to the `dict` used to
50+
construct the Transformers PreTrainedConfig.
51+
52+
DFlash specific fields:
53+
- draft_vocab_size: Size of the draft model's vocabulary
54+
- target_hidden_size: Hidden size of the target model
55+
- mask_token_id (required): Token ID used for parallel drafting mask
56+
placeholders
57+
- aux_hidden_state_layer_ids (required): Layer indices from the target
58+
model whose intermediate hidden states are used as context for the
59+
DFlash drafter. Mapped to both eagle_aux_hidden_state_layer_ids
60+
(for gpu_model_runner) and dflash_config.target_layer_ids (for the
61+
DFlash model).
62+
"""
63+
pre_trained_config["architectures"] = ["DFlashDraftModel"]
64+
pre_trained_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
65+
if config_dict.get("target_hidden_size") is not None:
66+
pre_trained_config["target_hidden_size"] = config_dict["target_hidden_size"]
67+
68+
aux_layer_ids = config_dict["aux_hidden_state_layer_ids"]
69+
pre_trained_config["eagle_aux_hidden_state_layer_ids"] = aux_layer_ids
70+
71+
pre_trained_config["dflash_config"] = {
72+
"mask_token_id": config_dict["mask_token_id"],
73+
"target_layer_ids": aux_layer_ids,
74+
}

0 commit comments

Comments
 (0)