diff --git a/mlir/include/mlir/Dialect/IRDL/IRDLSymbols.h b/mlir/include/mlir/Dialect/IRDL/IRDLSymbols.h new file mode 100644 index 0000000000000..4b7292c054ec2 --- /dev/null +++ b/mlir/include/mlir/Dialect/IRDL/IRDLSymbols.h @@ -0,0 +1,37 @@ +//===- IRDLSymbols.h - IRDL-related symbol logic ----------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Manages lookup logic for IRDL dialect-absolute symbols. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_IRDL_IRDLSYMBOLS_H +#define MLIR_DIALECT_IRDL_IRDLSYMBOLS_H + +#include "mlir/IR/Operation.h" +#include "mlir/IR/SymbolTable.h" + +namespace mlir { +namespace irdl { + +/// Looks up a symbol from the symbol table containing the source operation's +/// dialect definition operation. The source operation must be nested within an +/// IRDL dialect definition operation. This exploits SymbolTableCollection for +/// better symbol table lookup. +Operation *lookupSymbolNearDialect(SymbolTableCollection &symbolTable, + Operation *source, SymbolRefAttr symbol); + +/// Looks up a symbol from the symbol table containing the source operation's +/// dialect definition operation. The source operation must be nested within an +/// IRDL dialect definition operation. +Operation *lookupSymbolNearDialect(Operation *source, SymbolRefAttr symbol); + +} // namespace irdl +} // namespace mlir + +#endif // MLIR_DIALECT_IRDL_IRDLSYMBOLS_H diff --git a/mlir/lib/Dialect/IRDL/CMakeLists.txt b/mlir/lib/Dialect/IRDL/CMakeLists.txt index d25760e5d29bc..db4b98ef5308e 100644 --- a/mlir/lib/Dialect/IRDL/CMakeLists.txt +++ b/mlir/lib/Dialect/IRDL/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRIRDL IR/IRDL.cpp IR/IRDLOps.cpp IRDLLoading.cpp + IRDLSymbols.cpp IRDLVerifiers.cpp DEPENDS diff --git a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp index e4728f55b49d7..1f5584fa30c27 100644 --- a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp +++ b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/IRDL/IR/IRDL.h" +#include "mlir/Dialect/IRDL/IRDLSymbols.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Diagnostics.h" @@ -132,10 +133,14 @@ LogicalResult BaseOp::verify() { return success(); } +/// Finds whether the provided symbol is an IRDL type or attribute definition. +/// The source operation must be within a DialectOp. static LogicalResult checkSymbolIsTypeOrAttribute(SymbolTableCollection &symbolTable, Operation *source, SymbolRefAttr symbol) { - Operation *targetOp = symbolTable.lookupNearestSymbolFrom(source, symbol); + Operation *targetOp = + irdl::lookupSymbolNearDialect(symbolTable, source, symbol); + if (!targetOp) return source->emitOpError() << "symbol '" << symbol << "' not found"; diff --git a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp index 0895306b8bce1..7ec3aa2741023 100644 --- a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp +++ b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/IRDL/IR/IRDL.h" +#include "mlir/Dialect/IRDL/IRDLSymbols.h" #include "mlir/IR/ValueRange.h" #include @@ -47,8 +48,9 @@ std::unique_ptr BaseOp::getVerifier( // Case where the input is a symbol reference. // This corresponds to the case where the base is an IRDL type or attribute. if (auto baseRef = getBaseRef()) { + // The verifier for BaseOp guarantees it is within a dialect. Operation *defOp = - SymbolTable::lookupNearestSymbolFrom(getOperation(), baseRef.value()); + irdl::lookupSymbolNearDialect(getOperation(), baseRef.value()); // Type case. if (auto typeOp = dyn_cast(defOp)) { @@ -99,10 +101,10 @@ std::unique_ptr ParametricOp::getVerifier( SmallVector constraints = getConstraintIndicesForArgs(getArgs(), valueToConstr); - // Symbol reference case for the base + // Symbol reference case for the base. + // The verifier for ParametricOp guarantees it is within a dialect. SymbolRefAttr symRef = getBaseType(); - Operation *defOp = - SymbolTable::lookupNearestSymbolFrom(getOperation(), symRef); + Operation *defOp = irdl::lookupSymbolNearDialect(getOperation(), symRef); if (!defOp) { emitError() << symRef << " does not refer to any existing symbol"; return nullptr; diff --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp index 5df2b45d8037b..5f623e8845d10 100644 --- a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp +++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/IRDL/IRDLLoading.h" #include "mlir/Dialect/IRDL/IR/IRDL.h" #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h" +#include "mlir/Dialect/IRDL/IRDLSymbols.h" #include "mlir/Dialect/IRDL/IRDLVerifiers.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" @@ -523,7 +524,7 @@ static bool getBases(Operation *op, SmallPtrSet ¶mIds, // For `irdl.parametric`, we get directly the base from the operation. if (auto params = dyn_cast(op)) { SymbolRefAttr symRef = params.getBaseType(); - Operation *defOp = SymbolTable::lookupNearestSymbolFrom(op, symRef); + Operation *defOp = irdl::lookupSymbolNearDialect(op, symRef); assert(defOp && "symbol reference should refer to an existing operation"); paramIrdlOps.insert(defOp); return false; diff --git a/mlir/lib/Dialect/IRDL/IRDLSymbols.cpp b/mlir/lib/Dialect/IRDL/IRDLSymbols.cpp new file mode 100644 index 0000000000000..ff2136df364d9 --- /dev/null +++ b/mlir/lib/Dialect/IRDL/IRDLSymbols.cpp @@ -0,0 +1,38 @@ +//===- IRDLSymbols.cpp - IRDL-related symbol logic --------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/IRDL/IRDLSymbols.h" +#include "mlir/Dialect/IRDL/IR/IRDL.h" + +using namespace mlir; +using namespace mlir::irdl; + +static Operation *lookupDialectOp(Operation *source) { + Operation *dialectOp = source; + while (dialectOp && !isa(dialectOp)) + dialectOp = dialectOp->getParentOp(); + + if (!dialectOp) + llvm_unreachable("symbol lookup near dialect must originate from " + "within a dialect definition"); + + return dialectOp; +} + +Operation * +mlir::irdl::lookupSymbolNearDialect(SymbolTableCollection &symbolTable, + Operation *source, SymbolRefAttr symbol) { + return symbolTable.lookupNearestSymbolFrom( + lookupDialectOp(source)->getParentOp(), symbol); +} + +Operation *mlir::irdl::lookupSymbolNearDialect(Operation *source, + SymbolRefAttr symbol) { + return SymbolTable::lookupNearestSymbolFrom( + lookupDialectOp(source)->getParentOp(), symbol); +} diff --git a/mlir/test/Dialect/IRDL/cmath.irdl.mlir b/mlir/test/Dialect/IRDL/cmath.irdl.mlir index 997af08d24733..0b7e220ceb90c 100644 --- a/mlir/test/Dialect/IRDL/cmath.irdl.mlir +++ b/mlir/test/Dialect/IRDL/cmath.irdl.mlir @@ -19,13 +19,13 @@ module { // CHECK: irdl.operation @norm { // CHECK: %[[v0:[^ ]*]] = irdl.any - // CHECK: %[[v1:[^ ]*]] = irdl.parametric @complex<%[[v0]]> + // CHECK: %[[v1:[^ ]*]] = irdl.parametric @cmath::@complex<%[[v0]]> // CHECK: irdl.operands(%[[v1]]) // CHECK: irdl.results(%[[v0]]) // CHECK: } irdl.operation @norm { %0 = irdl.any - %1 = irdl.parametric @complex<%0> + %1 = irdl.parametric @cmath::@complex<%0> irdl.operands(%1) irdl.results(%0) } @@ -34,7 +34,7 @@ module { // CHECK: %[[v0:[^ ]*]] = irdl.is f32 // CHECK: %[[v1:[^ ]*]] = irdl.is f64 // CHECK: %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]]) - // CHECK: %[[v3:[^ ]*]] = irdl.parametric @complex<%[[v2]]> + // CHECK: %[[v3:[^ ]*]] = irdl.parametric @cmath::@complex<%[[v2]]> // CHECK: irdl.operands(%[[v3]], %[[v3]]) // CHECK: irdl.results(%[[v3]]) // CHECK: } @@ -42,7 +42,7 @@ module { %0 = irdl.is f32 %1 = irdl.is f64 %2 = irdl.any_of(%0, %1) - %3 = irdl.parametric @complex<%2> + %3 = irdl.parametric @cmath::@complex<%2> irdl.operands(%3, %3) irdl.results(%3) } diff --git a/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir b/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir index db8dfc5cb36ca..cbcc248bf00b1 100644 --- a/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir +++ b/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir @@ -6,14 +6,14 @@ irdl.dialect @testd { // CHECK: irdl.type @self_referencing { // CHECK: %[[v0:[^ ]*]] = irdl.any - // CHECK: %[[v1:[^ ]*]] = irdl.parametric @self_referencing<%[[v0]]> + // CHECK: %[[v1:[^ ]*]] = irdl.parametric @testd::@self_referencing<%[[v0]]> // CHECK: %[[v2:[^ ]*]] = irdl.is i32 // CHECK: %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]]) // CHECK: irdl.parameters(%[[v3]]) // CHECK: } irdl.type @self_referencing { %0 = irdl.any - %1 = irdl.parametric @self_referencing<%0> + %1 = irdl.parametric @testd::@self_referencing<%0> %2 = irdl.is i32 %3 = irdl.any_of(%1, %2) irdl.parameters(%3) @@ -22,13 +22,13 @@ irdl.dialect @testd { // CHECK: irdl.type @type1 { // CHECK: %[[v0:[^ ]*]] = irdl.any - // CHECK: %[[v1:[^ ]*]] = irdl.parametric @type2<%[[v0]]> + // CHECK: %[[v1:[^ ]*]] = irdl.parametric @testd::@type2<%[[v0]]> // CHECK: %[[v2:[^ ]*]] = irdl.is i32 // CHECK: %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]]) // CHECK: irdl.parameters(%[[v3]]) irdl.type @type1 { %0 = irdl.any - %1 = irdl.parametric @type2<%0> + %1 = irdl.parametric @testd::@type2<%0> %2 = irdl.is i32 %3 = irdl.any_of(%1, %2) irdl.parameters(%3) @@ -36,13 +36,13 @@ irdl.dialect @testd { // CHECK: irdl.type @type2 { // CHECK: %[[v0:[^ ]*]] = irdl.any - // CHECK: %[[v1:[^ ]*]] = irdl.parametric @type1<%[[v0]]> + // CHECK: %[[v1:[^ ]*]] = irdl.parametric @testd::@type1<%[[v0]]> // CHECK: %[[v2:[^ ]*]] = irdl.is i32 // CHECK: %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]]) // CHECK: irdl.parameters(%[[v3]]) irdl.type @type2 { %0 = irdl.any - %1 = irdl.parametric @type1<%0> + %1 = irdl.parametric @testd::@type1<%0> %2 = irdl.is i32 %3 = irdl.any_of(%1, %2) irdl.parameters(%3) diff --git a/mlir/test/Dialect/IRDL/invalid.irdl.mlir b/mlir/test/Dialect/IRDL/invalid.irdl.mlir index f207d31cf158b..93ad619358750 100644 --- a/mlir/test/Dialect/IRDL/invalid.irdl.mlir +++ b/mlir/test/Dialect/IRDL/invalid.irdl.mlir @@ -2,8 +2,6 @@ // Testing invalid IRDL IRs -func.func private @foo() - irdl.dialect @testd { irdl.type @type { // expected-error@+1 {{symbol '@foo' not found}} @@ -44,15 +42,12 @@ irdl.dialect @testd { // ----- +func.func private @not_a_type_or_attr() + irdl.dialect @invalid_parametric { irdl.operation @foo { // expected-error@+1 {{symbol '@not_a_type_or_attr' does not refer to a type or attribute definition}} %param = irdl.parametric @not_a_type_or_attr<> irdl.results(%param) } - - irdl.operation @not_a_type_or_attr { - %param = irdl.is i1 - irdl.results(%param) - } } diff --git a/mlir/test/Dialect/IRDL/testd.irdl.mlir b/mlir/test/Dialect/IRDL/testd.irdl.mlir index f828d95bdb81d..aeb1a83747ecc 100644 --- a/mlir/test/Dialect/IRDL/testd.irdl.mlir +++ b/mlir/test/Dialect/IRDL/testd.irdl.mlir @@ -76,20 +76,20 @@ irdl.dialect @testd { } // CHECK: irdl.operation @dyn_type_base { - // CHECK: %[[v1:[^ ]*]] = irdl.base @parametric + // CHECK: %[[v1:[^ ]*]] = irdl.base @testd::@parametric // CHECK: irdl.results(%[[v1]]) // CHECK: } irdl.operation @dyn_type_base { - %0 = irdl.base @parametric + %0 = irdl.base @testd::@parametric irdl.results(%0) } // CHECK: irdl.operation @dyn_attr_base { - // CHECK: %[[v1:[^ ]*]] = irdl.base @parametric_attr + // CHECK: %[[v1:[^ ]*]] = irdl.base @testd::@parametric_attr // CHECK: irdl.attributes {"attr1" = %[[v1]]} // CHECK: } irdl.operation @dyn_attr_base { - %0 = irdl.base @parametric_attr + %0 = irdl.base @testd::@parametric_attr irdl.attributes {"attr1" = %0} } @@ -115,14 +115,14 @@ irdl.dialect @testd { // CHECK: %[[v0:[^ ]*]] = irdl.is i32 // CHECK: %[[v1:[^ ]*]] = irdl.is i64 // CHECK: %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]]) - // CHECK: %[[v3:[^ ]*]] = irdl.parametric @parametric<%[[v2]]> + // CHECK: %[[v3:[^ ]*]] = irdl.parametric @testd::@parametric<%[[v2]]> // CHECK: irdl.results(%[[v3]]) // CHECK: } irdl.operation @dynparams { %0 = irdl.is i32 %1 = irdl.is i64 %2 = irdl.any_of(%0, %1) - %3 = irdl.parametric @parametric<%2> + %3 = irdl.parametric @testd::@parametric<%2> irdl.results(%3) }