diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 408e6e6ce0..0a1f837b33 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -25,7 +25,7 @@ from torch._inductor.test_case import TestCase as InductorTestCase from torch.testing._internal import common_utils -from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl +from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl, preprocess_scale from torchao.float8.float8_utils import compute_error from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, @@ -595,6 +595,51 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity): error = compute_error(ref_output, quant_output) self.assertGreater(error, 15, f"Quantization SQNR too low: {error}") + def test_preprocess_scale_3d_reshape(self): + """Test that preprocess_scale correctly handles 3D scale tensors""" + device = "cpu" # Use CPU for basic functionality test + + # Test 1: PerTensor scale (scalar) - should reshape to (1, 1) + per_tensor_scale = torch.tensor(0.5, device=device) + result = preprocess_scale(per_tensor_scale, (2, 4, 8)) + expected_shape = (1, 1) + self.assertEqual(result.shape, expected_shape) + self.assertEqual(result.item(), 0.5) + + # Test 2: 1D scale tensor with one element - should reshape to (1, 1) + one_element_scale = torch.tensor([0.3], device=device) + result = preprocess_scale(one_element_scale, (2, 4, 8)) + expected_shape = (1, 1) + self.assertEqual(result.shape, expected_shape) + self.assertEqual(result.item(), 0.3) + + # Test 3: 3D scale tensor for per-row quantization - should flatten first N-1 dims + # This is the key test for the 3D reshape fix + scale_3d = torch.randn( + 2, 4, device=device + ) # Shape matches first 2 dims of (2, 4, 8) + result = preprocess_scale(scale_3d, (2, 4, 8)) + expected_shape = (8, 1) # Flattened (2*4, 1) + self.assertEqual(result.shape, expected_shape) + + # Verify the values are preserved correctly + expected_values = scale_3d.flatten().unsqueeze(-1) + self.assertTrue(torch.allclose(result, expected_values)) + + # Test 4: 2D scale tensor (already correct shape) - should just add last dimension + scale_2d = torch.randn(8, device=device) + result = preprocess_scale(scale_2d, (8, 16)) + expected_shape = (8, 1) + self.assertEqual(result.shape, expected_shape) + + # Test 5: Edge case with higher dimensions (4D) + scale_4d = torch.randn( + 2, 2, 2, device=device + ) # Shape matches first 3 dims of (2, 2, 2, 8) + result = preprocess_scale(scale_4d, (2, 2, 2, 8)) + expected_shape = (8, 1) # Flattened (2*2*2, 1) + self.assertEqual(result.shape, expected_shape) + common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile) diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index 799832a5ea..543bd5002b 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -370,10 +370,18 @@ def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool: return check_aqt(input_tensor) and check_aqt(weight_tensor) -def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]): - """Ensures input tensor is correctly formated for _scaled_mm""" +def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int, ...]): + """Ensures input tensor is correctly formatted for _scaled_mm""" + + # For PerTensor quantization, scale should be a scalar or have shape [1] + if input_scale.numel() == 1: + # Already a scalar, ensure it has the right shape for _scaled_mm + return input_scale.reshape(1, 1) + + # For per-row/block quantization, we need to handle the reshaping input_scale = input_scale.unsqueeze(-1) + # Match: #input_data.reshape(-1, input_data.shape[-1]) if input_scale.dim() > 2: input_scale = input_scale.reshape(-1, input_scale.shape[-1]) @@ -388,31 +396,28 @@ def _linear_fp8_act_fp8_weight_impl( """Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm""" scaled_mm_config = weight_tensor._layout.mm_config assert scaled_mm_config is not None - out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) + assert not weight_tensor.tensor_impl.transposed, "Weight tensor must be contiguous" - # Weight tensor preprocessing - w_tensor_impl = weight_tensor.tensor_impl - assert not w_tensor_impl.transposed, "Weight tensor must be contiguous" - w_data = w_tensor_impl.float8_data - w_scale = w_tensor_impl.scale + out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) - # Input tensor preprocessing - inpt_data = input_tensor.tensor_impl.float8_data + # Extract tensor data and scales + inpt_data = input_tensor.tensor_impl.float8_data.reshape( + -1, input_tensor.tensor_impl.float8_data.shape[-1] + ) + w_data = weight_tensor.tensor_impl.float8_data input_scale = input_tensor.tensor_impl.scale - # Handle case where input tensor is more than 2D - inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1]) - # Handle rowwise case + w_scale = weight_tensor.tensor_impl.scale + + # Handle rowwise scaling if _is_rowwise_scaled(weight_tensor): assert _is_rowwise_scaled(input_tensor), ( "Input tensor must be rowwise block size" ) - w_scale = w_scale.T - input_scale = preprocess_scale(input_scale, input_tensor.shape) + w_scale = w_scale.transpose(-1, -2) - # Preprocess data + input_scale = preprocess_scale(input_scale, input_tensor.shape) inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config) - # Perform the computation return addmm_float8_unwrapped_inference( inpt_data, input_scale, diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index 00c905f3d8..274ac5d258 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -94,9 +94,8 @@ def addmm_float8_unwrapped_inference( out_dtype=output_dtype, use_fast_accum=use_fast_accum, ) - output += bias - return output - output = torch._scaled_mm( + return output + bias + return torch._scaled_mm( a_data, b_data, scale_a=a_scale, @@ -106,7 +105,6 @@ def addmm_float8_unwrapped_inference( out_dtype=output_dtype, use_fast_accum=use_fast_accum, ) - return output def _is_rowwise_scaled(x) -> bool: