Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions include/imex/ExecutionEngine/ImexRunnerUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,19 @@ extern "C" IMEX_RUNNERUTILS_EXPORT void
_mlir_ciface_printMaxErrorF32(UnrankedMemRefType<float> *M,
UnrankedMemRefType<float> *N);

extern "C" IMEX_RUNNERUTILS_EXPORT void
_mlir_ciface_gemmF16F16F16(UnrankedMemRefType<f16> *A,
UnrankedMemRefType<f16> *B,
UnrankedMemRefType<float> *C);

extern "C" IMEX_RUNNERUTILS_EXPORT void
_mlir_ciface_gemmF16F16F32(UnrankedMemRefType<f16> *A,
UnrankedMemRefType<f16> *B,
UnrankedMemRefType<float> *C);

extern "C" IMEX_RUNNERUTILS_EXPORT void
_mlir_ciface_gemmBF16BF16F32(UnrankedMemRefType<bf16> *A,
UnrankedMemRefType<bf16> *B,
UnrankedMemRefType<float> *C);

#endif // IMEX_EXECUTIONENGINE_IMEXRUNNERUTILS_H
103 changes: 103 additions & 0 deletions lib/ExecutionEngine/ImexRunnerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
//===----------------------------------------------------------------------===//

#include "imex/ExecutionEngine/ImexRunnerUtils.h"
#include "mlir/ExecutionEngine/CRunnerUtils.h"
#include "mlir/ExecutionEngine/Float16bits.h"
#include <cmath>
#include <cstdlib>
#include <cstring>
Expand Down Expand Up @@ -239,6 +241,89 @@ void _mlir_ciface_printMaxError(UnrankedMemRefType<T> *M,
<< '\n';
}

template <typename ABTy, typename CTy>
void _mlir_ciface_gemm(UnrankedMemRefType<ABTy> *A, UnrankedMemRefType<ABTy> *B,
UnrankedMemRefType<float> *C) {
DynamicMemRefType<ABTy> DA = DynamicMemRefType<ABTy>(*A);
DynamicMemRefType<ABTy> DB = DynamicMemRefType<ABTy>(*B);
DynamicMemRefType<float> DC = DynamicMemRefType<float>(*C);
if (DA.rank != 2 || DB.rank != 2 || DC.rank != 2) {
std::cout << "Expecting 2D memrefs, got A rank: " << DA.rank
<< ", B rank: " << DB.rank << ", C rank: " << DC.rank << '\n';
return;
}
if (DA.sizes[1] != DB.sizes[0] || DA.sizes[0] != DC.sizes[0] ||
DB.sizes[1] != DC.sizes[1]) {
std::cout << "Incompatible matrix dimensions: A: [" << DA.sizes[0] << ", "
<< DA.sizes[1] << "], B: [" << DB.sizes[0] << ", " << DB.sizes[1]
<< "], C: [" << DC.sizes[0] << ", " << DC.sizes[1] << "]\n";
return;
}
if ((DA.strides[0] != DA.sizes[1]) || (DA.strides[1] != 1) ||
(DB.strides[0] != DB.sizes[1]) || (DB.strides[1] != 1) ||
(DC.strides[0] != DC.sizes[1]) || (DC.strides[1] != 1)) {
std::cout << "Expecting A strides to be [M, 1], B strides to be [K, 1], C "
"strides to be [M, 1]\n";
return;
}
int M = DA.sizes[0];
int N = DB.sizes[1];
int K = DA.sizes[1];

float *storageA = (float *)malloc(M * K * sizeof(float));
float *storageB = (float *)malloc(K * N * sizeof(float));
float *storageC = (float *)malloc(M * N * sizeof(float));

DynamicMemRefIterator<ABTy> aIt(DA);
for (int i = 0; i < M * K; ++i) {
storageA[i] = getFloat(*aIt);
++aIt;
}

DynamicMemRefIterator<ABTy> bIt(DB);
for (int i = 0; i < K * N; ++i) {
storageB[i] = getFloat(*bIt);
++bIt;
}

DynamicMemRefIterator<float> cIt(DC);
for (int i = 0; i < M * N; ++i) {
storageC[i] = getFloat(*cIt);
++cIt;
}

for (int i = 0; i < M; ++i) {
for (int k = 0; k < K; ++k) {
for (int j = 0; j < N; ++j) {
int a_idx = i * DA.strides[0] + k;
int b_idx = k * DB.strides[0] + j;
int c_idx = i * DC.strides[0] + j;
storageC[c_idx] += storageA[a_idx] * storageB[b_idx];
}
}
}

// Store the result back to C.
DynamicMemRefIterator<float> resultIt(DC);
for (int i = 0; i < M * N; ++i) {
float r = storageC[i];
if constexpr (std::is_same_v<CTy, f16>) {
f16 casted(r);
*resultIt = getFloat(casted);
} else if constexpr (std::is_same_v<CTy, bf16>) {
bf16 casted(r);
*resultIt = getFloat(casted);
} else if constexpr (std::is_same_v<CTy, float>) {
*resultIt = storageC[i];
}
++resultIt;
}

free(storageC);
free(storageA);
free(storageB);
}

extern "C" void _mlir_ciface_printMaxErrorF16(UnrankedMemRefType<f16> *M,
UnrankedMemRefType<f16> *N) {
_mlir_ciface_printMaxError(M, N);
Expand Down Expand Up @@ -284,4 +369,22 @@ extern "C" void _mlir_ciface_printAllcloseF32(UnrankedMemRefType<float> *M,
_mlir_ciface_printAllclose(M, N);
}

extern "C" void _mlir_ciface_gemmF16F16F32(UnrankedMemRefType<f16> *A,
UnrankedMemRefType<f16> *B,
UnrankedMemRefType<float> *C) {
_mlir_ciface_gemm<f16, float>(A, B, C);
}

extern "C" void _mlir_ciface_gemmF16F16F16(UnrankedMemRefType<f16> *A,
UnrankedMemRefType<f16> *B,
UnrankedMemRefType<float> *C) {
_mlir_ciface_gemm<f16, f16>(A, B, C);
}

extern "C" void _mlir_ciface_gemmBF16BF16F32(UnrankedMemRefType<bf16> *A,
UnrankedMemRefType<bf16> *B,
UnrankedMemRefType<float> *C) {
_mlir_ciface_gemm<bf16, float>(A, B, C);
}

// NOLINTEND(*-identifier-naming)
58 changes: 10 additions & 48 deletions test/Integration/Dialect/XeGPU/SG/gemm_4kx4kx4k_f16_f16_f16.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -380,38 +380,6 @@ module @gemm attributes {gpu.container_module} {
}
}

// compute CPU reference (takes minutes)
func.func @cpu_reference(%A : memref<4096x4096xf16>, %B : memref<4096x4096xf16>, %C : memref<4096x4096xf32>) {
%c4096 = arith.constant 4096 : index
%c16 = arith.constant 16 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
scf.for %i = %c0 to %c4096 step %c1 {
scf.for %j = %c0 to %c4096 step %c1 {
%c_curr = memref.load %C[%i, %j] : memref<4096x4096xf32>
%c_val = scf.for %k_tile = %c0 to %c4096 step %c16 iter_args(%c_partial = %c_curr) -> f32 {
%c_val_dpas = scf.for %k = %c0 to %c16 step %c1 iter_args(%c_dpas_partial = %c_partial) -> f32 {
%k_dpas = arith.addi %k_tile, %k : index
%a_val = memref.load %A[%i, %k_dpas] : memref<4096x4096xf16>
%b_val = memref.load %B[%k_dpas, %j] : memref<4096x4096xf16>
%a_cast = arith.extf %a_val : f16 to f32
%b_cast = arith.extf %b_val : f16 to f32
%t = arith.mulf %a_cast, %b_cast : f32
// %t_cast = arith.extf %t : f16 to f16
%c_sum = arith.addf %t, %c_dpas_partial : f32
scf.yield %c_sum : f32
}
scf.yield %c_val_dpas : f32
}
%c_val_f16 = arith.truncf %c_val : f32 to f16
%c_val_ = arith.extf %c_val_f16 : f16 to f32
memref.store %c_val_ , %C[%i, %j] : memref<4096x4096xf32>
}
}
return
}

func.func @main() attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
Expand Down Expand Up @@ -463,7 +431,7 @@ module @gemm attributes {gpu.container_module} {
call @fillResource1DRandomF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> ()


// intialize matrix C and C_ref ; C[i, j] = 0
// Initialize matrix C and C_ref ; C[i, j] = 0
%c0_f16 = arith.constant 0.0 : f16
%c0_f32 = arith.constant 0.0 : f32
scf.for %i = %c0 to %c4096 step %c1 {
Expand All @@ -472,22 +440,16 @@ module @gemm attributes {gpu.container_module} {
memref.store %c0_f32, %C_ref[%i, %j] : memref<4096x4096xf32>
}
}
// print input fror debug
// %A_row_0 = memref.subview %A[1, 0][1, 4096][1, 1] : memref<4096x4096xf16> to memref<1x4096xf16, strided<[4096, 1], offset: 4096>>
// %A_row_0_cast = memref.cast %A_row_0 : memref<1x4096xf16, strided<[4096, 1], offset: 4096>> to memref<*xf16>
// call @printMemrefF16(%A_row_0_cast) : (memref<*xf16>) -> ()

// run GPU
// Run GPU version
%2 = call @test(%A, %B, %C) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf16>) -> memref<4096x4096xf16>
%gpu_result_cast = memref.cast %2 : memref<4096x4096xf16> to memref<*xf16>

call @cpu_reference(%A, %B, %C_ref) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf32>) -> ()

// %cast = memref.cast %A : memref<4096x4096xf16> to memref<*xf16>
// call @printMemrefF16(%cast) : (memref<*xf16>) -> ()
%cast_C = memref.cast %2 : memref<4096x4096xf16> to memref<*xf16>
%cast_C_ref = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32>
// call @printMemrefF16(%cast_C) : (memref<*xf16>) -> ()
// call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> ()
// Run CPU version.
%A_cast = memref.cast %A : memref<4096x4096xf16> to memref<*xf16>
%B_cast = memref.cast %B : memref<4096x4096xf16> to memref<*xf16>
%C_cast = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32>
call @gemmF16F16F16(%A_cast, %B_cast, %C_cast) : (memref<*xf16>, memref<*xf16>, memref<*xf32>) -> ()

%C_row_0 = memref.subview %C_ref[0, 0][1, 4096][1, 1] : memref<4096x4096xf32> to memref<1x4096xf32, strided<[4096, 1], offset:0>>
%C_row_0_cast = memref.cast %C_row_0 : memref<1x4096xf32, strided<[4096, 1], offset: 0>> to memref<*xf32>
Expand All @@ -498,7 +460,7 @@ module @gemm attributes {gpu.container_module} {
// call @printMemrefF16(%C_row_0_cast_gpu) : (memref<*xf16>) -> ()

// CHECK: [ALLCLOSE: TRUE]
call @printAllcloseF16(%cast_C, %cast_C_ref) : (memref<*xf16>, memref<*xf32>) -> ()
call @printAllcloseF16(%gpu_result_cast, %C_cast) : (memref<*xf16>, memref<*xf32>) -> ()
memref.dealloc %A : memref<4096x4096xf16>
memref.dealloc %B : memref<4096x4096xf16>
memref.dealloc %C : memref<4096x4096xf16>
Expand All @@ -510,5 +472,5 @@ module @gemm attributes {gpu.container_module} {
func.func private @printAllcloseF16(memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface}
func.func private @printMaxErrorF16(memref<*xf16>, memref<*xf16>) attributes {llvm.emit_c_interface}
func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface}

func.func private @gemmF16F16F16(memref<*xf16>, memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface}
}
54 changes: 9 additions & 45 deletions test/Integration/Dialect/XeGPU/SG/simple_gemm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -54,37 +54,6 @@ module @gemm attributes {gpu.container_module} {
return %c : memref<256x256xf32>
}

// compute CPU reference (takes minutes)
func.func @cpu_reference(%A : memref<256x256xf16>, %B : memref<256x256xf16>, %C : memref<256x256xf32>) {
%c256 = arith.constant 256 : index
%c16 = arith.constant 16 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
scf.for %i = %c0 to %c256 step %c1 {
scf.for %j = %c0 to %c256 step %c1 {
%c_curr = memref.load %C[%i, %j] : memref<256x256xf32>
%c_val = scf.for %k_tile = %c0 to %c256 step %c16 iter_args(%c_partial = %c_curr) -> f32 {
%c_val_dpas = scf.for %k = %c0 to %c16 step %c1 iter_args(%c_dpas_partial = %c_partial) -> f32 {
%k_dpas = arith.addi %k_tile, %k : index
%a_val = memref.load %A[%i, %k_dpas] : memref<256x256xf16>
%b_val = memref.load %B[%k_dpas, %j] : memref<256x256xf16>
%a_cast = arith.extf %a_val : f16 to f32
%b_cast = arith.extf %b_val : f16 to f32
%t = arith.mulf %a_cast, %b_cast : f32
// %t_cast = arith.extf %t : f16 to f16
%c_sum = arith.addf %t, %c_dpas_partial : f32
scf.yield %c_sum : f32
}
scf.yield %c_val_dpas : f32
}
// %c_val_f16 = arith.truncf %c_val : f32 to f16
// %c_val_ = arith.extf %c_val_f16 : f16 to f32
memref.store %c_val , %C[%i, %j] : memref<256x256xf32>
}
}
return
}


func.func @main() attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
Expand Down Expand Up @@ -145,22 +114,16 @@ module @gemm attributes {gpu.container_module} {
memref.store %c0_f32, %C_ref[%i, %j] : memref<256x256xf32>
}
}
// print input fror debug
// %A_row_0 = memref.subview %A[1, 0][1, 256][1, 1] : memref<256x256xf16> to memref<1x256xf16, strided<[256, 1], offset: 256>>
// %A_row_0_cast = memref.cast %A_row_0 : memref<1x256xf16, strided<[256, 1], offset: 256>> to memref<*xf16>
// call @printMemrefF16(%A_row_0_cast) : (memref<*xf16>) -> ()

// run GPU
// Run GPU.
%2 = call @test(%A, %B, %C) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> memref<256x256xf32>

call @cpu_reference(%A, %B, %C_ref) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> ()

// %cast = memref.cast %A : memref<256x256xf16> to memref<*xf16>
// call @printMemrefF16(%cast) : (memref<*xf16>) -> ()
%cast_C = memref.cast %2 : memref<256x256xf32> to memref<*xf32>
%cast_C_ref = memref.cast %C_ref : memref<256x256xf32> to memref<*xf32>
// call @printMemrefF32(%cast_C) : (memref<*xf32>) -> ()
// call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> ()

// Run CPU.
%A_cast = memref.cast %A : memref<256x256xf16> to memref<*xf16>
%B_cast = memref.cast %B : memref<256x256xf16> to memref<*xf16>
%C_ref_cast = memref.cast %C_ref : memref<256x256xf32> to memref<*xf32>
call @gemmF16F16F32(%A_cast, %B_cast, %C_ref_cast) : (memref<*xf16>, memref<*xf16>, memref<*xf32>) -> ()

%C_row_0 = memref.subview %C_ref[0, 0][1, 256][1, 1] : memref<256x256xf32> to memref<1x256xf32, strided<[256, 1], offset:0>>
%C_row_0_cast = memref.cast %C_row_0 : memref<1x256xf32, strided<[256, 1], offset: 0>> to memref<*xf32>
Expand All @@ -171,7 +134,7 @@ module @gemm attributes {gpu.container_module} {
// call @printMemrefF32(%C_row_0_cast_gpu) : (memref<*xf32>) -> ()

// CHECK: [ALLCLOSE: TRUE]
call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> ()
call @printAllcloseF32(%cast_C, %C_ref_cast) : (memref<*xf32>, memref<*xf32>) -> ()
memref.dealloc %A : memref<256x256xf16>
memref.dealloc %B : memref<256x256xf16>
memref.dealloc %C : memref<256x256xf32>
Expand All @@ -183,4 +146,5 @@ module @gemm attributes {gpu.container_module} {
func.func private @printAllcloseF16(memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface}
func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface}
func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface}
func.func private @gemmF16F16F32(memref<*xf16>, memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface}
}
Loading
Loading