From d6c9bbf62107a801ff7701cff2c75353a87b1119 Mon Sep 17 00:00:00 2001 From: Zenithal Date: Tue, 28 Jan 2025 05:32:20 +0000 Subject: [PATCH] [mlir] Integrate OpAsmTypeInterface with AsmPrinter --- mlir/lib/IR/AsmPrinter.cpp | 30 ++++++++++++++++++++++ mlir/test/IR/op-asm-interface.mlir | 36 +++++++++++++++++++++++++++ mlir/test/lib/Dialect/Test/TestOps.td | 19 ++++++++++++++ 3 files changed, 85 insertions(+) diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index eea4f7fa5c4be..9f17699be29e5 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1536,10 +1536,13 @@ StringRef maybeGetValueNameFromLoc(Value value, StringRef name) { } // namespace void SSANameState::numberValuesInRegion(Region ®ion) { + // Indicates whether OpAsmOpInterface set a name. + bool opAsmOpInterfaceUsed = false; auto setBlockArgNameFn = [&](Value arg, StringRef name) { assert(!valueIDs.count(arg) && "arg numbered multiple times"); assert(llvm::cast(arg).getOwner()->getParent() == ®ion && "arg not defined in current region"); + opAsmOpInterfaceUsed = true; if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix())) name = maybeGetValueNameFromLoc(arg, name); setValueName(arg, name); @@ -1549,6 +1552,15 @@ void SSANameState::numberValuesInRegion(Region ®ion) { if (Operation *op = region.getParentOp()) { if (auto asmInterface = dyn_cast(op)) asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn); + // If the OpAsmOpInterface didn't set a name, get name from the type. + if (!opAsmOpInterfaceUsed) { + for (BlockArgument arg : region.getArguments()) { + if (auto interface = dyn_cast(arg.getType())) { + interface.getAsmName( + [&](StringRef name) { setBlockArgNameFn(arg, name); }); + } + } + } } } @@ -1598,9 +1610,12 @@ void SSANameState::numberValuesInBlock(Block &block) { void SSANameState::numberValuesInOp(Operation &op) { // Function used to set the special result names for the operation. SmallVector resultGroups(/*Size=*/1, /*Value=*/0); + // Indicates whether OpAsmOpInterface set a name. + bool opAsmOpInterfaceUsed = false; auto setResultNameFn = [&](Value result, StringRef name) { assert(!valueIDs.count(result) && "result numbered multiple times"); assert(result.getDefiningOp() == &op && "result not defined by 'op'"); + opAsmOpInterfaceUsed = true; if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix())) name = maybeGetValueNameFromLoc(result, name); setValueName(result, name); @@ -1630,6 +1645,21 @@ void SSANameState::numberValuesInOp(Operation &op) { asmInterface.getAsmBlockNames(setBlockNameFn); asmInterface.getAsmResultNames(setResultNameFn); } + if (!opAsmOpInterfaceUsed) { + // If the OpAsmOpInterface didn't set a name, and all results have + // OpAsmTypeInterface, get names from types. + bool allHaveOpAsmTypeInterface = + llvm::all_of(op.getResultTypes(), [&](Type type) { + return isa(type); + }); + if (allHaveOpAsmTypeInterface) { + for (OpResult result : op.getResults()) { + auto interface = cast(result.getType()); + interface.getAsmName( + [&](StringRef name) { setResultNameFn(result, name); }); + } + } + } } unsigned numResults = op.getNumResults(); diff --git a/mlir/test/IR/op-asm-interface.mlir b/mlir/test/IR/op-asm-interface.mlir index a9c199e3dc973..6cce9572f4fc2 100644 --- a/mlir/test/IR/op-asm-interface.mlir +++ b/mlir/test/IR/op-asm-interface.mlir @@ -22,3 +22,39 @@ func.func @block_argument_name_from_op_asm_type_interface() { } return } + +// ----- + +//===----------------------------------------------------------------------===// +// Test OpAsmTypeInterface +//===----------------------------------------------------------------------===// + +func.func @result_name_from_op_asm_type_interface_asmprinter() { + // CHECK-LABEL: @result_name_from_op_asm_type_interface_asmprinter + // CHECK: %op_asm_type_interface + %0 = "test.result_name_from_type_interface"() : () -> !test.op_asm_type_interface + return +} + +// ----- + +// i1 does not have OpAsmTypeInterface, should not get named. +func.func @result_name_from_op_asm_type_interface_not_all() { + // CHECK-LABEL: @result_name_from_op_asm_type_interface_not_all + // CHECK-NOT: %op_asm_type_interface + // CHECK: %0:2 + %0:2 = "test.result_name_from_type_interface"() : () -> (!test.op_asm_type_interface, i1) + return +} + +// ----- + +func.func @block_argument_name_from_op_asm_type_interface_asmprinter() { + // CHECK-LABEL: @block_argument_name_from_op_asm_type_interface_asmprinter + // CHECK: ^bb0(%op_asm_type_interface + test.block_argument_name_from_type_interface { + ^bb0(%arg0: !test.op_asm_type_interface): + "test.terminator"() : ()->() + } + return +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 2aa0658ab0e5d..cdc1237ec8c5a 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -955,6 +955,25 @@ def BlockArgumentNameFromTypeOp let assemblyFormat = "regions attr-dict-with-keyword"; } +// This is used to test OpAsmTypeInterface::getAsmName's integration with AsmPrinter +// for op result name when OpAsmOpInterface::getAsmResultNames is the default implementation +// i.e. does nothing. +def ResultNameFromTypeInterfaceOp + : TEST_Op<"result_name_from_type_interface", + [OpAsmOpInterface]> { + let results = (outs Variadic:$r); +} + +// This is used to test OpAsmTypeInterface::getAsmName's integration with AsmPrinter +// for block argument name when OpAsmOpInterface::getAsmBlockArgumentNames is the default implementation +// i.e. does nothing. +def BlockArgumentNameFromTypeInterfaceOp + : TEST_Op<"block_argument_name_from_type_interface", + [OpAsmOpInterface]> { + let regions = (region AnyRegion:$body); + let assemblyFormat = "regions attr-dict-with-keyword"; +} + // This is used to test the OpAsmOpInterface::getDefaultDialect() feature: // operations nested in a region under this op will drop the "test." dialect // prefix.