-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[CIR] Upstream splat op for VectorType #139827
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CIR] Upstream splat op for VectorType #139827
Conversation
@llvm/pr-subscribers-clang Author: Amr Hesham (AmrDeveloper) ChangesThis change adds support for splat op for VectorType Issue #136487 Full diff: https://github.com/llvm/llvm-project/pull/139827.diff 7 Files Affected:
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 9f5fa266742e8..463ef929509bd 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2037,4 +2037,37 @@ def VecExtractOp : CIR_Op<"vec.extract", [Pure,
let hasFolder = 1;
}
+
+//===----------------------------------------------------------------------===//
+// VecSplat
+//===----------------------------------------------------------------------===//
+
+// cir.vec.splat is a separate operation from cir.vec.create because more
+// efficient LLVM IR can be generated for it, and because some optimization and
+// analysis passes can benefit from knowing that all elements of the vector
+// have the same value.
+
+def VecSplatOp : CIR_Op<"vec.splat", [Pure,
+ TypesMatchWith<"type of 'value' matches element type of 'result'", "result",
+ "value", "cast<VectorType>($_self).getElementType()">]> {
+
+ let summary = "Convert a scalar into a vector";
+ let description = [{
+ The `cir.vec.splat` operation creates a vector value from a scalar value.
+ All elements of the vector have the same value, that of the given scalar.
+
+ ```mlir
+ %value = cir.const #cir.int<3> : !s32i
+ %value_vec = cir.vec.splat %value : !s32i, !cir.vector<4 x !s32i>
+ ```
+ }];
+
+ let arguments = (ins CIR_AnyType:$value);
+ let results = (outs CIR_VectorType:$result);
+
+ let assemblyFormat = [{
+ $value `:` type($value) `,` qualified(type($result)) attr-dict
+ }];
+}
+
#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
index 4158973f1054b..e13a4650ed6f9 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
@@ -1650,6 +1650,14 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *ce) {
cgf.convertType(destTy));
}
+ case CK_VectorSplat: {
+ // Create a vector object and fill all elements with the same scalar value.
+ assert(destTy->isVectorType() && "CK_VectorSplat to non-vector type");
+ return cgf.getBuilder().create<cir::VecSplatOp>(
+ cgf.getLoc(subExpr->getSourceRange()), cgf.convertType(destTy),
+ Visit(subExpr));
+ }
+
default:
cgf.getCIRGenModule().errorNYI(subExpr->getSourceRange(),
"CastExpr: ", ce->getCastKindName());
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 3c85bb4b6b41d..ccadcbfd19ca2 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1648,7 +1648,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
CIRToLLVMUnaryOpLowering,
CIRToLLVMVecCreateOpLowering,
CIRToLLVMVecExtractOpLowering,
- CIRToLLVMVecInsertOpLowering
+ CIRToLLVMVecInsertOpLowering,
+ CIRToLLVMVecSplatOpLowering
// clang-format on
>(converter, patterns.getContext());
@@ -1773,6 +1774,38 @@ mlir::LogicalResult CIRToLLVMVecInsertOpLowering::matchAndRewrite(
return mlir::success();
}
+mlir::LogicalResult CIRToLLVMVecSplatOpLowering::matchAndRewrite(
+ cir::VecSplatOp op, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const {
+ // Vector splat can be implemented with an `insertelement` and a
+ // `shufflevector`, which is better than an `insertelement` for each
+ // element in the vector. Start with an undef vector. Insert the value into
+ // the first element. Then use a `shufflevector` with a mask of all 0 to
+ // fill out the entire vector with that value.
+ const auto vecTy = mlir::cast<cir::VectorType>(op.getType());
+ const mlir::Type llvmTy = typeConverter->convertType(vecTy);
+ const mlir::Location loc = op.getLoc();
+ const mlir::Value poison = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
+
+ const mlir::Value elementValue = adaptor.getValue();
+ if (mlir::isa<mlir::LLVM::PoisonOp>(elementValue.getDefiningOp())) {
+ // If the splat value is poison, then we can just use poison value
+ // for the entire vector.
+ rewriter.replaceOp(op, poison);
+ return mlir::success();
+ }
+
+ const mlir::Value indexValue =
+ rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), 0);
+ const mlir::Value oneElement = rewriter.create<mlir::LLVM::InsertElementOp>(
+ loc, poison, elementValue, indexValue);
+ const SmallVector<int32_t> zeroValues(vecTy.getSize(), 0);
+ const mlir::Value shuffled = rewriter.create<mlir::LLVM::ShuffleVectorOp>(
+ loc, oneElement, poison, zeroValues);
+ rewriter.replaceOp(op, shuffled);
+ return mlir::success();
+}
+
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
return std::make_unique<ConvertCIRToLLVMPass>();
}
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
index bd077e3d1d1e0..9eb44c9d60d6e 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
@@ -332,6 +332,16 @@ class CIRToLLVMVecInsertOpLowering
mlir::ConversionPatternRewriter &) const override;
};
+class CIRToLLVMVecSplatOpLowering
+ : public mlir::OpConversionPattern<cir::VecSplatOp> {
+public:
+ using mlir::OpConversionPattern<cir::VecSplatOp>::OpConversionPattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(cir::VecSplatOp op, OpAdaptor,
+ mlir::ConversionPatternRewriter &) const override;
+};
+
} // namespace direct
} // namespace cir
diff --git a/clang/test/CIR/CodeGen/vector-ext.cpp b/clang/test/CIR/CodeGen/vector-ext.cpp
index a16ef42f113df..bf7dbbe7fa579 100644
--- a/clang/test/CIR/CodeGen/vector-ext.cpp
+++ b/clang/test/CIR/CodeGen/vector-ext.cpp
@@ -6,6 +6,7 @@
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG
typedef int vi4 __attribute__((ext_vector_type(4)));
+typedef unsigned int uvi4 __attribute__((ext_vector_type(4)));
typedef int vi3 __attribute__((ext_vector_type(3)));
typedef int vi2 __attribute__((ext_vector_type(2)));
typedef double vd2 __attribute__((ext_vector_type(2)));
@@ -400,4 +401,67 @@ void foo9() {
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
// OGCG: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
-// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
\ No newline at end of file
+// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
+
+void foo11() {
+ vi4 a = {1, 2, 3, 4};
+ vi4 shl = a << 3;
+
+ uvi4 b = {1u, 2u, 3u, 4u};
+ uvi4 shr = b >> 3u;
+}
+
+// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
+// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["b", init]
+// CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["shr", init]
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
+// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !s32i, !cir.vector<4 x !s32i>
+// CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CIR: cir.store %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !u32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !u32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !u32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !u32i
+// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !u32i, !u32i, !u32i, !u32i) : !cir.vector<4 x !u32i>
+// CIR: cir.store %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
+// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !u32i>>, !cir.vector<4 x !u32i>
+// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !u32i
+// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !u32i, !cir.vector<4 x !u32i>
+// CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_B]] : !cir.vector<4 x !u32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !u32i>) -> !cir.vector<4 x !u32i>
+// CIR: cir.store %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
+
+// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
+// LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
+// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// LLVM: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
+// LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
+
+// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
+// OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
+// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// OGCG: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
+// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
diff --git a/clang/test/CIR/CodeGen/vector.cpp b/clang/test/CIR/CodeGen/vector.cpp
index 4546215865095..6f1622eb12c27 100644
--- a/clang/test/CIR/CodeGen/vector.cpp
+++ b/clang/test/CIR/CodeGen/vector.cpp
@@ -6,6 +6,7 @@
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG
typedef int vi4 __attribute__((vector_size(16)));
+typedef unsigned int uvi4 __attribute__((vector_size(16)));
typedef double vd2 __attribute__((vector_size(16)));
typedef long long vll2 __attribute__((vector_size(16)));
@@ -388,4 +389,67 @@ void foo9() {
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
// OGCG: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
-// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
\ No newline at end of file
+// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
+
+void foo11() {
+ vi4 a = {1, 2, 3, 4};
+ vi4 shl = a << 3;
+
+ uvi4 b = {1u, 2u, 3u, 4u};
+ uvi4 shr = b >> 3u;
+}
+
+// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
+// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["b", init]
+// CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["shr", init]
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
+// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !s32i, !cir.vector<4 x !s32i>
+// CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CIR: cir.store %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !u32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !u32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !u32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !u32i
+// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !u32i, !u32i, !u32i, !u32i) : !cir.vector<4 x !u32i>
+// CIR: cir.store %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
+// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !u32i>>, !cir.vector<4 x !u32i>
+// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !u32i
+// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !u32i, !cir.vector<4 x !u32i>
+// CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_B]] : !cir.vector<4 x !u32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !u32i>) -> !cir.vector<4 x !u32i>
+// CIR: cir.store %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
+
+// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
+// LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
+// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// LLVM: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
+// LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
+
+// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
+// OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
+// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// OGCG: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
+// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
diff --git a/clang/test/CIR/IR/vector.cir b/clang/test/CIR/IR/vector.cir
index 21a1f0a7559c4..bfc87350ecba1 100644
--- a/clang/test/CIR/IR/vector.cir
+++ b/clang/test/CIR/IR/vector.cir
@@ -135,4 +135,38 @@ cir.func @vector_insert_element_test() {
// CHECK: cir.return
// CHECK: }
+cir.func @vector_splat_test() {
+ %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+ %1 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
+ %2 = cir.const #cir.int<1> : !s32i
+ %3 = cir.const #cir.int<2> : !s32i
+ %4 = cir.const #cir.int<3> : !s32i
+ %5 = cir.const #cir.int<4> : !s32i
+ %6 = cir.vec.create(%2, %3, %4, %5 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+ cir.store %6, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+ %7 = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+ %8 = cir.const #cir.int<3> : !s32i
+ %9 = cir.vec.splat %8 : !s32i, !cir.vector<4 x !s32i>
+ %10 = cir.shift(left, %7 : !cir.vector<4 x !s32i>, %9 : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+ cir.store %10, %1 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+ cir.return
+}
+
+// CHECK: cir.func @vector_splat_test() {
+// CHECK: %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CHECK: %1 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
+// CHECK: %2 = cir.const #cir.int<1> : !s32i
+// CHECK: %3 = cir.const #cir.int<2> : !s32i
+// CHECK: %4 = cir.const #cir.int<3> : !s32i
+// CHECK: %5 = cir.const #cir.int<4> : !s32i
+// CHECK: %6 = cir.vec.create(%2, %3, %4, %5 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CHECK: cir.store %6, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CHECK: %7 = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CHECK: %8 = cir.const #cir.int<3> : !s32i
+// CHECK: %9 = cir.vec.splat %8 : !s32i, !cir.vector<4 x !s32i>
+// CHECK: %10 = cir.shift(left, %7 : !cir.vector<4 x !s32i>, %9 : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CHECK: cir.store %10, %1 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CHECK: cir.return
+// CHECK: }
+
}
|
@llvm/pr-subscribers-clangir Author: Amr Hesham (AmrDeveloper) ChangesThis change adds support for splat op for VectorType Issue #136487 Full diff: https://github.com/llvm/llvm-project/pull/139827.diff 7 Files Affected:
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 9f5fa266742e8..463ef929509bd 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2037,4 +2037,37 @@ def VecExtractOp : CIR_Op<"vec.extract", [Pure,
let hasFolder = 1;
}
+
+//===----------------------------------------------------------------------===//
+// VecSplat
+//===----------------------------------------------------------------------===//
+
+// cir.vec.splat is a separate operation from cir.vec.create because more
+// efficient LLVM IR can be generated for it, and because some optimization and
+// analysis passes can benefit from knowing that all elements of the vector
+// have the same value.
+
+def VecSplatOp : CIR_Op<"vec.splat", [Pure,
+ TypesMatchWith<"type of 'value' matches element type of 'result'", "result",
+ "value", "cast<VectorType>($_self).getElementType()">]> {
+
+ let summary = "Convert a scalar into a vector";
+ let description = [{
+ The `cir.vec.splat` operation creates a vector value from a scalar value.
+ All elements of the vector have the same value, that of the given scalar.
+
+ ```mlir
+ %value = cir.const #cir.int<3> : !s32i
+ %value_vec = cir.vec.splat %value : !s32i, !cir.vector<4 x !s32i>
+ ```
+ }];
+
+ let arguments = (ins CIR_AnyType:$value);
+ let results = (outs CIR_VectorType:$result);
+
+ let assemblyFormat = [{
+ $value `:` type($value) `,` qualified(type($result)) attr-dict
+ }];
+}
+
#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
index 4158973f1054b..e13a4650ed6f9 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
@@ -1650,6 +1650,14 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *ce) {
cgf.convertType(destTy));
}
+ case CK_VectorSplat: {
+ // Create a vector object and fill all elements with the same scalar value.
+ assert(destTy->isVectorType() && "CK_VectorSplat to non-vector type");
+ return cgf.getBuilder().create<cir::VecSplatOp>(
+ cgf.getLoc(subExpr->getSourceRange()), cgf.convertType(destTy),
+ Visit(subExpr));
+ }
+
default:
cgf.getCIRGenModule().errorNYI(subExpr->getSourceRange(),
"CastExpr: ", ce->getCastKindName());
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 3c85bb4b6b41d..ccadcbfd19ca2 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1648,7 +1648,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
CIRToLLVMUnaryOpLowering,
CIRToLLVMVecCreateOpLowering,
CIRToLLVMVecExtractOpLowering,
- CIRToLLVMVecInsertOpLowering
+ CIRToLLVMVecInsertOpLowering,
+ CIRToLLVMVecSplatOpLowering
// clang-format on
>(converter, patterns.getContext());
@@ -1773,6 +1774,38 @@ mlir::LogicalResult CIRToLLVMVecInsertOpLowering::matchAndRewrite(
return mlir::success();
}
+mlir::LogicalResult CIRToLLVMVecSplatOpLowering::matchAndRewrite(
+ cir::VecSplatOp op, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const {
+ // Vector splat can be implemented with an `insertelement` and a
+ // `shufflevector`, which is better than an `insertelement` for each
+ // element in the vector. Start with an undef vector. Insert the value into
+ // the first element. Then use a `shufflevector` with a mask of all 0 to
+ // fill out the entire vector with that value.
+ const auto vecTy = mlir::cast<cir::VectorType>(op.getType());
+ const mlir::Type llvmTy = typeConverter->convertType(vecTy);
+ const mlir::Location loc = op.getLoc();
+ const mlir::Value poison = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
+
+ const mlir::Value elementValue = adaptor.getValue();
+ if (mlir::isa<mlir::LLVM::PoisonOp>(elementValue.getDefiningOp())) {
+ // If the splat value is poison, then we can just use poison value
+ // for the entire vector.
+ rewriter.replaceOp(op, poison);
+ return mlir::success();
+ }
+
+ const mlir::Value indexValue =
+ rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), 0);
+ const mlir::Value oneElement = rewriter.create<mlir::LLVM::InsertElementOp>(
+ loc, poison, elementValue, indexValue);
+ const SmallVector<int32_t> zeroValues(vecTy.getSize(), 0);
+ const mlir::Value shuffled = rewriter.create<mlir::LLVM::ShuffleVectorOp>(
+ loc, oneElement, poison, zeroValues);
+ rewriter.replaceOp(op, shuffled);
+ return mlir::success();
+}
+
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
return std::make_unique<ConvertCIRToLLVMPass>();
}
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
index bd077e3d1d1e0..9eb44c9d60d6e 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
@@ -332,6 +332,16 @@ class CIRToLLVMVecInsertOpLowering
mlir::ConversionPatternRewriter &) const override;
};
+class CIRToLLVMVecSplatOpLowering
+ : public mlir::OpConversionPattern<cir::VecSplatOp> {
+public:
+ using mlir::OpConversionPattern<cir::VecSplatOp>::OpConversionPattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(cir::VecSplatOp op, OpAdaptor,
+ mlir::ConversionPatternRewriter &) const override;
+};
+
} // namespace direct
} // namespace cir
diff --git a/clang/test/CIR/CodeGen/vector-ext.cpp b/clang/test/CIR/CodeGen/vector-ext.cpp
index a16ef42f113df..bf7dbbe7fa579 100644
--- a/clang/test/CIR/CodeGen/vector-ext.cpp
+++ b/clang/test/CIR/CodeGen/vector-ext.cpp
@@ -6,6 +6,7 @@
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG
typedef int vi4 __attribute__((ext_vector_type(4)));
+typedef unsigned int uvi4 __attribute__((ext_vector_type(4)));
typedef int vi3 __attribute__((ext_vector_type(3)));
typedef int vi2 __attribute__((ext_vector_type(2)));
typedef double vd2 __attribute__((ext_vector_type(2)));
@@ -400,4 +401,67 @@ void foo9() {
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
// OGCG: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
-// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
\ No newline at end of file
+// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
+
+void foo11() {
+ vi4 a = {1, 2, 3, 4};
+ vi4 shl = a << 3;
+
+ uvi4 b = {1u, 2u, 3u, 4u};
+ uvi4 shr = b >> 3u;
+}
+
+// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
+// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["b", init]
+// CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["shr", init]
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
+// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !s32i, !cir.vector<4 x !s32i>
+// CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CIR: cir.store %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !u32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !u32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !u32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !u32i
+// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !u32i, !u32i, !u32i, !u32i) : !cir.vector<4 x !u32i>
+// CIR: cir.store %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
+// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !u32i>>, !cir.vector<4 x !u32i>
+// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !u32i
+// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !u32i, !cir.vector<4 x !u32i>
+// CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_B]] : !cir.vector<4 x !u32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !u32i>) -> !cir.vector<4 x !u32i>
+// CIR: cir.store %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
+
+// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
+// LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
+// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// LLVM: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
+// LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
+
+// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
+// OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
+// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// OGCG: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
+// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
diff --git a/clang/test/CIR/CodeGen/vector.cpp b/clang/test/CIR/CodeGen/vector.cpp
index 4546215865095..6f1622eb12c27 100644
--- a/clang/test/CIR/CodeGen/vector.cpp
+++ b/clang/test/CIR/CodeGen/vector.cpp
@@ -6,6 +6,7 @@
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG
typedef int vi4 __attribute__((vector_size(16)));
+typedef unsigned int uvi4 __attribute__((vector_size(16)));
typedef double vd2 __attribute__((vector_size(16)));
typedef long long vll2 __attribute__((vector_size(16)));
@@ -388,4 +389,67 @@ void foo9() {
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
// OGCG: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
-// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
\ No newline at end of file
+// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
+
+void foo11() {
+ vi4 a = {1, 2, 3, 4};
+ vi4 shl = a << 3;
+
+ uvi4 b = {1u, 2u, 3u, 4u};
+ uvi4 shr = b >> 3u;
+}
+
+// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
+// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["b", init]
+// CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["shr", init]
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
+// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !s32i, !cir.vector<4 x !s32i>
+// CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CIR: cir.store %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !u32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !u32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !u32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !u32i
+// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !u32i, !u32i, !u32i, !u32i) : !cir.vector<4 x !u32i>
+// CIR: cir.store %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
+// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !u32i>>, !cir.vector<4 x !u32i>
+// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !u32i
+// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !u32i, !cir.vector<4 x !u32i>
+// CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_B]] : !cir.vector<4 x !u32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !u32i>) -> !cir.vector<4 x !u32i>
+// CIR: cir.store %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
+
+// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
+// LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
+// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// LLVM: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
+// LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
+
+// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
+// OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
+// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// OGCG: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
+// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
diff --git a/clang/test/CIR/IR/vector.cir b/clang/test/CIR/IR/vector.cir
index 21a1f0a7559c4..bfc87350ecba1 100644
--- a/clang/test/CIR/IR/vector.cir
+++ b/clang/test/CIR/IR/vector.cir
@@ -135,4 +135,38 @@ cir.func @vector_insert_element_test() {
// CHECK: cir.return
// CHECK: }
+cir.func @vector_splat_test() {
+ %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+ %1 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
+ %2 = cir.const #cir.int<1> : !s32i
+ %3 = cir.const #cir.int<2> : !s32i
+ %4 = cir.const #cir.int<3> : !s32i
+ %5 = cir.const #cir.int<4> : !s32i
+ %6 = cir.vec.create(%2, %3, %4, %5 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+ cir.store %6, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+ %7 = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+ %8 = cir.const #cir.int<3> : !s32i
+ %9 = cir.vec.splat %8 : !s32i, !cir.vector<4 x !s32i>
+ %10 = cir.shift(left, %7 : !cir.vector<4 x !s32i>, %9 : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+ cir.store %10, %1 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+ cir.return
+}
+
+// CHECK: cir.func @vector_splat_test() {
+// CHECK: %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CHECK: %1 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
+// CHECK: %2 = cir.const #cir.int<1> : !s32i
+// CHECK: %3 = cir.const #cir.int<2> : !s32i
+// CHECK: %4 = cir.const #cir.int<3> : !s32i
+// CHECK: %5 = cir.const #cir.int<4> : !s32i
+// CHECK: %6 = cir.vec.create(%2, %3, %4, %5 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CHECK: cir.store %6, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CHECK: %7 = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CHECK: %8 = cir.const #cir.int<3> : !s32i
+// CHECK: %9 = cir.vec.splat %8 : !s32i, !cir.vector<4 x !s32i>
+// CHECK: %10 = cir.shift(left, %7 : !cir.vector<4 x !s32i>, %9 : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CHECK: cir.store %10, %1 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CHECK: cir.return
+// CHECK: }
+
}
|
More test cases will be added when Bin and Comp operators are upstreamed for Vector |
case CK_VectorSplat: { | ||
// Create a vector object and fill all elements with the same scalar value. | ||
assert(destTy->isVectorType() && "CK_VectorSplat to non-vector type"); | ||
return cgf.getBuilder().create<cir::VecSplatOp>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return cgf.getBuilder().create<cir::VecSplatOp>( | |
return builder.create<cir::VecSplatOp>( |
const auto vecTy = mlir::cast<cir::VectorType>(op.getType()); | ||
const mlir::Type llvmTy = typeConverter->convertType(vecTy); | ||
const mlir::Location loc = op.getLoc(); | ||
const mlir::Value poison = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because of the way MLIR value semantics work, using const
here doesn't really make sense. This applies throughout this function for mlir::Value.
const mlir::Location loc = op.getLoc(); | ||
const mlir::Value poison = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy); | ||
|
||
const mlir::Value elementValue = adaptor.getValue(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use something like llvm.mlir.constant(dense<1.0> : vector<4xf32>) : vector<4xf32>
if elementValue is a constant?
I see that this will end up getting lowered to the correct splat
constant in LLVM IR, but it seems like we should be representing that somehow in the LLVM dialect also.
clang/test/CIR/IR/vector.cir
Outdated
} | ||
|
||
// CHECK: cir.func @vector_splat_test() { | ||
// CHECK: %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All theses SSA values should use FileCheck regexes (like you did in the other test).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will update it, I think we need to update other tests in CIR/IR too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you are talking about pre-existing tests (not related to this PR), it's fine if you update just the relevant ones here (and in a follow up NFC PR you can update others)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I was talking about the other tests, I will update them later as NFC
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you update the checks in this PR before committing?
const mlir::Value shuffled = rewriter.create<mlir::LLVM::ShuffleVectorOp>( | ||
loc, oneElement, poison, zeroValues); | ||
rewriter.replaceOp(op, shuffled); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const mlir::Value shuffled = rewriter.create<mlir::LLVM::ShuffleVectorOp>( | |
loc, oneElement, poison, zeroValues); | |
rewriter.replaceOp(op, shuffled); | |
rewriter.replaceOpWithNewOp<mlir::LLVM::ShuffleVectorOp>(op, oneElement, | |
poison, zeroValues); |
e9b7f55
to
4216c4b
Compare
}]; | ||
} | ||
|
||
#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's saying there is no new line at the end of this file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure about the icon, but i think yes, i will check and update it now
402ff5a
to
29cc489
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM once all comments are address (looks like you did already?)
Since #142222 was merged, please mirror additional changes from llvm/clangir#1626 before merging. |
29cc489
to
6097632
Compare
I mirrored changes related to the Splat op |
const mlir::Value indexValue = rewriter.create<mlir::LLVM::ConstantOp>( | ||
loc, denseVec.getType(), denseVec); | ||
rewriter.replaceOp(op, indexValue); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not replaceOpWithNewOp
?
const mlir::Value indexValue = rewriter.create<mlir::LLVM::ConstantOp>( | ||
loc, denseVec.getType(), denseVec); | ||
rewriter.replaceOp(op, indexValue); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not replaceOpWithNewOp
?
// element in the vector. Start with an undef vector. Insert the value into | ||
// the first element. Then use a `shufflevector` with a mask of all 0 to | ||
// fill out the entire vector with that value. | ||
const auto vecTy = mlir::cast<cir::VectorType>(op.getType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the cast is no longer needed:
const auto vecTy = mlir::cast<cir::VectorType>(op.getType()); | |
const cir::VectorType vecTy = op.getType(); |
bc1d1ec
to
c6df5f8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
04a6dd6
to
93e1737
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other than pattern matching the value identifiers in the test, this looks good to me.
clang/test/CIR/IR/vector.cir
Outdated
} | ||
|
||
// CHECK: cir.func @vector_splat_test() { | ||
// CHECK: %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you update the checks in this PR before committing?
Sure, I will update the splat test function in this PR, and in NFS PR i will update the old tests |
e9bef3a
to
28f80c6
Compare
This change adds support for splat op for VectorType Issue llvm#136487
This change adds support for splat op for VectorType Issue llvm#136487
This change adds support for splat op for VectorType
Issue #136487