Skip to content

Conversation

brb-nv
Copy link
Collaborator

@brb-nv brb-nv commented Jul 8, 2025

Description

This MR:

  • Adds masking utilities needed for context phase requests with image tokens in Gemma3 VLMs.
  • Also, enables custom_mask usage in FlashInfer backend.

Background about custom mask:

  • The function get_flashinfer_attention_mask will only be called for a batch when there's at least one context request in the batch with image tokens.
  • In context phase, each sample's input_ids may have a mix of image (image_token_idx) and text tokens where tokens corresponding to an image appear as a contiguous blob.
    Example: torch.IntTensor([2, 3, 4, 5, img_idx, img_idx, img_idx, ..., img_idx, 100])
  • While the text tokens attend to other tokens in a causal fashion, image tokens attend to others in a causal fashion and well as attend to other image tokens in a bidirectional manner. Hence, the need for custom masking.

This MR has a nice visualization of the attention mask for global attention and sliding window attention:
huggingface/transformers#38295

Request for reviewers:
I'd appreciate your comments on the two strong assumptions of disabling chunked prefill and KV Cache reuse to get the bidirectional masking right. My thoughts:

  • Chunked prefill must be disabled (unless chunk_size=max_input_len, which doesn't make sense) because image tokens can appear anywhere in the input_ids and bidirectionality will be lost if chunking breaks an image input blob into separate chunks.
  • KV cache reuse is disabled to avoid partially matched image tokens. Either all tokens of an image must be reused or none at all.
$ python3 examples/pytorch/quickstart_multimodal.py --model_dir ../random/hf_models/gemma-3-27b-it/ --modality image --image_format pil --attention_backend FLASHINFER --disable_kv_cache_reuse

Test Coverage

Following tests validate masking utils.

$ pytest tests/unittest/_torch/modeling/test_modeling_gemma3.py::TestGemma3::test_gemma3_local_context_mask -s -v
$ pytest tests/unittest/_torch/modeling/test_modeling_gemma3.py::TestGemma3::test_gemma3_global_context_mask -s -v
$ pytest tests/unittest/_torch/modeling/test_modeling_gemma3.py::TestGemma3::test_gemma3_flashinfer_mask -s -v

Following tests validate masking utils as well as custom mask usage by FlashInfer backend.

$ pytest tests/unittest/_torch/modeling/test_modeling_gemma3.py::TestGemma3::test_gemma3_allclose_to_hf[backend:flashinfer_config:1b] -s -v
$ pytest tests/unittest/_torch/modeling/test_modeling_gemma3.py::TestGemma3::test_gemma3_allclose_to_hf[backend:flashinfer_config:27b] -s -v

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--disable-fail-fast --skip-test --stage-list "A10-1, xxx" --gpu-type "A30, H100_PCIe" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-[Post-Merge]-1, xxx"]

Launch build/test pipelines. All previously running jobs will be killed.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests. Will also run L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-[Post-Merge]-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-[Post-Merge]-1, xxx".

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@brb-nv brb-nv requested a review from a team as a code owner July 8, 2025 17:08
@brb-nv brb-nv requested review from symphonylyh and nv-yilinf July 8, 2025 17:08
@brb-nv brb-nv force-pushed the user/brb/gemma3-masking-utils branch 2 times, most recently from b3a2e3b to d79f8d6 Compare July 8, 2025 17:30
@brb-nv brb-nv changed the title feat: Masking utils for Gemma3 VLM feat: Custom masking utils for Gemma3 VLM Jul 8, 2025
@brb-nv brb-nv force-pushed the user/brb/gemma3-masking-utils branch 2 times, most recently from 9e7e465 to fe2237d Compare July 8, 2025 17:38
@brb-nv brb-nv requested a review from amukkara July 8, 2025 19:18
@brb-nv
Copy link
Collaborator Author

brb-nv commented Jul 8, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #11341 [ run ] triggered by Bot

@schetlur-nv
Copy link
Collaborator

RE: image tokens can appear anywhere in the input_ids
Is this actually true? Don't we know that the image tokens will always be the first k of input_ids, and all subsequent ones will be text?
At a high level, I agree with disabling the feature to get the MR in (decouple functional support from perf).
Do we know the max number of tokens an image can have? If chunk_size > max_image_size, we should be ok.
Longer term, should we have a parameter that disallows chunking up to the first k tokens of the input? IT really hinges on how common this pattern is.
CC @symphonylyh @chang-l

@brb-nv
Copy link
Collaborator Author

brb-nv commented Jul 8, 2025

RE: image tokens can appear anywhere in the input_ids Is this actually true? Don't we know that the image tokens will always be the first k of input_ids, and all subsequent ones will be text? At a high level, I agree with disabling the feature to get the MR in (decouple functional support from perf). Do we know the max number of tokens an image can have? If chunk_size > max_image_size, we should be ok. Longer term, should we have a parameter that disallows chunking up to the first k tokens of the input? IT really hinges on how common this pattern is. CC @symphonylyh @chang-l

Thank you, @schetlur-nv! I've noticed where the image tokens start could vary based on length of system prompt (more text tokens before image tokens for longer system prompt), use single image or multiple images (second image's tokens start after the first one). But, it's true that the number of mm_tokens_per_image is known already from config.

input_ids.txt

262144 is image_token_idx for Gemma3.

from PIL import Image
import requests
from transformers import AutoProcessor, Gemma3ForConditionalGeneration

model_dir = "/home/bbuddharaju/scratch/random/hf_models/gemma-3-4b-it/"
model = Gemma3ForConditionalGeneration.from_pretrained(model_dir)

processor = AutoProcessor.from_pretrained(model_dir)
messages = [
    {
        "role": "system",
        "content": [
            {"type": "text", "text": "You are a helpful assistant."}
        ]
    },
    {
        "role": "user",
        "content": [
            {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
            {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
            {"type": "text", "text": "Compare above images."},
        ]
    },
]

inputs = processor.apply_chat_template(messages, tokenize=True, return_dict=True, return_tensors="pt", add_generation_prompt=True)
print(f"input_ids: {inputs.input_ids.shape} \n {inputs.input_ids}")
generate_ids = model.generate(**inputs, do_sample=False, num_beams=1, top_p=None, top_k=None, max_new_tokens=200)
outputs = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(outputs)

@chang-l
Copy link
Collaborator

chang-l commented Jul 8, 2025

Longer term, should we have a parameter that disallows chunking up to the first k tokens of the input? IT really hinges on how common this pattern is.

I think chunked prefill may work if we respect the image token boundaries — that is, each multimodal item (e.g., an image) must occupy a single chunk

@brb-nv Regarding your comment on KV cache reuse — currently, kvcache reuse is not supported for all multimodal models in PyTorch flow (see here). I have a pending PR #5444 to enable it initially.

Also, I think your assumption — either all tokens of an image must be reused or none at all — might not always hold. For example, consider the input: [1, 2, image_token, image_token, image_token, image_token, 7, 9] with a block size of 4. This would split into:

  • Block 1: [1, 2, image_token, image_token]
  • Block 2: [image_token, image_token, 7, 9]

Now, if another sequence shares the same image but ends slightly differently, e.g., [1, 2, image_token, image_token, image_token, image_token, 11, 13], we can only reuse Block 1, which only partially covers the image tokens.

@brb-nv
Copy link
Collaborator Author

brb-nv commented Jul 8, 2025

Longer term, should we have a parameter that disallows chunking up to the first k tokens of the input? IT really hinges on how common this pattern is.

I think chunked prefill may work if we respect the image token boundaries — that is, each multimodal item (e.g., an image) must occupy a single chunk; but yeah, it's unclear how common this pattern is in practice

@brb-nv Regarding your comment on KV cache reuse — currently, kvcache reuse is not supported for all multimodal models in PyTorch flow (see here). I have a pending PR #5444 to enable it initially.

Also, I think your assumption — either all tokens of an image must be reused or none at all — might not always hold. For example, consider the input: [1, 2, image_token, image_token, image_token, image_token, 7, 9] with a block size of 4. This would split into:

  • Block 1: [1, 2, image_token, image_token]
  • Block 2: [image_token, image_token, 7, 9]

Now, if another sequence shares the same image but ends slightly differently, e.g., [1, 2, image_token, image_token, image_token, image_token, 11, 13], we can only reuse Block 1, which only partially covers the image tokens.

Thank you, @chang-l! The example you have is exactly why I was saying either all tokens of an image must be reused or none at all for correctness of the bidirectional mask. If we can't guarantee this, then KV cache reuse must be disabled for this model.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #11341 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #8392 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@brb-nv brb-nv force-pushed the user/brb/gemma3-masking-utils branch 2 times, most recently from c6d8db8 to 163b097 Compare July 9, 2025 08:06
@brb-nv
Copy link
Collaborator Author

brb-nv commented Jul 9, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #11411 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #11411 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #8440 completed with status: 'FAILURE'

@brb-nv
Copy link
Collaborator Author

brb-nv commented Jul 9, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #11448 [ run ] triggered by Bot

@brb-nv brb-nv force-pushed the user/brb/gemma3-masking-utils branch from 06c354b to b7d62a6 Compare July 9, 2025 15:16
@brb-nv
Copy link
Collaborator Author

brb-nv commented Jul 9, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #11453 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #11448 [ run ] completed with state ABORTED

@brb-nv brb-nv requested a review from qixiang-99 July 9, 2025 17:00
@tensorrt-cicd
Copy link
Collaborator

PR_Github #11453 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #8473 completed with status: 'SUCCESS'

Copy link
Collaborator

@qixiang-99 qixiang-99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, great work.

@brb-nv brb-nv requested a review from chang-l July 9, 2025 20:15
@brb-nv brb-nv force-pushed the user/brb/gemma3-masking-utils branch from f4bae6e to 355389c Compare July 9, 2025 20:28
@brb-nv
Copy link
Collaborator Author

brb-nv commented Jul 9, 2025

/bot reuse-pipeline

@brb-nv
Copy link
Collaborator Author

brb-nv commented Jul 9, 2025

Reusing pipeline because changes are cosmetic. f4bae6e

I reran formatting too.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #11468 [ reuse-pipeline ] triggered by Bot

@amukkara amukkara enabled auto-merge (squash) July 9, 2025 20:44
@amukkara amukkara force-pushed the user/brb/gemma3-masking-utils branch from 355389c to 5908b24 Compare July 9, 2025 20:47
@brb-nv
Copy link
Collaborator Author

brb-nv commented Jul 9, 2025

/bot reuse

Copy link

github-actions bot commented Jul 9, 2025

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--disable-fail-fast --skip-test --stage-list "A10-1, xxx" --gpu-type "A30, H100_PCIe" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-[Post-Merge]-1, xxx"]

Launch build/test pipelines. All previously running jobs will be killed.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests. Will also run L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-[Post-Merge]-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-[Post-Merge]-1, xxx".

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@brb-nv
Copy link
Collaborator Author

brb-nv commented Jul 9, 2025

/bot reuse-pipeline

@tensorrt-cicd
Copy link
Collaborator

PR_Github #11469 [ reuse-pipeline ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #11468 [ reuse-pipeline ] completed with state ABORTED
Can't reuse PR_Github #11453 with status: SUCCESS

@amukkara
Copy link
Collaborator

amukkara commented Jul 9, 2025

/bot reuse-pipeline

@tensorrt-cicd
Copy link
Collaborator

PR_Github #11470 [ reuse-pipeline ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #11469 [ reuse-pipeline ] completed with state ABORTED
Can't reuse PR_Github #11453 with status: SUCCESS

@tensorrt-cicd
Copy link
Collaborator

PR_Github #11470 [ reuse-pipeline ] completed with state SUCCESS
Reusing PR_Github #11453 for commit 5908b24

@amukkara amukkara merged commit 3209b31 into NVIDIA:main Jul 9, 2025
3 checks passed
@brb-nv brb-nv deleted the user/brb/gemma3-masking-utils branch July 11, 2025 23:29
zhou-yuxin pushed a commit to zhou-yuxin/TensorRT-LLM that referenced this pull request Jul 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants