Skip to content

[DSv4] Adding TRTLLM gen attention kernel#43827

Merged
WoosukKwon merged 1 commit into
vllm-project:mainfrom
zyongye:dsv4-sparse-mla-flashinfer-rebased
Jun 4, 2026
Merged

[DSv4] Adding TRTLLM gen attention kernel#43827
WoosukKwon merged 1 commit into
vllm-project:mainfrom
zyongye:dsv4-sparse-mla-flashinfer-rebased

Conversation

@zyongye

@zyongye zyongye commented May 28, 2026

Copy link
Copy Markdown
Member

Summary

Rebase of @PerkzZheng's #42316 onto current main, plus a few materially
new pieces:

  • Once-per-step C128A metadata caching — adds FlashInferMLASparseMetadata
    • FlashInferMLASparseMetadataBuilder so that for compress_ratio == 128
      layers the mixed-sparse-index Triton kernel runs once per step instead
      of once per layer. The SWA-baked combine is materialized lazily on first
      access by ensure_sparse_indices and cached on the metadata for the
      remaining C128A layers in the same step. C4 / SWA-only paths are
      unchanged (they read layer.topk_indices_buffer, which is per-layer).
  • Single-call launcher — collapses the previous decode+prefill split
    into one flashinfer_trtllm_batch_decode_sparse_mla_dsv4_raw call over
    the mixed batch.
  • FP8 scalar-vs-tensor scale dispatch — documents and consolidates the
    precomputed Python-float bmm1_scale / bmm2_scale derived from the
    per-tensor placeholders. The TRTLLM-GEN sparse-MLA launcher takes
    different C++ code paths for scalar vs 1-elem-tensor scale args, so this
    matters for correctness.
  • CuTeDSL compressor writes contiguous BF16 / per-tensor FP8 cache,
    matching what the FlashInfer V4 backend reads (no UE8M0 padding).

Relation to #42316

This branch is a rebase of #42316; the C128A metadata caching, single-call
launcher collapse, and FP8 scale dispatch documentation are the new pieces.
Opening as a separate PR rather than pushing into #42316 because the rebase

End-to-end eval

deepseek-ai/DeepSeek-V4-Flash, 4× GB200, TP=4, V4_FLASHINFER_MLA_SPARSE
backend, --kv-cache-dtype fp8 --block-size 256,
cudagraph_mode=FULL_DECODE_ONLY. vLLM commit ce4e168ba (branch tip
before today's mechanical rebase onto main; rebase only resolved stale
torch-stable-ABI decls in csrc/ops.h / csrc/torch_bindings.cpp):

Task Setting n Score
GSM8K 5-shot, T=0, completions 1319 0.9538 strict / 0.9530 flexible
GPQA-Diamond 0-shot, T=1.0, top_p=0.95, thinking=on, 4× epochs 792 0.8586
AIME25 0-shot, T=1.0, top_p=0.95, thinking=on, 4× epochs 120 0.9750

Test plan

  • pre-commit run --files vllm/models/deepseek_v4/nvidia/flashinfer_sparse.py (passed, including mypy)
  • tests/kernels/test_compressor_kv_cache.py — 36/36 passed
  • End-to-end eval on DSv4-Flash (table above)
  • Profile to confirm build_flashinfer_mixed_sparse_indices Triton kernel
    now appears once per step instead of N×per-step.

AI-assisted disclosure

Developed with AI assistance (Claude Code). Per AGENTS.md, the submitter
has reviewed each changed line.

🤖 Generated with Claude Code

@mergify

mergify Bot commented May 28, 2026

Copy link
Copy Markdown
Contributor

Documentation preview: https://vllm--43827.org.readthedocs.build/en/43827/

@mergify mergify Bot added documentation Improvements or additions to documentation deepseek Related to DeepSeek models nvidia v1 labels May 28, 2026
@mergify

