Skip to content

Commit 9b9c631

Browse files
committed
Code clean up; Add FP16 support.
Signed-off-by: Shiyu Li <[email protected]>
1 parent 8f720d7 commit 9b9c631

File tree

2 files changed

+133
-85
lines changed

2 files changed

+133
-85
lines changed

cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu

Lines changed: 132 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727

2828
namespace tensorrt_llm::kernels::mnnvl
2929
{
30+
31+
// Guard for internal helper functions
32+
namespace
33+
{
3034
__device__ bool isNegZero(float v)
3135
{
3236
return v == 0.f && signbit(v);
@@ -49,6 +53,12 @@ inline __device__ float toFloat<__nv_bfloat16>(__nv_bfloat16 val)
4953
return __bfloat162float(val);
5054
}
5155

56+
template <>
57+
inline __device__ float toFloat<__nv_half>(__nv_half val)
58+
{
59+
return __half2float(val);
60+
}
61+
5262
template <typename T>
5363
inline __device__ T fromFloat(float val)
5464
{
@@ -61,19 +71,77 @@ inline __device__ __nv_bfloat16 fromFloat<__nv_bfloat16>(float val)
6171
return __float2bfloat16(val);
6272
}
6373

64-
__device__ __inline__ float2 loadfloat2(void const* ptr)
74+
template <>
75+
inline __device__ __nv_half fromFloat<__nv_half>(float val)
6576
{
77+
return __float2half(val);
78+
}
6679

67-
float return_value[2];
68-
69-
asm volatile("ld.volatile.global.v2.f32 {%0, %1}, [%2];\n"
70-
: "=f"(return_value[0]), "=f"(return_value[1])
71-
: "l"(ptr)
72-
: "memory");
80+
inline __device__ float2 loadfloat2(void const* ptr)
81+
{
82+
float2 return_value;
83+
asm volatile("ld.volatile.global.v2.f32 {%0, %1}, [%2];\n" : "=f"(return_value.x), "=f"(return_value.y) : "l"(ptr));
84+
return return_value;
85+
}
7386

74-
return *(float2*) return_value;
87+
template <typename T>
88+
inline __device__ T divUp(T val, T divisor)
89+
{
90+
return (val + divisor - 1) / divisor;
7591
}
7692

93+
__device__ struct __attribute__((aligned(32))) LamportFlags
94+
{
95+
uint32_t buffer_size;
96+
uint32_t input_offset;
97+
uint32_t clear_offset;
98+
uint32_t num_tokens_prev;
99+
uint32_t* offset_access_ptr;
100+
uint32_t* buffer_flags;
101+
102+
__device__ explicit LamportFlags(uint32_t* buffer_flags)
103+
: offset_access_ptr(&buffer_flags[4])
104+
, buffer_flags(buffer_flags)
105+
{
106+
uint4 flag = reinterpret_cast<uint4*>(buffer_flags)[0];
107+
buffer_size = flag.z;
108+
input_offset = flag.x * (buffer_size << 1U);
109+
clear_offset = flag.y * (buffer_size << 1U);
110+
num_tokens_prev = flag.w;
111+
}
112+
113+
__device__ void cta_arrive()
114+
{
115+
__syncthreads();
116+
if (threadIdx.x == 0)
117+
{
118+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
119+
asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
120+
#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
121+
asm volatile("red.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
122+
#else
123+
atomicAdd(offset_access_ptr, 1);
124+
#endif
125+
}
126+
}
127+
128+
__device__ void wait_and_update(uint32_t num_tokens)
129+
{
130+
if (threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == 0)
131+
{
132+
while (*reinterpret_cast<uint32_t volatile*>(offset_access_ptr) < gridDim.x * gridDim.y)
133+
{
134+
}
135+
uint4 flag = reinterpret_cast<uint4*>(buffer_flags)[0];
136+
buffer_flags[0] = (flag.x + 1) % 3;
137+
buffer_flags[1] = (flag.y + 1) % 3;
138+
buffer_flags[3] = num_tokens;
139+
*(offset_access_ptr) = 0;
140+
}
141+
}
142+
};
143+
} // namespace
144+
77145
template <int WORLD_SIZE, typename T>
78146
__global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ptrs, T* mcast_ptr, int num_tokens,
79147
int buffer_M, int token_dim, int rank, uint32_t* buffer_flags, bool wait_for_results)
@@ -87,19 +155,14 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
87155
cudaGridDependencySynchronize();
88156
#endif
89157

90-
// [input_ptr, clear_ptr, buffer_size, access_counter]
91-
uint4 flag = reinterpret_cast<uint4*>(buffer_flags)[0];
92-
// Each buffer is M * N and we have 2 buffers in each group, one for reduce-scatter and one for allgather
93-
uint32_t buffer_group_size = flag.z << 1;
94-
uint32_t input_offset = flag.x * buffer_group_size;
95-
uint32_t clear_offset = flag.y * buffer_group_size;
96-
// Capture the number of tokens from the last call so that we can properly clear the buffer
97-
// The scatter stage will use the buffer in WORLD_SIZE granularity, thus we need to round up
98-
uint32_t num_tokens_to_clear = flag.w > num_tokens ? flag.w : num_tokens;
99-
num_tokens_to_clear = (num_tokens_to_clear + WORLD_SIZE - 1) / WORLD_SIZE * WORLD_SIZE;
100-
uint32_t* offset_access_ptr = &buffer_flags[4];
158+
LamportFlags flags(buffer_flags);
101159

102-
uint32_t clear_tokens_per_cta = (num_tokens_to_clear + gridDim.x - 1) / gridDim.x;
160+
// Capture the number of tokens in previous iteration so that we can properly clear the buffer
161+
// The scatter stage will use the buffer in WORLD_SIZE granularity, thus we need to round up
162+
uint32_t clr_toks_cta
163+
= divUp<uint32_t>(flags.num_tokens_prev > num_tokens ? flags.num_tokens_prev : num_tokens, WORLD_SIZE)
164+
* WORLD_SIZE;
165+
clr_toks_cta = divUp<uint32_t>(clr_toks_cta, gridDim.x);
103166

104167
if (elt < token_dim)
105168
{
@@ -109,16 +172,17 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
109172
T val = shard_ptr[token * token_dim + elt];
110173
if (isNegZero(val))
111174
val = fromFloat<T>(0.f);
112-
input_ptrs[dest_rank][input_offset + dest_token_offset * token_dim * WORLD_SIZE + rank * token_dim + elt] = val;
175+
input_ptrs[dest_rank][flags.input_offset + dest_token_offset * token_dim * WORLD_SIZE + rank * token_dim + elt]
176+
= val;
113177

114178
// Clear the buffer used by the previous call. Note the number of tokens to clear could be larger than the
115179
// number of tokens in the current call.
116-
for (int clr_tok = 0; clr_tok < clear_tokens_per_cta; clr_tok++)
180+
for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++)
117181
{
118182
uint32_t clr_token_idx = token + clr_tok * gridDim.x;
119183
if (clr_token_idx < buffer_M)
120184
{
121-
input_ptrs[rank][clear_offset + clr_token_idx * token_dim + elt] = fromFloat<T>(-0.f);
185+
input_ptrs[rank][flags.clear_offset + clr_token_idx * token_dim + elt] = fromFloat<T>(-0.f);
122186
}
123187
}
124188

@@ -134,7 +198,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
134198
bool valid = true;
135199
for (int r = 0; r < WORLD_SIZE; r++)
136200
{
137-
T volatile* lamport_ptr = (T volatile*) &input_ptrs[rank][input_offset
201+
T volatile* lamport_ptr = (T volatile*) &input_ptrs[rank][flags.input_offset
138202
+ local_token * token_dim * WORLD_SIZE + r * token_dim + elt];
139203
values[r] = *lamport_ptr;
140204
valid &= !isNegZero(values[r]);
@@ -146,7 +210,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
146210
{
147211
accum += toFloat<T>(values[r]);
148212
}
149-
mcast_ptr[input_offset + buffer_M * token_dim + token * token_dim + elt] = fromFloat<T>(accum);
213+
mcast_ptr[flags.input_offset + buffer_M * token_dim + token * token_dim + elt] = fromFloat<T>(accum);
150214
}
151215
}
152216

@@ -155,12 +219,12 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
155219
#endif
156220

157221
// Similarly clear broadcast buffer here
158-
for (int clr_tok = 0; clr_tok < clear_tokens_per_cta; clr_tok++)
222+
for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++)
159223
{
160224
uint32_t clr_token_idx = token + clr_tok * gridDim.x;
161225
if (clr_token_idx < buffer_M)
162226
{
163-
input_ptrs[rank][clear_offset + buffer_M * token_dim + clr_token_idx * token_dim + elt]
227+
input_ptrs[rank][flags.clear_offset + buffer_M * token_dim + clr_token_idx * token_dim + elt]
164228
= fromFloat<T>(-0.f);
165229
}
166230
}
@@ -169,26 +233,16 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
169233
if (wait_for_results)
170234
{
171235
// Update the atomic counter to indicate the block has read the offsets
172-
__syncthreads();
236+
flags.cta_arrive();
173237

174-
if (threadIdx.x == 0)
175-
{
176-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
177-
asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
178-
#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
179-
asm volatile("red.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
180-
#else
181-
atomicAdd(offset_access_ptr, 1);
182-
#endif
183-
}
184238
// Only use a set of CTAs for lamport sync, reargange the grid
185239
constexpr int ELTS_PER_LOAD = sizeof(float2) / sizeof(T);
186240
// blockDim.x / ELTS_PER_LOAD should be at least the size of a warp (32)
187241
if (threadIdx.x < (blockDim.x / ELTS_PER_LOAD))
188242
{
189243
uint64_t current_pos = blockIdx.x * token_dim + blockIdx.y * blockDim.x + threadIdx.x * ELTS_PER_LOAD;
190244

191-
void* lamport_ptr = (void*) &input_ptrs[rank][input_offset + buffer_M * token_dim + current_pos];
245+
void* lamport_ptr = (void*) &input_ptrs[rank][flags.input_offset + buffer_M * token_dim + current_pos];
192246
// We have 2 assumptions here:
193247
// 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be aligned to 8B
194248
// 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32)
@@ -204,18 +258,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
204258
}
205259

206260
// Update the buffer flags
207-
if (threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == 0)
208-
{
209-
// Make sure all blocks have finished reading the offsets, 2-D grid
210-
while (*reinterpret_cast<uint32_t volatile*>(offset_access_ptr) < gridDim.x * gridDim.y)
211-
{
212-
}
213-
buffer_flags[0] = (flag.x + 1) % 3;
214-
buffer_flags[1] = (flag.y + 1) % 3;
215-
// Update the flags with the number of tokens in the current call
216-
buffer_flags[3] = num_tokens;
217-
*(offset_access_ptr) = 0;
218-
}
261+
flags.wait_and_update(num_tokens);
219262
}
220263
}
221264

