Skip to content

Commit ded0931

Browse files
add tests
1 parent b133535 commit ded0931

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

torchao/float8/float8_linear_utils.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66
import logging
77
from functools import partial
8-
from typing import Callable, List, Optional
8+
from typing import Callable, List, Optional, Union
99

1010
import torch.nn as nn
1111

@@ -117,27 +117,28 @@ def convert_to_float8_training(
117117

118118

119119
def _auto_filter_for_recipe(
120-
recipe: Float8LinearRecipeName, filter_fqns: List[str]
120+
recipe: Union[str, Float8LinearRecipeName], filter_fqns: List[str]
121121
) -> Callable[[nn.Module, str], bool]:
122-
"""Automatically filters nn.Linear modules that meet at least one of the following criteria:
122+
"""Returns function which automatically filters nn.Linear modules that meet at least one of the following criteria:
123123
124124
1. Dims not divisible by 16 (hardware requirement for float8).
125-
2. Dim sizes below certain thresholds, which will result in worse performance.
125+
2. Dim sizes below certain thresholds, which may result in worse performance.
126126
127127
NOTE: the thresholds are simple heuristics based on performance testing, and may not be optimal
128128
for your model. For the best performance, we recommend defining your own module_filter_fn customized for
129129
your module, using the performance tables for the given float8 recipe here:
130-
https://github.com/pytorch/ao/tree/main/torchao/float8#performance). Note that the benchmarks referenced
131-
for auto filtering layers were run on H100 GPUs, and may not be representative of other hardware.
130+
https://github.com/pytorch/ao/tree/main/torchao/float8#performance). These benchmarks referenced for
131+
auto filtering layers were run on H100 GPUs, and may not be representative of other hardware.
132132
133-
134-
The design of this function may change in the future.
133+
This is an experimental API, the design may change in the future.
135134
"""
136-
if recipe == Float8LinearRecipeName.TENSORWISE.value:
135+
if isinstance(recipe, str):
136+
recipe = Float8LinearRecipeName(recipe)
137+
if recipe == Float8LinearRecipeName.TENSORWISE:
137138
return partial(_auto_filter_for_tensorwise, filter_fqns=filter_fqns)
138-
elif recipe == Float8LinearRecipeName.ROWWISE.value:
139+
elif recipe == Float8LinearRecipeName.ROWWISE:
139140
return partial(_auto_filter_for_rowwise, filter_fqns=filter_fqns)
140-
elif recipe == Float8LinearRecipeName.ROWWISE_WITH_GW_HP.value:
141+
elif recipe == Float8LinearRecipeName.ROWWISE_WITH_GW_HP:
141142
raise NotImplementedError(f"Unsupported recipe: {recipe}")
142143
else:
143144
raise ValueError(f"Invalid recipe: {recipe}")
@@ -153,7 +154,7 @@ def _auto_filter_for_rowwise(mod: nn.Module, fqn: str, filter_fqns: List[str]) -
153154
return False
154155

155156
# All dims must be divisible by 16 due to float8 hardware requirements.
156-
K, N = mod.weight.shape
157+
N, K = mod.weight.shape
157158
dims_multiples_of_16 = K % 16 == 0 and N % 16 == 0
158159
if not dims_multiples_of_16:
159160
return False
@@ -183,7 +184,7 @@ def _auto_filter_for_tensorwise(
183184
return False
184185

185186
# All dims must be divisible by 16 due to float8 hardware requirements.
186-
K, N = mod.weight.shape
187+
N, K = mod.weight.shape
187188
dims_multiples_of_16 = K % 16 == 0 and N % 16 == 0
188189
if not dims_multiples_of_16:
189190
return False

0 commit comments

Comments
 (0)