Skip to content

Commit f5e8dff

Browse files
committed
feat: Improve Dynamo partitioning system
- Upgrade Dynamo partitioning to use a custom version of the Torch _SplitterBase for efficiency and optimized usage in the Dynamo case - Validate existing use cases are still functional, with the same partitioning schema as before - Upgrade qualified name checking - Update testing for new partitioner - Add new directory to store available partitioners
1 parent 0527edd commit f5e8dff

File tree

9 files changed

+345
-14
lines changed

9 files changed

+345
-14
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import (
1212
pre_aot_substitutions,
1313
)
14-
from torch_tensorrt.dynamo.lowering._partition import (
14+
from torch_tensorrt.dynamo.partitioning import (
1515
partition,
1616
get_submod_inputs,
1717
)
@@ -131,6 +131,11 @@ def _compile_module(
131131
# Iterate over all components that can be accelerated
132132
# Generate the corresponding TRT Module for those
133133
for name, _ in partitioned_module.named_children():
134+
135+
# Criteria for a module to be convertible to TRT
136+
if "_run_on_acc" not in name:
137+
continue
138+
134139
submodule = getattr(partitioned_module, name)
135140

136141
# Get submodule inputs

py/torch_tensorrt/dynamo/conversion/converter_registry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,8 @@ def unique_targets(self):
305305
"""Returns the set of unique converter targets stored across all registries"""
306306
return set.union(*[set(registry.keys()) for registry in self.registries])
307307

308-
def qualified_name_or_str(self, target: Target) -> str:
308+
@staticmethod
309+
def qualified_name_or_str(target: Target) -> str:
309310
"""Returns string representation of an FX Node target"""
310311
if isinstance(target, str):
311312
return target

py/torch_tensorrt/dynamo/lowering/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,5 @@
55
SUBSTITUTION_REGISTRY,
66
register_substitution,
77
)
8-
from ._partition import partition, get_submod_inputs, DEFAULT_SINGLE_NODE_PARTITIONS
98
from .substitutions import *
109
from ._fusers import *
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from ._adjacency_partitioner import (
2+
partition,
3+
get_submod_inputs,
4+
DEFAULT_SINGLE_NODE_PARTITIONS,
5+
)
Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
import logging
2+
from typing import Dict, List, Optional, Sequence, Set, Tuple
3+
4+
import torch
5+
6+
from torch.fx.passes.splitter_base import (
7+
Subgraph,
8+
_SplitterBase,
9+
_SplitterSettingBase,
10+
FxNetAccNodesFinder,
11+
FxNetAccFusionsFinder,
12+
)
13+
import torch.fx.passes.operator_support as ops
14+
from torch.fx.passes.tools_common import NodeSet, CALLABLE_NODE_OPS
15+
from torch.fx.node import Target
16+
17+
from torch_tensorrt.dynamo.conversion.converter_registry import ConverterRegistry
18+
from torch_tensorrt.dynamo.lowering import SUBSTITUTION_REGISTRY
19+
from torch_tensorrt.dynamo._defaults import MIN_BLOCK_SIZE
20+
from torch.fx.node import _get_qualified_name
21+
22+
from torch_tensorrt.dynamo import DYNAMO_CONVERTERS as CONVERTERS
23+
24+
25+
logger = logging.getLogger(__name__)
26+
27+
DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set(
28+
_get_qualified_name(to_replace.new_operator)
29+
for to_replace in SUBSTITUTION_REGISTRY.values()
30+
)
31+
32+
33+
class OpSupportTester(ops.OperatorSupportBase):
34+
"""Class to determine whether operators within a module are supported"""
35+
36+
def __init__(self, torch_executed_ops: Sequence[Target] = set()) -> None:
37+
super().__init__()
38+
39+
# Initialize sets of supported/unsupported operators
40+
self.supported_operators = {}
41+
self.unsupported_operators = {}
42+
self.torch_executed_ops = torch_executed_ops
43+
44+
def is_node_supported(
45+
self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node
46+
) -> bool:
47+
node_name = ConverterRegistry.qualified_name_or_str(node.target)
48+
49+
if node in CONVERTERS and node_name not in self.torch_executed_ops:
50+
# If node is a proper, supported computational node, store the operator
51+
if not node.is_impure():
52+
if node_name not in self.supported_operators:
53+
self.supported_operators[node_name] = 1
54+
else:
55+
self.supported_operators[node_name] += 1
56+
57+
return True
58+
else:
59+
if not node.is_impure():
60+
if node_name not in self.unsupported_operators:
61+
self.unsupported_operators[node_name] = 1
62+
else:
63+
self.unsupported_operators[node_name] += 1
64+
65+
return False
66+
67+
def print_support_overview(self, num_trt_blocks: Optional[int] = None):
68+
if num_trt_blocks is not None:
69+
logger.debug(
70+
f"\nNumber of TensorRT-Accelerated Engines Generated: {num_trt_blocks}"
71+
)
72+
73+
# Reformat support messages for debugger to print node overview as a single string
74+
supported_nodes_str = "\nSupported Nodes:\n"
75+
for node_name, count in self.supported_operators.items():
76+
supported_nodes_str += f"- {node_name} + Operator Count: {count}\n"
77+
78+
logger.debug(supported_nodes_str)
79+
80+
if self.unsupported_operators:
81+
unsupported_nodes_str = "\nUnsupported or Excluded Nodes:\n"
82+
for node_name, count in self.unsupported_operators.items():
83+
unsupported_nodes_str += f"- {node_name} + Operator Count: {count}\n"
84+
85+
logger.debug(unsupported_nodes_str)
86+
else:
87+
logger.debug("\nAll Nodes Supported\n")
88+
89+
90+
class TRTPartitioner(_SplitterBase):
91+
"""Partitioner to split an FX graph into subgraphs based on operator support
92+
93+
Adapted from, and modified for the Torch-TensorRT Dynamo case:
94+
https://github.com/pytorch/pytorch/blob/93f538db355ea10c684a57f7a632ed03292ef98f/torch/fx/passes/splitter_base.py#L256C9-L871
95+
96+
Args:
97+
module: FX GraphModule to partition
98+
operator_support: OperatorSupport class describing allowed operators
99+
allowed_single_node_partition_ops: Nodes which can be included in single-node partitons.
100+
Generally useful for module-level exclusion ops which are intensive despite being single functions
101+
min_block_size: Minimum number of computational operators per block
102+
Returns:
103+
torch.fx.GraphModule
104+
"""
105+
106+
def __init__(
107+
self,
108+
module: torch.fx.GraphModule,
109+
operator_support: ops.OperatorSupportBase,
110+
allowed_single_node_partition_ops: Optional[
111+
Sequence[str]
112+
] = DEFAULT_SINGLE_NODE_PARTITIONS,
113+
min_block_size: int = MIN_BLOCK_SIZE,
114+
):
115+
"""
116+
Preprocesses graph before splitting:
117+
- finds nodes supported by ACC,
118+
- finds fusion groups for ACC nodes having non-tensor IO,
119+
- builds a graph of direct dependencies,
120+
- builds a map of fused nodes to their fusions.
121+
As a result we get self.acc_nodes, self.deps and self.fusions.
122+
"""
123+
assert isinstance(module, torch.fx.GraphModule)
124+
125+
self.module = module
126+
127+
self.settings = _SplitterSettingBase(
128+
min_acc_module_size=min_block_size,
129+
allow_non_tensor=True,
130+
)
131+
self.operator_support = operator_support
132+
133+
# Get all accelerated nodes based on operator support conditions
134+
self.acc_nodes = FxNetAccNodesFinder(
135+
self.module, self.operator_support, self.settings.allow_non_tensor
136+
)()
137+
138+
if self.settings.skip_fusion:
139+
self.fusions = {}
140+
else:
141+
self.fusions = FxNetAccFusionsFinder(module, set(self.acc_nodes))()
142+
143+
# Modify deps to add more deps for fused nodes
144+
self.deps = self.find_deps()
145+
self.update_deps_for_fusions()
146+
147+
self.non_acc_submodule_name = "_run_on_gpu_"
148+
self._node_submodule_map: Dict[str, str] = {}
149+
150+
self.num_trt_accelerated_subgraphs = None
151+
self.allowed_single_node_partition_ops = allowed_single_node_partition_ops
152+
153+
def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]:
154+
"""
155+
This pass finds ACC submodules with less than specified size and merges
156+
them with adjacent GPU submodules.
157+
"""
158+
result: List[Subgraph] = []
159+
for subgraph in subgraphs:
160+
if subgraph.is_acc:
161+
if len(subgraph.nodes) >= self.settings.min_acc_module_size or any(
162+
ConverterRegistry.qualified_name_or_str(node.target)
163+
in self.allowed_single_node_partition_ops
164+
for node in subgraph.nodes
165+
):
166+
result.append(subgraph)
167+
else:
168+
logger.debug(
169+
"Eliminating acc subgraph because it's smaller than the threshold: "
170+
f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}"
171+
)
172+
if result:
173+
result[-1].nodes.extend(subgraph.nodes)
174+
else:
175+
subgraph.is_acc = False
176+
result.append(subgraph)
177+
else:
178+
if result and not result[-1].is_acc:
179+
result[-1].nodes.extend(subgraph.nodes)
180+
else:
181+
result.append(subgraph)
182+
return result
183+
184+
def partition_graph(self) -> torch.fx.GraphModule:
185+
"""Partitions the GraphModule into subgraphs based on operator support
186+
187+
Returns a GraphModule with submodules for each segment
188+
"""
189+
# Delegate nodes based on operator coverage
190+
subgraphs = self.put_nodes_into_subgraphs()
191+
192+
# Remove segments smaller than the block size (with exceptions)
193+
subgraphs = self.remove_small_acc_subgraphs(subgraphs)
194+
195+
# Set the number of TRT engines to be generated
196+
self.num_trt_accelerated_subgraphs = len([s for s in subgraphs if s.is_acc])
197+
198+
# Tag the accelerated nodes and split the graph accordingly
199+
self.tag(subgraphs)
200+
return self.split()
201+
202+
def starter_nodes(self) -> Tuple[NodeSet, NodeSet]:
203+
"""Generates starter nodes for partitioning + segmentation"""
204+
# Starter accelerated nodes are all callable accelerated ops
205+
starter_acc_nodes = {
206+
node for node in self.acc_nodes if node.op in CALLABLE_NODE_OPS
207+
}
208+
209+
# Started non-accelerated nodes are the rest of the callable nodes
210+
starter_non_acc_nodes = {
211+
node
212+
for node in self.module.graph.nodes
213+
if (node not in starter_acc_nodes and node.op in CALLABLE_NODE_OPS)
214+
}
215+
216+
return starter_non_acc_nodes, starter_acc_nodes
217+
218+
219+
def partition(
220+
gm: torch.fx.GraphModule,
221+
verbose: bool = True,
222+
min_block_size: int = MIN_BLOCK_SIZE,
223+
torch_executed_ops: Sequence[Target] = set(),
224+
) -> torch.fx.GraphModule:
225+
"""Partition an FX GraphModule with aten ops into TRT engines
226+
Partitioning is based on converter operator support
227+
228+
Args:
229+
gm: FX GraphModule to partition
230+
verbose: Bool representing whether to print operator support
231+
min_block_size: Minimum number of operators per TRT-Engine Block
232+
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
233+
Returns:
234+
torch.fx.GraphModule
235+
"""
236+
# Ensure graph is clean prior to partitioning
237+
gm.graph.eliminate_dead_code()
238+
gm.graph.lint()
239+
gm.recompile()
240+
241+
# Construct
242+
supported_ops = OpSupportTester(torch_executed_ops=torch_executed_ops)
243+
partitioner = TRTPartitioner(gm, supported_ops, min_block_size=min_block_size)
244+
245+
partitioned_graph = partitioner.partition_graph()
246+
247+
if verbose:
248+
supported_ops.print_support_overview(partitioner.num_trt_accelerated_subgraphs)
249+
250+
return partitioned_graph
251+
252+
253+
def get_submod_inputs(
254+
mod: torch.fx.GraphModule,
255+
submod: torch.fx.GraphModule,
256+
inputs: Sequence[torch.Tensor],
257+
) -> Sequence[torch.Tensor]:
258+
"""Helper function to get inputs to a Torch submodule
259+
260+
Args:
261+
mod: Parent FX GraphModule
262+
submod: Child FX GraphModule
263+
inputs: Sample inputs to parent module
264+
Returns:
265+
Sequence of Tensors representing inputs to child module
266+
"""
267+
acc_inputs = None
268+
269+
def get_input(self, inputs):
270+
nonlocal acc_inputs
271+
acc_inputs = inputs
272+
273+
handle = submod.register_forward_pre_hook(get_input)
274+
mod(*inputs)
275+
handle.remove()
276+
return acc_inputs

