Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions tensorrt_llm/_torch/attention_backend/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from ..utils import get_global_attrs, get_model_extra_attrs
from .interface import (AttentionBackend, AttentionMask, AttentionMetadata,
PredefinedAttentionMask)
CustomAttentionMask, PredefinedAttentionMask)

try:
check_cuda_arch()
Expand Down Expand Up @@ -366,6 +366,12 @@ def _plan_with_params(self, plan_params: PlanParams) -> PlanParams:
is_causal = plan_params.attention_mask_type == AttentionMaskType.causal

def prefill_plan():
# Setting `window_left` to -1 for custom attention mask is important.
# Else, FlashInfer proceeds to use SWA regardless of attention_mask_data.
if plan_params.attention_mask_data is not None:
window_left = -1
else:
window_left = plan_params.window_left
prefill_wrapper.plan(
self.qo_indptr[:self.num_contexts + 1],
self.paged_kv_indptr_prefill[:self.num_contexts + 1],
Expand All @@ -377,9 +383,10 @@ def prefill_plan():
self.page_size,
causal=is_causal,
sm_scale=plan_params.sm_scale,
window_left=plan_params.window_left,
window_left=window_left,
q_data_type=plan_params.q_dtype,
kv_data_type=plan_params.kv_dtype,
custom_mask=plan_params.attention_mask_data,
)

