diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 237daed32532a..853a321b29309 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -2194,4 +2194,40 @@ def VecShuffleDynamicOp : CIR_Op<"vec.shuffle.dynamic", let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// VecTernaryOp +//===----------------------------------------------------------------------===// + +def VecTernaryOp : CIR_Op<"vec.ternary", + [Pure, AllTypesMatch<["result", "lhs", "rhs"]>]> { + let summary = "The `cond ? a : b` ternary operator for vector types"; + let description = [{ + The `cir.vec.ternary` operation represents the C/C++ ternary operator, + `?:`, for vector types, which does a `select` on individual elements of the + vectors. Unlike a regular `?:` operator, there is no short circuiting. All + three arguments are always evaluated. Because there is no short + circuiting, there are no regions in this operation, unlike cir.ternary. + + The first argument is a vector of integral type. The second and third + arguments are vectors of the same type and have the same number of elements + as the first argument. + + The result is a vector of the same type as the second and third arguments. + Each element of the result is `(bool)a[n] ? b[n] : c[n]`. + }]; + + let arguments = (ins + CIR_VectorOfIntType:$cond, + CIR_VectorType:$lhs, + CIR_VectorType:$rhs + ); + + let results = (outs CIR_VectorType:$result); + let assemblyFormat = [{ + `(` $cond `,` $lhs`,` $rhs `)` `:` qualified(type($cond)) `,` + qualified(type($lhs)) attr-dict + }]; + let hasVerifier = 1; +} + #endif // CLANG_CIR_DIALECT_IR_CIROPS_TD diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp index 77287ec45972d..31264ef1a6693 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp @@ -1956,19 +1956,28 @@ mlir::Value ScalarExprEmitter::VisitAbstractConditionalOperator( } } + QualType condType = condExpr->getType(); + // OpenCL: If the condition is a vector, we can treat this condition like // the select function. - if ((cgf.getLangOpts().OpenCL && condExpr->getType()->isVectorType()) || - condExpr->getType()->isExtVectorType()) { + if ((cgf.getLangOpts().OpenCL && condType->isVectorType()) || + condType->isExtVectorType()) { assert(!cir::MissingFeatures::vectorType()); cgf.cgm.errorNYI(e->getSourceRange(), "vector ternary op"); } - if (condExpr->getType()->isVectorType() || - condExpr->getType()->isSveVLSBuiltinType()) { - assert(!cir::MissingFeatures::vecTernaryOp()); - cgf.cgm.errorNYI(e->getSourceRange(), "vector ternary op"); - return {}; + if (condType->isVectorType() || condType->isSveVLSBuiltinType()) { + if (!condType->isVectorType()) { + assert(!cir::MissingFeatures::vecTernaryOp()); + cgf.cgm.errorNYI(loc, "TernaryOp for SVE vector"); + return {}; + } + + mlir::Value condValue = Visit(condExpr); + mlir::Value lhsValue = Visit(lhsExpr); + mlir::Value rhsValue = Visit(rhsExpr); + return builder.create(loc, condValue, lhsValue, + rhsValue); } // If this is a really simple expression (like x ? 4 : 5), emit this as a diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 36f050de9f8bb..5c41211d130bf 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -1589,6 +1589,23 @@ LogicalResult cir::VecShuffleDynamicOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// VecTernaryOp +//===----------------------------------------------------------------------===// + +LogicalResult cir::VecTernaryOp::verify() { + // Verify that the condition operand has the same number of elements as the + // other operands. (The automatic verification already checked that all + // operands are vector types and that the second and third operands are the + // same type.) + if (getCond().getType().getSize() != getLhs().getType().getSize()) { + return emitOpError() << ": the number of elements in " + << getCond().getType() << " and " << getLhs().getType() + << " don't match"; + } + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index f61e85ce3ccec..048e6a604bf01 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -1730,7 +1730,8 @@ void ConvertCIRToLLVMPass::runOnOperation() { CIRToLLVMVecExtractOpLowering, CIRToLLVMVecInsertOpLowering, CIRToLLVMVecCmpOpLowering, - CIRToLLVMVecShuffleDynamicOpLowering + CIRToLLVMVecShuffleDynamicOpLowering, + CIRToLLVMVecTernaryOpLowering // clang-format on >(converter, patterns.getContext()); @@ -1934,6 +1935,20 @@ mlir::LogicalResult CIRToLLVMVecShuffleDynamicOpLowering::matchAndRewrite( return mlir::success(); } +mlir::LogicalResult CIRToLLVMVecTernaryOpLowering::matchAndRewrite( + cir::VecTernaryOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + // Convert `cond` into a vector of i1, then use that in a `select` op. + mlir::Value bitVec = rewriter.create( + op.getLoc(), mlir::LLVM::ICmpPredicate::ne, adaptor.getCond(), + rewriter.create( + op.getCond().getLoc(), + typeConverter->convertType(op.getCond().getType()))); + rewriter.replaceOpWithNewOp( + op, bitVec, adaptor.getLhs(), adaptor.getRhs()); + return mlir::success(); +} + std::unique_ptr createConvertCIRToLLVMPass() { return std::make_unique(); } diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h index de043dfba77b5..d1efa4e181a07 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h @@ -368,6 +368,16 @@ class CIRToLLVMVecShuffleDynamicOpLowering mlir::ConversionPatternRewriter &) const override; }; +class CIRToLLVMVecTernaryOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::VecTernaryOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + } // namespace direct } // namespace cir diff --git a/clang/test/CIR/CodeGen/vector.cpp b/clang/test/CIR/CodeGen/vector.cpp index 4c50f68a56162..586438f1fed2e 100644 --- a/clang/test/CIR/CodeGen/vector.cpp +++ b/clang/test/CIR/CodeGen/vector.cpp @@ -1069,4 +1069,59 @@ void foo17() { // OGCG: %[[VEC_A:.*]] = alloca <2 x double>, align 16 // OGCG: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16 -// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16> \ No newline at end of file +// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16> + +void foo20() { + vi4 a; + vi4 b; + vi4 c; + vi4 r = c ? a : b; +} + +// CIR: %[[RES:.*]] = cir.vec.ternary({{.*}}, {{.*}}, {{.*}}) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i> + +// LLVM: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer +// LLVM: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}} + +// OGCG: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer +// OGCG: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}} + +void foo21() { + vi4 a; + vi4 b; + vi4 r = (a > b) ? (a - b) : (b - a); +} + +// CIR: %[[VEC_COND:.*]] = cir.vec.cmp(gt, {{.*}}, {{.*}}) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i> +// CIR: %[[LHS:.*]] = cir.binop(sub, {{.*}}, {{.*}}) : !cir.vector<4 x !s32i> +// CIR: %[[RHS:.*]] = cir.binop(sub, {{.*}}, {{.*}}) : !cir.vector<4 x !s32i> +// CIR: %[[RES:.*]] = cir.vec.ternary(%[[VEC_COND]], %[[LHS]], %[[RHS]]) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i> + +// LLVM: %[[CMP:.*]] = icmp sgt <4 x i32> {{.*}}, {{.*}} +// LLVM: %[[SEXT:.*]] = sext <4 x i1> %[[CMP]] to <4 x i32> +// LLVM: %[[LHS:.*]] = sub <4 x i32> {{.*}}, {{.*}} +// LLVM: %[[RHS:.*]] = sub <4 x i32> {{.*}}, {{.*}} +// LLVM: %[[VEC_COND:.*]] = icmp ne <4 x i32> %[[SEXT]], zeroinitializer +// LLVM: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> %[[LHS]], <4 x i32> %[[RHS]] + +// OGCG: %[[CMP:.*]] = icmp sgt <4 x i32> {{.*}}, {{.*}} +// OGCG: %[[SEXT:.*]] = sext <4 x i1> %[[CMP]] to <4 x i32> +// OGCG: %[[LHS:.*]] = sub <4 x i32> {{.*}}, {{.*}} +// OGCG: %[[RHS:.*]] = sub <4 x i32> {{.*}}, {{.*}} +// OGCG: %[[VEC_COND:.*]] = icmp ne <4 x i32> %[[SEXT]], zeroinitializer +// OGCG: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> %[[LHS]], <4 x i32> %[[RHS]] + +void foo22() { + vf4 a; + vf4 b; + vi4 c; + vf4 r = c ? a : b; +} + +// CIR: %[[RES:.*]] = cir.vec.ternary({{.*}}, {{.*}}, {{.*}}) : !cir.vector<4 x !s32i>, !cir.vector<4 x !cir.float> + +// LLVM: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer +// LLVM: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x float> {{.*}}, <4 x float> {{.*}} + +// OGCG: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer +// OGCG: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x float> {{.*}}, <4 x float> {{.*}}