tests/py/dynamo/backend/test_backend_compiler.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import torch_tensorrt
3-
from torch_tensorrt.dynamo.lowering import partition
3+
from torch_tensorrt.dynamo.partitioning import partition
44
from torch.testing._internal.common_utils import run_tests, TestCase
55
from copy import deepcopy
66
from utils import lower_graph_testing, DECIMALS_OF_AGREEMENT
@@ -20,7 +20,13 @@ def forward(self, x, y):
2020
partitioned_graph = partition(deepcopy(fx_graph), min_block_size=3)
2121

2222
self.assertEquals(
23-
len(list(partitioned_graph.named_children())),
23+
len(
24+
[
25+
1
26+
for submod in list(partitioned_graph.named_children())
27+
if "_run_on_acc" in submod[0]
28+
]
29+
),
2430
1,
2531
"All operators are supported, there should be one segment",
2632
)
@@ -94,7 +100,13 @@ def forward(self, x, y):
94100
"Without control flow breaks, there should only be a single graph",
95101
)
96102
self.assertEquals(
97-
len(list(partitioned_graphs[0].named_children())),
103+
len(
104+
[
105+
1
106+
for submod in list(partitioned_graphs[0].named_children())
107+
if "_run_on_acc" in submod[0]
108+
]
109+
),
98110
2,
99111
"Certain operators are set to run in Torch, expected 2 segments",
100112
)
@@ -253,7 +265,13 @@ def forward(self, x, y):
253265
"Without control flow breaks, there should only be a single graph",
254266
)
255267
self.assertEquals(
256-
len(list(partitioned_graphs[0].named_children())),
268+
len(
269+
[
270+
1
271+
for submod in list(partitioned_graphs[0].named_children())
272+
if "_run_on_acc" in submod[0]
273+
]
274+
),
257275
1,
258276
"Certain operators are set to run in Torch, expected 1 segment",
259277
)

0 commit comments

Comments
 (0)