if plan_params in self._plan_params_to_wrappers:
Expand Down Expand Up @@ -473,8 +480,14 @@ def forward(self,
*,
attention_window_size: Optional[int] = None,
attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
attention_mask_data: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
if attention_mask == PredefinedAttentionMask.CAUSAL:
if attention_mask == CustomAttentionMask.CUSTOM:
assert attention_mask_data is not None, "attention_mask_data is required for custom attention mask."
attention_mask_type = int(AttentionMaskType.custom_mask)
attention_mask_data = attention_mask_data if attention_mask_data.ndim == 1 else attention_mask_data.flatten(
)
elif attention_mask == PredefinedAttentionMask.CAUSAL:
attention_mask_type = int(AttentionMaskType.causal)
attention_mask_data = None
elif attention_mask == PredefinedAttentionMask.FULL:
Expand Down
10 changes: 8 additions & 2 deletions tensorrt_llm/_torch/attention_backend/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,8 +501,14 @@ class PredefinedAttentionMask(str, Enum):
FULL = "full"


# May extend to custom attention mask type
AttentionMask = Union[PredefinedAttentionMask]
class CustomAttentionMask(str, Enum):
"""
Custom attention mask types
"""
CUSTOM = "custom"


AttentionMask = Union[PredefinedAttentionMask, CustomAttentionMask]


class AttentionBackend(Generic[TMetadata]):
Expand Down
141 changes: 134 additions & 7 deletions tensorrt_llm/_torch/models/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from tensorrt_llm.functional import PositionEmbeddingType, RotaryScalingType
from tensorrt_llm.mapping import Mapping

from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import (PositionalEmbeddingParams,
from ..attention_backend import AttentionMetadata, FlashInferAttentionMetadata
from ..attention_backend.interface import (AttentionMask, CustomAttentionMask,
PositionalEmbeddingParams,
PredefinedAttentionMask, RopeParams)
from ..distributed import AllReduceParams
from ..model_config import ModelConfig
Expand Down Expand Up @@ -101,14 +102,19 @@ def forward(
position_ids: Optional[torch.IntTensor],
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.
CAUSAL,
attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
mrope_config: Optional[dict] = None,
all_reduce_params: Optional[AllReduceParams] = None,
lora_params: Optional[dict] = None,
attention_mask_data: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:

if attention_mask_data is not None:
assert isinstance(
attn_metadata, FlashInferAttentionMetadata
), "Only FlashInfer backend supports custom attention mask currently."
assert attention_mask == CustomAttentionMask.CUSTOM
return super().forward(position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
Expand All @@ -117,6 +123,7 @@ def forward(
all_reduce_params=all_reduce_params,
lora_params=lora_params,
attention_window_size=self.attention_window_size,
attention_mask_data=attention_mask_data,
**kwargs)

def apply_qk_norm(self, q, k):
Expand Down Expand Up @@ -214,6 +221,7 @@ def forward(
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
attention_mask_data: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:

Expand All @@ -223,6 +231,9 @@ def forward(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
attention_mask=CustomAttentionMask.CUSTOM if attention_mask_data
is not None else PredefinedAttentionMask.CAUSAL,
attention_mask_data=attention_mask_data,
**kwargs,
)
hidden_states = self.post_attention_layernorm(hidden_states)
Expand Down Expand Up @@ -267,6 +278,8 @@ def forward(
input_ids: Optional[torch.IntTensor] = None,
position_ids: Optional[torch.IntTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
local_attention_mask_data: Optional[torch.Tensor] = None,
global_attention_mask_data: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
Expand All @@ -280,9 +293,13 @@ def forward(
hidden_states = inputs_embeds.to(self.dtype)

for decoder_layer in self.layers:
hidden_states = decoder_layer(position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata)
hidden_states = decoder_layer(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
attention_mask_data=local_attention_mask_data
if decoder_layer.self_attn.is_sliding else
global_attention_mask_data)

hidden_states = self.norm(hidden_states)
return hidden_states
Expand All @@ -301,21 +318,131 @@ def __init__(
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size)

def get_context_mask(
self,
image_token_mask: torch.BoolTensor,
effective_sliding_window: Optional[int] = None,
):
"""
Returns an attention mask such that text tokens attend to each other in causal fashion while image
tokens attend in causal fashion as well as to all other image tokens in a bidirectional manner.
Args:
image_token_mask: A boolean tensor of shape (sequence_length,) where True indicates an image token.
effective_sliding_window: The effective sliding window size for the attention mask. Default is None, which means no sliding window.
For Gemma3, this is the sliding window size from config (e.g. 512 for 1B model).
Returns:
A boolean attention mask of shape (sequence_length, sequence_length).
"""
device = image_token_mask.device
sequence_length = len(image_token_mask)
if effective_sliding_window is None or effective_sliding_window >= sequence_length:
causal_mask = torch.arange(
sequence_length, device=device).unsqueeze(0) <= torch.arange(
sequence_length, device=device).unsqueeze(1)
else:
attention_mask_1 = (torch.arange(sequence_length,
device=device).unsqueeze(0)
<= torch.arange(sequence_length,
device=device).unsqueeze(1))
attention_mask_2 = (
torch.arange(sequence_length, device=device).unsqueeze(0)
> torch.arange(sequence_length, device=device).unsqueeze(1) -
effective_sliding_window)
causal_mask = attention_mask_1 & attention_mask_2

# Apply a bidirectional mask for image tokens.
token_type_ids = torch.zeros(sequence_length,
dtype=torch.int32,
device=device)
# 1 for image tokens, 0 for text tokens.
token_type_ids[image_token_mask] = 1
token_type_mask = token_type_ids.unsqueeze(
0) == token_type_ids.unsqueeze(1)
# If text token, do not change anything.
token_type_mask[token_type_ids == 0] = False
causal_mask = causal_mask.masked_fill(token_type_mask, True)
return causal_mask

# ASSUMPTIONS:
# 1) Chunked prefill is disabled to avoid chunking image tokens as they need bidirectional attention.
# 2) KV cache reuse is disabled to avoid partially matched image tokens (entire image must be reused to get things correct).
def get_flashinfer_attention_mask(
self,
image_token_mask: torch.BoolTensor,
attn_metadata: AttentionMetadata,
effective_sliding_window: Optional[int] = None) -> torch.Tensor:
"""
This is specifically needed for context phase requests. Currently, we don't create custom mask for generation requests because FlashInfer backend
doesn't use it anyway and there's nothing special we need to do for generation requests.
- This function will only be called for a batch when there's at least one context request in the batch with image tokens.
- In context phase, each sample's input_ids may have a mix of image tokens and text tokens where tokens corresponding to an image
appear as a contiguous blob. Example: torch.IntTensor([2, 3, 4, 5, img_idx, img_idx, img_idx, ..., img_idx, 100])
- While the text tokens attend to other tokens in a causal fashion, image tokens attend to others in a causal fashion and well as
attend to other image tokens in a bidirectional manner. Hence, the need for custom masking.
Args:
image_token_mask: A boolean tensor of shape (len(input_ids),) where True indicates an image token. This corresponds to concatenated
list of tokens for all samples in the batch.
attn_metadata: The attention metadata for the batch.
effective_sliding_window: The effective sliding window size for the attention mask. Default is None, which means no sliding window.
For Gemma3, this is the sliding window size from config (e.g. 512 for 1B model).
Returns:
A flattened boolean mask of shape (sum(q_len[i] * k_len[i] for i in range(batch_size)).
"""

assert isinstance(
attn_metadata, FlashInferAttentionMetadata
), "Only FlashInfer backend supports custom mask currently."
num_contexts = attn_metadata.num_contexts
assert num_contexts > 0, "There should be at least one context request in the batch for custom mask."

qo_indptr = attn_metadata.qo_indptr[:num_contexts + 1]
cached_token_lens = attn_metadata.cached_token_lens[:num_contexts]
assert (cached_token_lens == 0).all(
), "cached_token_lens should be 0 for context requests since chunked prefill and kv cache reuse must be disabled."

# Create masks for context requests.
context_mask_list = []
for i in range(num_contexts):
mask_i = self.get_context_mask(
image_token_mask=image_token_mask[qo_indptr[i]:qo_indptr[i +
1]],
effective_sliding_window=effective_sliding_window,
)
context_mask_list.append(mask_i.flatten())
return torch.cat(context_mask_list, dim=0).contiguous()

def forward(
self,
attn_metadata: AttentionMetadata,
input_ids: torch.IntTensor = None,
position_ids: Optional[torch.IntTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
return_context_logits: bool = False,
image_token_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:

local_attention_mask_data = None
global_attention_mask_data = None
if image_token_mask is not None:
global_attention_mask_data = self.get_flashinfer_attention_mask(
image_token_mask=image_token_mask,
attn_metadata=attn_metadata,
effective_sliding_window=None,
)
local_attention_mask_data = self.get_flashinfer_attention_mask(
image_token_mask=image_token_mask,
attn_metadata=attn_metadata,
effective_sliding_window=self.config.sliding_window,
)

output = self.model(
input_ids=input_ids,
attn_metadata=attn_metadata,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
local_attention_mask_data=local_attention_mask_data,
global_attention_mask_data=global_attention_mask_data,
)

return self.logits_processor.forward(
Expand Down
19 changes: 16 additions & 3 deletions tensorrt_llm/_torch/models/modeling_gemma3vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __call__(
input_ids = preprocess_outputs[0]["mm_processor_kwargs"]["input_ids"]
mm_features = self._process(pixel_values)
multimodal_data = {}
multimodal_data["multimodal_embedding"] = mm_features
multimodal_data["multimodal_embedding"] = mm_features.squeeze(dim=0)
return input_ids[0].to(torch.int32).tolist(), {
"multimodal_data": multimodal_data
}
Expand Down Expand Up @@ -129,6 +129,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,

self.model_config = model_config
self.vocab_size = config.text_config.vocab_size
self.sliding_window = config.text_config.sliding_window
self.model_dtype = getattr(config.text_config, "torch_dtype",
torch.float16)
logger.info(f"[Gemma3Model::__init__]{self.dtype=} {self.model_dtype=}")
Expand Down Expand Up @@ -172,12 +173,24 @@ def forward(
mm_embed
) == num_context_requests, "Number of multimodal features (if provided) should be equal to number of context requests"

mm_token_ids = torch.tensor([self.image_token_index
]).to(input_ids.device)
mm_token_mask = None
if len(mm_embed) > 0:
# Get token type ids. 0 corresponds to text tokens, 1 corresponds to image tokens.
mm_token_mask = torch.isin(input_ids, mm_token_ids)
input_ids, inputs_embeds = fuse_input_embeds(
embedding_layer=self.llm.model.embed_tokens,
input_ids=input_ids,
mm_embeds=mm_embed,
mm_token_ids=torch.tensor([self.image_token_index
]).to(input_ids.device))
logits = self.llm.forward(attn_metadata, input_ids, position_ids,
inputs_embeds, return_context_logits)
logits = self.llm.forward(
attn_metadata=attn_metadata,
input_ids=input_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
return_context_logits=return_context_logits,
image_token_mask=mm_token_mask,
)
return logits
14 changes: 8 additions & 6 deletions tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

from ..attention_backend import (AttentionInputType, AttentionMetadata,
TrtllmAttention, TrtllmAttentionMetadata)
from ..attention_backend.interface import (PositionalEmbeddingParams,
from ..attention_backend.interface import (AttentionMask,
PositionalEmbeddingParams,
PredefinedAttentionMask)
from ..attention_backend.utils import create_attention, get_attention_backend
from ..distributed import AllReduceParams
Expand Down Expand Up @@ -226,12 +227,12 @@ def forward(
position_ids: Optional[torch.IntTensor],
hidden_states: Union[torch.Tensor, Fp4QuantizedTensor],
attn_metadata: AttentionMetadata,
attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.
CAUSAL,
attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
mrope_config: Optional[dict] = None,
all_reduce_params: Optional[AllReduceParams] = None,
lora_params: Optional[dict] = None,
attention_window_size: Optional[int] = None,
attention_mask_data: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Expand All @@ -241,12 +242,12 @@ def forward(
position_ids (Optional[torch.IntTensor]): The position IDs.
hidden_states (torch.Tensor): The hidden states.
attn_metadata (AttentionMetadata): The attention metadata.
attention_mask (PredefinedAttentionMask): The attention mask type.
attention_mask (AttentionMask): The attention mask type.
mrope_config (Optional[dict]): The MROPE configuration.
all_reduce_params (Optional[AllReduceParams]): The all reduce parameters.
lora_params (Optional[dict]): The LoRA parameters.
attention_window_size (Optional[int]): The attention window size.

attention_mask_data (Optional[torch.Tensor]): The attention mask data.
Returns:
torch.Tensor: The output tensor.
"""
Expand Down Expand Up @@ -284,7 +285,8 @@ def forward(
out_scale_sf=out_scale_sf,
attention_mask=attention_mask,
mrope_config=mrope_config,
attention_window_size=attention_window_size)
attention_window_size=attention_window_size,
attention_mask_data=attention_mask_data)
hidden_states = attn_output
attn_output = self.o_proj(attn_output,
all_reduce_params=all_reduce_params,
Expand Down
Loading