-
Notifications
You must be signed in to change notification settings - Fork 1.7k
[TRTLLM-4279] feat: Multistream initial support for torch compile flow #5847
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
5627900
to
95852c9
Compare
/bot run |
PR_Github #11380 [ run ] triggered by Bot |
/bot kill |
PR_Github #11385 [ kill ] triggered by Bot |
PR_Github #11380 [ run ] completed with state |
PR_Github #11385 [ kill ] completed with state |
95852c9
to
ad9d60a
Compare
/bot run |
PR_Github #11390 [ run ] triggered by Bot |
PR_Github #11390 [ run ] completed with state |
ad9d60a
to
446a9da
Compare
/bot run |
PR_Github #11405 [ run ] triggered by Bot |
PR_Github #11405 [ run ] completed with state |
446a9da
to
59e9d98
Compare
/bot run |
59e9d98
to
28baa13
Compare
/bot run |
PR_Github #11749 [ run ] triggered by Bot |
/bot kill |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
tensorrt_llm/_torch/compilation/multi_stream/auto_multi_stream.py (4)
71-80
: Consider making operation weights configurable.The hardcoded weights (20 for MOE, 10 for GEMM, 1 for others) could be made configurable through a class or configuration object. This would make it easier to tune the scheduler without modifying code.
For example:
@dataclass class SchedulerConfig: moe_weight: int = 20 gemm_weight: int = 10 default_weight: int = 1
169-184
: Consider extracting argument flattening logic.The nested while loop for flattening arguments could be extracted into a separate helper function for better readability and testability.
def flatten_node_args(node: Node) -> List: """Flatten nested arguments from node into a flat list.""" args = [a for a in node.args] + [a for a in node.kwargs.values()] changed = True while changed: changed = False args_new = [] for arg in args: if isinstance(arg, dict): for v in arg.values(): args_new.append(v) changed = True elif isinstance(arg, (list, tuple)): for item in arg: args_new.append(item) changed = True else: args_new.append(arg) args = args_new return args
426-427
: Avoid global state in debug utilities.The global
call_cnt
variable could lead to issues in multi-threaded environments. Consider passing a filename or counter as a parameter instead.def dump_dag_as_dot(dag: MultiStreamDAG, filename: str = None) -> None: if filename is None: import time filename = f"dag_{int(time.time())}.dot" COLORS = [ "red", "chocolate", "cyan", "gold", "coral", "green", "blue", "orange", "purple", "brown" ] with open(filename, 'w') as f: # ... rest of the function
444-444
: Fix line length issue.Line exceeds the 120 character limit.
- f"id_{dag.node_to_id[node.node]} [label=\"node={node.node}, distance={node.distance}, weight={node.weight}\", color={color}, shape=oval]\n" + f"id_{dag.node_to_id[node.node]} [label=\"node={node.node}, " + f"distance={node.distance}, weight={node.weight}\", " + f"color={color}, shape=oval]\n"
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tensorrt_llm/_torch/compilation/multi_stream/auto_multi_stream.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/compilation/multi_stream/auto_multi_stream.py
444-444: Line too long (155 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (5)
tensorrt_llm/_torch/compilation/multi_stream/auto_multi_stream.py (5)
1-13
: LGTM!The imports are well-organized and all appear to be used in the implementation.
83-125
: Well-designed data structures for stream scheduling.The
Stream
dataclass andMultiStreamNode
class appropriately model the scheduling requirements with clear separation of concerns.
236-313
: Solid stream scheduling algorithm.The priority-based scheduling with distance heuristic and synchronization event generation is well-implemented. The deterministic ordering through node IDs is a nice touch.
315-400
: Correct graph reconstruction with proper synchronization.The implementation properly handles stream switching, event synchronization, and tensor lifetime management through
record_stream
operations.
413-422
: Clean public API design.The function provides a simple interface with clear documentation about in-place modification and meaningful return value.
PR_Github #12401 [ kill ] triggered by Bot |
PR_Github #12400 [ run ] completed with state |
PR_Github #12401 [ kill ] completed with state |
de5ba23
to
71df703
Compare
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
tensorrt_llm/_torch/compilation/multi_stream/auto_multi_stream.py (1)
27-27
: Fix grammatical error in comment.- # This is a symint call happen on host. No need to count time on stream. + # This is a symint call that happens on host. No need to count time on stream.
🧹 Nitpick comments (6)
tensorrt_llm/_torch/compilation/multi_stream/auto_multi_stream.py (6)
17-17
: Fix grammatical error in comment.The comment has a grammatical error that should be corrected.
- # This is a symint call happen on host. No need to count time on stream. + # This is a symint call that happens on host. No need to count time on stream.
73-73
: Consider defining cost values as named constants.The hardcoded cost values (20, 10, 1) would be more maintainable as named constants, making it easier to tune the scheduling algorithm.
Add constants at the module level:
# Cost model constants MOE_OP_COST = 20 GEMM_OP_COST = 10 DEFAULT_OP_COST = 1Then update the function:
if node.op == "call_function" and node.target in moe_ops: - return 20 + return MOE_OP_COST # GEMM ops if node.op == "call_function" and node.target in gemm_ops: - return 10 + return GEMM_OP_COST # Refine the estimation of time for nodes. - return 1 + return DEFAULT_OP_COSTAlso applies to: 77-77, 80-80
169-185
: Extract argument flattening logic into a helper method.The nested while loop for flattening arguments reduces readability. Consider extracting this into a dedicated helper method.
+ def _flatten_args(self, args): + """Recursively flatten nested arguments into a flat list.""" + result = [] + stack = list(args) + while stack: + arg = stack.pop() + if isinstance(arg, dict): + stack.extend(arg.values()) + elif isinstance(arg, (list, tuple)): + stack.extend(arg) + else: + result.append(arg) + return result + def create_dag_from_gm(self, gm: GraphModule) -> None: # ... existing code ... for node in gm.graph.nodes: # ... existing code ... args = [a for a in node.args] + [a for a in node.kwargs.values()] - - changed = True - while changed: - changed = False - args_new = [] - for arg in args: - if isinstance(arg, dict): - for v in arg.values(): - args_new.append(v) - changed = True - elif isinstance(arg, (list, tuple)): - for item in arg: - args_new.append(item) - changed = True - else: - args_new.append(arg) - args = args_new + args = self._flatten_args(args)
315-401
: Consider breaking down this method for better maintainability.This method is quite long (86 lines) and handles multiple responsibilities: graph creation, stream switching, event synchronization, and node remapping. Consider extracting some logic into helper methods.
For example, you could extract:
- Event synchronization logic (lines 378-396) into
_insert_event_synchronization
- Stream switching logic (lines 361-366) into
_insert_stream_switch
- Node processing logic into
_process_node
This would improve readability and make the code easier to test and maintain.
426-427
: Avoid using global variable for call counter.Using a global variable for tracking calls is not thread-safe and makes the function harder to test.
Consider passing a filename parameter or using a timestamp:
-call_cnt = 0 - -def dump_dag_as_dot(dag: MultiStreamDAG) -> None: - global call_cnt +def dump_dag_as_dot(dag: MultiStreamDAG, filename: str = None) -> None: + if filename is None: + import time + filename = f"dag_{int(time.time())}.dot" + else: + filename = f"{filename}.dot" # ... rest of the code ... - with open(f"{call_cnt}.dot", 'w') as f: + with open(filename, 'w') as f: # ... rest of the code ... - call_cnt += 1Also applies to: 455-455
450-452
: Make node limit configurable for debug function.The hardcoded limit of 500 nodes might be too restrictive for debugging large graphs.
-def dump_dag_as_dot(dag: MultiStreamDAG) -> None: +def dump_dag_as_dot(dag: MultiStreamDAG, max_nodes: int = 500) -> None: # ... existing code ... - if cnt > 500: + if cnt > max_nodes: break
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (15)
tensorrt_llm/_torch/compilation/backend.py
(5 hunks)tensorrt_llm/_torch/compilation/multi_stream/auto_multi_stream.py
(1 hunks)tensorrt_llm/_torch/compilation/piecewise_optimizer.py
(7 hunks)tensorrt_llm/_torch/compilation/remove_copy_pass.py
(2 hunks)tensorrt_llm/_torch/compilation/utils.py
(1 hunks)tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
(1 hunks)tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
(1 hunks)tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py
(5 hunks)tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
(1 hunks)tensorrt_llm/_torch/pyexecutor/config.py
(1 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py
(2 hunks)tensorrt_llm/_torch/utils.py
(1 hunks)tensorrt_llm/llmapi/llm_args.py
(2 hunks)tests/integration/defs/accuracy/test_llm_api_pytorch.py
(8 hunks)tests/unittest/_torch/thop/test_moe.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (14)
- tensorrt_llm/_torch/pyexecutor/config.py
- tensorrt_llm/_torch/compilation/remove_copy_pass.py
- tensorrt_llm/_torch/compilation/utils.py
- tensorrt_llm/_torch/utils.py
- tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
- tests/unittest/_torch/thop/test_moe.py
- tests/integration/defs/accuracy/test_llm_api_pytorch.py
- tensorrt_llm/llmapi/llm_args.py
- tensorrt_llm/_torch/pyexecutor/model_engine.py
- tensorrt_llm/_torch/compilation/backend.py
- tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
- tensorrt_llm/_torch/compilation/piecewise_optimizer.py
- tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py
- tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/compilation/multi_stream/auto_multi_stream.py
444-444: Line too long (155 > 120)
(E501)
🔇 Additional comments (1)
tensorrt_llm/_torch/compilation/multi_stream/auto_multi_stream.py (1)
256-262
: Edge Case Verification: No Zero-Weight or Symint-Only Dependencies FoundI ran searches for any
estimate_time()
implementations returning0
and for uses ofgetitem
with only symint inputs—both returned no matches in the codebase. It appears there are no zero-weight nodes or pure-symint dependency chains in practice, so the fallback tostreams[0]
should never actually be exercised.However, absence of evidence isn’t evidence of absence—please manually confirm that:
- No future or generated nodes can end up with zero estimated weight.
- Falling back to
streams[0]
won’t inadvertently introduce unwanted stream switches.
tensorrt_llm/_torch/compilation/multi_stream/auto_multi_stream.py
Outdated
Show resolved
Hide resolved
PR_Github #12405 [ run ] triggered by Bot |
Signed-off-by: Jin Li <[email protected]>
Signed-off-by: Jin Li <[email protected]>
PR_Github #12405 [ run ] completed with state |
Signed-off-by: Jin Li <[email protected]>
71df703
to
facac55
Compare
/bot run |
PR_Github #12416 [ run ] triggered by Bot |
PR_Github #12416 [ run ] completed with state |
NVIDIA#5847) Signed-off-by: Jin Li <[email protected]>
NVIDIA#5847) Signed-off-by: Jin Li <[email protected]> Signed-off-by: Shreyas Misra <[email protected]>
NVIDIA#5847) Signed-off-by: Jin Li <[email protected]> Signed-off-by: Ransiki Zhang <[email protected]>
PR title
Please write the PR title by following template:
[JIRA ticket link/nvbug link/github issue link][fix/feat/doc/infra/...] <summary of this PR>
For example, assume I have a PR hope to support a new feature about cache manager of Jira TRTLLM-1000 ticket, it would be like
[TRTLLM-1000][feat] Support a new feature about cache manager
Description
Please explain the issue and the solution in short.
Test Coverage
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...
Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]
to print this help message.See details below for each supported subcommand.
run [--disable-fail-fast --skip-test --stage-list "A10-1, xxx" --gpu-type "A30, H100_PCIe" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-[Post-Merge]-1, xxx"]
Launch build/test pipelines. All previously running jobs will be killed.
--disable-fail-fast
(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test
(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-1, xxx"
(OPTIONAL) : Only run the specified test stages. Examples: "A10-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"
(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--only-multi-gpu-test
(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test
(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test
(OPTIONAL) : Force run the multi-GPU tests. Will also run L0 pre-merge pipeline.--post-merge
(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-[Post-Merge]-1, xxx"
(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-[Post-Merge]-1, xxx".For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.md
.kill
kill
Kill all running builds associated with pull request.
skip
skip --comment COMMENT
Skip testing for latest commit on pull request.
--comment "Reason for skipping build/test"
is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.
Summary by CodeRabbit
New Features
Improvements
Bug Fixes
Tests