@@ -281,12 +324,28 @@ void twoshot_allreduce_op(AllReduceParams const& params)
281324
default: TLLM_CHECK_WITH_INFO(false, "TwoShot AllReduce]: unsupported world_size.");
282325
}
283326
}
327+
else if (dtype == nvinfer1::DataType::kHALF)
328+
{
329+
switch (world_size)
330+
{
331+
case 2: LAUNCH_ALL_REDUCE_KERNEL(2, __nv_half); break;
332+
case 4: LAUNCH_ALL_REDUCE_KERNEL(4, __nv_half); break;
333+
case 8: LAUNCH_ALL_REDUCE_KERNEL(8, __nv_half); break;
334+
case 16: LAUNCH_ALL_REDUCE_KERNEL(16, __nv_half); break;
335+
case 32: LAUNCH_ALL_REDUCE_KERNEL(32, __nv_half); break;
336+
case 64: LAUNCH_ALL_REDUCE_KERNEL(64, __nv_half); break;
337+
default: TLLM_CHECK_WITH_INFO(false, "TwoShot AllReduce]: unsupported world_size.");
338+
}
339+
}
284340
else
285341
{
286342
TLLM_CHECK_WITH_INFO(false, "TwoShot AllReduce]: unsupported dtype.");
287343
}
288344
}
289345

