diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td index 755396c8b9023..63f9ff1def4e1 100644 --- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td +++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td @@ -311,12 +311,12 @@ def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> { `f[k] = F(omega[n]^k) ; k = {0, ..., n-1}` - The choice of primitive root may be optionally specified. + The choice of primitive root is specified in the primitiveRootAttr of RingAttr. + Its degree affects the behavior of ntt performed, with n-th primitive root + performing cyclic convolution and 2n-th primitive root performing negacyclic + convolution. }]; - let arguments = (ins - Polynomial_PolynomialType:$input, - OptionalAttr:$root - ); + let arguments = (ins Polynomial_PolynomialType:$input); let results = (outs RankedTensorOf<[AnyInteger]>:$output); let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)"; let hasCanonicalizer = 1; @@ -335,12 +335,12 @@ def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> { `polynomial.ntt`). The ring of the polynomial is taken from the required encoding attribute of the tensor. - The choice of primitive root may be optionally specified. + The choice of primitive root is specified in the primitiveRootAttr of RingAttr. + Its degree affects the behavior of ntt performed, with n-th primitive root + performing cyclic convolution and 2n-th primitive root performing negacyclic + convolution. }]; - let arguments = ( - ins RankedTensorOf<[AnyInteger]>:$input, - OptionalAttr:$root - ); + let arguments = (ins RankedTensorOf<[AnyInteger]>:$input); let results = (outs Polynomial_PolynomialType:$output); let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)"; let hasCanonicalizer = 1; diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td index 7d59add3d37c2..00c9239fc6369 100644 --- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td +++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td @@ -126,6 +126,26 @@ def Polynomial_TypedFloatPolynomialAttr : Polynomial_Attr< }]; } +def Polynomial_PrimitiveRootAttr: Polynomial_Attr<"PrimitiveRoot", "primitive_root"> { + let summary = "an attribute containing an integer and its degree as a root of unity"; + let description = [{ + A primitive root attribute stores an integer root `value` and an integer + `degree`, corresponding to a primitive root of unity of the given degree in + an unspecified ring. + + Example: + + ```mlir + #poly = #polynomial.primitive_root + ``` + }]; + let parameters = (ins + "::mlir::IntegerAttr":$value, + "::mlir::IntegerAttr":$degree + ); + let assemblyFormat = "`<` struct(params) `>`"; +} + def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> { let summary = "an attribute specifying a polynomial ring"; let description = [{ @@ -142,6 +162,9 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> { modulus. For single-variable polynomials, an "polynomialModulus" is always specificed via a single polynomial, which we call `polynomialModulus`. + For ntt/intt and mul to ntt/intt optimization to work, an n-th or 2n-th + _primitiveRoot_ should be specified. + An expressive example is polynomials with i32 coefficients, whose coefficients are taken modulo `2**32 - 5`, with a polynomial modulus of `x**1024 - 1`. @@ -177,7 +200,8 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> { let parameters = (ins "Type": $coefficientType, OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus, - OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus + OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus, + OptionalParameter<"::mlir::polynomial::PrimitiveRootAttr">: $primitiveRoot ); let genVerifyDecl = 1; let assemblyFormat = "`<` struct(params) `>`"; @@ -185,38 +209,16 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> { AttrBuilderWithInferredContext< (ins "::mlir::Type":$coefficientTy, CArg<"::mlir::IntegerAttr", "nullptr"> :$coefficientModulusAttr, - CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr), [{ + CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr, + CArg<"::mlir::polynomial::PrimitiveRootAttr", "nullptr"> :$primitiveRootAttr), [{ return $_get( coefficientTy.getContext(), coefficientTy, coefficientModulusAttr, - polynomialModulusAttr); + polynomialModulusAttr, + primitiveRootAttr); }]>, ]; } -def Polynomial_PrimitiveRootAttr: Polynomial_Attr<"PrimitiveRoot", "primitive_root"> { - let summary = "an attribute containing an integer and its degree as a root of unity"; - let description = [{ - A primitive root attribute stores an integer root `value` and an integer - `degree`, corresponding to a primitive root of unity of the given degree in - an unspecified ring. - - This is used as an attribute on `polynomial.ntt` and `polynomial.intt` ops - to specify the root of unity used in lowering the transform. - - Example: - - ```mlir - #poly = #polynomial.primitive_root - ``` - }]; - let parameters = (ins - "::mlir::IntegerAttr":$value, - "::mlir::IntegerAttr":$degree - ); - let assemblyFormat = "`<` struct(params) `>`"; -} - - #endif // POLYNOMIAL_ATTRIBUTES diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp index cd7789a2e9531..f3f6afdee9950 100644 --- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp +++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp @@ -206,7 +206,8 @@ Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) { LogicalResult RingAttr::verify(function_ref emitError, Type coefficientType, IntegerAttr coefficientModulus, - IntPolynomialAttr polynomialModulus) { + IntPolynomialAttr polynomialModulus, + PrimitiveRootAttr primitiveRoot) { if (coefficientModulus) { auto coeffIntType = llvm::dyn_cast(coefficientType); if (!coeffIntType) { diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td index 28c45e6846380..a26b34e29d561 100644 --- a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td +++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td @@ -14,8 +14,6 @@ include "mlir/Dialect/Polynomial/IR/Polynomial.td" include "mlir/IR/OpBase.td" include "mlir/IR/PatternBase.td" -def Equal : Constraint>; - // Get a -1 integer attribute of the same type as the polynomial SSA value's // ring coefficient type. def getMinusOne @@ -30,15 +28,13 @@ def SubAsAdd : Pat< (Arith_ConstantOp (getMinusOne $g))))>; def INTTAfterNTT : Pat< - (Polynomial_INTTOp (Polynomial_NTTOp $poly, $r1), $r2), - (replaceWithValue $poly), - [(Equal $r1, $r2)] + (Polynomial_INTTOp (Polynomial_NTTOp $poly)), + (replaceWithValue $poly) >; def NTTAfterINTT : Pat< - (Polynomial_NTTOp (Polynomial_INTTOp $tensor, $r1), $r2), - (replaceWithValue $tensor), - [(Equal $r1, $r2)] + (Polynomial_NTTOp (Polynomial_INTTOp $tensor)), + (replaceWithValue $tensor) >; #endif // POLYNOMIAL_CANONICALIZATION diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp index 460ef17167e80..30a6a004c50af 100644 --- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp +++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp @@ -134,8 +134,7 @@ bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n, /// Verify that the types involved in an NTT or INTT operation are /// compatible. static LogicalResult verifyNTTOp(Operation *op, RingAttr ring, - RankedTensorType tensorType, - std::optional root) { + RankedTensorType tensorType) { Attribute encoding = tensorType.getEncoding(); if (!encoding) { return op->emitOpError() @@ -166,9 +165,10 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring, return diag; } - if (root.has_value()) { - APInt rootValue = root.value().getValue().getValue(); - APInt rootDegree = root.value().getDegree().getValue(); + auto root = ring.getPrimitiveRoot(); + if (root) { + APInt rootValue = root.getValue().getValue(); + APInt rootDegree = root.getDegree().getValue(); APInt cmod = ring.getCoefficientModulus().getValue(); if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) { return op->emitOpError() @@ -177,6 +177,9 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring, << "of unity mod " << cmod.getZExtValue() << ", with the specified degree " << rootDegree.getZExtValue(); } + } else { + return op->emitOpError() + << "primitive root not provided but ntt/intt op called"; } return success(); @@ -184,12 +187,12 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring, LogicalResult NTTOp::verify() { return verifyNTTOp(this->getOperation(), getInput().getType().getRing(), - getOutput().getType(), getRoot()); + getOutput().getType()); } LogicalResult INTTOp::verify() { return verifyNTTOp(this->getOperation(), getOutput().getType().getRing(), - getInput().getType(), getRoot()); + getInput().getType()); } ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) { diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir index c0ee514daab64..5a517a5e1ed9b 100644 --- a/mlir/test/Dialect/Polynomial/canonicalization.mlir +++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt -canonicalize %s | FileCheck %s #ntt_poly = #polynomial.int_polynomial<-1 + x**8> -#ntt_ring = #polynomial.ring #root = #polynomial.primitive_root +#ntt_ring = #polynomial.ring !ntt_poly_ty = !polynomial.polynomial !tensor_ty = tensor<8xi32, #ntt_ring> @@ -11,8 +11,8 @@ func.func @test_canonicalize_intt_after_ntt(%p0 : !ntt_poly_ty) -> !ntt_poly_ty // CHECK-NOT: polynomial.ntt // CHECK-NOT: polynomial.intt // CHECK: %[[RESULT:.+]] = polynomial.add %[[P]], %[[P]] : [[T]] - %t0 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty - %p1 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty + %t0 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty + %p1 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty %p2 = polynomial.add %p1, %p1 : !ntt_poly_ty // CHECK: return %[[RESULT]] : [[T]] return %p2 : !ntt_poly_ty @@ -24,8 +24,8 @@ func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty { // CHECK-NOT: polynomial.intt // CHECK-NOT: polynomial.ntt // CHECK: %[[RESULT:.+]] = arith.addi %[[X]], %[[X]] : [[T]] - %p0 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty - %t1 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty + %p0 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty + %t1 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty %t2 = arith.addi %t1, %t1 : !tensor_ty // CHECK: return %[[RESULT]] : [[T]] return %t2 : !tensor_ty diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir index faeb68a8b2c09..4998730c80c7e 100644 --- a/mlir/test/Dialect/Polynomial/ops.mlir +++ b/mlir/test/Dialect/Polynomial/ops.mlir @@ -15,12 +15,13 @@ !poly_ty = !polynomial.polynomial #ntt_poly = #polynomial.int_polynomial<-1 + x**8> -#ntt_ring = #polynomial.ring +#ntt_ring_root = #polynomial.primitive_root +#ntt_ring = #polynomial.ring !ntt_poly_ty = !polynomial.polynomial #ntt_poly_2 = #polynomial.int_polynomial<1 + x**65536> -#ntt_ring_2 = #polynomial.ring #ntt_ring_2_root = #polynomial.primitive_root +#ntt_ring_2 = #polynomial.ring !ntt_poly_ty_2 = !polynomial.polynomial module { @@ -96,17 +97,17 @@ module { } func.func @test_ntt(%0 : !ntt_poly_ty) { - %1 = polynomial.ntt %0 {root=#polynomial.primitive_root} : !ntt_poly_ty -> tensor<8xi32, #ntt_ring> + %1 = polynomial.ntt %0 : !ntt_poly_ty -> tensor<8xi32, #ntt_ring> return } func.func @test_ntt_with_overflowing_root(%0 : !ntt_poly_ty_2) { - %1 = polynomial.ntt %0 {root=#ntt_ring_2_root} : !ntt_poly_ty_2 -> tensor<65536xi32, #ntt_ring_2> + %1 = polynomial.ntt %0 : !ntt_poly_ty_2 -> tensor<65536xi32, #ntt_ring_2> return } func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) { - %1 = polynomial.intt %0 {root=#polynomial.primitive_root} : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty + %1 = polynomial.intt %0 : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty return } } diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir index 4937e17027afa..003967e3f4228 100644 --- a/mlir/test/Dialect/Polynomial/ops_errors.mlir +++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir @@ -55,36 +55,39 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty { // ----- #my_poly = #polynomial.int_polynomial<-1 + x**1024> -#ring = #polynomial.ring +#root = #polynomial.primitive_root +#ring = #polynomial.ring !poly_ty = !polynomial.polynomial // CHECK-NOT: @test_invalid_ntt // CHECK-NOT: polynomial.ntt func.func @test_invalid_ntt(%0 : !poly_ty) { // expected-error@below {{expects a ring encoding to be provided to the tensor}} - %1 = polynomial.ntt %0 {root=#polynomial.primitive_root} : !poly_ty -> tensor<1024xi32> + %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32> return } // ----- #my_poly = #polynomial.int_polynomial<-1 + x**1024> -#ring = #polynomial.ring +#root = #polynomial.primitive_root +#ring = #polynomial.ring !poly_ty = !polynomial.polynomial // CHECK-NOT: @test_invalid_ntt // CHECK-NOT: polynomial.ntt func.func @test_invalid_ntt(%0 : !poly_ty) { // expected-error@below {{tensor encoding is not a ring attribute}} - %1 = polynomial.ntt %0 {root=#polynomial.primitive_root} : !poly_ty -> tensor<1024xi32, #my_poly> + %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #my_poly> return } // ----- #my_poly = #polynomial.int_polynomial<-1 + x**1024> +#root = #polynomial.primitive_root #ring = #polynomial.ring -#ring1 = #polynomial.ring +#ring1 = #polynomial.ring !poly_ty = !polynomial.polynomial // CHECK-NOT: @test_invalid_intt @@ -98,7 +101,8 @@ func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) { // ----- #my_poly = #polynomial.int_polynomial<-1 + x**1024> -#ring = #polynomial.ring +#root = #polynomial.primitive_root +#ring = #polynomial.ring !poly_ty = !polynomial.polynomial // CHECK-NOT: @test_invalid_intt @@ -106,7 +110,7 @@ func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) { func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) { // expected-error@below {{does not match output type}} // expected-note@below {{exactly the degree of the polynomialModulus of the polynomial type's ring attribute}} - %1 = polynomial.intt %0 {root=#polynomial.primitive_root} : tensor<1025xi32, #ring> -> !poly_ty + %1 = polynomial.intt %0 : tensor<1025xi32, #ring> -> !poly_ty return } @@ -114,13 +118,28 @@ func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) { #my_poly = #polynomial.int_polynomial<-1 + x**8> // A valid root is 31 -#ring = #polynomial.ring +#root = #polynomial.primitive_root +#ring = #polynomial.ring !poly_ty = !polynomial.polynomial // CHECK-NOT: @test_invalid_intt // CHECK-NOT: polynomial.intt func.func @test_invalid_intt(%0 : tensor<8xi32, #ring>) { // expected-error@below {{provided root 32 is not a primitive root of unity mod 256, with the specified degree 8}} - %1 = polynomial.intt %0 {root=#polynomial.primitive_root} : tensor<8xi32, #ring> -> !poly_ty + %1 = polynomial.intt %0 : tensor<8xi32, #ring> -> !poly_ty + return +} + +// ----- + +#my_poly = #polynomial.int_polynomial<-1 + x**8> +#ring = #polynomial.ring +!poly_ty = !polynomial.polynomial + +// CHECK-NOT: @test_invalid_intt +// CHECK-NOT: polynomial.intt +func.func @test_invalid_intt(%0 : tensor<8xi32, #ring>) { + // expected-error@below {{primitive root not provided but ntt/intt op called}} + %1 = polynomial.intt %0 : tensor<8xi32, #ring> -> !poly_ty return }