@@ -87,8 +87,12 @@ def _insert_quantized_bmm(
87
87
node : Node ,
88
88
quantization_impl : QuantizationImpl ,
89
89
is_quantized_graph : bool = False ,
90
- ):
91
- """Replaces the bmm node with a new quantized bmm node."""
90
+ ) -> bool :
91
+ """Replace a bmm op with its quantized equivalent and wire scales/state_dict hooks.
92
+
93
+ Returns:
94
+ True if quantization was applied; False if skipped (e.g., unknown shape).
95
+ """
92
96
weight_node = node .args [1 ]
93
97
94
98
# Weight is a parameter
@@ -140,7 +144,7 @@ def get_scale_name(scale_name):
140
144
141
145
else :
142
146
# If we can't determine the shape, skip quantization
143
- return
147
+ return False
144
148
145
149
# Common logic for both parameter and dynamic tensor cases
146
150
# Register scales in the target module
@@ -163,16 +167,12 @@ def get_scale_name(scale_name):
163
167
# Update node arguments and kwargs
164
168
scale_values = [scales [scale_name ] for scale_name in quantization_impl .scale_names ()]
165
169
node .args = (* node .args , * scale_values )
170
+ return True
166
171
167
172
168
- @TransformRegistry .register ("quantize_from_config" )
169
- class QuantizationFromConfig (BaseTransform ):
170
- """
171
- Quantize linear and BMM ops using a quantization config.
172
-
173
- Replaces eligible ops with quantized equivalents based on the quantization algorithm
174
- and exclude patterns defined in the config.
175
- """
173
+ @TransformRegistry .register ("quantize_linear_from_config" )
174
+ class LinearQuantizationFromConfig (BaseTransform ):
175
+ """Quantize eligible linear ops per quant config (algo + exclude patterns)."""
176
176
177
177
def _apply (
178
178
self ,
@@ -182,38 +182,69 @@ def _apply(
182
182
shared_config : SharedConfig ,
183
183
) -> Tuple [GraphModule , TransformInfo ]:
184
184
quant_config = factory .get_quant_config ()
185
- if not quant_config :
185
+ if not quant_config or not quant_config . get ( "quant_algo" ) :
186
186
return gm , TransformInfo (
187
187
skipped = True , num_matches = 0 , is_clean = True , has_valid_shapes = True
188
188
)
189
- quant_algo = quant_config .get ("quant_algo" , None )
190
- excluded_patterns = quant_config .get ("exclude_modules" , [])
191
- if not quant_algo :
189
+
190
+ quant_algo = quant_config ["quant_algo" ]
191
+ excluded = quant_config .get ("exclude_modules" , [])
192
+
193
+ num_matches = 0
194
+ impl = QuantizationImpl .create (quant_algo , is_bmm = False )
195
+
196
+ for n in gm .graph .nodes :
197
+ # Only consider linear ops; skip if excluded
198
+ if not is_linear_op (n , include_quantization = False ):
199
+ continue
200
+ if should_skip_quantization (n , excluded ):
201
+ continue
202
+
203
+ _insert_quantized_linear (gm , n , impl , is_quantized_graph = False )
204
+ num_matches += 1
205
+
206
+ info = TransformInfo (
207
+ skipped = False , num_matches = num_matches , is_clean = False , has_valid_shapes = True
208
+ )
209
+ return gm , info
210
+
211
+
212
+ @TransformRegistry .register ("quantize_bmm_from_config" )
213
+ class BMMQuantizationFromConfig (BaseTransform ):
214
+ """Quantize eligible BMM ops per quant config (algo + exclude patterns)."""
215
+
216
+ def _apply (
217
+ self ,
218
+ gm : GraphModule ,
219
+ cm : CachedSequenceInterface ,
220
+ factory : ModelFactory ,
221
+ shared_config : SharedConfig ,
222
+ ) -> Tuple [GraphModule , TransformInfo ]:
223
+ quant_config = factory .get_quant_config ()
224
+ if not quant_config or not quant_config .get ("quant_algo" ):
192
225
return gm , TransformInfo (
193
226
skipped = True , num_matches = 0 , is_clean = True , has_valid_shapes = True
194
227
)
195
228
229
+ quant_algo = quant_config ["quant_algo" ]
230
+ excluded = quant_config .get ("exclude_modules" , [])
231
+
196
232
num_matches = 0
233
+ impl = QuantizationImpl .create (quant_algo , is_bmm = True )
197
234
198
235
for n in gm .graph .nodes :
199
- if should_skip_quantization (n , excluded_patterns ):
236
+ if not is_bmm_op (n ):
237
+ continue
238
+ # Reuse common exclusion rule (supports Node or param-name string)
239
+ if should_skip_quantization (n , excluded ):
200
240
continue
201
241
202
- if is_linear_op (n , include_quantization = False ):
203
- impl = QuantizationImpl .create (quant_algo , is_bmm = False )
204
- _insert_quantized_linear (gm , n , impl , False )
205
- num_matches += 1
206
-
207
- # TODO: Make _insert_quantized_bmm return a bool and increment only on success
208
- elif is_bmm_op (n ):
209
- impl = QuantizationImpl .create (quant_algo , is_bmm = True )
210
- _insert_quantized_bmm (gm , n , impl , False )
242
+ if _insert_quantized_bmm (gm , n , impl , is_quantized_graph = False ):
211
243
num_matches += 1
212
244
213
245
info = TransformInfo (
214
246
skipped = False , num_matches = num_matches , is_clean = False , has_valid_shapes = True
215
247
)
216
-
217
248
return gm , info
218
249
219
250
0 commit comments