|
23 | 23 | import pytest
|
24 | 24 | import torch
|
25 | 25 | from torch._inductor.test_case import TestCase as InductorTestCase
|
| 26 | +from torch.profiler import ProfilerActivity, profile |
26 | 27 | from torch.testing._internal import common_utils
|
27 | 28 |
|
28 | 29 | from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl, preprocess_scale
|
@@ -718,45 +719,74 @@ def test_preprocess_scale_3d_reshape(self):
|
718 | 719 | expected_shape = (8, 1) # Flattened (2*2*2, 1)
|
719 | 720 | self.assertEqual(result.shape, expected_shape)
|
720 | 721 |
|
721 |
| - @common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) |
722 |
| - @common_utils.parametrize("hp_dtype", [torch.float32, torch.bfloat16]) |
723 |
| - def test_quantize_dequantize_fp8_inductor(self, float8_dtype, hp_dtype): |
724 |
| - quantize_affine_float8 = torch.ops.torchao.quantize_affine_float8 |
725 |
| - dequantize_affine_float8 = torch.ops.torchao.dequantize_affine_float8 |
726 |
| - input = torch.randn(10, 10) |
727 |
| - with torch.no_grad(): |
728 |
| - torch._dynamo.reset() |
729 |
| - expected_scale = torch.tensor(2.0) |
730 |
| - expected_quantized = quantize_affine_float8( |
731 |
| - input, |
732 |
| - expected_scale, |
733 |
| - float8_dtype=float8_dtype, |
734 |
| - ) |
735 |
| - expected_dequantized = dequantize_affine_float8( |
736 |
| - expected_quantized, |
737 |
| - expected_scale, |
738 |
| - output_dtype=hp_dtype, |
739 |
| - ) |
740 |
| - test_q, (code_q,) = torch._inductor.utils.run_and_get_code( |
741 |
| - torch.compile(quantize_affine_float8), |
742 |
| - input, |
743 |
| - expected_scale, |
744 |
| - float8_dtype=float8_dtype, |
745 |
| - ) |
746 |
| - torch.testing.FileCheck().check( |
747 |
| - "torch.ops.torchao.quantize_affine_float8.default" |
748 |
| - ).run(code_q) |
749 |
| - test_dq, (code_dq,) = torch._inductor.utils.run_and_get_code( |
750 |
| - torch.compile(dequantize_affine_float8), |
751 |
| - test_q, |
752 |
| - expected_scale, |
753 |
| - hp_dtype, |
754 |
| - ) |
755 |
| - torch.testing.FileCheck().check( |
756 |
| - "torch.ops.torchao.dequantize_affine_float8.default" |
757 |
| - ).run(code_dq) |
758 |
| - torch.testing.assert_close(expected_quantized, test_q) |
759 |
| - torch.testing.assert_close(expected_dequantized, test_dq) |
| 722 | + @torch.no_grad() |
| 723 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 724 | + @unittest.skipIf( |
| 725 | + not is_sm_at_least_90(), "Requires GPU with compute capability >= 9.0" |
| 726 | + ) |
| 727 | + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) |
| 728 | + @common_utils.parametrize( |
| 729 | + "torch_compile_mode", |
| 730 | + [ |
| 731 | + "default", |
| 732 | + "reduce-overhead", |
| 733 | + ], |
| 734 | + ) |
| 735 | + def test_expected_kernels_on_gpu(self, granularity, torch_compile_mode): |
| 736 | + """ |
| 737 | + Verify that float8 quantization + torch.compile results in the |
| 738 | + expected number of kernels in the GPU trace. |
| 739 | + """ |
| 740 | + |
| 741 | + M, K, N = 128, 256, 512 |
| 742 | + m = torch.nn.Sequential( |
| 743 | + torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16) |
| 744 | + ) |
| 745 | + quantize_(m, Float8DynamicActivationFloat8WeightConfig(granularity=granularity)) |
| 746 | + m = torch.compile(m, mode=torch_compile_mode) |
| 747 | + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) |
| 748 | + |
| 749 | + # warm up |
| 750 | + _ = m(x) |
| 751 | + # capture trace |
| 752 | + with profile(activities=[ProfilerActivity.CUDA]) as prof: |
| 753 | + _ = m(x) |
| 754 | + |
| 755 | + cuda_kernel_events = [x for x in prof.key_averages() if x.cuda_time > 0] |
| 756 | + |
| 757 | + if granularity == PerTensor(): |
| 758 | + # kernel 1: x_max_tmp = max(x, ...) |
| 759 | + # kernel 2: x_max = max(x_max_tmp) |
| 760 | + # kernel 3: x_float8 = to_float8(x, x_max) |
| 761 | + # kernel 4: gemm |
| 762 | + if torch_compile_mode == "default": |
| 763 | + assert len(cuda_kernel_events) == 4, ( |
| 764 | + f"too many cuda kernels: {cuda_kernel_events}" |
| 765 | + ) |
| 766 | + elif torch_compile_mode == "reduce-overhead": |
| 767 | + # two extra kernels with reduce-overhead: |
| 768 | + # void at::native::(anonymous namespace)::multi_tensor... |
| 769 | + # void at::native::vectorized_elementwise_kernel<2, at... |
| 770 | + # TODO(future): debug and remove these |
| 771 | + assert len(cuda_kernel_events) == 6, ( |
| 772 | + f"too many cuda kernels: {cuda_kernel_events}" |
| 773 | + ) |
| 774 | + else: |
| 775 | + assert granularity == PerRow() |
| 776 | + # kernel 1: x_float8 = to_float8(x) |
| 777 | + # kernel 2: gemm |
| 778 | + if torch_compile_mode == "default": |
| 779 | + assert len(cuda_kernel_events) == 2, ( |
| 780 | + f"too many cuda kernels: {cuda_kernel_events}" |
| 781 | + ) |
| 782 | + elif torch_compile_mode == "reduce-overhead": |
| 783 | + # two extra kernels with reduce-overhead: |
| 784 | + # void at::native::(anonymous namespace)::multi_tensor... |
| 785 | + # void at::native::vectorized_elementwise_kernel<2, at... |
| 786 | + # TODO(future): debug and remove these |
| 787 | + assert len(cuda_kernel_events) == 4, ( |
| 788 | + f"too many cuda kernels: {cuda_kernel_events}" |
| 789 | + ) |
760 | 790 |
|
761 | 791 |
|
762 | 792 | common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)
|
|
0 commit comments