diff --git a/clang/lib/CodeGen/CGPointerAuth.cpp b/clang/lib/CodeGen/CGPointerAuth.cpp index dcef01a5eb6d3..0cfeb9ec26a02 100644 --- a/clang/lib/CodeGen/CGPointerAuth.cpp +++ b/clang/lib/CodeGen/CGPointerAuth.cpp @@ -440,9 +440,9 @@ CodeGenModule::getConstantSignedPointer(llvm::Constant *Pointer, unsigned Key, IntegerDiscriminator = llvm::ConstantInt::get(Int64Ty, 0); } - return llvm::ConstantPtrAuth::get(Pointer, - llvm::ConstantInt::get(Int32Ty, Key), - IntegerDiscriminator, AddressDiscriminator); + return llvm::ConstantPtrAuth::get( + Pointer, llvm::ConstantInt::get(Int32Ty, Key), IntegerDiscriminator, + AddressDiscriminator, llvm::Constant::getNullValue(UnqualPtrTy)); } /// Does a given PointerAuthScheme require us to sign a value diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h index dc78eb4164acf..b02891a7b8aa4 100644 --- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h +++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h @@ -437,6 +437,8 @@ enum ConstantsCodes { CST_CODE_CE_GEP_WITH_INRANGE = 31, // [opty, flags, range, n x operands] CST_CODE_CE_GEP = 32, // [opty, flags, n x operands] CST_CODE_PTRAUTH = 33, // [ptr, key, disc, addrdisc] + CST_CODE_PTRAUTH2 = 34, // [ptr, key, disc, addrdisc, + // deactivation_symbol] }; /// CastOpcodes - These are values used in the bitcode files to encode which diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h index 9c9fc8892bdbf..8862ef916656b 100644 --- a/llvm/include/llvm/IR/Constants.h +++ b/llvm/include/llvm/IR/Constants.h @@ -1033,10 +1033,10 @@ class ConstantPtrAuth final : public Constant { friend struct ConstantPtrAuthKeyType; friend class Constant; - constexpr static IntrusiveOperandsAllocMarker AllocMarker{4}; + constexpr static IntrusiveOperandsAllocMarker AllocMarker{5}; ConstantPtrAuth(Constant *Ptr, ConstantInt *Key, ConstantInt *Disc, - Constant *AddrDisc); + Constant *AddrDisc, Constant *DeactivationSymbol); void *operator new(size_t s) { return User::operator new(s, AllocMarker); } @@ -1046,7 +1046,8 @@ class ConstantPtrAuth final : public Constant { public: /// Return a pointer signed with the specified parameters. LLVM_ABI static ConstantPtrAuth *get(Constant *Ptr, ConstantInt *Key, - ConstantInt *Disc, Constant *AddrDisc); + ConstantInt *Disc, Constant *AddrDisc, + Constant *DeactivationSymbol); /// Produce a new ptrauth expression signing the given value using /// the same schema as is stored in one. @@ -1078,6 +1079,10 @@ class ConstantPtrAuth final : public Constant { return !getAddrDiscriminator()->isNullValue(); } + Constant *getDeactivationSymbol() const { + return cast(Op<4>().get()); + } + /// A constant value for the address discriminator which has special /// significance to ctors/dtors lowering. Regular address discrimination can't /// be applied for them since uses of llvm.global_{c|d}tors are disallowed @@ -1106,7 +1111,7 @@ class ConstantPtrAuth final : public Constant { template <> struct OperandTraits - : public FixedNumOperandTraits {}; + : public FixedNumOperandTraits {}; DEFINE_TRANSPARENT_OPERAND_ACCESSORS(ConstantPtrAuth, Constant) diff --git a/llvm/include/llvm/SandboxIR/Constant.h b/llvm/include/llvm/SandboxIR/Constant.h index 6f682a7059d10..2fe923f6c3866 100644 --- a/llvm/include/llvm/SandboxIR/Constant.h +++ b/llvm/include/llvm/SandboxIR/Constant.h @@ -1363,7 +1363,8 @@ class ConstantPtrAuth final : public Constant { public: /// Return a pointer signed with the specified parameters. LLVM_ABI static ConstantPtrAuth *get(Constant *Ptr, ConstantInt *Key, - ConstantInt *Disc, Constant *AddrDisc); + ConstantInt *Disc, Constant *AddrDisc, + Constant *DeactivationSymbol); /// The pointer that is signed in this ptrauth signed pointer. LLVM_ABI Constant *getPointer() const; @@ -1378,6 +1379,8 @@ class ConstantPtrAuth final : public Constant { /// the only global-initializer user of the ptrauth signed pointer. LLVM_ABI Constant *getAddrDiscriminator() const; + Constant *getDeactivationSymbol() const; + /// Whether there is any non-null address discriminator. bool hasAddressDiscriminator() const { return cast(Val)->hasAddressDiscriminator(); diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp index 13bef1f62f1a9..f2ef8749510b6 100644 --- a/llvm/lib/AsmParser/LLParser.cpp +++ b/llvm/lib/AsmParser/LLParser.cpp @@ -4216,11 +4216,12 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) { } case lltok::kw_ptrauth: { // ValID ::= 'ptrauth' '(' ptr @foo ',' i32 - // (',' i64 (',' ptr addrdisc)? )? ')' + // (',' i64 (',' ptr addrdisc (',' ptr ds)? )? )? ')' Lex.Lex(); Constant *Ptr, *Key; - Constant *Disc = nullptr, *AddrDisc = nullptr; + Constant *Disc = nullptr, *AddrDisc = nullptr, + *DeactivationSymbol = nullptr; if (parseToken(lltok::lparen, "expected '(' in constant ptrauth expression") || @@ -4229,11 +4230,14 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) { "expected comma in constant ptrauth expression") || parseGlobalTypeAndValue(Key)) return true; - // If present, parse the optional disc/addrdisc. - if (EatIfPresent(lltok::comma)) - if (parseGlobalTypeAndValue(Disc) || - (EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(AddrDisc))) - return true; + // If present, parse the optional disc/addrdisc/ds. + if (EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(Disc)) + return true; + if (EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(AddrDisc)) + return true; + if (EatIfPresent(lltok::comma) && + parseGlobalTypeAndValue(DeactivationSymbol)) + return true; if (parseToken(lltok::rparen, "expected ')' in constant ptrauth expression")) return true; @@ -4264,7 +4268,16 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) { AddrDisc = ConstantPointerNull::get(PointerType::get(Context, 0)); } - ID.ConstantVal = ConstantPtrAuth::get(Ptr, KeyC, DiscC, AddrDisc); + if (DeactivationSymbol) { + if (!DeactivationSymbol->getType()->isPointerTy()) + return error( + ID.Loc, "constant ptrauth deactivation symbol must be a pointer"); + } else { + DeactivationSymbol = ConstantPointerNull::get(PointerType::get(Context, 0)); + } + + ID.ConstantVal = + ConstantPtrAuth::get(Ptr, KeyC, DiscC, AddrDisc, DeactivationSymbol); ID.Kind = ValID::t_Constant; return false; } diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp index 290d873c632c9..03df10410a8d9 100644 --- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp +++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp @@ -1608,7 +1608,13 @@ Expected BitcodeReader::materializeValue(unsigned StartValID, if (!Disc) return error("ptrauth disc operand must be ConstantInt"); - C = ConstantPtrAuth::get(ConstOps[0], Key, Disc, ConstOps[3]); + auto *DeactivationSymbol = + ConstOps.size() > 4 ? ConstOps[4] + : ConstantPointerNull::get(cast( + ConstOps[3]->getType())); + + C = ConstantPtrAuth::get(ConstOps[0], Key, Disc, ConstOps[3], + DeactivationSymbol); break; } case BitcodeConstant::NoCFIOpcode: { @@ -3808,6 +3814,16 @@ Error BitcodeReader::parseConstants() { (unsigned)Record[2], (unsigned)Record[3]}); break; } + case bitc::CST_CODE_PTRAUTH2: { + if (Record.size() < 4) + return error("Invalid ptrauth record"); + // Ptr, Key, Disc, AddrDisc, DeactivationSymbol + V = BitcodeConstant::create( + Alloc, CurTy, BitcodeConstant::ConstantPtrAuthOpcode, + {(unsigned)Record[0], (unsigned)Record[1], (unsigned)Record[2], + (unsigned)Record[3], (unsigned)Record[4]}); + break; + } } assert(V->getType() == getTypeByID(CurTyID) && "Incorrect result type ID"); diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp index 7e0d81ff4b196..adf79b96da943 100644 --- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp +++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp @@ -3010,11 +3010,12 @@ void ModuleBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal, Record.push_back(VE.getTypeID(NC->getGlobalValue()->getType())); Record.push_back(VE.getValueID(NC->getGlobalValue())); } else if (const auto *CPA = dyn_cast(C)) { - Code = bitc::CST_CODE_PTRAUTH; + Code = bitc::CST_CODE_PTRAUTH2; Record.push_back(VE.getValueID(CPA->getPointer())); Record.push_back(VE.getValueID(CPA->getKey())); Record.push_back(VE.getValueID(CPA->getDiscriminator())); Record.push_back(VE.getValueID(CPA->getAddrDiscriminator())); + Record.push_back(VE.getValueID(CPA->getDeactivationSymbol())); } else { #ifndef NDEBUG C->dump(); diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp index e5a4e1e6b7ce5..a1facc9ba1e19 100644 --- a/llvm/lib/IR/AsmWriter.cpp +++ b/llvm/lib/IR/AsmWriter.cpp @@ -1667,12 +1667,14 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV, if (const ConstantPtrAuth *CPA = dyn_cast(CV)) { Out << "ptrauth ("; - // ptrauth (ptr CST, i32 KEY[, i64 DISC[, ptr ADDRDISC]?]?) + // ptrauth (ptr CST, i32 KEY[, i64 DISC[, ptr ADDRDISC[, ptr DS]?]?]?) unsigned NumOpsToWrite = 2; if (!CPA->getOperand(2)->isNullValue()) NumOpsToWrite = 3; if (!CPA->getOperand(3)->isNullValue()) NumOpsToWrite = 4; + if (!CPA->getOperand(4)->isNullValue()) + NumOpsToWrite = 5; ListSeparator LS; for (unsigned i = 0, e = NumOpsToWrite; i != e; ++i) { diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp index a3c725b2af62a..d87bc6c11400e 100644 --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -2060,19 +2060,22 @@ Value *NoCFIValue::handleOperandChangeImpl(Value *From, Value *To) { // ConstantPtrAuth *ConstantPtrAuth::get(Constant *Ptr, ConstantInt *Key, - ConstantInt *Disc, Constant *AddrDisc) { - Constant *ArgVec[] = {Ptr, Key, Disc, AddrDisc}; + ConstantInt *Disc, Constant *AddrDisc, + Constant *DeactivationSymbol) { + Constant *ArgVec[] = {Ptr, Key, Disc, AddrDisc, DeactivationSymbol}; ConstantPtrAuthKeyType MapKey(ArgVec); LLVMContextImpl *pImpl = Ptr->getContext().pImpl; return pImpl->ConstantPtrAuths.getOrCreate(Ptr->getType(), MapKey); } ConstantPtrAuth *ConstantPtrAuth::getWithSameSchema(Constant *Pointer) const { - return get(Pointer, getKey(), getDiscriminator(), getAddrDiscriminator()); + return get(Pointer, getKey(), getDiscriminator(), getAddrDiscriminator(), + getDeactivationSymbol()); } ConstantPtrAuth::ConstantPtrAuth(Constant *Ptr, ConstantInt *Key, - ConstantInt *Disc, Constant *AddrDisc) + ConstantInt *Disc, Constant *AddrDisc, + Constant *DeactivationSymbol) : Constant(Ptr->getType(), Value::ConstantPtrAuthVal, AllocMarker) { assert(Ptr->getType()->isPointerTy()); assert(Key->getBitWidth() == 32); @@ -2082,6 +2085,7 @@ ConstantPtrAuth::ConstantPtrAuth(Constant *Ptr, ConstantInt *Key, setOperand(1, Key); setOperand(2, Disc); setOperand(3, AddrDisc); + setOperand(4, DeactivationSymbol); } /// Remove the constant from the constant table. diff --git a/llvm/lib/IR/ConstantsContext.h b/llvm/lib/IR/ConstantsContext.h index 51fb40bad201d..584391b53800a 100644 --- a/llvm/lib/IR/ConstantsContext.h +++ b/llvm/lib/IR/ConstantsContext.h @@ -539,7 +539,8 @@ struct ConstantPtrAuthKeyType { ConstantPtrAuth *create(TypeClass *Ty) const { return new ConstantPtrAuth(Operands[0], cast(Operands[1]), - cast(Operands[2]), Operands[3]); + cast(Operands[2]), Operands[3], + Operands[4]); } }; diff --git a/llvm/lib/IR/Core.cpp b/llvm/lib/IR/Core.cpp index f7ef4aa473ef5..904111e742be0 100644 --- a/llvm/lib/IR/Core.cpp +++ b/llvm/lib/IR/Core.cpp @@ -1699,7 +1699,9 @@ LLVMValueRef LLVMConstantPtrAuth(LLVMValueRef Ptr, LLVMValueRef Key, LLVMValueRef Disc, LLVMValueRef AddrDisc) { return wrap(ConstantPtrAuth::get( unwrap(Ptr), unwrap(Key), - unwrap(Disc), unwrap(AddrDisc))); + unwrap(Disc), unwrap(AddrDisc), + ConstantPointerNull::get( + cast(unwrap(AddrDisc)->getType())))); } /*-- Opcode mapping */ diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp index 3ff9895e161c4..3478c2c450ae7 100644 --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -2627,6 +2627,11 @@ void Verifier::visitConstantPtrAuth(const ConstantPtrAuth *CPA) { Check(CPA->getDiscriminator()->getBitWidth() == 64, "signed ptrauth constant discriminator must be i64 constant integer"); + + Check(isa(CPA->getDeactivationSymbol()) || + CPA->getDeactivationSymbol()->isNullValue(), + "signed ptrauth constant deactivation symbol must be a global value " + "or null"); } bool Verifier::verifyAttributeCount(AttributeList Attrs, unsigned Params) { diff --git a/llvm/lib/SandboxIR/Constant.cpp b/llvm/lib/SandboxIR/Constant.cpp index 9de88ef2cf0a0..eb14797af081c 100644 --- a/llvm/lib/SandboxIR/Constant.cpp +++ b/llvm/lib/SandboxIR/Constant.cpp @@ -412,10 +412,12 @@ PointerType *NoCFIValue::getType() const { } ConstantPtrAuth *ConstantPtrAuth::get(Constant *Ptr, ConstantInt *Key, - ConstantInt *Disc, Constant *AddrDisc) { + ConstantInt *Disc, Constant *AddrDisc, + Constant *DeactivationSymbol) { auto *LLVMC = llvm::ConstantPtrAuth::get( cast(Ptr->Val), cast(Key->Val), - cast(Disc->Val), cast(AddrDisc->Val)); + cast(Disc->Val), cast(AddrDisc->Val), + cast(DeactivationSymbol->Val)); return cast(Ptr->getContext().getOrCreateConstant(LLVMC)); } @@ -439,6 +441,11 @@ Constant *ConstantPtrAuth::getAddrDiscriminator() const { cast(Val)->getAddrDiscriminator()); } +Constant *ConstantPtrAuth::getDeactivationSymbol() const { + return Ctx.getOrCreateConstant( + cast(Val)->getDeactivationSymbol()); +} + ConstantPtrAuth *ConstantPtrAuth::getWithSameSchema(Constant *Pointer) const { auto *LLVMC = cast(Val)->getWithSameSchema( cast(Pointer->Val)); diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp index 910d020057ccb..b5922c97dafed 100644 --- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp +++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp @@ -222,7 +222,7 @@ class AArch64AsmPrinter : public AsmPrinter { const MCExpr *emitPAuthRelocationAsIRelative( const MCExpr *Target, uint16_t Disc, AArch64PACKey::ID KeyID, - bool HasAddressDiversity, bool IsDSOLocal); + bool HasAddressDiversity, bool IsDSOLocal, const MCExpr *DSExpr); /// tblgen'erated driver function for lowering simple MI->MC /// pseudo instructions. @@ -2343,15 +2343,17 @@ static void emitAddress(MCStreamer &Streamer, MCRegister Reg, } static bool targetSupportsPAuthRelocation(const Triple &TT, - const MCExpr *Target) { + const MCExpr *Target, + const MCExpr *DSExpr) { // No released version of glibc supports PAuth relocations. if (TT.isOSGlibc()) return false; // We emit PAuth constants as IRELATIVE relocations in cases where the // constant cannot be represented as a PAuth relocation: - // 1) The signed value is not a symbol. - return !isa(Target); + // 1) There is a deactivation symbol. + // 2) The signed value is not a symbol. + return !DSExpr && !isa(Target); } static bool targetSupportsIRelativeRelocation(const Triple &TT) { @@ -2368,12 +2370,12 @@ static bool targetSupportsIRelativeRelocation(const Triple &TT) { const MCExpr *AArch64AsmPrinter::emitPAuthRelocationAsIRelative( const MCExpr *Target, uint16_t Disc, AArch64PACKey::ID KeyID, - bool HasAddressDiversity, bool IsDSOLocal) { + bool HasAddressDiversity, bool IsDSOLocal, const MCExpr *DSExpr) { const Triple &TT = TM.getTargetTriple(); // We only emit an IRELATIVE relocation if the target supports IRELATIVE and // does not support the kind of PAuth relocation that we are trying to emit. - if (targetSupportsPAuthRelocation(TT, Target) || + if (targetSupportsPAuthRelocation(TT, Target, DSExpr) || !targetSupportsIRelativeRelocation(TT)) return nullptr; @@ -2417,6 +2419,16 @@ const MCExpr *AArch64AsmPrinter::emitPAuthRelocationAsIRelative( emitMOVZ(AArch64::X1, Disc, 0); } + if (DSExpr) { + MCSymbol *PrePACInst = OutStreamer->getContext().createTempSymbol(); + OutStreamer->emitLabel(PrePACInst); + + auto *PrePACInstExpr = + MCSymbolRefExpr::create(PrePACInst, OutStreamer->getContext()); + OutStreamer->emitRelocDirective(*PrePACInstExpr, "R_AARCH64_PATCHINST", + DSExpr, SMLoc()); + } + // We don't know the subtarget because this is being emitted for a global // initializer. Because the performance of IFUNC resolvers is unimportant, we // always call the EmuPAC runtime, which will end up using the PAC instruction @@ -2427,6 +2439,12 @@ const MCExpr *AArch64AsmPrinter::emitPAuthRelocationAsIRelative( MCSymbolRefExpr::create(EmuPAC, OutStreamer->getContext()); OutStreamer->emitInstruction(MCInstBuilder(AArch64::B).addExpr(EmuPACRef), *STI); + + // We need a RET despite the above tail call because the deactivation symbol + // may replace it with a NOP. + if (DSExpr) + OutStreamer->emitInstruction( + MCInstBuilder(AArch64::RET).addReg(AArch64::LR), *STI); OutStreamer->popSection(); return MCSymbolRefExpr::create(IRelativeSym, AArch64::S_FUNCINIT, @@ -2458,6 +2476,13 @@ AArch64AsmPrinter::lowerConstantPtrAuth(const ConstantPtrAuth &CPA) { Sym = MCConstantExpr::create(Offset.getSExtValue(), Ctx); } + const MCExpr *DSExpr = nullptr; + if (auto *DS = dyn_cast(CPA.getDeactivationSymbol())) { + if (isa(DS)) + return Sym; + DSExpr = MCSymbolRefExpr::create(getSymbol(DS), Ctx); + } + uint64_t KeyID = CPA.getKey()->getZExtValue(); // We later rely on valid KeyID value in AArch64PACKeyIDToString call from // AArch64AuthMCExpr::printImpl, so fail fast. @@ -2478,9 +2503,13 @@ AArch64AsmPrinter::lowerConstantPtrAuth(const ConstantPtrAuth &CPA) { // Check if we need to represent this with an IRELATIVE and emit it if so. if (auto *IFuncSym = emitPAuthRelocationAsIRelative( Sym, Disc, AArch64PACKey::ID(KeyID), CPA.hasAddressDiscriminator(), - BaseGVB && BaseGVB->isDSOLocal())) + BaseGVB && BaseGVB->isDSOLocal(), DSExpr)) return IFuncSym; + if (DSExpr) + report_fatal_error("deactivation symbols unsupported in constant " + "expressions on this target"); + // Finally build the complete @AUTH expr. return AArch64AuthMCExpr::create(Sym, Disc, AArch64PACKey::ID(KeyID), CPA.hasAddressDiscriminator(), Ctx); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 47e017e17092b..20bb1bc43e598 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -3042,9 +3042,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (NeedSign && isa(II->getArgOperand(4))) { auto *SignKey = cast(II->getArgOperand(3)); auto *SignDisc = cast(II->getArgOperand(4)); - auto *SignAddrDisc = ConstantPointerNull::get(Builder.getPtrTy()); + auto *Null = ConstantPointerNull::get(Builder.getPtrTy()); auto *NewCPA = ConstantPtrAuth::get(CPA->getPointer(), SignKey, - SignDisc, SignAddrDisc); + SignDisc, Null, Null); replaceInstUsesWith( *II, ConstantExpr::getPointerCast(NewCPA, II->getType())); return eraseInstFromFunction(*II); diff --git a/llvm/lib/Transforms/Utils/ValueMapper.cpp b/llvm/lib/Transforms/Utils/ValueMapper.cpp index 8d8a60b6918fe..fca17893f47f4 100644 --- a/llvm/lib/Transforms/Utils/ValueMapper.cpp +++ b/llvm/lib/Transforms/Utils/ValueMapper.cpp @@ -526,8 +526,9 @@ Value *Mapper::mapValue(const Value *V) { if (isa(C)) return getVM()[V] = ConstantVector::get(Ops); if (isa(C)) - return getVM()[V] = ConstantPtrAuth::get(Ops[0], cast(Ops[1]), - cast(Ops[2]), Ops[3]); + return getVM()[V] = + ConstantPtrAuth::get(Ops[0], cast(Ops[1]), + cast(Ops[2]), Ops[3], Ops[4]); // If this is a no-operand constant, it must be because the type was remapped. if (isa(C)) return getVM()[V] = PoisonValue::get(NewTy); diff --git a/llvm/test/Bitcode/compatibility.ll b/llvm/test/Bitcode/compatibility.ll index 0b5ce08c00a23..9b490f1105912 100644 --- a/llvm/test/Bitcode/compatibility.ll +++ b/llvm/test/Bitcode/compatibility.ll @@ -217,9 +217,13 @@ declare void @g.f1() ; CHECK: @g.sanitize_address_dyninit = global i32 0, sanitize_address_dyninit ; CHECK: @g.sanitize_multiple = global i32 0, sanitize_memtag, sanitize_address_dyninit +@ds = external global i32 + ; ptrauth constant @auth_var = global ptr ptrauth (ptr @g1, i32 0, i64 65535, ptr null) ; CHECK: @auth_var = global ptr ptrauth (ptr @g1, i32 0, i64 65535) +@auth_var.ds = global ptr ptrauth (ptr @g1, i32 0, i64 65535, ptr null, ptr @ds) +; CHECK: @auth_var.ds = global ptr ptrauth (ptr @g1, i32 0, i64 65535, ptr null, ptr @ds) ;; Aliases ; Format: @ = [Linkage] [Visibility] [DLLStorageClass] [ThreadLocal] diff --git a/llvm/test/CodeGen/AArch64/ptrauth-irelative.ll b/llvm/test/CodeGen/AArch64/ptrauth-irelative.ll index 7857051668dfb..4ee1c19a86490 100644 --- a/llvm/test/CodeGen/AArch64/ptrauth-irelative.ll +++ b/llvm/test/CodeGen/AArch64/ptrauth-irelative.ll @@ -25,6 +25,23 @@ ; CHECK-NEXT: .xword [[FUNC]]@FUNCINIT @dsolocalref = constant ptr ptrauth (ptr @dsolocal, i32 2, i64 2, ptr null), align 8 +@ds = external global i8 + +; CHECK: dsolocalrefds: +; CHECK-NEXT: [[PLACE:.*]]: +; CHECK-NEXT: .section .text.startup +; CHECK-NEXT: [[FUNC:.*]]: +; CHECK-NEXT: adrp x0, dsolocal +; CHECK-NEXT: add x0, x0, :lo12:dsolocal +; CHECK-NEXT: mov x1, #2 +; CHECK-NEXT: [[LABEL:.L.*]]: +; CHECK-NEXT: .reloc [[LABEL]], R_AARCH64_PATCHINST, ds +; CHECK-NEXT: b __emupac_pacda +; CHECK-NEXT: ret +; CHECK-NEXT: .section .rodata +; CHECK-NEXT: .xword [[FUNC]]@FUNCINIT +@dsolocalrefds = constant ptr ptrauth (ptr @dsolocal, i32 2, i64 2, ptr null, ptr @ds), align 8 + ; CHECK: dsolocalref8: ; CHECK-NEXT: [[PLACE:.*]]: ; CHECK-NEXT: .section .text.startup diff --git a/llvm/test/Verifier/ptrauth-constant.ll b/llvm/test/Verifier/ptrauth-constant.ll new file mode 100644 index 0000000000000..fdd6352cf8469 --- /dev/null +++ b/llvm/test/Verifier/ptrauth-constant.ll @@ -0,0 +1,6 @@ +; RUN: not opt -passes=verify < %s 2>&1 | FileCheck %s + +@g = external global i8 + +; CHECK: signed ptrauth constant deactivation symbol must be a global variable or null +@ptr = global ptr ptrauth (ptr @g, i32 0, i64 65535, ptr null, ptr inttoptr (i64 16 to ptr)) diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp index 33928ac118e0c..168502f89cbf8 100644 --- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp +++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp @@ -1385,7 +1385,7 @@ define ptr @foo() { // Check get(), getKey(), getDiscriminator(), getAddrDiscriminator(). auto *NewPtrAuth = sandboxir::ConstantPtrAuth::get( &F, PtrAuth->getKey(), PtrAuth->getDiscriminator(), - PtrAuth->getAddrDiscriminator()); + PtrAuth->getAddrDiscriminator(), PtrAuth->getDeactivationSymbol()); EXPECT_EQ(NewPtrAuth, PtrAuth); // Check hasAddressDiscriminator(). EXPECT_EQ(PtrAuth->hasAddressDiscriminator(), diff --git a/llvm/unittests/Transforms/Utils/ValueMapperTest.cpp b/llvm/unittests/Transforms/Utils/ValueMapperTest.cpp index 86ad41fa7ad50..90b58236060d2 100644 --- a/llvm/unittests/Transforms/Utils/ValueMapperTest.cpp +++ b/llvm/unittests/Transforms/Utils/ValueMapperTest.cpp @@ -450,6 +450,10 @@ TEST(ValueMapperTest, mapValuePtrAuth) { PtrTy, false, GlobalValue::ExternalLinkage, nullptr, "Storage0"); std::unique_ptr Storage1 = std::make_unique( PtrTy, false, GlobalValue::ExternalLinkage, nullptr, "Storage1"); + std::unique_ptr DS0 = std::make_unique( + PtrTy, false, GlobalValue::ExternalLinkage, nullptr, "DS0"); + std::unique_ptr DS1 = std::make_unique( + PtrTy, false, GlobalValue::ExternalLinkage, nullptr, "DS1"); ConstantInt *ConstKey = ConstantInt::get(Int32Ty, 1); ConstantInt *ConstDisc = ConstantInt::get(Int64Ty, 1234); @@ -457,11 +461,12 @@ TEST(ValueMapperTest, mapValuePtrAuth) { ValueToValueMapTy VM; VM[Var0.get()] = Var1.get(); VM[Storage0.get()] = Storage1.get(); + VM[DS0.get()] = DS1.get(); - ConstantPtrAuth *Value = - ConstantPtrAuth::get(Var0.get(), ConstKey, ConstDisc, Storage0.get()); - ConstantPtrAuth *MappedValue = - ConstantPtrAuth::get(Var1.get(), ConstKey, ConstDisc, Storage1.get()); + ConstantPtrAuth *Value = ConstantPtrAuth::get(Var0.get(), ConstKey, ConstDisc, + Storage0.get(), DS0.get()); + ConstantPtrAuth *MappedValue = ConstantPtrAuth::get( + Var1.get(), ConstKey, ConstDisc, Storage1.get(), DS1.get()); EXPECT_EQ(ValueMapper(VM).mapValue(*Value), MappedValue); }