Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2155,6 +2155,8 @@ def VecCmpOp : CIR_Op<"vec.cmp", [Pure, SameTypeOperands]> {
`(` $kind `,` $lhs `,` $rhs `)` `:` qualified(type($lhs)) `,`
qualified(type($result)) attr-dict
}];

let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
98 changes: 98 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1589,6 +1589,104 @@ OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
return elements[index];
}

//===----------------------------------------------------------------------===//
// VecCmpOp
//===----------------------------------------------------------------------===//

OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
auto lhsVecAttr =
mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getLhs());
auto rhsVecAttr =
mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getRhs());
if (!lhsVecAttr || !rhsVecAttr)
return {};

mlir::Type inputElemTy =
mlir::cast<cir::VectorType>(lhsVecAttr.getType()).getElementType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't use auto here. The cast gives the misleading impression that it isn't needed, but this is actually the result of getElementType(), right?

if (!isAnyIntegerOrFloatingPointType(inputElemTy))
return {};

cir::CmpOpKind opKind = adaptor.getKind();
mlir::ArrayAttr lhsVecElhs = lhsVecAttr.getElts();
mlir::ArrayAttr rhsVecElhs = rhsVecAttr.getElts();
uint64_t vecSize = lhsVecElhs.size();

SmallVector<mlir::Attribute, 16> elements(vecSize);
bool isIntAttr = vecSize && mlir::isa<cir::IntAttr>(lhsVecElhs[0]);
for (uint64_t i = 0; i < vecSize; i++) {
mlir::Attribute lhsAttr = lhsVecElhs[i];
mlir::Attribute rhsAttr = rhsVecElhs[i];
int cmpResult = 0;
switch (opKind) {
case cir::CmpOpKind::lt: {
if (isIntAttr) {
cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <
mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
} else {
cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <
mlir::cast<cir::FPAttr>(rhsAttr).getValue();
}
break;
}
case cir::CmpOpKind::le: {
if (isIntAttr) {
cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <=
mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
} else {
cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <=
mlir::cast<cir::FPAttr>(rhsAttr).getValue();
}
break;
}
case cir::CmpOpKind::gt: {
if (isIntAttr) {
cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >
mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
} else {
cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >
mlir::cast<cir::FPAttr>(rhsAttr).getValue();
}
break;
}
case cir::CmpOpKind::ge: {
if (isIntAttr) {
cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >=
mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
} else {
cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >=
mlir::cast<cir::FPAttr>(rhsAttr).getValue();
}
break;
}
case cir::CmpOpKind::eq: {
if (isIntAttr) {
cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() ==
mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
} else {
cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() ==
mlir::cast<cir::FPAttr>(rhsAttr).getValue();
}
break;
}
case cir::CmpOpKind::ne: {
if (isIntAttr) {
cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() !=
mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
} else {
cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() !=
mlir::cast<cir::FPAttr>(rhsAttr).getValue();
}
break;
}
}

elements[i] = cir::IntAttr::get(getType().getElementType(), cmpResult);
}

return cir::ConstVectorAttr::get(
getType(), mlir::ArrayAttr::get(getContext(), elements));
}

//===----------------------------------------------------------------------===//
// VecShuffleOp
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ void CIRCanonicalizePass::runOnOperation() {
// Many operations are here to perform a manual `fold` in
// applyOpPatternsGreedily.
if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp,
ComplexCreateOp, VecCreateOp, VecExtractOp, VecShuffleOp,
ComplexCreateOp, VecCmpOp, VecCreateOp, VecExtractOp, VecShuffleOp,
VecShuffleDynamicOp, VecTernaryOp>(op))
ops.push_back(op);
});
Expand Down
227 changes: 227 additions & 0 deletions clang/test/CIR/Transforms/vector-cmp-fold.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
// RUN: cir-opt %s -cir-canonicalize -o - -split-input-file | FileCheck %s

!s32i = !cir.int<s, 32>

module {
cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
%vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i>
%vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
%new_vec = cir.vec.cmp(eq, %vec_1, %vec_2) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
cir.return %new_vec : !cir.vector<4 x !s32i>
}

// CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
// CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<0> : !s32i, #cir.int<0> : !s32i,
// CHECK-SAME: #cir.int<0> : !s32i, #cir.int<0> : !s32i]> : !cir.vector<4 x !s32i>
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
}

// -----

!s32i = !cir.int<s, 32>

module {
cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
%vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i>
%vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
%new_vec = cir.vec.cmp(ne, %vec_1, %vec_2) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
cir.return %new_vec : !cir.vector<4 x !s32i>
}

// CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
// CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<1> : !s32i,
// CHECK-SAME: #cir.int<1> : !s32i, #cir.int<1> : !s32i]> : !cir.vector<4 x !s32i>
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
}

// -----

!s32i = !cir.int<s, 32>

module {
cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
%vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i>
%vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
%new_vec = cir.vec.cmp(lt, %vec_1, %vec_2) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
cir.return %new_vec : !cir.vector<4 x !s32i>
}

// CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
// CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<1> : !s32i,
// CHECK-SAME: #cir.int<1> : !s32i, #cir.int<1> : !s32i]> : !cir.vector<4 x !s32i>
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
}

// -----

!s32i = !cir.int<s, 32>

module {
cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
%vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i>
%vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
%new_vec = cir.vec.cmp(le, %vec_1, %vec_2) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
cir.return %new_vec : !cir.vector<4 x !s32i>
}

// CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
// CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<1> : !s32i,
// CHECK-SAME: #cir.int<1> : !s32i, #cir.int<1> : !s32i]> : !cir.vector<4 x !s32i>
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
}

// -----

!s32i = !cir.int<s, 32>

module {
cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
%vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i>
%vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
%new_vec = cir.vec.cmp(gt, %vec_1, %vec_2) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
cir.return %new_vec : !cir.vector<4 x !s32i>
}

// CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
// CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<0> : !s32i, #cir.int<0> : !s32i,
// CHECK-SAME: #cir.int<0> : !s32i, #cir.int<0> : !s32i]> : !cir.vector<4 x !s32i>
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
}

// -----

!s32i = !cir.int<s, 32>

module {
cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
%vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i>
%vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
%new_vec = cir.vec.cmp(gt, %vec_1, %vec_2) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
cir.return %new_vec : !cir.vector<4 x !s32i>
}

// CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
// CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<0> : !s32i, #cir.int<0> : !s32i,
// CHECK-SAME: #cir.int<0> : !s32i, #cir.int<0> : !s32i]> : !cir.vector<4 x !s32i>
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
}

// -----

!s32i = !cir.int<s, 32>

module {
cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
%vec_1 = cir.const #cir.const_vector<[#cir.fp<1.000000e+00> : !cir.float, #cir.fp<2.000000e+00>
: !cir.float, #cir.fp<3.000000e+00> : !cir.float, #cir.fp<4.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
%vec_2 = cir.const #cir.const_vector<[#cir.fp<5.000000e+00> : !cir.float, #cir.fp<6.000000e+00>
: !cir.float, #cir.fp<7.000000e+00> : !cir.float, #cir.fp<8.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
%new_vec = cir.vec.cmp(eq, %vec_1, %vec_2) : !cir.vector<4 x !cir.float>, !cir.vector<4 x !s32i>
cir.return %new_vec : !cir.vector<4 x !s32i>
}

// CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
// CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<0> : !s32i, #cir.int<0> : !s32i,
// CHECK-SAME: #cir.int<0> : !s32i, #cir.int<0> : !s32i]> : !cir.vector<4 x !s32i>
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
}

// -----

!s32i = !cir.int<s, 32>

module {
cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
%vec_1 = cir.const #cir.const_vector<[#cir.fp<1.000000e+00> : !cir.float, #cir.fp<2.000000e+00>
: !cir.float, #cir.fp<3.000000e+00> : !cir.float, #cir.fp<4.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
%vec_2 = cir.const #cir.const_vector<[#cir.fp<5.000000e+00> : !cir.float, #cir.fp<6.000000e+00>
: !cir.float, #cir.fp<7.000000e+00> : !cir.float, #cir.fp<8.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
%new_vec = cir.vec.cmp(ne, %vec_1, %vec_2) : !cir.vector<4 x !cir.float>, !cir.vector<4 x !s32i>
cir.return %new_vec : !cir.vector<4 x !s32i>
}

// CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
// CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<1> : !s32i,
// CHECK-SAME: #cir.int<1> : !s32i, #cir.int<1> : !s32i]> : !cir.vector<4 x !s32i>
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
}

// -----

!s32i = !cir.int<s, 32>

module {
cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
%vec_1 = cir.const #cir.const_vector<[#cir.fp<1.000000e+00> : !cir.float, #cir.fp<2.000000e+00>
: !cir.float, #cir.fp<3.000000e+00> : !cir.float, #cir.fp<4.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
%vec_2 = cir.const #cir.const_vector<[#cir.fp<5.000000e+00> : !cir.float, #cir.fp<6.000000e+00>
: !cir.float, #cir.fp<7.000000e+00> : !cir.float, #cir.fp<8.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
%new_vec = cir.vec.cmp(lt, %vec_1, %vec_2) : !cir.vector<4 x !cir.float>, !cir.vector<4 x !s32i>
cir.return %new_vec : !cir.vector<4 x !s32i>
}

// CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
// CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<1> : !s32i,
// CHECK-SAME: #cir.int<1> : !s32i, #cir.int<1> : !s32i]> : !cir.vector<4 x !s32i>
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
}

// -----

!s32i = !cir.int<s, 32>

module {
cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
%vec_1 = cir.const #cir.const_vector<[#cir.fp<1.000000e+00> : !cir.float, #cir.fp<2.000000e+00>
: !cir.float, #cir.fp<3.000000e+00> : !cir.float, #cir.fp<4.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
%vec_2 = cir.const #cir.const_vector<[#cir.fp<5.000000e+00> : !cir.float, #cir.fp<6.000000e+00>
: !cir.float, #cir.fp<7.000000e+00> : !cir.float, #cir.fp<8.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
%new_vec = cir.vec.cmp(le, %vec_1, %vec_2) : !cir.vector<4 x !cir.float>, !cir.vector<4 x !s32i>
cir.return %new_vec : !cir.vector<4 x !s32i>
}

// CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
// CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<1> : !s32i,
// CHECK-SAME: #cir.int<1> : !s32i, #cir.int<1> : !s32i]> : !cir.vector<4 x !s32i>
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
}

// -----

!s32i = !cir.int<s, 32>

module {
cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
%vec_1 = cir.const #cir.const_vector<[#cir.fp<1.000000e+00> : !cir.float, #cir.fp<2.000000e+00>
: !cir.float, #cir.fp<3.000000e+00> : !cir.float, #cir.fp<4.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
%vec_2 = cir.const #cir.const_vector<[#cir.fp<5.000000e+00> : !cir.float, #cir.fp<6.000000e+00>
: !cir.float, #cir.fp<7.000000e+00> : !cir.float, #cir.fp<8.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
%new_vec = cir.vec.cmp(gt, %vec_1, %vec_2) : !cir.vector<4 x !cir.float>, !cir.vector<4 x !s32i>
cir.return %new_vec : !cir.vector<4 x !s32i>
}

// CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
// CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<0> : !s32i, #cir.int<0> : !s32i,
// CHECK-SAME: #cir.int<0> : !s32i, #cir.int<0> : !s32i]> : !cir.vector<4 x !s32i>
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
}

// -----

!s32i = !cir.int<s, 32>

module {
cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
%vec_1 = cir.const #cir.const_vector<[#cir.fp<1.000000e+00> : !cir.float, #cir.fp<2.000000e+00>
: !cir.float, #cir.fp<3.000000e+00> : !cir.float, #cir.fp<4.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
%vec_2 = cir.const #cir.const_vector<[#cir.fp<5.000000e+00> : !cir.float, #cir.fp<6.000000e+00>
: !cir.float, #cir.fp<7.000000e+00> : !cir.float, #cir.fp<8.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
%new_vec = cir.vec.cmp(ge, %vec_1, %vec_2) : !cir.vector<4 x !cir.float>, !cir.vector<4 x !s32i>
cir.return %new_vec : !cir.vector<4 x !s32i>
}

// CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
// CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<0> : !s32i, #cir.int<0> : !s32i,
// CHECK-SAME: #cir.int<0> : !s32i, #cir.int<0> : !s32i]> : !cir.vector<4 x !s32i>
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
}