Skip to content

Commit b5e383c

Browse files
authored
[gpt-oss] raise error for flashinfer backend without trtllm (vllm-project#24482)
Signed-off-by: Chen Zhang <[email protected]>
1 parent 9a16130 commit b5e383c

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

vllm/v1/attention/backends/flashinfer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,11 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
216216
self.window_left = self.global_hyperparameters.window_left
217217
self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap
218218
self.has_sinks = self.global_hyperparameters.has_sinks
219-
219+
if self.has_sinks and not supports_trtllm_attention()[0]:
220+
raise NotImplementedError(
221+
"FlashInfer backend currently does not support attention "
222+
"sinks, please use trtllm on blackwell or flash attention on "
223+
"earlier GPUs.")
220224
# Preparing persistent buffers (device-side)
221225
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
222226
dtype=torch.int32,
@@ -408,7 +412,11 @@ def build(self,
408412
self.q_data_type,
409413
is_prefill=False,
410414
has_sinks=self.has_sinks)
411-
415+
if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm):
416+
raise NotImplementedError(
417+
"FlashInfer backend currently does not support attention "
418+
"sinks, please use trtllm on blackwell or flash attention on "
419+
"earlier GPUs.")
412420
attn_metadata = FlashInferMetadata(
413421
num_actual_tokens=num_actual_tokens,
414422
q_data_type=self.q_data_type,

0 commit comments

Comments
 (0)