Skip to content

Commit 0e00df3

Browse files
authored
Add operation utils.
Differential Revision: D77760997 Pull Request resolved: #2540
1 parent 11ce634 commit 0e00df3

File tree

1 file changed

+126
-0
lines changed

1 file changed

+126
-0
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
#include <torchao/experimental/ops/groupwise_lowbit_weight_lut/kernel_config.h>
9+
#include <optional>
10+
#include <vector>
11+
12+
namespace torchao::ops::groupwise_lowbit_weight_lut {
13+
14+
/**
15+
* @brief Orchestrates the packing of quantized weights into a kernel-specific
16+
* memory layout.
17+
*
18+
* @details This function acts as a high-level operator that parallelizes the
19+
* weight packing process across the N dimension. It partitions the work into
20+
* tiles, calculates the correct memory offsets for each tile's source and
21+
* destination pointers, and then invokes the low-level `pack_weights` function
22+
* provided by the kernel configuration (`uk`).
23+
*
24+
* @param uk The kernel configuration, providing layout details, function
25+
* pointers, and dimension constraints (nr, kr).
26+
* @param packed_weights_ptr [out] The destination buffer for the packed weight
27+
* data.
28+
* @param n The N dimension of the weight matrix (e.g., output channels).
29+
* @param k The K dimension of the weight matrix (e.g., input channels).
30+
* @param scale_group_size The group size for weight quantization scales.
31+
* @param lut_group_size The group size for weight lookup tables (LUTs).
32+
* @param weight_qval_indices [in] Pointer to the raw quantized weight indices.
33+
* @param weight_scales [in] Pointer to the raw weight quantization scales.
34+
* @param weight_luts [in] Pointer to the raw weight lookup tables.
35+
* @param bias [in] Pointer to the raw bias values; can be nullptr if the kernel
36+
* configuration indicates no bias is used.
37+
*/
38+
void pack_weights_operator(
39+
const UKernelConfig& uk,
40+
// Outputs
41+
void* packed_weights_ptr,
42+
// Inputs
43+
int n,
44+
int k,
45+
int scale_group_size,
46+
int lut_group_size,
47+
const uint8_t* weight_qval_indices,
48+
const float* weight_scales,
49+
const float* weight_luts,
50+
const float* bias);
51+
52+
struct GroupwiseTilingParams {
53+
int mc;
54+
int nc;
55+
56+
/**
57+
* @brief Calculates groupwise tiling parameters based on a target number of
58+
* tiles per thread.
59+
*
60+
* @details This function implements a heuristic to determine optimal tile
61+
* sizes (`mc`, `nc`) for balancing a computational workload across multiple
62+
* threads. It calculates the number of tiles needed to cover the M dimension
63+
* and uses this, along with the target number of tiles per thread, to derive
64+
* a suitable tile count in the N dimension. This count is then scaled by
65+
* `n_step` to get the final `nc` value. The resulting tile sizes are clamped
66+
* to not exceed the original problem dimensions.
67+
*
68+
* @param m The total size of the M dimension (e.g., rows).
69+
* @param m_step The required step size for tiling in the M dimension.
70+
* @param n The total size of the N dimension (e.g., columns).
71+
* @param n_step The required step size for tiling in the N dimension.
72+
* @param target_tiles_per_thread A tuning parameter that suggests how many
73+
* tiles each thread should ideally process, influencing the calculated tile
74+
* sizes.
75+
* @return A `GroupwiseTilingParams` struct containing the computed `mc` and
76+
* `nc`.
77+
*/
78+
static GroupwiseTilingParams from_target_tiles_per_thread(
79+
int m,
80+
int m_step,
81+
int n,
82+
int n_step,
83+
int target_tiles_per_thread);
84+
};
85+
86+
/**
87+
* @brief Executes a parallel linear operation using a groupwise low-bit LUT
88+
* kernel.
89+
*
90+
* @details This function acts as a high-level operator for performing a linear
91+
* operation (GEMM-like) with quantized weights.
92+
*
93+
* @param uk The kernel configuration, providing layout details and function
94+
* pointers.
95+
* @param tiling_params [in] Optional. User-provided tiling parameters (mc, nc).
96+
* If not provided, the operator will calculate them dynamically.
97+
* @param output [out] The destination buffer for the output matrix.
98+
* @param m The M dimension of the output matrix (e.g., rows).
99+
* @param n The N dimension of the output matrix (e.g., columns).
100+
* @param k The K dimension, shared between the weights and activations.
101+
* @param scale_group_size The group size for weight quantization scales.
102+
* @param lut_group_size The group size for weight lookup tables (LUTs).
103+
* @param packed_weights [in] Pointer to the pre-packed weight data.
104+
* @param activations [in] Pointer to the raw activation data.
105+
* @param has_clamp A boolean flag indicating whether to apply clamping to the
106+
* output.
107+
* @param clamp_min The minimum value for output clamping.
108+
* @param clamp_max The maximum value for output clamping.
109+
*/
110+
void groupwise_lowbit_weight_lut_parallel_operator(
111+
const UKernelConfig& uk,
112+
const std::optional<GroupwiseTilingParams>& tiling_params,
113+
// Outputs
114+
float* output,
115+
// Inputs
116+
int m,
117+
int n,
118+
int k,
119+
int scale_group_size,
120+
int lut_group_size,
121+
const void* packed_weights,
122+
const float* activations,
123+
bool has_clamp,
124+
float clamp_min,
125+
float clamp_max);
126+
} // namespace torchao::ops::groupwise_lowbit_weight_lut

0 commit comments

Comments
 (0)