Skip to content

[MLIR][Vector] Add Lowering for vector.step #113655

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 1, 2024

Conversation

manupak
Copy link
Contributor

@manupak manupak commented Oct 25, 2024

Currently, the lowering for vector.step lives
under a folder. This is not ideal if we want
to do transformation on it and defer the
materizaliztion of the constants much later.

This commits adds a rewrite pattern that
could be used by using
transform.structured.vectorize_children_and_apply_patterns
transform dialect operation.

Moreover, the rewriter of vector.step is also
now used in -convert-vector-to-llvm pass where
it handles scalable and non-scalable types as
LLVM expects it.

As a consequence of removing the vector.step
lowering as its folder, linalg vectorization
will keep vector.step intact.

@llvmbot
Copy link
Member

llvmbot commented Oct 25, 2024

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir-vector

Author: Manupa Karunaratne (manupak)

Changes

Currently, the lowering for vector.step lives
under a folder. This is not ideal if we want
to do transformation on it and defer the
materizaliztion of the constants much later.

This commits adds a rewrite pattern + transform
op to do this instead. Thus enabling more control
on the lowering.


Patch is 25.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/113655.diff

13 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (-1)
  • (modified) mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td (+9)
  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h (+7)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (+2)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (-14)
  • (modified) mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (+5)
  • (modified) mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp (+64)
  • (modified) mlir/test/Dialect/Linalg/vectorization-scalable.mlir (+1)
  • (modified) mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir (+15)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir (+73-6)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (-9)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index c02b16ea931706..5e7b6659548203 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2940,7 +2940,6 @@ def Vector_StepOp : Vector_Op<"step", [Pure]> {
     %1 = vector.step : vector<[4]xindex> ; [0, 1, .., <vscale * 4 - 1>]
     ```
   }];
-  let hasFolder = 1;
   let results = (outs VectorOfRankAndType<[1], [Index]>:$result);
   let assemblyFormat = "attr-dict `:` type($result)";
 }
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index c973eca0132a92..3262aa37a81877 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -453,4 +453,13 @@ def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyLowerStepToArithOps : Op<Transform_Dialect,
+    "apply_patterns.vector.step_to_arith",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Convert vector.step to arith if not using scalable vectors.
+  }];
+  let assemblyFormat = "attr-dict";
+}
+
 #endif // VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 1976b8399c7f9c..27581443814322 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -235,6 +235,13 @@ void populateVectorTransferPermutationMapLoweringPatterns(
 void populateVectorScanLoweringPatterns(RewritePatternSet &patterns,
                                         PatternBenefit benefit = 1);
 
+/// Populate the pattern set with the following patterns:
+///
+/// [StepToArithOps]
+/// Convert vector.step op into arith ops if not scalable
+void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
+                                        PatternBenefit benefit = 1);
+
 /// Populate the pattern set with the following patterns:
 ///
 /// [FlattenGather]
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 984af50a7b0a51..15a545dbb42180 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1886,6 +1886,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
   MLIRContext *ctx = converter.getDialect()->getContext();
   patterns.add<VectorFMAOpNDRewritePattern>(ctx);
   populateVectorInsertExtractStridedSliceTransforms(patterns);
+  populateVectorStepLoweringPatterns(patterns);
   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
   patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
   patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index d1c95dabd88a5e..b2eca539194a87 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -27,6 +27,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/IR/Matchers.h"
 
 using namespace mlir;
@@ -664,6 +665,7 @@ void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
                                                bool enableVLAVectorization,
                                                bool enableSIMDIndex32) {
   assert(vectorLength > 0);
+  vector::populateVectorStepLoweringPatterns(patterns);
   patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
                               enableVLAVectorization, enableSIMDIndex32);
   patterns.add<ReducChainRewriter<vector::InsertElementOp>,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a2abe1619454f2..2daf3e8a29ff9c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6423,20 +6423,6 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
   return SplatElementsAttr::get(getType(), {constOperand});
 }
 
-//===----------------------------------------------------------------------===//
-// StepOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult StepOp::fold(FoldAdaptor adaptor) {
-  auto resultType = cast<VectorType>(getType());
-  if (resultType.isScalable())
-    return nullptr;
-  SmallVector<APInt> indices;
-  for (unsigned i = 0; i < resultType.getNumElements(); i++)
-    indices.push_back(APInt(/*width=*/64, i));
-  return DenseElementsAttr::get(resultType, indices);
-}
-
 //===----------------------------------------------------------------------===//
 // WarpExecuteOnLane0Op
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 241e83e234d621..b6f49b85c6205a 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -207,6 +207,11 @@ void transform::ApplyTransferToScfPatternsOp::populatePatterns(
   populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
 }
 
+void transform::ApplyLowerStepToArithOps::populatePatterns(
+    RewritePatternSet &patterns) {
+  populateVectorStepLoweringPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index b7e8724c3c2582..9a3bd5d4593d63 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   LowerVectorMultiReduction.cpp
   LowerVectorScan.cpp
   LowerVectorShapeCast.cpp
+  LowerVectorStep.cpp
   LowerVectorTransfer.cpp
   LowerVectorTranspose.cpp
   SubsetOpInterfaceImpl.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp
new file mode 100644
index 00000000000000..fb7a516e4b41c3
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp
@@ -0,0 +1,64 @@
+//===- LowerVectorStep.cpp - Lower 'vector.step' operation ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.step' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
+
+#define DEBUG_TYPE "vector-step-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+struct StepToArithOps : public OpRewritePattern<vector::StepOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::StepOp stepOp,
+                                PatternRewriter &rewriter) const override {
+    auto resultType = cast<VectorType>(stepOp.getType());
+    if (!resultType.isScalable()) {
+      SmallVector<APInt> indices;
+      for (unsigned i = 0; i < resultType.getNumElements(); i++)
+        indices.push_back(APInt(/*width=*/64, i));
+      rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+          stepOp, DenseElementsAttr::get(resultType, indices));
+      return success();
+    }
+    return failure();
+  }
+};
+} // namespace
+
+void mlir::vector::populateVectorStepLoweringPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<StepToArithOps>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
index c3a30e3ee209e8..96866885df4d49 100644
--- a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
@@ -180,6 +180,7 @@ module attributes {transform.with_named_sequence} {
     %func = transform.structured.match ops{["func.func"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %func {
+      transform.apply_patterns.vector.step_to_arith
       transform.apply_patterns.canonicalization
       transform.apply_patterns.linalg.tiling_canonicalization
     } : !transform.any_op
diff --git a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
index 1c6a786bfa436d..8a3dbe6765ebd8 100644
--- a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
@@ -346,6 +346,11 @@ module attributes {transform.with_named_sequence} {
      %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
      %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
      %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
+     %func = transform.structured.match ops{["func.func"]} in %arg1
+       : (!transform.any_op) -> !transform.any_op
+     transform.apply_patterns to %func {
+       transform.apply_patterns.vector.step_to_arith
+     } : !transform.any_op
      transform.yield
   }
 }
@@ -474,6 +479,11 @@ module attributes {transform.with_named_sequence} {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+    %func = transform.structured.match ops{["func.func"]} in %arg1
+       : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+       transform.apply_patterns.vector.step_to_arith
+    } : !transform.any_op
     transform.yield
   }
 }
@@ -505,6 +515,11 @@ module attributes {transform.with_named_sequence} {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+    %func = transform.structured.match ops{["func.func"]} in %arg1
+       : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.step_to_arith
+    } : !transform.any_op
     transform.yield
   }
 }
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index e611a8e22ee23f..2cf7804264ef8c 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -32,6 +32,11 @@ module attributes {transform.with_named_sequence} {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+    %func = transform.structured.match ops{["func.func"]} in %arg1
+       : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+       transform.apply_patterns.vector.step_to_arith
+    } : !transform.any_op
     transform.yield
   }
 }
@@ -172,6 +177,11 @@ module attributes {transform.with_named_sequence} {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
+    %func = transform.structured.match ops{["func.func"]} in %arg1
+       : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+       transform.apply_patterns.vector.step_to_arith
+    } : !transform.any_op
     transform.yield
   }
 }
@@ -207,8 +217,9 @@ func.func @vectorize_nd_tensor_extract_transfer_read_complex(%6: tensor<45x80x16
 // CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 79 : index
-// CHECK:           %[[VAL_11:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index
-// CHECK:           %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : index to vector<1x4xindex>
+// CHECK-DAG:       %[[VAL_1_BCAST:.*]] = vector.broadcast %[[VAL_1]] : index to vector<1x4xindex>
+// CHECK-DAG:       %[[VAL_2_BCAST:.*]] = vector.broadcast %[[VAL_2]] : index to vector<1x4xindex>
+// CHECK:           %[[VAL_12:.*]] = arith.addi %[[VAL_1_BCAST]], %[[VAL_2_BCAST]] : vector<1x4xindex>
 // CHECK:           %[[VAL_13:.*]] = vector.broadcast %[[VAL_3]] : index to vector<4xindex>
 // CHECK:           %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_6]] : vector<4xindex>
 // CHECK:           %[[VAL_15:.*]] = vector.broadcast %[[VAL_4]] : index to vector<4xindex>
@@ -226,6 +237,11 @@ module attributes {transform.with_named_sequence} {
      %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
      %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
      %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
+     %func = transform.structured.match ops{["func.func"]} in %arg1
+       : (!transform.any_op) -> !transform.any_op
+     transform.apply_patterns to %func {
+       transform.apply_patterns.vector.step_to_arith
+     } : !transform.any_op
      transform.yield
    }
 }
@@ -306,6 +322,11 @@ module attributes {transform.with_named_sequence} {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
+    %func = transform.structured.match ops{["func.func"]} in %arg0
+       : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.step_to_arith
+    } : !transform.any_op
     transform.yield
   }
 }
@@ -321,8 +342,9 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
 // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32>
 // CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex>
-// CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
-// CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<1xindex>
+// CHECK: %[[ARG1_BCAST0:.*]] = vector.broadcast %arg1 : index to vector<1xindex>
+// CHECK: %[[ARG1_BCAST1:.*]] = vector.broadcast %arg1 : index to vector<1xindex>
+// CHECK: %[[B2:.*]] = arith.addi %[[ARG1_BCAST0]], %[[ARG1_BCAST1]] : vector<1xindex>
 // CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex>
 // CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex>
 // CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
@@ -357,17 +379,22 @@ module attributes {transform.with_named_sequence} {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
+    %func = transform.structured.match ops{["func.func"]} in %arg2
+       : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.step_to_arith
+    } : !transform.any_op
     transform.yield
   }
 }
 
 // CHECK-LABEL:   func.func @index_from_output_column_vector_gather_load(
 // CHECK-SAME:      %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> {
+// CHECK:           %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
 // CHECK:           %[[C128:.*]] = arith.constant dense<128> : vector<1x8xindex>
 // CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
 // CHECK:           %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1>
-// CHECK:           %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
 // CHECK:           %[[OUT:.*]] = tensor.empty() : tensor<8x1xf32>
 // CHECK:           %[[B:.*]] = vector.broadcast %[[IDX_VEC]] : vector<8xindex> to vector<1x8xindex>
 // CHECK:           %[[MUL:.*]] = arith.muli %[[B]], %[[C128]] : vector<1x8xindex>
@@ -404,16 +431,21 @@ module attributes {transform.with_named_sequence} {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
+    %func = transform.structured.match ops{["func.func"]} in %arg2
+       : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.step_to_arith
+    } : !transform.any_op
     transform.yield
   }
 }
 
 // CHECK-LABEL:   func.func @index_from_output_column_vector_contiguous_load(
 // CHECK-SAME:      %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> {
+// CHECK:           %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
 // CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
 // CHECK:           %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1>
-// CHECK:           %[[IDX_VEC:.*]] = arith.constant dense<[0,...
[truncated]

@c-rhodes
Copy link
Collaborator

This is not ideal if we want
to do transformation on it and defer the
materizaliztion of the constants much later.

makes sense, the fact it's a simple step is lost after the canonicalization and I didn't consider the value in keeping that around when I added the operation

@manupak
Copy link
Contributor Author

manupak commented Oct 25, 2024

cc: @Groverkss

@banach-space
Copy link
Contributor

Thanks for working on this! I have a few high level comments.

This is not ideal if we want to do transformation on it and defer the materizaliztion of the constants much later.

Do we? Is there a use case for this? I just want to make that we are not moving code for the sake of ... moving code around 😅

Re layering, transformations in "mlir/lib/Dialect/Vector/Transforms/LowerVector{.*}.cpp" lower from higher-level Vector Ops to lower-level Vector ops. Rewriting vector.step as arith.constant feels more like dialect conversion to me. But then we don't have VectorToArith. So, to me, there are actual benefits to keeping this as a folder.

Re changes in tests, we should avoid adding more patterns and TD Ops to the vectorizer tests that already use transform.structured.vectorize_children_and_apply_patterns (this TD Op should encapsulate everything that's needed).

What happens if you don't add transform.apply_patterns.vector.step_to_arith to the tests?

@Groverkss
Copy link
Member

Groverkss commented Oct 28, 2024

Do we? Is there a use case for this? I just want to make that we are not moving code for the sake of ... moving code around 😅

We do actually have a usecase downstream. Sometimes, we see patterns like:

%a = vector.step : vector<128xindex>
%b = vector.extract_strided_slice %thread_id, %a : vector<4xindex>

If you know it's a vector.step, you can fold the step + slice into:

%a = vector.step : vector<4xindex>
%b = arith.constant ... : vector<4xindex>
%c = arith.addi %b, %a
%d = arith.addi %c, %thread_id

So you don't have to materialize a big constant like vector<128xindex> into private memory. Having vector.step allows you to know how the arith.constant materializes.

@manupak
Copy link
Contributor Author

manupak commented Oct 28, 2024

Thanks all, I ll be working on addressing the specific comments.

Replying to higher-level comments from @banach-space,

Do we? Is there a use case for this? I just want to make that we are not moving code for the sake of ... moving code around

Yes, I think @Groverkss is spot on the use case here.
In a higher-level, vector.step carry the semantics of specialized constant which has f(i) = i property. This is immediately lost if 'fold'ed it away. Therefore I d consider that as a lowering to lower level abstraction over a 'folder'.

Re layering, transformations in "mlir/lib/Dialect/Vector/Transforms/LowerVector{.*}.cpp" lower from higher-level Vector Ops to lower-level Vector ops

Is this true ?
I can cite few examples that violates the notion that LowerVector*.cpp exclusively contains vector -> vector rewrites such as e.g. : LowerVectorContract. I suppose arith dialect is a bit special in notion that it inter mixes with most dialects. If we are to push for a seperation of dialects via a dialect conversion, I feel we will have to replicate most operations in vector (such as vector.add) .. Therefore, Im not fully following the case for a VectorToArith because its feels impossible vector to be practically used w/o arith.

So, to me, there are actual benefits to keeping this as a folder.

Im not following the logical inference of the absense of VectorToArith to a keep this as a folder. Can you help me understand this a bit more?
At least to me personally, folders should be proven always beneficial to have the transformation where here we are presenting a counter example for the folding of vector.step.

Re changes in tests, we should avoid adding more patterns and TD Ops to the vectorizer tests that already use transform.structured.vectorize_children_and_apply_patterns (this TD Op should encapsulate everything that's needed).

What happens if you don't add transform.apply_patterns.vector.step_to_arith to the tests?

This part I somewhat agree.

So there were three options :
A1) Introduce and use transform.apply_patterns.vector.step_to_arith to showcase expectation of tests are still met (more or less) using the lowering. -- I took this one but happy to change
A2) include the rewriter in transform.structured.vectorize_children_and_apply_patterns
A3) keep it as vector.step

I can switch to A2 or A3.
Having read again what transform.structured.vectorize_children_and_apply_patterns entails, I feel A2 is better choice. WDYT ?

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! LG!

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM cheers

@banach-space
Copy link
Contributor

Thanks for replying to my comments!

I can cite few examples that violates the notion that LowerVector*.cpp exclusively contains vector -> vector rewrites such as e.g. : LowerVectorContract.

Can you link a specific example from that file that doesn't lower to other Vector Ops?

Bare in mind that there are many examples in-tree that e.g. don't follow our coding guidelines. However, that doesn't mean that it's OK to violate them, it merely means that there's code that should be re-formatted ;-)

Im not following the logical inference of the absense of VectorToArith to a keep this as a folder.

Without a use-case, this felt like moving code around (and creating a new file to maintain and to compile). But now that there is a use-case, we can park that discussion.

We do actually have a usecase downstream. Sometimes, we see patterns like:

%a = vector.step : vector<128xindex>
%b = vector.extract_strided_slice %thread_id, %a : vector<4xindex>
If you know it's a vector.step, you can fold the step + slice into:

%a = vector.step : vector<4xindex>
%b = arith.constant ... : vector<4xindex>
%c = arith.addi %b, %a
%d = arith.addi %c, %thread_id
So you don't have to materialize a big constant like vector<128xindex> into private memory. Having vector.step allows you to know how the arith.constant materializes.

Cool, why not upstream that? That's a good optimisation and nice justification for this change.

This part I somewhat agree.

So there were three options :

A2 is better choice.

+1

As for testing, I'd add this pattern to TestVectorScanLowering (and rename it as TestVectorToArithLowering). That would be the least intrusive and the most straightforward path for now.

Having a dedicated TD op for such a small pattern feels like an overkill and we should probably find other patterns to group this one with. I don't really have a good suggestion just now, so I'd leave it as a TODO. These things become clearer once we start combining stuff.

@Groverkss
Copy link
Member

Groverkss commented Oct 29, 2024

Cool, why not upstream that? That's a good optimisation and nice justification for this change.

It would eventually be upstreamed. But upstreaming a transformation, which can never actually be used because the op folds itself away, may not be useful, which is why this patch comes in first.

I'm not sure this change by itself needs justification though. In the original review for vector.step, moving this transformation out of canonicalization was already requested but not addressed https://github.com/llvm/llvm-project/pull/96776/files#r1667009554 .

@manupak
Copy link
Contributor Author

manupak commented Oct 29, 2024

Thanks @banach-space for taking another look..

Can you link a specific example from that file that doesn't lower to other Vector Ops?

Bare in mind that there are many examples in-tree that e.g. don't follow our coding guidelines. However, that doesn't mean that it's OK to violate them, it merely means that there's code that should be re-formatted ;-)

Value res =
isa<IntegerType>(elementType)
? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
: static_cast<Value>(
rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
return res;

I did not mean to cite as "just because there are other code that does the same thing". What I wanted bring up, you would need arith dialect more inter-twined with other shaped type dialects such as vector and tensor. Therefore a dialect conversion that from VectorToArith did not make sense to me.

A2 is better choice.

I can do A2.

As for testing, I'd add this pattern to TestVectorScanLowering (and rename it as TestVectorToArithLowering). That would be the least intrusive and the most straightforward path for now.

May I ask why vector.scan is special to be combined with vector.step ? Only reason I see its rewrite pattern is named as "ScanToArithOps" but that seemed like a random choice. For e.g. vector.mask do use arith ops as well many other ops in LowerVector{}.cpp files. Most of the lowering patterns involve at least creating an arith.constant. There are few others which involve a partial lowering to arith such as TestVectorGatherLowering.

Therefore, I personally feel its better to keep them seperated rather than combining two things together unless you have a much deeper reasoning why vector.scan and vector.step should be tested in a combined fashion.

Sorry, the least intrusive (I suppose you are worried about intrusion to the already existing tests) to me is to add to transform.structured.vectorize_children_and_apply_patterns not call the combined test pass of vector.scan and vector.step lowering.

Having a dedicated TD op for such a small pattern feels like an overkill and we should probably find other patterns to group this one with. I don't really have a good suggestion just now, so I'd leave it as a TODO.

This is what I meant as A2 isn;t it ?
A2) include the rewriter in transform.structured.vectorize_children_and_apply_patterns

@banach-space
Copy link
Contributor

I can do A2.

+1

Please, could you also reply to this question:

Cool, why not upstream that? That's a good optimisation and nice justification for this change.

?

@manupak
Copy link
Contributor Author

manupak commented Oct 29, 2024

Cool, why not upstream that? That's a good optimisation and nice justification for this change.

We may do that but for ordering of things we need vector.step to survive the folder first :) if that make sense

@manupak
Copy link
Contributor Author

manupak commented Oct 29, 2024

@banach-space I ve removed the TD op now and added the rewrite into transform.structured.vectorize_children_and_apply_patterns

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall +1, we definitely don't want to expand vector.step for large vectors

@@ -1886,6 +1886,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
MLIRContext *ctx = converter.getDialect()->getContext();
patterns.add<VectorFMAOpNDRewritePattern>(ctx);
populateVectorInsertExtractStridedSliceTransforms(patterns);
populateVectorStepLoweringPatterns(patterns);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this valid to add RewritePattern here since it uses the dialect conversion driver?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes -- it should ? AFAIK, the difference is conversion driver imposes legality post-rewrite as defined by the target.
I ve added a test to make sure which was useful in exposing the already existing rewriter pattern should be scoped to scalable vectors as per https://llvm.org/docs/LangRef.html#llvm-stepvector-intrinsic. So thanks!

Copy link
Member

@kuhar kuhar Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthias-springer could you confirm this? I can never remember what the rules are.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strictly speaking, this is lowering "Vector -> Arith" and not "Vector -> LLVM". What happens if you don't include this here?

IIUC, the whole point of this PR is not to run this pattern auto-magically as part of some larger pipeline.

Copy link
Contributor Author

@manupak manupak Nov 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the whole point of this PR is not to run this pattern auto-magically as part of some larger pipeline.

No, the whole point of the PR is defer the materialization of the constants until later.... where the latest is when we go to LLVM.

Strictly speaking, this is lowering "Vector -> Arith" and not "Vector -> LLVM". What happens if you don't include this here?

I need to include it some where in VectorToLLVM pass, so what do you mean by "here" ? I can move the pattern anywhere within the pass -- which is the question Im asking here where would people ideally like it to be..

Moreover, Arith is legal dialect for this pass. See here :

LLVMConversionTarget target(getContext());
target.addLegalDialect<arith::ArithDialect>();
target.addLegalDialect<memref::MemRefDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();

Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but please wait for @banach-space 's approval before landing

@manupak
Copy link
Contributor Author

manupak commented Oct 31, 2024

@banach-space when you have some time PTAL. thank you!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may do that but for ordering of things we need vector.step to survive the folder first :)

Right, so the use-case remains hypothetical? ;-)

I've left some suggestions re the updated test. The comment is long, but that's me trying to save you work 😅 It's borderline out-of-scope, but why not take this refactor one step further?

Currently, the lowering for vector.step lives
under a folder. This is not ideal if we want
to do transformation on it and defer the
materizaliztion of the constants much later.

This commits adds a rewrite pattern that
could be used by using
`transform.structured.vectorize_children_and_apply_patterns`
transform dialect operation.

Moreover, the rewriter of vector.step is also
now used in -convert-vector-to-llvm pass where
it handles scalable and non-scalable types as
LLVM expects it.

As a consequence of removing the vector.step
lowering as its folder, linalg vectorization
will keep vector.step intact.

%func = transform.structured.match ops{["func.func"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.canonicalization
transform.apply_patterns.linalg.tiling_canonicalization
} : !transform.any_op
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of the canonicalization can be removed, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I said I did not remove tiling_canonicalization as it changes the test beyond the scope of the PR.

Copy link
Contributor Author

@manupak manupak Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Prior to this PR canonicalization is folding the vector.step . So I dont see why we have to remove tiling_canonicalization here and now)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the canonicalization is not required for this change, true. But you have already re-written the test and this additional change could save us having to send/review another PR. Either way it's noise, either here or as as a new PR.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Please resolve the thread on rewrite patterns started by @kuhar before landing this (I have "un-resolved" it - looks like it's still ongoing).

@manupak
Copy link
Contributor Author

manupak commented Nov 1, 2024

Thanks @banach-space!

@kuhar if you have any suggestions lmk
(I personally and the test shows that its fine but... I can still move to non-conversion driver)

@kuhar
Copy link
Member

kuhar commented Nov 1, 2024

No blockers for me, just pointed out something we should be confident about instead of guessing. And I am personally not, so just flagging.

@manupak
Copy link
Contributor Author

manupak commented Nov 1, 2024

I hope the test that I added inspires the confidence :)

@manupak
Copy link
Contributor Author

manupak commented Nov 1, 2024

PS : I dont have write access... so I d appreciate if someone can land this if there are no further concerns.

@Groverkss Groverkss merged commit a6e72f9 into llvm:main Nov 1, 2024
8 checks passed
smallp-o-p pushed a commit to smallp-o-p/llvm-project that referenced this pull request Nov 3, 2024
Currently, the lowering for vector.step lives
under a folder. This is not ideal if we want
to do transformation on it and defer the
 materizaliztion of the constants much later.

This commits adds a rewrite pattern that
could be used by using
`transform.structured.vectorize_children_and_apply_patterns`
transform dialect operation.

Moreover, the rewriter of vector.step is also
now used in -convert-vector-to-llvm pass where
it handles scalable and non-scalable types as
LLVM expects it.

As a consequence of removing the vector.step
lowering as its folder, linalg vectorization
will keep vector.step intact.
NoumanAmir657 pushed a commit to NoumanAmir657/llvm-project that referenced this pull request Nov 4, 2024
Currently, the lowering for vector.step lives
under a folder. This is not ideal if we want
to do transformation on it and defer the
 materizaliztion of the constants much later.

This commits adds a rewrite pattern that
could be used by using
`transform.structured.vectorize_children_and_apply_patterns`
transform dialect operation.

Moreover, the rewriter of vector.step is also
now used in -convert-vector-to-llvm pass where
it handles scalable and non-scalable types as
LLVM expects it.

As a consequence of removing the vector.step
lowering as its folder, linalg vectorization
will keep vector.step intact.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants