Skip to content

[Speculative Decoding] Add DFlash speculators config parsing#38300

Merged
mgoin merged 19 commits into
vllm-project:mainfrom
ZhanqiuHu:dflash-speculators
Apr 15, 2026
Merged

[Speculative Decoding] Add DFlash speculators config parsing#38300
mgoin merged 19 commits into
vllm-project:mainfrom
ZhanqiuHu:dflash-speculators

Conversation

@ZhanqiuHu

@ZhanqiuHu ZhanqiuHu commented Mar 27, 2026

Copy link
Copy Markdown
Contributor

Summary

  • Adds DFlash speculators config parsing support (algos.py)
  • Allows user --speculative-config to override auto-detected values (config.py)
  • Updates qwen3_dflash.py weight loading: d2t/t2d/verifier handling (similar to Eagle3 patterns)
  • Adds E2E test for DFlash speculators auto-detect path
  • Closes [Feature]: dflash speculator model support #38240

Test Results (shanjiaz/speculators-dflash-format, Qwen3-8B target)

GSM8K Correctness (1319 questions, 5-shot, batched)

  • Accuracy: 0.885 (Qwen3-8B baseline: ~0.87-0.92)
  • Mean AL: 1.84

Magpie Acceptance Rates (200 prompts, batch-size-1)

  • Mean AL: 1.77 (min 1.45, max 2.09)
  • Per-position acceptance rate: [0.478, 0.181, 0.069, 0.023, 0.007, 0.002, 0.001, 0.000]

@shanjiaz @fynnsu Ready for review. Needs confirmation on qwen3_dflash.py changes.

Magpie validation script
"""
Test DFlash speculators auto-detect path.

Loads a speculators-format model directly from HF (no config patching)
and measures acceptance length on the magpie dataset.

Usage:
    chg run --gpus 1 -- python my_wip/dflash_speculators/test_speculators_path.py
"""

import os

from tqdm import tqdm

from vllm import LLM, SamplingParams
from vllm.v1.metrics.reader import Metric

DEFAULT_MODEL = "shanjiaz/speculators-dflash-format"


def _metric_map(metrics):
    return {m.name: m.value for m in metrics if hasattr(m, "value")}


def compute_acceptance_len(
    metrics: list[Metric], prev_metrics: list[Metric] | None = None
) -> float:
    name2metric = _metric_map(metrics)
    n_drafts = name2metric["vllm:spec_decode_num_drafts"]
    n_accepted = name2metric["vllm:spec_decode_num_accepted_tokens"]
    if prev_metrics is not None:
        prev = _metric_map(prev_metrics)
        n_drafts -= prev["vllm:spec_decode_num_drafts"]
        n_accepted -= prev["vllm:spec_decode_num_accepted_tokens"]
    if n_drafts <= 0:
        return 1.0
    return 1 + (n_accepted / n_drafts)


def load_magpie_dataset(max_prompts=200):
    from datasets import load_dataset
    ds = load_dataset("Magpie-Align/Magpie-Qwen2.5-Pro-1M-v0.1", split="train")
    prompts = []
    for i, row in enumerate(ds):
        if i >= max_prompts:
            break
        if "instruction" in row and row["instruction"]:
            prompts.append(row["instruction"])
        elif "conversations" in row and row["conversations"]:
            for turn in row["conversations"]:
                if turn.get("role") == "user" or turn.get("from") == "human":
                    prompts.append(turn.get("content", turn.get("value", "")))
                    break
    return prompts


