|
7 | 7 | #include "quantization/gptq_marlin/marlin_dtypes.cuh"
|
8 | 8 | #include "core/scalar_type.hpp"
|
9 | 9 |
|
10 |
| -#define MARLIN_KERNEL_PARAMS \ |
11 |
| - const int4 *__restrict__ A, const int4 *__restrict__ B, \ |
12 |
| - int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ |
13 |
| - const int4 *__restrict__ scales_ptr, \ |
14 |
| - const uint16_t *__restrict__ scale2_ptr, \ |
15 |
| - const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ |
16 |
| - const int32_t *__restrict__ sorted_token_ids_ptr, \ |
17 |
| - const int32_t *__restrict__ expert_ids_ptr, \ |
18 |
| - const int32_t *__restrict__ num_tokens_past_padded_ptr, \ |
19 |
| - const float *__restrict__ topk_weights_ptr, int top_k, \ |
20 |
| - bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ |
21 |
| - int prob_n, int prob_k, int *locks, bool use_atomic_add, \ |
| 10 | +#define MARLIN_KERNEL_PARAMS \ |
| 11 | + const int4 *__restrict__ A, const int4 *__restrict__ B, \ |
| 12 | + int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ |
| 13 | + const int4 *__restrict__ b_bias_ptr, \ |
| 14 | + const int4 *__restrict__ scales_ptr, \ |
| 15 | + const uint16_t *__restrict__ scale2_ptr, \ |
| 16 | + const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ |
| 17 | + const int32_t *__restrict__ sorted_token_ids_ptr, \ |
| 18 | + const int32_t *__restrict__ expert_ids_ptr, \ |
| 19 | + const int32_t *__restrict__ num_tokens_past_padded_ptr, \ |
| 20 | + const float *__restrict__ topk_weights_ptr, int top_k, \ |
| 21 | + bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ |
| 22 | + int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \ |
22 | 23 | bool use_fp32_reduce, int max_shared_mem
|
23 | 24 |
|
24 | 25 | namespace MARLIN_NAMESPACE_NAME {
|
25 | 26 | template <typename scalar_t, // compute dtype, half or nv_float16
|
26 | 27 | const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
| 28 | + const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id |
27 | 29 | const int threads, // number of threads in a threadblock
|
28 | 30 | const int thread_m_blocks, // number of 16x16 blocks in the m
|
29 | 31 | // dimension (batchsize) of the
|
|
0 commit comments