@@ -18,23 +18,27 @@ def __init__(
18
18
model : nn .Module ,
19
19
in_spec : TreeSpec ,
20
20
out_spec : TreeSpec ,
21
- cuda_graph_batch_sizes : List [int ],
21
+ max_batch_size : int ,
22
+ cuda_graph_batch_sizes : List [int ] = None ,
22
23
num_batched_inputs : Optional [int ] = 1 , # number of batched, dynamic inputs...
23
24
):
24
25
super ().__init__ ()
25
26
self ._in_spec = in_spec
26
27
self ._out_spec = out_spec
27
28
self .model = model
28
- self .max_batch_size = max (cuda_graph_batch_sizes )
29
- ad_logger .info (f"Setting max batch size to { self .max_batch_size } " )
29
+ self .max_batch_size = max_batch_size
30
30
self .num_batched_inputs = num_batched_inputs if num_batched_inputs is not None else 1
31
31
self .graphs : Dict [Tuple [int , ...], CUDAGraph ] = {}
32
32
self ._input_buffers : List [torch .Tensor ] = [
33
33
torch .empty (0 , 1 ) for _ in range (self .num_batched_inputs )
34
34
]
35
35
self ._out_buffer_flat : List [torch .Tensor ] = None
36
36
self ._args_hash : Optional [Tuple [int , ...]] = None
37
- self .cuda_graph_batch_sizes = sorted (cuda_graph_batch_sizes , reverse = True )
37
+ self .cuda_graph_batch_sizes = (
38
+ sorted (cuda_graph_batch_sizes , reverse = True )
39
+ if cuda_graph_batch_sizes is not None
40
+ else self ._get_graph_batch_sizes (self .max_batch_size )
41
+ )
38
42
self ._cuda_graph_mem_pool = None
39
43
40
44
def _get_hash (self , flat_args : List [Any ]) -> Tuple [int , ...]:
@@ -73,6 +77,20 @@ def _capture_one_graph(self, *args, **kwargs) -> torch.cuda.CUDAGraph:
73
77
self ._cuda_graph_mem_pool = self ._cuda_graph_mem_pool or graph .pool ()
74
78
return graph
75
79
80
+ @staticmethod
81
+ def _get_graph_batch_sizes (
82
+ max_bs : int , extra : Optional [List [int ]] = None , multiplier : int = 128
83
+ ) -> List [int ]:
84
+ """Heuristic to set batch sizes for graph capture."""
85
+ # do 1, max_bs, and extra as special batch sizes
86
+ batch_sizes = {1 , max_bs , * (extra or [])}
87
+
88
+ # add all multiples of multiplier up to max_bs
89
+ batch_sizes .update (range (multiplier , max_bs + 1 , multiplier ))
90
+
91
+ # return as sorted list
92
+ return sorted (batch_sizes , reverse = True )
93
+
76
94
def capture_graph (self , * args , ** kwargs ):
77
95
"""Capture and pre-fetch the graph for variable batch size."""
78
96
# flatten args, kwargs
@@ -159,21 +177,15 @@ def forward(self, *args, **kwargs) -> Any:
159
177
class TorchCudagraphCompiler (BackendCompiler ):
160
178
"""Compiler that uses only CUDA graphs."""
161
179
162
- def __init__ (self , * args , ** kwargs ):
163
- super ().__init__ (* args , ** kwargs )
164
- self .cuda_graph_batch_sizes = self .compiler_kwargs .get ("cuda_graph_batch_sizes" )
165
- if not self .cuda_graph_batch_sizes :
166
- self .cuda_graph_batch_sizes = self ._get_graph_batch_sizes (self .max_batch_size )
167
- ad_logger .info (f"Setting cuda_graph_batch_sizes to { self .cuda_graph_batch_sizes } " )
168
-
169
180
def _init_captured_graph (
170
181
self , gm : nn .Module , in_spec : TreeSpec , out_spec : TreeSpec
171
182
) -> CapturedGraph :
172
183
return CapturedGraph (
173
184
gm ,
174
185
in_spec = in_spec ,
175
186
out_spec = out_spec ,
176
- cuda_graph_batch_sizes = self .cuda_graph_batch_sizes ,
187
+ max_batch_size = self .max_batch_size ,
188
+ cuda_graph_batch_sizes = self .compiler_kwargs .get ("cuda_graph_batch_sizes" ),
177
189
num_batched_inputs = self .compiler_kwargs .get ("num_batched_inputs" ),
178
190
)
179
191
@@ -186,17 +198,3 @@ def compile(self) -> CapturedGraph:
186
198
captured_model .capture_graph (* self .args , ** self .kwargs )
187
199
188
200
return captured_model
189
-
190
- @staticmethod
191
- def _get_graph_batch_sizes (
192
- max_bs : int , extra : Optional [List [int ]] = None , multiplier : int = 128
193
- ) -> List [int ]:
194
- """Heuristic to set batch sizes for graph capture."""
195
- # do 1, max_bs, and extra as special batch sizes
196
- batch_sizes = {1 , max_bs , * (extra or [])}
197
-
198
- # add all multiples of multiplier up to max_bs
199
- batch_sizes .update (range (multiplier , max_bs + 1 , multiplier ))
200
-
201
- # return as sorted list
202
- return sorted (batch_sizes , reverse = True )
0 commit comments