From a46234b47de6f38b49c6348e6b0ce2e96c8d4421 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Fri, 6 Jun 2025 03:23:47 -0700 Subject: [PATCH 1/8] Bump llvm - apply_registered_pass options param --- build_tools/llvm_version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/llvm_version.txt b/build_tools/llvm_version.txt index b1c947181..f85612b40 100644 --- a/build_tools/llvm_version.txt +++ b/build_tools/llvm_version.txt @@ -1 +1 @@ -faf5d747f174cc9d714839f0d3bce1a783eac2ac +4eeee41f52d08dc544812a2c3a37e0adf686251a From 19b73202d9be48356eeb52d749c451744eb6634f Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Wed, 11 Jun 2025 09:39:42 -0700 Subject: [PATCH 2/8] Bufferization state fixes and isZeroInteger fixes --- .../Check/BufferizableOpInterfaceImpl.cpp | 17 ++++++++++------- .../Perf/BufferizableOpInterfaceImpl.cpp | 6 ++++-- lib/TPP/Transforms/VectorContractToAMX.cpp | 4 ++-- lib/TPP/Transforms/VectorContractToFMA.cpp | 4 ++-- .../Transforms/VectorContractToOuterproduct.cpp | 4 ++-- 5 files changed, 20 insertions(+), 15 deletions(-) diff --git a/lib/TPP/Dialect/Check/BufferizableOpInterfaceImpl.cpp b/lib/TPP/Dialect/Check/BufferizableOpInterfaceImpl.cpp index 964fdcf48..cac0980eb 100644 --- a/lib/TPP/Dialect/Check/BufferizableOpInterfaceImpl.cpp +++ b/lib/TPP/Dialect/Check/BufferizableOpInterfaceImpl.cpp @@ -48,11 +48,12 @@ struct ExpectTrueLayoutInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { check::ExpectTrueOp expectTrueOp = cast(op); FailureOr maybeSrcBuffer = - getBuffer(rewriter, expectTrueOp.getOperand(), options); + getBuffer(rewriter, expectTrueOp.getOperand(), options, state); if (failed(maybeSrcBuffer)) return failure(); Value srcBuffer = *maybeSrcBuffer; @@ -91,16 +92,17 @@ struct ExpectAlmostEqLayoutInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { check::ExpectAlmostEqOp almostEqOp = cast(op); FailureOr maybeFirstBuffer = - getBuffer(rewriter, almostEqOp.getLhs(), options); + getBuffer(rewriter, almostEqOp.getLhs(), options, state); if (failed(maybeFirstBuffer)) return failure(); Value firstBuffer = *maybeFirstBuffer; FailureOr maybeSecondBuffer = - getBuffer(rewriter, almostEqOp.getRhs(), options); + getBuffer(rewriter, almostEqOp.getRhs(), options, state); if (failed(maybeSecondBuffer)) return failure(); Value secondBuffer = *maybeSecondBuffer; @@ -142,10 +144,11 @@ struct ExpectSaneLayoutInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { check::ExpectSaneOp saneOp = cast(op); FailureOr maybeBuffer = - getBuffer(rewriter, saneOp.getOperand(), options); + getBuffer(rewriter, saneOp.getOperand(), options, state); if (failed(maybeBuffer)) { return failure(); } diff --git a/lib/TPP/Dialect/Perf/BufferizableOpInterfaceImpl.cpp b/lib/TPP/Dialect/Perf/BufferizableOpInterfaceImpl.cpp index 8c38d76b2..e0aa16d6b 100644 --- a/lib/TPP/Dialect/Perf/BufferizableOpInterfaceImpl.cpp +++ b/lib/TPP/Dialect/Perf/BufferizableOpInterfaceImpl.cpp @@ -55,10 +55,12 @@ struct SinkLayoutInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto sink = cast(op); - FailureOr srcBuffer = getBuffer(rewriter, sink.getInput(), options); + FailureOr srcBuffer = + getBuffer(rewriter, sink.getInput(), options, state); if (failed(srcBuffer)) return failure(); diff --git a/lib/TPP/Transforms/VectorContractToAMX.cpp b/lib/TPP/Transforms/VectorContractToAMX.cpp index 9355c43d8..ad95b56a8 100644 --- a/lib/TPP/Transforms/VectorContractToAMX.cpp +++ b/lib/TPP/Transforms/VectorContractToAMX.cpp @@ -344,8 +344,8 @@ struct VectorContractToAMXPattern return rewriter.notifyMatchFailure( op, "Accumulator defined by TransferReadOp"); - if (!llvm::all_of(lhsDefiningOp.getIndices(), isZeroIndex) || - !llvm::all_of(rhsDefiningOp.getIndices(), isZeroIndex)) + if (!llvm::all_of(lhsDefiningOp.getIndices(), isZeroInteger) || + !llvm::all_of(rhsDefiningOp.getIndices(), isZeroInteger)) return rewriter.notifyMatchFailure( op, "Inputs are not whole tensor or subview"); diff --git a/lib/TPP/Transforms/VectorContractToFMA.cpp b/lib/TPP/Transforms/VectorContractToFMA.cpp index 15802112f..d9771e873 100644 --- a/lib/TPP/Transforms/VectorContractToFMA.cpp +++ b/lib/TPP/Transforms/VectorContractToFMA.cpp @@ -174,8 +174,8 @@ struct VectorContractToFMAPattern return failure(); // Make sure the inputs being read are whole tensor or subview. - if (!llvm::all_of(lhsDefiningOp.getIndices(), isZeroIndex) || - !llvm::all_of(rhsDefiningOp.getIndices(), isZeroIndex)) { + if (!llvm::all_of(lhsDefiningOp.getIndices(), isZeroInteger) || + !llvm::all_of(rhsDefiningOp.getIndices(), isZeroInteger)) { return failure(); } diff --git a/lib/TPP/Transforms/VectorContractToOuterproduct.cpp b/lib/TPP/Transforms/VectorContractToOuterproduct.cpp index edc5bc143..4b109dc45 100644 --- a/lib/TPP/Transforms/VectorContractToOuterproduct.cpp +++ b/lib/TPP/Transforms/VectorContractToOuterproduct.cpp @@ -133,8 +133,8 @@ struct VectorContractToOuterproductPattern return failure(); // Make sure the inputs being read are whole tensor or subview. - if (!llvm::all_of(lhsDefiningOp.getIndices(), isZeroIndex) || - !llvm::all_of(rhsDefiningOp.getIndices(), isZeroIndex)) { + if (!llvm::all_of(lhsDefiningOp.getIndices(), isZeroInteger) || + !llvm::all_of(rhsDefiningOp.getIndices(), isZeroInteger)) { return failure(); } From 3279f676148ddecfe3096ed52376ebce969c4919 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Wed, 11 Jun 2025 09:49:54 -0700 Subject: [PATCH 3/8] Bump to llvm commit with updated apply_registered_pass --- build_tools/llvm_version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/llvm_version.txt b/build_tools/llvm_version.txt index f85612b40..ad07313d7 100644 --- a/build_tools/llvm_version.txt +++ b/build_tools/llvm_version.txt @@ -1 +1 @@ -4eeee41f52d08dc544812a2c3a37e0adf686251a +fe7bf4b90b1a835418bddd2b2aa63b4977a9f6d2 From 2381573a00176e2d53de3fc41fde4ca1212fb672 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Wed, 11 Jun 2025 09:50:49 -0700 Subject: [PATCH 4/8] Update options passing to apply_registered_pass in Python --- python/mlir/tpp/sched/bundles.py | 26 +++++++++++++++++++------- python/mlir/tpp/sched/common.py | 2 +- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/python/mlir/tpp/sched/bundles.py b/python/mlir/tpp/sched/bundles.py index ed9c2487d..d88ec2694 100755 --- a/python/mlir/tpp/sched/bundles.py +++ b/python/mlir/tpp/sched/bundles.py @@ -1,5 +1,6 @@ from typing import Optional, Sequence +from mlir import ir from mlir.dialects import transform from .common import apply_registered_pass, match from .utils import GpuBackend, PipelineInterrupt @@ -67,7 +68,7 @@ def linalg_lowering(mod, /, *, skip_operations: Sequence[str] = (), **_config): func = apply_registered_pass( func, "convert-linalg-to-xsmm", - options="skip-operations=" + ",".join(skip_operations), + options={"skip-operations": ir.StringAttr.get(",".join(skip_operations))}, ) func = apply_registered_pass(func, "combine-xsmm-op-optimization") func = apply_registered_pass(func, "fold-xsmm-flags") @@ -130,7 +131,11 @@ def low_level_parallel( # Run cleanup after LICM to allow CSE to eliminate common operations now # that they are hoisted out of loops. mod = cleanup(mod) - options = "parallel-loop-tile-sizes=" + ",".join(map(str, parallel_task_grid)) + options = { + "parallel-loop-tile-sizes": ir.StringAttr.get( + ",".join(map(str, parallel_task_grid)) + ) + } mod = apply_registered_pass(mod, "scf-parallel-loop-tiling", options=options) return mod @@ -228,7 +233,11 @@ def default_tpp_passes( mod = linalg_lowering(mod, skip_operations=skip_ops, **config) if linalg_to_vector or force_linalg_to_vector: func = match(mod, ops={"func.func"}) - options = "registerTileShape=" + ",".join(map(str, register_blocking)) + options = { + "registerTileShape": ir.StringAttr.get( + ",".join(map(str, register_blocking)) + ) + } func = apply_registered_pass(func, "brgemm-linalg-tiling", options=options) func = apply_registered_pass(func, "loop-invariant-code-motion") apply_registered_pass(func, "vectorization-pass") @@ -315,7 +324,11 @@ def default_pipeline( # #if defined(__x86_64__) # options.x86Vector = true; # #endif - options = f"enable-amx={int(xsmm_utils.has_amx())}" + options = { + "enable-amx": ir.IntegerAttr.get( + ir.IntegerType.get_signless(64), int(xsmm_utils.has_amx()) + ) + } mod = apply_registered_pass(mod, "convert-vector-to-llvm", options=options) mod = apply_registered_pass(mod, "finalize-memref-to-llvm") mod = apply_registered_pass(mod, "convert-scf-to-cf") @@ -327,9 +340,8 @@ def default_pipeline( # gpu-to-llvm cannot be invoked from transform-interpreter as it # tries to load ... something while multi-threaded PassManager is running. mod = apply_registered_pass(mod, "gpu-to-llvm") - mod = apply_registered_pass( - mod, "gpu-module-to-binary", options="compilation-target=fatbin" - ) + options = {"compilation-target": ir.StringAttr.get("fatbin")} + mod = apply_registered_pass(mod, "gpu-module-to-binary", options=options) mod = apply_registered_pass(mod, "convert-math-to-llvm") if gpu_backend: mod = apply_registered_pass(mod, "async-to-async-runtime") diff --git a/python/mlir/tpp/sched/common.py b/python/mlir/tpp/sched/common.py index dc307e019..e51ebb3c5 100644 --- a/python/mlir/tpp/sched/common.py +++ b/python/mlir/tpp/sched/common.py @@ -4,7 +4,7 @@ # Wrapper to addresss verbosity. def apply_registered_pass(*args, **kwargs): - return transform.ApplyRegisteredPassOp(transform.AnyOpType.get(), *args, **kwargs) + return transform.apply_registered_pass(transform.AnyOpType.get(), *args, **kwargs) # Wrapper to addresss verbosity. From edb753560a78f8f4eb9b626a9050fbd5ae403c40 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Wed, 11 Jun 2025 13:20:46 -0700 Subject: [PATCH 5/8] Bump (for apply_registered_pass arg order fix) and tileresult fix --- build_tools/llvm_version.txt | 2 +- lib/TPP/Transforms/LowerPacksAndUnpacks.cpp | 10 +++++----- lib/TPP/Transforms/RewriteBatchMatmulToMatmul.cpp | 2 +- lib/TPP/Transforms/SplitReductionDim.cpp | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/build_tools/llvm_version.txt b/build_tools/llvm_version.txt index ad07313d7..2bf3e65a0 100644 --- a/build_tools/llvm_version.txt +++ b/build_tools/llvm_version.txt @@ -1 +1 @@ -fe7bf4b90b1a835418bddd2b2aa63b4977a9f6d2 +fb761aa38b0bc01ab911f5dbbfb474b70aaafbb4 diff --git a/lib/TPP/Transforms/LowerPacksAndUnpacks.cpp b/lib/TPP/Transforms/LowerPacksAndUnpacks.cpp index 6c65cb6b7..3d287c498 100644 --- a/lib/TPP/Transforms/LowerPacksAndUnpacks.cpp +++ b/lib/TPP/Transforms/LowerPacksAndUnpacks.cpp @@ -112,7 +112,7 @@ static void fuseOrTilePacks(RewriterBase &rewriter, FunctionOpInterface func) { forLoops); if (!fusedProducer) continue; - rewriter.replaceOp(consumerPackOp, tilingResult->mergeResult.replacements); + rewriter.replaceOp(consumerPackOp, tilingResult->replacements); } // Tile packs. @@ -124,7 +124,7 @@ static void fuseOrTilePacks(RewriterBase &rewriter, FunctionOpInterface func) { rewriter, cast(packOp.getOperation()), tileSizes); if (failed(tilingResult)) continue; - rewriter.replaceOp(packOp, tilingResult->mergeResult.replacements); + rewriter.replaceOp(packOp, tilingResult->replacements); } // Tile unpacks. @@ -136,7 +136,7 @@ static void fuseOrTilePacks(RewriterBase &rewriter, FunctionOpInterface func) { rewriter, cast(unPackOp.getOperation()), tileSizes); if (failed(tilingResult)) continue; - rewriter.replaceOp(unPackOp, tilingResult->mergeResult.replacements); + rewriter.replaceOp(unPackOp, tilingResult->replacements); } } @@ -215,7 +215,7 @@ class LowerPacksAndUnPacks unpackTilingOptions); if (failed(tilingResult)) return signalPassFailure(); - rewriter.replaceOp(unPackOp, tilingResult->mergeResult.replacements); + rewriter.replaceOp(unPackOp, tilingResult->replacements); }); getOperation()->walk([&](linalg::PackOp packOp) { SmallVector tiles(packOp.getSourceType().getRank(), 1); @@ -226,7 +226,7 @@ class LowerPacksAndUnPacks packTilingOptions); if (failed(tilingResult)) return signalPassFailure(); - rewriter.replaceOp(packOp, tilingResult->mergeResult.replacements); + rewriter.replaceOp(packOp, tilingResult->replacements); }); RewritePatternSet patterns(&getContext()); patterns.addmergeResult.replacements); + rewriter.replaceOp(batchMatmulOp, tilingResult->replacements); }); // Step2: diff --git a/lib/TPP/Transforms/SplitReductionDim.cpp b/lib/TPP/Transforms/SplitReductionDim.cpp index edba1bf66..b3f9fbf17 100644 --- a/lib/TPP/Transforms/SplitReductionDim.cpp +++ b/lib/TPP/Transforms/SplitReductionDim.cpp @@ -81,7 +81,7 @@ struct SplitContractionReduction return rewriter.notifyMatchFailure(linalgOp, "failed to tile contraction"); - rewriter.replaceOp(linalgOp, tilingResult->mergeResult.replacements); + rewriter.replaceOp(linalgOp, tilingResult->replacements); return success(); } From bd41253c407754a42c4d30be008eafa3cc00b5a8 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Wed, 11 Jun 2025 13:37:51 -0700 Subject: [PATCH 6/8] Remove attribute wrapping for auto-converted options --- python/mlir/tpp/sched/bundles.py | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/python/mlir/tpp/sched/bundles.py b/python/mlir/tpp/sched/bundles.py index d88ec2694..5edd1d713 100755 --- a/python/mlir/tpp/sched/bundles.py +++ b/python/mlir/tpp/sched/bundles.py @@ -68,7 +68,7 @@ def linalg_lowering(mod, /, *, skip_operations: Sequence[str] = (), **_config): func = apply_registered_pass( func, "convert-linalg-to-xsmm", - options={"skip-operations": ir.StringAttr.get(",".join(skip_operations))}, + options={"skip-operations": ",".join(skip_operations)}, ) func = apply_registered_pass(func, "combine-xsmm-op-optimization") func = apply_registered_pass(func, "fold-xsmm-flags") @@ -131,11 +131,7 @@ def low_level_parallel( # Run cleanup after LICM to allow CSE to eliminate common operations now # that they are hoisted out of loops. mod = cleanup(mod) - options = { - "parallel-loop-tile-sizes": ir.StringAttr.get( - ",".join(map(str, parallel_task_grid)) - ) - } + options = {"parallel-loop-tile-sizes": ",".join(map(str, parallel_task_grid))} mod = apply_registered_pass(mod, "scf-parallel-loop-tiling", options=options) return mod @@ -233,11 +229,7 @@ def default_tpp_passes( mod = linalg_lowering(mod, skip_operations=skip_ops, **config) if linalg_to_vector or force_linalg_to_vector: func = match(mod, ops={"func.func"}) - options = { - "registerTileShape": ir.StringAttr.get( - ",".join(map(str, register_blocking)) - ) - } + options = {"registerTileShape": ",".join(map(str, register_blocking))} func = apply_registered_pass(func, "brgemm-linalg-tiling", options=options) func = apply_registered_pass(func, "loop-invariant-code-motion") apply_registered_pass(func, "vectorization-pass") @@ -324,11 +316,7 @@ def default_pipeline( # #if defined(__x86_64__) # options.x86Vector = true; # #endif - options = { - "enable-amx": ir.IntegerAttr.get( - ir.IntegerType.get_signless(64), int(xsmm_utils.has_amx()) - ) - } + options = {"enable-amx": int(xsmm_utils.has_amx())} mod = apply_registered_pass(mod, "convert-vector-to-llvm", options=options) mod = apply_registered_pass(mod, "finalize-memref-to-llvm") mod = apply_registered_pass(mod, "convert-scf-to-cf") @@ -340,7 +328,7 @@ def default_pipeline( # gpu-to-llvm cannot be invoked from transform-interpreter as it # tries to load ... something while multi-threaded PassManager is running. mod = apply_registered_pass(mod, "gpu-to-llvm") - options = {"compilation-target": ir.StringAttr.get("fatbin")} + options = {"compilation-target": "fatbin"} mod = apply_registered_pass(mod, "gpu-module-to-binary", options=options) mod = apply_registered_pass(mod, "convert-math-to-llvm") if gpu_backend: From fb9deb8007187c331f85dac739bbca3a6821c17d Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Wed, 11 Jun 2025 13:57:35 -0700 Subject: [PATCH 7/8] amx refactor fix and constant name change --- test/Integration/tpp-run-splat-mlp.mlir | 2 +- test/Passes/DefaultPipeline/amx-initialization.mlir | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/test/Integration/tpp-run-splat-mlp.mlir b/test/Integration/tpp-run-splat-mlp.mlir index 83a1924de..26eae3976 100644 --- a/test/Integration/tpp-run-splat-mlp.mlir +++ b/test/Integration/tpp-run-splat-mlp.mlir @@ -53,7 +53,7 @@ func.func @entry(%arg0: tensor<8x8xf32>, %output: tensor<8x8xf32>) -> tensor<8x8 // CHECK-DAG: memref.global "private" constant @__constant_1x1x8x8xf32 : memref<1x1x8x8xf32> // CHECK-DAG: memref.global "private" constant @__constant_1x1x8x8xf32_0 : memref<1x1x8x8xf32> // CHECK-DAG: memref.global "private" constant @__constant_8xf32 : memref<8xf32> -// CHECK-DAG: memref.global "private" constant @__constant_8xf32_0 : memref<8xf32> +// CHECK-DAG: memref.global "private" constant @__constant_8xf32_1 : memref<8xf32> // Randomized input. // CHECK-DAG: memref.global "private" @__wrapper_0 : memref<8x8xf32> diff --git a/test/Passes/DefaultPipeline/amx-initialization.mlir b/test/Passes/DefaultPipeline/amx-initialization.mlir index 82e21e094..14224ea8d 100644 --- a/test/Passes/DefaultPipeline/amx-initialization.mlir +++ b/test/Passes/DefaultPipeline/amx-initialization.mlir @@ -5,9 +5,10 @@ // CHECK-AMX-BF16-LABEL: llvm.func @entry -// CHECK-AMX-BF16: amx.tileloadd64 -// CHECK-AMX-BF16: amx.tdpbf16ps -// CHECK-AMX-BF16: amx.tilestored64 +// CHECK-AMX-BF16: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"{{.*}} -> !llvm.x86_amx +// CHECK-AMX-BF16: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"{{.*}} -> !llvm.x86_amx +// CHECK-AMX-BF16: llvm.call_intrinsic "llvm.x86.tilezero.internal"{{.*}} -> !llvm.x86_amx +// CHECK-AMX-BF16: llvm.call_intrinsic "llvm.x86.tdpbf16ps.internal"{{.*}} -> !llvm.x86_amx func.func @entry(%arg0: memref<16x32xbf16>, %arg1: memref<16x32xbf16>, %arg2: memref<16x16xf32>) { From 51827fab3671fad58b74c8059b1e1f94fdd0011c Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Thu, 12 Jun 2025 04:58:53 -0700 Subject: [PATCH 8/8] Bump (for amx conversion fix) and remove duplicate tests --- build_tools/llvm_version.txt | 2 +- .../vector-contract-to-amx-gemm.mlir | 20 ---------- .../vector-contract-to-amx-mlp.mlir | 39 ------------------- 3 files changed, 1 insertion(+), 60 deletions(-) delete mode 100644 test/BF16/Integration/avx512bf16/vector-contract-to-amx-gemm.mlir delete mode 100644 test/BF16/Integration/avx512bf16/vector-contract-to-amx-mlp.mlir diff --git a/build_tools/llvm_version.txt b/build_tools/llvm_version.txt index 2bf3e65a0..10c7980b6 100644 --- a/build_tools/llvm_version.txt +++ b/build_tools/llvm_version.txt @@ -1 +1 @@ -fb761aa38b0bc01ab911f5dbbfb474b70aaafbb4 +d698ede748e66f5519cb8481abc2df89a994a059 diff --git a/test/BF16/Integration/avx512bf16/vector-contract-to-amx-gemm.mlir b/test/BF16/Integration/avx512bf16/vector-contract-to-amx-gemm.mlir deleted file mode 100644 index 840e1bb14..000000000 --- a/test/BF16/Integration/avx512bf16/vector-contract-to-amx-gemm.mlir +++ /dev/null @@ -1,20 +0,0 @@ -// RUN: tpp-run -e entry --entry-point-result=void -print --splat-to-random --init-type normal -seed 123 %s > %t.1 -// RUN: tpp-run %s -e entry --entry-point-result=void --vector-to-kernels --registerBlocking=32,32,32 -print --splat-to-random --init-type normal -seed 123 > %t.2 -// RUN: fpcmp -r 0.001 %t.1 %t.2 - - -#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d4, d6, d3)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d6, d5, d3)> -#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d4, d5)> -module { - func.func @entry(%arg0: tensor<1x2x32x32xbf16>, %arg1: tensor<2x2x16x32x2xbf16>, %arg2: tensor<1x2x32x32xbf16>) -> tensor<1x2x32x32xbf16> { - %expanded = tensor.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [1, 2, 32, 16, 2] : tensor<1x2x32x32xbf16> into tensor<1x2x32x16x2xbf16> - %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%expanded, %arg1 : tensor<1x2x32x16x2xbf16>, tensor<2x2x16x32x2xbf16>) outs(%arg2 : tensor<1x2x32x32xbf16>) { - ^bb0(%in: bf16, %in_0: bf16, %out: bf16): - %1 = arith.mulf %in, %in_0 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } -> tensor<1x2x32x32xbf16> - return %0 : tensor<1x2x32x32xbf16> - } -} diff --git a/test/BF16/Integration/avx512bf16/vector-contract-to-amx-mlp.mlir b/test/BF16/Integration/avx512bf16/vector-contract-to-amx-mlp.mlir deleted file mode 100644 index 9f8a84591..000000000 --- a/test/BF16/Integration/avx512bf16/vector-contract-to-amx-mlp.mlir +++ /dev/null @@ -1,39 +0,0 @@ -// RUN: tpp-run -e entry --entry-point-result=void -print --splat-to-random --init-type normal -seed 123 %s > %t.1 -// RUN: tpp-run %s -e entry --entry-point-result=void --vector-to-kernels --registerBlocking=32,32,32 -print --splat-to-random --init-type normal -seed 123 > %t.2 -// RUN: fpcmp -r 0.001 %t.1 %t.2 - - -#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d4, d6, d3)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d6, d5, d3)> -#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d4, d5)> -#map3 = affine_map<(d0, d1, d2, d3) -> (d1, d3)> -#map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -module { - func.func @entry(%arg0: tensor<1x2x32x32xbf16>, %arg1: tensor<2x2x16x32x2xbf16>, %arg2: tensor<2x32xbf16>, %arg3: tensor<1x2x32x32xbf16>, %arg4: tensor<2x2x16x32x2xbf16>, %arg5: tensor<2x32xbf16>, %arg6: tensor<1x2x32x32xbf16>) -> tensor<1x2x32x32xbf16> { - %expanded = tensor.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [1, 2, 32, 16, 2] : tensor<1x2x32x32xbf16> into tensor<1x2x32x16x2xbf16> - %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%expanded, %arg1 : tensor<1x2x32x16x2xbf16>, tensor<2x2x16x32x2xbf16>) outs(%arg3 : tensor<1x2x32x32xbf16>) { - ^bb0(%in: bf16, %in_1: bf16, %out: bf16): - %4 = arith.mulf %in, %in_1 : bf16 - %5 = arith.addf %out, %4 : bf16 - linalg.yield %5 : bf16 - } -> tensor<1x2x32x32xbf16> - %1 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<2x32xbf16>) outs(%0 : tensor<1x2x32x32xbf16>) { - ^bb0(%in: bf16, %out: bf16): - %4 = arith.addf %in, %out : bf16 - linalg.yield %4 : bf16 - } -> tensor<1x2x32x32xbf16> - %expanded_0 = tensor.expand_shape %1 [[0], [1], [2], [3, 4]] output_shape [1, 2, 32, 16, 2] : tensor<1x2x32x32xbf16> into tensor<1x2x32x16x2xbf16> - %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%expanded_0, %arg4 : tensor<1x2x32x16x2xbf16>, tensor<2x2x16x32x2xbf16>) outs(%arg6 : tensor<1x2x32x32xbf16>) { - ^bb0(%in: bf16, %in_1: bf16, %out: bf16): - %4 = arith.mulf %in, %in_1 : bf16 - %5 = arith.addf %out, %4 : bf16 - linalg.yield %5 : bf16 - } -> tensor<1x2x32x32xbf16> - %3 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg5 : tensor<2x32xbf16>) outs(%2 : tensor<1x2x32x32xbf16>) { - ^bb0(%in: bf16, %out: bf16): - %4 = arith.addf %in, %out : bf16 - linalg.yield %4 : bf16 - } -> tensor<1x2x32x32xbf16> - return %3 : tensor<1x2x32x32xbf16> - } -}