@@ -17,53 +17,72 @@ static __global__ void cumsum_kernel(
1717 const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
1818 const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) {
1919
20- // Shared memory to store warp sums (always use float for accumulation)
21- extern __shared__ float shmem[];
20+ const int tid = threadIdx .x ;
21+ const int lane = tid & (WARP_SIZE - 1 );
22+ const int warp = tid / WARP_SIZE ;
23+ const int warps_per_block = blockDim .x / WARP_SIZE ;
24+
25+ extern __shared__ float smem[];
26+ float * s_vals = smem;
27+ float * s_warp_sums = smem + blockDim .x ;
28+ float * s_carry = smem + blockDim .x + warps_per_block;
29+ float * s_chunk_total = s_carry + 1 ;
30+
31+ // Initialize carry
32+ if (tid == 0 ) {
33+ *s_carry = 0 .0f ;
34+ }
35+ __syncthreads ();
2236
2337 const int64_t i3 = blockIdx .z ;
2438 const int64_t i2 = blockIdx .y ;
2539 const int64_t i1 = blockIdx .x ;
26-
2740 if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
2841 return ;
2942 }
3043
31- const T * src_row = src + i1 * nb01 + i2* nb02 + i3* nb03;
32- T * dst_row = dst + i1 * nb1 + i2* nb2 + i3* nb3;
44+ const T * src_row = src + i1 * nb01 + i2 * nb02 + i3 * nb03;
45+ T * dst_row = dst + i1 * nb1 + i2 * nb2 + i3 * nb3;
3346
34- const int tid = threadIdx .x ;
35- const int lane_id = tid % WARP_SIZE ;
36-
37- if (tid >= ne00) {
38- return ;
39- }
47+ for (int64_t start = 0 ; start < ne00; start += blockDim .x ) {
48+ int64_t idx = start + tid;
49+ float val = (idx < ne00) ? static_cast <float >(src_row[idx]) : 0 .0f ;
4050
41- // Phase 1: Each thread processes elements at stride blockDim.x
42- // Compute warp-level prefix sums
43- for (int64_t i0 = tid; i0 < ne00; i0 += blockDim .x ) {
44- // Load value and compute prefix sum within warp
45- float val = static_cast <float >(src_row[i0]);
51+ // 1. Warp inclusive scan
4652 val = warp_prefix_inclusive_sum (val);
47- dst_row[i0 ] = static_cast <T>( val) ;
53+ s_vals[tid ] = val;
4854
49- // Last thread of warp stores its sum to shared memory at position based on data index
50- if (lane_id == WARP_SIZE - 1 || i0 == ne00 - 1 ) {
51- const int shmem_idx = i0 / WARP_SIZE ;
52- shmem[shmem_idx] = val;
55+ // Store warp total
56+ if (lane == WARP_SIZE - 1 ) {
57+ s_warp_sums[warp] = val;
5358 }
54- }
59+ __syncthreads ();
60+
61+ // 2. Exclusive scan of warp sums (warp 0 only)
62+ if (warp == 0 ) {
63+ float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0 .0f ;
64+ float inc = warp_prefix_inclusive_sum (w);
65+ if (tid < warps_per_block) {
66+ s_warp_sums[tid] = inc - w; // exclusive sum
67+ }
68+ if (tid == warps_per_block - 1 ) {
69+ *s_chunk_total = inc; // total sum of this chunk
70+ }
71+ }
72+ __syncthreads ();
5573
56- // Sync once after all warp prefix sums are computed
57- __syncthreads ();
74+ float carry = *s_carry;
75+ float final_val = s_vals[tid] + s_warp_sums[warp] + carry;
76+ if (idx < ne00) {
77+ dst_row[idx] = static_cast <T>(final_val);
78+ }
79+ __syncthreads ();
5880
59- // Phase 2: Add the sum of all preceding warp groups to each element
60- for (int64_t i0 = tid; i0 < ne00; i0 += blockDim .x ) {
61- const int shmem_idx = i0 / WARP_SIZE ;
62- float sum = 0 .0f ;
63- for (int j = 0 ; j < shmem_idx; ++j) {
64- sum += shmem[j];
81+ // Update carry for next chunk
82+ if (tid == 0 ) {
83+ *s_carry += *s_chunk_total;
6584 }
66- dst_row[i0] = static_cast <T>( static_cast < float >(dst_row[i0]) + sum );
85+ __syncthreads ( );
6786 }
6887}
6988
@@ -76,15 +95,13 @@ static void cumsum_cuda(
7695 cudaStream_t stream) {
7796
7897 dim3 grid_dims (ne01, ne02, ne03);
79-
80- // Shared memory size: one float per warp
8198 const int num_warps = (ne00 + WARP_SIZE - 1 ) / WARP_SIZE ;
82- const size_t shmem_size = num_warps * sizeof (float );
83- const size_t type_size = sizeof (T);
84-
8599 int block_size = num_warps * WARP_SIZE ;
86100 block_size = std::min (block_size, CUDA_CUMSUM_BLOCK_SIZE );
87101 dim3 block_dims (block_size, 1 , 1 );
102+ const int warps_per_block = block_size / WARP_SIZE ;
103+ const size_t shmem_size = (block_size + warps_per_block + 2 ) * sizeof (float );
104+ const size_t type_size = sizeof (T);
88105
89106 cumsum_kernel<<<grid_dims, block_dims, shmem_size, stream>>> (
90107 src, dst,
0 commit comments