Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,14 @@ void threeStepBuildExpertMapsSortFirstToken(int const* token_selected_experts, i
int64_t const num_tokens, int64_t const num_experts_per_node, int64_t const num_experts_per_token,
int const start_expert_id, cudaStream_t stream);

template <class InputActivationsType, class ExpandedActivationsType>
template <class InputActivationsType, class ExpandedActivationsType, bool PRE_QUANT_AWQ = false>
void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
ExpandedActivationsType* permuted_output, float const* unpermuted_scales, float* permuted_scales,
int const* permuted_row_to_unpermuted_row, int64_t const num_rows, int64_t const cols, int const k,
int const num_experts_per_node, float const* fc1_act_global_scale, bool use_per_expert_act_scale,
int64_t* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, cudaStream_t stream);
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, cudaStream_t stream,
void const* prequant_scales = nullptr);

template <class OutputType, class GemmOutputType, class ScaleBiasType>
void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_rows,
Expand Down
82 changes: 65 additions & 17 deletions cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1422,13 +1422,14 @@ __host__ __device__ constexpr static U arrayConvert(T const& input)

constexpr static int EXPAND_THREADS_PER_BLOCK = 256;

template <class InputActivationsType, class ExpandedActivationsType>
template <class InputActivationsType, class ExpandedActivationsType, bool PRE_QUANT_AWQ>
__global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_input,
ExpandedActivationsType* permuted_output, float const* unpermuted_scales, float* permuted_scales,
int const* permuted_row_to_unpermuted_row, int64_t const num_rows, int64_t const cols, int64_t const k,
float const* fc1_act_global_scale, bool use_per_expert_act_scale, int64_t const* expert_first_token_offset,
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, int64_t const num_experts_per_node)
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, int64_t const num_experts_per_node,
InputActivationsType const* prequant_scales = nullptr)
{
#ifdef ENABLE_FP4
constexpr bool is_fp4 = std::is_same_v<ExpandedActivationsType, __nv_fp4_e2m1>;
Expand All @@ -1440,8 +1441,8 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp
constexpr bool need_fp4_quant = false;
#endif

static_assert(need_fp4_quant || std::is_same_v<InputActivationsType, ExpandedActivationsType>,
"Only FP4 quantization supports outputting a different format as part of the expansion");
static_assert(need_fp4_quant || PRE_QUANT_AWQ || std::is_same_v<InputActivationsType, ExpandedActivationsType>,
"Only FP4 and WINT4_AFP8 supports outputting a different format as part of the expansion");

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
Expand Down Expand Up @@ -1499,10 +1500,35 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp
assert(act_scale_idx == 0 && "Cannot use per-expert act scale for pre-quantized activations");
writeSF(num_tokens_before_expert, expert, source_row, permuted_row, elem_index, cols, num_rows,
fc1_act_sf_flat, input_sf);
dest_row_ptr[elem_index] = in_vec;
dest_row_ptr[elem_index] = reinterpret_cast<OutputElem>(in_vec);
}
}
}
else if constexpr (PRE_QUANT_AWQ)
{
using InputElem = cutlass::Array<InputActivationsType, ELEM_PER_THREAD>;
using OutputElem_ = cutlass::Array<ExpandedActivationsType, ELEM_PER_THREAD>;
using OutputElem_AWQ = std::conditional_t<is_fp4, uint32_t, OutputElem_>;
auto const* source_row_ptr_awq
= reinterpret_cast<InputElem const*>(unpermuted_input + source_row * cols / ELEM_PER_BYTE);
auto* dest_row_ptr_awq
= reinterpret_cast<OutputElem_AWQ*>(permuted_output) + permuted_row * cols / ELEM_PER_THREAD;
cutlass::NumericArrayConverter<ExpandedActivationsType, InputActivationsType, ELEM_PER_THREAD> converter;
InputElem frag_elems;

for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride)
{
frag_elems = source_row_ptr_awq[elem_index];

CUTLASS_PRAGMA_UNROLL
for (int e = 0; e < ELEM_PER_THREAD; e++)
{
frag_elems[e] = frag_elems[e] * prequant_scales[elem_index * ELEM_PER_THREAD + e];
}

dest_row_ptr_awq[elem_index] = converter(frag_elems);
}
}
else
{
for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride)
Expand All @@ -1522,13 +1548,14 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp
#endif
}

template <class InputActivationsType, class ExpandedActivationsType>
template <class InputActivationsType, class ExpandedActivationsType, bool PRE_QUANT_AWQ = false>
void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
ExpandedActivationsType* permuted_output, float const* unpermuted_scales, float* permuted_scales,
int const* permuted_row_to_unpermuted_row, int64_t const num_rows, int64_t const cols, int const k,
int const num_experts_per_node, float const* fc1_act_global_scale, bool use_per_expert_act_scale,
int64_t* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, cudaStream_t stream)
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, cudaStream_t stream,
void const* prequant_scales = nullptr)
{
#ifdef ENABLE_FP4
// TODO Currently this is a bit hacky because we assume we are in FP8_MXFP4 mode if activations are FP8.
Expand All @@ -1552,7 +1579,7 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
// Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
int64_t const blocks = smCount * 8;
int64_t const threads = EXPAND_THREADS_PER_BLOCK;
auto func = expandInputRowsKernel<InputActivationsType, ExpandedActivationsType>;
auto func = expandInputRowsKernel<InputActivationsType, ExpandedActivationsType, PRE_QUANT_AWQ>;

cudaLaunchConfig_t config;
config.gridDim = blocks;
Expand All @@ -1566,17 +1593,19 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
config.attrs = attrs;
cudaLaunchKernelEx(&config, func, unpermuted_input, permuted_output, unpermuted_scales, permuted_scales,
permuted_row_to_unpermuted_row, num_rows, cols, k, fc1_act_global_scale, use_per_expert_act_scale,
expert_first_token_offset, fc1_act_sf_flat, input_sf, num_experts_per_node);
expert_first_token_offset, fc1_act_sf_flat, input_sf, num_experts_per_node,
reinterpret_cast<InputActivationsType const*>(prequant_scales));
}

#define INSTANTIATE_EXPAND_INPUT_ROWS(InputActivationsType, ExpandedActivationsType) \
template void expandInputRowsKernelLauncher<InputActivationsType, ExpandedActivationsType>( \
template void expandInputRowsKernelLauncher<InputActivationsType, ExpandedActivationsType, false>( \
InputActivationsType const* unpermuted_input, ExpandedActivationsType* permuted_output, \
float const* unpermuted_scales, float* permuted_scales, int const* permuted_row_to_unpermuted_row, \
int64_t const num_rows, int64_t const cols, int const k, int const num_experts_per_node, \
float const* fc1_act_global_scale, bool use_per_expert_act_scale, int64_t* expert_first_token_offset, \
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, \
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, cudaStream_t stream);
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, cudaStream_t stream, \
void const* prequant_scales);

INSTANTIATE_EXPAND_INPUT_ROWS(half, half);
INSTANTIATE_EXPAND_INPUT_ROWS(float, float);
Expand Down Expand Up @@ -3341,10 +3370,26 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
using ExpandedActivationsType = std::conditional_t<use_w4afp8, BackBoneType, T>;
// Only NVFP4xNVFP4 supports FC1 per-expert act scale
bool use_per_expert_act_scale = use_fp4 ? quant_params.fp4.fc1.use_per_expert_act_scale : false;
expandInputRowsKernelLauncher(input_activations, reinterpret_cast<ExpandedActivationsType*>(permuted_data_),
token_topk_unpermuted_scales, permuted_token_final_scales_, permuted_row_to_unpermuted_row_, num_rows,
hidden_size, experts_per_token, num_experts_per_node, quant_params.fp4.fc1.act_global_scale,
use_per_expert_act_scale, expert_first_token_offset_, fc1_fp4_act_scale_, input_sf, stream);
T const* gemm1_input;
if constexpr (use_w4afp8)
{
// FP16/BF16 input_activations -> FP8 smoothed_act
expandInputRowsKernelLauncher<InputType, T, true>(input_activations, reinterpret_cast<T*>(smoothed_act_),
token_topk_unpermuted_scales, permuted_token_final_scales_, permuted_row_to_unpermuted_row_, num_rows,
hidden_size, experts_per_token, num_experts_per_node, quant_params.fp4.fc1.act_global_scale,
use_per_expert_act_scale, expert_first_token_offset_, fc1_fp4_act_scale_, input_sf, stream,
quant_params.groupwise.fc1.act_scales);

gemm1_input = reinterpret_cast<T const*>(smoothed_act_);
}
else
{
expandInputRowsKernelLauncher(input_activations, reinterpret_cast<ExpandedActivationsType*>(permuted_data_),
token_topk_unpermuted_scales, permuted_token_final_scales_, permuted_row_to_unpermuted_row_, num_rows,
hidden_size, experts_per_token, num_experts_per_node, quant_params.fp4.fc1.act_global_scale,
use_per_expert_act_scale, expert_first_token_offset_, fc1_fp4_act_scale_, input_sf, stream);
gemm1_input = reinterpret_cast<T const*>(permuted_data_);
}

sync_check_cuda_error(stream);

Expand All @@ -3371,8 +3416,11 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
}
}

auto gemm1_input = applyPrequantScale(smoothed_act_, permuted_data_, quant_params.groupwise.fc1.act_scales,
num_valid_tokens_ptr, expanded_num_rows, hidden_size, use_awq, stream);
if constexpr (!use_w4afp8)
{
gemm1_input = applyPrequantScale(smoothed_act_, permuted_data_, quant_params.groupwise.fc1.act_scales,
num_valid_tokens_ptr, expanded_num_rows, hidden_size, use_awq, stream);
}
sync_check_cuda_error(stream);
Self::gemm1(moe_gemm_runner_, blockscale_gemm_runner, gemm1_input, fc1_result_, glu_inter_result_,
expert_first_token_offset_, gemm1_tma_ws_input, fc1_expert_weights, fc1_expert_biases, num_valid_tokens_ptr,
Expand Down
1 change: 0 additions & 1 deletion tests/unittest/_torch/modules/test_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,6 @@ def test_fused_moe_nvfp4(dtype):
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)


@pytest.mark.skip(reason="https://nvbugs/5325653")
@skip_neither_ada_nor_hopper_unittest
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_fused_moe_w4afp8(dtype):
Expand Down