Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
faf5d747f174cc9d714839f0d3bce1a783eac2ac
d698ede748e66f5519cb8481abc2df89a994a059
17 changes: 10 additions & 7 deletions lib/TPP/Dialect/Check/BufferizableOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,12 @@ struct ExpectTrueLayoutInterface
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
const BufferizationOptions &options,
BufferizationState &state) const {
check::ExpectTrueOp expectTrueOp = cast<check::ExpectTrueOp>(op);

FailureOr<Value> maybeSrcBuffer =
getBuffer(rewriter, expectTrueOp.getOperand(), options);
getBuffer(rewriter, expectTrueOp.getOperand(), options, state);
if (failed(maybeSrcBuffer))
return failure();
Value srcBuffer = *maybeSrcBuffer;
Expand Down Expand Up @@ -91,16 +92,17 @@ struct ExpectAlmostEqLayoutInterface
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
const BufferizationOptions &options,
BufferizationState &state) const {
check::ExpectAlmostEqOp almostEqOp = cast<check::ExpectAlmostEqOp>(op);
FailureOr<Value> maybeFirstBuffer =
getBuffer(rewriter, almostEqOp.getLhs(), options);
getBuffer(rewriter, almostEqOp.getLhs(), options, state);
if (failed(maybeFirstBuffer))
return failure();
Value firstBuffer = *maybeFirstBuffer;

FailureOr<Value> maybeSecondBuffer =
getBuffer(rewriter, almostEqOp.getRhs(), options);
getBuffer(rewriter, almostEqOp.getRhs(), options, state);
if (failed(maybeSecondBuffer))
return failure();
Value secondBuffer = *maybeSecondBuffer;
Expand Down Expand Up @@ -142,10 +144,11 @@ struct ExpectSaneLayoutInterface
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
const BufferizationOptions &options,
BufferizationState &state) const {
check::ExpectSaneOp saneOp = cast<check::ExpectSaneOp>(op);
FailureOr<Value> maybeBuffer =
getBuffer(rewriter, saneOp.getOperand(), options);
getBuffer(rewriter, saneOp.getOperand(), options, state);
if (failed(maybeBuffer)) {
return failure();
}
Expand Down
6 changes: 4 additions & 2 deletions lib/TPP/Dialect/Perf/BufferizableOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,12 @@ struct SinkLayoutInterface
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
const BufferizationOptions &options,
BufferizationState &state) const {
auto sink = cast<perf::SinkOp>(op);

FailureOr<Value> srcBuffer = getBuffer(rewriter, sink.getInput(), options);
FailureOr<Value> srcBuffer =
getBuffer(rewriter, sink.getInput(), options, state);
if (failed(srcBuffer))
return failure();

Expand Down
10 changes: 5 additions & 5 deletions lib/TPP/Transforms/LowerPacksAndUnpacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ static void fuseOrTilePacks(RewriterBase &rewriter, FunctionOpInterface func) {
forLoops);
if (!fusedProducer)
continue;
rewriter.replaceOp(consumerPackOp, tilingResult->mergeResult.replacements);
rewriter.replaceOp(consumerPackOp, tilingResult->replacements);
}

// Tile packs.
Expand All @@ -124,7 +124,7 @@ static void fuseOrTilePacks(RewriterBase &rewriter, FunctionOpInterface func) {
rewriter, cast<TilingInterface>(packOp.getOperation()), tileSizes);
if (failed(tilingResult))
continue;
rewriter.replaceOp(packOp, tilingResult->mergeResult.replacements);
rewriter.replaceOp(packOp, tilingResult->replacements);
}

// Tile unpacks.
Expand All @@ -136,7 +136,7 @@ static void fuseOrTilePacks(RewriterBase &rewriter, FunctionOpInterface func) {
rewriter, cast<TilingInterface>(unPackOp.getOperation()), tileSizes);
if (failed(tilingResult))
continue;
rewriter.replaceOp(unPackOp, tilingResult->mergeResult.replacements);
rewriter.replaceOp(unPackOp, tilingResult->replacements);
}
}

