@@ -18,27 +18,23 @@ def __init__(
18
18
model : nn .Module ,
19
19
in_spec : TreeSpec ,
20
20
out_spec : TreeSpec ,
21
- max_batch_size : int ,
22
- cuda_graph_batch_sizes : List [int ] = None ,
21
+ cuda_graph_batch_sizes : List [int ],
23
22
num_batched_inputs : Optional [int ] = 1 , # number of batched, dynamic inputs...
24
23
):
25
24
super ().__init__ ()
26
25
self ._in_spec = in_spec
27
26
self ._out_spec = out_spec
28
27
self .model = model
29
- self .max_batch_size = max_batch_size
28
+ self .max_batch_size = max (cuda_graph_batch_sizes )
29
+ ad_logger .info (f"Setting max batch size to { self .max_batch_size } " )
30
30
self .num_batched_inputs = num_batched_inputs if num_batched_inputs is not None else 1
31
31
self .graphs : Dict [Tuple [int , ...], CUDAGraph ] = {}
32
32
self ._input_buffers : List [torch .Tensor ] = [
33
33
torch .empty (0 , 1 ) for _ in range (self .num_batched_inputs )
34
34
]
35
35
self ._out_buffer_flat : List [torch .Tensor ] = None
36
36
self ._args_hash : Optional [Tuple [int , ...]] = None
37
- self .cuda_graph_batch_sizes = (
38
- sorted (cuda_graph_batch_sizes , reverse = True )
39
- if cuda_graph_batch_sizes is not None
40
- else self ._get_graph_batch_sizes (self .max_batch_size )
41
- )
37
+ self .cuda_graph_batch_sizes = sorted (cuda_graph_batch_sizes , reverse = True )
42
38
self ._cuda_graph_mem_pool = None
43
39
44
40
def _get_hash (self , flat_args : List [Any ]) -> Tuple [int , ...]:
@@ -77,20 +73,6 @@ def _capture_one_graph(self, *args, **kwargs) -> torch.cuda.CUDAGraph:
77
73
self ._cuda_graph_mem_pool = self ._cuda_graph_mem_pool or graph .pool ()
78
74
return graph
79
75
80
- @staticmethod
81
- def _get_graph_batch_sizes (
82
- max_bs : int , extra : Optional [List [int ]] = None , multiplier : int = 128
83
- ) -> List [int ]:
84
- """Heuristic to set batch sizes for graph capture."""
85
- # do 1, max_bs, and extra as special batch sizes
86
- batch_sizes = {1 , max_bs , * (extra or [])}
87
-
88
- # add all multiples of multiplier up to max_bs
89
- batch_sizes .update (range (multiplier , max_bs + 1 , multiplier ))
90
-
91
- # return as sorted list
92
- return sorted (batch_sizes , reverse = True )
93
-
94
76
def capture_graph (self , * args , ** kwargs ):
95
77
"""Capture and pre-fetch the graph for variable batch size."""
96
78
# flatten args, kwargs
@@ -177,15 +159,21 @@ def forward(self, *args, **kwargs) -> Any:
177
159
class TorchCudagraphCompiler (BackendCompiler ):
178
160
"""Compiler that uses only CUDA graphs."""
179
161
162
+ def __init__ (self , * args , ** kwargs ):
163
+ super ().__init__ (* args , ** kwargs )
164
+ self .cuda_graph_batch_sizes = self .compiler_kwargs .get ("cuda_graph_batch_sizes" )
165
+ if not self .cuda_graph_batch_sizes :
166
+ self .cuda_graph_batch_sizes = self ._get_graph_batch_sizes (self .max_batch_size )
167
+ ad_logger .info (f"Setting cuda_graph_batch_sizes to { self .cuda_graph_batch_sizes } " )
168
+
180
169
def _init_captured_graph (
181
170
self , gm : nn .Module , in_spec : TreeSpec , out_spec : TreeSpec
182
171
) -> CapturedGraph :
183
172
return CapturedGraph (
184
173
gm ,
185
174
in_spec = in_spec ,
186
175
out_spec = out_spec ,
187
- max_batch_size = self .max_batch_size ,
188
- cuda_graph_batch_sizes = self .compiler_kwargs .get ("cuda_graph_batch_sizes" ),
176
+ cuda_graph_batch_sizes = self .cuda_graph_batch_sizes ,
189
177
num_batched_inputs = self .compiler_kwargs .get ("num_batched_inputs" ),
190
178
)
191
179
@@ -198,3 +186,17 @@ def compile(self) -> CapturedGraph:
198
186
captured_model .capture_graph (* self .args , ** self .kwargs )
199
187
200
188
return captured_model
189
+
190
+ @staticmethod
191
+ def _get_graph_batch_sizes (
192
+ max_bs : int , extra : Optional [List [int ]] = None , multiplier : int = 128
193
+ ) -> List [int ]:
194
+ """Heuristic to set batch sizes for graph capture."""
195
+ # do 1, max_bs, and extra as special batch sizes
196
+ batch_sizes = {1 , max_bs , * (extra or [])}
197
+
198
+ # add all multiples of multiplier up to max_bs
199
+ batch_sizes .update (range (multiplier , max_bs + 1 , multiplier ))
200
+
201
+ # return as sorted list
202
+ return sorted (batch_sizes , reverse = True )
0 commit comments