mergify Bot commented May 28, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @zyongye.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 28, 2026
@zyongye zyongye changed the title [DSv4] FlashInfer sparse MLA: rebase + once-per-step C128A metadata caching [DSv4] Adding TRTLLM gen attention kernel May 28, 2026
@zyongye zyongye force-pushed the dsv4-sparse-mla-flashinfer-rebased branch from df2a27f to b96b676 Compare May 28, 2026 03:25
@zyongye zyongye requested a review from hmellor as a code owner May 28, 2026 03:25
@mergify mergify Bot removed the needs-rebase label May 28, 2026
@mergify

mergify Bot commented May 28, 2026

Copy link
Copy Markdown
Contributor

Hi @zyongye, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@PerkzZheng PerkzZheng left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zyongye Hi Yongye, thanks for rebasing my MR. I have left some comments.

Comment thread docs/design/attention_backends.md Outdated

| Backend | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | Non-Causal | Sparse | MM Prefix | DCP | Attention Types | Compute Cap. |
| ------- | ------ | --------- | ----------- | ---------- | ---- | ---------- | ------ | --------- | --- | --------------- | ------------ |
| `V4_FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto` | 256 | 512 | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | Any |

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious what 256 Block Sizes mean here ?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a guidance on what block size user should specify when launching the engine with this kernel. Currently since we did some custom kv cache layout we restricted the block size to be 256 by passing --block-size 256. I need to look into if this are really necessary.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, then it should be 512 for flashinfer sparse MLA for your information.

decode_compressed_topk_lens = token_to_req_indices

padded_topk = max(topk, decode_compressed_topk)
padded_topk = (padded_topk + 3) // 4 * 4

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flashinfer mla kernels have a requirement of 16B alignment for topk indices. Can you help add some comments here ? thanks.

# paged path and writes a contiguous 512-wide cache row per token; bf16
# vs per-tensor fp8 is selected by ``store_full_fp8`` (with the scale
# source supplied via ``fp8_scale``).
store_full_bf16: bool = False,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems better to rename to store_full_kv otherwise it is misleading. sorry if it was introduced in my commits. we will also need to modify other places that are using this term.

num_decodes = swa_metadata.num_decodes
num_prefills = swa_metadata.num_prefills
num_decode_tokens = swa_metadata.num_decode_tokens
num_prefill_tokens = swa_metadata.num_prefill_tokens

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I split it into two calls (prefill and decode) in my previous MR because we pad gridDim.x to maxSeqLenQ, which can launch too many paddings CTA for mixed requests. I still observe obvious perf gains by splitting even though we will skip those paddings during runtime (which means the CTA switching overhead is not negligible).

@zyongye zyongye added ready ONLY add when PR is ready to merge/full CI is needed and removed ready ONLY add when PR is ready to merge/full CI is needed labels May 31, 2026
@mergify

mergify Bot commented Jun 2, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @zyongye.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Jun 2, 2026
Comment thread vllm/models/deepseek_v4/attention.py Outdated
Comment on lines +182 to +203
@dataclass
class DeepseekV4MLAModules:
"""Modules used in DeepseekV4 MLA."""

vllm_config: VllmConfig
fused_wqa_wkv: torch.nn.Module
q_norm: torch.nn.Module
wq_b: torch.nn.Module
kv_norm: torch.nn.Module
wo_a: torch.nn.Module
wo_b: torch.nn.Module
attn_sink: torch.nn.Module
rotary_emb: torch.nn.Module
indexer: torch.nn.Module | None
indexer_rotary_emb: torch.nn.Module
topk_indices_buffer: torch.Tensor | None
aux_stream_list: list[torch.cuda.Stream] | None = None


# --8<-- [start:multi_head_latent_attention]
@PluggableLayer.register("deepseek_v4_multi_head_latent_attention")
class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove these classes as I did in #44246? I think there was a error in resolving the merge conflict.

