Skip to content

Commit df917cc

Browse files
committed
Optimizations + big case performance tests
1 parent f422ba8 commit df917cc

2 files changed

Lines changed: 55 additions & 37 deletions

File tree

ggml/src/ggml-cuda/cumsum.cu

Lines changed: 53 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,53 +17,72 @@ static __global__ void cumsum_kernel(
1717
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
1818
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) {
1919

20-
// Shared memory to store warp sums (always use float for accumulation)
21-
extern __shared__ float shmem[];
20+
const int tid = threadIdx.x;
21+
const int lane = tid & (WARP_SIZE - 1);
22+
const int warp = tid / WARP_SIZE;
23+
const int warps_per_block = blockDim.x / WARP_SIZE;
24+
25+
extern __shared__ float smem[];
26+
float* s_vals = smem;
27+
float* s_warp_sums = smem + blockDim.x;
28+
float* s_carry = smem + blockDim.x + warps_per_block;
29+
float* s_chunk_total = s_carry + 1;
30+
31+
// Initialize carry
32+
if (tid == 0) {
33+
*s_carry = 0.0f;
34+
}
35+
__syncthreads();
2236

2337
const int64_t i3 = blockIdx.z;
2438
const int64_t i2 = blockIdx.y;
2539
const int64_t i1 = blockIdx.x;
26-
2740
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
2841
return;
2942
}
3043

31-
const T * src_row = src + i1 * nb01 + i2*nb02 + i3*nb03;
32-
T * dst_row = dst + i1 * nb1 + i2*nb2 + i3*nb3;
44+
const T * src_row = src + i1 * nb01 + i2 * nb02 + i3 * nb03;
45+
T * dst_row = dst + i1 * nb1 + i2 * nb2 + i3 * nb3;
3346

34-
const int tid = threadIdx.x;
35-
const int lane_id = tid % WARP_SIZE;
36-
37-
if (tid >= ne00) {
38-
return;
39-
}
47+
for (int64_t start = 0; start < ne00; start += blockDim.x) {
48+
int64_t idx = start + tid;
49+
float val = (idx < ne00) ? static_cast<float>(src_row[idx]) : 0.0f;
4050

41-
// Phase 1: Each thread processes elements at stride blockDim.x
42-
// Compute warp-level prefix sums
43-
for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) {
44-
// Load value and compute prefix sum within warp
45-
float val = static_cast<float>(src_row[i0]);
51+
// 1. Warp inclusive scan
4652
val = warp_prefix_inclusive_sum(val);
47-
dst_row[i0] = static_cast<T>(val);
53+
s_vals[tid] = val;
4854

49-
// Last thread of warp stores its sum to shared memory at position based on data index
50-
if (lane_id == WARP_SIZE - 1 || i0 == ne00 - 1) {
51-
const int shmem_idx = i0 / WARP_SIZE;
52-
shmem[shmem_idx] = val;
55+
// Store warp total
56+
if (lane == WARP_SIZE - 1) {
57+
s_warp_sums[warp] = val;
5358
}
54-
}
59+
__syncthreads();
60+
61+
// 2. Exclusive scan of warp sums (warp 0 only)
62+
if (warp == 0) {
63+
float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
64+
float inc = warp_prefix_inclusive_sum(w);
65+
if (tid < warps_per_block) {
66+
s_warp_sums[tid] = inc - w; // exclusive sum
67+
}
68+
if (tid == warps_per_block - 1) {
69+
*s_chunk_total = inc; // total sum of this chunk
70+
}
71+
}
72+
__syncthreads();
5573

56-
// Sync once after all warp prefix sums are computed
57-
__syncthreads();
74+
float carry = *s_carry;
75+
float final_val = s_vals[tid] + s_warp_sums[warp] + carry;
76+
if (idx < ne00) {
77+
dst_row[idx] = static_cast<T>(final_val);
78+
}
79+
__syncthreads();
5880

59-
// Phase 2: Add the sum of all preceding warp groups to each element
60-
for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) {
61-
const int shmem_idx = i0 / WARP_SIZE;
62-
float sum = 0.0f;
63-
for (int j = 0; j < shmem_idx; ++j) {
64-
sum += shmem[j];
81+
// Update carry for next chunk
82+
if (tid == 0) {
83+
*s_carry += *s_chunk_total;
6584
}
66-
dst_row[i0] = static_cast<T>(static_cast<float>(dst_row[i0]) + sum);
85+
__syncthreads();
6786
}
6887
}
6988

@@ -76,15 +95,13 @@ static void cumsum_cuda(
7695
cudaStream_t stream) {
7796

7897
dim3 grid_dims(ne01, ne02, ne03);
79-
80-
// Shared memory size: one float per warp
8198
const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE;
82-
const size_t shmem_size = num_warps * sizeof(float);
83-
const size_t type_size = sizeof(T);
84-
8599
int block_size = num_warps * WARP_SIZE;
86100
block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE);
87101
dim3 block_dims(block_size, 1, 1);
102+
const int warps_per_block = block_size / WARP_SIZE;
103+
const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float);
104+
const size_t type_size = sizeof(T);
88105

89106
cumsum_kernel<<<grid_dims, block_dims, shmem_size, stream>>>(
90107
src, dst,

tests/test-backend-ops.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7942,7 +7942,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
79427942
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 }));
79437943

79447944
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 1 }));
7945-
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 1 }));
7945+
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2048, 16, 5, 4 }));
7946+
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 20000, 10, 4, 1 }));
79467947

79477948
for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
79487949
for (ggml_type type_a : all_types) {

0 commit comments

Comments
 (0)