|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# All rights reserved. |
| 4 | +# |
| 5 | +# This source code is licensed under the BSD-style license found in the |
| 6 | +# LICENSE file in the root directory of this source tree. |
| 7 | + |
| 8 | + |
| 9 | +import logging |
| 10 | +from typing import Optional |
| 11 | + |
| 12 | +import torch |
| 13 | + |
| 14 | +from fbgemm_gpu.quantize_utils import ( |
| 15 | + bf16_to_fp32, |
| 16 | + fp16_to_fp32, |
| 17 | + fp32_to_bf16_with_clamp, |
| 18 | + fp32_to_fp16_with_clamp, |
| 19 | + fp32_to_hfp8_with_clamp, |
| 20 | + hfp8_to_fp32, |
| 21 | +) |
| 22 | +from fbgemm_gpu.split_embedding_configs import SparseType |
| 23 | +from torch.autograd.profiler import record_function |
| 24 | + |
| 25 | +logger: logging.Logger = logging.getLogger() |
| 26 | + |
| 27 | + |
| 28 | +def _quantize_tensor( |
| 29 | + input_tensor: torch.Tensor, |
| 30 | + comm_precision: SparseType, |
| 31 | +) -> torch.Tensor: |
| 32 | + if comm_precision == SparseType.FP32: |
| 33 | + return input_tensor |
| 34 | + elif comm_precision == SparseType.FP16: |
| 35 | + return fp32_to_fp16_with_clamp(input_tensor) |
| 36 | + elif comm_precision == SparseType.BF16: |
| 37 | + return fp32_to_bf16_with_clamp(input_tensor) |
| 38 | + elif comm_precision == SparseType.FP8: |
| 39 | + return fp32_to_hfp8_with_clamp(input_tensor) |
| 40 | + else: |
| 41 | + raise ValueError(f"comm_precision={comm_precision} is not supported") |
| 42 | + |
| 43 | + |
| 44 | +def _dequantize_tensor( |
| 45 | + quantized_tensor: torch.Tensor, |
| 46 | + comm_precision: SparseType, |
| 47 | +) -> torch.Tensor: |
| 48 | + if comm_precision == SparseType.FP32: |
| 49 | + assert quantized_tensor.dtype == torch.float |
| 50 | + return quantized_tensor |
| 51 | + elif comm_precision == SparseType.FP16: |
| 52 | + assert quantized_tensor.dtype == torch.half |
| 53 | + return fp16_to_fp32(quantized_tensor) |
| 54 | + elif comm_precision == SparseType.BF16: |
| 55 | + assert quantized_tensor.dtype == torch.bfloat16 |
| 56 | + return bf16_to_fp32(quantized_tensor) |
| 57 | + elif comm_precision == SparseType.FP8: |
| 58 | + assert quantized_tensor.dtype == torch.uint8 |
| 59 | + return hfp8_to_fp32(quantized_tensor) |
| 60 | + else: |
| 61 | + raise ValueError(f"comm_precision={comm_precision} is not supported") |
| 62 | + |
| 63 | + |
| 64 | +class QuantizedCommCodec: |
| 65 | + def __init__( |
| 66 | + self, |
| 67 | + comm_precision: SparseType, |
| 68 | + loss_scale: Optional[float] = None, |
| 69 | + ) -> None: |
| 70 | + |
| 71 | + if loss_scale is not None: |
| 72 | + if comm_precision not in [SparseType.FP16, SparseType.BF16]: |
| 73 | + logger.warning( |
| 74 | + f"Setting loss scale for comm_precision={comm_precision} is not supported. Overriding to None" |
| 75 | + ) |
| 76 | + loss_scale = None |
| 77 | + |
| 78 | + logger.info( |
| 79 | + f"Creating QuantizedCommsCodec comm_precision:{comm_precision}, loss_scale:{loss_scale}" |
| 80 | + ) |
| 81 | + |
| 82 | + self._comm_precision = comm_precision |
| 83 | + self._loss_scale = loss_scale |
| 84 | + |
| 85 | + def encode(self, input_tensor: torch.Tensor) -> torch.Tensor: |
| 86 | + if self._loss_scale is not None: |
| 87 | + input_tensor = self._loss_scale * input_tensor |
| 88 | + with record_function( |
| 89 | + f"## encoder {self._comm_precision} {self._loss_scale} ##" |
| 90 | + ): |
| 91 | + return _quantize_tensor(input_tensor, self._comm_precision) |
| 92 | + |
| 93 | + def decode(self, input_grad: torch.Tensor) -> torch.Tensor: |
| 94 | + if self._loss_scale is not None: |
| 95 | + input_grad = input_grad / self._loss_scale |
| 96 | + with record_function( |
| 97 | + f"## decoder {self._comm_precision} {self._loss_scale} ##" |
| 98 | + ): |
| 99 | + dequantized_tensor = _dequantize_tensor(input_grad, self._comm_precision) |
| 100 | + return dequantized_tensor |
| 101 | + |
| 102 | + @property |
| 103 | + def quantized_dtype(self) -> torch.dtype: |
| 104 | + if self._comm_precision == SparseType.FP16: |
| 105 | + return torch.half |
| 106 | + elif self._comm_precision == SparseType.BF16: |
| 107 | + return torch.bfloat16 |
| 108 | + elif self._comm_precision == SparseType.FP8: |
| 109 | + return torch.uint8 |
| 110 | + return torch.float |
0 commit comments