346+
// Guard for internal helper functions
347+
namespace
348+
{
290349
template <typename T_IN>
291350
__device__ void copy_f4(T_IN* dst, T_IN const* src)
292351
{
@@ -338,14 +397,15 @@ inline __device__ float block_reduce_sum(float val)
338397
__device__ float4 loadfloat4(void const* ptr)
339398
{
340399

341-
float return_value[4];
400+
float4 return_value;
342401

343402
asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n"
344-
: "=f"(return_value[0]), "=f"(return_value[1]), "=f"(return_value[2]), "=f"(return_value[3])
403+
: "=f"(return_value.x), "=f"(return_value.y), "=f"(return_value.z), "=f"(return_value.w)
345404
: "l"(ptr));
346405

347-
return *(float4*) return_value;
406+
return return_value;
348407
}
408+
} // namespace
349409

350410
template <int DIM, int NUM_THREADS, int NUM_INPUTS, typename T_OUT, typename T_IN>
351411
__global__ void __launch_bounds__(128, 1)
@@ -373,12 +433,8 @@ __global__ void __launch_bounds__(128, 1)
373433

374434
int offsets[NUM_INPUTS][DIM / (1 * ELTS_PER_THREAD * NUM_THREADS)];
375435

376-
uint32_t* offset_access_ptr = &buffer_flags[4];
377-
uint4 flag = reinterpret_cast<uint4*>(buffer_flags)[0];
378-
// Buffer size is M * N, and we need two buffers for reduce-scatter and allgather
379-
uint32_t buffer_size = flag.z;
380-
uint32_t buffer_offset = flag.x * (buffer_size << 1);
381-
T_IN const* input = &buffer_input[buffer_offset + buffer_size];
436+
LamportFlags flags(buffer_flags);
437+
T_IN const* input = &buffer_input[flags.input_offset + flags.buffer_size];
382438

