From 3a83e2d5cfd5aae8c35fde4050886a96b61edd3f Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 1 May 2025 22:10:03 +0000 Subject: [PATCH 01/10] Add linearization pattern for vector.create_mask --- .../Vector/Transforms/VectorLinearize.cpp | 65 ++++++++++++++++++- mlir/test/Dialect/Vector/linearize.mlir | 33 ++++++++++ .../Dialect/Vector/TestVectorTransforms.cpp | 3 +- 3 files changed, 97 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index b9cef003fa365..cdd937eed6569 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -445,6 +445,64 @@ struct LinearizeVectorSplat final } }; +/// This pattern converts the CreateMaskOp to work on a +/// linearized vector. The pattern currently +/// supports only 2D masks with a unit outer dimension. +/// Following, +/// vector.create_mask %dims : vector<1x4xi1> +/// is converted to: +/// %out_1d = vector.create_mask %dims : vector<4xi1> +/// %out_nd = vector.shape_cast %out_1d : vector<4xi1> to vector<1x4xi1> +struct LinearizeVectorCreateMask final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LinearizeVectorCreateMask(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcTy = createMaskOp.getType(); + auto srcShape = srcTy.getShape(); + if (srcShape.size() != 2) + return rewriter.notifyMatchFailure(createMaskOp, + "only 2D mask is supported."); + + if (srcShape[0] != 1) + return rewriter.notifyMatchFailure( + createMaskOp, "only unit outer dimension is supported."); + + auto dstTy = getTypeConverter()->convertType(srcTy); + if (!dstTy) + return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type."); + + // Compare the first operand with 0. If it's less than or equal to 0, + // create a zero mask, else strip the first operand and create a mask + // using the second operand. + auto firstOperand = adaptor.getOperands().front(); + auto zero = + rewriter.create(createMaskOp.getLoc(), 0); + auto isZeroOrNegative = rewriter.create( + createMaskOp.getLoc(), mlir::arith::CmpIPredicate::sle, firstOperand, + zero); + auto isZeroOrNegativeSplat = rewriter.create( + createMaskOp.getLoc(), dstTy, isZeroOrNegative); + + // Use a select operation to choose between the masks. + auto zeroMask = rewriter.create( + createMaskOp.getLoc(), dstTy, rewriter.getZeroAttr(dstTy)); + auto newMask = rewriter.create( + createMaskOp.getLoc(), dstTy, adaptor.getOperands().back()); + auto result = rewriter.create( + createMaskOp.getLoc(), isZeroOrNegativeSplat, zeroMask, newMask); + + rewriter.replaceOp(createMaskOp, result.getResult()); + return success(); + } +}; + } // namespace /// Return true if the operation `op` does not support scalable vectors and @@ -530,9 +588,10 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter, void mlir::vector::populateVectorLinearizeBasePatterns( const TypeConverter &typeConverter, const ConversionTarget &target, RewritePatternSet &patterns) { - patterns.add( - typeConverter, patterns.getContext()); + patterns + .add( + typeConverter, patterns.getContext()); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 20169c15eb2c1..2b802eed64595 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -447,3 +447,36 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> { %0 = vector.splat %arg0 : vector<4x[2]xi32> return %0 : vector<4x[2]xi32> } + +// ----- +// ALL-LABEL: test_create_mask +func.func @test_create_mask() -> vector<1x16xi1> { + // DEFAULT: %[[C0:.*]] = arith.constant 0 : index + // BW-128: %[[C0:.*]] = arith.constant 0 : index + // DEFAULT: %[[C20:.*]] = arith.constant 20 : index + // BW-128: %[[C20:.*]] = arith.constant 20 : index + // DEFAULT: %[[C0_0:.*]] = arith.constant 0 : index + // BW-128: %[[C0_0:.*]] = arith.constant 0 : index + // DEFAULT: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index + // BW-128: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index + // DEFAULT: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1> + // BW-128: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1> + // DEFAULT: %[[CST:.*]] = arith.constant dense : vector<16xi1> + // BW-128: %[[CST:.*]] = arith.constant dense : vector<16xi1> + // DEFAULT: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1> + // BW-128: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1> + // DEFAULT: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>, vector<16xi1> + // BW-128: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1> + // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1> + // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1> + // DEFAULT: return %[[CAST]] : vector<1x16xi1> + // BW-128: return %[[CAST]] : vector<1x16xi1> + + // BW-0: %[[C0:.*]] = arith.constant 0 : index + // BW-0: %[[C20:.*]] = arith.constant 20 : index + // BW-0: %[[MASK:.*]] = vector.create_mask %[[C0]], %[[C20]] : vector<1x16xi1> + %c0 = arith.constant 0 : index + %c20 = arith.constant 20 : index + %0 = vector.create_mask %c0, %c20 : vector<1x16xi1> + return %0 : vector<1x16xi1> +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index eda2594fbc7c7..2d5e90908d4d0 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -973,7 +973,8 @@ struct TestVectorLinearize final return "Linearizes ND vectors for N >= 2 into 1D vectors"; } void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } void runOnOperation() override { From 24b0739da64b109564abbe85bb68706c0ad6d101 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 1 May 2025 22:56:42 +0000 Subject: [PATCH 02/10] Use CHECKS --- mlir/test/Dialect/Vector/linearize.mlir | 38 ++++++++----------------- 1 file changed, 12 insertions(+), 26 deletions(-) diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index d5d4cfa4f9aa1..cc5ec1a5c036c 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -347,32 +347,18 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> { } // ----- -// ALL-LABEL: test_create_mask -func.func @test_create_mask() -> vector<1x16xi1> { - // DEFAULT: %[[C0:.*]] = arith.constant 0 : index - // BW-128: %[[C0:.*]] = arith.constant 0 : index - // DEFAULT: %[[C20:.*]] = arith.constant 20 : index - // BW-128: %[[C20:.*]] = arith.constant 20 : index - // DEFAULT: %[[C0_0:.*]] = arith.constant 0 : index - // BW-128: %[[C0_0:.*]] = arith.constant 0 : index - // DEFAULT: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index - // BW-128: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index - // DEFAULT: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1> - // BW-128: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1> - // DEFAULT: %[[CST:.*]] = arith.constant dense : vector<16xi1> - // BW-128: %[[CST:.*]] = arith.constant dense : vector<16xi1> - // DEFAULT: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1> - // BW-128: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1> - // DEFAULT: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>, vector<16xi1> - // BW-128: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1> - // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1> - // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1> - // DEFAULT: return %[[CAST]] : vector<1x16xi1> - // BW-128: return %[[CAST]] : vector<1x16xi1> - - // BW-0: %[[C0:.*]] = arith.constant 0 : index - // BW-0: %[[C20:.*]] = arith.constant 20 : index - // BW-0: %[[MASK:.*]] = vector.create_mask %[[C0]], %[[C20]] : vector<1x16xi1> +// ALL-LABEL: linearize_create_mask +func.func @linearize_create_mask() -> vector<1x16xi1> { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[C20:.*]] = arith.constant 20 : index + // CHECK: %[[C0_0:.*]] = arith.constant 0 : index + // CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index + // CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1> + // CHECK: %[[CST:.*]] = arith.constant dense : vector<16xi1> + // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1> + // CHECK: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>, vector<16xi1> + // CHECK: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1> + // CHECK: return %[[CAST]] : vector<1x16xi1> %c0 = arith.constant 0 : index %c20 = arith.constant 20 : index %0 = vector.create_mask %c0, %c20 : vector<1x16xi1> From 2b8a653c279b3be08fd6426316cd57dbf3fd54eb Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 1 May 2025 23:00:13 +0000 Subject: [PATCH 03/10] Add test case for scalable vector --- mlir/test/Dialect/Vector/linearize.mlir | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index cc5ec1a5c036c..01872426c77bb 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -364,3 +364,22 @@ func.func @linearize_create_mask() -> vector<1x16xi1> { %0 = vector.create_mask %c0, %c20 : vector<1x16xi1> return %0 : vector<1x16xi1> } + +// ----- +// ALL-LABEL: linearize_scalable_create_mask +func.func @linearize_scalable_create_mask() -> vector<1x[16]xi1> { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[C20:.*]] = arith.constant 20 : index + // CHECK: %[[C0_0:.*]] = arith.constant 0 : index + // CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index + // CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<[16]xi1> + // CHECK: %[[CST:.*]] = arith.constant dense : vector<[16]xi1> + // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<[16]xi1> + // CHECK: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<[16]xi1>, vector<[16]xi1> + // CHECK: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<[16]xi1> to vector<1x[16]xi1> + // CHECK: return %[[CAST]] : vector<1x[16]xi1> + %c0 = arith.constant 0 : index + %c20 = arith.constant 20 : index + %0 = vector.create_mask %c0, %c20 : vector<1x[16]xi1> + return %0 : vector<1x[16]xi1> +} From 8e8de7af27934145383a7ee504e35f18a5787abd Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 1 May 2025 23:07:38 +0000 Subject: [PATCH 04/10] Clean up --- mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 2d5e90908d4d0..54defd949c264 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -973,8 +973,7 @@ struct TestVectorLinearize final return "Linearizes ND vectors for N >= 2 into 1D vectors"; } void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } void runOnOperation() override { From 528f91324c54730a68fde3bb1f3f94c8a258bfce Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 8 May 2025 23:39:59 +0000 Subject: [PATCH 05/10] Address Feedback --- .../Vector/Transforms/VectorLinearize.cpp | 35 ++++++++++--------- mlir/test/Dialect/Vector/linearize.mlir | 16 ++++----- 2 files changed, 27 insertions(+), 24 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index cdd937eed6569..7e03b073fb369 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -445,14 +445,18 @@ struct LinearizeVectorSplat final } }; -/// This pattern converts the CreateMaskOp to work on a -/// linearized vector. The pattern currently -/// supports only 2D masks with a unit outer dimension. +/// This pattern converts the CreateMaskOp to work on a linearized vector. +// The pattern currently supports only 2D masks with a unit outer dimension. /// Following, -/// vector.create_mask %dims : vector<1x4xi1> +/// vector.create_mask %arg0, %arg1 : vector<1x4xi1> /// is converted to: -/// %out_1d = vector.create_mask %dims : vector<4xi1> -/// %out_nd = vector.shape_cast %out_1d : vector<4xi1> to vector<1x4xi1> +/// %zero = arith.constant 0 : index +/// %cmpi = arith.cmpi sle, %arg0, %zero : index +/// %splat = vector.splat %cmpi : vector<4xi1> +/// %cst = arith.constant dense : vector<4xi1> +/// %mask = vector.create_mask %arg1 : vector<4xi1> +/// %out = arith.select %splat, %cst, %mask : vector<4xi1> +/// %out_1d = vector.shape_cast %out : vector<4xi1> to vector<1x4xi1> struct LinearizeVectorCreateMask final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -464,7 +468,8 @@ struct LinearizeVectorCreateMask final LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto srcTy = createMaskOp.getType(); + Location loc = createMaskOp.getLoc(); + VectorType srcTy = createMaskOp.getType(); auto srcShape = srcTy.getShape(); if (srcShape.size() != 2) return rewriter.notifyMatchFailure(createMaskOp, @@ -482,21 +487,19 @@ struct LinearizeVectorCreateMask final // create a zero mask, else strip the first operand and create a mask // using the second operand. auto firstOperand = adaptor.getOperands().front(); - auto zero = - rewriter.create(createMaskOp.getLoc(), 0); + auto zero = rewriter.create(loc, 0); auto isZeroOrNegative = rewriter.create( - createMaskOp.getLoc(), mlir::arith::CmpIPredicate::sle, firstOperand, - zero); - auto isZeroOrNegativeSplat = rewriter.create( - createMaskOp.getLoc(), dstTy, isZeroOrNegative); + loc, mlir::arith::CmpIPredicate::sle, firstOperand, zero); + auto isZeroOrNegativeSplat = + rewriter.create(loc, dstTy, isZeroOrNegative); // Use a select operation to choose between the masks. auto zeroMask = rewriter.create( - createMaskOp.getLoc(), dstTy, rewriter.getZeroAttr(dstTy)); + loc, dstTy, rewriter.getZeroAttr(dstTy)); auto newMask = rewriter.create( - createMaskOp.getLoc(), dstTy, adaptor.getOperands().back()); + loc, dstTy, adaptor.getOperands().back()); auto result = rewriter.create( - createMaskOp.getLoc(), isZeroOrNegativeSplat, zeroMask, newMask); + loc, isZeroOrNegativeSplat, zeroMask, newMask); rewriter.replaceOp(createMaskOp, result.getResult()); return success(); diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 01872426c77bb..55fad7b1704c9 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -350,18 +350,18 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> { // ALL-LABEL: linearize_create_mask func.func @linearize_create_mask() -> vector<1x16xi1> { // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[C20:.*]] = arith.constant 20 : index + // CHECK: %[[C10:.*]] = arith.constant 10 : index // CHECK: %[[C0_0:.*]] = arith.constant 0 : index // CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index // CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1> // CHECK: %[[CST:.*]] = arith.constant dense : vector<16xi1> - // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1> + // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C10]] : vector<16xi1> // CHECK: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>, vector<16xi1> // CHECK: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1> // CHECK: return %[[CAST]] : vector<1x16xi1> %c0 = arith.constant 0 : index - %c20 = arith.constant 20 : index - %0 = vector.create_mask %c0, %c20 : vector<1x16xi1> + %c10 = arith.constant 10 : index + %0 = vector.create_mask %c0, %c10 : vector<1x16xi1> return %0 : vector<1x16xi1> } @@ -369,17 +369,17 @@ func.func @linearize_create_mask() -> vector<1x16xi1> { // ALL-LABEL: linearize_scalable_create_mask func.func @linearize_scalable_create_mask() -> vector<1x[16]xi1> { // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[C20:.*]] = arith.constant 20 : index + // CHECK: %[[C10:.*]] = arith.constant 10 : index // CHECK: %[[C0_0:.*]] = arith.constant 0 : index // CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index // CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<[16]xi1> // CHECK: %[[CST:.*]] = arith.constant dense : vector<[16]xi1> - // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<[16]xi1> + // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C10]] : vector<[16]xi1> // CHECK: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<[16]xi1>, vector<[16]xi1> // CHECK: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<[16]xi1> to vector<1x[16]xi1> // CHECK: return %[[CAST]] : vector<1x[16]xi1> %c0 = arith.constant 0 : index - %c20 = arith.constant 20 : index - %0 = vector.create_mask %c0, %c20 : vector<1x[16]xi1> + %c10 = arith.constant 10 : index + %0 = vector.create_mask %c0, %c10 : vector<1x[16]xi1> return %0 : vector<1x[16]xi1> } From c2c1a22a16b1271307620d743378b57d673a3889 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Fri, 9 May 2025 01:15:14 +0000 Subject: [PATCH 06/10] Replace select with mul --- .../Vector/Transforms/VectorLinearize.cpp | 43 +++++++++---------- mlir/test/Dialect/Vector/linearize.mlir | 43 +++++-------------- 2 files changed, 31 insertions(+), 55 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 7e03b073fb369..e10483bd1a862 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -446,17 +446,16 @@ struct LinearizeVectorSplat final }; /// This pattern converts the CreateMaskOp to work on a linearized vector. -// The pattern currently supports only 2D masks with a unit outer dimension. +/// It currently supports only 2D masks with a unit outer dimension. /// Following, /// vector.create_mask %arg0, %arg1 : vector<1x4xi1> /// is converted to: /// %zero = arith.constant 0 : index -/// %cmpi = arith.cmpi sle, %arg0, %zero : index -/// %splat = vector.splat %cmpi : vector<4xi1> -/// %cst = arith.constant dense : vector<4xi1> -/// %mask = vector.create_mask %arg1 : vector<4xi1> -/// %out = arith.select %splat, %cst, %mask : vector<4xi1> -/// %out_1d = vector.shape_cast %out : vector<4xi1> to vector<1x4xi1> +/// %cmpi = arith.cmpi sgt, %arg0, %zero : index +/// %index = arith.index_cast %cmpi : i1 to index +/// %mul = arith.muli %index, %arg1 : index +/// %mask = vector.create_mask %mul : vector<4xi1> +/// %out_1d = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1> struct LinearizeVectorCreateMask final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -483,25 +482,23 @@ struct LinearizeVectorCreateMask final if (!dstTy) return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type."); - // Compare the first operand with 0. If it's less than or equal to 0, - // create a zero mask, else strip the first operand and create a mask - // using the second operand. + // Compare the first operand with 0. If it is greater than 0, the + // corresponding mask element is set to true, otherwise false. + // The result of the comparison is then multiplied with + // the second operand of create_mask to get the 1D mask. auto firstOperand = adaptor.getOperands().front(); auto zero = rewriter.create(loc, 0); - auto isZeroOrNegative = rewriter.create( - loc, mlir::arith::CmpIPredicate::sle, firstOperand, zero); - auto isZeroOrNegativeSplat = - rewriter.create(loc, dstTy, isZeroOrNegative); - - // Use a select operation to choose between the masks. - auto zeroMask = rewriter.create( - loc, dstTy, rewriter.getZeroAttr(dstTy)); - auto newMask = rewriter.create( - loc, dstTy, adaptor.getOperands().back()); - auto result = rewriter.create( - loc, isZeroOrNegativeSplat, zeroMask, newMask); + auto isNonZero = rewriter.create( + loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero); + auto isNonZeroIndex = rewriter.create( + loc, rewriter.getIndexType(), isNonZero); + auto secondOperand = adaptor.getOperands().back(); + auto maskSize = rewriter.create( + loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand); - rewriter.replaceOp(createMaskOp, result.getResult()); + auto newMask = rewriter.create( + loc, dstTy, maskSize.getResult()); + rewriter.replaceOp(createMaskOp, newMask); return success(); } }; diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 55fad7b1704c9..3ca2721dc1201 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -347,39 +347,18 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> { } // ----- -// ALL-LABEL: linearize_create_mask -func.func @linearize_create_mask() -> vector<1x16xi1> { + +// CHECK-LABEL: linearize_create_mask +// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x16xi1> +func.func @linearize_create_mask(%arg0 : index, %arg1 : index) -> vector<1x16xi1> { + // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[C10:.*]] = arith.constant 10 : index - // CHECK: %[[C0_0:.*]] = arith.constant 0 : index - // CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index - // CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1> - // CHECK: %[[CST:.*]] = arith.constant dense : vector<16xi1> - // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C10]] : vector<16xi1> - // CHECK: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>, vector<16xi1> - // CHECK: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1> + // CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[ARG0]], %[[C0]] : index + // CHECK: %[[INDEXCAST:.*]] = arith.index_cast %[[CMP]] : i1 to index + // CHECK: %[[MULI:.*]] = arith.muli %[[INDEXCAST]], %[[ARG1]] : index + // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[MULI]] : vector<16xi1> + // CHECK: %[[CAST:.*]] = vector.shape_cast %[[MASK_1D]] : vector<16xi1> to vector<1x16xi1> // CHECK: return %[[CAST]] : vector<1x16xi1> - %c0 = arith.constant 0 : index - %c10 = arith.constant 10 : index - %0 = vector.create_mask %c0, %c10 : vector<1x16xi1> + %0 = vector.create_mask %arg0, %arg1 : vector<1x16xi1> return %0 : vector<1x16xi1> } - -// ----- -// ALL-LABEL: linearize_scalable_create_mask -func.func @linearize_scalable_create_mask() -> vector<1x[16]xi1> { - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[C10:.*]] = arith.constant 10 : index - // CHECK: %[[C0_0:.*]] = arith.constant 0 : index - // CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index - // CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<[16]xi1> - // CHECK: %[[CST:.*]] = arith.constant dense : vector<[16]xi1> - // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C10]] : vector<[16]xi1> - // CHECK: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<[16]xi1>, vector<[16]xi1> - // CHECK: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<[16]xi1> to vector<1x[16]xi1> - // CHECK: return %[[CAST]] : vector<1x[16]xi1> - %c0 = arith.constant 0 : index - %c10 = arith.constant 10 : index - %0 = vector.create_mask %c0, %c10 : vector<1x[16]xi1> - return %0 : vector<1x[16]xi1> -} From c5b2e81af1fed9be5eece1966127ffc4ff87af92 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Fri, 9 May 2025 04:08:32 +0000 Subject: [PATCH 07/10] Fix typo --- mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index e10483bd1a862..eb18891772e80 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -455,7 +455,7 @@ struct LinearizeVectorSplat final /// %index = arith.index_cast %cmpi : i1 to index /// %mul = arith.muli %index, %arg1 : index /// %mask = vector.create_mask %mul : vector<4xi1> -/// %out_1d = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1> +/// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1> struct LinearizeVectorCreateMask final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; From 8fca9c1c6510f1cefe1ff64c87629ef36d31e5ee Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 13 May 2025 02:42:25 +0000 Subject: [PATCH 08/10] Address comments --- .../Dialect/Vector/Transforms/VectorLinearize.cpp | 8 ++++---- mlir/test/Dialect/Vector/linearize.mlir | 13 ++++++++++++- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index eb18891772e80..4844549a6b25b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -488,16 +488,16 @@ struct LinearizeVectorCreateMask final // the second operand of create_mask to get the 1D mask. auto firstOperand = adaptor.getOperands().front(); auto zero = rewriter.create(loc, 0); - auto isNonZero = rewriter.create( + auto isNonZero = rewriter.createOrFold( loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero); - auto isNonZeroIndex = rewriter.create( + auto isNonZeroIndex = rewriter.createOrFold( loc, rewriter.getIndexType(), isNonZero); auto secondOperand = adaptor.getOperands().back(); - auto maskSize = rewriter.create( + auto maskSize = rewriter.createOrFold( loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand); auto newMask = rewriter.create( - loc, dstTy, maskSize.getResult()); + loc, dstTy, maskSize); rewriter.replaceOp(createMaskOp, newMask); return success(); } diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 3ca2721dc1201..a2bb9fc4509ec 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -355,10 +355,21 @@ func.func @linearize_create_mask(%arg0 : index, %arg1 : index) -> vector<1x16xi1 // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[ARG0]], %[[C0]] : index // CHECK: %[[INDEXCAST:.*]] = arith.index_cast %[[CMP]] : i1 to index - // CHECK: %[[MULI:.*]] = arith.muli %[[INDEXCAST]], %[[ARG1]] : index + // CHECK: %[[MULI:.*]] = arith.andi %[[INDEXCAST]], %[[ARG1]] : index // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[MULI]] : vector<16xi1> // CHECK: %[[CAST:.*]] = vector.shape_cast %[[MASK_1D]] : vector<16xi1> to vector<1x16xi1> // CHECK: return %[[CAST]] : vector<1x16xi1> %0 = vector.create_mask %arg0, %arg1 : vector<1x16xi1> return %0 : vector<1x16xi1> } + +// ----- +// CHECK-LABEL: linearize_scalable_create_mask +// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x[16]xi1> +func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vector<1x[16]xi1> { + + // CHECK: %[[MASK_1D:.*]] = vector.create_mask {{%.*}} : vector<[16]xi1> + // CHECK: %[[CAST:.*]] = vector.shape_cast %[[MASK_1D]] : vector<[16]xi1> to vector<1x[16]xi1> + %0 = vector.create_mask %arg0, %arg1 : vector<1x[16]xi1> + return %0 : vector<1x[16]xi1> +} From 2b7e06e8d907981ca382896076e470583c67a4c9 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 13 May 2025 02:54:08 +0000 Subject: [PATCH 09/10] Fix doc --- mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 4844549a6b25b..7d3ccbe930e38 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -453,7 +453,7 @@ struct LinearizeVectorSplat final /// %zero = arith.constant 0 : index /// %cmpi = arith.cmpi sgt, %arg0, %zero : index /// %index = arith.index_cast %cmpi : i1 to index -/// %mul = arith.muli %index, %arg1 : index +/// %mul = arith.andi %index, %arg1 : index /// %mask = vector.create_mask %mul : vector<4xi1> /// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1> struct LinearizeVectorCreateMask final From db484c9bd4283f35a52ba7e3f0a71182162b8934 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 13 May 2025 02:55:25 +0000 Subject: [PATCH 10/10] Clang-format --- mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 7d3ccbe930e38..8c50ad96681f0 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -496,8 +496,8 @@ struct LinearizeVectorCreateMask final auto maskSize = rewriter.createOrFold( loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand); - auto newMask = rewriter.create( - loc, dstTy, maskSize); + auto newMask = + rewriter.create(loc, dstTy, maskSize); rewriter.replaceOp(createMaskOp, newMask); return success(); }