Skip to content

[float8] add _auto_filter_for_recipe to float8 #2410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion torchao/float8/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
ScalingGranularity,
ScalingType,
)
from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.float8.float8_linear_utils import (
_auto_filter_for_recipe,
convert_to_float8_training,
)
from torchao.float8.float8_tensor import (
Float8Tensor,
GemmInputRole,
Expand Down Expand Up @@ -44,6 +47,7 @@
# top level UX
"convert_to_float8_training",
"precompute_float8_dynamic_scale_for_fsdp",
"_auto_filter_for_recipe",
# types
"FP8Granularity",
# note: Float8Tensor and Float8Linear are not public APIs
Expand Down
87 changes: 85 additions & 2 deletions torchao/float8/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Callable, Optional
from functools import partial
from typing import Callable, List, Optional, Union

import torch.nn as nn

from torchao.float8.config import Float8LinearConfig
from torchao.float8.config import Float8LinearConfig, Float8LinearRecipeName
from torchao.float8.float8_linear import Float8Linear

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -113,3 +114,85 @@ def convert_to_float8_training(
from_float,
module_filter_fn=module_filter_fn,
)


def _auto_filter_for_recipe(
recipe: Union[str, Float8LinearRecipeName], filter_fqns: List[str]
) -> Callable[[nn.Module, str], bool]:
"""Returns function which automatically filters nn.Linear modules that meet at least one of the following criteria:

1. Dims not divisible by 16 (hardware requirement for float8).
2. Dim sizes below certain thresholds, which may result in worse performance.

NOTE: the thresholds are simple heuristics based on performance testing, and may not be optimal
for your model. For the best performance, we recommend defining your own module_filter_fn customized for
your module, using the performance tables for the given float8 recipe here:
https://github.com/pytorch/ao/tree/main/torchao/float8#performance). These benchmarks referenced for
auto filtering layers were run on H100 GPUs, and may not be representative of other hardware.

This is an experimental API, the design may change in the future.
"""
if isinstance(recipe, str):
recipe = Float8LinearRecipeName(recipe)
if recipe == Float8LinearRecipeName.TENSORWISE:
return partial(_auto_filter_for_tensorwise, filter_fqns=filter_fqns)
elif recipe == Float8LinearRecipeName.ROWWISE:
return partial(_auto_filter_for_rowwise, filter_fqns=filter_fqns)
elif recipe == Float8LinearRecipeName.ROWWISE_WITH_GW_HP:
raise NotImplementedError(f"Unsupported recipe: {recipe}")
else:
raise ValueError(f"Invalid recipe: {recipe}")


def _auto_filter_for_rowwise(mod: nn.Module, fqn: str, filter_fqns: List[str]) -> bool:
if not isinstance(mod, nn.Linear):
return False

# If the fqn matches any filtered fqn, then we should not convert this module.
is_filtered_fqn = any(filter_fqn in fqn for filter_fqn in filter_fqns)
if is_filtered_fqn:
return False

# All dims must be divisible by 16 due to float8 hardware requirements.
N, K = mod.weight.shape
dims_multiples_of_16 = K % 16 == 0 and N % 16 == 0
if not dims_multiples_of_16:
return False

# Dims below these thresholds may result in worse performance
# (see https://github.com/pytorch/ao/tree/main/torchao/float8#rowwise-scaling)
# Note that these benchmarks referenced for auto filtering layers were run on
# H100 GPUs, and may not be representative of other hardware.
if N <= 2048:
return False
elif K <= 1024:
return False
elif N <= 4096 and K <= 2048:
return False
return True


def _auto_filter_for_tensorwise(
mod: nn.Module, fqn: str, filter_fqns: List[str]
) -> bool:
if not isinstance(mod, nn.Linear):
return False

# If the fqn matches any filtered fqn, then we should not convert this module.
is_filtered_fqn = any(filter_fqn in fqn for filter_fqn in filter_fqns)
if is_filtered_fqn:
return False

# All dims must be divisible by 16 due to float8 hardware requirements.
N, K = mod.weight.shape
dims_multiples_of_16 = K % 16 == 0 and N % 16 == 0
if not dims_multiples_of_16:
return False

# Dims below these thresholds may result in worse performance
# (see https://github.com/pytorch/ao/tree/main/torchao/float8#tensorwise-scaling)
# Note that these benchmarks referenced for auto filtering layers were run on
# H100 GPUs, and may not be representative of other hardware.
if K <= 4096 and N <= 1024:
return False
return True
Loading