@@ -372,60 +372,20 @@ def _forward(
372372
373373 workspace = _get_flashinfer_dsv4_workspace (q .device )
374374
375- # Split decode and prefill into two launcher calls (PR #42316 pattern).
376- # The TRTLLM-GEN DSV4 sparse-MLA kernel was tuned for uniform-q batches
377- # and an earlier attempt to fold both halves into a single call produced
378- # subtly wrong attention outputs (~3pt GSM8K drop). Decode uses the
379- # absolute cum_seq_lens_q (it already starts at 0); prefill uses a
380- # rebased cum_seq_lens_q so its sliced query view re-anchors at 0.
381- if num_decode_tokens > 0 :
382- decode_query_start_loc = query_start_loc [: num_decodes + 1 ]
383- decode_query_start_loc_cpu = query_start_loc_cpu [: num_decodes + 1 ]
384- decode_query_lens_cpu = (
385- decode_query_start_loc_cpu [1 :] - decode_query_start_loc_cpu [:- 1 ]
386- )
387- flashinfer_trtllm_batch_decode_sparse_mla_dsv4_raw (
388- query = query [:num_decode_tokens ],
389- swa_kv_cache = swa_k_cache ,
390- workspace_buffer = workspace ,
391- sparse_indices = sparse_indices [:num_decode_tokens ],
392- compressed_kv_cache = compressed_kv_cache ,
393- sparse_topk_lens = sparse_topk_lens [:num_decode_tokens ],
394- seq_lens = seq_lens [:num_decodes ],
395- out = output [:num_decode_tokens ],
396- bmm1_scale = bmm1_scale ,
397- bmm2_scale = bmm2_scale ,
398- sinks = layer .attn_sink ,
399- cum_seq_lens_q = decode_query_start_loc ,
400- max_q_len = int (decode_query_lens_cpu .max ().item ()),
401- )
402-
403- if num_prefill_tokens > 0 :
404- # Prefill query view starts at offset num_decode_tokens inside the
405- # combined batch; the launcher expects cum_seq_lens_q to index into
406- # that sliced view, so subtract the decode base.
407- prefill_query_start_loc = (
408- query_start_loc [num_decodes : num_reqs + 1 ]
409- - query_start_loc [num_decodes ]
410- )
411- prefill_query_start_loc_cpu = query_start_loc_cpu [
412- num_decodes : num_reqs + 1
413- ]
414- prefill_query_lens_cpu = (
415- prefill_query_start_loc_cpu [1 :] - prefill_query_start_loc_cpu [:- 1 ]
416- )
417- flashinfer_trtllm_batch_decode_sparse_mla_dsv4_raw (
418- query = query [num_decode_tokens :num_tokens ],
419- swa_kv_cache = swa_k_cache ,
420- workspace_buffer = workspace ,
421- sparse_indices = sparse_indices [num_decode_tokens :num_tokens ],
422- compressed_kv_cache = compressed_kv_cache ,
423- sparse_topk_lens = sparse_topk_lens [num_decode_tokens :num_tokens ],
424- seq_lens = seq_lens [num_decodes :num_reqs ],
425- out = output [num_decode_tokens :num_tokens ],
426- bmm1_scale = bmm1_scale ,
427- bmm2_scale = bmm2_scale ,
428- sinks = layer .attn_sink ,
429- cum_seq_lens_q = prefill_query_start_loc ,
430- max_q_len = int (prefill_query_lens_cpu .max ().item ()),
431- )
375+ # Single-call launcher over the full mixed decode+prefill batch.
376+ query_lens_cpu = query_start_loc_cpu [1 :] - query_start_loc_cpu [:- 1 ]
377+ flashinfer_trtllm_batch_decode_sparse_mla_dsv4_raw (
378+ query = query ,
379+ swa_kv_cache = swa_k_cache ,
380+ workspace_buffer = workspace ,
381+ sparse_indices = sparse_indices [:num_tokens ],
382+ compressed_kv_cache = compressed_kv_cache ,
383+ sparse_topk_lens = sparse_topk_lens [:num_tokens ],
384+ seq_lens = seq_lens [:num_reqs ],
385+ out = output ,
386+ bmm1_scale = bmm1_scale ,
387+ bmm2_scale = bmm2_scale ,
388+ sinks = layer .attn_sink ,
389+ cum_seq_lens_q = query_start_loc ,
390+ max_q_len = int (query_lens_cpu .max ().item ()),
391+ )
0 commit comments