Skip to content

Added quantization codecs to utils #1219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/quantize_comm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import logging
from typing import Optional

import torch

from fbgemm_gpu.quantize_utils import (
bf16_to_fp32,
fp16_to_fp32,
fp32_to_bf16_with_clamp,
fp32_to_fp16_with_clamp,
fp32_to_hfp8_with_clamp,
hfp8_to_fp32,
)
from fbgemm_gpu.split_embedding_configs import SparseType
from torch.autograd.profiler import record_function

logger: logging.Logger = logging.getLogger()


def _quantize_tensor(
input_tensor: torch.Tensor,
comm_precision: SparseType,
) -> torch.Tensor:
if comm_precision == SparseType.FP32:
return input_tensor
elif comm_precision == SparseType.FP16:
return fp32_to_fp16_with_clamp(input_tensor)
elif comm_precision == SparseType.BF16:
return fp32_to_bf16_with_clamp(input_tensor)
elif comm_precision == SparseType.FP8:
return fp32_to_hfp8_with_clamp(input_tensor)
else:
raise ValueError(f"comm_precision={comm_precision} is not supported")


def _dequantize_tensor(
quantized_tensor: torch.Tensor,
comm_precision: SparseType,
) -> torch.Tensor:
if comm_precision == SparseType.FP32:
assert quantized_tensor.dtype == torch.float
return quantized_tensor
elif comm_precision == SparseType.FP16:
assert quantized_tensor.dtype == torch.half
return fp16_to_fp32(quantized_tensor)
elif comm_precision == SparseType.BF16:
assert quantized_tensor.dtype == torch.bfloat16
return bf16_to_fp32(quantized_tensor)
elif comm_precision == SparseType.FP8:
assert quantized_tensor.dtype == torch.uint8
return hfp8_to_fp32(quantized_tensor)
else:
raise ValueError(f"comm_precision={comm_precision} is not supported")


class QuantizedCommCodec:
def __init__(
self,
comm_precision: SparseType,
loss_scale: Optional[float] = None,
) -> None:

if loss_scale is not None:
if comm_precision not in [SparseType.FP16, SparseType.BF16]:
logger.warning(
f"Setting loss scale for comm_precision={comm_precision} is not supported. Overriding to None"
)
loss_scale = None

logger.info(
f"Creating QuantizedCommsCodec comm_precision:{comm_precision}, loss_scale:{loss_scale}"
)

self._comm_precision = comm_precision
self._loss_scale = loss_scale

def encode(self, input_tensor: torch.Tensor) -> torch.Tensor:
if self._loss_scale is not None:
input_tensor = self._loss_scale * input_tensor
with record_function(
f"## encoder {self._comm_precision} {self._loss_scale} ##"
):
return _quantize_tensor(input_tensor, self._comm_precision)

def decode(self, input_grad: torch.Tensor) -> torch.Tensor:
if self._loss_scale is not None:
input_grad = input_grad / self._loss_scale
with record_function(
f"## decoder {self._comm_precision} {self._loss_scale} ##"
):
dequantized_tensor = _dequantize_tensor(input_grad, self._comm_precision)
return dequantized_tensor

@property
def quantized_dtype(self) -> torch.dtype:
if self._comm_precision == SparseType.FP16:
return torch.half
elif self._comm_precision == SparseType.BF16:
return torch.bfloat16
elif self._comm_precision == SparseType.FP8:
return torch.uint8
return torch.float
67 changes: 67 additions & 0 deletions fbgemm_gpu/test/quantize_comm_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest
from typing import Optional, Tuple

import hypothesis.strategies as st
import torch
from fbgemm_gpu.quantize_comm import QuantizedCommCodec
from fbgemm_gpu.split_embedding_configs import SparseType
from hypothesis import assume, given, settings


class QuantizedCommCodecTest(unittest.TestCase):
@settings(deadline=2000)
# pyre-ignore
@given(
comm_precisions_loss_scale=st.sampled_from(
[
(SparseType.FP32, None),
(SparseType.FP16, None),
(SparseType.FP16, 4.0),
(SparseType.BF16, None),
(SparseType.BF16, 2.0),
(SparseType.FP8, None),
(SparseType.FP8, 3.0),
]
),
row_size=st.integers(4, 256),
col_size=st.integers(4, 256),
rand_seed=st.integers(0, 65534),
)
def test_quantized_comm_codec(
self,
comm_precisions_loss_scale: Tuple[SparseType, Optional[float]],
row_size: int,
col_size: int,
rand_seed: int,
) -> None:

(comm_precision, loss_scale) = comm_precisions_loss_scale
if comm_precision == SparseType.FP8:
assume(col_size % 4 == 0)

torch.manual_seed(rand_seed)
shape = (row_size, col_size)

quant_codec = QuantizedCommCodec(comm_precision, loss_scale)

input_tensor = torch.rand(shape, requires_grad=True)

quant_tensor = quant_codec.encode(input_tensor)
output_tensor = quant_codec.decode(quant_tensor)

rtol = 0.005
atol = 0.005
if comm_precision == SparseType.FP8:
rtol = 0.05
atol = 0.05

torch.testing.assert_close(
input_tensor.detach(), output_tensor.detach(), rtol=rtol, atol=atol
)