diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt index 0c7937dfd69e5..846547ff131e3 100644 --- a/mlir/include/mlir/IR/CMakeLists.txt +++ b/mlir/include/mlir/IR/CMakeLists.txt @@ -2,6 +2,8 @@ add_mlir_interface(SymbolInterfaces) add_mlir_interface(RegionKindInterface) set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td) +mlir_tablegen(OpAsmAttrInterface.h.inc -gen-attr-interface-decls) +mlir_tablegen(OpAsmAttrInterface.cpp.inc -gen-attr-interface-defs) mlir_tablegen(OpAsmOpInterface.h.inc -gen-op-interface-decls) mlir_tablegen(OpAsmOpInterface.cpp.inc -gen-op-interface-defs) mlir_tablegen(OpAsmTypeInterface.h.inc -gen-type-interface-decls) diff --git a/mlir/include/mlir/IR/OpAsmInterface.td b/mlir/include/mlir/IR/OpAsmInterface.td index 34c830a12856f..c3e84bccc5dee 100644 --- a/mlir/include/mlir/IR/OpAsmInterface.td +++ b/mlir/include/mlir/IR/OpAsmInterface.td @@ -130,6 +130,28 @@ def OpAsmTypeInterface : TypeInterface<"OpAsmTypeInterface"> { ]; } +//===----------------------------------------------------------------------===// +// OpAsmAttrInterface +//===----------------------------------------------------------------------===// + +def OpAsmAttrInterface : AttrInterface<"OpAsmAttrInterface"> { + let description = [{ + This interface provides hooks to interact with the AsmPrinter and AsmParser + classes. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod<[{ + Get a name to use when generating an alias for this attribute. + }], + "::mlir::OpAsmDialectInterface::AliasResult", "getAlias", + (ins "::llvm::raw_ostream&":$os), "", + "return ::mlir::OpAsmDialectInterface::AliasResult::NoAlias;" + >, + ]; +} + //===----------------------------------------------------------------------===// // ResourceHandleParameter //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 5eb8b4a5cff5b..a863e881ee7c8 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -1825,6 +1825,7 @@ ParseResult parseDimensionList(OpAsmParser &parser, //===--------------------------------------------------------------------===// /// The OpAsmOpInterface, see OpAsmInterface.td for more details. +#include "mlir/IR/OpAsmAttrInterface.h.inc" #include "mlir/IR/OpAsmOpInterface.h.inc" #include "mlir/IR/OpAsmTypeInterface.h.inc" diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 03b8fca0fa7ab..cc578eae3ee36 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -125,6 +125,7 @@ void OpAsmPrinter::printFunctionalType(Operation *op) { //===----------------------------------------------------------------------===// /// The OpAsmOpInterface, see OpAsmInterface.td for more details. +#include "mlir/IR/OpAsmAttrInterface.cpp.inc" #include "mlir/IR/OpAsmOpInterface.cpp.inc" #include "mlir/IR/OpAsmTypeInterface.cpp.inc" @@ -1159,15 +1160,31 @@ template void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias, bool canBeDeferred) { SmallString<32> nameBuffer; - for (const auto &interface : interfaces) { - OpAsmDialectInterface::AliasResult result = - interface.getAlias(symbol, aliasOS); - if (result == OpAsmDialectInterface::AliasResult::NoAlias) - continue; - nameBuffer = std::move(aliasBuffer); - assert(!nameBuffer.empty() && "expected valid alias name"); - if (result == OpAsmDialectInterface::AliasResult::FinalAlias) - break; + + OpAsmDialectInterface::AliasResult symbolInterfaceResult = + OpAsmDialectInterface::AliasResult::NoAlias; + if constexpr (std::is_base_of_v) { + if (auto symbolInterface = dyn_cast(symbol)) { + symbolInterfaceResult = symbolInterface.getAlias(aliasOS); + if (symbolInterfaceResult != + OpAsmDialectInterface::AliasResult::NoAlias) { + nameBuffer = std::move(aliasBuffer); + assert(!nameBuffer.empty() && "expected valid alias name"); + } + } + } + + if (symbolInterfaceResult != OpAsmDialectInterface::AliasResult::FinalAlias) { + for (const auto &interface : interfaces) { + OpAsmDialectInterface::AliasResult result = + interface.getAlias(symbol, aliasOS); + if (result == OpAsmDialectInterface::AliasResult::NoAlias) + continue; + nameBuffer = std::move(aliasBuffer); + assert(!nameBuffer.empty() && "expected valid alias name"); + if (result == OpAsmDialectInterface::AliasResult::FinalAlias) + break; + } } if (nameBuffer.empty()) diff --git a/mlir/test/IR/op-asm-interface.mlir b/mlir/test/IR/op-asm-interface.mlir index 6cce9572f4fc2..44a6e7afece03 100644 --- a/mlir/test/IR/op-asm-interface.mlir +++ b/mlir/test/IR/op-asm-interface.mlir @@ -58,3 +58,17 @@ func.func @block_argument_name_from_op_asm_type_interface_asmprinter() { } return } + +// ----- + +//===----------------------------------------------------------------------===// +// Test OpAsmAttrInterface +//===----------------------------------------------------------------------===// + +// CHECK: #op_asm_attr_interface_test +#attr = #test.op_asm_attr_interface + +func.func @test_op_asm_attr_interface() { + %1 = "test.result_name_from_type"() {attr = #attr} : () -> !test.op_asm_type_interface + return +} diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td index 0fd272f85d39b..4b809c1c0a765 100644 --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -395,4 +395,14 @@ def TestCustomLocationAttr : Test_LocAttr<"TestCustomLocation"> { let assemblyFormat = "`<` $file `*` $line `>`"; } +// Test OpAsmAttrInterface. +def TestOpAsmAttrInterfaceAttr : Test_Attr<"TestOpAsmAttrInterface", + [DeclareAttrInterfaceMethods]> { + let mnemonic = "op_asm_attr_interface"; + let parameters = (ins "mlir::StringAttr":$value); + let assemblyFormat = [{ + `<` struct(params) `>` + }]; +} + #endif // TEST_ATTRDEFS diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp index e09ea10906164..7c467308386f1 100644 --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -67,7 +67,7 @@ void CompoundAAttr::print(AsmPrinter &printer) const { //===----------------------------------------------------------------------===// Attribute TestDecimalShapeAttr::parse(AsmParser &parser, Type type) { - if (parser.parseLess()){ + if (parser.parseLess()) { return Attribute(); } SmallVector shape; @@ -316,6 +316,17 @@ static ParseResult parseCustomFloatAttr(AsmParser &p, StringAttr &typeStrAttr, return success(); } +//===----------------------------------------------------------------------===// +// TestOpAsmAttrInterfaceAttr +//===----------------------------------------------------------------------===// + +::mlir::OpAsmDialectInterface::AliasResult +TestOpAsmAttrInterfaceAttr::getAlias(::llvm::raw_ostream &os) const { + os << "op_asm_attr_interface_"; + os << getValue().getValue(); + return ::mlir::OpAsmDialectInterface::AliasResult::FinalAlias; +} + //===----------------------------------------------------------------------===// // Tablegen Generated Definitions //===----------------------------------------------------------------------===//