|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +import pytest |
| 4 | +import torch |
| 5 | + |
| 6 | +from tests.evals.gsm8k.gsm8k_eval import evaluate_gsm8k_offline |
| 7 | +from tests.utils import large_gpu_mark |
| 8 | +from vllm import LLM |
| 9 | +from vllm.config import SpeculativeConfig |
| 10 | +from vllm.distributed import cleanup_dist_env_and_memory |
| 11 | + |
| 12 | +MODEL_PATH = "nm-testing/dflash-qwen3-8b-speculators" |
| 13 | + |
| 14 | +EXPECTED_GSM8K_ACCURACY = 0.885 |
| 15 | +ACCURACY_RTOL = 0.03 |
| 16 | +EXPECTED_ACCEPTANCE_LEN = 3.45 |
| 17 | +ACCEPTANCE_LEN_RTOL = 0.15 |
| 18 | + |
| 19 | +# Expected per-position acceptance rates (accepted_at_pos / num_drafts) |
| 20 | +# Based on GSM8K evaluation with Qwen3-8B dflash speculators. |
| 21 | +EXPECTED_PER_POS_ACCEPTANCE_RATES = [0.795, 0.611, 0.429, 0.282] |
| 22 | +PER_POS_RTOL = 0.15 |
| 23 | + |
| 24 | + |
| 25 | +def compute_spec_decode_stats( |
| 26 | + metrics, |
| 27 | +) -> dict: |
| 28 | + """Extract all spec-decode metrics and compute derived stats.""" |
| 29 | + name2metric = {m.name: m for m in metrics} |
| 30 | + |
| 31 | + n_drafts = name2metric["vllm:spec_decode_num_drafts"].value |
| 32 | + n_draft_tokens = name2metric["vllm:spec_decode_num_draft_tokens"].value |
| 33 | + n_accepted = name2metric["vllm:spec_decode_num_accepted_tokens"].value |
| 34 | + |
| 35 | + per_pos_vec = name2metric["vllm:spec_decode_num_accepted_tokens_per_pos"].values |
| 36 | + |
| 37 | + acceptance_len = 1 + (n_accepted / n_drafts) if n_drafts > 0 else 1.0 |
| 38 | + draft_tokens_per_step = (n_draft_tokens / n_drafts) if n_drafts > 0 else 0 |
| 39 | + overall_acceptance_rate = (n_accepted / n_draft_tokens) if n_draft_tokens > 0 else 0 |
| 40 | + per_pos_rates = [v / n_drafts for v in per_pos_vec] if n_drafts > 0 else [] |
| 41 | + |
| 42 | + return { |
| 43 | + "num_drafts": n_drafts, |
| 44 | + "num_draft_tokens": n_draft_tokens, |
| 45 | + "num_accepted_tokens": n_accepted, |
| 46 | + "acceptance_len": acceptance_len, |
| 47 | + "draft_tokens_per_step": draft_tokens_per_step, |
| 48 | + "overall_acceptance_rate": overall_acceptance_rate, |
| 49 | + "per_pos_accepted": list(per_pos_vec), |
| 50 | + "per_pos_acceptance_rates": per_pos_rates, |
| 51 | + } |
| 52 | + |
| 53 | + |
| 54 | +def print_spec_decode_stats(stats: dict) -> None: |
| 55 | + """Print all spec-decode metrics and derived values.""" |
| 56 | + print("\n===== Spec Decode Metrics =====") |
| 57 | + print(f" num_drafts: {stats['num_drafts']}") |
| 58 | + print(f" num_draft_tokens: {stats['num_draft_tokens']}") |
| 59 | + print(f" num_accepted_tokens: {stats['num_accepted_tokens']}") |
| 60 | + print(f" draft_tokens_per_step: {stats['draft_tokens_per_step']:.2f}") |
| 61 | + print(f" overall_acceptance_rate: {stats['overall_acceptance_rate']:.4f}") |
| 62 | + print(f" acceptance_len (1+acc/drafts): {stats['acceptance_len']:.4f}") |
| 63 | + print(" per-position accepted tokens:", stats["per_pos_accepted"]) |
| 64 | + print(" per-position acceptance rates:") |
| 65 | + for i, rate in enumerate(stats["per_pos_acceptance_rates"]): |
| 66 | + print(f" pos {i}: {rate:.4f}") |
| 67 | + print("===============================\n") |
| 68 | + |
| 69 | + |
| 70 | +def test_dflash_speculators_model(vllm_runner, example_prompts, monkeypatch): |
| 71 | + """ |
| 72 | + Test DFlash speculators model properly initializes speculative decoding. |
| 73 | +
|
| 74 | + Verifies: |
| 75 | + 1. Speculative config is automatically initialized from speculators config |
| 76 | + 2. Method is detected as 'dflash' |
| 77 | + 3. The draft model path is correctly set |
| 78 | + 4. Speculative tokens count is valid (num_speculative_tokens=8) |
| 79 | + 5. Text generation works with speculative decoding enabled |
| 80 | + """ |
| 81 | + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") |
| 82 | + |
| 83 | + with vllm_runner( |
| 84 | + MODEL_PATH, |
| 85 | + dtype=torch.bfloat16, |
| 86 | + enforce_eager=True, |
| 87 | + quantization="fp8", |
| 88 | + ) as vllm_model: |
| 89 | + vllm_config = vllm_model.llm.llm_engine.vllm_config |
| 90 | + |
| 91 | + assert isinstance(vllm_config.speculative_config, SpeculativeConfig), ( |
| 92 | + "Speculative config should be initialized for speculators model" |
| 93 | + ) |
| 94 | + |
| 95 | + spec_config = vllm_config.speculative_config |
| 96 | + assert spec_config.method == "dflash", ( |
| 97 | + f"Expected method='dflash', got '{spec_config.method}'" |
| 98 | + ) |
| 99 | + assert spec_config.num_speculative_tokens > 0, ( |
| 100 | + f"Expected positive speculative tokens, " |
| 101 | + f"got {spec_config.num_speculative_tokens}" |
| 102 | + ) |
| 103 | + assert spec_config.model == MODEL_PATH, ( |
| 104 | + f"Draft model should be {MODEL_PATH}, got {spec_config.model}" |
| 105 | + ) |
| 106 | + |
| 107 | + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens=20) |
| 108 | + assert vllm_outputs, f"No outputs generated for speculators model {MODEL_PATH}" |
| 109 | + |
| 110 | + |
| 111 | +@pytest.mark.slow_test |
| 112 | +@large_gpu_mark(min_gb=40) |
| 113 | +def test_dflash_speculators_correctness(monkeypatch): |
| 114 | + """ |
| 115 | + E2E correctness test for DFlash via the speculators auto-detect path. |
| 116 | +
|
| 117 | + Evaluates GSM8k accuracy to ensure the speculators-format model produces |
| 118 | + correct outputs, and checks that acceptance length does not collapse under |
| 119 | + batched inference (lm-eval style). |
| 120 | +
|
| 121 | + Observed per-position acceptance rates on GSM8K (1319 prompts): |
| 122 | + pos 0: 0.795, pos 1: 0.611, pos 2: 0.429, pos 3: 0.282, |
| 123 | + pos 4: 0.169, pos 5: 0.093, pos 6: 0.048, pos 7: 0.023 |
| 124 | + Observed mean AL: 3.45 (GSM8K dataset, max_num_seqs=128) |
| 125 | + """ |
| 126 | + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") |
| 127 | + |
| 128 | + spec_llm = LLM( |
| 129 | + model=MODEL_PATH, |
| 130 | + trust_remote_code=True, |
| 131 | + max_model_len=4096, |
| 132 | + max_num_seqs=128, |
| 133 | + gpu_memory_utilization=0.85, |
| 134 | + enforce_eager=False, |
| 135 | + disable_log_stats=False, |
| 136 | + ) |
| 137 | + |
| 138 | + results = evaluate_gsm8k_offline(spec_llm) |
| 139 | + accuracy = results["accuracy"] |
| 140 | + accuracy_threshold = EXPECTED_GSM8K_ACCURACY * (1 - ACCURACY_RTOL) |
| 141 | + assert accuracy >= accuracy_threshold, ( |
| 142 | + f"Expected GSM8K accuracy >= {accuracy_threshold:.3f}, got {accuracy:.3f}" |
| 143 | + ) |
| 144 | + |
| 145 | + current_metrics = spec_llm.get_metrics() |
| 146 | + stats = compute_spec_decode_stats(current_metrics) |
| 147 | + print_spec_decode_stats(stats) |
| 148 | + |
| 149 | + acceptance_len = stats["acceptance_len"] |
| 150 | + al_threshold = EXPECTED_ACCEPTANCE_LEN * (1 - ACCEPTANCE_LEN_RTOL) |
| 151 | + assert acceptance_len >= al_threshold, ( |
| 152 | + f"DFlash speculators acceptance length too low: " |
| 153 | + f"{acceptance_len:.2f} < {al_threshold:.2f}" |
| 154 | + ) |
| 155 | + |
| 156 | + # Check per-position acceptance rates for the first few positions. |
| 157 | + per_pos_rates = stats["per_pos_acceptance_rates"] |
| 158 | + for i, expected_rate in enumerate(EXPECTED_PER_POS_ACCEPTANCE_RATES): |
| 159 | + assert i < len(per_pos_rates), ( |
| 160 | + f"Missing per-position acceptance rate for position {i}" |
| 161 | + ) |
| 162 | + threshold = expected_rate * (1 - PER_POS_RTOL) |
| 163 | + assert per_pos_rates[i] >= threshold, ( |
| 164 | + f"Per-position acceptance rate at pos {i} too low: " |
| 165 | + f"{per_pos_rates[i]:.4f} < {threshold:.4f} " |
| 166 | + f"(expected ~{expected_rate:.4f})" |
| 167 | + ) |
| 168 | + |
| 169 | + del spec_llm |
| 170 | + torch.accelerator.empty_cache() |
| 171 | + cleanup_dist_env_and_memory() |
0 commit comments