Skip to content

Commit 3c21cec

Browse files
committed
Use PoisonAttr
1 parent 06a8d18 commit 3c21cec

File tree

8 files changed

+60
-41
lines changed

8 files changed

+60
-41
lines changed

clang/include/clang/CIR/Dialect/IR/CIRAttrs.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,21 @@ def UndefAttr : CIR_Attr<"Undef", "undef", [TypedAttrInterface]> {
169169
let assemblyFormat = [{}];
170170
}
171171

172+
//===----------------------------------------------------------------------===//
173+
// PoisonAttr
174+
//===----------------------------------------------------------------------===//
175+
176+
def PoisonAttr : CIR_Attr<"Poison", "poison", [TypedAttrInterface]> {
177+
let summary = "Represent an poison constant";
178+
let description = [{
179+
The PoisonAttr represents an poison constant, corresponding to LLVM's notion
180+
of poison.
181+
}];
182+
183+
let parameters = (ins AttributeSelfTypeParameter<"">:$type);
184+
let assemblyFormat = [{}];
185+
}
186+
172187
//===----------------------------------------------------------------------===//
173188
// ConstArrayAttr
174189
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -379,30 +379,6 @@ def PtrStrideOp : CIR_Op<"ptr_stride",
379379
let hasVerifier = 0;
380380
}
381381

382-
//===----------------------------------------------------------------------===//
383-
// PoisonOp
384-
//===----------------------------------------------------------------------===//
385-
def PoisonOp : CIR_Op<"poison", [Pure]> {
386-
let summary = "Creates a poison value of CIR type";
387-
388-
let description = [{
389-
Unlike LLVM IR, MLIR does not have first-class poison values. Such values
390-
must be created as SSA values using a dialect operation. This operation
391-
has no operands or attributes. It creates a poison value of the specified
392-
CIR type.
393-
394-
Example:
395-
396-
```mlir
397-
%0 = cir.poison : !cir.vector<!s32i x 2>
398-
```
399-
}];
400-
401-
let results = (outs CIR_AnyType:$res);
402-
let assemblyFormat = "attr-dict `:` type($res)";
403-
let llvmOp = "PoisonOp";
404-
}
405-
406382
//===----------------------------------------------------------------------===//
407383
// ConstantOp
408384
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2275,8 +2275,10 @@ static void vecExtendIntValue(CIRGenFunction &cgf, cir::VectorType argVTy,
22752275
// it before inserting.
22762276
arg = builder.createIntCast(arg, eltTy);
22772277
mlir::Value zero = builder.getConstInt(loc, cgf.SizeTy, 0);
2278+
mlir::Value poison = builder.create<cir::ConstantOp>(
2279+
loc, eltTy, builder.getAttr<cir::PoisonAttr>(eltTy));
22782280
arg = builder.create<cir::VecInsertOp>(
2279-
loc, builder.create<cir::PoisonOp>(loc, argVTy), arg, zero);
2281+
loc, builder.create<cir::VecSplatOp>(loc, argVTy, poison), arg, zero);
22802282
}
22812283

22822284
/// Reduce vector type value to scalar, usually for result of a

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,12 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
389389
return op->emitOpError("undef expects non-void type");
390390
}
391391

392+
if (isa<cir::PoisonAttr>(attrType)) {
393+
if (!::mlir::isa<cir::VoidType>(opType))
394+
return success();
395+
return op->emitOpError("poison expects non-void type");
396+
}
397+
392398
if (mlir::isa<cir::BoolAttr>(attrType)) {
393399
if (!mlir::isa<cir::BoolType>(opType))
394400
return op->emitOpError("result type (")

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,16 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::UndefAttr undefAttr,
425425
loc, converter->convertType(undefAttr.getType()));
426426
}
427427

