From 223d7d781410b4f2a0a6bb8dd1a5c0a5df38c88c Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 3 Jul 2025 11:10:09 +0000 Subject: [PATCH] fix lint Signed-off-by: jiqing-feng --- csrc/pythonInterface.cpp | 76 +++---- csrc/xpu_kernels.cpp | 431 ++++++++++++++++++--------------------- csrc/xpu_kernels.h | 89 ++++---- csrc/xpu_ops.cpp | 150 +++++++------- csrc/xpu_ops.h | 43 ++-- 5 files changed, 375 insertions(+), 414 deletions(-) diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index aa577d853..b5d9afc6b 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -314,81 +314,83 @@ void spmm_coo_very_sparse_naive_int8( #if BUILD_XPU void dequantizeBlockwise_fp16( - float *code, unsigned char *A, float *absmax, sycl::half *out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream ) { dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_fp16_fp4( - float *code, unsigned char *A, float *absmax, sycl::half *out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream ) { dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_fp16_nf4( - float *code, unsigned char *A, float *absmax, sycl::half *out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream ) { dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_fp32( - float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream ) { dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_fp32_fp4( - float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream ) { dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_fp32_nf4( - float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream ) { dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_bf16( - float *code, unsigned char *A, float *absmax, sycl::ext::oneapi::bfloat16 *out, int blocksize, const int n, + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, sycl::queue* stream ) { dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_bf16_fp4( - float *code, unsigned char *A, float *absmax, sycl::ext::oneapi::bfloat16 *out, int blocksize, const int n, + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, sycl::queue* stream ) { dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_bf16_nf4( - float *code, unsigned char *A, float *absmax, sycl::ext::oneapi::bfloat16 *out, int blocksize, const int n, + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, sycl::queue* stream -) { +) { dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); } void gemv_4bit_inference_fp16( - int m, int n, int k, sycl::half * A, unsigned char* B, float *absmax, float *datatype, sycl::half * out, - int lda, int ldb, int ldc, int blocksize, sycl::queue* stream + int m, int n, int k, sycl::half* A, unsigned char* B, float* absmax, float* datatype, sycl::half* out, int lda, + int ldb, int ldc, int blocksize, sycl::queue* stream ) { - gemv_4bit_inference(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); + gemv_4bit_inference(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } void gemv_4bit_inference_bf16( - int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, float *absmax, float *datatype, - sycl::ext::oneapi::bfloat16 * out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream + int m, int n, int k, sycl::ext::oneapi::bfloat16* A, unsigned char* B, float* absmax, float* datatype, + sycl::ext::oneapi::bfloat16* out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream ) { - gemv_4bit_inference(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); + gemv_4bit_inference( + m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream + ); } void gemv_4bit_inference_fp32( - int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, - int ldb, int ldc, int blocksize, sycl::queue* stream + int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb, + int ldc, int blocksize, sycl::queue* stream ) { - gemv_4bit_inference(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); + gemv_4bit_inference(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } #endif @@ -746,81 +748,81 @@ void cgemm_4bit_inference_naive_fp32( #if BUILD_XPU void cdequantize_blockwise_fp16_fp4( - float *code, unsigned char *A, float *absmax, sycl::half *out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream ) { dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n, stream); } void cdequantize_blockwise_fp16( - float *code, unsigned char *A, float *absmax, sycl::half *out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream ) { dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n, stream); } void cdequantize_blockwise_fp16_nf4( - float *code, unsigned char *A, float *absmax, sycl::half *out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream ) { dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n, stream); } void cdequantize_blockwise_fp32( - float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream ) { dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n, stream); } void cdequantize_blockwise_fp32_fp4( - float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream ) { dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n, stream); } void cdequantize_blockwise_fp32_nf4( - float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, sycl::queue* stream + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream ) { dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n, stream); } void cdequantize_blockwise_bf16( - float *code, unsigned char *A, float *absmax, sycl::ext::oneapi::bfloat16 *out, int blocksize, const int n, + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, sycl::queue* stream ) { dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n, stream); } void cdequantize_blockwise_bf16_fp4( - float *code, unsigned char *A, float *absmax, sycl::ext::oneapi::bfloat16 *out, int blocksize, const int n, + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, sycl::queue* stream ) { dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n, stream); } void cdequantize_blockwise_bf16_nf4( - float *code, unsigned char *A, float *absmax, sycl::ext::oneapi::bfloat16 *out, int blocksize, const int n, + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, sycl::queue* stream ) { dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream); } void cgemv_4bit_inference_fp16( - int m, int n, int k, sycl::half * A, unsigned char* B, float *absmax, float *datatype, sycl::half * out, - int lda, int ldb, int ldc, int blocksize, sycl::queue* stream + int m, int n, int k, sycl::half* A, unsigned char* B, float* absmax, float* datatype, sycl::half* out, int lda, + int ldb, int ldc, int blocksize, sycl::queue* stream ) { - gemv_4bit_inference_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); + gemv_4bit_inference_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } void cgemv_4bit_inference_bf16( - int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, float *absmax, float *datatype, - sycl::ext::oneapi::bfloat16 * out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream + int m, int n, int k, sycl::ext::oneapi::bfloat16* A, unsigned char* B, float* absmax, float* datatype, + sycl::ext::oneapi::bfloat16* out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream ) { - gemv_4bit_inference_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); + gemv_4bit_inference_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } void cgemv_4bit_inference_fp32( - int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, - int ldb, int ldc, int blocksize, sycl::queue* stream + int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb, + int ldc, int blocksize, sycl::queue* stream ) { - gemv_4bit_inference_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); + gemv_4bit_inference_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } #endif diff --git a/csrc/xpu_kernels.cpp b/csrc/xpu_kernels.cpp index 9bdbd6e31..efc5e6fbe 100644 --- a/csrc/xpu_kernels.cpp +++ b/csrc/xpu_kernels.cpp @@ -6,281 +6,258 @@ #include inline float dDequantizeFP4(unsigned char val) { - if ((val & 0b1000) == 8) - if ((val & 0b0100) == 4) - if ((val & 0b0010) == 2) - if ((val & 0b0001) == 1) - return -0.25000000f; + if ((val & 0b1000) == 8) + if ((val & 0b0100) == 4) + if ((val & 0b0010) == 2) + if ((val & 0b0001) == 1) + return -0.25000000f; + else + return -0.16666667f; + else if ((val & 0b0001) == 1) + return -0.50000000f; + else + return -0.33333333f; + else if ((val & 0b0010) == 2) + if ((val & 0b0001) == 1) + return -1.00000000f; + else + return -0.66666667f; + else if ((val & 0b0001) == 1) + return -5.208333333e-03f; + else + return 0.00000000f; + else if ((val & 0b0100) == 4) + if ((val & 0b0010) == 2) + if ((val & 0b0001) == 1) + return 0.25000000f; + else + return 0.16666667f; + else if ((val & 0b0001) == 1) + return 0.50000000f; else - return -0.16666667f; - else if ((val & 0b0001) == 1) - return -0.50000000f; - else - return -0.33333333f; + return 0.33333333f; else if ((val & 0b0010) == 2) - if ((val & 0b0001) == 1) - return -1.00000000f; - else - return -0.66666667f; - else if ((val & 0b0001) == 1) - return -5.208333333e-03f; - else - return 0.00000000f; - else if ((val & 0b0100) == 4) - if ((val & 0b0010) == 2) - if ((val & 0b0001) == 1) - return 0.25000000f; - else - return 0.16666667f; + if ((val & 0b0001) == 1) + return 1.00000000f; + else + return 0.66666667f; else if ((val & 0b0001) == 1) - return 0.50000000f; + return 5.208333333e-03f; else - return 0.33333333f; - else if ((val & 0b0010) == 2) - if ((val & 0b0001) == 1) - return 1.00000000f; - else - return 0.66666667f; - else if ((val & 0b0001) == 1) - return 5.208333333e-03f; - else - return 0.00000000f; + return 0.00000000f; } inline float dDequantizeNF4(unsigned char val) { - // the values for this tree was generated by test_normal_map_tree - // in the file tests/test_functional.py - if ((val & 0b1000) == 8) - if ((val & 0b0100) == 4) // 1 - if ((val & 0b0010) == 2) // 11 - if ((val & 0b0001) == 1) // 111 - return 1.0f; //*1111 + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if ((val & 0b1000) == 8) + if ((val & 0b0100) == 4) // 1 + if ((val & 0b0010) == 2) // 11 + if ((val & 0b0001) == 1) // 111 + return 1.0f; //*1111 + else + return 0.7229568362236023f; //*1110 + else if ((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; //*1101 + else + return 0.44070982933044434f; //*1100 + else if ((val & 0b0010) == 2) // 10 + if ((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; //*1011 + else + return 0.24611230194568634f; //*1010 + else if ((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; //*1001 else - return 0.7229568362236023f; //*1110 - else if ((val & 0b0001) == 1) // 110 - return 0.5626170039176941f; //*1101 - else - return 0.44070982933044434f; //*1100 - else if ((val & 0b0010) == 2) // 10 - if ((val & 0b0001) == 1) // 101 - return 0.33791524171829224f; //*1011 - else - return 0.24611230194568634f; //*1010 - else if ((val & 0b0001) == 1) // 100 - return 0.16093020141124725f; //*1001 - else - return 0.07958029955625534f; //*1000 + return 0.07958029955625534f; //*1000 - else if ((val & 0b0100) == 4) // 0 - if ((val & 0b0010) == 2) // 01 - if ((val & 0b0001) == 1) // 011 - return 0.0f; //*0111 - else - return -0.09105003625154495f; //*0110 - else if ((val & 0b0001) == 1) // 010 - return -0.18477343022823334f; //*0101 - else - return -0.28444138169288635f; //*0100 - else if ((val & 0b0010) == 2) // 00 - if ((val & 0b0001) == 1) // 001 - return -0.39491748809814453f; //*0011 + else if ((val & 0b0100) == 4) // 0 + if ((val & 0b0010) == 2) // 01 + if ((val & 0b0001) == 1) // 011 + return 0.0f; //*0111 + else + return -0.09105003625154495f; //*0110 + else if ((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; //*0101 + else + return -0.28444138169288635f; //*0100 + else if ((val & 0b0010) == 2) // 00 + if ((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; //*0011 + else + return -0.5250730514526367f; //*0010 + else if ((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; //*0001 else - return -0.5250730514526367f; //*0010 - else if ((val & 0b0001) == 1) // 000 - return -0.6961928009986877f; //*0001 - else - return -1.0f; //*0000 + return -1.0f; //*0000 } template -SYCL_EXTERNAL void -kDequantizeBlockwise::operator()( - sycl::nd_item<1> item) const { - const int base_idx = item.get_group(0) * TILE_SIZE; - size_t local_idx = item.get_local_id(0) * NUM_PER_TH; - float local_abs_max = -FLT_MAX; - int local_load_idx = 0; - int local_store_idx = 0; +SYCL_EXTERNAL void kDequantizeBlockwise::operator()(sycl::and_item<1> item) const { + const int base_idx = item.get_group(0) * TILE_SIZE; + size_t local_idx = item.get_local_id(0) * NUM_PER_TH; + float local_abs_max = -FLT_MAX; + int local_load_idx = 0; + int local_store_idx = 0; - uint8_t qvals[NUM_PER_TH]; - T vals[NUM_PER_TH * ((DATA_TYPE > 0) ? 2 : 1)]; + uint8_t qvals[NUM_PER_TH]; + T vals[NUM_PER_TH * ((DATA_TYPE > 0) ? 2 : 1)]; - if (DATA_TYPE > 0) { - local_load_idx = sycl::min(TILE_SIZE, (n + 1) / 2 - base_idx); - local_store_idx = sycl::min(TILE_SIZE * 2, n - base_idx * 2); - } else { - local_load_idx = sycl::min(TILE_SIZE, n - base_idx); - local_store_idx = local_load_idx; - } + if (DATA_TYPE > 0) { + local_load_idx = sycl::min(TILE_SIZE, (n + 1) / 2 - base_idx); + local_store_idx = sycl::min(TILE_SIZE * 2, n - base_idx * 2); + } else { + local_load_idx = sycl::min(TILE_SIZE, n - base_idx); + local_store_idx = local_load_idx; + } - // Avoid expensive divsion by the blocksize (as blocksize will always be a - // power-of-2) - local_abs_max = absmax[(base_idx + local_idx) >> - (31 - std::countl_zero(blocksize))]; + // Avoid expensive division by the blocksize (as blocksize will always be a + // power-of-2) + local_abs_max = absmax[(base_idx + local_idx) >> (31 - std::countl_zero(blocksize))]; - if (local_idx + NUM_PER_TH < local_load_idx) { - reinterpret_cast(&)[NUM_PER_TH]>(qvals)[0] = - reinterpret_cast *>( - A)[(base_idx + local_idx) / NUM_PER_TH]; - } else { + if (local_idx + NUM_PER_TH < local_load_idx) { + reinterpret_cast(&)[NUM_PER_TH]>(qvals)[0] = + reinterpret_cast*>(A)[(base_idx + local_idx) / NUM_PER_TH]; + } else { #pragma unroll NUM_PER_TH - for (int i = 0; i < NUM_PER_TH; i++) { - if (local_idx + i < local_load_idx) { - qvals[i] = A[base_idx + local_idx + i]; - } else { - qvals[i] = (uint8_t)0; - } + for (int i = 0; i < NUM_PER_TH; i++) { + if (local_idx + i < local_load_idx) { + qvals[i] = A[base_idx + local_idx + i]; + } else { + qvals[i] = (uint8_t)0; + } + } } - } - switch (DATA_TYPE) { - case General8bit: + switch (DATA_TYPE) { + case General8bit: #pragma unroll NUM_PER_TH - for (int j = 0; j < NUM_PER_TH; j++) - vals[j] = code[qvals[j]] * local_abs_max; - break; - case FP4: + for (int j = 0; j < NUM_PER_TH; j++) + vals[j] = code[qvals[j]] * local_abs_max; + break; + case FP4: #pragma unroll NUM_PER_TH - for (int j = 0; j < NUM_PER_TH; j++) { - vals[j * 2] = dDequantizeFP4(qvals[j] >> 4) * local_abs_max; - vals[j * 2 + 1] = dDequantizeFP4(qvals[j] & 0x0F) * local_abs_max; - } - break; - case NF4: + for (int j = 0; j < NUM_PER_TH; j++) { + vals[j * 2] = dDequantizeFP4(qvals[j] >> 4) * local_abs_max; + vals[j * 2 + 1] = dDequantizeFP4(qvals[j] & 0x0F) * local_abs_max; + } + break; + case NF4: #pragma unroll NUM_PER_TH - for (int j = 0; j < NUM_PER_TH; j++) { - vals[j * 2] = dDequantizeNF4(qvals[j] >> 4) * local_abs_max; - vals[j * 2 + 1] = dDequantizeNF4(qvals[j] & 0x0F) * local_abs_max; + for (int j = 0; j < NUM_PER_TH; j++) { + vals[j * 2] = dDequantizeNF4(qvals[j] >> 4) * local_abs_max; + vals[j * 2 + 1] = dDequantizeNF4(qvals[j] & 0x0F) * local_abs_max; + } + break; } - break; - } - const int local_dst_size = (DATA_TYPE > 0) ? NUM_PER_TH * 2 : NUM_PER_TH; - int local_dst_idx = (DATA_TYPE > 0) ? local_idx * 2 : local_idx; + const int local_dst_size = (DATA_TYPE > 0) ? NUM_PER_TH * 2 : NUM_PER_TH; + int local_dst_idx = (DATA_TYPE > 0) ? local_idx * 2 : local_idx; - if (local_dst_idx + local_dst_size < local_store_idx) { - reinterpret_cast *>( - out)[(((DATA_TYPE > 0) ? base_idx * 2 : base_idx) + local_dst_idx) / - local_dst_size] = - reinterpret_cast(&)[local_dst_size]>( - vals)[0]; - } else { + if (local_dst_idx + local_dst_size < local_store_idx) { + reinterpret_cast*>( + out + )[(((DATA_TYPE > 0) ? base_idx * 2 : base_idx) + local_dst_idx) / local_dst_size] = + reinterpret_cast(&)[local_dst_size]>(vals)[0]; + } else { #pragma unroll NUM_PER_TH - for (int i = 0; i < local_dst_size; i++) { - if (local_dst_idx + i < local_store_idx) { - out[((DATA_TYPE > 0) ? base_idx * 2 : base_idx) + local_dst_idx + i] = - vals[i]; - } + for (int i = 0; i < local_dst_size; i++) { + if (local_dst_idx + i < local_store_idx) { + out[((DATA_TYPE > 0) ? base_idx * 2 : base_idx) + local_dst_idx + i] = vals[i]; + } + } } - } } -template +template SYCL_EXTERNAL void -kgemv_4bit_inference::operator()(sycl::nd_item<1> item) const { - size_t idx = item.get_local_id(); - const int sg_idx = idx / SUBG_SIZE; - const int sg_lane = idx % SUBG_SIZE; - const int num_values_4bit = SUBG_SIZE; - const int row_B = NUM_PER_THREAD * item.get_group().get_group_id() + sg_idx; - const int offset_B = ldb * row_B; - const int num_values_8bit = num_values_4bit / 2; - float local_C = 0.0f; + kgemv_4bit_inference::operator()(sycl::and_item<1> item) const { + size_t idx = item.get_local_id(); + const int sg_idx = idx / SUBG_SIZE; + const int sg_lane = idx % SUBG_SIZE; + const int num_values_4bit = SUBG_SIZE; + const int row_B = NUM_PER_THREAD * item.get_group().get_group_id() + sg_idx; + const int offset_B = ldb * row_B; + const int num_values_8bit = num_values_4bit / 2; + float local_C = 0.0f; - unsigned char local_B_4bit[num_values_8bit]; - T local_B[num_values_4bit / 4]; - T local_A[num_values_4bit / 4]; - T local_absmax = T(0.0f); + unsigned char local_B_4bit[num_values_8bit]; + T local_B[num_values_4bit / 4]; + T local_A[num_values_4bit / 4]; + T local_absmax = T(0.0f); - if (idx < 16) { - quant_map[idx] = T(datatype[idx]); - } + if (idx < 16) { + quant_map[idx] = T(datatype[idx]); + } - item.barrier(sycl::access::fence_space::local_space); + item.barrier(sycl::access::fence_space::local_space); - for (int inner_idx = sg_lane * num_values_4bit; inner_idx < K; - inner_idx += SUBG_SIZE * num_values_4bit) { - const int inner_idx_halved = inner_idx / 2; + for (int inner_idx = sg_lane * num_values_4bit; inner_idx < K; inner_idx += SUBG_SIZE * num_values_4bit) { + const int inner_idx_halved = inner_idx / 2; - // Avoid expensive divsion by the blocksize (as blocksize will always be a - // power-of-2) - const int absidx = ((2 * offset_B) + inner_idx) >> - (31 - std::countl_zero((unsigned int)blocksize)); - local_absmax = absmax[absidx]; + // Avoid expensive division by the blocksize (as blocksize will always be a + // power-of-2) + const int absidx = ((2 * offset_B) + inner_idx) >> (31 - std::countl_zero((unsigned int)blocksize)); + local_absmax = absmax[absidx]; - if (row_B < N) { - if ((inner_idx_halved + num_values_8bit) < (K / 2)) { - reinterpret_cast(&)[num_values_8bit]>( - local_B_4bit)[0] = - reinterpret_cast *>( - B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)]; - } else { + if (row_B < N) { + if ((inner_idx_halved + num_values_8bit) < (K / 2)) { + reinterpret_cast(&)[num_values_8bit]>(local_B_4bit)[0] = + reinterpret_cast*>(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)]; + } else { #pragma unroll - for (int j = 0; j < (num_values_8bit); j++) - if ((inner_idx_halved) + j < (K / 2)) - local_B_4bit[j] = B[offset_B + inner_idx_halved + j]; - else - local_B_4bit[j] = 0b01110111; - } - } else { + for (int j = 0; j < (num_values_8bit); j++) + if ((inner_idx_halved) + j < (K / 2)) + local_B_4bit[j] = B[offset_B + inner_idx_halved + j]; + else + local_B_4bit[j] = 0b01110111; + } + } else { #pragma unroll - for (int j = 0; j < (num_values_8bit); j++) - local_B_4bit[j] = 0b01110111; - } + for (int j = 0; j < (num_values_8bit); j++) + local_B_4bit[j] = 0b01110111; + } - for (int i = 0; i < 4; i++) { + for (int i = 0; i < 4; i++) { #pragma unroll - for (int k = 0; k < num_values_8bit / 4; k++) { - local_B[k * 2] = - quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * - local_absmax; - local_B[k * 2 + 1] = - quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * - local_absmax; - } + for (int k = 0; k < num_values_8bit / 4; k++) { + local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax; + local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax; + } - if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) { - if (BITS == 16) { - reinterpret_cast(&)[num_values_4bit / 4]>( - local_A)[0] = - reinterpret_cast *>( - A)[inner_idx / (num_values_4bit / 4) + i]; - } else { - reinterpret_cast(&)[num_values_4bit / 4]>( - local_A)[0] = - reinterpret_cast *>( - A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 0]; - reinterpret_cast(&)[num_values_4bit / 4]>( - local_A)[1] = - reinterpret_cast *>( - A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1]; - } + if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) { + if (BITS == 16) { + reinterpret_cast(&)[num_values_4bit / 4]>(local_A)[0] = + reinterpret_cast*>(A)[inner_idx / (num_values_4bit / 4) + i]; + } else { + reinterpret_cast(&)[num_values_4bit / 4]>(local_A)[0] = + reinterpret_cast*>(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 0]; + reinterpret_cast(&)[num_values_4bit / 4]>(local_A)[1] = + reinterpret_cast*>(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1]; + } - } else { + } else { #pragma unroll - for (int k = 0; k < num_values_4bit / 4; k++) - if (inner_idx + (i * num_values_4bit / 4) + k < K) - local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)]; - else - local_A[k] = T(0.0f); - } + for (int k = 0; k < num_values_4bit / 4; k++) + if (inner_idx + (i * num_values_4bit / 4) + k < K) + local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)]; + else + local_A[k] = T(0.0f); + } // accumulate in float for accuracy; #pragma unroll - for (int k = 0; k < num_values_4bit / 4; k++) { - local_C += (float)(local_A[k] * local_B[k]); - } + for (int k = 0; k < num_values_4bit / 4; k++) { + local_C += (float)(local_A[k] * local_B[k]); + } + } } - } - local_C = - sycl::reduce_over_group(item.get_sub_group(), local_C, sycl::plus<>()); + local_C = sycl::reduce_over_group(item.get_sub_group(), local_C, sycl::plus<>()); - if (row_B < N && sg_lane == 0) - out[row_B] = T(local_C); + if (row_B < N && sg_lane == 0) + out[row_B] = T(local_C); } //============================================================== @@ -296,11 +273,9 @@ template class kDequantizeBlockwise; template class kDequantizeBlockwise; template class kDequantizeBlockwise; -template class kDequantizeBlockwise; +template class kDequantizeBlockwise; template class kDequantizeBlockwise; template class kgemv_4bit_inference; -template class kgemv_4bit_inference; +template class kgemv_4bit_inference; template class kgemv_4bit_inference; diff --git a/csrc/xpu_kernels.h b/csrc/xpu_kernels.h index e5a115ced..bad6d4ca8 100644 --- a/csrc/xpu_kernels.h +++ b/csrc/xpu_kernels.h @@ -4,56 +4,49 @@ #ifndef xpu_kernels #define xpu_kernels -template -class kDequantizeBlockwise { -public: - SYCL_EXTERNAL void operator()(sycl::nd_item<1> item) const; - - kDequantizeBlockwise(float *code_, uint8_t *A_, float *absmax_, T *out_, - const int blocksize_, const int n_) - : code(code_), A(A_), absmax(absmax_), out(out_), blocksize(blocksize_), - n(n_) {} - -private: - float *code; - uint8_t *A; - float *absmax; - T *out; - const int blocksize; - const int n; +template class kDequantizeBlockwise { + public: + SYCL_EXTERNAL void operator()(sycl::and_item<1> item) const; + + kDequantizeBlockwise(float* code_, uint8_t* A_, float* absmax_, T* out_, const int blocksize_, const int n_) + : code(code_), A(A_), absmax(absmax_), out(out_), blocksize(blocksize_), n(n_) {} + + private: + float* code; + uint8_t* A; + float* absmax; + T* out; + const int blocksize; + const int n; }; -template -class kgemv_4bit_inference { -public: - SYCL_EXTERNAL void operator()(sycl::nd_item<1> item) const; - - kgemv_4bit_inference(int M_, int N_, int K_, T *A_, unsigned char *B_, - float *absmax_, const float *datatype_, T *out_, - int lda_, int ldb_, int ldc_, int blocksize_) - : M(M_), N(N_), K(K_), A(A_), B(B_), absmax(absmax_), datatype(datatype_), - out(out_), lda(lda_), ldb(ldb_), ldc(ldc_), blocksize(blocksize_), - quant_map() {} - - void sycl_ker_local_memory_creation(sycl::handler &cgh) { - quant_map = sycl::local_accessor(16, cgh); - } - -private: - int M; - int N; - int K; - T *A; - unsigned char *B; - float *absmax; - const float *datatype; - T *out; - int lda; - int ldb; - int ldc; - int blocksize; - sycl::local_accessor quant_map; +template class kgemv_4bit_inference { + public: + SYCL_EXTERNAL void operator()(sycl::and_item<1> item) const; + + kgemv_4bit_inference( + int M_, int N_, int K_, T* A_, unsigned char* B_, float* absmax_, const float* datatype_, T* out_, int lda_, + int ldb_, int ldc_, int blocksize_ + ) + : M(M_), N(N_), K(K_), A(A_), B(B_), absmax(absmax_), datatype(datatype_), out(out_), lda(lda_), ldb(ldb_), + ldc(ldc_), blocksize(blocksize_), quant_map() {} + + void sycl_ker_local_memory_creation(sycl::handler& cgh) { quant_map = sycl::local_accessor(16, cgh); } + + private: + int M; + int N; + int K; + T* A; + unsigned char* B; + float* absmax; + const float* datatype; + T* out; + int lda; + int ldb; + int ldc; + int blocksize; + sycl::local_accessor quant_map; }; #endif diff --git a/csrc/xpu_ops.cpp b/csrc/xpu_ops.cpp index c1feb3996..37ef92973 100644 --- a/csrc/xpu_ops.cpp +++ b/csrc/xpu_ops.cpp @@ -3,54 +3,52 @@ #include template -void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, - int blocksize, const int n, sycl::queue *stream) { - auto &queue = *stream; - const int workgroup_size = 128; - const int num_per_th = 4; - const int tile_size = workgroup_size * num_per_th; - if (DATA_TYPE > 0) { - const int workgroup_num = (n + tile_size * 2 - 1) / (tile_size * 2); - sycl::range<1> local_range{(size_t)workgroup_size}; - sycl::range<1> global_range{(size_t)workgroup_num * (size_t)workgroup_size}; - kDequantizeBlockwise kfn( - code, A, absmax, out, blocksize / 2, n); - sycl_kernel_submit( - sycl::nd_range<1>(sycl::range<1>(global_range), - sycl::range<1>(local_range)), - queue, kfn); - } else { - const int workgroup_num = (n + tile_size - 1) / tile_size; - sycl::range<1> local_range{(size_t)workgroup_size}; - sycl::range<1> global_range{(size_t)workgroup_num * (size_t)workgroup_size}; - kDequantizeBlockwise kfn( - code, A, absmax, out, blocksize, n); - sycl_kernel_submit( - sycl::nd_range<1>(sycl::range<1>(global_range), - sycl::range<1>(local_range)), - queue, kfn); - } +void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, T* out, int blocksize, const int n, sycl::queue* stream +) { + auto& queue = *stream; + const int workgroup_size = 128; + const int num_per_th = 4; + const int tile_size = workgroup_size * num_per_th; + if (DATA_TYPE > 0) { + const int workgroup_num = (n + tile_size * 2 - 1) / (tile_size * 2); + sycl::range<1> local_range{(size_t)workgroup_size}; + sycl::range<1> global_range{(size_t)workgroup_num * (size_t)workgroup_size}; + kDequantizeBlockwise kfn(code, A, absmax, out, blocksize / 2, n); + sycl_kernel_submit( + sycl::and_range<1>(sycl::range<1>(global_range), sycl::range<1>(local_range)), queue, kfn + ); + } else { + const int workgroup_num = (n + tile_size - 1) / tile_size; + sycl::range<1> local_range{(size_t)workgroup_size}; + sycl::range<1> global_range{(size_t)workgroup_num * (size_t)workgroup_size}; + kDequantizeBlockwise kfn(code, A, absmax, out, blocksize, n); + sycl_kernel_submit( + sycl::and_range<1>(sycl::range<1>(global_range), sycl::range<1>(local_range)), queue, kfn + ); + } } template -void gemv_4bit_inference(int m, int n, int k, T *A, unsigned char *B, - float *absmax, float *datatype, T *out, int lda, - int ldb, int ldc, int blocksize, sycl::queue *stream) { +void gemv_4bit_inference( + int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc, + int blocksize, sycl::queue* stream +) { - auto &queue = *stream; + auto& queue = *stream; - const size_t GROUP_SIZE = 128; // workgroup_size - const size_t SUBG_SIZE = 32; // subgroup_size - const size_t NUM_PER_THREAD = GROUP_SIZE / SUBG_SIZE; - size_t workgroup_num = (n + NUM_PER_THREAD - 1) / NUM_PER_THREAD; + const size_t GROUP_SIZE = 128; // workgroup_size + const size_t SUBG_SIZE = 32; // subgroup_size + const size_t NUM_PER_THREAD = GROUP_SIZE / SUBG_SIZE; + size_t workgroup_num = (n + NUM_PER_THREAD - 1) / NUM_PER_THREAD; - kgemv_4bit_inference kfn( - m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); + kgemv_4bit_inference kfn( + m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize + ); - sycl_comp_kernel_submit( - sycl::nd_range<1>(sycl::range<1>(GROUP_SIZE * workgroup_num), - sycl::range<1>(GROUP_SIZE)), - queue, kfn); + sycl_comp_kernel_submit( + sycl::and_range<1>(sycl::range<1>(GROUP_SIZE * workgroup_num), sycl::range<1>(GROUP_SIZE)), queue, kfn + ); } //============================================================== @@ -58,51 +56,47 @@ void gemv_4bit_inference(int m, int n, int k, T *A, unsigned char *B, //============================================================== template void dequantizeBlockwise( - float *code, unsigned char *A, float *absmax, float *out, int blocksize, - const int n, sycl::queue *stream); -template void dequantizeBlockwise(float *code, unsigned char *A, - float *absmax, float *out, - int blocksize, const int n, - sycl::queue *stream); -template void dequantizeBlockwise(float *code, unsigned char *A, - float *absmax, float *out, - int blocksize, const int n, - sycl::queue *stream); + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream +); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream +); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream +); template void dequantizeBlockwise( - float *code, unsigned char *A, float *absmax, sycl::half *out, - int blocksize, const int n, sycl::queue *stream); + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream +); template void dequantizeBlockwise( - float *code, unsigned char *A, float *absmax, sycl::half *out, - int blocksize, const int n, sycl::queue *stream); + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream +); template void dequantizeBlockwise( - float *code, unsigned char *A, float *absmax, sycl::half *out, - int blocksize, const int n, sycl::queue *stream); + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream +); template void dequantizeBlockwise( - float *code, unsigned char *A, float *absmax, - sycl::ext::oneapi::bfloat16 *out, int blocksize, const int n, - sycl::queue *stream); + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + sycl::queue* stream +); template void dequantizeBlockwise( - float *code, unsigned char *A, float *absmax, - sycl::ext::oneapi::bfloat16 *out, int blocksize, const int n, - sycl::queue *stream); + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + sycl::queue* stream +); template void dequantizeBlockwise( - float *code, unsigned char *A, float *absmax, - sycl::ext::oneapi::bfloat16 *out, int blocksize, const int n, - sycl::queue *stream); + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + sycl::queue* stream +); template void gemv_4bit_inference( - int m, int n, int k, sycl::half *A, unsigned char *B, float *absmax, - float *datatype, sycl::half *out, int lda, int ldb, int ldc, int blocksize, - sycl::queue *stream); + int m, int n, int k, sycl::half* A, unsigned char* B, float* absmax, float* datatype, sycl::half* out, int lda, + int ldb, int ldc, int blocksize, sycl::queue* stream +); template void gemv_4bit_inference( - int m, int n, int k, sycl::ext::oneapi::bfloat16 *A, unsigned char *B, - float *absmax, float *datatype, sycl::ext::oneapi::bfloat16 *out, int lda, - int ldb, int ldc, int blocksize, sycl::queue *stream); -template void gemv_4bit_inference(int m, int n, int k, float *A, - unsigned char *B, float *absmax, - float *datatype, float *out, - int lda, int ldb, int ldc, - int blocksize, - sycl::queue *stream); + int m, int n, int k, sycl::ext::oneapi::bfloat16* A, unsigned char* B, float* absmax, float* datatype, + sycl::ext::oneapi::bfloat16* out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream +); +template void gemv_4bit_inference( + int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb, + int ldc, int blocksize, sycl::queue* stream +); diff --git a/csrc/xpu_ops.h b/csrc/xpu_ops.h index 3045283a9..fa395fcc4 100644 --- a/csrc/xpu_ops.h +++ b/csrc/xpu_ops.h @@ -12,38 +12,35 @@ #include template -static inline void sycl_kernel_submit(sycl::nd_range range, sycl::queue q, - ker_t ker) { - auto cgf = [&](::sycl::handler & cgh) - [[sycl::reqd_sub_group_size(subgroup_size)]] { - cgh.parallel_for(range, ker); - }; - q.submit(cgf); +static inline void sycl_kernel_submit(sycl::and_range range, sycl::queue q, ker_t ker) { + auto cgf = [&](::sycl::handler& cgh) + [[sycl::reqd_sub_group_size(subgroup_size)]] { cgh.parallel_for(range, ker); }; + q.submit(cgf); } template -static inline void sycl_comp_kernel_submit(sycl::nd_range range, - sycl::queue q, ker_t ker) { - auto cgf = [&](::sycl::handler & cgh) - [[sycl::reqd_sub_group_size(subgroup_size)]] { - ker.sycl_ker_local_memory_creation(cgh); - cgh.parallel_for(range, ker); - }; - q.submit(cgf); +static inline void sycl_comp_kernel_submit(sycl::and_range range, sycl::queue q, ker_t ker) { + auto cgf = [&](::sycl::handler& cgh) [[sycl::reqd_sub_group_size(subgroup_size)]] { + ker.sycl_ker_local_memory_creation(cgh); + cgh.parallel_for(range, ker); + }; + q.submit(cgf); } typedef enum DataType_t { - General8bit = 0, - FP4 = 1, - NF4 = 2, + General8bit = 0, + FP4 = 1, + NF4 = 2, } DataType_t; template -void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, - int workgroup_size, const int n, sycl::queue *stream); +void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, T* out, int workgroup_size, const int n, sycl::queue* stream +); template -void gemv_4bit_inference(int m, int n, int k, T *A, unsigned char *B, - float *absmax, float *datatype, T *out, int lda, - int ldb, int ldc, int blocksize, sycl::queue *stream); +void gemv_4bit_inference( + int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc, + int blocksize, sycl::queue* stream +); #endif