Skip to content

Commit cb34335

Browse files
authored
Registry interface for custom quantization functional backend (#683)
## What does this PR do? **Type of change:** ? <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** Add registry interface for custom quantization functional backend ## Usage <!-- You can potentially add a usage example below. --> see `tests/unit/torch/quantization/test_custom_backend.py` for usage example. ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes ## Additional Information <!-- E.g. related issue. --> Signed-off-by: realAsma <[email protected]>
1 parent 3fd8b80 commit cb34335

File tree

14 files changed

+328
-77
lines changed

14 files changed

+328
-77
lines changed

.github/workflows/example_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ jobs:
9393
strategy:
9494
fail-fast: false
9595
matrix:
96-
example: [llm_ptq, vlm_ptq]
96+
example: [llm_ptq] # vlm_ptq temporarily disabled due to pipeline error
9797
uses: ./.github/workflows/_example_tests_runner.yml
9898
secrets: inherit
9999
with:

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ NVIDIA Model Optimizer Changelog (Linux)
1212
- Add support for KV Cache Quantization for vLLM FakeQuant PTQ script. See `examples/vllm_serve/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/vllm_serve#Calibrate-and-serve-fake-quant-model-in-vLLM>`__ for more details.
1313
- Add support for subgraphs in ONNX autocast.
1414
- Add support for parallel draft heads in Eagle speculative decoding.
15+
- Add support to enable custom emulated quantization backend. See :meth:`register_quant_backend <modelopt.torch.quantization.nn.modules.tensor_quantizer.register_quant_backend>`` for more details. See an example in ``tests/unit/torch/quantization/test_custom_backend.py``.
1516

1617
**Deprecations**
1718

modelopt/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Nvidia Model Optimizer (modelopt)."""
1717

18+
import warnings as _warnings
1819
from importlib.metadata import version as _version
1920

2021
__version__ = _version("nvidia-modelopt")

modelopt/torch/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,9 @@
3939
)
4040
except ImportError:
4141
pass
42+
43+
# Initialize modelopt_internal if available
44+
with utils.import_plugin(
45+
"modelopt_internal", success_msg="modelopt_internal successfully initialized", verbose=True
46+
):
47+
import modelopt_internal

modelopt/torch/quantization/config.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,7 @@ class QuantizerAttributeConfig(ModeloptBaseConfig):
667667
description="""If True, enables the quantizer. If False, by-pass the quantizer and returns the input tensor.""",
668668
)
669669

670-
num_bits: int | tuple[int, int] = ModeloptField(
670+
num_bits: int | tuple[int, int] | str = ModeloptField(
671671
default=8,
672672
title="An integer or a tuple of two integers specifying the number of quantization bits.",
673673
description="""`num_bits` can be:
@@ -677,7 +677,9 @@ class QuantizerAttributeConfig(ModeloptBaseConfig):
677677
678678
#. Constant integer tuple (E,M) for floating point quantization emulating
679679
Nvidia's FPx quantization. E is the number of exponent bits and M is the number
680-
of mantissa bits. Supported FPx quantization formats: FP8 (E4M3, E5M2), FP6(E3M2, E2M3), FP4(E2M1).""",
680+
of mantissa bits. Supported FPx quantization formats: FP8 (E4M3, E5M2), FP6(E3M2, E2M3), FP4(E2M1).
681+
682+
#. String specifying the quantization format. This is current used only for custom backends.""",
681683
)
682684

683685
@model_validator(mode="before")
@@ -709,10 +711,16 @@ def _validate_recursive(value):
709711
@model_validator(mode="after")
710712
def validate_num_bits(self):
711713
"""Validate `num_bits`."""
714+
if self.backend is not None:
715+
# For custom backends, we don't need to validate num_bits
716+
return self
717+
712718
num_bits = self.num_bits
713719

714720
if isinstance(num_bits, int) and num_bits < 1:
715-
raise ValueError("num_bits must be a positive integer or a tuple of positive integers.")
721+
raise ValueError(
722+
f"num_bits must be a positive integer or a tuple of positive integers. {num_bits}"
723+
)
716724

717725
if not isinstance(num_bits, tuple):
718726
return self
@@ -954,6 +962,27 @@ def validate_calibrator(cls, v, info: ValidationInfo):
954962
""",
955963
)
956964

965+
backend: str | None = ModeloptField(
966+
default=None,
967+
title="Name of custom quantization functional backend.",
968+
description="""
969+
Selects a non-default quantization functional backend by name. See
970+
:meth:`register_quant_backend <modelopt.torch.nn.modules.tensor_quantizer.register_quant_backend>`
971+
for more details on how to register a custom quantization backend.
972+
""",
973+
)
974+
backend_extra_args: dict | None = ModeloptField(
975+
default=None,
976+
title="Extra arguments for the selected backend.",
977+
description="""The extra arguments will saved on to the quantizer instance - this wont be
978+
passed directly to the backend entrypoint. Can be any serializable dictionary.
979+
980+
Please use `backend_extra_args` to pass arguments that are not already supported by
981+
`QuantizerAttributeConfig`. This will ensure maximum compatibility with the other modelopt
982+
features such as modelopt's calibration algorithms.
983+
""",
984+
)
985+
957986

958987
class QuantizeAlgorithmConfig(ModeloptBaseConfig):
959988
"""Calibration algorithm config base."""

modelopt/torch/quantization/model_quant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def forward_loop(model) -> None:
229229
Returns: A pytorch model which has been quantized and calibrated.
230230
"""
231231
model = apply_mode(model, mode=[("quantize", config)], registry=QuantizeModeRegistry)
232-
return calibrate(model, config["algorithm"], forward_loop=forward_loop)
232+
return calibrate(model, config.get("algorithm"), forward_loop=forward_loop)
233233

234234

235235
# TODO: create a config interface for auto_quantize and expose setting

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 97 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
import contextlib
1919
import math
2020
import warnings
21-
from typing import TYPE_CHECKING, Any
21+
from collections.abc import Callable
22+
from typing import Any, Protocol
2223

2324
import torch
2425
import torch.distributed as dist
@@ -36,7 +37,7 @@
3637
import torch.nn.functional as F
3738
from torch import nn
3839

39-
from modelopt.torch.utils import standardize_constructor_args
40+
from modelopt.torch.utils import same_device_as, standardize_constructor_args
4041
from modelopt.torch.utils.distributed import DistributedProcessGroup
4142

4243
from ... import calib
@@ -56,10 +57,63 @@
5657
from ...utils import is_torch_export_mode
5758
from ..functional import normalized_hadamard_transform
5859

59-
if TYPE_CHECKING:
60-
from collections.abc import Callable
60+
__all__ = [
61+
"SequentialQuantizer",
62+
"TensorQuantizer",
63+
"TensorQuantizerCache",
64+
"is_registered_quant_backend",
65+
"register_quant_backend",
66+
"unregister_quant_backend",
67+
]
6168

62-
__all__ = ["SequentialQuantizer", "TensorQuantizer"]
69+
70+
QuantBackendEntrypoint = Callable[[torch.Tensor, "TensorQuantizer"], torch.Tensor]
71+
72+
_QUANT_FUNCTIONAL_BACKENDS: dict[str, QuantBackendEntrypoint] = {}
73+
74+
75+
def register_quant_backend(name: str, entrypoint: QuantBackendEntrypoint) -> None:
76+
"""Register a custom quantization backend.
77+
78+
Args:
79+
name: The name of the backend.
80+
entrypoint: The entrypoint of the backend. The entrypoint should be a callable that takes in
81+
the inputs and the tensor quantizer as arguments and returns the quantized tensor.
82+
See :class:`modelopt.torch.quantization.config.QuantizerAttributeConfig`
83+
for details on choosing from the registered backends via the ``backend`` and
84+
``backend_extra_args`` fields.
85+
"""
86+
if not isinstance(name, str) or not name:
87+
raise ValueError("Backend name must be a non-empty string.")
88+
if not callable(entrypoint):
89+
raise TypeError("Entrypoint must be callable.")
90+
if name in _QUANT_FUNCTIONAL_BACKENDS:
91+
warnings.warn(f"Overwriting existing backend: {name}")
92+
_QUANT_FUNCTIONAL_BACKENDS[name] = entrypoint
93+
94+
95+
def unregister_quant_backend(name: str) -> None:
96+
"""Unregister a custom quantization backend.
97+
98+
Args:
99+
name: The name of the backend to unregister.
100+
"""
101+
if not isinstance(name, str) or not name:
102+
raise ValueError("Backend name must be a non-empty string.")
103+
_QUANT_FUNCTIONAL_BACKENDS.pop(name, None)
104+
105+
106+
def is_registered_quant_backend(name: str) -> bool:
107+
"""Check if a custom quantization backend is registered.
108+
109+
Args:
110+
name: The name of the backend to check.
111+
"""
112+
return name in _QUANT_FUNCTIONAL_BACKENDS
113+
114+
115+
class TensorQuantizerCache(Protocol):
116+
"""A protocol for a cache interface for TensorQuantizer."""
63117

64118

65119
class TensorQuantizer(nn.Module):
@@ -104,6 +158,8 @@ class TensorQuantizer(nn.Module):
104158
"ds_grads_remaining",
105159
"ds_id",
106160
"pre_bwd_fn",
161+
# quantizer cache for custom backends, like luts
162+
"_quantizer_cache",
107163
}
108164

109165
def __init__(
@@ -132,6 +188,9 @@ def __init__(
132188
# Lazy initialize the bias calibrator for KV cache quantization
133189
self._bias_calibrator = None
134190

191+
# Optional quantizer cache for caching quantizer related encoding or tensors.
192+
self._quantizer_cache = None
193+
135194
def set_from_attribute_config(self, attribute_cfg: QuantizerAttributeConfig | dict):
136195
"""Set quantizer attributes from attribute_dict.
137196
@@ -153,6 +212,8 @@ def _calibrator_setter(val):
153212
"enable": ("_disabled", lambda val: val is False),
154213
"type": ("_dynamic", lambda val: val == "dynamic"),
155214
"calibrator": ("_calibrator", _calibrator_setter),
215+
"backend": ("backend", lambda val: val),
216+
"backend_extra_args": ("backend_extra_args", lambda val: val or {}),
156217
}
157218

158219
for attribute, val in attribute_cfg.items():
@@ -632,6 +693,12 @@ def _real_quantize(self, inputs):
632693

633694
def _fake_quantize(self, inputs):
634695
"""Fake quantization."""
696+
if self.backend is not None:
697+
if self.backend not in _QUANT_FUNCTIONAL_BACKENDS:
698+
raise KeyError(f"Quant backend '{self.backend}' is not registered.")
699+
entrypoint = _QUANT_FUNCTIONAL_BACKENDS[self.backend]
700+
return entrypoint(inputs, self)
701+
635702
amax = None
636703
if not self.is_mx_format:
637704
amax = self._get_amax(inputs)
@@ -934,7 +1001,8 @@ def forward(self, inputs):
9341001
if hasattr(inputs, "is_contiguous") and not inputs.is_contiguous():
9351002
inputs.data = inputs.data.contiguous()
9361003
if self.fake_quant:
937-
outputs = self._fake_quantize(inputs)
1004+
with same_device_as(inputs):
1005+
outputs = self._fake_quantize(inputs)
9381006
elif not self._dequantize:
9391007
outputs = self._real_quantize(inputs)
9401008
else:
@@ -964,16 +1032,23 @@ def _short_amax(self, fmt=".4f"):
9641032
return "None"
9651033
if self._amax.is_meta:
9661034
return "meta"
967-
if self._amax.numel() == 1:
968-
return f"{self._amax.item():{fmt}}"
969-
return (
970-
f"[{self._amax.min().item():{fmt}},"
971-
f" {self._amax.max().item():{fmt}}]({self._amax.numel()})"
972-
)
1035+
return self._short_tensor(self._amax, fmt)
1036+
1037+
def _short_tensor(self, tensor: torch.Tensor, fmt=".4f"):
1038+
"""Short description of tensor."""
1039+
if tensor.numel() == 1:
1040+
return f"{tensor.item():{fmt}}"
1041+
return f"[{tensor.min().item():{fmt}}, {tensor.max().item():{fmt}}]({tensor.numel()})"
9731042

9741043
def extra_repr(self):
9751044
"""Set the extra information about this module."""
9761045
if self._disabled:
1046+
s = "disabled"
1047+
s += (
1048+
f" pre_quant_scale={self._short_tensor(self.pre_quant_scale)}"
1049+
if self.pre_quant_scale is not None
1050+
else ""
1051+
)
9771052
return "disabled"
9781053
s = f"{'unsigned ' if self._unsigned else ''}{self._num_bits} bit"
9791054
s += " narrow" if (self._narrow_range) else ""
@@ -983,7 +1058,11 @@ def extra_repr(self):
9831058
else:
9841059
s += f" axis={self._axis}" if self._axis is not None else " per-tensor"
9851060
s += f" amax={self._short_amax()}"
986-
s += " pre_quant_scale" if self.pre_quant_scale is not None else ""
1061+
s += (
1062+
f" pre_quant_scale={self._short_tensor(self.pre_quant_scale)}"
1063+
if self.pre_quant_scale is not None
1064+
else ""
1065+
)
9871066
s += " rotated" if self._rotate else ""
9881067
s += (
9891068
f" calibrator={self._calibrator.__class__.__name__}"
@@ -995,6 +1074,11 @@ def extra_repr(self):
9951074

9961075
s += " quant" if (self._if_quant) else ""
9971076
s += " calib" if (self._if_calib) else ""
1077+
s += (
1078+
f" backend={self.backend}, extra_args={self.backend_extra_args}"
1079+
if self.backend is not None
1080+
else ""
1081+
)
9981082
return s
9991083

10001084
def _get_properties_for_modelopt_state(self):

modelopt/torch/quantization/tensor_quant.py

Lines changed: 28 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,11 @@ def scaled_e4m3_impl(
7979
if cuda_ext_fp8 is None:
8080
return fp8_eager(inputs, amax)
8181

82-
with torch.cuda.device(
83-
None if inputs.device.index == torch.cuda.current_device() else inputs.device.index
84-
):
85-
if amax.numel() == 1:
86-
outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax)
87-
elif amax.squeeze().ndim == 1:
88-
axis = amax.shape.index(amax.numel())
89-
outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)
82+
if amax.numel() == 1:
83+
outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax)
84+
elif amax.squeeze().ndim == 1:
85+
axis = amax.shape.index(amax.numel())
86+
outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)
9087
return outputs
9188

9289

@@ -100,17 +97,14 @@ def fake_quant_impl(
10097
"""Implementation of fake quantizing input according to number of bits."""
10198
cuda_ext = get_cuda_ext()
10299

103-
with torch.cuda.device(
104-
None if inputs.device.index == torch.cuda.current_device() else inputs.device.index
105-
):
106-
if amax.numel() == 1:
107-
outputs = cuda_ext.fake_tensor_quant(inputs, amax, num_bits, unsigned, narrow_range)
108-
else:
109-
axis = amax.shape.index(amax.numel())
110-
outputs = cuda_ext.fake_tensor_quant_with_axis(
111-
inputs, amax.squeeze(), axis, num_bits, unsigned, narrow_range
112-
)
113-
return outputs
100+
if amax.numel() == 1:
101+
outputs = cuda_ext.fake_tensor_quant(inputs, amax, num_bits, unsigned, narrow_range)
102+
else:
103+
axis = amax.shape.index(amax.numel())
104+
outputs = cuda_ext.fake_tensor_quant_with_axis(
105+
inputs, amax.squeeze(), axis, num_bits, unsigned, narrow_range
106+
)
107+
return outputs
114108

115109

116110
def _quantize_impl(
@@ -173,25 +167,22 @@ def _dynamic_block_quantize_impl(
173167
assert amax.is_cuda, "amax must be a CUDA tensor for dynamic block quantization."
174168
if amax.numel() != 1:
175169
amax = amax.amax()
176-
with torch.cuda.device(
177-
None if inputs.device.index == torch.cuda.current_device() else inputs.device.index
170+
if (
171+
num_bits == (2, 1) # type: ignore[comparison-overlap]
172+
and scale_bits == (4, 3)
173+
and triton_kernel.IS_AVAILABLE
174+
and not DISABLE_TRITON_KERNEL
175+
and amax is not None
178176
):
179-
if (
180-
num_bits == (2, 1) # type: ignore[comparison-overlap]
181-
and scale_bits == (4, 3)
182-
and triton_kernel.IS_AVAILABLE
183-
and not DISABLE_TRITON_KERNEL
184-
and amax is not None
185-
):
186-
return triton_kernel.fp4_fake_quant_block(inputs, amax)
187-
cuda_ext_mx = get_cuda_ext_mx(raise_if_failed=True)
188-
return cuda_ext_mx.fused_amax_convert(
189-
inputs,
190-
block_size,
191-
getattr(cuda_ext_mx.Types, mx_format_map[num_bits]),
192-
getattr(cuda_ext_mx.Types, mx_format_map[scale_bits]),
193-
amax,
194-
)
177+
return triton_kernel.fp4_fake_quant_block(inputs, amax)
178+
cuda_ext_mx = get_cuda_ext_mx(raise_if_failed=True)
179+
return cuda_ext_mx.fused_amax_convert(
180+
inputs,
181+
block_size,
182+
getattr(cuda_ext_mx.Types, mx_format_map[num_bits]),
183+
getattr(cuda_ext_mx.Types, mx_format_map[scale_bits]),
184+
amax,
185+
)
195186
else:
196187
raise NotImplementedError(
197188
f"Unsupported num_bits: {num_bits}, scale_bits: {scale_bits} for dynamic block quantization."

0 commit comments

Comments
 (0)