Comment thread vllm/models/deepseek_v4/attention.py Outdated
Comment on lines +395 to +400
torch.ops.vllm.deepseek_v4_attention(
hidden_states,
positions,
o_padded,
self.layer_name,
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This torch op was also removed in the main branch

Comment thread vllm/models/deepseek_v4/attention.py Outdated
"bhr,hdr->bhd",
(o_fp8, o_scale),
(wo_a_fp8, wo_a_scale),
torch.ops.vllm.deepseek_v4_fp8_einsum(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Comment thread vllm/models/deepseek_v4/attention.py Outdated
Comment on lines +709 to +755
def deepseek_v4_attention_fake(
hidden_states: torch.Tensor,
positions: torch.Tensor,
out: torch.Tensor,
layer_name: str,
) -> None:
return None


direct_register_custom_op(
op_name="deepseek_v4_attention",
op_func=deepseek_v4_attention,
mutates_args=["out"],
fake_impl=deepseek_v4_attention_fake,
)


def deepseek_v4_fp8_einsum(
a: torch.Tensor,
a_scale: torch.Tensor,
b: torch.Tensor,
b_scale: torch.Tensor,
out: torch.Tensor,
equation: str,
recipe: list[int],
) -> None:
fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe))


def deepseek_v4_fp8_einsum_fake(
a: torch.Tensor,
a_scale: torch.Tensor,
b: torch.Tensor,
b_scale: torch.Tensor,
out: torch.Tensor,
equation: str,
recipe: list[int],
) -> None:
return None


