From 135bbdc6d76ef4b178df19d63fd7dd671a0ef7f4 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Fri, 24 May 2024 21:33:35 -0700 Subject: [PATCH] [mlir][polynomial] ensure primitive root calculation doesn't overflow --- mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp | 14 +++++++++----- mlir/test/Dialect/Polynomial/ops.mlir | 10 ++++++++++ 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp index 3117721a94152..40fa97dd2597e 100644 --- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp +++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp @@ -114,16 +114,20 @@ LogicalResult MulScalarOp::verify() { /// Test if a value is a primitive nth root of unity modulo cmod. bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n, const APInt &cmod) { - // Root bitwidth may be 1 less then cmod. - APInt r = APInt(root).zext(cmod.getBitWidth()); - assert(r.ule(cmod) && "root must be less than cmod"); - unsigned upperBound = n.getZExtValue(); + // The first or subsequent multiplications, may overflow the input bit width, + // so scale them up to ensure they do not overflow. + unsigned requiredBitWidth = + std::max(root.getActiveBits() * 2, cmod.getActiveBits() * 2); + APInt r = APInt(root).zextOrTrunc(requiredBitWidth); + APInt cmodExt = APInt(cmod).zextOrTrunc(requiredBitWidth); + assert(r.ule(cmodExt) && "root must be less than cmod"); + uint64_t upperBound = n.getZExtValue(); APInt a = r; for (size_t k = 1; k < upperBound; k++) { if (a.isOne()) return false; - a = (a * r).urem(cmod); + a = (a * r).urem(cmodExt); } return a.isOne(); } diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir index 8c134ab789d60..faeb68a8b2c09 100644 --- a/mlir/test/Dialect/Polynomial/ops.mlir +++ b/mlir/test/Dialect/Polynomial/ops.mlir @@ -18,6 +18,11 @@ #ntt_ring = #polynomial.ring !ntt_poly_ty = !polynomial.polynomial +#ntt_poly_2 = #polynomial.int_polynomial<1 + x**65536> +#ntt_ring_2 = #polynomial.ring +#ntt_ring_2_root = #polynomial.primitive_root +!ntt_poly_ty_2 = !polynomial.polynomial + module { func.func @test_multiply() -> !polynomial.polynomial { %c0 = arith.constant 0 : index @@ -95,6 +100,11 @@ module { return } + func.func @test_ntt_with_overflowing_root(%0 : !ntt_poly_ty_2) { + %1 = polynomial.ntt %0 {root=#ntt_ring_2_root} : !ntt_poly_ty_2 -> tensor<65536xi32, #ntt_ring_2> + return + } + func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) { %1 = polynomial.intt %0 {root=#polynomial.primitive_root} : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty return