Skip to content

Commit f4463a5

Browse files
committed
remove unused ENUM
Signed-off-by: Frida Hou <[email protected]>
1 parent 40ef068 commit f4463a5

File tree

3 files changed

+1
-23
lines changed

3 files changed

+1
-23
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,6 @@
88

99
from .quant import QUANT_LINEAR_OPS, QUANT_OPS
1010

11-
# ===== Enums =====
12-
FORMAT_FP8 = 0
13-
FORMAT_NVFP4 = 1
14-
15-
# scale layouts
16-
PER_TENSOR = 0
17-
PER_CHANNEL_OUT = 1
18-
1911
# FP4 tables (E2M1)
2012
e2m1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5])
2113
e2m1_values = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6])
@@ -179,7 +171,6 @@ def torch_fake_quant_fp8_linear(
179171
For FP8:
180172
- input_scale[0] and weight_scale[0] are required (amax/448 style)
181173
- input_zp / weight_zp ignored
182-
- supports PER_TENSOR and PER_CHANNEL_OUT for weights
183174
"""
184175
if weight_quantized.dtype != torch.float8_e4m3fn:
185176
raise TypeError("FP8 path requires weight_quantized.dtype == float8_e4m3fn")

tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,6 @@
2727
except ImportError:
2828
float4_sf_dtype = None
2929

30-
# TODO: put the ENUMs in the same place and import it
31-
FORMAT_FP8 = 0
32-
FORMAT_NVFP4 = 1
33-
3430

3531
def modelopt_fp4_scale_to_cutlass_fp4_scale(modelopt_scale: torch.Tensor) -> torch.Tensor:
3632
"""Converts the modelopt FP4 per-block weight scale to the cutlass format (padded and swizzled)."""
@@ -185,14 +181,11 @@ def default_scales(original_weight_shape: Tuple) -> Dict[str, torch.Tensor]:
185181
def build_custom_kwargs_for_linear(
186182
scale_getattrs: Dict[str, Node],
187183
) -> Dict[str, object]:
188-
# FP8 custom op contract:
189-
# input_scale=[tensor], weight_scale=[tensor], input_zp=[], weight_zp=[], format_type=FORMAT_FP8
190184
return dict(
191185
input_scale=[scale_getattrs["input_scale"]],
192186
weight_scale=[scale_getattrs["weight_scale"]],
193187
input_zp=[],
194188
weight_zp=[],
195-
# format_type=FORMAT_FP8,
196189
)
197190

198191
@staticmethod
@@ -280,15 +273,13 @@ def build_custom_kwargs_for_linear(
280273
weight_scale=[weight_scale_cutlass_uint8, alpha_fused],
281274
input_zp=[],
282275
weight_zp=[],
283-
format_type=FORMAT_NVFP4
284276
)
285277
"""
286278
return dict(
287279
input_scale=[scale_getattrs["input_scale"]],
288280
weight_scale=[scale_getattrs["weight_scale"], scale_getattrs["alpha"]],
289281
input_zp=[],
290282
weight_zp=[],
291-
# format_type=FORMAT_NVFP4,
292283
)
293284

294285
@staticmethod

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,6 @@
88

99
torch.manual_seed(0)
1010

11-
scaling_vector_size = 16
12-
FORMAT_FP8 = 0
13-
FORMAT_NVFP4 = 1
14-
1511
SCALING_VECTOR_SIZE = 16 # NVFP4 block size along K
1612

1713

@@ -51,7 +47,7 @@ def test_fp4_linear():
5147
weight_scale_2 = fp4_global_scale(weight)
5248

5349
weight_fp4, weight_scale = torch.ops.trtllm.fp4_quantize(
54-
weight, weight_scale_2, scaling_vector_size, False
50+
weight, weight_scale_2, SCALING_VECTOR_SIZE, False
5551
)
5652

5753
output_fp4_gemm = torch.ops.auto_deploy.torch_quant_fp4_linear(

0 commit comments

Comments
 (0)