-
Notifications
You must be signed in to change notification settings - Fork 1.7k
[TRTLLM-6308][feat] Support Aggregate mode for phi4-mm #6184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TRTLLM-6308][feat] Support Aggregate mode for phi4-mm #6184
Conversation
WalkthroughThis update introduces comprehensive multimodal support for the Phi-4-MM model within the TensorRT-LLM PyTorch codebase. It adds new configuration, embedding, and utility classes for handling image and audio modalities, implements a full SigLIP vision model, and refactors the input processing and encoder pipeline to integrate these multimodal components. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Phi4MMInputProcessor
participant HFPhi4MultimodalEncoder
participant ImageEmbedding
participant AudioEmbedding
User->>Phi4MMInputProcessor: Provide input_ids, image/audio embeds, masks
Phi4MMInputProcessor->>HFPhi4MultimodalEncoder: forward(input_ids, image/audio embeds, masks)
alt Image embeddings present
HFPhi4MultimodalEncoder->>ImageEmbedding: Process image tokens/embeds
ImageEmbedding-->>HFPhi4MultimodalEncoder: Projected image features
end
alt Audio embeddings present
HFPhi4MultimodalEncoder->>AudioEmbedding: Process audio tokens/embeds
AudioEmbedding-->>HFPhi4MultimodalEncoder: Projected audio features
end
HFPhi4MultimodalEncoder-->>Phi4MMInputProcessor: Hidden states (with multimodal features)
Phi4MMInputProcessor-->>User: Extracted multimodal embeddings
Poem
✨ Finishing Touches
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
Documentation and Community
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
🧹 Nitpick comments (9)
tensorrt_llm/_torch/models/configuration_phi4mm.py (1)
224-227
: Consider splitting long error messages for better readabilityThe error messages exceed the 120 character line limit. While this doesn't affect functionality, splitting them would improve code readability.
if not len(rope_scaling_short_factor) == rotary_ndims // 2: raise ValueError( - f"`rope_scaling`'s short_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_short_factor)}" + f"`rope_scaling`'s short_factor field must have length {rotary_ndims // 2}, " + f"got {len(rope_scaling_short_factor)}" )if not len(rope_scaling_long_factor) == rotary_ndims // 2: raise ValueError( - f"`rope_scaling`'s long_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_long_factor)}" + f"`rope_scaling`'s long_factor field must have length {rotary_ndims // 2}, " + f"got {len(rope_scaling_long_factor)}" )Also applies to: 235-238
tensorrt_llm/_torch/models/utils_phi4mm.py (1)
128-128
: Improve assertion error messageThe assertion could provide a more descriptive error message to help users understand the requirement.
- assert self.use_hd_transform == self.with_learnable_separator, 'use_hd_transform and with_learnable_separator should have same value' + assert self.use_hd_transform == self.with_learnable_separator, \ + f'use_hd_transform ({self.use_hd_transform}) and with_learnable_separator ({self.with_learnable_separator}) must have the same value'tensorrt_llm/_torch/models/vision_siglip_navit.py (7)
268-268
: Fix capitalization in log messageThe log message should start with a capital letter for consistency.
- logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.") + logger.info("`vision_config` is `None`. Initializing the `SiglipVisionConfig` with default values.")
729-729
: Improve comment clarityThe comment uses informal language that could be more professional.
- self.is_causal = False # Hack to make sure we don't use a causal mask + self.is_causal = False # Siglip uses bidirectional attention, not causal
758-765
: Remove commented-out codeThis commented-out code for rotary embeddings and KV cache should be removed if it's not needed for Siglip.
Consider removing these lines entirely if rotary embeddings are not part of the Siglip architecture.
766-767
: Address TODO about transpose inefficiencyThis TODO indicates a known performance issue with the transposes required for Flash Attention.
Would you like me to create an issue to track this optimization opportunity? The inefficiency stems from Flash Attention requiring a different tensor layout than the standard transformer implementation.
995-995
: Use math.sqrt for consistencyThe code uses
np.sqrt
here butmath.sqrt
elsewhere in the file.- nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + nn.init.normal_(module.position_embedding.weight, std=1 / math.sqrt(width))
1687-1688
: Implement SigLIP loss functionThe loss function raises NotImplementedError, which means the model cannot be trained.
The SigLIP loss is a key component for training. Would you like me to help implement the SigLIP loss function or create an issue to track this?
1705-1720
: Consider externalizing model configurationThe model configuration is hardcoded in the function. Consider moving it to a configuration file or class constant for better maintainability.
# Define as a class constant or in a separate config file SIGLIP_VISION_BASE_CONFIG = { "hidden_size": 1152, "image_size": 448, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14, } def get_siglip_vision_model(_flash_attn_2_enabled=True, **kwargs): model_config = SiglipVisionConfig(**SIGLIP_VISION_BASE_CONFIG, _flash_attn_2_enabled=_flash_attn_2_enabled, **kwargs) vision_model = SiglipVisionModel(model_config).vision_model return vision_model
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
tensorrt_llm/_torch/models/configuration_phi4mm.py
(1 hunks)tensorrt_llm/_torch/models/modeling_phi4mm.py
(3 hunks)tensorrt_llm/_torch/models/utils_phi4mm.py
(1 hunks)tensorrt_llm/_torch/models/vision_siglip_navit.py
(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
tensorrt_llm/_torch/models/configuration_phi4mm.py (1)
tensorrt_llm/models/modeling_utils.py (1)
PretrainedConfig
(361-562)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/models/configuration_phi4mm.py
30-30: Line too long (121 > 120)
(E501)
226-226: Line too long (129 > 120)
(E501)
237-237: Line too long (127 > 120)
(E501)
tensorrt_llm/_torch/models/utils_phi4mm.py
105-105: Line too long (123 > 120)
(E501)
108-108: Line too long (141 > 120)
(E501)
166-166: Line too long (137 > 120)
(E501)
172-172: Line too long (133 > 120)
(E501)
190-190: Line too long (125 > 120)
(E501)
199-199: Line too long (125 > 120)
(E501)
246-246: Line too long (133 > 120)
(E501)
299-299: Line too long (130 > 120)
(E501)
306-306: Line too long (173 > 120)
(E501)
316-316: Line too long (250 > 120)
(E501)
338-338: Line too long (366 > 120)
(E501)
342-342: Line too long (140 > 120)
(E501)
351-351: Line too long (301 > 120)
(E501)
352-352: Line too long (334 > 120)
(E501)
355-355: Line too long (333 > 120)
(E501)
360-360: Line too long (161 > 120)
(E501)
363-363: Line too long (123 > 120)
(E501)
365-365: Line too long (140 > 120)
(E501)
377-377: Line too long (142 > 120)
(E501)
380-380: Local variable num_img_tokens
is assigned to but never used
Remove assignment to unused variable num_img_tokens
(F841)
385-385: Line too long (138 > 120)
(E501)
386-386: Line too long (171 > 120)
(E501)
395-395: Line too long (123 > 120)
(E501)
404-404: Line too long (164 > 120)
(E501)
409-409: Line too long (133 > 120)
(E501)
419-419: Line too long (150 > 120)
(E501)
436-436: Line too long (122 > 120)
(E501)
539-539: Line too long (139 > 120)
(E501)
556-556: Line too long (198 > 120)
(E501)
569-569: Local variable MAX_INPUT_ID
is assigned to but never used
Remove assignment to unused variable MAX_INPUT_ID
(F841)
603-603: Line too long (310 > 120)
(E501)
627-627: Line too long (131 > 120)
(E501)
tensorrt_llm/_torch/models/vision_siglip_navit.py
75-75: Line too long (122 > 120)
(E501)
169-169: Line too long (124 > 120)
(E501)
303-303: Module level import not at top of file
(E402)
304-304: Module level import not at top of file
(E402)
305-305: Module level import not at top of file
(E402)
306-306: Module level import not at top of file
(E402)
308-308: Module level import not at top of file
(E402)
309-309: Module level import not at top of file
(E402)
310-310: Module level import not at top of file
(E402)
311-311: Module level import not at top of file
(E402)
312-312: Module level import not at top of file
(E402)
313-313: Module level import not at top of file
(E402)
315-315: Module level import not at top of file
(E402)
316-316: Module level import not at top of file
(E402)
317-317: Module level import not at top of file
(E402)
318-318: Module level import not at top of file
(E402)
319-326: Module level import not at top of file
(E402)
372-372: Ambiguous variable name: l
(E741)
467-467: Line too long (153 > 120)
(E501)
471-471: Line too long (159 > 120)
(E501)
475-475: Line too long (150 > 120)
(E501)
494-494: Line too long (152 > 120)
(E501)
498-498: Line too long (159 > 120)
(E501)
502-502: Line too long (150 > 120)
(E501)
531-531: Line too long (121 > 120)
(E501)
697-697: Line too long (125 > 120)
(E501)
766-766: Line too long (184 > 120)
(E501)
834-834: Line too long (170 > 120)
(E501)
948-948: Line too long (140 > 120)
(E501)
tensorrt_llm/_torch/models/modeling_phi4mm.py
127-127: Line too long (121 > 120)
(E501)
193-193: Line too long (153 > 120)
(E501)
196-196: Line too long (165 > 120)
(E501)
🔇 Additional comments (6)
tensorrt_llm/_torch/models/configuration_phi4mm.py (1)
1-2
: Configuration file source is properly documentedGood practice to document the source of copied code with a clear explanation of why it was necessary.
tensorrt_llm/_torch/models/modeling_phi4mm.py (3)
154-160
: Good backward compatibility implementationThe token ID remapping ensures compatibility with legacy token ranges while maintaining clear special token IDs.
233-245
: Clean refactoring to use dedicated multimodal encoderThe change from AutoModelForCausalLM to HFPhi4MultimodalEncoder simplifies the pipeline and improves code clarity. Setting
trust_remote_code=False
is a good security practice.
1-3
: Clear roadmap for AGGREGATE mode implementationGood to see the implementation plan clearly documented. This PR provides the foundational multimodal support needed before implementing AGGREGATE mode in the next step.
tensorrt_llm/_torch/models/vision_siglip_navit.py (2)
1415-1439
: Clarify relationship to Aggregate mode featureThe PR objective mentions "Support Aggregate mode for phi4-mm", but this file doesn't appear to have any explicit aggregate mode implementation. The
SiglipMultiheadAttentionPoolingHead
performs attention-based pooling, which might be related to the aggregate feature.Could you clarify:
- Is this multihead attention pooling the "Aggregate mode" mentioned in the PR?
- If not, where is the aggregate mode implemented?
- Should there be a configuration option to enable/disable aggregate mode?
887-890
: Clarify cu_seqlens allocation and buffer‐reuse suggestionTorch’s
arange(..., device=…)
allocates and fills the tensor directly on the GPU (no host→device memcpy). If profiling shows this per-forward allocation is a bottleneck, consider pre-allocating a maximum‐lengthcu_seqlens
buffer in__init__
(viaregister_buffer
) and slicing it each call.• File: tensorrt_llm/_torch/models/vision_siglip_navit.py
Lines: 887–890- cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. + # Direct on-device fill; no host-device memcpy + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + )Example buffer‐reuse pattern:
# in __init__ self.register_buffer( "_cu_seqlens_buffer", torch.arange(self.max_batch_size + 1, dtype=torch.int32, device=self.device), ) # in forward cu_seqlens_q = self._cu_seqlens_buffer[: batch_size + 1]Please profile this allocation and, if it shows up in your GPU-allocation metrics, switch to the buffer-reuse approach.
4510412
to
15702ee
Compare
63bd263
to
3d32c42
Compare
c9e09c2
to
e1aa7b6
Compare
/bot run |
PR_Github #13653 [ run ] triggered by Bot |
PR_Github #13653 [ run ] completed with state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx for the work. Left some commnets.
2990afb
to
7ec6a42
Compare
/bot run |
PR_Github #14098 [ run ] triggered by Bot |
PR_Github #14098 [ run ] completed with state |
/bot run --disable-fail-fast |
PR_Github #14141 [ run ] triggered by Bot |
/bot run |
PR_Github #14376 [ run ] triggered by Bot |
PR_Github #14376 [ run ] completed with state |
/bot run |
PR_Github #14389 [ run ] triggered by Bot |
PR_Github #14389 [ run ] completed with state |
75aa846
to
57aa168
Compare
/bot run --disable-fail-fast |
PR_Github #14422 [ run ] triggered by Bot |
57aa168
to
e2e70b0
Compare
/bot run --disable-fail-fast |
PR_Github #14466 [ run ] triggered by Bot |
PR_Github #14422 [ run ] completed with state |
PR_Github #14466 [ run ] completed with state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thx for the work.
Signed-off-by: Wanli Jiang <[email protected]>
e2e70b0
to
2edf4fb
Compare
/bot run |
PR_Github #14541 [ run ] triggered by Bot |
PR_Github #14541 [ run ] completed with state |
/bot run |
PR_Github #14573 [ run ] triggered by Bot |
PR_Github #14573 [ run ] completed with state |
Summary by CodeRabbit
New Features
Improvements
Description
Test Coverage
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...
Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]
to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]
Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id
(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test
(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast
(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test
(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"
(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"
(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"
(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test
(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test
(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test
(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge
(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"
(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log
(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug
(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-list
parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.md
.kill
kill
Kill all running builds associated with pull request.
skip
skip --comment COMMENT
Skip testing for latest commit on pull request.
--comment "Reason for skipping build/test"
is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.