Skip to content

Commit 1e8dd14

Browse files
authored
[test] Improve CPU gemm validation for faster CI. (#1107)
1 parent 4f5569b commit 1e8dd14

16 files changed

+262
-579
lines changed

include/imex/ExecutionEngine/ImexRunnerUtils.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,19 @@ extern "C" IMEX_RUNNERUTILS_EXPORT void
118118
_mlir_ciface_printMaxErrorF32(UnrankedMemRefType<float> *M,
119119
UnrankedMemRefType<float> *N);
120120

121+
extern "C" IMEX_RUNNERUTILS_EXPORT void
122+
_mlir_ciface_gemmF16F16F16(UnrankedMemRefType<f16> *A,
123+
UnrankedMemRefType<f16> *B,
124+
UnrankedMemRefType<float> *C);
125+
126+
extern "C" IMEX_RUNNERUTILS_EXPORT void
127+
_mlir_ciface_gemmF16F16F32(UnrankedMemRefType<f16> *A,
128+
UnrankedMemRefType<f16> *B,
129+
UnrankedMemRefType<float> *C);
130+
131+
extern "C" IMEX_RUNNERUTILS_EXPORT void
132+
_mlir_ciface_gemmBF16BF16F32(UnrankedMemRefType<bf16> *A,
133+
UnrankedMemRefType<bf16> *B,
134+
UnrankedMemRefType<float> *C);
135+
121136
#endif // IMEX_EXECUTIONENGINE_IMEXRUNNERUTILS_H

lib/ExecutionEngine/ImexRunnerUtils.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
//===----------------------------------------------------------------------===//
1414

1515
#include "imex/ExecutionEngine/ImexRunnerUtils.h"
16+
#include "mlir/ExecutionEngine/CRunnerUtils.h"
17+
#include "mlir/ExecutionEngine/Float16bits.h"
1618
#include <cmath>
1719
#include <cstdlib>
1820
#include <cstring>
@@ -239,6 +241,89 @@ void _mlir_ciface_printMaxError(UnrankedMemRefType<T> *M,
239241
<< '\n';
240242
}
241243

244+
template <typename ABTy, typename CTy>
245+
void _mlir_ciface_gemm(UnrankedMemRefType<ABTy> *A, UnrankedMemRefType<ABTy> *B,
246+
UnrankedMemRefType<float> *C) {
247+
DynamicMemRefType<ABTy> DA = DynamicMemRefType<ABTy>(*A);
248+
DynamicMemRefType<ABTy> DB = DynamicMemRefType<ABTy>(*B);
249+
DynamicMemRefType<float> DC = DynamicMemRefType<float>(*C);
250+
if (DA.rank != 2 || DB.rank != 2 || DC.rank != 2) {
251+
std::cout << "Expecting 2D memrefs, got A rank: " << DA.rank
252+
<< ", B rank: " << DB.rank << ", C rank: " << DC.rank << '\n';
253+
return;
254+
}
255+
if (DA.sizes[1] != DB.sizes[0] || DA.sizes[0] != DC.sizes[0] ||
256+
DB.sizes[1] != DC.sizes[1]) {
257+
std::cout << "Incompatible matrix dimensions: A: [" << DA.sizes[0] << ", "
258+
<< DA.sizes[1] << "], B: [" << DB.sizes[0] << ", " << DB.sizes[1]
259+
<< "], C: [" << DC.sizes[0] << ", " << DC.sizes[1] << "]\n";
260+
return;
261+
}
262+
if ((DA.strides[0] != DA.sizes[1]) || (DA.strides[1] != 1) ||
263+
(DB.strides[0] != DB.sizes[1]) || (DB.strides[1] != 1) ||
264+
(DC.strides[0] != DC.sizes[1]) || (DC.strides[1] != 1)) {
265+
std::cout << "Expecting A strides to be [M, 1], B strides to be [K, 1], C "
266+
"strides to be [M, 1]\n";
267+
return;
268+
}
269+
int M = DA.sizes[0];
270+
int N = DB.sizes[1];
271+
int K = DA.sizes[1];
272+
273+
float *storageA = (float *)malloc(M * K * sizeof(float));
274+
float *storageB = (float *)malloc(K * N * sizeof(float));
275+
float *storageC = (float *)malloc(M * N * sizeof(float));
276+
277+
DynamicMemRefIterator<ABTy> aIt(DA);
278+
for (int i = 0; i < M * K; ++i) {
279+
storageA[i] = getFloat(*aIt);
280+
++aIt;
281+
}
282+
283+
DynamicMemRefIterator<ABTy> bIt(DB);
284+
for (int i = 0; i < K * N; ++i) {
285+
storageB[i] = getFloat(*bIt);
286+
++bIt;
287+
}
288+
289+
DynamicMemRefIterator<float> cIt(DC);
290+
for (int i = 0; i < M * N; ++i) {
291+
storageC[i] = getFloat(*cIt);
292+
++cIt;
293+
}
294+
295+
for (int i = 0; i < M; ++i) {
296+
for (int k = 0; k < K; ++k) {
297+
for (int j = 0; j < N; ++j) {
298+
int a_idx = i * DA.strides[0] + k;
299+
int b_idx = k * DB.strides[0] + j;
300+
int c_idx = i * DC.strides[0] + j;
301+
storageC[c_idx] += storageA[a_idx] * storageB[b_idx];
302+
}
303+
}
304+
}
305+
306+
// Store the result back to C.
307+
DynamicMemRefIterator<float> resultIt(DC);
308+
for (int i = 0; i < M * N; ++i) {
309+
float r = storageC[i];
310+
if constexpr (std::is_same_v<CTy, f16>) {
311+
f16 casted(r);
312+
*resultIt = getFloat(casted);
313+
} else if constexpr (std::is_same_v<CTy, bf16>) {
314+
bf16 casted(r);
315+
*resultIt = getFloat(casted);
316+
} else if constexpr (std::is_same_v<CTy, float>) {
317+
*resultIt = storageC[i];
318+
}
319+
++resultIt;
320+
}
321+
322+
free(storageC);
323+
free(storageA);
324+
free(storageB);
325+
}
326+
242327
extern "C" void _mlir_ciface_printMaxErrorF16(UnrankedMemRefType<f16> *M,
243328
UnrankedMemRefType<f16> *N) {
244329
_mlir_ciface_printMaxError(M, N);
@@ -284,4 +369,22 @@ extern "C" void _mlir_ciface_printAllcloseF32(UnrankedMemRefType<float> *M,
284369
_mlir_ciface_printAllclose(M, N);
285370
}
286371

372+
extern "C" void _mlir_ciface_gemmF16F16F32(UnrankedMemRefType<f16> *A,
373+
UnrankedMemRefType<f16> *B,
374+
UnrankedMemRefType<float> *C) {
375+
_mlir_ciface_gemm<f16, float>(A, B, C);
376+
}
377+
378+
extern "C" void _mlir_ciface_gemmF16F16F16(UnrankedMemRefType<f16> *A,
379+
UnrankedMemRefType<f16> *B,
380+
UnrankedMemRefType<float> *C) {
381+
_mlir_ciface_gemm<f16, f16>(A, B, C);
382+
}
383+
384+
extern "C" void _mlir_ciface_gemmBF16BF16F32(UnrankedMemRefType<bf16> *A,
385+
UnrankedMemRefType<bf16> *B,
386+
UnrankedMemRefType<float> *C) {
387+
_mlir_ciface_gemm<bf16, float>(A, B, C);
388+
}
389+
287390
// NOLINTEND(*-identifier-naming)

test/Integration/Dialect/XeGPU/SG/gemm_4kx4kx4k_f16_f16_f16.mlir

Lines changed: 10 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -380,38 +380,6 @@ module @gemm attributes {gpu.container_module} {
380380
}
381381
}
382382

383-
// compute CPU reference (takes minutes)
384-
func.func @cpu_reference(%A : memref<4096x4096xf16>, %B : memref<4096x4096xf16>, %C : memref<4096x4096xf32>) {
385-
%c4096 = arith.constant 4096 : index
386-
%c16 = arith.constant 16 : index
387-
%c1 = arith.constant 1 : index
388-
%c0 = arith.constant 0 : index
389-
%c64 = arith.constant 64 : index
390-
scf.for %i = %c0 to %c4096 step %c1 {
391-
scf.for %j = %c0 to %c4096 step %c1 {
392-
%c_curr = memref.load %C[%i, %j] : memref<4096x4096xf32>
393-
%c_val = scf.for %k_tile = %c0 to %c4096 step %c16 iter_args(%c_partial = %c_curr) -> f32 {
394-
%c_val_dpas = scf.for %k = %c0 to %c16 step %c1 iter_args(%c_dpas_partial = %c_partial) -> f32 {
395-
%k_dpas = arith.addi %k_tile, %k : index
396-
%a_val = memref.load %A[%i, %k_dpas] : memref<4096x4096xf16>
397-
%b_val = memref.load %B[%k_dpas, %j] : memref<4096x4096xf16>
398-
%a_cast = arith.extf %a_val : f16 to f32
399-
%b_cast = arith.extf %b_val : f16 to f32
400-
%t = arith.mulf %a_cast, %b_cast : f32
401-
// %t_cast = arith.extf %t : f16 to f16
402-
%c_sum = arith.addf %t, %c_dpas_partial : f32
403-
scf.yield %c_sum : f32
404-
}
405-
scf.yield %c_val_dpas : f32
406-
}
407-
%c_val_f16 = arith.truncf %c_val : f32 to f16
408-
%c_val_ = arith.extf %c_val_f16 : f16 to f32
409-
memref.store %c_val_ , %C[%i, %j] : memref<4096x4096xf32>
410-
}
411-
}
412-
return
413-
}
414-
415383
func.func @main() attributes {llvm.emit_c_interface} {
416384
%c0 = arith.constant 0 : index
417385
%c1 = arith.constant 1 : index
@@ -463,7 +431,7 @@ module @gemm attributes {gpu.container_module} {
463431
call @fillResource1DRandomF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> ()
464432

465433

466-
// intialize matrix C and C_ref ; C[i, j] = 0
434+
// Initialize matrix C and C_ref ; C[i, j] = 0
467435
%c0_f16 = arith.constant 0.0 : f16
468436
%c0_f32 = arith.constant 0.0 : f32
469437
scf.for %i = %c0 to %c4096 step %c1 {
@@ -472,22 +440,16 @@ module @gemm attributes {gpu.container_module} {
472440
memref.store %c0_f32, %C_ref[%i, %j] : memref<4096x4096xf32>
473441
}
474442
}
475-
// print input fror debug
476-
// %A_row_0 = memref.subview %A[1, 0][1, 4096][1, 1] : memref<4096x4096xf16> to memref<1x4096xf16, strided<[4096, 1], offset: 4096>>
477-
// %A_row_0_cast = memref.cast %A_row_0 : memref<1x4096xf16, strided<[4096, 1], offset: 4096>> to memref<*xf16>
478-
// call @printMemrefF16(%A_row_0_cast) : (memref<*xf16>) -> ()
479443

480-
// run GPU
444+
// Run GPU version
481445
%2 = call @test(%A, %B, %C) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf16>) -> memref<4096x4096xf16>
446+
%gpu_result_cast = memref.cast %2 : memref<4096x4096xf16> to memref<*xf16>
482447

483-
call @cpu_reference(%A, %B, %C_ref) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf32>) -> ()
484-
485-
// %cast = memref.cast %A : memref<4096x4096xf16> to memref<*xf16>
486-
// call @printMemrefF16(%cast) : (memref<*xf16>) -> ()
487-
%cast_C = memref.cast %2 : memref<4096x4096xf16> to memref<*xf16>
488-
%cast_C_ref = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32>
489-
// call @printMemrefF16(%cast_C) : (memref<*xf16>) -> ()
490-
// call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> ()
448+
// Run CPU version.
449+
%A_cast = memref.cast %A : memref<4096x4096xf16> to memref<*xf16>
450+
%B_cast = memref.cast %B : memref<4096x4096xf16> to memref<*xf16>
451+
%C_cast = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32>
452+
call @gemmF16F16F16(%A_cast, %B_cast, %C_cast) : (memref<*xf16>, memref<*xf16>, memref<*xf32>) -> ()
491453

492454
%C_row_0 = memref.subview %C_ref[0, 0][1, 4096][1, 1] : memref<4096x4096xf32> to memref<1x4096xf32, strided<[4096, 1], offset:0>>
493455
%C_row_0_cast = memref.cast %C_row_0 : memref<1x4096xf32, strided<[4096, 1], offset: 0>> to memref<*xf32>
@@ -498,7 +460,7 @@ module @gemm attributes {gpu.container_module} {
498460
// call @printMemrefF16(%C_row_0_cast_gpu) : (memref<*xf16>) -> ()
499461

500462
// CHECK: [ALLCLOSE: TRUE]
501-
call @printAllcloseF16(%cast_C, %cast_C_ref) : (memref<*xf16>, memref<*xf32>) -> ()
463+
call @printAllcloseF16(%gpu_result_cast, %C_cast) : (memref<*xf16>, memref<*xf32>) -> ()
502464
memref.dealloc %A : memref<4096x4096xf16>
503465
memref.dealloc %B : memref<4096x4096xf16>
504466
memref.dealloc %C : memref<4096x4096xf16>
@@ -510,5 +472,5 @@ module @gemm attributes {gpu.container_module} {
510472
func.func private @printAllcloseF16(memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface}
511473
func.func private @printMaxErrorF16(memref<*xf16>, memref<*xf16>) attributes {llvm.emit_c_interface}
512474
func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface}
513-
475+
func.func private @gemmF16F16F16(memref<*xf16>, memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface}
514476
}

test/Integration/Dialect/XeGPU/SG/simple_gemm.mlir

Lines changed: 9 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -54,37 +54,6 @@ module @gemm attributes {gpu.container_module} {
5454
return %c : memref<256x256xf32>
5555
}
5656

57-
// compute CPU reference (takes minutes)
58-
func.func @cpu_reference(%A : memref<256x256xf16>, %B : memref<256x256xf16>, %C : memref<256x256xf32>) {
59-
%c256 = arith.constant 256 : index
60-
%c16 = arith.constant 16 : index
61-
%c1 = arith.constant 1 : index
62-
%c0 = arith.constant 0 : index
63-
scf.for %i = %c0 to %c256 step %c1 {
64-
scf.for %j = %c0 to %c256 step %c1 {
65-
%c_curr = memref.load %C[%i, %j] : memref<256x256xf32>
66-
%c_val = scf.for %k_tile = %c0 to %c256 step %c16 iter_args(%c_partial = %c_curr) -> f32 {
67-
%c_val_dpas = scf.for %k = %c0 to %c16 step %c1 iter_args(%c_dpas_partial = %c_partial) -> f32 {
68-
%k_dpas = arith.addi %k_tile, %k : index
69-
%a_val = memref.load %A[%i, %k_dpas] : memref<256x256xf16>
70-
%b_val = memref.load %B[%k_dpas, %j] : memref<256x256xf16>
71-
%a_cast = arith.extf %a_val : f16 to f32
72-
%b_cast = arith.extf %b_val : f16 to f32
73-
%t = arith.mulf %a_cast, %b_cast : f32
74-
// %t_cast = arith.extf %t : f16 to f16
75-
%c_sum = arith.addf %t, %c_dpas_partial : f32
76-
scf.yield %c_sum : f32
77-
}
78-
scf.yield %c_val_dpas : f32
79-
}
80-
// %c_val_f16 = arith.truncf %c_val : f32 to f16
81-
// %c_val_ = arith.extf %c_val_f16 : f16 to f32
82-
memref.store %c_val , %C[%i, %j] : memref<256x256xf32>
83-
}
84-
}
85-
return
86-
}
87-
8857

8958
func.func @main() attributes {llvm.emit_c_interface} {
9059
%c0 = arith.constant 0 : index
@@ -145,22 +114,16 @@ module @gemm attributes {gpu.container_module} {
145114
memref.store %c0_f32, %C_ref[%i, %j] : memref<256x256xf32>
146115
}
147116
}
148-
// print input fror debug
149-
// %A_row_0 = memref.subview %A[1, 0][1, 256][1, 1] : memref<256x256xf16> to memref<1x256xf16, strided<[256, 1], offset: 256>>
150-
// %A_row_0_cast = memref.cast %A_row_0 : memref<1x256xf16, strided<[256, 1], offset: 256>> to memref<*xf16>
151-
// call @printMemrefF16(%A_row_0_cast) : (memref<*xf16>) -> ()
152117

153-
// run GPU
118+
// Run GPU.
154119
%2 = call @test(%A, %B, %C) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> memref<256x256xf32>
155-
156-
call @cpu_reference(%A, %B, %C_ref) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> ()
157-
158-
// %cast = memref.cast %A : memref<256x256xf16> to memref<*xf16>
159-
// call @printMemrefF16(%cast) : (memref<*xf16>) -> ()
160120
%cast_C = memref.cast %2 : memref<256x256xf32> to memref<*xf32>
161-
%cast_C_ref = memref.cast %C_ref : memref<256x256xf32> to memref<*xf32>
162-
// call @printMemrefF32(%cast_C) : (memref<*xf32>) -> ()
163-
// call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> ()
121+
122+
// Run CPU.
123+
%A_cast = memref.cast %A : memref<256x256xf16> to memref<*xf16>
124+
%B_cast = memref.cast %B : memref<256x256xf16> to memref<*xf16>
125+
%C_ref_cast = memref.cast %C_ref : memref<256x256xf32> to memref<*xf32>
126+
call @gemmF16F16F32(%A_cast, %B_cast, %C_ref_cast) : (memref<*xf16>, memref<*xf16>, memref<*xf32>) -> ()
164127

165128
%C_row_0 = memref.subview %C_ref[0, 0][1, 256][1, 1] : memref<256x256xf32> to memref<1x256xf32, strided<[256, 1], offset:0>>
166129
%C_row_0_cast = memref.cast %C_row_0 : memref<1x256xf32, strided<[256, 1], offset: 0>> to memref<*xf32>
@@ -171,7 +134,7 @@ module @gemm attributes {gpu.container_module} {
171134
// call @printMemrefF32(%C_row_0_cast_gpu) : (memref<*xf32>) -> ()
172135

173136
// CHECK: [ALLCLOSE: TRUE]
174-
call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> ()
137+
call @printAllcloseF32(%cast_C, %C_ref_cast) : (memref<*xf32>, memref<*xf32>) -> ()
175138
memref.dealloc %A : memref<256x256xf16>
176139
memref.dealloc %B : memref<256x256xf16>
177140
memref.dealloc %C : memref<256x256xf32>
@@ -183,4 +146,5 @@ module @gemm attributes {gpu.container_module} {
183146
func.func private @printAllcloseF16(memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface}
184147
func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface}
185148
func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface}
149+
func.func private @gemmF16F16F32(memref<*xf16>, memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface}
186150
}

0 commit comments

Comments
 (0)