diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py index adfdd462e7..bb387d5cd9 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py @@ -16,7 +16,6 @@ from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import ( fused_single_block_cumsum_and_segmented_arange, - nvfp4_fused_padding_cumsum_and_segmented_arange, triton_nvfp4_quant_stacked, triton_quantize_mx4_unpack, triton_scale_nvfp4_quant, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip index 9280ce9f26..44ad30c283 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip @@ -16,7 +16,7 @@ #include #include -#include +#include #include #include @@ -24,7 +24,7 @@ #include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp" #include "kernels/fp8_rowwise_grouped_kernel_manifest.h" -namespace fbgemm_gpu { +namespace { template using RowwiseGroupedKernel = std::function; @@ -257,7 +268,8 @@ void set_static_kernel_args( } } -__global__ void set_kernel_args_m_sizes_kernel( +// Supports using either M_sizes or offsets. +__global__ void set_kernel_args( KernelArguments* kernel_args, ADataType* XQ, BDataType* WQ, @@ -265,10 +277,16 @@ __global__ void set_kernel_args_m_sizes_kernel( D1DataType* x_scale, EDataType* output, int64_t* M_sizes, + int32_t* offsets, int64_t M, int64_t N, int64_t K, - int64_t group_count) { + int64_t group_count, + std::optional input_type = std::nullopt) { + // The "message" part seems not working on AMD currently :( + CUDA_KERNEL_ASSERT_MSG((M_sizes == nullptr && offsets == nullptr) || (M_sizes == nullptr ^ offsets == nullptr), "Cannot set both M_sizes and offsets"); + CUDA_KERNEL_ASSERT_MSG(input_type.has_value() || M_sizes != nullptr, "M_sizes should not be used with input_type"); + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; // Each thread is responsible for setting up the arguments for one group. if (thread_idx < group_count) { @@ -295,29 +313,93 @@ __global__ void set_kernel_args_m_sizes_kernel( kernel_args[thread_idx] = default_group_args; // Sync threads to get consistent state. __syncthreads(); - // Get M information for this group. - int64_t kernel_M = M_sizes[thread_idx]; + + // Offset information for this group. + int64_t XQ_offset; + int64_t WQ_offset; + int64_t x_scale_offset; + int64_t w_scale_offset; + int64_t output_offset; + // Strides for this group + int64_t A_stride = K; + int64_t B_stride = K; + int64_t output_stride = N; + // Problem size for this group. Dynamic dimension for the group would updated below. + int64_t M_group = M; + int64_t N_group = N; + int64_t K_group = K; + + // M_sizes API implies 2D-3D inputs + if (M_sizes != nullptr) { + M_group = M_sizes[thread_idx]; + if (M_group > 0) { + // Offset is computed by finding the sum of previous group Ms. + int64_t offset_M = 0; + for (int i = 0; i < thread_idx; i++) { + offset_M += M_sizes[i]; + } + + XQ_offset = offset_M * K; + WQ_offset = thread_idx * N * K; + x_scale_offset = offset_M; + w_scale_offset = thread_idx * N; + output_offset = offset_M * N; + } + } else { + if (input_type == GroupedGemmInputType::_2D3D) { + const int32_t offset_M = thread_idx == 0 ? 0 : offsets[thread_idx - 1]; + M_group = offsets[thread_idx] - offset_M; + + XQ_offset = offset_M * K; + WQ_offset = thread_idx * N * K; + x_scale_offset = offset_M; + w_scale_offset = thread_idx * N; + output_offset = offset_M * N; + } else if (input_type == GroupedGemmInputType::_3D2D) { + const int32_t offset_N = thread_idx == 0 ? 0 : offsets[thread_idx - 1]; + N_group = offsets[thread_idx] - offset_N; + + XQ_offset = thread_idx * M * K; + WQ_offset = offset_N * K; + x_scale_offset = thread_idx * M; + w_scale_offset = offset_N; + // Offset of offset_N as the N dimension of the output is across the input gemm problems. + output_offset = offset_N; + } else if (input_type == GroupedGemmInputType::_2D2D) { + const int32_t offset_K = thread_idx == 0 ? 0 : offsets[thread_idx - 1]; + K_group = offsets[thread_idx] - offset_K; + + XQ_offset = offset_K; + WQ_offset = offset_K; + x_scale_offset = thread_idx * M; + w_scale_offset = thread_idx * N; + output_offset = thread_idx * M * N; + } else { + XQ_offset = thread_idx * M * K; + WQ_offset = thread_idx * N * K; + x_scale_offset = thread_idx * M; + w_scale_offset = thread_idx * N; + output_offset = thread_idx * M * N; + } + } + // Only write actual group information if this group is nonzero. - if (kernel_M > 0) { + if (M_group > 0 && N_group > 0 && K_group > 0) { // Get index automatically for this group. - int non_zero_idx = atomicAdd(&non_zero_counter, 1); - int64_t offset_M = 0; - // Offset is computed by finding the sum of previous group Ms. - for (int i = 0; i < thread_idx; i++) { - offset_M += M_sizes[i]; - } + const int non_zero_idx = atomicAdd(&non_zero_counter, 1); KernelArguments kernel_group_args = { - XQ + (offset_M * K), - WQ + (thread_idx * N * K), - {w_scale + (thread_idx * N), x_scale + offset_M}, - output + (offset_M * N), - int(kernel_M), - int(N), - int(K), - int(K), - int(K), - {0, 0}, - int(N)}; + XQ + XQ_offset, // A + WQ + WQ_offset, // B + {w_scale + w_scale_offset, x_scale + x_scale_offset}, // Ds + output + output_offset, // E + int(M_group), // M + int(N_group), // N + int(K_group), // K + int(A_stride), // StrideA + int(B_stride), // StrideB + {0, 0}, // StrideDs + int(output_stride) // StrideE + }; // Write kernel args to memory. kernel_args[non_zero_idx] = kernel_group_args; } @@ -450,7 +532,7 @@ void set_dynamic_kernel_args( // Depending on the mode, use appropriate setup kernel. if (M_sizes.has_value()) { - set_kernel_args_m_sizes_kernel<<<1, group_count, 0, stream>>>( + set_kernel_args<<<1, group_count, 0, stream>>>( reinterpret_cast(kernel_args.data_ptr()), reinterpret_cast(XQ.data_ptr()), reinterpret_cast(WQ.data_ptr()), @@ -458,6 +540,7 @@ void set_dynamic_kernel_args( reinterpret_cast(x_scale.data_ptr()), reinterpret_cast(output.data_ptr()), reinterpret_cast(M_sizes.value().data_ptr()), + nullptr, M, N, K, @@ -599,6 +682,33 @@ OutputType _f8f8bf16_rowwise_grouped( } } +void validate_inputs_common( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale +) { + TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); + TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); + TORCH_CHECK(x_scale.is_cuda()); + TORCH_CHECK(w_scale.is_cuda()); + + static auto float8_dtype = get_float8_e4m3_dtype(); + TORCH_CHECK( + XQ.dtype() == float8_dtype.first, + "Input XQ must be type ", float8_dtype.second); + TORCH_CHECK( + WQ.dtype() == float8_dtype.first, + "Input WQ must be type ", float8_dtype.second); + + TORCH_CHECK(x_scale.dtype() == at::kFloat, "Scales must be float32."); + TORCH_CHECK(w_scale.dtype() == at::kFloat, "Scales must be float32."); +} + +} // namespace + +namespace fbgemm_gpu { + std::vector f8f8bf16_rowwise_grouped( at::TensorList XQ, at::TensorList WQ, @@ -631,30 +741,20 @@ at::Tensor f8f8bf16_rowwise_grouped_stacked( // WQ is expected to be shape [G, N, K]. int64_t N = WQ.size(1); int64_t K = XQ.size(1); + + validate_inputs_common(XQ, WQ, x_scale, w_scale); + TORCH_CHECK( WQ.size(0) == group_count && x_scale.numel() == total_M && w_scale.numel() / group_count == N, "All inputs must have the same number of groups."); // Iterate over inputs and check they are valid. - TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); TORCH_CHECK(XQ.dim() == 2, "Input XQ must be 2D (total_M,K)."); - static auto float8_dtype = get_float8_e4m3_dtype(); - TORCH_CHECK( - XQ.dtype() == float8_dtype.first, - "Input XQ must be type ", float8_dtype.second); - - TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); TORCH_CHECK(WQ.dim() == 3, "Input WQ must be 3D (G,N,K)."); - TORCH_CHECK( - WQ.dtype() == float8_dtype.first, - "Input WQ must be type ", float8_dtype.second); TORCH_CHECK( WQ.size(1) >= 512 && WQ.size(2) >= 512, "N and K must be at least 512 for grouped gemm. For smaller inputs, consider unrolling."); - TORCH_CHECK(x_scale.dtype() == at::kFloat, "Scales must be float32."); - TORCH_CHECK(w_scale.dtype() == at::kFloat, "Scales must be float32."); - // Allocate an empty output array. We will set its values to zero as part // of kernel setup. at::Tensor Y = at::empty({total_M, N}, XQ.options().dtype(at::kBFloat16)); @@ -694,26 +794,15 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic( WQ.size(0) == group_count && x_scale.numel() / group_count == M && w_scale.numel() / group_count == N, "All inputs must have the same number of groups."); + + validate_inputs_common(XQ, WQ, x_scale, w_scale); // Iterate over inputs and check they are valid. - TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); TORCH_CHECK(XQ.dim() == 3, "Input XQ must be 3D (G,M,K)."); - static auto float8_dtype = get_float8_e4m3_dtype(); - TORCH_CHECK( - XQ.dtype() == float8_dtype.first, - "Input XQ must be type ", float8_dtype.second); - - TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); TORCH_CHECK(WQ.dim() == 3, "Input WQ must be 3D (G,N,K)."); - TORCH_CHECK( - WQ.dtype() == float8_dtype.first, - "Input WQ must be type ", float8_dtype.second); TORCH_CHECK( WQ.size(1) >= 512 && WQ.size(2) >= 512, "N and K must be at least 512 for grouped gemm. For smaller inputs, consider unrolling."); - TORCH_CHECK(x_scale.dtype() == at::kFloat, "Scales must be float32."); - TORCH_CHECK(w_scale.dtype() == at::kFloat, "Scales must be float32."); - // Allocate an empty output array. We will set its values to zero as part // of kernel setup. at::Tensor Y = @@ -744,4 +833,113 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic( return selected_kernel(XQ, WQ, x_scale, w_scale, kernel_args, Y); } +/** + * PyTorch compliant grouped GEMM API. + * Supports 2D-2D (K dynamic), 2D-3D (M dynamic), 3D-2D (N dynamic), and 3D-3D (BMM). + */ +at::Tensor f8f8bf16_rowwise_grouped_mm( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, // FP32 + at::Tensor w_scale, // FP32 + std::optional offsets, // int32 + at::Tensor& out) { + validate_inputs_common(XQ, WQ, x_scale, w_scale); + + // M, N, K could be the "total" dimension in the case of 2D inputs. + int64_t G; + int64_t M; + int64_t N; + int64_t K; + std::optional inputType; + + if (XQ.dim() == 2 && WQ.dim() == 3) { + TORCH_CHECK(offsets.has_value(), "Must pass offsets for 2D input XQ."); + TORCH_CHECK(offsets->dtype() == at::kInt, "offsets must be int32."); + + G = offsets->size(0); + M = XQ.size(0); + N = WQ.size(1); + K = WQ.size(2); + inputType = GroupedGemmInputType::_2D3D; + + TORCH_CHECK(XQ.size(1) == K && WQ.size(0) == G, "XQ shape must be (total_M, K) and WQ shape must be (G, N, K)."); + TORCH_CHECK(x_scale.size(0) == M, "x_scale shape must be (total_M)."); + TORCH_CHECK(w_scale.size(0) == G && w_scale.size(1) == N, "w_scale shape must be (G, N)."); + TORCH_CHECK(out.dim() == 2 && out.size(0) == M && out.size(1) == N, "out shape must be (total_M, N)."); + } else if (XQ.dim() == 3 && WQ.dim() == 2) { + TORCH_CHECK(offsets.has_value(), "Must pass offsets for 2D input WQ."); + TORCH_CHECK(offsets->dtype() == at::kInt, "offsets must be int32."); + + G = offsets->size(0); + M = XQ.size(1); + N = WQ.size(0); + K = WQ.size(1); + inputType = GroupedGemmInputType::_3D2D; + + TORCH_CHECK(XQ.size(0) == G && XQ.size(2) == K, "XQ shape must be (G, M, K) and WQ shape must be (total_N, K)."); + TORCH_CHECK(x_scale.size(0) == G && x_scale.size(1) == M, "x_scale shape must be (G, M)."); + TORCH_CHECK(w_scale.size(0) == N, "w_scale shape must be (total_N)."); + TORCH_CHECK(out.dim() == 2 && out.size(0) == M && out.size(1) == N, "out shape must be (M, total_N)."); + } else if (XQ.dim() == 3 && WQ.dim() == 3) { + TORCH_CHECK(!offsets.has_value(), "Offsets should not be passed for 3D-3D inputs."); + + G = XQ.size(0); + M = XQ.size(1); + N = WQ.size(1); + K = XQ.size(2); + inputType = GroupedGemmInputType::_3D3D; + + TORCH_CHECK(WQ.size(0) == G && WQ.size(2) == K, "XQ shape must be (G, M, K) and WQ shape must be (G, N, K)."); + TORCH_CHECK(x_scale.size(0) == G && x_scale.size(1) == M, "x_scale shape must be (G, M)."); + TORCH_CHECK(w_scale.size(0) == G && w_scale.size(1) == N, "w_scale shape must be (G, N)."); + TORCH_CHECK(out.dim() == 3 && out.size(0) == G && out.size(1) == M && out.size(2) == N, "out shape must be (G, M, N)."); + } else if (XQ.dim() == 2 && WQ.dim() == 2) { + TORCH_CHECK(offsets.has_value(), "Must pass offsets for 2D inputs XQ nd WQ."); + TORCH_CHECK(offsets->dtype() == at::kInt, "offsets must be int32."); + + G = offsets->size(0); + M = XQ.size(0); + N = WQ.size(0); + K = XQ.size(1); + inputType = GroupedGemmInputType::_2D2D; + + TORCH_CHECK(XQ.dim() == 2 && WQ.dim() == 2 && WQ.size(1) == K, "XQ shape must be (M, total_K) and WQ shape must be (N, total_K)."); + TORCH_CHECK(x_scale.size(0) == G * M, "x_scale shape must be (G * M)."); + TORCH_CHECK(w_scale.size(0) == G * N, "w_scale shape must be (G * N)."); + TORCH_CHECK(out.dim() == 3 && out.size(0) == G && out.size(1) == M && out.size(2) == N, "out shape must be (G, M, N)."); + } else { + TORCH_CHECK(false, "Invalid input shapes. Must be one of 2D-2D, 3D-3D, 2D-3D, 3D-2D."); + } + + // Early exit for empty input. + if (out.numel() == 0) { + return out; + } + + at::Tensor kernel_args = at::empty( + {static_cast(G * sizeof(KernelArguments))}, + XQ.options().dtype(at::kByte)); + + auto stream = at::cuda::getCurrentHIPStream().stream(); + set_kernel_args<<<1, G, 0, stream>>>( + reinterpret_cast(kernel_args.data_ptr()), + reinterpret_cast(XQ.data_ptr()), + reinterpret_cast(WQ.data_ptr()), + reinterpret_cast(w_scale.data_ptr()), + reinterpret_cast(x_scale.data_ptr()), + reinterpret_cast(out.data_ptr()), + nullptr, + offsets.has_value() ? reinterpret_cast(offsets.value().data_ptr()) : nullptr, + M, + N, + K, + G, + inputType); + + RowwiseGroupedKernel selected_kernel = + rowwise_grouped_heuristic_dispatch(G, M, N, K); + return selected_kernel(XQ, WQ, x_scale, w_scale, kernel_args, out); +} + } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip index ffe5a82384..b7cbb349ed 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip @@ -19,17 +19,18 @@ fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_i OutputType Y) { // Check if this input needs to be padded. -#if 0 - int M = XQ.size(1); - int N = WQ.size(1); - int K = WQ.size(2); - bool pad = (M % 128 != 0) || (N % 128 != 0) || (K % (128 * KBatch) != 0); -#else - // disable padding for packed tensor bool pad = false; -#endif - if (pad) - { + if constexpr (std::is_same::value) { + if (XQ.dim() == 3) { + pad = XQ.size(2) % 128 != 0; + } else if (WQ.dim() == 3) { + pad = WQ.size(2) % 128 != 0; + } else { + // For 2D-2D host does not know the K size for each problem so always pad. + pad = true; + } + } + if (pad) { // pad using DeviceGemmInstance = DeviceGemmHelper< 256, @@ -48,14 +49,10 @@ fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_i 1, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::MNKPadding>; + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, kernel_args, Y); - - // pad - } - else - { + } else { // no pad using DeviceGemmInstance = DeviceGemmHelper< 256, @@ -77,8 +74,6 @@ fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_i ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, kernel_args, Y); - - // no pad } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_common.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_common.h index a27721d74f..0232273fa2 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_common.h +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_common.h @@ -133,14 +133,15 @@ OutputType f8f8bf16_rowwise_grouped_impl( // Get input information. int group_count; if constexpr (std::is_same_v) { - // Two different modes when inputs are tensors. - // If XQ is 3D then its shape is [G, M, K]. - // If its 2D then its shape is [total_M, K]. - if (XQ.dim() == 2) { - // group count is the min of total_M and G. + if (WQ.dim() == 3) { + // If WQ is 3D the group count is the min of G and total_M (if XQ is 2D). + group_count = std::min(WQ.size(0), XQ.size(0)); + } else if (XQ.dim() == 3) { + // If XQ is 3D the group count is the min of G and total_N (if WQ is 2D). group_count = std::min(XQ.size(0), WQ.size(0)); } else { - group_count = WQ.size(0); + // XQ and WQ are 2D. The group count is G. + group_count = Y.size(0); } } else { group_count = XQ.size(); @@ -172,15 +173,15 @@ OutputType f8f8bf16_rowwise_grouped_impl( D1DataType* d1_ptr; // Populate arguments. for (int i = 0; i < group_count; i++) { - // Compute appropriate data pointers. - // Set the shape arguments for this gemm. + // Compute appropriate data pointers. The host problem shape and data + // pointers below are unused, as the device memory contains the correct + // data. if constexpr (std::is_same_v) { - M = XQ.size(XQ.dim() - 2); - N = WQ.size(1); - K = WQ.size(2); - // These pointers dont seem to actually be used since the kernel arguments - // contains the correct version. For simplicity, we just point to the - // start of the tensor. + // Set these to 0 as placeholders, they are unsused. + M = 0; + N = 0; + K = 0; + // For simplicity, we just point to the start of the tensor. a_ptr = reinterpret_cast(XQ.data_ptr()); b_ptr = reinterpret_cast(WQ.data_ptr()); d0_ptr = reinterpret_cast(w_scale.data_ptr()); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index 1bcb389292..a45630d1d0 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -31,6 +31,15 @@ namespace fbgemm_gpu { #ifdef USE_ROCM // flush icache void flush_icache_ck(); + +// Generic PyTorch grouped GEMM API is only available on AMD for now. +at::Tensor f8f8bf16_rowwise_grouped_mm( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + std::optional offsets, + at::Tensor& output); #endif // SmoothQuant kernels @@ -339,6 +348,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { #ifdef USE_ROCM m.impl("flush_icache_hip", flush_icache_ck); + m.impl("f8f8bf16_rowwise_grouped_mm", f8f8bf16_rowwise_grouped_mm); #endif #ifdef USE_ROCM m.impl("f8f8f16_rowwise", f8f8f16_rowwise); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp index 86ab705a57..ae3ab5851a 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp @@ -115,6 +115,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "f8f8bf16_rowwise_preshuffle(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor? bias=None, bool use_fast_accum=True) -> Tensor"); m.def( "f8f8f16_rowwise_preshuffle(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor? bias=None, bool use_fast_accum=True) -> Tensor"); + // Generic PyTorch grouped GEMM API is only available on AMD for now. + m.def( + "f8f8bf16_rowwise_grouped_mm(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor? offsets, Tensor(a!) output) -> Tensor"); #endif } diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index 913c41e1c0..d3d73fe0d1 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -10,7 +10,7 @@ import os import unittest -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import fbgemm_gpu.experimental.gen_ai # noqa: F401 @@ -41,6 +41,19 @@ running_on_github: bool = os.getenv("GITHUB_ENV") is not None + +def evaluate_platform_supports_fp8(): + if torch.cuda.is_available(): + if torch.version.hip: + return supports_float8_fnuz(throw_on_hip_incompatibility=False) + else: + # Only SM90 or later is supported + return torch.cuda.get_device_capability() >= (9, 0) + return False + + +SUPPORTS_FP8 = evaluate_platform_supports_fp8() + if torch.cuda.is_available() and supports_float8_fnuz( throw_on_hip_incompatibility=(not running_on_github) ): @@ -815,9 +828,7 @@ def test_quantize_fp8_per_tensor_with_ub( zq_ref = (x @ w.T).to(torch.bfloat16) torch.testing.assert_close(zq, zq_ref, atol=1.0e-3, rtol=1.0e-3) - @unittest.skipIf( - not torch.version.cuda, "Skip on AMD: BMM ops are not yet suported." - ) + @unittest.skipIf(not SUPPORTS_FP8, "FP8 not supported on this platform") @settings(deadline=None) @given( B=st.sampled_from([1, 4]), @@ -825,7 +836,10 @@ def test_quantize_fp8_per_tensor_with_ub( N=st.sampled_from([128, 256]), K=st.sampled_from([256, 512]), use_loopover=st.sampled_from([True, False]), - Bias=st.sampled_from([True, False]), + Bias=st.sampled_from([False] + ([True] if torch.version.cuda else [])), + mode=st.sampled_from( + ["default"] + (["torch_3d3d"] if torch.version.hip else []) + ), ) def test_fp8_batched_gemm( self, @@ -835,7 +849,13 @@ def test_fp8_batched_gemm( K: int, Bias: bool, use_loopover: bool, + mode: str, ) -> None: + # AMD CK FP8 batched gemm does not support N < 512 or K < 512. + # Funny enough, grouped gemm does not have this restriction. + if mode == "default" and torch.version.hip and (N < 512 or K < 512): + return + x = ( torch.rand( size=(B, M, K), @@ -897,37 +917,42 @@ def fp8_loopover_bmm( if use_loopover: y_fp8 = fp8_loopover_bmm(xq, wq, x_scale, w_scale, bias) else: - y_fp8 = torch.ops.fbgemm.f8f8bf16_rowwise_batched( - xq, wq, x_scale, w_scale, bias - ) + if mode == "default": + y_fp8 = torch.ops.fbgemm.f8f8bf16_rowwise_batched( + xq, wq, x_scale, w_scale, bias + ) + elif mode == "torch_3d3d": + y_fp8_ = torch.empty( + (B, M, N), dtype=torch.bfloat16, device=xq[0].device + ) + y_fp8 = torch.ops.fbgemm.f8f8bf16_rowwise_grouped_mm( + xq, + wq, + x_scale, + w_scale, + None, + y_fp8_, + ) + torch.testing.assert_close(y_ref, y_fp8, atol=8.0e-2, rtol=8.0e-2) - @unittest.skipIf( - torch.version.hip is not None and running_on_github, - "type fp8e4b8 not supported in this architecture. The supported fp8 dtypes are ('fp8e5',)", - ) - @unittest.skipIf( - not torch.version.cuda and torch.version.hip < "6.2", - "Skip on AMD with < RoCM 6.2", - ) + @unittest.skipIf(not SUPPORTS_FP8, "FP8 not supported on this platform") @settings(deadline=None) @given( G=st.sampled_from([1, 4, 5, 16]), M=st.sampled_from([0, 2048, 3584]), - N=st.sampled_from([1024, 6144]), - K=st.sampled_from([512, 3584]), + N=st.sampled_from([256, 1024, 6144]), + K=st.sampled_from([256, 512, 3584]), use_cudagraph=st.booleans(), - mode=st.sampled_from(["default", "cat", "padded", "stacked"]), + mode=st.sampled_from(["default", "cat", "padded"]), ) - def test_grouped_gemm( - self, - G: int, - M: int, - N: int, - K: int, - use_cudagraph: bool, - mode: str, - ) -> None: + def test_grouped_gemm_fbgemm_api( + self, G: int, M: int, N: int, K: int, use_cudagraph: bool, mode: str + ): + # TODO remove this restriction. + if N < 512 or K < 512: + return + if M > 0: ms = ( torch.randint( @@ -995,16 +1020,7 @@ def test_grouped_gemm( wq_group = torch.stack(wq_group, dim=0).contiguous() x_scale_group = torch.stack(x_scale_group, dim=0).contiguous() w_scale_group = torch.stack(w_scale_group, dim=0).contiguous() - elif mode == "stacked": - x_group = torch.cat(x_group, dim=0).contiguous() - w_group = torch.stack(w_group, dim=0).contiguous() - xq_group = torch.cat(xq_group, dim=0).contiguous() - wq_group = torch.stack(wq_group, dim=0).contiguous() - x_scale_group = torch.cat(x_scale_group, dim=0).contiguous() - w_scale_group = torch.stack(w_scale_group, dim=0).contiguous() - # FP8 grouped gemm kernel - if mode == "padded": fp8_op = torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic bf16_op = torch.ops.fbgemm.bf16bf16bf16_grouped_dynamic fp8_args = [ @@ -1015,12 +1031,6 @@ def test_grouped_gemm( zero_start_index_M, ] bf16_args = [x_group, w_group, zero_start_index_M] - elif mode == "stacked": - fp8_op = torch.ops.fbgemm.f8f8bf16_rowwise_grouped_stacked - M_sizes = ms.to(device=self.device, dtype=torch.int64) - fp8_args = [xq_group, wq_group, x_scale_group, w_scale_group, M_sizes] - bf16_op = torch.ops.fbgemm.bf16bf16bf16_grouped_stacked - bf16_args = [x_group, w_group, M_sizes] else: if mode == "cat": fp8_op = torch.ops.fbgemm.f8f8bf16_rowwise_grouped_cat @@ -1067,26 +1077,266 @@ def test_grouped_gemm( else: y_bf16_group = torch.unbind(y_bf16_group) - # BF16 loopover gemm reference - # unstack input to make it compatible with loopover. - if mode == "stacked": - x_group = torch.split(x_group, tuple(ms.tolist()), dim=0) - y_group_ref = [] - for i in range(len(x_group)): - y = torch.matmul(x_group[i], w_group[i].t()) - y_group_ref.append(y) + self.bf16_loopover_validate( + x_group, + w_group, + y_fp8_group, + y_bf16_group, + # default mode is worse for some reason + rtol_fp8=2.0e-1 if mode == "default" else 8.0e-2, + ) - # Assert FP8 outputs - for i in range(len(y_group_ref)): + def bf16_loopover_validate( + self, + x: Union[torch.Tensor, list[torch.Tensor]], + w: Union[torch.Tensor, list[torch.Tensor]], + out_fp8: Union[torch.tensor, list[torch.Tensor]], + out_bf16: Union[torch.tensor, list[torch.Tensor], None] = None, + atol_fp8=8.0e-2, + rtol_fp8=8.0e-2, + atol_bf16=8.0e-3, + rtol_bf16=8.0e-3, + ): + out_ref = [torch.matmul(x[i], w[i].t()) for i in range(len(x))] + + for i in range(len(out_fp8)): torch.testing.assert_close( - y_fp8_group[i], y_group_ref[i], atol=8.0e-2, rtol=2.0e-1 + out_fp8[i], out_ref[i], atol=atol_fp8, rtol=rtol_fp8 ) - # Assert BF16 outputs - for i in range(len(y_group_ref)): - torch.testing.assert_close( - y_bf16_group[i], y_group_ref[i], atol=8.0e-3, rtol=8.0e-3 + if out_bf16: + for i in range(len(out_bf16)): + torch.testing.assert_close( + out_bf16[i], out_ref[i], atol=atol_bf16, rtol=rtol_bf16 + ) + + @unittest.skipIf(not SUPPORTS_FP8, "FP8 not supported on this platform") + @settings(deadline=None) + @given( + G=st.sampled_from([1, 4, 16]), + M=st.sampled_from([0, 2048, 3584]), + N=st.sampled_from([256, 1024, 6144]), + K=st.sampled_from([256, 512, 3584]), + use_cudagraph=st.booleans(), + mode=st.sampled_from(["stacked", "torch_2d3d"]), + ) + def test_grouped_gemm_2d_3d( + self, + G: int, + M: int, + N: int, + K: int, + use_cudagraph: bool, + mode: str, + ) -> None: + # TODO remove this restriction. + if (N < 512 or K < 512) and mode == "stacked": + return + + if M > 0: + M_sizes = ( + torch.randint( + 1, + (M // 64) + 1, + (G,), + dtype=torch.int, + ) + * 64 ) + else: + M_sizes = torch.zeros((G,), dtype=torch.int) + + M = torch.sum(M_sizes).item() + X = torch.randn((M, K), dtype=torch.bfloat16, device=self.device) * 0.1 + W = torch.randn((G, N, K), dtype=torch.bfloat16, device=self.device) * 0.01 + + xq, x_scale = quantize_fp8_row(X) + wq, w_scale = quantize_fp8_row(W) + + # FP8 grouped gemm kernel + if mode == "stacked": + fp8_op = torch.ops.fbgemm.f8f8bf16_rowwise_grouped_stacked + M_sizes_gpu = M_sizes.clone().to(device=self.device, dtype=torch.int64) + fp8_args = [xq, wq, x_scale, w_scale, M_sizes_gpu] + + bf16_op = torch.ops.fbgemm.bf16bf16bf16_grouped_stacked + bf16_args = [X, W, M_sizes_gpu] + elif mode == "torch_2d3d": + fp8_op = torch.ops.fbgemm.f8f8bf16_rowwise_grouped_mm + M_offsets = torch.cumsum(M_sizes, dim=0).to( + device=self.device, dtype=torch.int32 + ) + out = torch.empty(M, N).to(device=self.device, dtype=torch.bfloat16) + fp8_args = [ + xq, + wq, + x_scale, + w_scale, + M_offsets, + out, + ] + + bf16_op = None + bf16_args = None + + if use_cudagraph: + # warmup + fp8_op(*fp8_args) + # With cudagraph + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + y_fp8_group = fp8_op(*fp8_args) + g.replay() + else: + y_fp8_group = fp8_op(*fp8_args) + + # Massage output into proper format. + y_fp8_group = torch.split(y_fp8_group, tuple(M_sizes.tolist()), dim=0) + + # unstack input to make it compatible with loopover. + x_group = torch.split(X, tuple(M_sizes.tolist()), dim=0) + + y_bf16_group = None + if bf16_op is not None: + if use_cudagraph: + # warmup + bf16_op(*bf16_args) + # With cudagraph + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + y_bf16_group = bf16_op(*bf16_args) + g.replay() + else: + y_bf16_group = bf16_op(*bf16_args) + + y_bf16_group = torch.split(y_bf16_group, tuple(M_sizes.tolist()), dim=0) + + # BF16 loopover gemm reference + self.bf16_loopover_validate(x_group, W, y_fp8_group, y_bf16_group) + + @unittest.skipIf( + not torch.version.hip, + "Only AMD supports torch 3D-2D grouped gemm API", + ) + @unittest.skipIf(not SUPPORTS_FP8, "FP8 not supported on this platform") + @settings(deadline=None) + @given( + G=st.sampled_from([1, 4, 16]), + M=st.sampled_from([0, 64, 2048, 3584]), + N=st.sampled_from([64, 256, 1024, 6144]), + K=st.sampled_from([64, 256, 512, 3584]), + ) + def test_grouped_gemm_3d_2d( + self, + G: int, + M: int, + N: int, + K: int, + ) -> None: + N_sizes = ( + torch.randint( + 1, + (N // 64) + 1, + (G,), + dtype=torch.int, + ) + * 64 + ) + N = torch.sum(N_sizes).item() + N_offsets = torch.cumsum(N_sizes, dim=0).to( + device=self.device, dtype=torch.int32 + ) + + X = torch.randn((G, M, K), dtype=torch.bfloat16, device=self.device) * 0.1 + W = torch.randn((N, K), dtype=torch.bfloat16, device=self.device) * 0.01 + out = torch.empty((M, N), dtype=torch.bfloat16, device=self.device) + + xq, x_scale = quantize_fp8_row(X) + wq, w_scale = quantize_fp8_row(W) + + y = torch.ops.fbgemm.f8f8bf16_rowwise_grouped_mm( + xq, wq, x_scale, w_scale, N_offsets, out + ) + + # Compare using loopover BF16 gemm + y_fp8 = torch.split(y, tuple(N_sizes), dim=1) + W_split = torch.split(W, tuple(N_sizes), dim=0) + self.bf16_loopover_validate(X, W_split, y_fp8) + + @unittest.skipIf( + not torch.version.hip, + "Only AMD supports torch 2D-2D grouped gemm API", + ) + @unittest.skipIf(not SUPPORTS_FP8, "FP8 not supported on this platform") + @settings(deadline=None) + @given( + G=st.sampled_from([1, 4, 16]), + M=st.sampled_from([16, 2048, 3584]), + N=st.sampled_from([16, 256, 1024, 6144]), + K=st.sampled_from([16, 256, 512, 3584]), + use_cudagraph=st.booleans(), + ) + def test_grouped_gemm_2d_2d( + self, + G: int, + M: int, + N: int, + K: int, + use_cudagraph: bool, + ) -> None: + K_sizes = torch.ones((G,), dtype=torch.int, device=self.device) * K + K_offsets = torch.cumsum(K_sizes, dim=0).to( + device=self.device, dtype=torch.int32 + ) + + # Each group should be quantized rowwise separately + X_list = [] + W_list = [] + xq_list = [] + wq_list = [] + x_scale_list = [] + w_scale_list = [] + for k_size in K_sizes.tolist(): + X = torch.randn((M, k_size), dtype=torch.bfloat16, device=self.device) * 0.1 + W = ( + torch.randn((N, k_size), dtype=torch.bfloat16, device=self.device) + * 0.01 + ) + xq, x_scale = quantize_fp8_row(X) + wq, w_scale = quantize_fp8_row(W) + + X_list.append(X) + W_list.append(W) + xq_list.append(xq) + wq_list.append(wq) + x_scale_list.append(x_scale) + w_scale_list.append(w_scale) + + xq = torch.cat(xq_list, dim=1) + wq = torch.cat(wq_list, dim=1) + x_scale = torch.cat(x_scale_list, dim=0) + w_scale = torch.cat(w_scale_list, dim=0) + + out = torch.empty((G, M, N), dtype=torch.bfloat16, device=self.device) + + if use_cudagraph: + # warmup + torch.ops.fbgemm.f8f8bf16_rowwise_grouped_mm( + xq, wq, x_scale, w_scale, K_offsets, out + ) + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + y = torch.ops.fbgemm.f8f8bf16_rowwise_grouped_mm( + xq, wq, x_scale, w_scale, K_offsets, out + ) + g.replay() + else: + y = torch.ops.fbgemm.f8f8bf16_rowwise_grouped_mm( + xq, wq, x_scale, w_scale, K_offsets, out + ) + + # Compare using loopover BF16 gemm + self.bf16_loopover_validate(X_list, W_list, y) @unittest.skipIf(not torch.version.cuda, "Currently not supported on AMD.") @settings(deadline=None)