Expand Down Expand Up @@ -215,7 +215,7 @@ class LowerPacksAndUnPacks
unpackTilingOptions);
if (failed(tilingResult))
return signalPassFailure();
rewriter.replaceOp(unPackOp, tilingResult->mergeResult.replacements);
rewriter.replaceOp(unPackOp, tilingResult->replacements);
});
getOperation()->walk([&](linalg::PackOp packOp) {
SmallVector<int64_t> tiles(packOp.getSourceType().getRank(), 1);
Expand All @@ -226,7 +226,7 @@ class LowerPacksAndUnPacks
packTilingOptions);
if (failed(tilingResult))
return signalPassFailure();
rewriter.replaceOp(packOp, tilingResult->mergeResult.replacements);
rewriter.replaceOp(packOp, tilingResult->replacements);
});
RewritePatternSet patterns(&getContext());
patterns.add<linalg::DecomposeOuterUnitDimsUnPackOpPattern,
Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/Transforms/RewriteBatchMatmulToMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ struct RewriteBatchMatmulToMatmul
tilingOpts);
if (failed(tilingResult))
return signalPassFailure();
rewriter.replaceOp(batchMatmulOp, tilingResult->mergeResult.replacements);
rewriter.replaceOp(batchMatmulOp, tilingResult->replacements);
});

// Step2:
Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/Transforms/SplitReductionDim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ struct SplitContractionReduction
return rewriter.notifyMatchFailure(linalgOp,
"failed to tile contraction");

rewriter.replaceOp(linalgOp, tilingResult->mergeResult.replacements);
rewriter.replaceOp(linalgOp, tilingResult->replacements);

return success();
}
Expand Down
4 changes: 2 additions & 2 deletions lib/TPP/Transforms/VectorContractToAMX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,8 @@ struct VectorContractToAMXPattern
return rewriter.notifyMatchFailure(
op, "Accumulator defined by TransferReadOp");

if (!llvm::all_of(lhsDefiningOp.getIndices(), isZeroIndex) ||
!llvm::all_of(rhsDefiningOp.getIndices(), isZeroIndex))
if (!llvm::all_of(lhsDefiningOp.getIndices(), isZeroInteger) ||
!llvm::all_of(rhsDefiningOp.getIndices(), isZeroInteger))
return rewriter.notifyMatchFailure(
op, "Inputs are not whole tensor or subview");

Expand Down
4 changes: 2 additions & 2 deletions lib/TPP/Transforms/VectorContractToFMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ struct VectorContractToFMAPattern
return failure();

// Make sure the inputs being read are whole tensor or subview.
if (!llvm::all_of(lhsDefiningOp.getIndices(), isZeroIndex) ||
!llvm::all_of(rhsDefiningOp.getIndices(), isZeroIndex)) {
if (!llvm::all_of(lhsDefiningOp.getIndices(), isZeroInteger) ||
!llvm::all_of(rhsDefiningOp.getIndices(), isZeroInteger)) {
return failure();
}

Expand Down
4 changes: 2 additions & 2 deletions lib/TPP/Transforms/VectorContractToOuterproduct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ struct VectorContractToOuterproductPattern
return failure();

// Make sure the inputs being read are whole tensor or subview.
if (!llvm::all_of(lhsDefiningOp.getIndices(), isZeroIndex) ||
!llvm::all_of(rhsDefiningOp.getIndices(), isZeroIndex)) {
if (!llvm::all_of(lhsDefiningOp.getIndices(), isZeroInteger) ||
!llvm::all_of(rhsDefiningOp.getIndices(), isZeroInteger)) {
return failure();
}

Expand Down
14 changes: 7 additions & 7 deletions python/mlir/tpp/sched/bundles.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional, Sequence

from mlir import ir
from mlir.dialects import transform
from .common import apply_registered_pass, match
from .utils import GpuBackend, PipelineInterrupt
Expand Down Expand Up @@ -67,7 +68,7 @@ def linalg_lowering(mod, /, *, skip_operations: Sequence[str] = (), **_config):
func = apply_registered_pass(
func,
"convert-linalg-to-xsmm",
options="skip-operations=" + ",".join(skip_operations),
options={"skip-operations": ",".join(skip_operations)},
)
func = apply_registered_pass(func, "combine-xsmm-op-optimization")
func = apply_registered_pass(func, "fold-xsmm-flags")
Expand Down Expand Up @@ -130,7 +131,7 @@ def low_level_parallel(
# Run cleanup after LICM to allow CSE to eliminate common operations now
# that they are hoisted out of loops.
mod = cleanup(mod)
options = "parallel-loop-tile-sizes=" + ",".join(map(str, parallel_task_grid))
options = {"parallel-loop-tile-sizes": ",".join(map(str, parallel_task_grid))}
mod = apply_registered_pass(mod, "scf-parallel-loop-tiling", options=options)
return mod

Expand Down Expand Up @@ -228,7 +229,7 @@ def default_tpp_passes(
mod = linalg_lowering(mod, skip_operations=skip_ops, **config)
if linalg_to_vector or force_linalg_to_vector:
func = match(mod, ops={"func.func"})
options = "registerTileShape=" + ",".join(map(str, register_blocking))
options = {"registerTileShape": ",".join(map(str, register_blocking))}
func = apply_registered_pass(func, "brgemm-linalg-tiling", options=options)
func = apply_registered_pass(func, "loop-invariant-code-motion")
apply_registered_pass(func, "vectorization-pass")
Expand Down Expand Up @@ -315,7 +316,7 @@ def default_pipeline(
# #if defined(__x86_64__)
# options.x86Vector = true;
# #endif
options = f"enable-amx={int(xsmm_utils.has_amx())}"
options = {"enable-amx": int(xsmm_utils.has_amx())}
mod = apply_registered_pass(mod, "convert-vector-to-llvm", options=options)
mod = apply_registered_pass(mod, "finalize-memref-to-llvm")
mod = apply_registered_pass(mod, "convert-scf-to-cf")
Expand All @@ -327,9 +328,8 @@ def default_pipeline(
# gpu-to-llvm cannot be invoked from transform-interpreter as it
# tries to load ... something while multi-threaded PassManager is running.
mod = apply_registered_pass(mod, "gpu-to-llvm")
mod = apply_registered_pass(
mod, "gpu-module-to-binary", options="compilation-target=fatbin"
)
options = {"compilation-target": "fatbin"}
mod = apply_registered_pass(mod, "gpu-module-to-binary", options=options)
mod = apply_registered_pass(mod, "convert-math-to-llvm")
if gpu_backend:
mod = apply_registered_pass(mod, "async-to-async-runtime")
Expand Down
2 changes: 1 addition & 1 deletion python/mlir/tpp/sched/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

# Wrapper to addresss verbosity.
def apply_registered_pass(*args, **kwargs):
return transform.ApplyRegisteredPassOp(transform.AnyOpType.get(), *args, **kwargs)
return transform.apply_registered_pass(transform.AnyOpType.get(), *args, **kwargs)


# Wrapper to addresss verbosity.
Expand Down
20 changes: 0 additions & 20 deletions test/BF16/Integration/avx512bf16/vector-contract-to-amx-gemm.mlir

This file was deleted.

39 changes: 0 additions & 39 deletions test/BF16/Integration/avx512bf16/vector-contract-to-amx-mlp.mlir

This file was deleted.

2 changes: 1 addition & 1 deletion test/Integration/tpp-run-splat-mlp.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func.func @entry(%arg0: tensor<8x8xf32>, %output: tensor<8x8xf32>) -> tensor<8x8
// CHECK-DAG: memref.global "private" constant @__constant_1x1x8x8xf32 : memref<1x1x8x8xf32>
// CHECK-DAG: memref.global "private" constant @__constant_1x1x8x8xf32_0 : memref<1x1x8x8xf32>
// CHECK-DAG: memref.global "private" constant @__constant_8xf32 : memref<8xf32>
// CHECK-DAG: memref.global "private" constant @__constant_8xf32_0 : memref<8xf32>
// CHECK-DAG: memref.global "private" constant @__constant_8xf32_1 : memref<8xf32>

// Randomized input.
// CHECK-DAG: memref.global "private" @__wrapper_0 : memref<8x8xf32>
Expand Down
7 changes: 4 additions & 3 deletions test/Passes/DefaultPipeline/amx-initialization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@


// CHECK-AMX-BF16-LABEL: llvm.func @entry
// CHECK-AMX-BF16: amx.tileloadd64
// CHECK-AMX-BF16: amx.tdpbf16ps
// CHECK-AMX-BF16: amx.tilestored64
// CHECK-AMX-BF16: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"{{.*}} -> !llvm.x86_amx
// CHECK-AMX-BF16: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"{{.*}} -> !llvm.x86_amx
// CHECK-AMX-BF16: llvm.call_intrinsic "llvm.x86.tilezero.internal"{{.*}} -> !llvm.x86_amx
// CHECK-AMX-BF16: llvm.call_intrinsic "llvm.x86.tdpbf16ps.internal"{{.*}} -> !llvm.x86_amx
func.func @entry(%arg0: memref<16x32xbf16>,
%arg1: memref<16x32xbf16>,
%arg2: memref<16x16xf32>) {
Expand Down