5
5
# LICENSE file in the root directory of this source tree.
6
6
import logging
7
7
from functools import partial
8
- from typing import Callable , List , Optional
8
+ from typing import Callable , List , Optional , Union
9
9
10
10
import torch .nn as nn
11
11
@@ -117,27 +117,28 @@ def convert_to_float8_training(
117
117
118
118
119
119
def _auto_filter_for_recipe (
120
- recipe : Float8LinearRecipeName , filter_fqns : List [str ]
120
+ recipe : Union [ str , Float8LinearRecipeName ] , filter_fqns : List [str ]
121
121
) -> Callable [[nn .Module , str ], bool ]:
122
- """Automatically filters nn.Linear modules that meet at least one of the following criteria:
122
+ """Returns function which automatically filters nn.Linear modules that meet at least one of the following criteria:
123
123
124
124
1. Dims not divisible by 16 (hardware requirement for float8).
125
- 2. Dim sizes below certain thresholds, which will result in worse performance.
125
+ 2. Dim sizes below certain thresholds, which may result in worse performance.
126
126
127
127
NOTE: the thresholds are simple heuristics based on performance testing, and may not be optimal
128
128
for your model. For the best performance, we recommend defining your own module_filter_fn customized for
129
129
your module, using the performance tables for the given float8 recipe here:
130
- https://github.com/pytorch/ao/tree/main/torchao/float8#performance). Note that the benchmarks referenced
131
- for auto filtering layers were run on H100 GPUs, and may not be representative of other hardware.
130
+ https://github.com/pytorch/ao/tree/main/torchao/float8#performance). These benchmarks referenced for
131
+ auto filtering layers were run on H100 GPUs, and may not be representative of other hardware.
132
132
133
-
134
- The design of this function may change in the future.
133
+ This is an experimental API, the design may change in the future.
135
134
"""
136
- if recipe == Float8LinearRecipeName .TENSORWISE .value :
135
+ if isinstance (recipe , str ):
136
+ recipe = Float8LinearRecipeName (recipe )
137
+ if recipe == Float8LinearRecipeName .TENSORWISE :
137
138
return partial (_auto_filter_for_tensorwise , filter_fqns = filter_fqns )
138
- elif recipe == Float8LinearRecipeName .ROWWISE . value :
139
+ elif recipe == Float8LinearRecipeName .ROWWISE :
139
140
return partial (_auto_filter_for_rowwise , filter_fqns = filter_fqns )
140
- elif recipe == Float8LinearRecipeName .ROWWISE_WITH_GW_HP . value :
141
+ elif recipe == Float8LinearRecipeName .ROWWISE_WITH_GW_HP :
141
142
raise NotImplementedError (f"Unsupported recipe: { recipe } " )
142
143
else :
143
144
raise ValueError (f"Invalid recipe: { recipe } " )
@@ -153,7 +154,7 @@ def _auto_filter_for_rowwise(mod: nn.Module, fqn: str, filter_fqns: List[str]) -
153
154
return False
154
155
155
156
# All dims must be divisible by 16 due to float8 hardware requirements.
156
- K , N = mod .weight .shape
157
+ N , K = mod .weight .shape
157
158
dims_multiples_of_16 = K % 16 == 0 and N % 16 == 0
158
159
if not dims_multiples_of_16 :
159
160
return False
@@ -183,7 +184,7 @@ def _auto_filter_for_tensorwise(
183
184
return False
184
185
185
186
# All dims must be divisible by 16 due to float8 hardware requirements.
186
- K , N = mod .weight .shape
187
+ N , K = mod .weight .shape
187
188
dims_multiples_of_16 = K % 16 == 0 and N % 16 == 0
188
189
if not dims_multiples_of_16 :
189
190
return False
0 commit comments