-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[Arith][MemRef] add AtomicRMWKind::xori to enum #151701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Add missing xor AtomicRMWKind enum in arith. Also add support for xor to memref.atomic_rmw so the change can be tested. This does NOT add it for all users of the enum (e.g. Affine, Vector)
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-arith Author: Scott Manley (rscottmanley) ChangesAdd missing xor AtomicRMWKind enum in arith. Also add support for xor to memref.atomic_rmw so the change can be tested. This does NOT add it for all users of the enum (e.g. Affine, Vector) Full diff: https://github.com/llvm/llvm-project/pull/151701.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index 19a2ade2e95a0..e51c20498746b 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -76,27 +76,29 @@ def Arith_CmpIPredicateAttr : I64EnumAttr<
def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>;
def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>;
-def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>;
-def ATOMIC_RMW_KIND_MAXIMUMF : I64EnumAttrCase<"maximumf", 3>;
-def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 4>;
-def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 5>;
-def ATOMIC_RMW_KIND_MINIMUMF : I64EnumAttrCase<"minimumf", 6>;
-def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 7>;
-def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>;
-def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>;
-def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>;
-def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>;
-def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>;
-def ATOMIC_RMW_KIND_MAXNUMF : I64EnumAttrCase<"maxnumf", 13>;
-def ATOMIC_RMW_KIND_MINNUMF : I64EnumAttrCase<"minnumf", 14>;
+def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 2>;
+def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 3>;
+def ATOMIC_RMW_KIND_MAXIMUMF : I64EnumAttrCase<"maximumf", 4>;
+def ATOMIC_RMW_KIND_MAXNUMF : I64EnumAttrCase<"maxnumf", 5>;
+def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 6>;
+def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 7>;
+def ATOMIC_RMW_KIND_MINIMUMF : I64EnumAttrCase<"minimumf", 8>;
+def ATOMIC_RMW_KIND_MINNUMF : I64EnumAttrCase<"minnumf", 9>;
+def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 10>;
+def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 11>;
+def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 12>;
+def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 13>;
+def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 14>;
+def ATOMIC_RMW_KIND_XORI : I64EnumAttrCase<"xori", 15>;
def AtomicRMWKindAttr : I64EnumAttr<
"AtomicRMWKind", "",
- [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN,
- ATOMIC_RMW_KIND_MAXIMUMF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
- ATOMIC_RMW_KIND_MINIMUMF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
+ [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ANDI,
+ ATOMIC_RMW_KIND_ASSIGN, ATOMIC_RMW_KIND_MAXIMUMF, ATOMIC_RMW_KIND_MAXNUMF,
+ ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU, ATOMIC_RMW_KIND_MINIMUMF,
+ ATOMIC_RMW_KIND_MINNUMF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI,
- ATOMIC_RMW_KIND_ANDI, ATOMIC_RMW_KIND_MAXNUMF, ATOMIC_RMW_KIND_MINNUMF]> {
+ ATOMIC_RMW_KIND_XORI]> {
let cppNamespace = "::mlir::arith";
}
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index dc2035b0700d0..a45fde4b85ab8 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1871,6 +1871,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
return LLVM::AtomicBinOp::umin;
case arith::AtomicRMWKind::ori:
return LLVM::AtomicBinOp::_or;
+ case arith::AtomicRMWKind::xori:
+ return LLVM::AtomicBinOp::_xor;
case arith::AtomicRMWKind::andi:
return LLVM::AtomicBinOp::_and;
default:
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 488c3c369afed..7d4d818ee448b 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2678,6 +2678,7 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
case AtomicRMWKind::addi:
case AtomicRMWKind::maxu:
case AtomicRMWKind::ori:
+ case AtomicRMWKind::xori:
return builder.getZeroAttr(resultType);
case AtomicRMWKind::andi:
return builder.getIntegerAttr(
@@ -2736,7 +2737,7 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
// Integer operations.
.Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
.Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
- .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
+ .Case([](arith::XOrIOp op) { return AtomicRMWKind::xori; })
.Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
.Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
.Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
@@ -2806,6 +2807,8 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
return arith::OrIOp::create(builder, loc, lhs, rhs);
case AtomicRMWKind::andi:
return arith::AndIOp::create(builder, loc, lhs, rhs);
+ case AtomicRMWKind::xori:
+ return arith::XOrIOp::create(builder, loc, lhs, rhs);
// TODO: Add remaining reduction operations.
default:
(void)emitOptionalError(loc, "Reduction operation type not supported");
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 74b968c27a62a..b59d73d1291c8 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3558,6 +3558,7 @@ LogicalResult AtomicRMWOp::verify() {
case arith::AtomicRMWKind::minu:
case arith::AtomicRMWKind::muli:
case arith::AtomicRMWKind::ori:
+ case arith::AtomicRMWKind::xori:
case arith::AtomicRMWKind::andi:
if (!llvm::isa<IntegerType>(getValue().getType()))
return emitOpError() << "with kind '"
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 51d56389dac9e..12a46daffcb9e 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -464,7 +464,9 @@ func.func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fv
// CHECK: llvm.atomicrmw _or %{{.*}}, %{{.*}} acq_rel
memref.atomic_rmw andi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw _and %{{.*}}, %{{.*}} acq_rel
- // CHECK-INTERFACE-COUNT-13: llvm.atomicrmw
+ memref.atomic_rmw xori %ival, %I[%i] : (i32, memref<10xi32>) -> i32
+ // CHECK: llvm.atomicrmw _xor %{{.*}}, %{{.*}} acq_rel
+ // CHECK-INTERFACE-COUNT-14: llvm.atomicrmw
return
}
|
@llvm/pr-subscribers-mlir-memref Author: Scott Manley (rscottmanley) ChangesAdd missing xor AtomicRMWKind enum in arith. Also add support for xor to memref.atomic_rmw so the change can be tested. This does NOT add it for all users of the enum (e.g. Affine, Vector) Full diff: https://github.com/llvm/llvm-project/pull/151701.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index 19a2ade2e95a0..e51c20498746b 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -76,27 +76,29 @@ def Arith_CmpIPredicateAttr : I64EnumAttr<
def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>;
def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>;
-def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>;
-def ATOMIC_RMW_KIND_MAXIMUMF : I64EnumAttrCase<"maximumf", 3>;
-def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 4>;
-def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 5>;
-def ATOMIC_RMW_KIND_MINIMUMF : I64EnumAttrCase<"minimumf", 6>;
-def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 7>;
-def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>;
-def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>;
-def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>;
-def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>;
-def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>;
-def ATOMIC_RMW_KIND_MAXNUMF : I64EnumAttrCase<"maxnumf", 13>;
-def ATOMIC_RMW_KIND_MINNUMF : I64EnumAttrCase<"minnumf", 14>;
+def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 2>;
+def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 3>;
+def ATOMIC_RMW_KIND_MAXIMUMF : I64EnumAttrCase<"maximumf", 4>;
+def ATOMIC_RMW_KIND_MAXNUMF : I64EnumAttrCase<"maxnumf", 5>;
+def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 6>;
+def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 7>;
+def ATOMIC_RMW_KIND_MINIMUMF : I64EnumAttrCase<"minimumf", 8>;
+def ATOMIC_RMW_KIND_MINNUMF : I64EnumAttrCase<"minnumf", 9>;
+def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 10>;
+def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 11>;
+def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 12>;
+def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 13>;
+def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 14>;
+def ATOMIC_RMW_KIND_XORI : I64EnumAttrCase<"xori", 15>;
def AtomicRMWKindAttr : I64EnumAttr<
"AtomicRMWKind", "",
- [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN,
- ATOMIC_RMW_KIND_MAXIMUMF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
- ATOMIC_RMW_KIND_MINIMUMF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
+ [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ANDI,
+ ATOMIC_RMW_KIND_ASSIGN, ATOMIC_RMW_KIND_MAXIMUMF, ATOMIC_RMW_KIND_MAXNUMF,
+ ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU, ATOMIC_RMW_KIND_MINIMUMF,
+ ATOMIC_RMW_KIND_MINNUMF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI,
- ATOMIC_RMW_KIND_ANDI, ATOMIC_RMW_KIND_MAXNUMF, ATOMIC_RMW_KIND_MINNUMF]> {
+ ATOMIC_RMW_KIND_XORI]> {
let cppNamespace = "::mlir::arith";
}
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index dc2035b0700d0..a45fde4b85ab8 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1871,6 +1871,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
return LLVM::AtomicBinOp::umin;
case arith::AtomicRMWKind::ori:
return LLVM::AtomicBinOp::_or;
+ case arith::AtomicRMWKind::xori:
+ return LLVM::AtomicBinOp::_xor;
case arith::AtomicRMWKind::andi:
return LLVM::AtomicBinOp::_and;
default:
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 488c3c369afed..7d4d818ee448b 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2678,6 +2678,7 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
case AtomicRMWKind::addi:
case AtomicRMWKind::maxu:
case AtomicRMWKind::ori:
+ case AtomicRMWKind::xori:
return builder.getZeroAttr(resultType);
case AtomicRMWKind::andi:
return builder.getIntegerAttr(
@@ -2736,7 +2737,7 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
// Integer operations.
.Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
.Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
- .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
+ .Case([](arith::XOrIOp op) { return AtomicRMWKind::xori; })
.Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
.Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
.Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
@@ -2806,6 +2807,8 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
return arith::OrIOp::create(builder, loc, lhs, rhs);
case AtomicRMWKind::andi:
return arith::AndIOp::create(builder, loc, lhs, rhs);
+ case AtomicRMWKind::xori:
+ return arith::XOrIOp::create(builder, loc, lhs, rhs);
// TODO: Add remaining reduction operations.
default:
(void)emitOptionalError(loc, "Reduction operation type not supported");
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 74b968c27a62a..b59d73d1291c8 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3558,6 +3558,7 @@ LogicalResult AtomicRMWOp::verify() {
case arith::AtomicRMWKind::minu:
case arith::AtomicRMWKind::muli:
case arith::AtomicRMWKind::ori:
+ case arith::AtomicRMWKind::xori:
case arith::AtomicRMWKind::andi:
if (!llvm::isa<IntegerType>(getValue().getType()))
return emitOpError() << "with kind '"
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 51d56389dac9e..12a46daffcb9e 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -464,7 +464,9 @@ func.func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fv
// CHECK: llvm.atomicrmw _or %{{.*}}, %{{.*}} acq_rel
memref.atomic_rmw andi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw _and %{{.*}}, %{{.*}} acq_rel
- // CHECK-INTERFACE-COUNT-13: llvm.atomicrmw
+ memref.atomic_rmw xori %ival, %I[%i] : (i32, memref<10xi32>) -> i32
+ // CHECK: llvm.atomicrmw _xor %{{.*}}, %{{.*}} acq_rel
+ // CHECK-INTERFACE-COUNT-14: llvm.atomicrmw
return
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SPIR-V conversion seems missing:
llvm-project/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
Lines 424 to 432 in 7c2f109
switch (atomicOp.getKind()) { | |
ATOMIC_CASE(addi, AtomicIAddOp); | |
ATOMIC_CASE(maxs, AtomicSMaxOp); | |
ATOMIC_CASE(maxu, AtomicUMaxOp); | |
ATOMIC_CASE(mins, AtomicSMinOp); | |
ATOMIC_CASE(minu, AtomicUMinOp); | |
ATOMIC_CASE(ori, AtomicOrOp); | |
ATOMIC_CASE(andi, AtomicAndOp); | |
default: |
It's already missing several others and IMO that's up to whoever cares about this conversion to fix and test it properly. I did point out in the description that I was not intending to modify every user of AtomicRMWKind. |
Add missing xor AtomicRMWKind enum in arith. Also add support for xor to memref.atomic_rmw so the change can be tested.
This does NOT add it for all users of the enum (e.g. Affine, Vector)