428+
/// PoisonAttr visitor.
429+
static mlir::Value
430+
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::PoisonAttr poisonAttr,
431+
mlir::ConversionPatternRewriter &rewriter,
432+
const mlir::TypeConverter *converter) {
433+
auto loc = parentOp->getLoc();
434+
return rewriter.create<mlir::LLVM::PoisonOp>(
435+
loc, converter->convertType(poisonAttr.getType()));
436+
}
437+
428438
/// ConstStruct visitor.
429439
static mlir::Value
430440
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstStructAttr constStruct,
@@ -644,6 +654,8 @@ mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::Attribute attr,
644654
return lowerCirAttrAsValue(parentOp, zeroAttr, rewriter, converter);
645655
if (const auto undefAttr = mlir::dyn_cast<cir::UndefAttr>(attr))
646656
return lowerCirAttrAsValue(parentOp, undefAttr, rewriter, converter);
657+
if (const auto poisonAttr = mlir::dyn_cast<cir::PoisonAttr>(attr))
658+
return lowerCirAttrAsValue(parentOp, poisonAttr, rewriter, converter);
647659
if (const auto globalAttr = mlir::dyn_cast<cir::GlobalViewAttr>(attr))
648660
return lowerCirAttrAsValue(parentOp, globalAttr, rewriter, converter);
649661
if (const auto vtableAttr = mlir::dyn_cast<cir::VTableAttr>(attr))
@@ -1555,6 +1567,14 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
15551567
mlir::ConversionPatternRewriter &rewriter) const {
15561568
mlir::Attribute attr = op.getValue();
15571569

1570+
// Regardless of the type, we should lower the constant of poison value
1571+
// into PoisonOp.
1572+
if (mlir::isa<cir::PoisonAttr>(attr)) {
1573+
rewriter.replaceOp(
1574+
op, lowerCirAttrAsValue(op, attr, rewriter, getTypeConverter()));
1575+
return mlir::success();
1576+
}
1577+
15581578
if (mlir::isa<mlir::IntegerType>(op.getType())) {
15591579
// Verified cir.const operations cannot actually be of these types, but the
15601580
// lowering pass may generate temporary cir.const operations with these
@@ -1695,6 +1715,7 @@ mlir::LogicalResult CIRToLLVMVecCreateOpLowering::matchAndRewrite(
16951715
mlir::Value result = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
16961716
assert(vecTy.getSize() == op.getElements().size() &&
16971717
"cir.vec.create op count doesn't match vector type elements count");
1718+
16981719
for (uint64_t i = 0; i < vecTy.getSize(); ++i) {
16991720
mlir::Value indexValue =
17001721
rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), i);
@@ -1745,15 +1766,21 @@ mlir::LogicalResult CIRToLLVMVecSplatOpLowering::matchAndRewrite(
17451766
assert(vecTy && "result type of cir.vec.splat op is not VectorType");
17461767
auto llvmTy = typeConverter->convertType(vecTy);
17471768
auto loc = op.getLoc();
1748-
mlir::Value undef = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
1769+
mlir::Value poison = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
17491770
mlir::Value indexValue =
17501771
rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), 0);
17511772
mlir::Value elementValue = adaptor.getValue();
1773+
if (mlir::isa<mlir::LLVM::PoisonOp>(elementValue.getDefiningOp())) {
1774+
// If the splat value is poison, then we can just use poison value
1775+
// for the entire vector.
1776+
rewriter.replaceOp(op, poison);
1777+
return mlir::success();
1778+
}
17521779
mlir::Value oneElement = rewriter.create<mlir::LLVM::InsertElementOp>(
1753-
loc, undef, elementValue, indexValue);
1780+
loc, poison, elementValue, indexValue);
17541781
SmallVector<int32_t> zeroValues(vecTy.getSize(), 0);
17551782
mlir::Value shuffled = rewriter.create<mlir::LLVM::ShuffleVectorOp>(
1756-
loc, oneElement, undef, zeroValues);
1783+
loc, oneElement, poison, zeroValues);
17571784
rewriter.replaceOp(op, shuffled);
17581785
return mlir::success();
17591786
}

