Skip to content

Commit 6a74e34

Browse files
authored
fix float8 rowwise inference perf with torch.compile (#2672)
In #2379, logic was added which prevented torchinductor from fusing the activation quantization for float8 inference. This PR reverts most of #2379, and adds a test to ensure we see the correct # of GPU kernels for float8 tensorwise and rowwise quantization. We'll have to re-do #2379 without breaking this test. Summary: Test Plan: ```bash TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 pytest test/dtypes/test_affine_quantized_float.py -s -k expected_kernels_on_gpu ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 5f3ab63 commit 6a74e34

File tree

3 files changed

+109
-49
lines changed

3 files changed

+109
-49
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import fire
7+
import torch
8+
import torch.nn as nn
9+
from torch._inductor.utils import do_bench_using_profiling
10+
11+
from torchao.quantization.quant_api import (
12+
Float8DynamicActivationFloat8WeightConfig,
13+
PerRow,
14+
quantize_,
15+
)
16+
17+
18+
def benchmark_fn_in_usec(f, *args, **kwargs):
19+
no_args = lambda: f(*args, **kwargs)
20+
time = do_bench_using_profiling(no_args)
21+
return time * 1e3
22+
23+
24+
def run(torch_compile_mode: str = "default"):
25+
M, K, N = 1024, 2048, 4096
26+
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
27+
m = nn.Sequential(nn.Linear(K, N, device="cuda", dtype=torch.bfloat16))
28+
quantize_(m, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()))
29+
m = torch.compile(m, mode=torch_compile_mode)
30+
# warm up
31+
with torch.no_grad():
32+
_ = m(x)
33+
# measure
34+
with torch.no_grad():
35+
time_us = benchmark_fn_in_usec(m, x)
36+
print("time_us", time_us)
37+
38+
39+
if __name__ == "__main__":
40+
fire.Fire(run)

test/dtypes/test_affine_quantized_float.py

Lines changed: 69 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import pytest
2424
import torch
2525
from torch._inductor.test_case import TestCase as InductorTestCase
26+
from torch.profiler import ProfilerActivity, profile
2627
from torch.testing._internal import common_utils
2728

2829
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl, preprocess_scale
@@ -718,45 +719,74 @@ def test_preprocess_scale_3d_reshape(self):
718719
expected_shape = (8, 1) # Flattened (2*2*2, 1)
719720
self.assertEqual(result.shape, expected_shape)
720721