direct_register_custom_op(
op_name="deepseek_v4_fp8_einsum",
op_func=deepseek_v4_fp8_einsum,
mutates_args=["out"],
fake_impl=deepseek_v4_fp8_einsum_fake,
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also needs to be removed

),
Float32(self.fp8_max),
)
y1 = cute.arch.fmin(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we avoid using fmin?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. Will change

@zyongye zyongye force-pushed the dsv4-sparse-mla-flashinfer-rebased branch 2 times, most recently from f73111f to a37d8af Compare June 3, 2026 05:30
@zyongye zyongye added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 3, 2026
@zyongye zyongye force-pushed the dsv4-sparse-mla-flashinfer-rebased branch from a37d8af to 3f596f3 Compare June 3, 2026 18:13
Comment thread vllm/v1/attention/backends/registry.py Outdated
Comment on lines +80 to +87
FLASHMLA_SPARSE_V4 = (
"vllm.models.deepseek_v4.nvidia.flashmla.DeepseekV4FlashMLASparseBackend"
)
FLASHINFER_MLA_SPARSE_V4 = (
"vllm.models.deepseek_v4.nvidia.flashinfer_sparse."
"DeepseekV4FlashInferMLASparseBackend"
)
ROCM_FLASHMLA_SPARSE_V4 = (

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: What about DSV4 instead of V4?

Suggested change
FLASHMLA_SPARSE_V4 = (
"vllm.models.deepseek_v4.nvidia.flashmla.DeepseekV4FlashMLASparseBackend"
)
FLASHINFER_MLA_SPARSE_V4 = (
"vllm.models.deepseek_v4.nvidia.flashinfer_sparse."
"DeepseekV4FlashInferMLASparseBackend"
)
ROCM_FLASHMLA_SPARSE_V4 = (
FLASHMLA_SPARSE_DSV4 = (
"vllm.models.deepseek_v4.nvidia.flashmla.DeepseekV4FlashMLASparseBackend"
)
FLASHINFER_MLA_SPARSE_DSV4 = (
"vllm.models.deepseek_v4.nvidia.flashinfer_sparse."
"DeepseekV4FlashInferMLASparseBackend"
)
ROCM_FLASHMLA_SPARSE_DSV4 = (

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@mergify

mergify Bot commented Jun 4, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @zyongye.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Jun 4, 2026
Add a selectable DeepSeek V4 sparse-MLA decode backend that runs through
FlashInfer's TRTLLM-gen kernel with a contiguous bf16 / per-tensor FP8 KV
cache, alongside the existing FlashMLA path. Re-ported from PerkzZheng's

- csrc: sibling fusedDeepseekV4FullCacheKernel (contiguous 512-wide bf16 /
  per-tensor fp8 insert) + full_cache_{bf16,fp8}_insert ops, in the libtorch
  stable ABI (csrc/libtorch_stable/).
- nvidia/flashinfer_sparse.py: DeepseekV4FlashInferMLASparseBackend/Impl;
  public flashinfer.mla.trtllm_batch_decode_sparse_mla_dsv4 launcher, two-call
  decode/prefill split, q-head padding to {64,128}, fp8 scale buffers.
- registry + selection: FLASHMLA_SPARSE_DSV4 / FLASHINFER_MLA_SPARSE_DSV4 /
  ROCM_FLASHMLA_SPARSE_DSV4; _select_v4_sparse_impl consults the backend;
  _resolve_dsv4_kv_cache_dtype maps dtype per backend.
- compressor: CuTeDSL full-cache classes (SparseAttnCompressNormRopeStoreFullC4Kernel,
  SparseAttnNormRopeStoreFullKernel) separate from the pristine legacy UE8M0
  classes so the legacy path keeps its perf; build_flashinfer_mixed_sparse_indices
  in common/ops.
- SWA cache / kv_cache_interface: accept bf16/fp8 dtypes, gate 576B alignment
  on fp8_ds_mla.
- docs: dedicated "DeepSeek V4 Decode Backends" section.
- tests: full-cache parity (insert ops + cutedsl compressor).

Verified: kernel/compressor tests pass; e2e GSM8K on DSv4-Flash (TP=4, fp8)
matches the FlashMLA baseline (~0.953).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
@zyongye zyongye force-pushed the dsv4-sparse-mla-flashinfer-rebased branch from 8e39bc7 to 6a65ae7 Compare June 4, 2026 04:45
@mergify mergify Bot removed the needs-rebase label Jun 4, 2026
@WoosukKwon WoosukKwon self-assigned this Jun 4, 2026
@WoosukKwon WoosukKwon self-requested a review June 4, 2026 08:01

@WoosukKwon WoosukKwon left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!
I will follow up with some refactoring.

@github-project-automation github-project-automation Bot moved this to Ready in NVIDIA Jun 4, 2026
@WoosukKwon WoosukKwon enabled auto-merge (squash) June 4, 2026 08:02
@WoosukKwon WoosukKwon merged commit b5235fc into vllm-project:main Jun 4, 2026
164 of 165 checks passed
@github-project-automation github-project-automation Bot moved this from Ready to Done in NVIDIA Jun 4, 2026
@wzhao18

wzhao18 commented Jun 4, 2026

Copy link
Copy Markdown
Contributor

Hi @zyongye, thanks for the PR! Have you done benchmarking to measure the performance of the TRTLLM gen attention kernel vs FlashMLA? Which settings does this kernel have better perf for dsv4?

@zyongye zyongye deleted the dsv4-sparse-mla-flashinfer-rebased branch June 4, 2026 16:55
@zyongye

zyongye commented Jun 4, 2026

Copy link
Copy Markdown
Member Author

Hi @zyongye, thanks for the PR! Have you done benchmarking to measure the performance of the TRTLLM gen attention kernel vs FlashMLA? Which settings does this kernel have better perf for dsv4?

I only test c2048. I actually didn't see any perf improvement. Haven't check out points yet.

JisoLya pushed a commit to JisoLya/vllm that referenced this pull request Jun 5, 2026
knight0528 pushed a commit to knight0528/vllm that referenced this pull request Jun 8, 2026
waqahmed-amd-fi pushed a commit to waqahmed-amd-fi/vllm that referenced this pull request Jun 10, 2026
Signed-off-by: Waqar Ahmed <waqar.ahmed@amd.com>
Saddss pushed a commit to Saddss/vllm that referenced this pull request Jun 14, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models documentation Improvements or additions to documentation nvidia ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants