Skip to content

Commit 1daa8c3

Browse files
authored
[https://nvbugs/5340941][https://nvbugs/5375785] - fix: Wrap attentio… (#6355)
Signed-off-by: Jin Li <[email protected]>
1 parent f39d621 commit 1daa8c3

File tree

16 files changed

+412
-455
lines changed

16 files changed

+412
-455
lines changed

tensorrt_llm/_torch/attention_backend/flashinfer.py

Lines changed: 88 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,82 @@ def update_quant_config(self, new_quant_config: Optional[QuantConfig]):
478478
self.has_fp8_kv_cache = self.quant_config.layer_quant_mode.has_fp8_kv_cache(
479479
)
480480

481+
def forward_impl(
482+
self,
483+
q: torch.Tensor,
484+
k: torch.Tensor,
485+
v: torch.Tensor,
486+
metadata: FlashInferAttentionMetadata,
487+
attention_mask_type: int,
488+
output: torch.Tensor,
489+
attention_mask_data: Optional[torch.Tensor] = None,
490+
attention_window_size: Optional[int] = None,
491+
) -> None:
492+
# Query
493+
q = q.view(-1, self.num_heads, self.head_dim)
494+
495+
# Key and Value
496+
kv_cache = metadata.kv_cache_manager.get_buffers(self.layer_idx)
497+
498+
if k is not None and v is not None:
499+
k = k.view(-1, self.num_kv_heads, self.head_dim)
500+
v = v.view(-1, self.num_kv_heads, self.head_dim)
501+
502+
if self.has_fp8_kv_cache:
503+
assert kv_cache.dtype == torch.float8_e4m3fn, (
504+
f"KV cache should have fp8 dtype, but get {kv_cache.dtype}")
505+
k = k.to(torch.float8_e4m3fn)
506+
v = v.to(torch.float8_e4m3fn)
507+
assert k.dtype == v.dtype == kv_cache.dtype, (
508+
f"KV cache dtype {kv_cache.dtype} does not match k/v dtype {k.dtype}/{v.dtype}"
509+
)
510+
511+
flashinfer.page.append_paged_kv_cache(
512+
append_key=k,
513+
append_value=v,
514+
batch_indices=metadata.batch_indices,
515+
positions=metadata.positions,
516+
paged_kv_cache=kv_cache,
517+
kv_indices=metadata.paged_kv_indices,
518+
kv_indptr=metadata.paged_kv_indptr,
519+
kv_last_page_len=metadata.paged_kv_last_page_len,
520+
kv_layout=metadata.kv_layout)
521+
522+
num_contexts = metadata.num_contexts
523+
num_generations = metadata.num_generations
524+
num_ctx_tokens = metadata.num_ctx_tokens
525+
526+
def prefill_forward(plan_params: PlanParams, out: torch.Tensor):
527+
wrapper = metadata.get_prefill_wrapper(plan_params)
528+
wrapper.run(q[:num_ctx_tokens],
529+
kv_cache,
530+
out=out.view(-1, self.num_heads, self.head_dim))
531+
532+
def decode_forward(plan_params: PlanParams, out: torch.Tensor):
533+
wrapper = metadata.get_decode_wrapper(plan_params)
534+
wrapper.run(q[num_ctx_tokens:],
535+
kv_cache,
536+
out=out.view(-1, self.num_heads, self.head_dim))
537+
538+
# this will do nothing if the last forward pass had the same parameters
539+
plan_params = metadata.plan(self.num_heads,
540+
self.num_kv_heads,
541+
self.head_dim,
542+
q_dtype=q.dtype,
543+
kv_dtype=kv_cache.dtype,
544+
q_scaling=self.q_scaling,
545+
attention_window_size=attention_window_size,
546+
attention_mask_type=attention_mask_type,
547+
attention_mask_data=attention_mask_data)
548+
549+
if num_contexts == 0:
550+
decode_forward(plan_params, output)
551+
elif num_generations == 0:
552+
prefill_forward(plan_params, output)
553+
else:
554+
prefill_forward(plan_params, output[:num_ctx_tokens, :])
555+
decode_forward(plan_params, output[num_ctx_tokens:, :])
556+
481557
def forward(self,
482558
q: torch.Tensor,
483559
k: Optional[torch.Tensor],
@@ -487,6 +563,7 @@ def forward(self,
487563
attention_window_size: Optional[int] = None,
488564
attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
489565
attention_mask_data: Optional[torch.Tensor] = None,
566+
output: Optional[torch.Tensor] = None,
490567
**kwargs) -> torch.Tensor:
491568
if attention_mask == CustomAttentionMask.CUSTOM:
492569
assert attention_mask_data is not None, "attention_mask_data is required for custom attention mask."
@@ -502,133 +579,15 @@ def forward(self,
502579
else:
503580
raise ValueError("Unexpected attention mask type")
504581

505-
return forward_pattern(q=q,
506-
k=k,
507-
v=v,
508-
num_heads=self.num_heads,
509-
head_dim=self.head_dim,
510-
num_kv_heads=self.num_kv_heads,
511-
layer_idx=self.layer_idx,
512-
has_fp8_kv_cache=self.has_fp8_kv_cache,
513-
attention_mask_type=attention_mask_type,
514-
q_scaling=self.q_scaling,
515-
attention_mask_data=attention_mask_data,
516-
attention_window_size=attention_window_size)
517-
518-
519-
@torch.library.custom_op("trtllm::flashinfer_forward", mutates_args=())
520-
def forward_pattern(
521-
q: torch.Tensor,
522-
k: torch.Tensor,
523-
v: torch.Tensor,
524-
num_heads: int,
525-
head_dim: int,
526-
num_kv_heads: int,
527-
layer_idx: int,
528-
has_fp8_kv_cache: bool,
529-
attention_mask_type: int,
530-
q_scaling: Optional[float] = None,
531-
attention_mask_data: Optional[torch.Tensor] = None,
532-
attention_window_size: Optional[int] = None,
533-
) -> torch.Tensor:
534-
'''
535-
Wrapping the flashinfer forward as a custom op is required to fix `torch.compile` graph breaks,
536-
otherwise it will graph break when calling `metadata.num_contexts` since it convert tensor's sum directly to int.
537-
'''
538-
# torch.compile does not support custom object as arguments, so we have to use global function to get the metadata.
539-
extra_attrs = get_model_extra_attrs()
540-
if extra_attrs is not None:
541-
metadata_ref = extra_attrs.get("attention_metadata", None)
542-
metadata = metadata_ref() if metadata_ref is not None else None
543-
else:
544-
metadata = get_global_attrs().attention_metadata()
545-
546-
assert isinstance(
547-
metadata,
548-
FlashInferAttentionMetadata,
549-
)
550-
551-
# Query
552-
q = q.view(-1, num_heads, head_dim)
553-
554-
# Key and Value
555-
kv_cache = metadata.kv_cache_manager.get_buffers(layer_idx)
556-
557-
if k is not None and v is not None:
558-
k = k.view(-1, num_kv_heads, head_dim)
559-
v = v.view(-1, num_kv_heads, head_dim)
560-
561-
if has_fp8_kv_cache:
562-
assert kv_cache.dtype == torch.float8_e4m3fn, f"KV cache should have fp8 dtype, but get {kv_cache.dtype}"
563-
k = k.to(torch.float8_e4m3fn)
564-
v = v.to(torch.float8_e4m3fn)
565-
assert k.dtype == v.dtype == kv_cache.dtype, f"KV cache dtype {kv_cache.dtype} does not match k/v dtype {k.dtype}/{v.dtype}"
566-
567-
flashinfer.page.append_paged_kv_cache(
568-
append_key=k,
569-
append_value=v,
570-
batch_indices=metadata.batch_indices,
571-
positions=metadata.positions,
572-
paged_kv_cache=kv_cache,
573-
kv_indices=metadata.paged_kv_indices,
574-
kv_indptr=metadata.paged_kv_indptr,
575-
kv_last_page_len=metadata.paged_kv_last_page_len,
576-
kv_layout=metadata.kv_layout)
577-
578-
num_contexts = metadata.num_contexts
579-
num_generations = metadata.num_generations
580-
num_ctx_tokens = metadata.num_ctx_tokens
581-
582-
def prefill_forward(plan_params: PlanParams):
583-
wrapper = metadata.get_prefill_wrapper(plan_params)
584-
output = wrapper.run(q[:num_ctx_tokens], kv_cache)
585-
output = output.view(num_ctx_tokens, -1)
586-
return output
587-
588-
def decode_forward(plan_params: PlanParams):
589-
wrapper = metadata.get_decode_wrapper(plan_params)
590-
output = wrapper.run(q[num_ctx_tokens:], kv_cache)
591-
output = output.view(num_generations, -1)
582+
if output is None:
583+
output = torch.empty_like(q)
584+
585+
self.forward_impl(q=q,
586+
k=k,
587+
v=v,
588+
metadata=metadata,
589+
attention_mask_type=attention_mask_type,
590+
attention_mask_data=attention_mask_data,
591+
attention_window_size=attention_window_size,
592+
output=output)
592593
return output
593-
594-
# this will do nothing if the last forward pass had the same parameters
595-
plan_params = metadata.plan(num_heads,
596-
num_kv_heads,
597-
head_dim,
598-
q_dtype=q.dtype,
599-
kv_dtype=kv_cache.dtype,
600-
q_scaling=q_scaling,
601-
attention_window_size=attention_window_size,
602-
attention_mask_type=attention_mask_type,
603-
attention_mask_data=attention_mask_data)
604-
605-
if num_contexts > 0:
606-
ctx_output = prefill_forward(plan_params)
607-
608-
if num_generations > 0:
609-
gen_output = decode_forward(plan_params)
610-
611-
if num_contexts > 0 and num_generations > 0:
612-
output = torch.cat([ctx_output, gen_output], dim=0)
613-
elif num_contexts > 0:
614-
output = ctx_output
615-
elif num_generations > 0:
616-
output = gen_output
617-
618-
return output
619-
620-
621-
@forward_pattern.register_fake
622-
def _(
623-
q: torch.Tensor,
624-
k: torch.Tensor,
625-
v: torch.Tensor,
626-
num_heads: int,
627-
head_dim: int,
628-
num_kv_heads: int,
629-
layer_idx: int,
630-
has_fp8_kv_cache: bool,
631-
attention_mask_type: int,
632-
attention_mask_data: Optional[torch.Tensor],
633-
):
634-
return torch.empty_like(q)

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import weakref
44
from dataclasses import dataclass, field
5-
from typing import Optional
5+
from typing import Optional, Tuple, Union
66

77
import torch
88

@@ -11,8 +11,8 @@
1111
from tensorrt_llm.logger import logger
1212
from tensorrt_llm.models.modeling_utils import QuantConfig
1313

14-
from ..utils import (Fp4QuantizedTensor, compute_swizzled_sf_shape,
15-
get_global_attrs, get_model_extra_attrs)
14+
from ..utils import (compute_swizzled_sf_shape, get_global_attrs,
15+
get_model_extra_attrs)
1616
from .interface import (AttentionBackend, AttentionInputType, AttentionMask,
1717
AttentionMetadata, KVCacheParams, MLAParams,
1818
PositionalEmbeddingParams, PredefinedAttentionMask,
@@ -263,6 +263,35 @@ def plan(
263263
self.spec_decoding_generation_lengths = spec_decoding_generation_lengths
264264
self.kwargs.update(kwargs)
265265

266+
def create_output(self, q: torch.Tensor, out_dtype: torch.dtype):
267+
num_tokens = q.size(0)
268+
attention_input_type = (AttentionInputType(self.attention_input_type)
269+
if self.attention_input_type is not None else
270+
AttentionInputType.mixed)
271+
if out_dtype is None:
272+
out_dtype = q.dtype
273+
is_gen_only = attention_input_type == AttentionInputType.generation_only
274+
v_head_size = self.head_size
275+
if self.is_mla_enable:
276+
v_head_size = self.kv_lora_rank if is_gen_only else self.v_head_dim
277+
if out_dtype == torch.uint8:
278+
num_nvfp4_elements_per_container = 2
279+
scaling_vector_size = 16
280+
size_per_token = self.num_heads * v_head_size
281+
output = q.new_empty(
282+
(num_tokens,
283+
size_per_token // num_nvfp4_elements_per_container),
284+
dtype=torch.uint8)
285+
# Create a sf (scaling factors) tensor for NVFP4 (use INT8 as the container dtype).
286+
output_sf = q.new_empty(compute_swizzled_sf_shape(
287+
num_tokens, size_per_token // scaling_vector_size),
288+
dtype=torch.uint8)
289+
else:
290+
output = q.new_empty((num_tokens, self.num_heads * v_head_size),
291+
dtype=out_dtype)
292+
output_sf = None
293+
return output, output_sf
294+
266295
def run(
267296
self,
268297
q: torch.Tensor,
@@ -361,30 +390,7 @@ def run(
361390

362391
if output is None:
363392
assert output_sf is None
364-
num_tokens = q.size(0)
365-
attention_input_type = (AttentionInputType(
366-
self.attention_input_type) if self.attention_input_type
367-
is not None else AttentionInputType.mixed)
368-
if out_dtype is None:
369-
out_dtype = q.dtype
370-
is_gen_only = attention_input_type == AttentionInputType.generation_only
371-
v_head_size = self.head_size if not self.is_mla_enable else self.kv_lora_rank if is_gen_only else self.v_head_dim
372-
if out_dtype == torch.uint8:
373-
num_nvfp4_elements_per_container = 2
374-
scaling_vector_size = 16
375-
size_per_token = self.num_heads * v_head_size
376-
output = q.new_empty(
377-
(num_tokens,
378-
size_per_token // num_nvfp4_elements_per_container),
379-
dtype=torch.uint8)
380-
# Create a sf (scaling factors) tensor for NVFP4 (use INT8 as the container dtype).
381-
output_sf = q.new_empty(compute_swizzled_sf_shape(
382-
num_tokens, size_per_token // scaling_vector_size),
383-
dtype=torch.uint8)
384-
else:
385-
output = q.new_empty((num_tokens, self.num_heads * v_head_size),
386-
dtype=out_dtype)
387-
output_sf = None
393+
output, output_sf = self.create_output(q, out_dtype)
388394
else:
389395
# output is provided, expect output_sf be provided as well if has NVFP4 output.
390396
assert out_dtype is None or out_dtype != torch.uint8 or output_sf is not None
@@ -1089,10 +1095,11 @@ def forward(
10891095
mla_context_paged_kv: Optional[torch.Tensor] = None,
10901096
mla_context_kv_cache_block_offsets: Optional[torch.Tensor] = None,
10911097
softmax_stats_tensor: Optional[torch.Tensor] = None,
1098+
enable_attn_nvfp4_output: bool = True,
10921099
output: Optional[torch.Tensor] = None,
10931100
output_sf: Optional[torch.Tensor] = None,
10941101
**kwargs,
1095-
) -> torch.Tensor:
1102+
) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
10961103
assert isinstance(
10971104
metadata,
10981105
TrtllmAttentionMetadata,
@@ -1111,7 +1118,8 @@ def forward(
11111118
metadata)
11121119

11131120
use_nvfp4_output = False
1114-
if self.has_nvfp4 and self.support_nvfp4_output():
1121+
if enable_attn_nvfp4_output and self.has_nvfp4 and self.support_nvfp4_output(
1122+
):
11151123
# Runtime check whether the NVFP4 output kernel is available.
11161124
use_nvfp4_output = self.wrapper.is_nvfp4_output_kernel_available(
11171125
tokens_per_block=metadata.tokens_per_block,
@@ -1184,9 +1192,9 @@ def forward(
11841192
update_kv_cache=not metadata.is_cross or k is not None,
11851193
attention_mask=attention_mask)
11861194

1187-
if out_dtype == torch.uint8:
1188-
assert output_sf is not None
1189-
return Fp4QuantizedTensor(output, output_sf)
1195+
if use_nvfp4_output:
1196+
return output, output_sf
1197+
11901198
return output
11911199

11921200
@classmethod

tensorrt_llm/_torch/compilation/piecewise_optimizer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,17 +246,16 @@ def piecewise_optimizer(
246246
if node.op in ("output", "placeholder"):
247247
continue
248248
if (not stop_partition and is_call_function(node, [
249-
torch.ops.trtllm.attention_inplace.default,
249+
torch.ops.trtllm.attn_custom_op_inplace.default,
250250
torch.ops.trtllm.mla_custom_op_inplace.default,
251251
torch.ops.aten.index.Tensor,
252252
torch.ops.aten.cumsum.default,
253253
])):
254254
idx += 1
255255
node_to_graph_id[node] = idx
256256
exclude_modules_id.append(idx)
257-
if node.target != torch.ops.trtllm.attention_inplace.default and node.target != torch.ops.trtllm.mla_custom_op_inplace.default:
257+
if node.target != torch.ops.trtllm.attn_custom_op_inplace.default and node.target != torch.ops.trtllm.mla_custom_op_inplace.default:
258258
# We only know it is safe to continue splitting after attention
259-
# since attention_inplace will not produce any new tensor
260259
stop_partition = True
261260
else:
262261
idx += 1

tensorrt_llm/_torch/compilation/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,8 @@ def inplace_info():
4949
1: "input",
5050
2: "residual"
5151
},
52-
torch.ops.trtllm.attention_inplace.default: {
52+
torch.ops.trtllm.attn_custom_op_inplace.default: {
5353
1: "output",
54-
2: "output_sf"
5554
},
5655
torch.ops.trtllm.mla_custom_op_inplace.default: {
5756
1: "output"

0 commit comments

Comments
 (0)