diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index c7169c5297d9a..90e0479a515d5 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -566,6 +566,64 @@ struct LinearizeVectorSplat final } }; +/// This pattern converts the CreateMaskOp to work on a linearized vector. +/// It currently supports only 2D masks with a unit outer dimension. +/// Following, +/// vector.create_mask %arg0, %arg1 : vector<1x4xi1> +/// is converted to: +/// %zero = arith.constant 0 : index +/// %cmpi = arith.cmpi sgt, %arg0, %zero : index +/// %index = arith.index_cast %cmpi : i1 to index +/// %mul = arith.andi %index, %arg1 : index +/// %mask = vector.create_mask %mul : vector<4xi1> +/// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1> +struct LinearizeVectorCreateMask final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LinearizeVectorCreateMask(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = createMaskOp.getLoc(); + VectorType srcTy = createMaskOp.getType(); + auto srcShape = srcTy.getShape(); + if (srcShape.size() != 2) + return rewriter.notifyMatchFailure(createMaskOp, + "only 2D mask is supported."); + + if (srcShape[0] != 1) + return rewriter.notifyMatchFailure( + createMaskOp, "only unit outer dimension is supported."); + + auto dstTy = getTypeConverter()->convertType(srcTy); + if (!dstTy) + return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type."); + + // Compare the first operand with 0. If it is greater than 0, the + // corresponding mask element is set to true, otherwise false. + // The result of the comparison is then multiplied with + // the second operand of create_mask to get the 1D mask. + auto firstOperand = adaptor.getOperands().front(); + auto zero = rewriter.create(loc, 0); + auto isNonZero = rewriter.createOrFold( + loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero); + auto isNonZeroIndex = rewriter.createOrFold( + loc, rewriter.getIndexType(), isNonZero); + auto secondOperand = adaptor.getOperands().back(); + auto maskSize = rewriter.createOrFold( + loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand); + + auto newMask = + rewriter.create(loc, dstTy, maskSize); + rewriter.replaceOp(createMaskOp, newMask); + return success(); + } +}; + } // namespace /// Return true if the operation `op` does not support scalable vectors and @@ -651,9 +709,10 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter, void mlir::vector::populateVectorLinearizeBasePatterns( const TypeConverter &typeConverter, const ConversionTarget &target, RewritePatternSet &patterns) { - patterns.add( - typeConverter, patterns.getContext()); + patterns + .add( + typeConverter, patterns.getContext()); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 3cdbef8db604b..40445d3781228 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -416,3 +416,28 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> { return %0 : vector<4x[2]xi32> } +// ----- + +// CHECK-LABEL: linearize_create_mask +// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x16xi1> +func.func @linearize_create_mask(%arg0 : index, %arg1 : index) -> vector<1x16xi1> { + + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[ARG0]], %[[C0]] : index + // CHECK: %[[INDEXCAST:.*]] = arith.index_cast %[[CMP]] : i1 to index + // CHECK: %[[MULI:.*]] = arith.andi %[[INDEXCAST]], %[[ARG1]] : index + // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[MULI]] : vector<16xi1> + // CHECK: %[[CAST:.*]] = vector.shape_cast %[[MASK_1D]] : vector<16xi1> to vector<1x16xi1> + // CHECK: return %[[CAST]] : vector<1x16xi1> + %0 = vector.create_mask %arg0, %arg1 : vector<1x16xi1> + return %0 : vector<1x16xi1> +} + +// ----- +// CHECK-LABEL: linearize_scalable_create_mask +func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vector<1x[16]xi1> { + + // CHECK: %[[MASK_1D:.*]] = vector.create_mask {{%.*}} : vector<[16]xi1> + %0 = vector.create_mask %arg0, %arg1 : vector<1x[16]xi1> + return %0 : vector<1x[16]xi1> +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index eda2594fbc7c7..54defd949c264 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -973,7 +973,7 @@ struct TestVectorLinearize final return "Linearizes ND vectors for N >= 2 into 1D vectors"; } void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } void runOnOperation() override {