[DSv4] Adding TRTLLM gen attention kernel#43827
Conversation
|
Documentation preview: https://vllm--43827.org.readthedocs.build/en/43827/ |
|
This pull request has merge conflicts that must be resolved before it can be |
df2a27f to
b96b676
Compare
|
Hi @zyongye, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
PerkzZheng
left a comment
There was a problem hiding this comment.
@zyongye Hi Yongye, thanks for rebasing my MR. I have left some comments.
|
|
||
| | Backend | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | Non-Causal | Sparse | MM Prefix | DCP | Attention Types | Compute Cap. | | ||
| | ------- | ------ | --------- | ----------- | ---------- | ---- | ---------- | ------ | --------- | --- | --------------- | ------------ | | ||
| | `V4_FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto` | 256 | 512 | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | Any | |
There was a problem hiding this comment.
Curious what 256 Block Sizes mean here ?
There was a problem hiding this comment.
I think it's a guidance on what block size user should specify when launching the engine with this kernel. Currently since we did some custom kv cache layout we restricted the block size to be 256 by passing --block-size 256. I need to look into if this are really necessary.
There was a problem hiding this comment.
Thanks, then it should be 512 for flashinfer sparse MLA for your information.
| decode_compressed_topk_lens = token_to_req_indices | ||
|
|
||
| padded_topk = max(topk, decode_compressed_topk) | ||
| padded_topk = (padded_topk + 3) // 4 * 4 |
There was a problem hiding this comment.
flashinfer mla kernels have a requirement of 16B alignment for topk indices. Can you help add some comments here ? thanks.
| # paged path and writes a contiguous 512-wide cache row per token; bf16 | ||
| # vs per-tensor fp8 is selected by ``store_full_fp8`` (with the scale | ||
| # source supplied via ``fp8_scale``). | ||
| store_full_bf16: bool = False, |
There was a problem hiding this comment.
it seems better to rename to store_full_kv otherwise it is misleading. sorry if it was introduced in my commits. we will also need to modify other places that are using this term.
| num_decodes = swa_metadata.num_decodes | ||
| num_prefills = swa_metadata.num_prefills | ||
| num_decode_tokens = swa_metadata.num_decode_tokens | ||
| num_prefill_tokens = swa_metadata.num_prefill_tokens |
There was a problem hiding this comment.
I split it into two calls (prefill and decode) in my previous MR because we pad gridDim.x to maxSeqLenQ, which can launch too many paddings CTA for mixed requests. I still observe obvious perf gains by splitting even though we will skip those paddings during runtime (which means the CTA switching overhead is not negligible).
|
This pull request has merge conflicts that must be resolved before it can be |
| @dataclass | ||
| class DeepseekV4MLAModules: | ||
| """Modules used in DeepseekV4 MLA.""" | ||
|
|
||
| vllm_config: VllmConfig | ||
| fused_wqa_wkv: torch.nn.Module | ||
| q_norm: torch.nn.Module | ||
| wq_b: torch.nn.Module | ||
| kv_norm: torch.nn.Module | ||
| wo_a: torch.nn.Module | ||
| wo_b: torch.nn.Module | ||
| attn_sink: torch.nn.Module | ||
| rotary_emb: torch.nn.Module | ||
| indexer: torch.nn.Module | None | ||
| indexer_rotary_emb: torch.nn.Module | ||
| topk_indices_buffer: torch.Tensor | None | ||
| aux_stream_list: list[torch.cuda.Stream] | None = None | ||
|
|
||
|
|
||
| # --8<-- [start:multi_head_latent_attention] | ||
| @PluggableLayer.register("deepseek_v4_multi_head_latent_attention") | ||
| class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): |
There was a problem hiding this comment.
Can you remove these classes as I did in #44246? I think there was a error in resolving the merge conflict.
| torch.ops.vllm.deepseek_v4_attention( | ||
| hidden_states, | ||
| positions, | ||
| o_padded, | ||
| self.layer_name, | ||
| ) |
There was a problem hiding this comment.
This torch op was also removed in the main branch
| "bhr,hdr->bhd", | ||
| (o_fp8, o_scale), | ||
| (wo_a_fp8, wo_a_scale), | ||
| torch.ops.vllm.deepseek_v4_fp8_einsum( |
| def deepseek_v4_attention_fake( | ||
| hidden_states: torch.Tensor, | ||
| positions: torch.Tensor, | ||
| out: torch.Tensor, | ||
| layer_name: str, | ||
| ) -> None: | ||
| return None | ||
|
|
||
|
|
||
| direct_register_custom_op( | ||
| op_name="deepseek_v4_attention", | ||
| op_func=deepseek_v4_attention, | ||
| mutates_args=["out"], | ||
| fake_impl=deepseek_v4_attention_fake, | ||
| ) | ||
|
|
||
|
|
||
| def deepseek_v4_fp8_einsum( | ||
| a: torch.Tensor, | ||
| a_scale: torch.Tensor, | ||
| b: torch.Tensor, | ||
| b_scale: torch.Tensor, | ||
| out: torch.Tensor, | ||
| equation: str, | ||
| recipe: list[int], | ||
| ) -> None: | ||
| fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe)) | ||
|
|
||
|
|
||
| def deepseek_v4_fp8_einsum_fake( | ||
| a: torch.Tensor, | ||
| a_scale: torch.Tensor, | ||
| b: torch.Tensor, | ||
| b_scale: torch.Tensor, | ||
| out: torch.Tensor, | ||
| equation: str, | ||
| recipe: list[int], | ||
| ) -> None: | ||
| return None | ||
|
|
||
|
|
||
| direct_register_custom_op( | ||
| op_name="deepseek_v4_fp8_einsum", | ||
| op_func=deepseek_v4_fp8_einsum, | ||
| mutates_args=["out"], | ||
| fake_impl=deepseek_v4_fp8_einsum_fake, | ||
| ) |
There was a problem hiding this comment.
This also needs to be removed
| ), | ||
| Float32(self.fp8_max), | ||
| ) | ||
| y1 = cute.arch.fmin( |
There was a problem hiding this comment.
Should we avoid using fmin?
There was a problem hiding this comment.
You're right. Will change
f73111f to
a37d8af
Compare
a37d8af to
3f596f3
Compare
| FLASHMLA_SPARSE_V4 = ( | ||
| "vllm.models.deepseek_v4.nvidia.flashmla.DeepseekV4FlashMLASparseBackend" | ||
| ) | ||
| FLASHINFER_MLA_SPARSE_V4 = ( | ||
| "vllm.models.deepseek_v4.nvidia.flashinfer_sparse." | ||
| "DeepseekV4FlashInferMLASparseBackend" | ||
| ) | ||
| ROCM_FLASHMLA_SPARSE_V4 = ( |
There was a problem hiding this comment.
nit: What about DSV4 instead of V4?
| FLASHMLA_SPARSE_V4 = ( | |
| "vllm.models.deepseek_v4.nvidia.flashmla.DeepseekV4FlashMLASparseBackend" | |
| ) | |
| FLASHINFER_MLA_SPARSE_V4 = ( | |
| "vllm.models.deepseek_v4.nvidia.flashinfer_sparse." | |
| "DeepseekV4FlashInferMLASparseBackend" | |
| ) | |
| ROCM_FLASHMLA_SPARSE_V4 = ( | |
| FLASHMLA_SPARSE_DSV4 = ( | |
| "vllm.models.deepseek_v4.nvidia.flashmla.DeepseekV4FlashMLASparseBackend" | |
| ) | |
| FLASHINFER_MLA_SPARSE_DSV4 = ( | |
| "vllm.models.deepseek_v4.nvidia.flashinfer_sparse." | |
| "DeepseekV4FlashInferMLASparseBackend" | |
| ) | |
| ROCM_FLASHMLA_SPARSE_DSV4 = ( |
|
This pull request has merge conflicts that must be resolved before it can be |
Add a selectable DeepSeek V4 sparse-MLA decode backend that runs through
FlashInfer's TRTLLM-gen kernel with a contiguous bf16 / per-tensor FP8 KV
cache, alongside the existing FlashMLA path. Re-ported from PerkzZheng's
- csrc: sibling fusedDeepseekV4FullCacheKernel (contiguous 512-wide bf16 /
per-tensor fp8 insert) + full_cache_{bf16,fp8}_insert ops, in the libtorch
stable ABI (csrc/libtorch_stable/).
- nvidia/flashinfer_sparse.py: DeepseekV4FlashInferMLASparseBackend/Impl;
public flashinfer.mla.trtllm_batch_decode_sparse_mla_dsv4 launcher, two-call
decode/prefill split, q-head padding to {64,128}, fp8 scale buffers.
- registry + selection: FLASHMLA_SPARSE_DSV4 / FLASHINFER_MLA_SPARSE_DSV4 /
ROCM_FLASHMLA_SPARSE_DSV4; _select_v4_sparse_impl consults the backend;
_resolve_dsv4_kv_cache_dtype maps dtype per backend.
- compressor: CuTeDSL full-cache classes (SparseAttnCompressNormRopeStoreFullC4Kernel,
SparseAttnNormRopeStoreFullKernel) separate from the pristine legacy UE8M0
classes so the legacy path keeps its perf; build_flashinfer_mixed_sparse_indices
in common/ops.
- SWA cache / kv_cache_interface: accept bf16/fp8 dtypes, gate 576B alignment
on fp8_ds_mla.
- docs: dedicated "DeepSeek V4 Decode Backends" section.
- tests: full-cache parity (insert ops + cutedsl compressor).
Verified: kernel/compressor tests pass; e2e GSM8K on DSv4-Flash (TP=4, fp8)
matches the FlashMLA baseline (~0.953).
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
8e39bc7 to
6a65ae7
Compare
WoosukKwon
left a comment
There was a problem hiding this comment.
Thanks for the PR!
I will follow up with some refactoring.
|
Hi @zyongye, thanks for the PR! Have you done benchmarking to measure the performance of the TRTLLM gen attention kernel vs FlashMLA? Which settings does this kernel have better perf for dsv4? |
I only test c2048. I actually didn't see any perf improvement. Haven't check out points yet. |
Signed-off-by: JisoLya <523420504@qq.com>
Signed-off-by: Waqar Ahmed <waqar.ahmed@amd.com>
Summary
Rebase of @PerkzZheng's #42316 onto current
main, plus a few materiallynew pieces:
FlashInferMLASparseMetadataFlashInferMLASparseMetadataBuilderso that forcompress_ratio == 128layers the mixed-sparse-index Triton kernel runs once per step instead
of once per layer. The SWA-baked combine is materialized lazily on first
access by
ensure_sparse_indicesand cached on the metadata for theremaining C128A layers in the same step. C4 / SWA-only paths are
unchanged (they read
layer.topk_indices_buffer, which is per-layer).into one
flashinfer_trtllm_batch_decode_sparse_mla_dsv4_rawcall overthe mixed batch.
precomputed Python-float
bmm1_scale/bmm2_scalederived from theper-tensor placeholders. The TRTLLM-GEN sparse-MLA launcher takes
different C++ code paths for scalar vs 1-elem-tensor scale args, so this
matters for correctness.
matching what the FlashInfer V4 backend reads (no UE8M0 padding).
Relation to #42316
This branch is a rebase of #42316; the C128A metadata caching, single-call
launcher collapse, and FP8 scale dispatch documentation are the new pieces.
Opening as a separate PR rather than pushing into #42316 because the rebase
into Port DeepSeek V4 FlashInfer sparse MLA kernels #42316 if @PerkzZheng prefers.
End-to-end eval
deepseek-ai/DeepSeek-V4-Flash, 4× GB200, TP=4,V4_FLASHINFER_MLA_SPARSEbackend,
--kv-cache-dtype fp8 --block-size 256,cudagraph_mode=FULL_DECODE_ONLY. vLLM commitce4e168ba(branch tipbefore today's mechanical rebase onto
main; rebase only resolved staletorch-stable-ABI decls in
csrc/ops.h/csrc/torch_bindings.cpp):Test plan
pre-commit run --files vllm/models/deepseek_v4/nvidia/flashinfer_sparse.py(passed, including mypy)tests/kernels/test_compressor_kv_cache.py— 36/36 passedbuild_flashinfer_mixed_sparse_indicesTriton kernelnow appears once per step instead of N×per-step.
AI-assisted disclosure
Developed with AI assistance (Claude Code). Per AGENTS.md, the submitter
has reviewed each changed line.
🤖 Generated with Claude Code