721-
@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
722-
@common_utils.parametrize("hp_dtype", [torch.float32, torch.bfloat16])
723-
def test_quantize_dequantize_fp8_inductor(self, float8_dtype, hp_dtype):
724-
quantize_affine_float8 = torch.ops.torchao.quantize_affine_float8
725-
dequantize_affine_float8 = torch.ops.torchao.dequantize_affine_float8
726-
input = torch.randn(10, 10)
727-
with torch.no_grad():
728-
torch._dynamo.reset()
729-
expected_scale = torch.tensor(2.0)
730-
expected_quantized = quantize_affine_float8(
731-
input,
732-
expected_scale,
733-
float8_dtype=float8_dtype,
734-
)
735-
expected_dequantized = dequantize_affine_float8(
736-
expected_quantized,
737-
expected_scale,
738-
output_dtype=hp_dtype,
739-
)
740-
test_q, (code_q,) = torch._inductor.utils.run_and_get_code(
741-
torch.compile(quantize_affine_float8),
742-
input,
743-
expected_scale,
744-
float8_dtype=float8_dtype,
745-
)
746-
torch.testing.FileCheck().check(
747-
"torch.ops.torchao.quantize_affine_float8.default"
748-
).run(code_q)
749-
test_dq, (code_dq,) = torch._inductor.utils.run_and_get_code(
750-
torch.compile(dequantize_affine_float8),
751-
test_q,
752-
expected_scale,
753-
hp_dtype,
754-
)
755-
torch.testing.FileCheck().check(
756-
"torch.ops.torchao.dequantize_affine_float8.default"
757-
).run(code_dq)
758-
torch.testing.assert_close(expected_quantized, test_q)
759-
torch.testing.assert_close(expected_dequantized, test_dq)
722+
@torch.no_grad()
723+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
724+
@unittest.skipIf(
725+
not is_sm_at_least_90(), "Requires GPU with compute capability >= 9.0"
726+
)
727+
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
728+
@common_utils.parametrize(
729+
"torch_compile_mode",
730+
[
731+
"default",
732+
"reduce-overhead",
733+
],
734+
)
735+
def test_expected_kernels_on_gpu(self, granularity, torch_compile_mode):
736+
"""
737+
Verify that float8 quantization + torch.compile results in the
738+
expected number of kernels in the GPU trace.
739+
"""
740+
741+
M, K, N = 128, 256, 512
742+
m = torch.nn.Sequential(
743+
torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16)
744+
)
745+
quantize_(m, Float8DynamicActivationFloat8WeightConfig(granularity=granularity))
746+
m = torch.compile(m, mode=torch_compile_mode)
747+
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
748+
749+
# warm up
750+
_ = m(x)
751+
# capture trace
752+
with profile(activities=[ProfilerActivity.CUDA]) as prof:
753+
_ = m(x)
754+
755+
cuda_kernel_events = [x for x in prof.key_averages() if x.cuda_time > 0]
756+
757+
if granularity == PerTensor():
758+
# kernel 1: x_max_tmp = max(x, ...)
759+
# kernel 2: x_max = max(x_max_tmp)
760+
# kernel 3: x_float8 = to_float8(x, x_max)
761+
# kernel 4: gemm
762+
if torch_compile_mode == "default":
763+
assert len(cuda_kernel_events) == 4, (
764+
f"too many cuda kernels: {cuda_kernel_events}"
765+
)
766+
elif torch_compile_mode == "reduce-overhead":
767+
# two extra kernels with reduce-overhead:
768+
# void at::native::(anonymous namespace)::multi_tensor...
769+
# void at::native::vectorized_elementwise_kernel<2, at...
770+
# TODO(future): debug and remove these
771+
assert len(cuda_kernel_events) == 6, (
772+
f"too many cuda kernels: {cuda_kernel_events}"
773+
)
774+
else:
775+
assert granularity == PerRow()
776+
# kernel 1: x_float8 = to_float8(x)
777+
# kernel 2: gemm
778+
if torch_compile_mode == "default":
779+
assert len(cuda_kernel_events) == 2, (
780+
f"too many cuda kernels: {cuda_kernel_events}"
781+
)
782+
elif torch_compile_mode == "reduce-overhead":
783+
# two extra kernels with reduce-overhead:
784+
# void at::native::(anonymous namespace)::multi_tensor...
785+
# void at::native::vectorized_elementwise_kernel<2, at...
786+
# TODO(future): debug and remove these
787+
assert len(cuda_kernel_events) == 4, (
788+
f"too many cuda kernels: {cuda_kernel_events}"
789+
)
760790

761791

762792
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)

torchao/quantization/quant_primitives.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2279,7 +2279,6 @@ def _expand_scale_to_tensor_shape(
22792279
return expanded_scale
22802280

22812281

2282-
@_register_custom_op(quant_lib, False)
22832282
def _quantize_affine_float8(
22842283
tensor: torch.Tensor,
22852284
scale: torch.Tensor,
@@ -2300,15 +2299,6 @@ def _quantize_affine_float8(
23002299
return fp8_tensor
23012300

23022301

2303-
@_register_meta_op(quant_lib, "quantize_affine_float8")
2304-
def _quantize_affine_float8_meta(
2305-
tensor: torch.Tensor,
2306-
scale: torch.Tensor,
2307-
float8_dtype: torch.dtype = torch.float8_e4m3fn,
2308-
) -> torch.Tensor:
2309-
return torch.empty_like(tensor, dtype=float8_dtype)
2310-
2311-
23122302
@_register_custom_op(quant_lib, False)
23132303
def _dequantize_affine_float8(
23142304
tensor: torch.Tensor,

0 commit comments

Comments
 (0)