Skip to content

Fix Per Tensor 3d rehsape #2293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torch._inductor.test_case import TestCase as InductorTestCase
from torch.testing._internal import common_utils

from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl, preprocess_scale
from torchao.float8.float8_utils import compute_error
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
Expand Down Expand Up @@ -595,6 +595,51 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity):
error = compute_error(ref_output, quant_output)
self.assertGreater(error, 15, f"Quantization SQNR too low: {error}")

def test_preprocess_scale_3d_reshape(self):
"""Test that preprocess_scale correctly handles 3D scale tensors"""
device = "cpu" # Use CPU for basic functionality test

# Test 1: PerTensor scale (scalar) - should reshape to (1, 1)
per_tensor_scale = torch.tensor(0.5, device=device)
result = preprocess_scale(per_tensor_scale, (2, 4, 8))
expected_shape = (1, 1)
self.assertEqual(result.shape, expected_shape)
self.assertEqual(result.item(), 0.5)

# Test 2: 1D scale tensor with one element - should reshape to (1, 1)
one_element_scale = torch.tensor([0.3], device=device)
result = preprocess_scale(one_element_scale, (2, 4, 8))
expected_shape = (1, 1)
self.assertEqual(result.shape, expected_shape)
self.assertEqual(result.item(), 0.3)

# Test 3: 3D scale tensor for per-row quantization - should flatten first N-1 dims
# This is the key test for the 3D reshape fix
scale_3d = torch.randn(
2, 4, device=device
) # Shape matches first 2 dims of (2, 4, 8)
result = preprocess_scale(scale_3d, (2, 4, 8))
expected_shape = (8, 1) # Flattened (2*4, 1)
self.assertEqual(result.shape, expected_shape)

# Verify the values are preserved correctly
expected_values = scale_3d.flatten().unsqueeze(-1)
self.assertTrue(torch.allclose(result, expected_values))

# Test 4: 2D scale tensor (already correct shape) - should just add last dimension
scale_2d = torch.randn(8, device=device)
result = preprocess_scale(scale_2d, (8, 16))
expected_shape = (8, 1)
self.assertEqual(result.shape, expected_shape)

# Test 5: Edge case with higher dimensions (4D)
scale_4d = torch.randn(
2, 2, 2, device=device
) # Shape matches first 3 dims of (2, 2, 2, 8)
result = preprocess_scale(scale_4d, (2, 2, 2, 8))
expected_shape = (8, 1) # Flattened (2*2*2, 1)
self.assertEqual(result.shape, expected_shape)


common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)

Expand Down
39 changes: 22 additions & 17 deletions torchao/dtypes/floatx/float8_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,10 +370,18 @@ def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool:
return check_aqt(input_tensor) and check_aqt(weight_tensor)


def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]):
"""Ensures input tensor is correctly formated for _scaled_mm"""
def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int, ...]):
"""Ensures input tensor is correctly formatted for _scaled_mm"""

# For PerTensor quantization, scale should be a scalar or have shape [1]
if input_scale.numel() == 1:
# Already a scalar, ensure it has the right shape for _scaled_mm
return input_scale.reshape(1, 1)

# For per-row/block quantization, we need to handle the reshaping
input_scale = input_scale.unsqueeze(-1)

# Match: #input_data.reshape(-1, input_data.shape[-1])
if input_scale.dim() > 2:
input_scale = input_scale.reshape(-1, input_scale.shape[-1])

Expand All @@ -388,31 +396,28 @@ def _linear_fp8_act_fp8_weight_impl(
"""Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm"""
scaled_mm_config = weight_tensor._layout.mm_config
assert scaled_mm_config is not None
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)
assert not weight_tensor.tensor_impl.transposed, "Weight tensor must be contiguous"

# Weight tensor preprocessing
w_tensor_impl = weight_tensor.tensor_impl
assert not w_tensor_impl.transposed, "Weight tensor must be contiguous"
w_data = w_tensor_impl.float8_data
w_scale = w_tensor_impl.scale
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)

# Input tensor preprocessing
inpt_data = input_tensor.tensor_impl.float8_data
# Extract tensor data and scales
inpt_data = input_tensor.tensor_impl.float8_data.reshape(
-1, input_tensor.tensor_impl.float8_data.shape[-1]
)
w_data = weight_tensor.tensor_impl.float8_data
input_scale = input_tensor.tensor_impl.scale
# Handle case where input tensor is more than 2D
inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1])
# Handle rowwise case
w_scale = weight_tensor.tensor_impl.scale

# Handle rowwise scaling
if _is_rowwise_scaled(weight_tensor):
assert _is_rowwise_scaled(input_tensor), (
"Input tensor must be rowwise block size"
)
w_scale = w_scale.T
input_scale = preprocess_scale(input_scale, input_tensor.shape)
w_scale = w_scale.transpose(-1, -2)

# Preprocess data
input_scale = preprocess_scale(input_scale, input_tensor.shape)
inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config)

# Perform the computation
return addmm_float8_unwrapped_inference(
inpt_data,
input_scale,
Expand Down
6 changes: 2 additions & 4 deletions torchao/float8/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,8 @@ def addmm_float8_unwrapped_inference(
out_dtype=output_dtype,
use_fast_accum=use_fast_accum,
)
output += bias
return output
output = torch._scaled_mm(
return output + bias
return torch._scaled_mm(
Comment on lines +97 to +98

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trimmin down those lines!

a_data,
b_data,
scale_a=a_scale,
Expand All @@ -106,7 +105,6 @@ def addmm_float8_unwrapped_inference(
out_dtype=output_dtype,
use_fast_accum=use_fast_accum,
)
return output


def _is_rowwise_scaled(x) -> bool:
Expand Down
Loading