def main():
    model_path = os.environ.get("SPEC_MODEL_PATH", DEFAULT_MODEL)

    print(f"\nLoading model via speculators auto-detect: {model_path}")
    llm = LLM(
        model=model_path,
        trust_remote_code=True,
        max_model_len=4096,
        max_num_seqs=128,
        gpu_memory_utilization=0.85,
        enforce_eager=False,
        disable_log_stats=False,
    )

    sc = llm.llm_engine.vllm_config.speculative_config
    print(f"Detected: method={sc.method}, "
          f"num_speculative_tokens={sc.num_speculative_tokens}, "
          f"draft_model={sc.model}")

    tokenizer = llm.get_tokenizer()

    print("\nLoading magpie dataset...")
    prompts_raw = load_magpie_dataset(max_prompts=200)
    print(f"Loaded {len(prompts_raw)} prompts")

    sampling_params = SamplingParams(temperature=0, max_tokens=2048)

    prev_metrics = None
    acceptance_lengths = []

    for i in tqdm(range(len(prompts_raw)), desc="Processing magpie"):
        prompt_text = tokenizer.apply_chat_template(
            [{"role": "user", "content": prompts_raw[i]}],
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=False,
        )

        llm.generate([prompt_text], sampling_params, use_tqdm=False)
        current_metrics = llm.get_metrics()
        al = compute_acceptance_len(current_metrics, prev_metrics)
        prev_metrics = current_metrics
        acceptance_lengths.append(al)

    mean_al = sum(acceptance_lengths) / len(acceptance_lengths)

    print("\n" + "=" * 60)
    print("RESULTS — Speculators Auto-Detect Path")
    print(f"Model: {model_path}")
    print(f"Dataset: magpie ({len(prompts_raw)} prompts)")
    print(f"Mean Acceptance Length: {mean_al:.3f}")
    print(f"Min AL: {min(acceptance_lengths):.3f}")
    print(f"Max AL: {max(acceptance_lengths):.3f}")

    final_metrics = llm.get_metrics()
    name2val = _metric_map(final_metrics)
    n_drafts = name2val.get("vllm:spec_decode_num_drafts", 0)
    per_pos_rates = []
    if n_drafts > 0:
        for m in final_metrics:
            if hasattr(m, "values") and "per_pos" in m.name:
                per_pos_rates = [v / n_drafts for v in m.values]
                break
    if per_pos_rates:
        rate_strs = ", ".join(f"{r:.3f}" for r in per_pos_rates)
        print(f"Per-position acceptance rate: [{rate_strs}]")

    print("=" * 60)

    print("\nFinal aggregate metrics:")
    for key in sorted(name2val):
        if "spec_decode" in key:
            print(f"  {key}: {name2val[key]}")


if __name__ == "__main__":
    main()

@mergify mergify Bot added new-model Requests to new models qwen Related to Qwen models speculative-decoding v1 labels Mar 27, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements DFlash speculative decoding for Qwen3 models in the vLLM V1 engine. Key additions include the DFlashProposer, a specialized Qwen3 DFlash model with fused KV precomputation, and a Triton kernel for efficient input preparation. The changes also extend the attention selector to support non-causal attention required by DFlash. Feedback was provided regarding the _raise_if_multimodal override in the proposer, which currently enables an untested code path; it is recommended to remove this override to maintain stability for multimodal inputs.

Comment thread vllm/v1/spec_decode/dflash.py
@mergify

mergify Bot commented Mar 30, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ZhanqiuHu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Mar 30, 2026
@ZhanqiuHu ZhanqiuHu force-pushed the dflash-speculators branch from f229ecb to fa56363 Compare March 31, 2026 15:40
@ZhanqiuHu ZhanqiuHu marked this pull request as ready for review March 31, 2026 15:41
@mergify mergify Bot removed the needs-rebase label Mar 31, 2026
@ZhanqiuHu ZhanqiuHu force-pushed the dflash-speculators branch from fa56363 to 41bd6e7 Compare April 8, 2026 19:50
from vllm.config import SpeculativeConfig
from vllm.distributed import cleanup_dist_env_and_memory

MODEL_PATH = "shanjiaz/speculators-dflash-format"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use this model instead? https://huggingface.co/nm-testing/dflash-qwen3-8b-speculators Thanks!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated!

@dsikka

dsikka commented Apr 8, 2026

Copy link
Copy Markdown
Contributor

Thank you!

@rahul-tuli please help review

@ZhanqiuHu ZhanqiuHu force-pushed the dflash-speculators branch from 41bd6e7 to 3c202ed Compare April 8, 2026 20:30
@mergify

mergify Bot commented Apr 8, 2026

Copy link
Copy Markdown
Contributor

Hi @ZhanqiuHu, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@rahul-tuli rahul-tuli left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! pending two nits

Comment thread tests/v1/spec_decode/test_speculators_dflash.py Outdated
Comment thread tests/v1/spec_decode/test_speculators_dflash.py Outdated
@ZhanqiuHu ZhanqiuHu force-pushed the dflash-speculators branch from 024300c to 410498e Compare April 9, 2026 15:25
Comment thread vllm/model_executor/models/qwen3_dflash.py Outdated
Comment thread vllm/transformers_utils/config.py Outdated
Comment thread vllm/model_executor/models/qwen3_dflash.py Outdated

EXPECTED_GSM8K_ACCURACY = 0.885
ACCURACY_RTOL = 0.03
EXPECTED_ACCEPTANCE_LEN = 1.84

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a dummy checkpoint? That score seems really low for a DFlash speculator

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This checkpoint is trained with speculators using different setup than the original dflash qwen3-8b config. For example, it only uses 3 layers instead of 5. I confirmed with @shanjiaz on the expected acc len.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shanjiaz is there a reason the accuracy is so low? 50% on the first token indicates either a very short/poor training run, or a problem in the code. What's going on here?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZhanqiuHu could you run the DFlash reference checkpoint on the same test to get a baseline to compare against? I'm shocked it would be so low, even with fewer params

@shanjiaz shanjiaz Apr 9, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@benchislett This is a relatively short run with limited data only used for testing. We'll replace this with a better checkpoint when we have the time/resources to produce a better one.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@benchislett z-lab/Qwen3-8B-DFlash-b16, 5 layers, GSM8K-5shot:

Accuracy: 0.886
AL: 3.70 
Per-position: 0.756, 0.584, 0.440, 0.320, 0.227, 0.153, 0.102, 0.066

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, acceptable for now but please create a github issue to track and update it when you have a better checkpoint.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are lots of bugs in specdec that tend to only manifest on later predicted tokens, especially with parallel drafting, and the coverage of those issues is not as good if the test model has very low AR since it falls off fast anyways

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here: #39519

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @benchislett, we have trained a new dflash model that gets better acceptance rate:

SpecDecoding metrics: Mean acceptance length: 3.42, Accepted throughput: 4858.14 tokens/s, Drafted throughput: 16030.11 tokens/s, Accepted: 59735 tokens, Drafted: 197104 tokens, Per-position acceptance rate: 0.794, 0.607, 0.424, 0.277, 0.166, 0.090, 0.046, 0.021, Avg Draft acceptance rate: 30.3%

Let us know : )

Comment thread tests/v1/spec_decode/test_speculators_dflash.py
@benchislett benchislett added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 13, 2026
@benchislett benchislett enabled auto-merge (squash) April 13, 2026 18:38
auto-merge was automatically disabled April 13, 2026 22:04

Head branch was pushed to by a user without write access

@ZhanqiuHu ZhanqiuHu force-pushed the dflash-speculators branch from 4b5cc91 to c4ceb01 Compare April 13, 2026 22:04
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…diff

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
@ZhanqiuHu ZhanqiuHu force-pushed the dflash-speculators branch from 48677ec to c036e37 Compare April 14, 2026 14:42
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
@mergify mergify Bot added the ci/build label Apr 14, 2026
ZhanqiuHu and others added 7 commits April 14, 2026 17:04
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
@mgoin mgoin merged commit 0b790a2 into vllm-project:main Apr 15, 2026
62 checks passed
baonudesifeizhai pushed a commit to baonudesifeizhai/vllm that referenced this pull request Apr 23, 2026
whk-lab pushed a commit to whk-lab/vllm that referenced this pull request Apr 23, 2026
avinashsingh77 pushed a commit to avinashsingh77/vllm that referenced this pull request Apr 27, 2026
…oject#38300)

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Avinash Singh <avinashsingh.rcoem@gmail.com>
mystous pushed a commit to mystous/vllm_hybrid that referenced this pull request May 10, 2026
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
jhu960213 pushed a commit to jhu960213/vllm that referenced this pull request May 20, 2026
Rukhaiya2004 pushed a commit to Rukhaiya2004/vllm that referenced this pull request May 23, 2026
mvanhorn pushed a commit to mvanhorn/vllm that referenced this pull request Jun 4, 2026
…oject#38300)

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
@ZhanqiuHu ZhanqiuHu deleted the dflash-speculators branch June 4, 2026 17:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build new-model Requests to new models qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature]: dflash speculator model support

6 participants