Skip to content

[CIR] Upstream splat op for VectorType #139827

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 11, 2025

Conversation

AmrDeveloper
Copy link
Member

This change adds support for splat op for VectorType

Issue #136487

@llvmbot llvmbot added clang Clang issues not falling into any other category ClangIR Anything related to the ClangIR project labels May 14, 2025
@llvmbot
Copy link
Member

llvmbot commented May 14, 2025

@llvm/pr-subscribers-clang

Author: Amr Hesham (AmrDeveloper)

Changes

This change adds support for splat op for VectorType

Issue #136487


Full diff: https://github.com/llvm/llvm-project/pull/139827.diff

7 Files Affected:

  • (modified) clang/include/clang/CIR/Dialect/IR/CIROps.td (+33)
  • (modified) clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp (+8)
  • (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp (+34-1)
  • (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h (+10)
  • (modified) clang/test/CIR/CodeGen/vector-ext.cpp (+65-1)
  • (modified) clang/test/CIR/CodeGen/vector.cpp (+65-1)
  • (modified) clang/test/CIR/IR/vector.cir (+34)
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 9f5fa266742e8..463ef929509bd 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2037,4 +2037,37 @@ def VecExtractOp : CIR_Op<"vec.extract", [Pure,
   let hasFolder = 1;
 }
 
+
+//===----------------------------------------------------------------------===//
+// VecSplat
+//===----------------------------------------------------------------------===//
+
+// cir.vec.splat is a separate operation from cir.vec.create because more
+// efficient LLVM IR can be generated for it, and because some optimization and
+// analysis passes can benefit from knowing that all elements of the vector
+// have the same value.
+
+def VecSplatOp : CIR_Op<"vec.splat", [Pure,
+  TypesMatchWith<"type of 'value' matches element type of 'result'", "result",
+                 "value", "cast<VectorType>($_self).getElementType()">]> {
+
+  let summary = "Convert a scalar into a vector";
+  let description = [{
+    The `cir.vec.splat` operation creates a vector value from a scalar value.
+    All elements of the vector have the same value, that of the given scalar.
+
+    ```mlir
+   %value = cir.const #cir.int<3> : !s32i
+   %value_vec = cir.vec.splat %value : !s32i, !cir.vector<4 x !s32i>
+    ```
+  }];
+
+  let arguments = (ins CIR_AnyType:$value);
+  let results = (outs CIR_VectorType:$result);
+
+  let assemblyFormat = [{
+    $value `:` type($value) `,` qualified(type($result)) attr-dict
+  }];
+}
+
 #endif // CLANG_CIR_DIALECT_IR_CIROPS_TD
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
index 4158973f1054b..e13a4650ed6f9 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
@@ -1650,6 +1650,14 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *ce) {
                               cgf.convertType(destTy));
   }
 
+  case CK_VectorSplat: {
+    // Create a vector object and fill all elements with the same scalar value.
+    assert(destTy->isVectorType() && "CK_VectorSplat to non-vector type");
+    return cgf.getBuilder().create<cir::VecSplatOp>(
+        cgf.getLoc(subExpr->getSourceRange()), cgf.convertType(destTy),
+        Visit(subExpr));
+  }
+
   default:
     cgf.getCIRGenModule().errorNYI(subExpr->getSourceRange(),
                                    "CastExpr: ", ce->getCastKindName());
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 3c85bb4b6b41d..ccadcbfd19ca2 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1648,7 +1648,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
                CIRToLLVMUnaryOpLowering,
                CIRToLLVMVecCreateOpLowering,
                CIRToLLVMVecExtractOpLowering,
-               CIRToLLVMVecInsertOpLowering
+               CIRToLLVMVecInsertOpLowering,
+               CIRToLLVMVecSplatOpLowering
       // clang-format on
       >(converter, patterns.getContext());
 
@@ -1773,6 +1774,38 @@ mlir::LogicalResult CIRToLLVMVecInsertOpLowering::matchAndRewrite(
   return mlir::success();
 }
 
