Skip to content

Commit 60a5c2b

Browse files
YLGHfacebook-github-bot
authored andcommitted
Added quantization codecs to utils (#1219)
Summary: Pull Request resolved: #1219 This provides a generic codec interface that can be used for quantized comms. Reviewed By: colin2328, jianyuh Differential Revision: D38125284 fbshipit-source-id: 12a3a1e05e2893e1931e4782c47a75aab9840e52
1 parent a6f5488 commit 60a5c2b

File tree

2 files changed

+177
-0
lines changed

2 files changed

+177
-0
lines changed
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
9+
import logging
10+
from typing import Optional
11+
12+
import torch
13+
14+
from fbgemm_gpu.quantize_utils import (
15+
bf16_to_fp32,
16+
fp16_to_fp32,
17+
fp32_to_bf16_with_clamp,
18+
fp32_to_fp16_with_clamp,
19+
fp32_to_hfp8_with_clamp,
20+
hfp8_to_fp32,
21+
)
22+
from fbgemm_gpu.split_embedding_configs import SparseType
23+
from torch.autograd.profiler import record_function
24+
25+
logger: logging.Logger = logging.getLogger()
26+
27+
28+
def _quantize_tensor(
29+
input_tensor: torch.Tensor,
30+
comm_precision: SparseType,
31+
) -> torch.Tensor:
32+
if comm_precision == SparseType.FP32:
33+
return input_tensor
34+
elif comm_precision == SparseType.FP16:
35+
return fp32_to_fp16_with_clamp(input_tensor)
36+
elif comm_precision == SparseType.BF16:
37+
return fp32_to_bf16_with_clamp(input_tensor)
38+
elif comm_precision == SparseType.FP8:
39+
return fp32_to_hfp8_with_clamp(input_tensor)
40+
else:
41+
raise ValueError(f"comm_precision={comm_precision} is not supported")
42+
43+
44+
def _dequantize_tensor(
45+
quantized_tensor: torch.Tensor,
46+
comm_precision: SparseType,
47+
) -> torch.Tensor:
48+
if comm_precision == SparseType.FP32:
49+
assert quantized_tensor.dtype == torch.float
50+
return quantized_tensor
51+
elif comm_precision == SparseType.FP16:
52+
assert quantized_tensor.dtype == torch.half
53+
return fp16_to_fp32(quantized_tensor)
54+
elif comm_precision == SparseType.BF16:
55+
assert quantized_tensor.dtype == torch.bfloat16
56+
return bf16_to_fp32(quantized_tensor)
57+
elif comm_precision == SparseType.FP8:
58+
assert quantized_tensor.dtype == torch.uint8
59+
return hfp8_to_fp32(quantized_tensor)
60+
else:
61+
raise ValueError(f"comm_precision={comm_precision} is not supported")
62+
63+
64+
class QuantizedCommCodec:
65+
def __init__(
66+
self,
67+
comm_precision: SparseType,
68+
loss_scale: Optional[float] = None,
69+
) -> None:
70+
71+
if loss_scale is not None:
72+
if comm_precision not in [SparseType.FP16, SparseType.BF16]:
73+
logger.warning(
74+
f"Setting loss scale for comm_precision={comm_precision} is not supported. Overriding to None"
75+
)
76+
loss_scale = None
77+
78+
logger.info(
79+
f"Creating QuantizedCommsCodec comm_precision:{comm_precision}, loss_scale:{loss_scale}"
80+
)
81+
82+
self._comm_precision = comm_precision
83+
self._loss_scale = loss_scale
84+
85+
def encode(self, input_tensor: torch.Tensor) -> torch.Tensor:
86+
if self._loss_scale is not None:
87+
input_tensor = self._loss_scale * input_tensor
88+
with record_function(
89+
f"## encoder {self._comm_precision} {self._loss_scale} ##"
90+
):
91+
return _quantize_tensor(input_tensor, self._comm_precision)
92+
93+
def decode(self, input_grad: torch.Tensor) -> torch.Tensor:
94+
if self._loss_scale is not None:
95+
input_grad = input_grad / self._loss_scale
96+
with record_function(
97+
f"## decoder {self._comm_precision} {self._loss_scale} ##"
98+
):
99+
dequantized_tensor = _dequantize_tensor(input_grad, self._comm_precision)
100+
return dequantized_tensor
101+
102+
@property
103+
def quantized_dtype(self) -> torch.dtype:
104+
if self._comm_precision == SparseType.FP16:
105+
return torch.half
106+
elif self._comm_precision == SparseType.BF16:
107+
return torch.bfloat16
108+
elif self._comm_precision == SparseType.FP8:
109+
return torch.uint8
110+
return torch.float

fbgemm_gpu/test/quantize_comm_test.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import unittest
9+
from typing import Optional, Tuple
10+
11+
import hypothesis.strategies as st
12+
import torch
13+
from fbgemm_gpu.quantize_comm import QuantizedCommCodec
14+
from fbgemm_gpu.split_embedding_configs import SparseType
15+
from hypothesis import assume, given, settings
16+
17+
18+
class QuantizedCommCodecTest(unittest.TestCase):
19+
@settings(deadline=2000)
20+
# pyre-ignore
21+
@given(
22+
comm_precisions_loss_scale=st.sampled_from(
23+
[
24+
(SparseType.FP32, None),
25+
(SparseType.FP16, None),
26+
(SparseType.FP16, 4.0),
27+
(SparseType.BF16, None),
28+
(SparseType.BF16, 2.0),
29+
(SparseType.FP8, None),
30+
(SparseType.FP8, 3.0),
31+
]
32+
),
33+
row_size=st.integers(4, 256),
34+
col_size=st.integers(4, 256),
35+
rand_seed=st.integers(0, 65534),
36+
)
37+
def test_quantized_comm_codec(
38+
self,
39+
comm_precisions_loss_scale: Tuple[SparseType, Optional[float]],
40+
row_size: int,
41+
col_size: int,
42+
rand_seed: int,
43+
) -> None:
44+
45+
(comm_precision, loss_scale) = comm_precisions_loss_scale
46+
if comm_precision == SparseType.FP8:
47+
assume(col_size % 4 == 0)
48+
49+
torch.manual_seed(rand_seed)
50+
shape = (row_size, col_size)
51+
52+
quant_codec = QuantizedCommCodec(comm_precision, loss_scale)
53+
54+
input_tensor = torch.rand(shape, requires_grad=True)
55+
56+
quant_tensor = quant_codec.encode(input_tensor)
57+
output_tensor = quant_codec.decode(quant_tensor)
58+
59+
rtol = 0.005
60+
atol = 0.005
61+
if comm_precision == SparseType.FP8:
62+
rtol = 0.05
63+
atol = 0.05
64+
65+
torch.testing.assert_close(
66+
input_tensor.detach(), output_tensor.detach(), rtol=rtol, atol=atol
67+
)

0 commit comments

Comments
 (0)