Skip to content

Commit 854b7b4

Browse files
zyongyeclaude
andcommitted
[DSv4] FlashInfer sparse MLA: collapse decode+prefill into single launcher call
GSM8K parity (95) verified with the full mixed batch passed in one flashinfer_trtllm_batch_decode_sparse_mla_dsv4_raw call -- the prior two-call split (PR vllm-project#42316 pattern) is no longer needed. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
1 parent 240a843 commit 854b7b4

1 file changed

Lines changed: 17 additions & 57 deletions

File tree

vllm/models/deepseek_v4/nvidia/flashinfer_sparse.py

Lines changed: 17 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -372,60 +372,20 @@ def _forward(
372372

373373
workspace = _get_flashinfer_dsv4_workspace(q.device)
374374

375-
# Split decode and prefill into two launcher calls (PR #42316 pattern).
376-
# The TRTLLM-GEN DSV4 sparse-MLA kernel was tuned for uniform-q batches
377-
# and an earlier attempt to fold both halves into a single call produced
378-
# subtly wrong attention outputs (~3pt GSM8K drop). Decode uses the
379-
# absolute cum_seq_lens_q (it already starts at 0); prefill uses a
380-
# rebased cum_seq_lens_q so its sliced query view re-anchors at 0.
381-
if num_decode_tokens > 0:
382-
decode_query_start_loc = query_start_loc[: num_decodes + 1]
383-
decode_query_start_loc_cpu = query_start_loc_cpu[: num_decodes + 1]
384-
decode_query_lens_cpu = (
385-
decode_query_start_loc_cpu[1:] - decode_query_start_loc_cpu[:-1]
386-
)
387-
flashinfer_trtllm_batch_decode_sparse_mla_dsv4_raw(
388-
query=query[:num_decode_tokens],
389-
swa_kv_cache=swa_k_cache,
390-
workspace_buffer=workspace,
391-
sparse_indices=sparse_indices[:num_decode_tokens],
392-
compressed_kv_cache=compressed_kv_cache,
393-
sparse_topk_lens=sparse_topk_lens[:num_decode_tokens],
394-
seq_lens=seq_lens[:num_decodes],
395-
out=output[:num_decode_tokens],
396-
bmm1_scale=bmm1_scale,
397-
bmm2_scale=bmm2_scale,
398-
sinks=layer.attn_sink,
399-
cum_seq_lens_q=decode_query_start_loc,
400-
max_q_len=int(decode_query_lens_cpu.max().item()),
401-
)
402-
403-
if num_prefill_tokens > 0:
404-
# Prefill query view starts at offset num_decode_tokens inside the
405-
# combined batch; the launcher expects cum_seq_lens_q to index into
406-
# that sliced view, so subtract the decode base.
407-
prefill_query_start_loc = (
408-
query_start_loc[num_decodes : num_reqs + 1]
409-
- query_start_loc[num_decodes]
410-
)
411-
prefill_query_start_loc_cpu = query_start_loc_cpu[
412-
num_decodes : num_reqs + 1
413-
]
414-
prefill_query_lens_cpu = (
415-
prefill_query_start_loc_cpu[1:] - prefill_query_start_loc_cpu[:-1]
416-
)
417-
flashinfer_trtllm_batch_decode_sparse_mla_dsv4_raw(
418-
query=query[num_decode_tokens:num_tokens],
419-
swa_kv_cache=swa_k_cache,
420-
workspace_buffer=workspace,
421-
sparse_indices=sparse_indices[num_decode_tokens:num_tokens],
422-
compressed_kv_cache=compressed_kv_cache,
423-
sparse_topk_lens=sparse_topk_lens[num_decode_tokens:num_tokens],
424-
seq_lens=seq_lens[num_decodes:num_reqs],
425-
out=output[num_decode_tokens:num_tokens],
426-
bmm1_scale=bmm1_scale,
427-
bmm2_scale=bmm2_scale,
428-
sinks=layer.attn_sink,
429-
cum_seq_lens_q=prefill_query_start_loc,
430-
max_q_len=int(prefill_query_lens_cpu.max().item()),
431-
)
375+
# Single-call launcher over the full mixed decode+prefill batch.
376+
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
377+
flashinfer_trtllm_batch_decode_sparse_mla_dsv4_raw(
378+
query=query,
379+
swa_kv_cache=swa_k_cache,
380+
workspace_buffer=workspace,
381+
sparse_indices=sparse_indices[:num_tokens],
382+
compressed_kv_cache=compressed_kv_cache,
383+
sparse_topk_lens=sparse_topk_lens[:num_tokens],
384+
seq_lens=seq_lens[:num_reqs],
385+
out=output,
386+
bmm1_scale=bmm1_scale,
387+
bmm2_scale=bmm2_scale,
388+
sinks=layer.attn_sink,
389+
cum_seq_lens_q=query_start_loc,
390+
max_q_len=int(query_lens_cpu.max().item()),
391+
)

0 commit comments

Comments
 (0)