+mlir::LogicalResult CIRToLLVMVecSplatOpLowering::matchAndRewrite(
+    cir::VecSplatOp op, OpAdaptor adaptor,
+    mlir::ConversionPatternRewriter &rewriter) const {
+  // Vector splat can be implemented with an `insertelement` and a
+  // `shufflevector`, which is better than an `insertelement` for each
+  // element in the vector. Start with an undef vector. Insert the value into
+  // the first element. Then use a `shufflevector` with a mask of all 0 to
+  // fill out the entire vector with that value.
+  const auto vecTy = mlir::cast<cir::VectorType>(op.getType());
+  const mlir::Type llvmTy = typeConverter->convertType(vecTy);
+  const mlir::Location loc = op.getLoc();
+  const mlir::Value poison = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
+
+  const mlir::Value elementValue = adaptor.getValue();
+  if (mlir::isa<mlir::LLVM::PoisonOp>(elementValue.getDefiningOp())) {
+    // If the splat value is poison, then we can just use poison value
+    // for the entire vector.
+    rewriter.replaceOp(op, poison);
+    return mlir::success();
+  }
+
+  const mlir::Value indexValue =
+      rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), 0);
+  const mlir::Value oneElement = rewriter.create<mlir::LLVM::InsertElementOp>(
+      loc, poison, elementValue, indexValue);
+  const SmallVector<int32_t> zeroValues(vecTy.getSize(), 0);
+  const mlir::Value shuffled = rewriter.create<mlir::LLVM::ShuffleVectorOp>(
+      loc, oneElement, poison, zeroValues);
+  rewriter.replaceOp(op, shuffled);
+  return mlir::success();
+}
+
 std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
   return std::make_unique<ConvertCIRToLLVMPass>();
 }
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
index bd077e3d1d1e0..9eb44c9d60d6e 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
@@ -332,6 +332,16 @@ class CIRToLLVMVecInsertOpLowering
                   mlir::ConversionPatternRewriter &) const override;
 };
 
+class CIRToLLVMVecSplatOpLowering
+    : public mlir::OpConversionPattern<cir::VecSplatOp> {
+public:
+  using mlir::OpConversionPattern<cir::VecSplatOp>::OpConversionPattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(cir::VecSplatOp op, OpAdaptor,
+                  mlir::ConversionPatternRewriter &) const override;
+};
+
 } // namespace direct
 } // namespace cir
 
diff --git a/clang/test/CIR/CodeGen/vector-ext.cpp b/clang/test/CIR/CodeGen/vector-ext.cpp
index a16ef42f113df..bf7dbbe7fa579 100644
--- a/clang/test/CIR/CodeGen/vector-ext.cpp
+++ b/clang/test/CIR/CodeGen/vector-ext.cpp
@@ -6,6 +6,7 @@
 // RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG
 
 typedef int vi4 __attribute__((ext_vector_type(4)));
+typedef unsigned int uvi4 __attribute__((ext_vector_type(4)));
 typedef int vi3 __attribute__((ext_vector_type(3)));
 typedef int vi2 __attribute__((ext_vector_type(2)));
 typedef double vd2 __attribute__((ext_vector_type(2)));
@@ -400,4 +401,67 @@ void foo9() {
 // OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
 // OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
 // OGCG: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
-// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
\ No newline at end of file
+// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
+
+void foo11() {
+  vi4 a = {1, 2, 3, 4};
+  vi4 shl = a << 3;
+
+  uvi4 b = {1u, 2u, 3u, 4u};
+  uvi4 shr = b >> 3u;
+}
+
+// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
+// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["b", init]
+// CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["shr", init]
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
+// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !s32i, !cir.vector<4 x !s32i>
+// CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CIR: cir.store %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !u32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !u32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !u32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !u32i
+// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !u32i, !u32i, !u32i, !u32i) : !cir.vector<4 x !u32i>
+// CIR: cir.store %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
+// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !u32i>>, !cir.vector<4 x !u32i>
+// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !u32i
+// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !u32i, !cir.vector<4 x !u32i>
+// CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_B]] : !cir.vector<4 x !u32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !u32i>) -> !cir.vector<4 x !u32i>
+// CIR: cir.store %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
+
+// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
+// LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
+// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// LLVM: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
+// LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
+
+// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
+// OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
+// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// OGCG: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
+// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
diff --git a/clang/test/CIR/CodeGen/vector.cpp b/clang/test/CIR/CodeGen/vector.cpp
index 4546215865095..6f1622eb12c27 100644
--- a/clang/test/CIR/CodeGen/vector.cpp
+++ b/clang/test/CIR/CodeGen/vector.cpp
@@ -6,6 +6,7 @@
 // RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG
 
 typedef int vi4 __attribute__((vector_size(16)));
+typedef unsigned int uvi4 __attribute__((vector_size(16)));
 typedef double vd2 __attribute__((vector_size(16)));
 typedef long long vll2 __attribute__((vector_size(16)));
 
@@ -388,4 +389,67 @@ void foo9() {
 // OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
 // OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
 // OGCG: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
-// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
\ No newline at end of file
+// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
+
+void foo11() {
+  vi4 a = {1, 2, 3, 4};
+  vi4 shl = a << 3;
+
+  uvi4 b = {1u, 2u, 3u, 4u};
+  uvi4 shr = b >> 3u;
+}
+
+// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
+// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["b", init]
+// CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["shr", init]
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
+// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !s32i, !cir.vector<4 x !s32i>
+// CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CIR: cir.store %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !u32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !u32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !u32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !u32i
+// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !u32i, !u32i, !u32i, !u32i) : !cir.vector<4 x !u32i>
+// CIR: cir.store %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
+// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !u32i>>, !cir.vector<4 x !u32i>
+// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !u32i
+// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !u32i, !cir.vector<4 x !u32i>
+// CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_B]] : !cir.vector<4 x !u32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !u32i>) -> !cir.vector<4 x !u32i>
+// CIR: cir.store %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
+
+// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
+// LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
+// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// LLVM: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
+// LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
+
+// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
+// OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
+// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// OGCG: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
+// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
diff --git a/clang/test/CIR/IR/vector.cir b/clang/test/CIR/IR/vector.cir
index 21a1f0a7559c4..bfc87350ecba1 100644
--- a/clang/test/CIR/IR/vector.cir
+++ b/clang/test/CIR/IR/vector.cir
@@ -135,4 +135,38 @@ cir.func @vector_insert_element_test() {
 // CHECK:    cir.return
 // CHECK:  }
 
+cir.func @vector_splat_test() {
+    %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+    %1 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
+    %2 = cir.const #cir.int<1> : !s32i
+    %3 = cir.const #cir.int<2> : !s32i
+    %4 = cir.const #cir.int<3> : !s32i
+    %5 = cir.const #cir.int<4> : !s32i
+    %6 = cir.vec.create(%2, %3, %4, %5 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+    cir.store %6, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+    %7 = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+    %8 = cir.const #cir.int<3> : !s32i
+    %9 = cir.vec.splat %8 : !s32i, !cir.vector<4 x !s32i>
+    %10 = cir.shift(left, %7 : !cir.vector<4 x !s32i>, %9 : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+    cir.store %10, %1 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+    cir.return
+}
+
+// CHECK: cir.func @vector_splat_test() {
+// CHECK:    %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CHECK:    %1 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
+// CHECK:    %2 = cir.const #cir.int<1> : !s32i
+// CHECK:    %3 = cir.const #cir.int<2> : !s32i
+// CHECK:    %4 = cir.const #cir.int<3> : !s32i
+// CHECK:    %5 = cir.const #cir.int<4> : !s32i
+// CHECK:    %6 = cir.vec.create(%2, %3, %4, %5 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CHECK:    cir.store %6, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CHECK:    %7 = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CHECK:    %8 = cir.const #cir.int<3> : !s32i
+// CHECK:    %9 = cir.vec.splat %8 : !s32i, !cir.vector<4 x !s32i>
+// CHECK:    %10 = cir.shift(left, %7 : !cir.vector<4 x !s32i>, %9 : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CHECK:    cir.store %10, %1 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CHECK:    cir.return
+// CHECK: }
+
 }

@llvmbot
Copy link
Member

llvmbot commented May 14, 2025

@llvm/pr-subscribers-clangir

Author: Amr Hesham (AmrDeveloper)

Changes

This change adds support for splat op for VectorType

Issue #136487


Full diff: https://github.com/llvm/llvm-project/pull/139827.diff

7 Files Affected:

  • (modified) clang/include/clang/CIR/Dialect/IR/CIROps.td (+33)
  • (modified) clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp (+8)
  • (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp (+34-1)
  • (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h (+10)
  • (modified) clang/test/CIR/CodeGen/vector-ext.cpp (+65-1)
  • (modified) clang/test/CIR/CodeGen/vector.cpp (+65-1)
  • (modified) clang/test/CIR/IR/vector.cir (+34)
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 9f5fa266742e8..463ef929509bd 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2037,4 +2037,37 @@ def VecExtractOp : CIR_Op<"vec.extract", [Pure,
   let hasFolder = 1;
 }
 
+
+//===----------------------------------------------------------------------===//
+// VecSplat
+//===----------------------------------------------------------------------===//
+
+// cir.vec.splat is a separate operation from cir.vec.create because more
+// efficient LLVM IR can be generated for it, and because some optimization and
+// analysis passes can benefit from knowing that all elements of the vector
+// have the same value.
+
+def VecSplatOp : CIR_Op<"vec.splat", [Pure,
+  TypesMatchWith<"type of 'value' matches element type of 'result'", "result",
+                 "value", "cast<VectorType>($_self).getElementType()">]> {
+
+  let summary = "Convert a scalar into a vector";
+  let description = [{
+    The `cir.vec.splat` operation creates a vector value from a scalar value.
+    All elements of the vector have the same value, that of the given scalar.
+
+    ```mlir
+   %value = cir.const #cir.int<3> : !s32i
+   %value_vec = cir.vec.splat %value : !s32i, !cir.vector<4 x !s32i>
+    ```
+  }];
+
+  let arguments = (ins CIR_AnyType:$value);
+  let results = (outs CIR_VectorType:$result);
+
+  let assemblyFormat = [{
+    $value `:` type($value) `,` qualified(type($result)) attr-dict
+  }];
+}
+
 #endif // CLANG_CIR_DIALECT_IR_CIROPS_TD
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
index 4158973f1054b..e13a4650ed6f9 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
@@ -1650,6 +1650,14 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *ce) {
                               cgf.convertType(destTy));
   }
 
+  case CK_VectorSplat: {
+    // Create a vector object and fill all elements with the same scalar value.
+    assert(destTy->isVectorType() && "CK_VectorSplat to non-vector type");
+    return cgf.getBuilder().create<cir::VecSplatOp>(
+        cgf.getLoc(subExpr->getSourceRange()), cgf.convertType(destTy),
+        Visit(subExpr));
+  }
+
   default:
     cgf.getCIRGenModule().errorNYI(subExpr->getSourceRange(),
                                    "CastExpr: ", ce->getCastKindName());
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 3c85bb4b6b41d..ccadcbfd19ca2 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1648,7 +1648,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
                CIRToLLVMUnaryOpLowering,
                CIRToLLVMVecCreateOpLowering,
                CIRToLLVMVecExtractOpLowering,
-               CIRToLLVMVecInsertOpLowering
+               CIRToLLVMVecInsertOpLowering,
+               CIRToLLVMVecSplatOpLowering
       // clang-format on
       >(converter, patterns.getContext());
 
@@ -1773,6 +1774,38 @@ mlir::LogicalResult CIRToLLVMVecInsertOpLowering::matchAndRewrite(
   return mlir::success();
 }
 
+mlir::LogicalResult CIRToLLVMVecSplatOpLowering::matchAndRewrite(
+    cir::VecSplatOp op, OpAdaptor adaptor,
+    mlir::ConversionPatternRewriter &rewriter) const {
+  // Vector splat can be implemented with an `insertelement` and a
+  // `shufflevector`, which is better than an `insertelement` for each
+  // element in the vector. Start with an undef vector. Insert the value into
+  // the first element. Then use a `shufflevector` with a mask of all 0 to
+  // fill out the entire vector with that value.
+  const auto vecTy = mlir::cast<cir::VectorType>(op.getType());
+  const mlir::Type llvmTy = typeConverter->convertType(vecTy);
+  const mlir::Location loc = op.getLoc();
+  const mlir::Value poison = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
+
+  const mlir::Value elementValue = adaptor.getValue();
+  if (mlir::isa<mlir::LLVM::PoisonOp>(elementValue.getDefiningOp())) {
+    // If the splat value is poison, then we can just use poison value
+    // for the entire vector.
+    rewriter.replaceOp(op, poison);
+    return mlir::success();
+  }
+
+  const mlir::Value indexValue =
+      rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), 0);
+  const mlir::Value oneElement = rewriter.create<mlir::LLVM::InsertElementOp>(
+      loc, poison, elementValue, indexValue);
+  const SmallVector<int32_t> zeroValues(vecTy.getSize(), 0);
+  const mlir::Value shuffled = rewriter.create<mlir::LLVM::ShuffleVectorOp>(
+      loc, oneElement, poison, zeroValues);
+  rewriter.replaceOp(op, shuffled);
+  return mlir::success();
+}
+
 std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
   return std::make_unique<ConvertCIRToLLVMPass>();
 }
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
index bd077e3d1d1e0..9eb44c9d60d6e 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
@@ -332,6 +332,16 @@ class CIRToLLVMVecInsertOpLowering
                   mlir::ConversionPatternRewriter &) const override;
 };
 
+class CIRToLLVMVecSplatOpLowering
+    : public mlir::OpConversionPattern<cir::VecSplatOp> {
+public:
+  using mlir::OpConversionPattern<cir::VecSplatOp>::OpConversionPattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(cir::VecSplatOp op, OpAdaptor,
+                  mlir::ConversionPatternRewriter &) const override;
+};
+
 } // namespace direct
 } // namespace cir
 
diff --git a/clang/test/CIR/CodeGen/vector-ext.cpp b/clang/test/CIR/CodeGen/vector-ext.cpp
index a16ef42f113df..bf7dbbe7fa579 100644
--- a/clang/test/CIR/CodeGen/vector-ext.cpp
+++ b/clang/test/CIR/CodeGen/vector-ext.cpp
@@ -6,6 +6,7 @@
 // RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG
 
 typedef int vi4 __attribute__((ext_vector_type(4)));
+typedef unsigned int uvi4 __attribute__((ext_vector_type(4)));
 typedef int vi3 __attribute__((ext_vector_type(3)));
 typedef int vi2 __attribute__((ext_vector_type(2)));
 typedef double vd2 __attribute__((ext_vector_type(2)));
@@ -400,4 +401,67 @@ void foo9() {
 // OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
 // OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
 // OGCG: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
-// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
\ No newline at end of file
+// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
+
+void foo11() {
+  vi4 a = {1, 2, 3, 4};
+  vi4 shl = a << 3;
+
+  uvi4 b = {1u, 2u, 3u, 4u};
+  uvi4 shr = b >> 3u;
+}
+
+// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
+// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["b", init]
+// CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["shr", init]
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
+// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !s32i, !cir.vector<4 x !s32i>
+// CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CIR: cir.store %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !u32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !u32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !u32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !u32i
+// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !u32i, !u32i, !u32i, !u32i) : !cir.vector<4 x !u32i>
+// CIR: cir.store %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
+// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !u32i>>, !cir.vector<4 x !u32i>
+// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !u32i
+// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !u32i, !cir.vector<4 x !u32i>
+// CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_B]] : !cir.vector<4 x !u32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !u32i>) -> !cir.vector<4 x !u32i>
+// CIR: cir.store %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
+
+// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
+// LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
+// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// LLVM: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
+// LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
+
+// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
+// OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
+// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// OGCG: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
+// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
diff --git a/clang/test/CIR/CodeGen/vector.cpp b/clang/test/CIR/CodeGen/vector.cpp
index 4546215865095..6f1622eb12c27 100644
--- a/clang/test/CIR/CodeGen/vector.cpp
+++ b/clang/test/CIR/CodeGen/vector.cpp
@@ -6,6 +6,7 @@
 // RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG
 
 typedef int vi4 __attribute__((vector_size(16)));
+typedef unsigned int uvi4 __attribute__((vector_size(16)));
 typedef double vd2 __attribute__((vector_size(16)));
 typedef long long vll2 __attribute__((vector_size(16)));
 
@@ -388,4 +389,67 @@ void foo9() {
 // OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
 // OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
 // OGCG: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
-// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
\ No newline at end of file
+// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
+
+void foo11() {
+  vi4 a = {1, 2, 3, 4};
+  vi4 shl = a << 3;
+
+  uvi4 b = {1u, 2u, 3u, 4u};
+  uvi4 shr = b >> 3u;
+}
+
+// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
+// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["b", init]
+// CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["shr", init]
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
+// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !s32i, !cir.vector<4 x !s32i>
+// CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CIR: cir.store %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !u32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !u32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !u32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !u32i
+// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !u32i, !u32i, !u32i, !u32i) : !cir.vector<4 x !u32i>
+// CIR: cir.store %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
+// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !u32i>>, !cir.vector<4 x !u32i>
+// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !u32i
+// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !u32i, !cir.vector<4 x !u32i>
+// CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_B]] : !cir.vector<4 x !u32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !u32i>) -> !cir.vector<4 x !u32i>
+// CIR: cir.store %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
+
+// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
+// LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
+// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// LLVM: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
+// LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
+
+// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
+// OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
+// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// OGCG: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
+// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
diff --git a/clang/test/CIR/IR/vector.cir b/clang/test/CIR/IR/vector.cir
index 21a1f0a7559c4..bfc87350ecba1 100644
--- a/clang/test/CIR/IR/vector.cir
+++ b/clang/test/CIR/IR/vector.cir
@@ -135,4 +135,38 @@ cir.func @vector_insert_element_test() {
 // CHECK:    cir.return
 // CHECK:  }
 
+cir.func @vector_splat_test() {
+    %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+    %1 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
+    %2 = cir.const #cir.int<1> : !s32i
+    %3 = cir.const #cir.int<2> : !s32i
+    %4 = cir.const #cir.int<3> : !s32i
+    %5 = cir.const #cir.int<4> : !s32i
+    %6 = cir.vec.create(%2, %3, %4, %5 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+    cir.store %6, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+    %7 = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+    %8 = cir.const #cir.int<3> : !s32i
+    %9 = cir.vec.splat %8 : !s32i, !cir.vector<4 x !s32i>
+    %10 = cir.shift(left, %7 : !cir.vector<4 x !s32i>, %9 : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+    cir.store %10, %1 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+    cir.return
+}
+
+// CHECK: cir.func @vector_splat_test() {
+// CHECK:    %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CHECK:    %1 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
+// CHECK:    %2 = cir.const #cir.int<1> : !s32i
+// CHECK:    %3 = cir.const #cir.int<2> : !s32i
+// CHECK:    %4 = cir.const #cir.int<3> : !s32i
+// CHECK:    %5 = cir.const #cir.int<4> : !s32i
+// CHECK:    %6 = cir.vec.create(%2, %3, %4, %5 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CHECK:    cir.store %6, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CHECK:    %7 = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CHECK:    %8 = cir.const #cir.int<3> : !s32i
+// CHECK:    %9 = cir.vec.splat %8 : !s32i, !cir.vector<4 x !s32i>
+// CHECK:    %10 = cir.shift(left, %7 : !cir.vector<4 x !s32i>, %9 : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CHECK:    cir.store %10, %1 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CHECK:    cir.return
+// CHECK: }
+
 }

@AmrDeveloper
Copy link
Member Author

More test cases will be added when Bin and Comp operators are upstreamed for Vector

case CK_VectorSplat: {
// Create a vector object and fill all elements with the same scalar value.
assert(destTy->isVectorType() && "CK_VectorSplat to non-vector type");
return cgf.getBuilder().create<cir::VecSplatOp>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return cgf.getBuilder().create<cir::VecSplatOp>(
return builder.create<cir::VecSplatOp>(

const auto vecTy = mlir::cast<cir::VectorType>(op.getType());
const mlir::Type llvmTy = typeConverter->convertType(vecTy);
const mlir::Location loc = op.getLoc();
const mlir::Value poison = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of the way MLIR value semantics work, using const here doesn't really make sense. This applies throughout this function for mlir::Value.

See https://mlir.llvm.org/docs/Rationale/UsageOfConst/

const mlir::Location loc = op.getLoc();
const mlir::Value poison = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);

const mlir::Value elementValue = adaptor.getValue();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use something like llvm.mlir.constant(dense<1.0> : vector<4xf32>) : vector<4xf32> if elementValue is a constant?

I see that this will end up getting lowered to the correct splat constant in LLVM IR, but it seems like we should be representing that somehow in the LLVM dialect also.

}

// CHECK: cir.func @vector_splat_test() {
// CHECK: %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All theses SSA values should use FileCheck regexes (like you did in the other test).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will update it, I think we need to update other tests in CIR/IR too

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you are talking about pre-existing tests (not related to this PR), it's fine if you update just the relevant ones here (and in a follow up NFC PR you can update others)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was talking about the other tests, I will update them later as NFC

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the checks in this PR before committing?

Comment on lines 1803 to 1917
const mlir::Value shuffled = rewriter.create<mlir::LLVM::ShuffleVectorOp>(
loc, oneElement, poison, zeroValues);
rewriter.replaceOp(op, shuffled);
Copy link
Contributor

@xlauko xlauko May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const mlir::Value shuffled = rewriter.create<mlir::LLVM::ShuffleVectorOp>(
loc, oneElement, poison, zeroValues);
rewriter.replaceOp(op, shuffled);
rewriter.replaceOpWithNewOp<mlir::LLVM::ShuffleVectorOp>(op, oneElement,
poison, zeroValues);

@AmrDeveloper AmrDeveloper force-pushed the cir_upstream_splat_vec_op branch from e9b7f55 to 4216c4b Compare May 29, 2025 16:05
}];
}

#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's saying there is no new line at the end of this file.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about the icon, but i think yes, i will check and update it now

@AmrDeveloper AmrDeveloper force-pushed the cir_upstream_splat_vec_op branch 2 times, most recently from 402ff5a to 29cc489 Compare June 3, 2025 17:22
Copy link
Member

@bcardosolopes bcardosolopes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM once all comments are address (looks like you did already?)

@xlauko
Copy link
Contributor

xlauko commented Jun 5, 2025

Since #142222 was merged, please mirror additional changes from llvm/clangir#1626 before merging.

@AmrDeveloper AmrDeveloper force-pushed the cir_upstream_splat_vec_op branch from 29cc489 to 6097632 Compare June 5, 2025 10:20
@AmrDeveloper
Copy link
Member Author

Since #142222 was merged, please mirror additional changes from llvm/clangir#1626 before merging.

I mirrored changes related to the Splat op

Comment on lines 1914 to 1916
const mlir::Value indexValue = rewriter.create<mlir::LLVM::ConstantOp>(
loc, denseVec.getType(), denseVec);
rewriter.replaceOp(op, indexValue);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not replaceOpWithNewOp?

Comment on lines 1924 to 1926
const mlir::Value indexValue = rewriter.create<mlir::LLVM::ConstantOp>(
loc, denseVec.getType(), denseVec);
rewriter.replaceOp(op, indexValue);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not replaceOpWithNewOp?

// element in the vector. Start with an undef vector. Insert the value into
// the first element. Then use a `shufflevector` with a mask of all 0 to
// fill out the entire vector with that value.
const auto vecTy = mlir::cast<cir::VectorType>(op.getType());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the cast is no longer needed:

Suggested change
const auto vecTy = mlir::cast<cir::VectorType>(op.getType());
const cir::VectorType vecTy = op.getType();

@AmrDeveloper AmrDeveloper requested review from xlauko and andykaylor June 6, 2025 22:26
@AmrDeveloper AmrDeveloper force-pushed the cir_upstream_splat_vec_op branch from bc1d1ec to c6df5f8 Compare June 8, 2025 14:26
Copy link
Contributor

@xlauko xlauko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@AmrDeveloper AmrDeveloper force-pushed the cir_upstream_splat_vec_op branch from 04a6dd6 to 93e1737 Compare June 10, 2025 16:42
Copy link
Contributor

@andykaylor andykaylor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than pattern matching the value identifiers in the test, this looks good to me.

}

// CHECK: cir.func @vector_splat_test() {
// CHECK: %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the checks in this PR before committing?

@AmrDeveloper
Copy link
Member Author

vector_splat_test
Can you update the checks in this PR before committing?

Sure, I will update the splat test function in this PR, and in NFS PR i will update the old tests

@AmrDeveloper AmrDeveloper force-pushed the cir_upstream_splat_vec_op branch from e9bef3a to 28f80c6 Compare June 10, 2025 21:34
@AmrDeveloper AmrDeveloper merged commit 370e54d into llvm:main Jun 11, 2025
7 checks passed
tomtor pushed a commit to tomtor/llvm-project that referenced this pull request Jun 14, 2025
This change adds support for splat op for VectorType

Issue llvm#136487
akuhlens pushed a commit to akuhlens/llvm-project that referenced this pull request Jun 24, 2025
This change adds support for splat op for VectorType

Issue llvm#136487
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang Clang issues not falling into any other category ClangIR Anything related to the ClangIR project
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants