27
27
28
28
namespace tensorrt_llm ::kernels::mnnvl
29
29
{
30
+
31
+ // Guard for internal helper functions
32
+ namespace
33
+ {
30
34
__device__ bool isNegZero (float v)
31
35
{
32
36
return v == 0 .f && signbit (v);
@@ -49,6 +53,12 @@ inline __device__ float toFloat<__nv_bfloat16>(__nv_bfloat16 val)
49
53
return __bfloat162float (val);
50
54
}
51
55
56
+ template <>
57
+ inline __device__ float toFloat<__nv_half>(__nv_half val)
58
+ {
59
+ return __half2float (val);
60
+ }
61
+
52
62
template <typename T>
53
63
inline __device__ T fromFloat (float val)
54
64
{
@@ -61,19 +71,77 @@ inline __device__ __nv_bfloat16 fromFloat<__nv_bfloat16>(float val)
61
71
return __float2bfloat16 (val);
62
72
}
63
73
64
- __device__ __inline__ float2 loadfloat2 (void const * ptr)
74
+ template <>
75
+ inline __device__ __nv_half fromFloat<__nv_half>(float val)
65
76
{
77
+ return __float2half (val);
78
+ }
66
79
67
- float return_value[ 2 ];
68
-
69
- asm volatile ( " ld.volatile.global.v2.f32 {%0, %1}, [%2]; \n "
70
- : " =f" (return_value[ 0 ] ), " =f" (return_value[ 1 ])
71
- : " l " (ptr)
72
- : " memory " );
80
+ inline __device__ float2 loadfloat2 ( void const * ptr)
81
+ {
82
+ float2 return_value;
83
+ asm volatile ( " ld.volatile.global.v2.f32 {%0, %1}, [%2]; \n " : " =f" (return_value. x ), " =f" (return_value. y ) : " l " (ptr));
84
+ return return_value;
85
+ }
73
86
74
- return *(float2 *) return_value;
87
+ template <typename T>
88
+ inline __device__ T divUp (T val, T divisor)
89
+ {
90
+ return (val + divisor - 1 ) / divisor;
75
91
}
76
92
93
+ __device__ struct __attribute__ ((aligned(32 ))) LamportFlags
94
+ {
95
+ uint32_t buffer_size;
96
+ uint32_t input_offset;
97
+ uint32_t clear_offset;
98
+ uint32_t num_tokens_prev;
99
+ uint32_t * offset_access_ptr;
100
+ uint32_t * buffer_flags;
101
+
102
+ __device__ explicit LamportFlags (uint32_t * buffer_flags)
103
+ : offset_access_ptr (&buffer_flags[4 ])
104
+ , buffer_flags (buffer_flags)
105
+ {
106
+ uint4 flag = reinterpret_cast <uint4 *>(buffer_flags)[0 ];
107
+ buffer_size = flag.z ;
108
+ input_offset = flag.x * (buffer_size << 1U );
109
+ clear_offset = flag.y * (buffer_size << 1U );
110
+ num_tokens_prev = flag.w ;
111
+ }
112
+
113
+ __device__ void cta_arrive ()
114
+ {
115
+ __syncthreads ();
116
+ if (threadIdx .x == 0 )
117
+ {
118
+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
119
+ asm volatile (" red.async.release.global.gpu.add.u32 [%0], %1;" ::" l" (offset_access_ptr), " r" (1 ) : " memory" );
120
+ #elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
121
+ asm volatile (" red.global.gpu.add.u32 [%0], %1;" ::" l" (offset_access_ptr), " r" (1 ) : " memory" );
122
+ #else
123
+ atomicAdd (offset_access_ptr, 1 );
124
+ #endif
125
+ }
126
+ }
127
+
128
+ __device__ void wait_and_update (uint32_t num_tokens)
129
+ {
130
+ if (threadIdx .x == 0 && blockIdx .x == gridDim .x - 1 && blockIdx .y == 0 )
131
+ {
132
+ while (*reinterpret_cast <uint32_t volatile *>(offset_access_ptr) < gridDim .x * gridDim .y )
133
+ {
134
+ }
135
+ uint4 flag = reinterpret_cast <uint4 *>(buffer_flags)[0 ];
136
+ buffer_flags[0 ] = (flag.x + 1 ) % 3 ;
137
+ buffer_flags[1 ] = (flag.y + 1 ) % 3 ;
138
+ buffer_flags[3 ] = num_tokens;
139
+ *(offset_access_ptr) = 0 ;
140
+ }
141
+ }
142
+ };
143
+ } // namespace
144
+
77
145
template <int WORLD_SIZE, typename T>
78
146
__global__ void twoshot_allreduce_kernel (T* output_ptr, T* shard_ptr, T** input_ptrs, T* mcast_ptr, int num_tokens,
79
147
int buffer_M, int token_dim, int rank, uint32_t * buffer_flags, bool wait_for_results)
@@ -87,19 +155,14 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
87
155
cudaGridDependencySynchronize ();
88
156
#endif
89
157
90
- // [input_ptr, clear_ptr, buffer_size, access_counter]
91
- uint4 flag = reinterpret_cast <uint4 *>(buffer_flags)[0 ];
92
- // Each buffer is M * N and we have 2 buffers in each group, one for reduce-scatter and one for allgather
93
- uint32_t buffer_group_size = flag.z << 1 ;
94
- uint32_t input_offset = flag.x * buffer_group_size;
95
- uint32_t clear_offset = flag.y * buffer_group_size;
96
- // Capture the number of tokens from the last call so that we can properly clear the buffer
97
- // The scatter stage will use the buffer in WORLD_SIZE granularity, thus we need to round up
98
- uint32_t num_tokens_to_clear = flag.w > num_tokens ? flag.w : num_tokens;
99
- num_tokens_to_clear = (num_tokens_to_clear + WORLD_SIZE - 1 ) / WORLD_SIZE * WORLD_SIZE;
100
- uint32_t * offset_access_ptr = &buffer_flags[4 ];
158
+ LamportFlags flags (buffer_flags);
101
159
102
- uint32_t clear_tokens_per_cta = (num_tokens_to_clear + gridDim .x - 1 ) / gridDim .x ;
160
+ // Capture the number of tokens in previous iteration so that we can properly clear the buffer
161
+ // The scatter stage will use the buffer in WORLD_SIZE granularity, thus we need to round up
162
+ uint32_t clr_toks_cta
163
+ = divUp<uint32_t >(flags.num_tokens_prev > num_tokens ? flags.num_tokens_prev : num_tokens, WORLD_SIZE)
164
+ * WORLD_SIZE;
165
+ clr_toks_cta = divUp<uint32_t >(clr_toks_cta, gridDim .x );
103
166
104
167
if (elt < token_dim)
105
168
{
@@ -109,16 +172,17 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
109
172
T val = shard_ptr[token * token_dim + elt];
110
173
if (isNegZero (val))
111
174
val = fromFloat<T>(0 .f );
112
- input_ptrs[dest_rank][input_offset + dest_token_offset * token_dim * WORLD_SIZE + rank * token_dim + elt] = val;
175
+ input_ptrs[dest_rank][flags.input_offset + dest_token_offset * token_dim * WORLD_SIZE + rank * token_dim + elt]
176
+ = val;
113
177
114
178
// Clear the buffer used by the previous call. Note the number of tokens to clear could be larger than the
115
179
// number of tokens in the current call.
116
- for (int clr_tok = 0 ; clr_tok < clear_tokens_per_cta ; clr_tok++)
180
+ for (int clr_tok = 0 ; clr_tok < clr_toks_cta ; clr_tok++)
117
181
{
118
182
uint32_t clr_token_idx = token + clr_tok * gridDim .x ;
119
183
if (clr_token_idx < buffer_M)
120
184
{
121
- input_ptrs[rank][clear_offset + clr_token_idx * token_dim + elt] = fromFloat<T>(-0 .f );
185
+ input_ptrs[rank][flags. clear_offset + clr_token_idx * token_dim + elt] = fromFloat<T>(-0 .f );
122
186
}
123
187
}
124
188
@@ -134,7 +198,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
134
198
bool valid = true ;
135
199
for (int r = 0 ; r < WORLD_SIZE; r++)
136
200
{
137
- T volatile * lamport_ptr = (T volatile *) &input_ptrs[rank][input_offset
201
+ T volatile * lamport_ptr = (T volatile *) &input_ptrs[rank][flags. input_offset
138
202
+ local_token * token_dim * WORLD_SIZE + r * token_dim + elt];
139
203
values[r] = *lamport_ptr;
140
204
valid &= !isNegZero (values[r]);
@@ -146,7 +210,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
146
210
{
147
211
accum += toFloat<T>(values[r]);
148
212
}
149
- mcast_ptr[input_offset + buffer_M * token_dim + token * token_dim + elt] = fromFloat<T>(accum);
213
+ mcast_ptr[flags. input_offset + buffer_M * token_dim + token * token_dim + elt] = fromFloat<T>(accum);
150
214
}
151
215
}
152
216
@@ -155,12 +219,12 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
155
219
#endif
156
220
157
221
// Similarly clear broadcast buffer here
158
- for (int clr_tok = 0 ; clr_tok < clear_tokens_per_cta ; clr_tok++)
222
+ for (int clr_tok = 0 ; clr_tok < clr_toks_cta ; clr_tok++)
159
223
{
160
224
uint32_t clr_token_idx = token + clr_tok * gridDim .x ;
161
225
if (clr_token_idx < buffer_M)
162
226
{
163
- input_ptrs[rank][clear_offset + buffer_M * token_dim + clr_token_idx * token_dim + elt]
227
+ input_ptrs[rank][flags. clear_offset + buffer_M * token_dim + clr_token_idx * token_dim + elt]
164
228
= fromFloat<T>(-0 .f );
165
229
}
166
230
}
@@ -169,26 +233,16 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
169
233
if (wait_for_results)
170
234
{
171
235
// Update the atomic counter to indicate the block has read the offsets
172
- __syncthreads ();
236
+ flags. cta_arrive ();
173
237
174
- if (threadIdx .x == 0 )
175
- {
176
- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
177
- asm volatile (" red.async.release.global.gpu.add.u32 [%0], %1;" ::" l" (offset_access_ptr), " r" (1 ) : " memory" );
178
- #elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
179
- asm volatile (" red.global.gpu.add.u32 [%0], %1;" ::" l" (offset_access_ptr), " r" (1 ) : " memory" );
180
- #else
181
- atomicAdd (offset_access_ptr, 1 );
182
- #endif
183
- }
184
238
// Only use a set of CTAs for lamport sync, reargange the grid
185
239
constexpr int ELTS_PER_LOAD = sizeof (float2 ) / sizeof (T);
186
240
// blockDim.x / ELTS_PER_LOAD should be at least the size of a warp (32)
187
241
if (threadIdx .x < (blockDim .x / ELTS_PER_LOAD))
188
242
{
189
243
uint64_t current_pos = blockIdx .x * token_dim + blockIdx .y * blockDim .x + threadIdx .x * ELTS_PER_LOAD;
190
244
191
- void * lamport_ptr = (void *) &input_ptrs[rank][input_offset + buffer_M * token_dim + current_pos];
245
+ void * lamport_ptr = (void *) &input_ptrs[rank][flags. input_offset + buffer_M * token_dim + current_pos];
192
246
// We have 2 assumptions here:
193
247
// 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be aligned to 8B
194
248
// 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32)
@@ -204,18 +258,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
204
258
}
205
259
206
260
// Update the buffer flags
207
- if (threadIdx .x == 0 && blockIdx .x == gridDim .x - 1 && blockIdx .y == 0 )
208
- {
209
- // Make sure all blocks have finished reading the offsets, 2-D grid
210
- while (*reinterpret_cast <uint32_t volatile *>(offset_access_ptr) < gridDim .x * gridDim .y )
211
- {
212
- }
213
- buffer_flags[0 ] = (flag.x + 1 ) % 3 ;
214
- buffer_flags[1 ] = (flag.y + 1 ) % 3 ;
215
- // Update the flags with the number of tokens in the current call
216
- buffer_flags[3 ] = num_tokens;
217
- *(offset_access_ptr) = 0 ;
218
- }
261
+ flags.wait_and_update (num_tokens);
219
262
}
220
263
}
221
264
@@ -281,12 +324,28 @@ void twoshot_allreduce_op(AllReduceParams const& params)
281
324
default : TLLM_CHECK_WITH_INFO (false , " TwoShot AllReduce]: unsupported world_size." );
282
325
}
283
326
}
327
+ else if (dtype == nvinfer1::DataType::kHALF )
328
+ {
329
+ switch (world_size)
330
+ {
331
+ case 2 : LAUNCH_ALL_REDUCE_KERNEL (2 , __nv_half); break ;
332
+ case 4 : LAUNCH_ALL_REDUCE_KERNEL (4 , __nv_half); break ;
333
+ case 8 : LAUNCH_ALL_REDUCE_KERNEL (8 , __nv_half); break ;
334
+ case 16 : LAUNCH_ALL_REDUCE_KERNEL (16 , __nv_half); break ;
335
+ case 32 : LAUNCH_ALL_REDUCE_KERNEL (32 , __nv_half); break ;
336
+ case 64 : LAUNCH_ALL_REDUCE_KERNEL (64 , __nv_half); break ;
337
+ default : TLLM_CHECK_WITH_INFO (false , " TwoShot AllReduce]: unsupported world_size." );
338
+ }
339
+ }
284
340
else
285
341
{
286
342
TLLM_CHECK_WITH_INFO (false , " TwoShot AllReduce]: unsupported dtype." );
287
343
}
288
344
}
289
345
346
+ // Guard for internal helper functions
347
+ namespace
348
+ {
290
349
template <typename T_IN>
291
350
__device__ void copy_f4 (T_IN* dst, T_IN const * src)
292
351
{
@@ -338,14 +397,15 @@ inline __device__ float block_reduce_sum(float val)
338
397
__device__ float4 loadfloat4 (void const * ptr)
339
398
{
340
399
341
- float return_value[ 4 ] ;
400
+ float4 return_value;
342
401
343
402
asm volatile (" ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n "
344
- : " =f" (return_value[ 0 ] ), " =f" (return_value[ 1 ] ), " =f" (return_value[ 2 ] ), " =f" (return_value[ 3 ] )
403
+ : " =f" (return_value. x ), " =f" (return_value. y ), " =f" (return_value. z ), " =f" (return_value. w )
345
404
: " l" (ptr));
346
405
347
- return *( float4 *) return_value;
406
+ return return_value;
348
407
}
408
+ } // namespace
349
409
350
410
template <int DIM, int NUM_THREADS, int NUM_INPUTS, typename T_OUT, typename T_IN>
351
411
__global__ void __launch_bounds__ (128 , 1 )
@@ -373,12 +433,8 @@ __global__ void __launch_bounds__(128, 1)
373
433
374
434
int offsets[NUM_INPUTS][DIM / (1 * ELTS_PER_THREAD * NUM_THREADS)];
375
435
376
- uint32_t * offset_access_ptr = &buffer_flags[4 ];
377
- uint4 flag = reinterpret_cast <uint4 *>(buffer_flags)[0 ];
378
- // Buffer size is M * N, and we need two buffers for reduce-scatter and allgather
379
- uint32_t buffer_size = flag.z ;
380
- uint32_t buffer_offset = flag.x * (buffer_size << 1 );
381
- T_IN const * input = &buffer_input[buffer_offset + buffer_size];
436
+ LamportFlags flags (buffer_flags);
437
+ T_IN const * input = &buffer_input[flags.input_offset + flags.buffer_size ];
382
438
383
439
cudaTriggerProgrammaticLaunchCompletion ();
384
440
@@ -408,17 +464,7 @@ __global__ void __launch_bounds__(128, 1)
408
464
}
409
465
410
466
__pipeline_commit ();
411
- __syncthreads ();
412
- if (threadIdx .x == 0 )
413
- {
414
- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
415
- asm volatile (" red.async.release.global.gpu.add.u32 [%0], %1;" ::" l" (offset_access_ptr), " r" (1 ) : " memory" );
416
- #elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
417
- asm volatile (" red.global.gpu.add.u32 [%0], %1;" ::" l" (offset_access_ptr), " r" (1 ) : " memory" );
418
- #else
419
- atomicAdd (offset_access_ptr, 1 );
420
- #endif
421
- }
467
+ flags.cta_arrive ();
422
468
// Load all inputs
423
469
bool valid = false ;
424
470
@@ -548,17 +594,7 @@ __global__ void __launch_bounds__(128, 1)
548
594
= out4;
549
595
}
550
596
// Update the buffer pointers
551
- if (threadIdx .x == 0 && blockIdx .x == 0 && blockIdx .y == 0 )
552
- {
553
- // Make sure all blocks have finished accessing the buffer
554
- while (*reinterpret_cast <uint32_t volatile *>(offset_access_ptr) < gridDim .x * gridDim .y )
555
- {
556
- }
557
- buffer_flags[0 ] = (flag.x + 1 ) % 3 ;
558
- buffer_flags[1 ] = (flag.y + 1 ) % 3 ;
559
- buffer_flags[3 ] = batch_size;
560
- *(offset_access_ptr) = 0 ;
561
- }
597
+ flags.wait_and_update (batch_size);
562
598
#endif
563
599
}
564
600
@@ -569,8 +605,6 @@ void twoshot_rmsnorm(T* prenorm_output, T* normed_output, T const* input, T cons
569
605
570
606
// input to rmsnorm is the buffer in the twoshot ar
571
607
// We should use prenorm output to determine the actual used size
572
- // int batch = normed_output.sizes()[0];
573
- // int dim = normed_output.sizes()[1];
574
608
float _epsilon{static_cast <float >(epsilon)};
575
609
576
610
static constexpr int NUM_THREADS = 128 ;
@@ -633,6 +667,20 @@ void twoshot_rmsnorm_op(RMSNormParams const& params)
633
667
default : TLLM_CHECK_WITH_INFO (false , " [MNNVL TwoShot RMSNorm]: unsupported hidden_dim." );
634
668
}
635
669
}
670
+ else if (dtype == nvinfer1::DataType::kHALF )
671
+ {
672
+ switch (params.hidden_dim )
673
+ {
674
+ case 2048 : LAUNCH_RMSNORM_KERNEL (__nv_half, 2048 ); break ;
675
+ case 4096 : LAUNCH_RMSNORM_KERNEL (__nv_half, 4096 ); break ;
676
+ // Llama-4 Hidden Dimension
677
+ case 5120 : LAUNCH_RMSNORM_KERNEL (__nv_half, 5120 ); break ;
678
+ // DeepSeek Hidden Dimension
679
+ case 7168 : LAUNCH_RMSNORM_KERNEL (__nv_half, 7168 ); break ;
680
+ case 8192 : LAUNCH_RMSNORM_KERNEL (__nv_half, 8192 ); break ;
681
+ default : TLLM_CHECK_WITH_INFO (false , " [MNNVL TwoShot RMSNorm]: unsupported hidden_dim." );
682
+ }
683
+ }
636
684
else
637
685
{
638
686
TLLM_CHECK_WITH_INFO (false , " [MNNVL TwoShot RMSNorm]: unsupported dtype." );
0 commit comments