383439
cudaTriggerProgrammaticLaunchCompletion();
384440

@@ -408,17 +464,7 @@ __global__ void __launch_bounds__(128, 1)
408464
}
409465

410466
__pipeline_commit();
411-
__syncthreads();
412-
if (threadIdx.x == 0)
413-
{
414-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
415-
asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
416-
#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
417-
asm volatile("red.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
418-
#else
419-
atomicAdd(offset_access_ptr, 1);
420-
#endif
421-
}
467+
flags.cta_arrive();
422468
// Load all inputs
423469
bool valid = false;
424470

@@ -548,17 +594,7 @@ __global__ void __launch_bounds__(128, 1)
548594
= out4;
549595
}
550596
// Update the buffer pointers
551-
if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0)
552-
{
553-
// Make sure all blocks have finished accessing the buffer
554-
while (*reinterpret_cast<uint32_t volatile*>(offset_access_ptr) < gridDim.x * gridDim.y)
555-
{
556-
}
557-
buffer_flags[0] = (flag.x + 1) % 3;
558-
buffer_flags[1] = (flag.y + 1) % 3;
559-
buffer_flags[3] = batch_size;
560-
*(offset_access_ptr) = 0;
561-
}
597+
flags.wait_and_update(batch_size);
562598
#endif
563599
}
564600

@@ -569,8 +605,6 @@ void twoshot_rmsnorm(T* prenorm_output, T* normed_output, T const* input, T cons
569605

570606
// input to rmsnorm is the buffer in the twoshot ar
571607
// We should use prenorm output to determine the actual used size
572-
// int batch = normed_output.sizes()[0];
573-
// int dim = normed_output.sizes()[1];
574608
float _epsilon{static_cast<float>(epsilon)};
575609

576610
static constexpr int NUM_THREADS = 128;
@@ -633,6 +667,20 @@ void twoshot_rmsnorm_op(RMSNormParams const& params)
633667
default: TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported hidden_dim.");
634668
}
635669
}
670+
else if (dtype == nvinfer1::DataType::kHALF)
671+
{
672+
switch (params.hidden_dim)
673+
{
674+
case 2048: LAUNCH_RMSNORM_KERNEL(__nv_half, 2048); break;
675+
case 4096: LAUNCH_RMSNORM_KERNEL(__nv_half, 4096); break;
676+
// Llama-4 Hidden Dimension
677+
case 5120: LAUNCH_RMSNORM_KERNEL(__nv_half, 5120); break;
678+
// DeepSeek Hidden Dimension
679+
case 7168: LAUNCH_RMSNORM_KERNEL(__nv_half, 7168); break;
680+
case 8192: LAUNCH_RMSNORM_KERNEL(__nv_half, 8192); break;
681+
default: TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported hidden_dim.");
682+
}
683+
}
636684
else
637685
{
638686
TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported dtype.");

tensorrt_llm/_torch/distributed/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def __init__(self, mapping: Mapping, dtype: torch.dtype):
305305

306306
@staticmethod
307307
def get_supported_dtypes():
308-
return (torch.bfloat16, torch.float32)
308+
return (torch.float16, torch.bfloat16, torch.float32)
309309

310310
def forward(
311311
self,

0 commit comments

Comments
 (0)