Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions sgl-kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ if("${CUDA_VERSION}" VERSION_EQUAL "12.8")
elseif("${CUDA_VERSION}" VERSION_EQUAL "12.9")
set(DeepGEMM_REPO "https://github.com/sgl-project/DeepGEMM")
set(DeepGEMM_TAG "blackwell")
elseif("${CUDA_VERSION}" VERSION_EQUAL "13.0")
set(DeepGEMM_REPO "https://github.com/sgl-project/DeepGEMM")
set(DeepGEMM_TAG "blackwell")
else()
set(DeepGEMM_REPO "https://github.com/deepseek-ai/DeepGEMM")
set(DeepGEMM_TAG "391755ada0ffefa9a6a52b6f14dcaf22d1a463e0")
Expand All @@ -83,7 +86,7 @@ FetchContent_Populate(repo-triton)
FetchContent_Declare(
repo-flashinfer
GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git
GIT_TAG 9220fb3443b5a5d274f00ca5552f798e225239b7
GIT_TAG 018b551825c8e5579206e6eb9d3229fa679202b3
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-flashinfer)
Expand Down Expand Up @@ -179,11 +182,28 @@ if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_100,code=sm_100"
"-gencode=arch=compute_100a,code=sm_100a"
"-gencode=arch=compute_101,code=sm_101"
"-gencode=arch=compute_101a,code=sm_101a"
"-gencode=arch=compute_103,code=sm_103"
"-gencode=arch=compute_103a,code=sm_103a"
"-gencode=arch=compute_120,code=sm_120"
"-gencode=arch=compute_120a,code=sm_120a"
)

# refer sm_121, sm_110 and sm_101 description https://github.com/pytorch/pytorch/pull/156176
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "13.0")
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_110,code=sm_110"
"-gencode=arch=compute_110a,code=sm_110a"
"-gencode=arch=compute_121,code=sm_121"
"-gencode=arch=compute_121a,code=sm_121a"
"--compress-mode=size"
)
else()
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_101,code=sm_101"
"-gencode=arch=compute_101a,code=sm_101a"
)
endif()

else()
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-use_fast_math"
Expand Down Expand Up @@ -266,12 +286,6 @@ set(SOURCES
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
"csrc/moe/marlin_moe_wna16/ops.cu"
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu"
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu"
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu"
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu"
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu"
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu"
"csrc/moe/moe_align_kernel.cu"
"csrc/moe/moe_fused_gate.cu"
"csrc/moe/moe_topk_softmax_kernels.cu"
Expand Down
27 changes: 25 additions & 2 deletions sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
FILE_HEAD = """
// auto generated by generate.py
// clang-format off
#pragma once

#include "kernel.h"
#include "marlin_template.h"
Expand All @@ -33,6 +34,17 @@
"( MARLIN_KERNEL_PARAMS );"
)

KERNEL_FILE_TEMPLATE = (
"// auto generated by generate.py\n"
"// clang-format off\n"
"#pragma once\n\n"
"{% for kernel_file in kernel_files %}"
'#include "{{ kernel_file }}"\n'
"{% endfor %}"
)

KERNEL_FILE_NAME = "kernel_marlin.cuh"

# int8 with zero point case (sglang::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = ["sglang::kU4", "sglang::kU4B8", "sglang::kU8B128"]
Expand All @@ -48,11 +60,12 @@


def remove_old_kernels():
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cuh"):
subprocess.call(["rm", "-f", filename])


def generate_new_kernels():
kernel_files = set()
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
has_zp = "B" not in scalar_type
all_template_str_list = []
Expand Down Expand Up @@ -95,10 +108,20 @@ def generate_new_kernels():

file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cu"
filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cuh"

with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content)
kernel_files.add(filename)

kernel_files = list(kernel_files)
kernel_files.sort()

file_content = jinja2.Template(KERNEL_FILE_TEMPLATE).render(
kernel_files=kernel_files
)
with open(os.path.join(os.path.dirname(__file__), KERNEL_FILE_NAME), "w") as f:
f.write(file_content)


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#pragma once

#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// auto generated by generate.py
// clang-format off
#pragma once

#include "kernel.h"
#include "marlin_template.h"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// auto generated by generate.py
// clang-format off
#pragma once

#include "kernel.h"
#include "marlin_template.h"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// auto generated by generate.py
// clang-format off
#pragma once

#include "kernel.h"
#include "marlin_template.h"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// auto generated by generate.py
// clang-format off
#pragma once

#include "kernel.h"
#include "marlin_template.h"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// auto generated by generate.py
// clang-format off
#pragma once

#include "kernel.h"
#include "marlin_template.h"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// auto generated by generate.py
// clang-format off
#pragma once

#include "kernel.h"
#include "marlin_template.h"
Expand Down
10 changes: 10 additions & 0 deletions sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_marlin.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// auto generated by generate.py
// clang-format off
#pragma once

#include "kernel_bf16_ku4.cuh"
#include "kernel_bf16_ku4b8.cuh"
#include "kernel_bf16_ku8b128.cuh"
#include "kernel_fp16_ku4.cuh"
#include "kernel_fp16_ku4b8.cuh"
#include "kernel_fp16_ku8b128.cuh"
2 changes: 2 additions & 0 deletions sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
/*
* Adapted from https://github.com/IST-DASLab/marlin
*/
#pragma once

#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
Expand Down
1 change: 1 addition & 0 deletions sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#endif

#include "kernel.h"
#include "kernel_marlin.cuh"

#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert( \
Expand Down
16 changes: 13 additions & 3 deletions sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#ifndef USE_ROCM
#include <cub/cub.cuh>
#include <cub/util_type.cuh>
#include <cuda/functional>
#else
#include <hipcub/hipcub.hpp>
#include <hipcub/util_type.hpp>
Expand All @@ -33,6 +34,16 @@ limitations under the License.
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))

// Define reduction operators based on CUDA version
// CUDA 13 (12.9+) deprecated cub::Max/Min in favor of cuda::maximum/minimum
#if CUDA_VERSION >= 12090
using MaxReduceOp = cuda::maximum<>;
using MinReduceOp = cuda::minimum<>;
#else
using MaxReduceOp = cub::Max;
using MinReduceOp = cub::Min;
#endif

/// Aligned array type
template <
typename T,
Expand Down Expand Up @@ -72,7 +83,6 @@ __launch_bounds__(TPB) __global__

const int thread_row_offset = blockIdx.x * num_cols;

cub::Sum sum;
float threadData(-FLT_MAX);

// Don't touch finished rows.
Expand All @@ -85,7 +95,7 @@ __launch_bounds__(TPB) __global__
threadData = max(convert_to_float<T>(input[idx]), threadData);
}

const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, MaxReduceOp());

if (threadIdx.x == 0) {
float_max = maxElem;
Expand All @@ -99,7 +109,7 @@ __launch_bounds__(TPB) __global__
threadData += exp((convert_to_float<T>(input[idx]) - float_max));
}

const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
const auto Z = BlockReduce(tmpStorage).Sum(threadData);

if (threadIdx.x == 0) {
normalizing_factor = 1.f / Z;
Expand Down
Loading