36#define DEBUG_TYPE "aarch64-sls-hardening"
38#define AARCH64_SLS_HARDENING_NAME "AArch64 sls hardening pass"
67 static const ThunkKind BR;
68 static const ThunkKind BRAA;
69 static const ThunkKind BRAB;
70 static const ThunkKind BRAAZ;
71 static const ThunkKind BRABZ;
77 static constexpr unsigned NumXRegisters = 32;
80 static unsigned indexOfXReg(
Register Xn);
85 BLRThunks |=
Other.BLRThunks;
86 BLRAAZThunks |=
Other.BLRAAZThunks;
87 BLRABZThunks |=
Other.BLRABZThunks;
88 for (
unsigned I = 0;
I < NumXRegisters; ++
I)
89 BLRAAThunks[
I] |=
Other.BLRAAThunks[
I];
90 for (
unsigned I = 0;
I < NumXRegisters; ++
I)
91 BLRABThunks[
I] |=
Other.BLRABThunks[
I];
97 reg_bitmask_t XnBit = reg_bitmask_t(1) << indexOfXReg(Xn);
98 return getBitmask(Kind, Xm) & XnBit;
102 reg_bitmask_t XnBit = reg_bitmask_t(1) << indexOfXReg(Xn);
103 getBitmask(Kind, Xm) |= XnBit;
108 static_assert(NumXRegisters <=
sizeof(reg_bitmask_t) * CHAR_BIT,
109 "Bitmask is not wide enough to hold all Xn registers");
116 reg_bitmask_t BLRThunks = 0;
117 reg_bitmask_t BLRAAZThunks = 0;
118 reg_bitmask_t BLRABZThunks = 0;
119 reg_bitmask_t BLRAAThunks[NumXRegisters] = {};
120 reg_bitmask_t BLRABThunks[NumXRegisters] = {};
122 reg_bitmask_t &getBitmask(ThunkKind::ThunkKindId Kind,
Register Xm) {
124 case ThunkKind::ThunkBR:
126 case ThunkKind::ThunkBRAAZ:
128 case ThunkKind::ThunkBRABZ:
130 case ThunkKind::ThunkBRAA:
131 return BLRAAThunks[indexOfXReg(Xm)];
132 case ThunkKind::ThunkBRAB:
133 return BLRABThunks[indexOfXReg(Xm)];
139struct SLSHardeningInserter :
ThunkInserter<SLSHardeningInserter, ThunksSet> {
150 ThunksSet ExistingThunks);
154 bool ComdatThunks =
true;
167const ThunkKind ThunkKind::BR = {ThunkBR,
"",
false,
169const ThunkKind ThunkKind::BRAA = {ThunkBRAA,
"aa_",
true,
170 true, AArch64::BRAA};
171const ThunkKind ThunkKind::BRAB = {ThunkBRAB,
"ab_",
true,
172 true, AArch64::BRAB};
173const ThunkKind ThunkKind::BRAAZ = {ThunkBRAAZ,
"aaz_",
false,
174 true, AArch64::BRAAZ};
175const ThunkKind ThunkKind::BRABZ = {ThunkBRABZ,
"abz_",
false,
176 true, AArch64::BRABZ};
180 switch (OriginalOpcode) {
182 case AArch64::BLRNoIP:
183 return &ThunkKind::BR;
185 return &ThunkKind::BRAA;
187 return &ThunkKind::BRAB;
188 case AArch64::BLRAAZ:
189 return &ThunkKind::BRAAZ;
190 case AArch64::BLRABZ:
191 return &ThunkKind::BRABZ;
202 assert(
Reg != AArch64::X16 &&
Reg != AArch64::X17 &&
Reg != AArch64::LR);
205 unsigned Result = (unsigned)
Reg - (
unsigned)AArch64::X0;
206 if (
Reg == AArch64::FP)
208 else if (
Reg == AArch64::XZR)
211 assert(Result < NumXRegisters &&
"Internal register numbering changed");
212 assert(AArch64::GPR64RegClass.getRegister(Result).
id() ==
Reg &&
213 "Internal register numbering changed");
218Register ThunksSet::xRegByIndex(
unsigned N) {
219 return AArch64::GPR64RegClass.getRegister(
N);
228 "Must not insert SpeculationBarrierEndBB as only instruction in MBB.");
230 "SpeculationBarrierEndBB must only follow unconditional control flow "
233 "SpeculationBarrierEndBB must only follow terminators.");
236 ? AArch64::SpeculationBarrierSBEndBB
237 : AArch64::SpeculationBarrierISBDSBEndBB;
239 (
MBBI->getOpcode() != AArch64::SpeculationBarrierSBEndBB &&
240 MBBI->getOpcode() != AArch64::SpeculationBarrierISBDSBEndBB))
244ThunksSet SLSHardeningInserter::insertThunks(MachineModuleInfo &MMI,
246 ThunksSet ExistingThunks) {
247 const AArch64Subtarget *
ST = &MF.
getSubtarget<AArch64Subtarget>();
249 for (
auto &
MBB : MF) {
250 if (
ST->hardenSlsRetBr())
251 hardenReturnsAndBRs(MMI,
MBB);
252 if (
ST->hardenSlsBlr())
253 hardenBLRs(MMI,
MBB, ExistingThunks);
255 return ExistingThunks;
258bool SLSHardeningInserter::hardenReturnsAndBRs(MachineModuleInfo &MMI,
259 MachineBasicBlock &
MBB) {
260 const AArch64Subtarget *
ST =
267 NextMBBI = std::next(
MBBI);
282 unsigned N = ThunksSet::indexOfXReg(Xn);
283 if (!Kind.HasXmOperand)
286 unsigned M = ThunksSet::indexOfXReg(Xm);
290static std::tuple<const ThunkKind &, Register, Register>
293 "Should be filtered out by ThunkInserter");
308 assert(Name.starts_with(
"x") &&
"xN register name expected");
309 bool Fail = Name.drop_front(1).getAsInteger(10,
N);
310 assert(!
Fail &&
N < ThunksSet::NumXRegisters &&
"Unexpected register");
313 return ThunksSet::xRegByIndex(
N);
319 std::tie(XnStr, XmStr) = RegsStr.
split(
'_');
323 Register Xm = Kind.HasXmOperand ? ParseRegName(XmStr) : AArch64::NoRegister;
325 return std::make_tuple(std::ref(Kind), Xn, Xm);
328void SLSHardeningInserter::populateThunk(MachineFunction &MF) {
330 "ComdatThunks value changed since MF creation");
333 const ThunkKind &
Kind = std::get<0>(KindAndRegs);
334 std::tie(std::ignore, Xn, Xm) = KindAndRegs;
336 const TargetInstrInfo *
TII =
343 if (MF.
size() == 1) {
363 Entry->addLiveIn(Xn);
369 MachineInstrBuilder Builder =
371 if (Xm != AArch64::NoRegister) {
372 Entry->addLiveIn(Xm);
384void SLSHardeningInserter::convertBLRToBL(
385 MachineModuleInfo &MMI, MachineBasicBlock &
MBB,
423 MachineInstr &BLR = *
MBBI;
427 unsigned NumRegOperands =
Kind.HasXmOperand ? 2 : 1;
429 "Expected one or two register inputs");
436 MachineFunction &MF = *
MBBI->getMF();
443 if (!Thunks.get(
Kind.Id, Xn, Xm)) {
444 StringRef TargetAttrs =
Kind.NeedsPAuth ?
"+pauth" :
"";
445 Thunks.set(
Kind.Id, Xn, Xm);
446 createThunkFunction(MMI, ThunkName, ComdatThunks, TargetAttrs);
466 if (
Op.getReg() == AArch64::LR &&
Op.isDef())
468 if (
Op.getReg() == AArch64::SP && !
Op.isDef())
473 int FirstOpIdxToRemove = std::max(ImpLROpIdx, ImpSPOpIdx);
474 int SecondOpIdxToRemove = std::min(ImpLROpIdx, ImpSPOpIdx);
491bool SLSHardeningInserter::hardenBLRs(MachineModuleInfo &MMI,
492 MachineBasicBlock &
MBB,
500 NextMBBI = std::next(
MBBI);
502 convertBLRToBL(MMI,
MBB,
MBBI, Thunks);
510class AArch64SLSHardeningLegacy
511 :
public ThunkInserterPass<SLSHardeningInserter> {
515 AArch64SLSHardeningLegacy() : ThunkInserterPass(
ID) {}
522char AArch64SLSHardeningLegacy::ID = 0;
528 return new AArch64SLSHardeningLegacy();
536 .getCachedResult<MachineModuleAnalysis>(
538 assert(MMI &&
"MachineModuleAnalysis must be available");
540 SLSHardeningInserter Inserter;
543 if (Inserter.run(MMI->
getMMI(), MF)) {
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
#define AARCH64_SLS_HARDENING_NAME
static void insertSpeculationBarrier(const AArch64Subtarget *ST, MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, DebugLoc DL, bool AlwaysUseISBDSB=false)
static constexpr StringRef CommonNamePrefix
static SmallString< 32 > createThunkName(const ThunkKind &Kind, Register Xn, Register Xm)
static std::tuple< const ThunkKind &, Register, Register > parseThunkName(StringRef ThunkName)
static const ThunkKind * getThunkKind(unsigned OriginalOpcode)
static bool isBLR(const MachineInstr &MI)
MachineBasicBlock MachineBasicBlock::iterator DebugLoc bool AlwaysUseISBDSB
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
MachineBasicBlock MachineBasicBlock::iterator MBBI
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
const HexagonInstrInfo * TII
Contains a base ThunkInserter class that simplifies injection of MI thunks as well as a default imple...
Promote Memory to Register
MachineInstr unsigned OpIdx
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
This file declares the machine register scavenger class.
static bool contains(SmallPtrSetImpl< ConstantExpr * > &Cache, ConstantExpr *Expr, Constant *C)
This file implements the StringSwitch template, which mimics a switch() statement whose cases are str...
PreservedAnalyses run(MachineFunction &MF, MachineFunctionAnalysisManager &MFAM)
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represents analyses that only rely on functions' control flow.
FunctionPass class - This class is used to implement most global optimizations.
Module * getParent()
Get the module that this global value is contained inside of...
instr_iterator instr_begin()
LLVM_ABI iterator getFirstTerminator()
Returns an iterator to the first terminator instruction of this basic block.
Instructions::iterator instr_iterator
instr_iterator instr_end()
const MachineFunction * getParent() const
Return the MachineFunction containing this basic block.
LLVM_ABI instr_iterator erase(instr_iterator I)
Remove an instruction from the instruction list and delete it.
MachineInstrBundleIterator< MachineInstr > iterator
void moveAdditionalCallInfo(const MachineInstr *Old, const MachineInstr *New)
Move the call site info from Old to \New call site info.
const TargetSubtargetInfo & getSubtarget() const
getSubtarget - Return the subtarget for which this machine code is being compiled.
StringRef getName() const
getName - Return the name of the corresponding LLVM function.
void push_back(MachineBasicBlock *MBB)
MCContext & getContext() const
Function & getFunction()
Return the LLVM function that this machine code represents.
const MachineBasicBlock & front() const
MachineBasicBlock * CreateMachineBasicBlock(const BasicBlock *BB=nullptr, std::optional< UniqueBBID > BBID=std::nullopt)
CreateMachineInstr - Allocate a new MachineInstr.
const MachineInstrBuilder & addReg(Register RegNo, RegState Flags={}, unsigned SubReg=0) const
Add a new virtual register operand.
const MachineInstrBuilder & addImm(int64_t Val) const
Add a new immediate operand.
const MachineInstrBuilder & addSym(MCSymbol *Sym, unsigned char TargetFlags=0) const
Representation of each machine instruction.
unsigned getOpcode() const
Returns the opcode of this MachineInstr.
LLVM_ABI void addOperand(MachineFunction &MF, const MachineOperand &Op)
Add the specified operand to the instruction.
LLVM_ABI unsigned getNumExplicitOperands() const
Returns the number of non-implicit operands.
LLVM_ABI void copyImplicitOps(MachineFunction &MF, const MachineInstr &MI)
Copy implicit register operands from specified instruction to this instruction.
const DebugLoc & getDebugLoc() const
Returns the debug location id of this MachineInstr.
LLVM_ABI void removeOperand(unsigned OpNo)
Erase an operand from an instruction, leaving it with one fewer operand than it started with.
const MachineOperand & getOperand(unsigned i) const
MachineModuleInfo & getMMI()
This class contains meta information specific to a module.
Register getReg() const
getReg - Returns the register number.
static MachineOperand CreateReg(Register Reg, bool isDef, bool isImp=false, bool isKill=false, bool isDead=false, bool isUndef=false, bool isEarlyClobber=false, unsigned SubReg=0, bool isDebug=false, bool isInternalRead=false, bool isRenamable=false)
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalyses & preserveSet()
Mark an analysis set as preserved.
Wrapper class representing virtual and physical registers.
SmallString - A SmallString is just a SmallVector with methods and accessors that make it work better...
Represent a constant reference to a string, i.e.
std::pair< StringRef, StringRef > split(char Separator) const
Split into two substrings around the first occurrence of a separator character.
bool starts_with(StringRef Prefix) const
Check if this string starts with the given Prefix.
StringRef drop_front(size_t N=1) const
Return a StringRef equal to 'this' but with the first N elements dropped.
A switch()-like statement whose cases are string literals.
StringSwitch & StartsWith(StringLiteral S, T Value)
TargetInstrInfo - Interface to description of machine instruction set.
virtual const TargetInstrInfo * getInstrInfo() const
This class assists in inserting MI thunk functions into the module and rewriting the existing machine...
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
This is an optimization pass for GlobalISel generic memory operations.
OuterAnalysisManagerProxy< ModuleAnalysisManager, MachineFunction > ModuleAnalysisManagerMachineFunctionProxy
Provide the ModuleAnalysisManager to Function proxy.
MachineInstrBuilder BuildMI(MachineFunction &MF, const MIMetadata &MIMD, const MCInstrDesc &MCID)
Builder interface. Specify how to create the initial instruction itself.
static bool isIndirectBranchOpcode(int Opc)
FunctionPass * createAArch64SLSHardeningLegacyPass()
AnalysisManager< MachineFunction > MachineFunctionAnalysisManager
LLVM_ABI PreservedAnalyses getMachineFunctionPassPreservedAnalyses()
Returns the minimum set of Analyses that all machine function passes must preserve.
auto formatv(bool Validate, const char *Fmt, Ts &&...Vals)
decltype(auto) get(const PointerIntPair< PointerTy, IntBits, IntType, PtrTraits, Info > &Pair)
DWARFExpression::Operation Op
bool operator|=(SparseBitVector< ElementSize > &LHS, const SparseBitVector< ElementSize > *RHS)