clang/test/CIR/CodeGen/AArch64/neon.c

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14617,11 +14617,12 @@ int16_t test_vqmovns_s32(int32_t a) {
1461714617
// CIR-LABEL: vqmovns_s32
1461814618
// CIR: [[A:%.*]] = cir.load {{.*}} : !cir.ptr<!s32i>, !s32i
1461914619
// CIR: [[VQMOVNS_S32_ZERO1:%.*]] = cir.const #cir.int<0> : !u64i
14620-
// CIR: [[POISON:%.*]] = cir.poison : !cir.vector<!s32i x 4>
14621-
// CIR: [[TMP0:%.*]] = cir.vec.insert [[A]], [[POISON]][[[VQMOVNS_S32_ZERO1]] : !u64i] : !cir.vector<!s32i x 4>
14620+
// CIR: [[POISON:%.*]] = cir.const #cir.poison : !s32i
14621+
// CIR: [[POISON_VEC:%.*]] = cir.vec.splat [[POISON]] : !s32i, !cir.vector<!s32i x 4>
14622+
// CIR: [[TMP0:%.*]] = cir.vec.insert [[A]], [[POISON_VEC]][[[VQMOVNS_S32_ZERO1]] : !u64i] : !cir.vector<!s32i x 4>
1462214623
// CIR: [[VQMOVNS_S32_I:%.*]] = cir.llvm.intrinsic "aarch64.neon.sqxtn" [[TMP0]] : (!cir.vector<!s32i x 4>) -> !cir.vector<!s16i x 4>
1462314624
// CIR: [[VQMOVNS_S32_ZERO2:%.*]] = cir.const #cir.int<0> : !u64i
14624-
// CIR: [[TMP1:%.*]] = cir.vec.extract [[VQMOVNS_S32_I]][[[VQMOVNS_S32_ZERO2]] : !u64i] : !cir.vector<!s16i x 4> loc(#loc4503)
14625+
// CIR: [[TMP1:%.*]] = cir.vec.extract [[VQMOVNS_S32_I]][[[VQMOVNS_S32_ZERO2]] : !u64i] : !cir.vector<!s16i x 4>
1462514626

1462614627
// LLVM: {{.*}}@test_vqmovns_s32(i32{{.*}}[[a:%.*]])
1462714628
// LLVM: [[TMP0:%.*]] = insertelement <4 x i32> poison, i32 [[a]], i64 0

clang/test/CIR/IR/cir-ops.cir

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,6 @@ module {
6565
%3 = cir.shift(left, %1 : !cir.vector<!s32i x 2>, %2 : !cir.vector<!s32i x 2>) -> !cir.vector<!s32i x 2>
6666
cir.return
6767
}
68-
69-
cir.func @poisonvalue() {
70-
%0 = cir.poison : !cir.vector<!s32i x 2>
71-
cir.return
72-
}
7368
}
7469

7570
// CHECK: module {
@@ -123,9 +118,4 @@ module {
123118
// CHECK-NEXT: cir.return
124119
// CHECK-NEXT: }
125120

126-
// CHECK: cir.func @poisonvalue() {
127-
// CHECK-NEXT: %0 = cir.poison : !cir.vector<!s32i x 2>
128-
// CHECK-NEXT: cir.return
129-
// CHECK-NEXT: }
130-
131121
// CHECK: }

clang/test/CIR/Lowering/const.cir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ module {
1717
// CHECK: llvm.mlir.zero : !llvm.array<3 x i32>
1818
%5 = cir.const #cir.undef : !cir.array<!s32i x 3>
1919
// CHECK: llvm.mlir.undef : !llvm.array<3 x i32>
20+
%6 = cir.const #cir.poison : !s32i
21+
// CHECK: llvm.mlir.poison : i32
2022
cir.return
2123
}
2224

0 commit comments

Comments
 (0)