diff --git a/include/imex/ExecutionEngine/ImexRunnerUtils.h b/include/imex/ExecutionEngine/ImexRunnerUtils.h index 464ad6418..32cebb116 100644 --- a/include/imex/ExecutionEngine/ImexRunnerUtils.h +++ b/include/imex/ExecutionEngine/ImexRunnerUtils.h @@ -118,4 +118,19 @@ extern "C" IMEX_RUNNERUTILS_EXPORT void _mlir_ciface_printMaxErrorF32(UnrankedMemRefType *M, UnrankedMemRefType *N); +extern "C" IMEX_RUNNERUTILS_EXPORT void +_mlir_ciface_gemmF16F16F16(UnrankedMemRefType *A, + UnrankedMemRefType *B, + UnrankedMemRefType *C); + +extern "C" IMEX_RUNNERUTILS_EXPORT void +_mlir_ciface_gemmF16F16F32(UnrankedMemRefType *A, + UnrankedMemRefType *B, + UnrankedMemRefType *C); + +extern "C" IMEX_RUNNERUTILS_EXPORT void +_mlir_ciface_gemmBF16BF16F32(UnrankedMemRefType *A, + UnrankedMemRefType *B, + UnrankedMemRefType *C); + #endif // IMEX_EXECUTIONENGINE_IMEXRUNNERUTILS_H diff --git a/lib/ExecutionEngine/ImexRunnerUtils.cpp b/lib/ExecutionEngine/ImexRunnerUtils.cpp index 91fe921e8..9d512981c 100644 --- a/lib/ExecutionEngine/ImexRunnerUtils.cpp +++ b/lib/ExecutionEngine/ImexRunnerUtils.cpp @@ -13,6 +13,8 @@ //===----------------------------------------------------------------------===// #include "imex/ExecutionEngine/ImexRunnerUtils.h" +#include "mlir/ExecutionEngine/CRunnerUtils.h" +#include "mlir/ExecutionEngine/Float16bits.h" #include #include #include @@ -239,6 +241,89 @@ void _mlir_ciface_printMaxError(UnrankedMemRefType *M, << '\n'; } +template +void _mlir_ciface_gemm(UnrankedMemRefType *A, UnrankedMemRefType *B, + UnrankedMemRefType *C) { + DynamicMemRefType DA = DynamicMemRefType(*A); + DynamicMemRefType DB = DynamicMemRefType(*B); + DynamicMemRefType DC = DynamicMemRefType(*C); + if (DA.rank != 2 || DB.rank != 2 || DC.rank != 2) { + std::cout << "Expecting 2D memrefs, got A rank: " << DA.rank + << ", B rank: " << DB.rank << ", C rank: " << DC.rank << '\n'; + return; + } + if (DA.sizes[1] != DB.sizes[0] || DA.sizes[0] != DC.sizes[0] || + DB.sizes[1] != DC.sizes[1]) { + std::cout << "Incompatible matrix dimensions: A: [" << DA.sizes[0] << ", " + << DA.sizes[1] << "], B: [" << DB.sizes[0] << ", " << DB.sizes[1] + << "], C: [" << DC.sizes[0] << ", " << DC.sizes[1] << "]\n"; + return; + } + if ((DA.strides[0] != DA.sizes[1]) || (DA.strides[1] != 1) || + (DB.strides[0] != DB.sizes[1]) || (DB.strides[1] != 1) || + (DC.strides[0] != DC.sizes[1]) || (DC.strides[1] != 1)) { + std::cout << "Expecting A strides to be [M, 1], B strides to be [K, 1], C " + "strides to be [M, 1]\n"; + return; + } + int M = DA.sizes[0]; + int N = DB.sizes[1]; + int K = DA.sizes[1]; + + float *storageA = (float *)malloc(M * K * sizeof(float)); + float *storageB = (float *)malloc(K * N * sizeof(float)); + float *storageC = (float *)malloc(M * N * sizeof(float)); + + DynamicMemRefIterator aIt(DA); + for (int i = 0; i < M * K; ++i) { + storageA[i] = getFloat(*aIt); + ++aIt; + } + + DynamicMemRefIterator bIt(DB); + for (int i = 0; i < K * N; ++i) { + storageB[i] = getFloat(*bIt); + ++bIt; + } + + DynamicMemRefIterator cIt(DC); + for (int i = 0; i < M * N; ++i) { + storageC[i] = getFloat(*cIt); + ++cIt; + } + + for (int i = 0; i < M; ++i) { + for (int k = 0; k < K; ++k) { + for (int j = 0; j < N; ++j) { + int a_idx = i * DA.strides[0] + k; + int b_idx = k * DB.strides[0] + j; + int c_idx = i * DC.strides[0] + j; + storageC[c_idx] += storageA[a_idx] * storageB[b_idx]; + } + } + } + + // Store the result back to C. + DynamicMemRefIterator resultIt(DC); + for (int i = 0; i < M * N; ++i) { + float r = storageC[i]; + if constexpr (std::is_same_v) { + f16 casted(r); + *resultIt = getFloat(casted); + } else if constexpr (std::is_same_v) { + bf16 casted(r); + *resultIt = getFloat(casted); + } else if constexpr (std::is_same_v) { + *resultIt = storageC[i]; + } + ++resultIt; + } + + free(storageC); + free(storageA); + free(storageB); +} + extern "C" void _mlir_ciface_printMaxErrorF16(UnrankedMemRefType *M, UnrankedMemRefType *N) { _mlir_ciface_printMaxError(M, N); @@ -284,4 +369,22 @@ extern "C" void _mlir_ciface_printAllcloseF32(UnrankedMemRefType *M, _mlir_ciface_printAllclose(M, N); } +extern "C" void _mlir_ciface_gemmF16F16F32(UnrankedMemRefType *A, + UnrankedMemRefType *B, + UnrankedMemRefType *C) { + _mlir_ciface_gemm(A, B, C); +} + +extern "C" void _mlir_ciface_gemmF16F16F16(UnrankedMemRefType *A, + UnrankedMemRefType *B, + UnrankedMemRefType *C) { + _mlir_ciface_gemm(A, B, C); +} + +extern "C" void _mlir_ciface_gemmBF16BF16F32(UnrankedMemRefType *A, + UnrankedMemRefType *B, + UnrankedMemRefType *C) { + _mlir_ciface_gemm(A, B, C); +} + // NOLINTEND(*-identifier-naming) diff --git a/test/Integration/Dialect/XeGPU/SG/gemm_4kx4kx4k_f16_f16_f16.mlir b/test/Integration/Dialect/XeGPU/SG/gemm_4kx4kx4k_f16_f16_f16.mlir index 60cab8305..ef7771d28 100644 --- a/test/Integration/Dialect/XeGPU/SG/gemm_4kx4kx4k_f16_f16_f16.mlir +++ b/test/Integration/Dialect/XeGPU/SG/gemm_4kx4kx4k_f16_f16_f16.mlir @@ -380,38 +380,6 @@ module @gemm attributes {gpu.container_module} { } } - // compute CPU reference (takes minutes) - func.func @cpu_reference(%A : memref<4096x4096xf16>, %B : memref<4096x4096xf16>, %C : memref<4096x4096xf32>) { - %c4096 = arith.constant 4096 : index - %c16 = arith.constant 16 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %c64 = arith.constant 64 : index - scf.for %i = %c0 to %c4096 step %c1 { - scf.for %j = %c0 to %c4096 step %c1 { - %c_curr = memref.load %C[%i, %j] : memref<4096x4096xf32> - %c_val = scf.for %k_tile = %c0 to %c4096 step %c16 iter_args(%c_partial = %c_curr) -> f32 { - %c_val_dpas = scf.for %k = %c0 to %c16 step %c1 iter_args(%c_dpas_partial = %c_partial) -> f32 { - %k_dpas = arith.addi %k_tile, %k : index - %a_val = memref.load %A[%i, %k_dpas] : memref<4096x4096xf16> - %b_val = memref.load %B[%k_dpas, %j] : memref<4096x4096xf16> - %a_cast = arith.extf %a_val : f16 to f32 - %b_cast = arith.extf %b_val : f16 to f32 - %t = arith.mulf %a_cast, %b_cast : f32 - // %t_cast = arith.extf %t : f16 to f16 - %c_sum = arith.addf %t, %c_dpas_partial : f32 - scf.yield %c_sum : f32 - } - scf.yield %c_val_dpas : f32 - } - %c_val_f16 = arith.truncf %c_val : f32 to f16 - %c_val_ = arith.extf %c_val_f16 : f16 to f32 - memref.store %c_val_ , %C[%i, %j] : memref<4096x4096xf32> - } - } - return - } - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -463,7 +431,7 @@ module @gemm attributes {gpu.container_module} { call @fillResource1DRandomF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - // intialize matrix C and C_ref ; C[i, j] = 0 + // Initialize matrix C and C_ref ; C[i, j] = 0 %c0_f16 = arith.constant 0.0 : f16 %c0_f32 = arith.constant 0.0 : f32 scf.for %i = %c0 to %c4096 step %c1 { @@ -472,22 +440,16 @@ module @gemm attributes {gpu.container_module} { memref.store %c0_f32, %C_ref[%i, %j] : memref<4096x4096xf32> } } - // print input fror debug - // %A_row_0 = memref.subview %A[1, 0][1, 4096][1, 1] : memref<4096x4096xf16> to memref<1x4096xf16, strided<[4096, 1], offset: 4096>> - // %A_row_0_cast = memref.cast %A_row_0 : memref<1x4096xf16, strided<[4096, 1], offset: 4096>> to memref<*xf16> - // call @printMemrefF16(%A_row_0_cast) : (memref<*xf16>) -> () - // run GPU + // Run GPU version %2 = call @test(%A, %B, %C) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf16>) -> memref<4096x4096xf16> + %gpu_result_cast = memref.cast %2 : memref<4096x4096xf16> to memref<*xf16> - call @cpu_reference(%A, %B, %C_ref) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf32>) -> () - - // %cast = memref.cast %A : memref<4096x4096xf16> to memref<*xf16> - // call @printMemrefF16(%cast) : (memref<*xf16>) -> () - %cast_C = memref.cast %2 : memref<4096x4096xf16> to memref<*xf16> - %cast_C_ref = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> - // call @printMemrefF16(%cast_C) : (memref<*xf16>) -> () - // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () + // Run CPU version. + %A_cast = memref.cast %A : memref<4096x4096xf16> to memref<*xf16> + %B_cast = memref.cast %B : memref<4096x4096xf16> to memref<*xf16> + %C_cast = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> + call @gemmF16F16F16(%A_cast, %B_cast, %C_cast) : (memref<*xf16>, memref<*xf16>, memref<*xf32>) -> () %C_row_0 = memref.subview %C_ref[0, 0][1, 4096][1, 1] : memref<4096x4096xf32> to memref<1x4096xf32, strided<[4096, 1], offset:0>> %C_row_0_cast = memref.cast %C_row_0 : memref<1x4096xf32, strided<[4096, 1], offset: 0>> to memref<*xf32> @@ -498,7 +460,7 @@ module @gemm attributes {gpu.container_module} { // call @printMemrefF16(%C_row_0_cast_gpu) : (memref<*xf16>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF16(%cast_C, %cast_C_ref) : (memref<*xf16>, memref<*xf32>) -> () + call @printAllcloseF16(%gpu_result_cast, %C_cast) : (memref<*xf16>, memref<*xf32>) -> () memref.dealloc %A : memref<4096x4096xf16> memref.dealloc %B : memref<4096x4096xf16> memref.dealloc %C : memref<4096x4096xf16> @@ -510,5 +472,5 @@ module @gemm attributes {gpu.container_module} { func.func private @printAllcloseF16(memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @printMaxErrorF16(memref<*xf16>, memref<*xf16>) attributes {llvm.emit_c_interface} func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface} - + func.func private @gemmF16F16F16(memref<*xf16>, memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/SG/simple_gemm.mlir b/test/Integration/Dialect/XeGPU/SG/simple_gemm.mlir index 5e9c0d5d4..fce53a210 100644 --- a/test/Integration/Dialect/XeGPU/SG/simple_gemm.mlir +++ b/test/Integration/Dialect/XeGPU/SG/simple_gemm.mlir @@ -54,37 +54,6 @@ module @gemm attributes {gpu.container_module} { return %c : memref<256x256xf32> } - // compute CPU reference (takes minutes) - func.func @cpu_reference(%A : memref<256x256xf16>, %B : memref<256x256xf16>, %C : memref<256x256xf32>) { - %c256 = arith.constant 256 : index - %c16 = arith.constant 16 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - scf.for %i = %c0 to %c256 step %c1 { - scf.for %j = %c0 to %c256 step %c1 { - %c_curr = memref.load %C[%i, %j] : memref<256x256xf32> - %c_val = scf.for %k_tile = %c0 to %c256 step %c16 iter_args(%c_partial = %c_curr) -> f32 { - %c_val_dpas = scf.for %k = %c0 to %c16 step %c1 iter_args(%c_dpas_partial = %c_partial) -> f32 { - %k_dpas = arith.addi %k_tile, %k : index - %a_val = memref.load %A[%i, %k_dpas] : memref<256x256xf16> - %b_val = memref.load %B[%k_dpas, %j] : memref<256x256xf16> - %a_cast = arith.extf %a_val : f16 to f32 - %b_cast = arith.extf %b_val : f16 to f32 - %t = arith.mulf %a_cast, %b_cast : f32 - // %t_cast = arith.extf %t : f16 to f16 - %c_sum = arith.addf %t, %c_dpas_partial : f32 - scf.yield %c_sum : f32 - } - scf.yield %c_val_dpas : f32 - } - // %c_val_f16 = arith.truncf %c_val : f32 to f16 - // %c_val_ = arith.extf %c_val_f16 : f16 to f32 - memref.store %c_val , %C[%i, %j] : memref<256x256xf32> - } - } - return - } - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index @@ -145,22 +114,16 @@ module @gemm attributes {gpu.container_module} { memref.store %c0_f32, %C_ref[%i, %j] : memref<256x256xf32> } } - // print input fror debug - // %A_row_0 = memref.subview %A[1, 0][1, 256][1, 1] : memref<256x256xf16> to memref<1x256xf16, strided<[256, 1], offset: 256>> - // %A_row_0_cast = memref.cast %A_row_0 : memref<1x256xf16, strided<[256, 1], offset: 256>> to memref<*xf16> - // call @printMemrefF16(%A_row_0_cast) : (memref<*xf16>) -> () - // run GPU + // Run GPU. %2 = call @test(%A, %B, %C) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> memref<256x256xf32> - - call @cpu_reference(%A, %B, %C_ref) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> () - - // %cast = memref.cast %A : memref<256x256xf16> to memref<*xf16> - // call @printMemrefF16(%cast) : (memref<*xf16>) -> () %cast_C = memref.cast %2 : memref<256x256xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<256x256xf32> to memref<*xf32> - // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () - // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () + + // Run CPU. + %A_cast = memref.cast %A : memref<256x256xf16> to memref<*xf16> + %B_cast = memref.cast %B : memref<256x256xf16> to memref<*xf16> + %C_ref_cast = memref.cast %C_ref : memref<256x256xf32> to memref<*xf32> + call @gemmF16F16F32(%A_cast, %B_cast, %C_ref_cast) : (memref<*xf16>, memref<*xf16>, memref<*xf32>) -> () %C_row_0 = memref.subview %C_ref[0, 0][1, 256][1, 1] : memref<256x256xf32> to memref<1x256xf32, strided<[256, 1], offset:0>> %C_row_0_cast = memref.cast %C_row_0 : memref<1x256xf32, strided<[256, 1], offset: 0>> to memref<*xf32> @@ -171,7 +134,7 @@ module @gemm attributes {gpu.container_module} { // call @printMemrefF32(%C_row_0_cast_gpu) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () + call @printAllcloseF32(%cast_C, %C_ref_cast) : (memref<*xf32>, memref<*xf32>) -> () memref.dealloc %A : memref<256x256xf16> memref.dealloc %B : memref<256x256xf16> memref.dealloc %C : memref<256x256xf32> @@ -183,4 +146,5 @@ module @gemm attributes {gpu.container_module} { func.func private @printAllcloseF16(memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface} + func.func private @gemmF16F16F32(memref<*xf16>, memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/SIMT/gemm_4kx4kx4k_f16_f16_f16.mlir b/test/Integration/Dialect/XeGPU/SIMT/gemm_4kx4kx4k_f16_f16_f16.mlir index d50dbdf50..572558c6a 100644 --- a/test/Integration/Dialect/XeGPU/SIMT/gemm_4kx4kx4k_f16_f16_f16.mlir +++ b/test/Integration/Dialect/XeGPU/SIMT/gemm_4kx4kx4k_f16_f16_f16.mlir @@ -392,38 +392,6 @@ module @gemm attributes {gpu.container_module} { } } - // compute CPU reference (takes minutes) - func.func @cpu_reference(%A : memref<4096x4096xf16>, %B : memref<4096x4096xf16>, %C : memref<4096x4096xf32>) { - %c4096 = arith.constant 4096 : index - %c16 = arith.constant 16 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %c64 = arith.constant 64 : index - scf.for %i = %c0 to %c4096 step %c1 { - scf.for %j = %c0 to %c4096 step %c1 { - %c_curr = memref.load %C[%i, %j] : memref<4096x4096xf32> - %c_val = scf.for %k_tile = %c0 to %c4096 step %c16 iter_args(%c_partial = %c_curr) -> f32 { - %c_val_dpas = scf.for %k = %c0 to %c16 step %c1 iter_args(%c_dpas_partial = %c_partial) -> f32 { - %k_dpas = arith.addi %k_tile, %k : index - %a_val = memref.load %A[%i, %k_dpas] : memref<4096x4096xf16> - %b_val = memref.load %B[%k_dpas, %j] : memref<4096x4096xf16> - %a_cast = arith.extf %a_val : f16 to f32 - %b_cast = arith.extf %b_val : f16 to f32 - %t = arith.mulf %a_cast, %b_cast : f32 - // %t_cast = arith.extf %t : f16 to f16 - %c_sum = arith.addf %t, %c_dpas_partial : f32 - scf.yield %c_sum : f32 - } - scf.yield %c_val_dpas : f32 - } - %c_val_f16 = arith.truncf %c_val : f32 to f16 - %c_val_ = arith.extf %c_val_f16 : f16 to f32 - memref.store %c_val_ , %C[%i, %j] : memref<4096x4096xf32> - } - } - return - } - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -475,7 +443,7 @@ module @gemm attributes {gpu.container_module} { call @fillResource1DRandomF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - // intialize matrix C and C_ref ; C[i, j] = 0 + // Initialize matrix C and C_ref ; C[i, j] = 0 %c0_f16 = arith.constant 0.0 : f16 %c0_f32 = arith.constant 0.0 : f32 scf.for %i = %c0 to %c4096 step %c1 { @@ -484,22 +452,16 @@ module @gemm attributes {gpu.container_module} { memref.store %c0_f32, %C_ref[%i, %j] : memref<4096x4096xf32> } } - // print input fror debug - // %A_row_0 = memref.subview %A[1, 0][1, 4096][1, 1] : memref<4096x4096xf16> to memref<1x4096xf16, strided<[4096, 1], offset: 4096>> - // %A_row_0_cast = memref.cast %A_row_0 : memref<1x4096xf16, strided<[4096, 1], offset: 4096>> to memref<*xf16> - // call @printMemrefF16(%A_row_0_cast) : (memref<*xf16>) -> () - // run GPU + // Run GPU version %2 = call @test(%A, %B, %C) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf16>) -> memref<4096x4096xf16> + %gpu_result_cast = memref.cast %2 : memref<4096x4096xf16> to memref<*xf16> - call @cpu_reference(%A, %B, %C_ref) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf32>) -> () - - // %cast = memref.cast %A : memref<4096x4096xf16> to memref<*xf16> - // call @printMemrefF16(%cast) : (memref<*xf16>) -> () - %cast_C = memref.cast %2 : memref<4096x4096xf16> to memref<*xf16> - %cast_C_ref = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> - // call @printMemrefF16(%cast_C) : (memref<*xf16>) -> () - // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () + // Run CPU version. + %A_cast = memref.cast %A : memref<4096x4096xf16> to memref<*xf16> + %B_cast = memref.cast %B : memref<4096x4096xf16> to memref<*xf16> + %C_cast = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> + call @gemmF16F16F16(%A_cast, %B_cast, %C_cast) : (memref<*xf16>, memref<*xf16>, memref<*xf32>) -> () %C_row_0 = memref.subview %C_ref[0, 0][1, 4096][1, 1] : memref<4096x4096xf32> to memref<1x4096xf32, strided<[4096, 1], offset:0>> %C_row_0_cast = memref.cast %C_row_0 : memref<1x4096xf32, strided<[4096, 1], offset: 0>> to memref<*xf32> @@ -510,7 +472,7 @@ module @gemm attributes {gpu.container_module} { // call @printMemrefF16(%C_row_0_cast_gpu) : (memref<*xf16>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF16(%cast_C, %cast_C_ref) : (memref<*xf16>, memref<*xf32>) -> () + call @printAllcloseF16(%gpu_result_cast, %C_cast) : (memref<*xf16>, memref<*xf32>) -> () memref.dealloc %A : memref<4096x4096xf16> memref.dealloc %B : memref<4096x4096xf16> memref.dealloc %C : memref<4096x4096xf16> @@ -522,5 +484,6 @@ module @gemm attributes {gpu.container_module} { func.func private @printAllcloseF16(memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @printMaxErrorF16(memref<*xf16>, memref<*xf16>) attributes {llvm.emit_c_interface} func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface} + func.func private @gemmF16F16F16(memref<*xf16>, memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/SIMT/gemm_4kx4kx4k_transpose_b.mlir b/test/Integration/Dialect/XeGPU/SIMT/gemm_4kx4kx4k_transpose_b.mlir index 1eaf1d147..c251811f4 100644 --- a/test/Integration/Dialect/XeGPU/SIMT/gemm_4kx4kx4k_transpose_b.mlir +++ b/test/Integration/Dialect/XeGPU/SIMT/gemm_4kx4kx4k_transpose_b.mlir @@ -417,38 +417,6 @@ module @gemm attributes {gpu.container_module} { } } - // compute CPU reference (takes minutes) - func.func @cpu_reference(%A : memref<4096x4096xf16>, %B : memref<4096x4096xf16>, %C : memref<4096x4096xf32>) { - %c4096 = arith.constant 4096 : index - %c16 = arith.constant 16 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %c64 = arith.constant 64 : index - scf.for %i = %c0 to %c4096 step %c1 { - scf.for %j = %c0 to %c4096 step %c1 { - %c_curr = memref.load %C[%i, %j] : memref<4096x4096xf32> - %c_val = scf.for %k_tile = %c0 to %c4096 step %c16 iter_args(%c_partial = %c_curr) -> f32 { - %c_val_dpas = scf.for %k = %c0 to %c16 step %c1 iter_args(%c_dpas_partial = %c_partial) -> f32 { - %k_dpas = arith.addi %k_tile, %k : index - %a_val = memref.load %A[%i, %k_dpas] : memref<4096x4096xf16> - %b_val = memref.load %B[%j, %k_dpas] : memref<4096x4096xf16> - %a_cast = arith.extf %a_val : f16 to f32 - %b_cast = arith.extf %b_val : f16 to f32 - %t = arith.mulf %a_cast, %b_cast : f32 - // %t_cast = arith.extf %t : f16 to f16 - %c_sum = arith.addf %t, %c_dpas_partial : f32 - scf.yield %c_sum : f32 - } - scf.yield %c_val_dpas : f32 - } - %c_val_f16 = arith.truncf %c_val : f32 to f16 - %c_val_ = arith.extf %c_val_f16 : f16 to f32 - memref.store %c_val_ , %C[%i, %j] : memref<4096x4096xf32> - } - } - return - } - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -500,7 +468,7 @@ module @gemm attributes {gpu.container_module} { call @fillResource1DRandomF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - // intialize matrix C and C_ref ; C[i, j] = 0 + // Initialize matrix C and C_ref ; C[i, j] = 0 %c0_f16 = arith.constant 0.0 : f16 %c0_f32 = arith.constant 0.0 : f32 scf.for %i = %c0 to %c4096 step %c1 { @@ -509,22 +477,24 @@ module @gemm attributes {gpu.container_module} { memref.store %c0_f32, %C_ref[%i, %j] : memref<4096x4096xf32> } } - // print input fror debug - // %A_row_0 = memref.subview %A[1, 0][1, 4096][1, 1] : memref<4096x4096xf16> to memref<1x4096xf16, strided<[4096, 1], offset: 4096>> - // %A_row_0_cast = memref.cast %A_row_0 : memref<1x4096xf16, strided<[4096, 1], offset: 4096>> to memref<*xf16> - // call @printMemrefF16(%A_row_0_cast) : (memref<*xf16>) -> () - // run GPU + // Run GPU version. %2 = call @test(%A, %B, %C) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf16>) -> memref<4096x4096xf16> + %gpu_result_cast = memref.cast %2 : memref<4096x4096xf16> to memref<*xf16> - call @cpu_reference(%A, %B, %C_ref) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf32>) -> () - - // %cast = memref.cast %A : memref<4096x4096xf16> to memref<*xf16> - // call @printMemrefF16(%cast) : (memref<*xf16>) -> () - %cast_C = memref.cast %2 : memref<4096x4096xf16> to memref<*xf16> - %cast_C_ref = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> - // call @printMemrefF16(%cast_C) : (memref<*xf16>) -> () - // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () + // Run CPU version. + // Construct a non transposed version of B for validating the results using imex runtime calls. + %B_non_tranposed = memref.alloc() : memref<4096x4096xf16> + scf.for %i = %c0 to %c4096 step %c1 { + scf.for %j = %c0 to %c4096 step %c1 { + %b_val = memref.load %B[%j, %i] : memref<4096x4096xf16> + memref.store %b_val, %B_non_tranposed[%i, %j] : memref<4096x4096xf16> + } + } + %A_cast = memref.cast %A : memref<4096x4096xf16> to memref<*xf16> + %B_cast = memref.cast %B_non_tranposed : memref<4096x4096xf16> to memref<*xf16> + %C_cast = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> + call @gemmF16F16F16(%A_cast, %B_cast, %C_cast) : (memref<*xf16>, memref<*xf16>, memref<*xf32>) -> () %C_row_0 = memref.subview %C_ref[0, 0][1, 4096][1, 1] : memref<4096x4096xf32> to memref<1x4096xf32, strided<[4096, 1], offset:0>> %C_row_0_cast = memref.cast %C_row_0 : memref<1x4096xf32, strided<[4096, 1], offset: 0>> to memref<*xf32> @@ -535,11 +505,12 @@ module @gemm attributes {gpu.container_module} { // call @printMemrefF16(%C_row_0_cast_gpu) : (memref<*xf16>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF16(%cast_C, %cast_C_ref) : (memref<*xf16>, memref<*xf32>) -> () + call @printAllcloseF16(%gpu_result_cast, %C_cast) : (memref<*xf16>, memref<*xf32>) -> () memref.dealloc %A : memref<4096x4096xf16> memref.dealloc %B : memref<4096x4096xf16> memref.dealloc %C : memref<4096x4096xf16> memref.dealloc %C_ref : memref<4096x4096xf32> + memref.dealloc %B_non_tranposed : memref<4096x4096xf16> return } func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} @@ -547,5 +518,6 @@ module @gemm attributes {gpu.container_module} { func.func private @printAllcloseF16(memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @printMaxErrorF16(memref<*xf16>, memref<*xf16>) attributes {llvm.emit_c_interface} func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface} + func.func private @gemmF16F16F16(memref<*xf16>, memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir b/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir index 6e6f8d31a..e48a60245 100644 --- a/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir +++ b/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir @@ -56,38 +56,6 @@ module @gemm attributes {gpu.container_module} { return %c : memref<256x256xf32> } - // compute CPU reference (takes minutes) - func.func @cpu_reference(%A : memref<256x256xf16>, %B : memref<256x256xf16>, %C : memref<256x256xf32>) { - %c256 = arith.constant 256 : index - %c16 = arith.constant 16 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - scf.for %i = %c0 to %c256 step %c1 { - scf.for %j = %c0 to %c256 step %c1 { - %c_curr = memref.load %C[%i, %j] : memref<256x256xf32> - %c_val = scf.for %k_tile = %c0 to %c256 step %c16 iter_args(%c_partial = %c_curr) -> f32 { - %c_val_dpas = scf.for %k = %c0 to %c16 step %c1 iter_args(%c_dpas_partial = %c_partial) -> f32 { - %k_dpas = arith.addi %k_tile, %k : index - %a_val = memref.load %A[%i, %k_dpas] : memref<256x256xf16> - %b_val = memref.load %B[%k_dpas, %j] : memref<256x256xf16> - %a_cast = arith.extf %a_val : f16 to f32 - %b_cast = arith.extf %b_val : f16 to f32 - %t = arith.mulf %a_cast, %b_cast : f32 - // %t_cast = arith.extf %t : f16 to f16 - %c_sum = arith.addf %t, %c_dpas_partial : f32 - scf.yield %c_sum : f32 - } - scf.yield %c_val_dpas : f32 - } - // %c_val_f16 = arith.truncf %c_val : f32 to f16 - // %c_val_ = arith.extf %c_val_f16 : f16 to f32 - memref.store %c_val , %C[%i, %j] : memref<256x256xf32> - } - } - return - } - - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -139,7 +107,7 @@ module @gemm attributes {gpu.container_module} { call @fillResource1DRandomF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - // intialize matrix C and C_ref ; C[i, j] = 0 + // Initialize matrix C and C_ref ; C[i, j] = 0 %c0_f32 = arith.constant 0.0 : f32 scf.for %i = %c0 to %c256 step %c1 { scf.for %j = %c0 to %c256 step %c1 { @@ -147,22 +115,16 @@ module @gemm attributes {gpu.container_module} { memref.store %c0_f32, %C_ref[%i, %j] : memref<256x256xf32> } } - // print input fror debug - // %A_row_0 = memref.subview %A[1, 0][1, 256][1, 1] : memref<256x256xf16> to memref<1x256xf16, strided<[256, 1], offset: 256>> - // %A_row_0_cast = memref.cast %A_row_0 : memref<1x256xf16, strided<[256, 1], offset: 256>> to memref<*xf16> - // call @printMemrefF16(%A_row_0_cast) : (memref<*xf16>) -> () - // run GPU + // Run GPU version. %2 = call @test(%A, %B, %C) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> memref<256x256xf32> + %gpu_result_cast = memref.cast %2 : memref<256x256xf32> to memref<*xf32> - call @cpu_reference(%A, %B, %C_ref) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> () - - // %cast = memref.cast %A : memref<256x256xf16> to memref<*xf16> - // call @printMemrefF16(%cast) : (memref<*xf16>) -> () - %cast_C = memref.cast %2 : memref<256x256xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<256x256xf32> to memref<*xf32> - // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () - // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () + // Run CPU version. + %A_cast = memref.cast %A : memref<256x256xf16> to memref<*xf16> + %B_cast = memref.cast %B : memref<256x256xf16> to memref<*xf16> + %C_cast = memref.cast %C_ref : memref<256x256xf32> to memref<*xf32> + call @gemmF16F16F32(%A_cast, %B_cast, %C_cast) : (memref<*xf16>, memref<*xf16>, memref<*xf32>) -> () %C_row_0 = memref.subview %C_ref[0, 0][1, 256][1, 1] : memref<256x256xf32> to memref<1x256xf32, strided<[256, 1], offset:0>> %C_row_0_cast = memref.cast %C_row_0 : memref<1x256xf32, strided<[256, 1], offset: 0>> to memref<*xf32> @@ -173,7 +135,7 @@ module @gemm attributes {gpu.container_module} { // call @printMemrefF32(%C_row_0_cast_gpu) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () + call @printAllcloseF32(%gpu_result_cast, %C_cast) : (memref<*xf32>, memref<*xf32>) -> () memref.dealloc %A : memref<256x256xf16> memref.dealloc %B : memref<256x256xf16> memref.dealloc %C : memref<256x256xf32> @@ -185,4 +147,5 @@ module @gemm attributes {gpu.container_module} { func.func private @printAllcloseF16(memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface} + func.func private @gemmF16F16F32(memref<*xf16>, memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/gemm_256x256x256_bf16_bf16_f32.mlir b/test/Integration/Dialect/XeGPU/VC/gemm_256x256x256_bf16_bf16_f32.mlir index 39143053d..b08a60638 100644 --- a/test/Integration/Dialect/XeGPU/VC/gemm_256x256x256_bf16_bf16_f32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gemm_256x256x256_bf16_bf16_f32.mlir @@ -500,34 +500,6 @@ module @gemm attributes {gpu.container_module} { } } - // compute CPU reference (takes minutes) - func.func @cpu_reference(%A : memref<256x256xbf16>, %B : memref<256x256xbf16>, %C : memref<256x256xf32>) { - %c256 = arith.constant 256 : index - %c16 = arith.constant 16 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - scf.for %i = %c0 to %c256 step %c1 { - scf.for %j = %c0 to %c256 step %c1 { - %c_curr = memref.load %C[%i, %j] : memref<256x256xf32> - %c_val = scf.for %k_tile = %c0 to %c256 step %c16 iter_args(%c_partial = %c_curr) -> f32 { - %c_val_dpas = scf.for %k = %c0 to %c16 step %c1 iter_args(%c_dpas_partial = %c_partial) -> f32 { - %k_dpas = arith.addi %k_tile, %k : index - %a_val = memref.load %A[%i, %k_dpas] : memref<256x256xbf16> - %b_val = memref.load %B[%k_dpas, %j] : memref<256x256xbf16> - %a_cast = arith.extf %a_val : bf16 to f32 - %b_cast = arith.extf %b_val : bf16 to f32 - %t = arith.mulf %a_cast, %b_cast : f32 - %c_sum = arith.addf %t, %c_dpas_partial : f32 - scf.yield %c_sum : f32 - } - scf.yield %c_val_dpas : f32 - } - memref.store %c_val , %C[%i, %j] : memref<256x256xf32> - } - } - return - } - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -589,16 +561,18 @@ module @gemm attributes {gpu.container_module} { } } - // run GPU + // Run GPU %2 = call @test(%A, %B, %C) : (memref<256x256xbf16>, memref<256x256xbf16>, memref<256x256xf32>) -> memref<256x256xf32> + %cast_C = memref.cast %2 : memref<256x256xf32> to memref<*xf32> - // run CPU - call @cpu_reference(%A, %B, %C_ref) : (memref<256x256xbf16>, memref<256x256xbf16>, memref<256x256xf32>) -> () + // Run CPU + %A_cast = memref.cast %A : memref<256x256xbf16> to memref<*xbf16> + %B_cast = memref.cast %B : memref<256x256xbf16> to memref<*xbf16> + %C_ref_cast = memref.cast %C_ref : memref<256x256xf32> to memref<*xf32> + call @gemmBF16BF16F32(%A_cast, %B_cast, %C_ref_cast) : (memref<*xbf16>, memref<*xbf16>, memref<*xf32>) -> () - %cast_C = memref.cast %2 : memref<256x256xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<256x256xf32> to memref<*xf32> // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () + call @printAllcloseF32(%cast_C, %C_ref_cast) : (memref<*xf32>, memref<*xf32>) -> () memref.dealloc %A : memref<256x256xbf16> memref.dealloc %B : memref<256x256xbf16> memref.dealloc %C : memref<256x256xf32> @@ -611,5 +585,6 @@ module @gemm attributes {gpu.container_module} { func.func private @printAllcloseBF16(memref<*xbf16>, memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @fillResource1DRandomBF16(memref<*xbf16>, f32, f32, i1) attributes {llvm.emit_c_interface} + func.func private @gemmBF16BF16F32(memref<*xbf16>, memref<*xbf16>, memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_bf16_bf16_f32_xetla_like_load_store_prefetch.mlir b/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_bf16_bf16_f32_xetla_like_load_store_prefetch.mlir index 9f7616000..4139f5122 100644 --- a/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_bf16_bf16_f32_xetla_like_load_store_prefetch.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_bf16_bf16_f32_xetla_like_load_store_prefetch.mlir @@ -507,34 +507,6 @@ module @gemm attributes {gpu.container_module} { } } - // compute CPU reference (takes minutes) - func.func @cpu_reference(%A : memref<4096x4096xbf16>, %B : memref<4096x4096xbf16>, %C : memref<4096x4096xf32>) { - %c4096 = arith.constant 4096 : index - %c16 = arith.constant 16 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - scf.for %i = %c0 to %c4096 step %c1 { - scf.for %j = %c0 to %c4096 step %c1 { - %c_curr = memref.load %C[%i, %j] : memref<4096x4096xf32> - %c_val = scf.for %k_tile = %c0 to %c4096 step %c16 iter_args(%c_partial = %c_curr) -> f32 { - %c_val_dpas = scf.for %k = %c0 to %c16 step %c1 iter_args(%c_dpas_partial = %c_partial) -> f32 { - %k_dpas = arith.addi %k_tile, %k : index - %a_val = memref.load %A[%i, %k_dpas] : memref<4096x4096xbf16> - %b_val = memref.load %B[%k_dpas, %j] : memref<4096x4096xbf16> - %a_cast = arith.extf %a_val : bf16 to f32 - %b_cast = arith.extf %b_val : bf16 to f32 - %t = arith.mulf %a_cast, %b_cast : f32 - %c_sum = arith.addf %t, %c_dpas_partial : f32 - scf.yield %c_sum : f32 - } - scf.yield %c_val_dpas : f32 - } - memref.store %c_val , %C[%i, %j] : memref<4096x4096xf32> - } - } - return - } - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -596,16 +568,18 @@ module @gemm attributes {gpu.container_module} { } } - // run GPU + // Run GPU %2 = call @test(%A, %B, %C) : (memref<4096x4096xbf16>, memref<4096x4096xbf16>, memref<4096x4096xf32>) -> memref<4096x4096xf32> + %cast_C = memref.cast %2 : memref<4096x4096xf32> to memref<*xf32> - // run CPU - call @cpu_reference(%A, %B, %C_ref) : (memref<4096x4096xbf16>, memref<4096x4096xbf16>, memref<4096x4096xf32>) -> () + // Run CPU. + %A_cast = memref.cast %A : memref<4096x4096xbf16> to memref<*xbf16> + %B_cast = memref.cast %B : memref<4096x4096xbf16> to memref<*xbf16> + %C_cast = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> + call @gemmBF16BF16F32(%A_cast, %B_cast, %C_cast) : (memref<*xbf16>, memref<*xbf16>, memref<*xf32>) -> () - %cast_C = memref.cast %2 : memref<4096x4096xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () + call @printAllcloseF32(%cast_C, %C_cast) : (memref<*xf32>, memref<*xf32>) -> () memref.dealloc %A : memref<4096x4096xbf16> memref.dealloc %B : memref<4096x4096xbf16> memref.dealloc %C : memref<4096x4096xf32> @@ -617,5 +591,6 @@ module @gemm attributes {gpu.container_module} { func.func private @printAllcloseBF16(memref<*xbf16>, memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @fillResource1DRandomBF16(memref<*xbf16>, f32, f32, i1) attributes {llvm.emit_c_interface} + func.func private @gemmBF16BF16F32(memref<*xbf16>, memref<*xbf16>, memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_dpas_sized_loads_f16_f16_f32.mlir b/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_dpas_sized_loads_f16_f16_f32.mlir index e0ee10a29..fe71252d4 100644 --- a/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_dpas_sized_loads_f16_f16_f32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_dpas_sized_loads_f16_f16_f32.mlir @@ -384,37 +384,6 @@ module @gemm attributes {gpu.container_module} { } } - // compute CPU reference (takes minutes) - func.func @cpu_reference(%A : memref<4096x4096xf16>, %B : memref<4096x4096xf16>, %C : memref<4096x4096xf32>) { - %c4096 = arith.constant 4096 : index - %c16 = arith.constant 16 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - scf.for %i = %c0 to %c4096 step %c1 { - scf.for %j = %c0 to %c4096 step %c1 { - %c_curr = memref.load %C[%i, %j] : memref<4096x4096xf32> - %c_val = scf.for %k_tile = %c0 to %c4096 step %c16 iter_args(%c_partial = %c_curr) -> f32 { - %c_val_dpas = scf.for %k = %c0 to %c16 step %c1 iter_args(%c_dpas_partial = %c_partial) -> f32 { - %k_dpas = arith.addi %k_tile, %k : index - %a_val = memref.load %A[%i, %k_dpas] : memref<4096x4096xf16> - %b_val = memref.load %B[%k_dpas, %j] : memref<4096x4096xf16> - %a_cast = arith.extf %a_val : f16 to f32 - %b_cast = arith.extf %b_val : f16 to f32 - %t = arith.mulf %a_cast, %b_cast : f32 - // %t_cast = arith.extf %t : f16 to f16 - %c_sum = arith.addf %t, %c_dpas_partial : f32 - scf.yield %c_sum : f32 - } - scf.yield %c_val_dpas : f32 - } - // %c_val_f16 = arith.truncf %c_val : f32 to f16 - // %c_val_ = arith.extf %c_val_f16 : f16 to f32 - memref.store %c_val , %C[%i, %j] : memref<4096x4096xf32> - } - } - return - } - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -474,28 +443,19 @@ module @gemm attributes {gpu.container_module} { memref.store %c0_f32, %C_ref[%i, %j] : memref<4096x4096xf32> } } - // print input fror debug - // %A_row_0 = memref.subview %A[1, 0][1, 4096][1, 1] : memref<4096x4096xf16> to memref<1x4096xf16, strided<[4096, 1], offset: 4096>> - // %A_row_0_cast = memref.cast %A_row_0 : memref<1x4096xf16, strided<[4096, 1], offset: 4096>> to memref<*xf16> - // call @printMemrefF16(%A_row_0_cast) : (memref<*xf16>) -> () - // run GPU + // Run GPU. %2 = call @test(%A, %B, %C) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf32>) -> memref<4096x4096xf32> + %cast_C = memref.cast %2 : memref<4096x4096xf32> to memref<*xf32> - // run CPU - call @cpu_reference(%A, %B, %C_ref) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf32>) -> () + // Run CPU. + %A_cast = memref.cast %A : memref<4096x4096xf16> to memref<*xf16> + %B_cast = memref.cast %B : memref<4096x4096xf16> to memref<*xf16> + %C_ref_cast = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> + call @gemmF16F16F32(%A_cast, %B_cast, %C_ref_cast) : (memref<*xf16>, memref<*xf16>, memref<*xf32>) -> () - // %cast = memref.cast %A : memref<4096x4096xf16> to memref<*xf16> - // call @printMemrefF16(%cast) : (memref<*xf16>) -> () - %cast_C = memref.cast %2 : memref<4096x4096xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> - // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () - // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () - // %C_row_0 = memref.subview %2[1, 0][1, 4096][1, 1] : memref<4096x4096xf32> to memref<1x4096xf32, strided<[4096, 1], offset: 4096>> - // %C_row_0_cast = memref.cast %C_row_0 : memref<1x4096xf32, strided<[4096, 1], offset: 4096>> to memref<*xf32> - // call @printMemrefF32(%C_row_0_cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () + call @printAllcloseF32(%cast_C, %C_ref_cast) : (memref<*xf32>, memref<*xf32>) -> () memref.dealloc %A : memref<4096x4096xf16> memref.dealloc %B : memref<4096x4096xf16> memref.dealloc %C : memref<4096x4096xf32> @@ -506,5 +466,6 @@ module @gemm attributes {gpu.container_module} { func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface} + func.func private @gemmF16F16F32(memref<*xf16>, memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_f16_f16_f16.mlir b/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_f16_f16_f16.mlir index 6d1c6f55e..a3e2b901d 100644 --- a/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_f16_f16_f16.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_f16_f16_f16.mlir @@ -440,37 +440,6 @@ module @gemm attributes {gpu.container_module} { } } - // compute CPU reference (takes minutes) - func.func @cpu_reference(%A : memref<4096x4096xf16>, %B : memref<4096x4096xf16>, %C : memref<4096x4096xf32>) { - %c4096 = arith.constant 4096 : index - %c16 = arith.constant 16 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - scf.for %i = %c0 to %c4096 step %c1 { - scf.for %j = %c0 to %c4096 step %c1 { - %c_curr = memref.load %C[%i, %j] : memref<4096x4096xf32> - %c_val = scf.for %k_tile = %c0 to %c4096 step %c16 iter_args(%c_partial = %c_curr) -> f32 { - %c_val_dpas = scf.for %k = %c0 to %c16 step %c1 iter_args(%c_dpas_partial = %c_partial) -> f32 { - %k_dpas = arith.addi %k_tile, %k : index - %a_val = memref.load %A[%i, %k_dpas] : memref<4096x4096xf16> - %b_val = memref.load %B[%k_dpas, %j] : memref<4096x4096xf16> - %a_cast = arith.extf %a_val : f16 to f32 - %b_cast = arith.extf %b_val : f16 to f32 - %t = arith.mulf %a_cast, %b_cast : f32 - // %t_cast = arith.extf %t : f16 to f16 - %c_sum = arith.addf %t, %c_dpas_partial : f32 - scf.yield %c_sum : f32 - } - scf.yield %c_val_dpas : f32 - } - %c_val_f16 = arith.truncf %c_val : f32 to f16 - %c_val_ = arith.extf %c_val_f16 : f16 to f32 - memref.store %c_val_ , %C[%i, %j] : memref<4096x4096xf32> - } - } - return - } - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -531,22 +500,17 @@ module @gemm attributes {gpu.container_module} { memref.store %c0_f32, %C_ref[%i, %j] : memref<4096x4096xf32> } } - // print input fror debug - // %A_row_0 = memref.subview %A[1, 0][1, 4096][1, 1] : memref<4096x4096xf16> to memref<1x4096xf16, strided<[4096, 1], offset: 4096>> - // %A_row_0_cast = memref.cast %A_row_0 : memref<1x4096xf16, strided<[4096, 1], offset: 4096>> to memref<*xf16> - // call @printMemrefF16(%A_row_0_cast) : (memref<*xf16>) -> () - // run GPU + // Run GPU. %2 = call @test(%A, %B, %C) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf16>) -> memref<4096x4096xf16> + %cast_C = memref.cast %2 : memref<4096x4096xf16> to memref<*xf16> - call @cpu_reference(%A, %B, %C_ref) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf32>) -> () + // Run CPU. + %A_cast = memref.cast %A : memref<4096x4096xf16> to memref<*xf16> + %B_cast = memref.cast %B : memref<4096x4096xf16> to memref<*xf16> + %C_ref_cast = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> + call @gemmF16F16F16(%A_cast, %B_cast, %C_ref_cast) : (memref<*xf16>, memref<*xf16>, memref<*xf32>) -> () - // %cast = memref.cast %A : memref<4096x4096xf16> to memref<*xf16> - // call @printMemrefF16(%cast) : (memref<*xf16>) -> () - %cast_C = memref.cast %2 : memref<4096x4096xf16> to memref<*xf16> - %cast_C_ref = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> - // call @printMemrefF16(%cast_C) : (memref<*xf16>) -> () - // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () %C_row_0 = memref.subview %C_ref[0, 0][1, 4096][1, 1] : memref<4096x4096xf32> to memref<1x4096xf32, strided<[4096, 1], offset:0>> %C_row_0_cast = memref.cast %C_row_0 : memref<1x4096xf32, strided<[4096, 1], offset: 0>> to memref<*xf32> @@ -557,7 +521,7 @@ module @gemm attributes {gpu.container_module} { // call @printMemrefF16(%C_row_0_cast_gpu) : (memref<*xf16>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF16(%cast_C, %cast_C_ref) : (memref<*xf16>, memref<*xf32>) -> () + call @printAllcloseF16(%cast_C, %C_ref_cast) : (memref<*xf16>, memref<*xf32>) -> () memref.dealloc %A : memref<4096x4096xf16> memref.dealloc %B : memref<4096x4096xf16> memref.dealloc %C : memref<4096x4096xf16> @@ -568,5 +532,6 @@ module @gemm attributes {gpu.container_module} { func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @printAllcloseF16(memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface} + func.func private @gemmF16F16F16(memref<*xf16>, memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_f16_f16_f16_w_8x32xf16_stores.mlir b/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_f16_f16_f16_w_8x32xf16_stores.mlir index 0850e3f83..fcb53a4bb 100644 --- a/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_f16_f16_f16_w_8x32xf16_stores.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_f16_f16_f16_w_8x32xf16_stores.mlir @@ -458,37 +458,6 @@ module @gemm attributes {gpu.container_module} { } } - // compute CPU reference (takes minutes) - func.func @cpu_reference(%A : memref<4096x4096xf16>, %B : memref<4096x4096xf16>, %C : memref<4096x4096xf32>) { - %c4096 = arith.constant 4096 : index - %c16 = arith.constant 16 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - scf.for %i = %c0 to %c4096 step %c1 { - scf.for %j = %c0 to %c4096 step %c1 { - %c_curr = memref.load %C[%i, %j] : memref<4096x4096xf32> - %c_val = scf.for %k_tile = %c0 to %c4096 step %c16 iter_args(%c_partial = %c_curr) -> f32 { - %c_val_dpas = scf.for %k = %c0 to %c16 step %c1 iter_args(%c_dpas_partial = %c_partial) -> f32 { - %k_dpas = arith.addi %k_tile, %k : index - %a_val = memref.load %A[%i, %k_dpas] : memref<4096x4096xf16> - %b_val = memref.load %B[%k_dpas, %j] : memref<4096x4096xf16> - %a_cast = arith.extf %a_val : f16 to f32 - %b_cast = arith.extf %b_val : f16 to f32 - %t = arith.mulf %a_cast, %b_cast : f32 - // %t_cast = arith.extf %t : f16 to f16 - %c_sum = arith.addf %t, %c_dpas_partial : f32 - scf.yield %c_sum : f32 - } - scf.yield %c_val_dpas : f32 - } - %c_val_f16 = arith.truncf %c_val : f32 to f16 - %c_val_ = arith.extf %c_val_f16 : f16 to f32 - memref.store %c_val_ , %C[%i, %j] : memref<4096x4096xf32> - } - } - return - } - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -549,22 +518,15 @@ module @gemm attributes {gpu.container_module} { memref.store %c0_f32, %C_ref[%i, %j] : memref<4096x4096xf32> } } - // print input fror debug - // %A_row_0 = memref.subview %A[1, 0][1, 4096][1, 1] : memref<4096x4096xf16> to memref<1x4096xf16, strided<[4096, 1], offset: 4096>> - // %A_row_0_cast = memref.cast %A_row_0 : memref<1x4096xf16, strided<[4096, 1], offset: 4096>> to memref<*xf16> - // call @printMemrefF16(%A_row_0_cast) : (memref<*xf16>) -> () - // run GPU + // Run GPU. %2 = call @test(%A, %B, %C) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf16>) -> memref<4096x4096xf16> - - call @cpu_reference(%A, %B, %C_ref) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf32>) -> () - - // %cast = memref.cast %A : memref<4096x4096xf16> to memref<*xf16> - // call @printMemrefF16(%cast) : (memref<*xf16>) -> () %cast_C = memref.cast %2 : memref<4096x4096xf16> to memref<*xf16> - %cast_C_ref = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> - // call @printMemrefF16(%cast_C) : (memref<*xf16>) -> () - // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () + // Run CPU. + %A_cast = memref.cast %A : memref<4096x4096xf16> to memref<*xf16> + %B_cast = memref.cast %B : memref<4096x4096xf16> to memref<*xf16> + %C_cast = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> + call @gemmF16F16F16(%A_cast, %B_cast, %C_cast) : (memref<*xf16>, memref<*xf16>, memref<*xf32>) -> () %C_row_0 = memref.subview %C_ref[0, 0][1, 4096][1, 1] : memref<4096x4096xf32> to memref<1x4096xf32, strided<[4096, 1], offset:0>> %C_row_0_cast = memref.cast %C_row_0 : memref<1x4096xf32, strided<[4096, 1], offset: 0>> to memref<*xf32> @@ -575,7 +537,7 @@ module @gemm attributes {gpu.container_module} { // call @printMemrefF16(%C_row_0_cast_gpu) : (memref<*xf16>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF16(%cast_C, %cast_C_ref) : (memref<*xf16>, memref<*xf32>) -> () + call @printAllcloseF16(%cast_C, %C_cast) : (memref<*xf16>, memref<*xf32>) -> () memref.dealloc %A : memref<4096x4096xf16> memref.dealloc %B : memref<4096x4096xf16> memref.dealloc %C : memref<4096x4096xf16> @@ -586,5 +548,6 @@ module @gemm attributes {gpu.container_module} { func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @printAllcloseF16(memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface} + func.func private @gemmF16F16F16(memref<*xf16>, memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_f16_f16_f16_w_simple_B_prefetch.mlir b/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_f16_f16_f16_w_simple_B_prefetch.mlir index 9de6ec890..0be1c7753 100644 --- a/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_f16_f16_f16_w_simple_B_prefetch.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_f16_f16_f16_w_simple_B_prefetch.mlir @@ -454,37 +454,6 @@ module @gemm attributes {gpu.container_module} { } } - // compute CPU reference (takes minutes) - func.func @cpu_reference(%A : memref<4096x4096xf16>, %B : memref<4096x4096xf16>, %C : memref<4096x4096xf32>) { - %c4096 = arith.constant 4096 : index - %c16 = arith.constant 16 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - scf.for %i = %c0 to %c4096 step %c1 { - scf.for %j = %c0 to %c4096 step %c1 { - %c_curr = memref.load %C[%i, %j] : memref<4096x4096xf32> - %c_val = scf.for %k_tile = %c0 to %c4096 step %c16 iter_args(%c_partial = %c_curr) -> f32 { - %c_val_dpas = scf.for %k = %c0 to %c16 step %c1 iter_args(%c_dpas_partial = %c_partial) -> f32 { - %k_dpas = arith.addi %k_tile, %k : index - %a_val = memref.load %A[%i, %k_dpas] : memref<4096x4096xf16> - %b_val = memref.load %B[%k_dpas, %j] : memref<4096x4096xf16> - %a_cast = arith.extf %a_val : f16 to f32 - %b_cast = arith.extf %b_val : f16 to f32 - %t = arith.mulf %a_cast, %b_cast : f32 - // %t_cast = arith.extf %t : f16 to f16 - %c_sum = arith.addf %t, %c_dpas_partial : f32 - scf.yield %c_sum : f32 - } - scf.yield %c_val_dpas : f32 - } - %c_val_f16 = arith.truncf %c_val : f32 to f16 - %c_val_ = arith.extf %c_val_f16 : f16 to f32 - memref.store %c_val_ , %C[%i, %j] : memref<4096x4096xf32> - } - } - return - } - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -545,22 +514,17 @@ module @gemm attributes {gpu.container_module} { memref.store %c0_f32, %C_ref[%i, %j] : memref<4096x4096xf32> } } - // print input fror debug - // %A_row_0 = memref.subview %A[1, 0][1, 4096][1, 1] : memref<4096x4096xf16> to memref<1x4096xf16, strided<[4096, 1], offset: 4096>> - // %A_row_0_cast = memref.cast %A_row_0 : memref<1x4096xf16, strided<[4096, 1], offset: 4096>> to memref<*xf16> - // call @printMemrefF16(%A_row_0_cast) : (memref<*xf16>) -> () - // run GPU + // Run GPU. %2 = call @test(%A, %B, %C) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf16>) -> memref<4096x4096xf16> + %cast_C = memref.cast %2 : memref<4096x4096xf16> to memref<*xf16> - call @cpu_reference(%A, %B, %C_ref) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf32>) -> () + // Run CPU. + %A_cast = memref.cast %A : memref<4096x4096xf16> to memref<*xf16> + %B_cast = memref.cast %B : memref<4096x4096xf16> to memref<*xf16> + %C_cast = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> + call @gemmF16F16F16(%A_cast, %B_cast, %C_cast) : (memref<*xf16>, memref<*xf16>, memref<*xf32>) -> () - // %cast = memref.cast %A : memref<4096x4096xf16> to memref<*xf16> - // call @printMemrefF16(%cast) : (memref<*xf16>) -> () - %cast_C = memref.cast %2 : memref<4096x4096xf16> to memref<*xf16> - %cast_C_ref = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> - // call @printMemrefF16(%cast_C) : (memref<*xf16>) -> () - // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () %C_row_0 = memref.subview %C_ref[0, 0][1, 4096][1, 1] : memref<4096x4096xf32> to memref<1x4096xf32, strided<[4096, 1], offset:0>> %C_row_0_cast = memref.cast %C_row_0 : memref<1x4096xf32, strided<[4096, 1], offset: 0>> to memref<*xf32> @@ -571,7 +535,7 @@ module @gemm attributes {gpu.container_module} { // call @printMemrefF16(%C_row_0_cast_gpu) : (memref<*xf16>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF16(%cast_C, %cast_C_ref) : (memref<*xf16>, memref<*xf32>) -> () + call @printAllcloseF16(%cast_C, %C_cast) : (memref<*xf16>, memref<*xf32>) -> () memref.dealloc %A : memref<4096x4096xf16> memref.dealloc %B : memref<4096x4096xf16> memref.dealloc %C : memref<4096x4096xf16> @@ -582,5 +546,6 @@ module @gemm attributes {gpu.container_module} { func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @printAllcloseF16(memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface} + func.func private @gemmF16F16F16(memref<*xf16>, memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/WG/gemm_4kx4kx4k_f16_f16_f32.mlir b/test/Integration/Dialect/XeGPU/WG/gemm_4kx4kx4k_f16_f16_f32.mlir index 8aa205287..9c55ad11e 100644 --- a/test/Integration/Dialect/XeGPU/WG/gemm_4kx4kx4k_f16_f16_f32.mlir +++ b/test/Integration/Dialect/XeGPU/WG/gemm_4kx4kx4k_f16_f16_f32.mlir @@ -17,23 +17,20 @@ module @gemm attributes {gpu.container_module} { %c64 = arith.constant 64 : index %c128 = arith.constant 128 : index %c512 = arith.constant 512 : index - %t0 = gpu.wait async - %A_gpu, %t1 = gpu.alloc async [%t0] () : memref<4096x4096xf16> - %t2 = gpu.memcpy async [%t1] %A_gpu, %A : memref<4096x4096xf16>, memref<4096x4096xf16> - %B_gpu, %t3 = gpu.alloc async [%t2] () : memref<4096x4096xf16> - %t4 = gpu.memcpy async [%t3] %B_gpu, %B : memref<4096x4096xf16>, memref<4096x4096xf16> - %C_gpu, %t5 = gpu.alloc async [%t4] () : memref<4096x4096xf32> - %t6 = gpu.memcpy async [%t5] %C_gpu, %C : memref<4096x4096xf32>, memref<4096x4096xf32> + %A_gpu = gpu.alloc () : memref<4096x4096xf16> + gpu.memcpy %A_gpu, %A : memref<4096x4096xf16>, memref<4096x4096xf16> + %B_gpu = gpu.alloc () : memref<4096x4096xf16> + gpu.memcpy %B_gpu, %B : memref<4096x4096xf16>, memref<4096x4096xf16> + %C_gpu = gpu.alloc () : memref<4096x4096xf32> + gpu.memcpy %C_gpu, %C : memref<4096x4096xf32>, memref<4096x4096xf32> // NOTE: Here we can't use [8, 64] wi threads following the SG thread layout of [8, 4]. Because runtime will linearize the x dimension first (we need y dimension to be linearized first). // So just use linearized thread layout of [512, 1] wi threads. - %t7 = gpu.launch_func async [%t6] @test_kernel::@test_kernel blocks in (%c16, %c16, %c1) threads in (%c512, %c1, %c1) args(%A_gpu : memref<4096x4096xf16>, %B_gpu : memref<4096x4096xf16>, %C_gpu : memref<4096x4096xf32>) - gpu.wait [%t7] // Wait for the kernel to finish. - %t12 = gpu.wait async - %t8 = gpu.memcpy async [%t12] %C, %C_gpu : memref<4096x4096xf32>, memref<4096x4096xf32> - %t9 = gpu.dealloc async [%t8] %A_gpu : memref<4096x4096xf16> - %t10 = gpu.dealloc async [%t9] %B_gpu : memref<4096x4096xf16> - %t11 = gpu.dealloc async [%t10] %C_gpu : memref<4096x4096xf32> - gpu.wait [%t11] + gpu.launch_func @test_kernel::@test_kernel blocks in (%c16, %c16, %c1) threads in (%c512, %c1, %c1) args(%A_gpu : memref<4096x4096xf16>, %B_gpu : memref<4096x4096xf16>, %C_gpu : memref<4096x4096xf32>) + gpu.wait // Wait for the kernel to finish. + gpu.memcpy %C, %C_gpu : memref<4096x4096xf32>, memref<4096x4096xf32> + gpu.dealloc %A_gpu : memref<4096x4096xf16> + gpu.dealloc %B_gpu : memref<4096x4096xf16> + gpu.dealloc %C_gpu : memref<4096x4096xf32> return %C : memref<4096x4096xf32> } @@ -97,38 +94,6 @@ module @gemm attributes {gpu.container_module} { } } - // compute CPU reference (takes minutes) - func.func @cpu_reference(%A : memref<4096x4096xf16>, %B : memref<4096x4096xf16>, %C : memref<4096x4096xf32>) { - %c4096 = arith.constant 4096 : index - %c16 = arith.constant 16 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %c64 = arith.constant 64 : index - scf.for %i = %c0 to %c4096 step %c1 { - scf.for %j = %c0 to %c4096 step %c1 { - %c_curr = memref.load %C[%i, %j] : memref<4096x4096xf32> - %c_val = scf.for %k_tile = %c0 to %c4096 step %c16 iter_args(%c_partial = %c_curr) -> f32 { - %c_val_dpas = scf.for %k = %c0 to %c16 step %c1 iter_args(%c_dpas_partial = %c_partial) -> f32 { - %k_dpas = arith.addi %k_tile, %k : index - %a_val = memref.load %A[%i, %k_dpas] : memref<4096x4096xf16> - %b_val = memref.load %B[%k_dpas, %j] : memref<4096x4096xf16> - %a_cast = arith.extf %a_val : f16 to f32 - %b_cast = arith.extf %b_val : f16 to f32 - %t = arith.mulf %a_cast, %b_cast : f32 - // %t_cast = arith.extf %t : f16 to f16 - %c_sum = arith.addf %t, %c_dpas_partial : f32 - scf.yield %c_sum : f32 - } - scf.yield %c_val_dpas : f32 - } - %c_val_f16 = arith.truncf %c_val : f32 to f16 - %c_val_ = arith.extf %c_val_f16 : f16 to f32 - memref.store %c_val_ , %C[%i, %j] : memref<4096x4096xf32> - } - } - return - } - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -180,7 +145,7 @@ module @gemm attributes {gpu.container_module} { call @fillResource1DRandomF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - // intialize matrix C and C_ref ; C[i, j] = 0 + // Initialize matrix C and C_ref ; C[i, j] = 0 %c0_f32 = arith.constant 0.0 : f32 scf.for %i = %c0 to %c4096 step %c1 { scf.for %j = %c0 to %c4096 step %c1 { @@ -188,32 +153,27 @@ module @gemm attributes {gpu.container_module} { memref.store %c0_f32, %C_ref[%i, %j] : memref<4096x4096xf32> } } - // %A_row_0 = memref.subview %A[1, 0][1, 4096][1, 1] : memref<4096x4096xf16> to memref<1x4096xf16, strided<[4096, 1], offset: 4096>> - // %A_row_0_cast = memref.cast %A_row_0 : memref<1x4096xf16, strided<[4096, 1], offset: 4096>> to memref<*xf16> - // call @printMemrefF16(%A_row_0_cast) : (memref<*xf16>) -> () - // run GPU + // Run GPU version. %2 = call @test(%A, %B, %C) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf32>) -> memref<4096x4096xf32> + %gpu_result_cast = memref.cast %2 : memref<4096x4096xf32> to memref<*xf32> - call @cpu_reference(%A, %B, %C_ref) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf32>) -> () - - // %cast = memref.cast %A : memref<4096x4096xf16> to memref<*xf16> - // call @printMemrefF16(%cast) : (memref<*xf16>) -> () - %cast_C = memref.cast %2 : memref<4096x4096xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> - // call @printMemrefF16(%cast_C) : (memref<*xf16>) -> () - // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () + // Run CPU version. + %A_cast = memref.cast %A : memref<4096x4096xf16> to memref<*xf16> + %B_cast = memref.cast %B : memref<4096x4096xf16> to memref<*xf16> + %C_cast = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> + call @gemmF16F16F32(%A_cast, %B_cast, %C_cast) : (memref<*xf16>, memref<*xf16>, memref<*xf32>) -> () %C_row_0 = memref.subview %C_ref[0, 0][1, 4096][1, 1] : memref<4096x4096xf32> to memref<1x4096xf32, strided<[4096, 1], offset:0>> %C_row_0_cast = memref.cast %C_row_0 : memref<1x4096xf32, strided<[4096, 1], offset: 0>> to memref<*xf32> - // call @printMemrefF32(%C_row_0_cast) : (memref<*xf32>) -> () + call @printMemrefF32(%C_row_0_cast) : (memref<*xf32>) -> () %C_row_0_gpu = memref.subview %2[0, 0][1, 4096][1, 1] : memref<4096x4096xf32> to memref<1x4096xf32, strided<[4096, 1], offset:0>> %C_row_0_cast_gpu = memref.cast %C_row_0_gpu : memref<1x4096xf32, strided<[4096, 1], offset: 0>> to memref<*xf32> - // call @printMemrefF32(%C_row_0_cast_gpu) : (memref<*xf32>) -> () + call @printMemrefF32(%C_row_0_cast_gpu) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () + call @printAllcloseF32(%gpu_result_cast, %C_cast) : (memref<*xf32>, memref<*xf32>) -> () memref.dealloc %A : memref<4096x4096xf16> memref.dealloc %B : memref<4096x4096xf16> memref.dealloc %C : memref<4096x4096xf32> @@ -224,5 +184,6 @@ module @gemm attributes {gpu.container_module} { func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @printMaxErrorF16(memref<*xf16>, memref<*xf16>) attributes {llvm.emit_c_interface} func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface} + func.func private @gemmF16F16F32(memref<*xf16>, memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/WG/xegpu-to-llvm.pp b/test/Integration/Dialect/XeGPU/WG/xegpu-to-llvm.pp index e2cf8d10c..1fed796a6 100644 --- a/test/Integration/Dialect/XeGPU/WG/xegpu-to-llvm.pp +++ b/test/Integration/Dialect/XeGPU/WG/xegpu-to-llvm.pp @@ -21,6 +21,7 @@ convert-xevm-to-llvm cse ) + func.func(gpu-async-region) reconcile-unrealized-casts convert-vector-to-scf convert-scf-to-cf diff --git a/test/Integration/Dialect/XeTile/wg_gemm_4k_8x4_sg_layout.mlir b/test/Integration/Dialect/XeTile/wg_gemm_4k_8x4_sg_layout.mlir index b088561f4..ee34852da 100644 --- a/test/Integration/Dialect/XeTile/wg_gemm_4k_8x4_sg_layout.mlir +++ b/test/Integration/Dialect/XeTile/wg_gemm_4k_8x4_sg_layout.mlir @@ -166,34 +166,6 @@ module @gemm attributes {gpu.container_module} { } } - // compute CPU reference (takes minutes) - func.func @cpu_reference(%A : memref<4096x4096xbf16>, %B : memref<4096x4096xbf16>, %C : memref<4096x4096xf32>) { - %c4096 = arith.constant 4096 : index - %c16 = arith.constant 16 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - scf.for %i = %c0 to %c4096 step %c1 { - scf.for %j = %c0 to %c4096 step %c1 { - %c_curr = memref.load %C[%i, %j] : memref<4096x4096xf32> - %c_val = scf.for %k_tile = %c0 to %c4096 step %c16 iter_args(%c_partial = %c_curr) -> f32 { - %c_val_dpas = scf.for %k = %c0 to %c16 step %c1 iter_args(%c_dpas_partial = %c_partial) -> f32 { - %k_dpas = arith.addi %k_tile, %k : index - %a_val = memref.load %A[%i, %k_dpas] : memref<4096x4096xbf16> - %b_val = memref.load %B[%k_dpas, %j] : memref<4096x4096xbf16> - %a_cast = arith.extf %a_val : bf16 to f32 - %b_cast = arith.extf %b_val : bf16 to f32 - %t = arith.mulf %a_cast, %b_cast : f32 - %c_sum = arith.addf %t, %c_dpas_partial : f32 - scf.yield %c_sum : f32 - } - scf.yield %c_val_dpas : f32 - } - memref.store %c_val , %C[%i, %j] : memref<4096x4096xf32> - } - } - return - } - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -229,14 +201,16 @@ module @gemm attributes {gpu.container_module} { } } - // run GPU + // Run GPU. %2 = call @test(%A, %B, %C) : (memref<4096x4096xbf16>, memref<4096x4096xbf16>, memref<4096x4096xf32>) -> memref<4096x4096xf32> - - // run CPU - call @cpu_reference(%A, %B, %C_ref) : (memref<4096x4096xbf16>, memref<4096x4096xbf16>, memref<4096x4096xf32>) -> () - %cast_C = memref.cast %2 : memref<4096x4096xf32> to memref<*xf32> + + // Run CPU + %A_cast = memref.cast %A : memref<4096x4096xbf16> to memref<*xbf16> + %B_cast = memref.cast %B : memref<4096x4096xbf16> to memref<*xbf16> %cast_C_ref = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> + call @gemmBF16BF16F32(%A_cast, %B_cast, %cast_C_ref) : (memref<*xbf16>, memref<*xbf16>, memref<*xf32>) -> () + // CHECK: [ALLCLOSE: TRUE] call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () memref.dealloc %A : memref<4096x4096xbf16> @@ -250,4 +224,5 @@ module @gemm attributes {gpu.container_module} { func.func private @printAllcloseBF16(memref<*xbf16>, memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @fillResource1DRandomBF16(memref<*xbf16>, f32, f32, i1) attributes {llvm.emit_c_interface} + func.func private @gemmBF16BF16F32(memref<*xbf16>, memref<*xbf16>, memref<*xf32>) attributes {llvm.emit_c_interface} }