LLVM: lib/Target/AArch64/AArch64SLSHardening.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

29#include

30#include

31#include

32

33using namespace llvm;

34

35#define DEBUG_TYPE "aarch64-sls-hardening"

36

37#define AARCH64_SLS_HARDENING_NAME "AArch64 sls hardening pass"

38

39

40

41

42

43

44

46

47namespace {

48

49struct ThunkKind {

50 enum ThunkKindId {

51 ThunkBR,

52 ThunkBRAA,

53 ThunkBRAB,

54 ThunkBRAAZ,

55 ThunkBRABZ,

56 };

57

58 ThunkKindId Id;

60 bool HasXmOperand;

61 bool NeedsPAuth;

62

63

64 unsigned BROpcode;

65

66 static const ThunkKind BR;

67 static const ThunkKind BRAA;

68 static const ThunkKind BRAB;

69 static const ThunkKind BRAAZ;

70 static const ThunkKind BRABZ;

71};

72

73

74class ThunksSet {

75public:

76 static constexpr unsigned NumXRegisters = 32;

77

78

79 static unsigned indexOfXReg(Register Xn);

80

81 static Register xRegByIndex(unsigned N);

82

84 BLRThunks |= Other.BLRThunks;

85 BLRAAZThunks |= Other.BLRAAZThunks;

86 BLRABZThunks |= Other.BLRABZThunks;

87 for (unsigned I = 0; I < NumXRegisters; ++I)

88 BLRAAThunks[I] |= Other.BLRAAThunks[I];

89 for (unsigned I = 0; I < NumXRegisters; ++I)

90 BLRABThunks[I] |= Other.BLRABThunks[I];

91

92 return *this;

93 }

94

96 reg_bitmask_t XnBit = reg_bitmask_t(1) << indexOfXReg(Xn);

97 return getBitmask(Kind, Xm) & XnBit;

98 }

99

100 void set(ThunkKind::ThunkKindId Kind, Register Xn, Register Xm) {

101 reg_bitmask_t XnBit = reg_bitmask_t(1) << indexOfXReg(Xn);

102 getBitmask(Kind, Xm) |= XnBit;

103 }

104

105private:

106 typedef uint32_t reg_bitmask_t;

107 static_assert(NumXRegisters <= sizeof(reg_bitmask_t) * CHAR_BIT,

108 "Bitmask is not wide enough to hold all Xn registers");

109

110

111

112

113

114

115 reg_bitmask_t BLRThunks = 0;

116 reg_bitmask_t BLRAAZThunks = 0;

117 reg_bitmask_t BLRABZThunks = 0;

118 reg_bitmask_t BLRAAThunks[NumXRegisters] = {};

119 reg_bitmask_t BLRABThunks[NumXRegisters] = {};

120

121 reg_bitmask_t &getBitmask(ThunkKind::ThunkKindId Kind, Register Xm) {

122 switch (Kind) {

123 case ThunkKind::ThunkBR:

124 return BLRThunks;

125 case ThunkKind::ThunkBRAAZ:

126 return BLRAAZThunks;

127 case ThunkKind::ThunkBRABZ:

128 return BLRABZThunks;

129 case ThunkKind::ThunkBRAA:

130 return BLRAAThunks[indexOfXReg(Xm)];

131 case ThunkKind::ThunkBRAB:

132 return BLRABThunks[indexOfXReg(Xm)];

133 }

135 }

136};

137

138struct SLSHardeningInserter : ThunkInserter<SLSHardeningInserter, ThunksSet> {

139public:

140 const char *getThunkPrefix() { return CommonNamePrefix.data(); }

143

144

147 }

149 ThunksSet ExistingThunks);

151

152private:

153 bool ComdatThunks = true;

154

157 ThunksSet &Thunks);

158

161 ThunksSet &Thunks);

162};

163

164}

165

166const ThunkKind ThunkKind::BR = {ThunkBR, "", false,

167 false, AArch64::BR};

168const ThunkKind ThunkKind::BRAA = {ThunkBRAA, "aa_", true,

169 true, AArch64::BRAA};

170const ThunkKind ThunkKind::BRAB = {ThunkBRAB, "ab_", true,

171 true, AArch64::BRAB};

172const ThunkKind ThunkKind::BRAAZ = {ThunkBRAAZ, "aaz_", false,

173 true, AArch64::BRAAZ};

174const ThunkKind ThunkKind::BRABZ = {ThunkBRABZ, "abz_", false,

175 true, AArch64::BRABZ};

176

177

178static const ThunkKind *getThunkKind(unsigned OriginalOpcode) {

179 switch (OriginalOpcode) {

180 case AArch64::BLR:

181 case AArch64::BLRNoIP:

182 return &ThunkKind::BR;

183 case AArch64::BLRAA:

184 return &ThunkKind::BRAA;

185 case AArch64::BLRAB:

186 return &ThunkKind::BRAB;

187 case AArch64::BLRAAZ:

188 return &ThunkKind::BRAAZ;

189 case AArch64::BLRABZ:

190 return &ThunkKind::BRABZ;

191 }

192 return nullptr;

193}

194

198

199unsigned ThunksSet::indexOfXReg(Register Reg) {

201 assert(Reg != AArch64::X16 && Reg != AArch64::X17 && Reg != AArch64::LR);

202

203

204 unsigned Result = (unsigned)Reg - (unsigned)AArch64::X0;

205 if (Reg == AArch64::FP)

207 else if (Reg == AArch64::XZR)

209

210 assert(Result < NumXRegisters && "Internal register numbering changed");

211 assert(AArch64::GPR64RegClass.getRegister(Result).id() == Reg &&

212 "Internal register numbering changed");

213

215}

216

217Register ThunksSet::xRegByIndex(unsigned N) {

218 return AArch64::GPR64RegClass.getRegister(N);

219}

220

227 "Must not insert SpeculationBarrierEndBB as only instruction in MBB.");

228 assert(std::prev(MBBI)->isBarrier() &&

229 "SpeculationBarrierEndBB must only follow unconditional control flow "

230 "instructions.");

231 assert(std::prev(MBBI)->isTerminator() &&

232 "SpeculationBarrierEndBB must only follow terminators.");

235 ? AArch64::SpeculationBarrierSBEndBB

236 : AArch64::SpeculationBarrierISBDSBEndBB;

238 (MBBI->getOpcode() != AArch64::SpeculationBarrierSBEndBB &&

239 MBBI->getOpcode() != AArch64::SpeculationBarrierISBDSBEndBB))

241}

242

243ThunksSet SLSHardeningInserter::insertThunks(MachineModuleInfo &MMI,

244 MachineFunction &MF,

245 ThunksSet ExistingThunks) {

246 const AArch64Subtarget *ST = &MF.getSubtarget();

247

248 for (auto &MBB : MF) {

249 if (ST->hardenSlsRetBr())

250 hardenReturnsAndBRs(MMI, MBB);

251 if (ST->hardenSlsBlr())

252 hardenBLRs(MMI, MBB, ExistingThunks);

253 }

254 return ExistingThunks;

255}

256

257bool SLSHardeningInserter::hardenReturnsAndBRs(MachineModuleInfo &MMI,

258 MachineBasicBlock &MBB) {

259 const AArch64Subtarget *ST =

264 for (; MBBI != E; MBBI = NextMBBI) {

265 MachineInstr &MI = *MBBI;

266 NextMBBI = std::next(MBBI);

271 }

272 }

274}

275

276

277

278

281 unsigned N = ThunksSet::indexOfXReg(Xn);

282 if (!Kind.HasXmOperand)

284

285 unsigned M = ThunksSet::indexOfXReg(Xm);

287}

288

289static std::tuple<const ThunkKind &, Register, Register>

292 "Should be filtered out by ThunkInserter");

293

295

296

300 .StartsWith("aaz_", &ThunkKind::BRAAZ)

301 .StartsWith("abz_", &ThunkKind::BRABZ)

302 .Default(&ThunkKind::BR);

303

304 auto ParseRegName = [](StringRef Name) {

305 unsigned N;

306

307 assert(Name.starts_with("x") && "xN register name expected");

308 bool Fail = Name.drop_front(1).getAsInteger(10, N);

309 assert(Fail && N < ThunksSet::NumXRegisters && "Unexpected register");

311

312 return ThunksSet::xRegByIndex(N);

313 };

314

315

318 std::tie(XnStr, XmStr) = RegsStr.split('_');

319

320

321 Register Xn = ParseRegName(XnStr);

322 Register Xm = Kind.HasXmOperand ? ParseRegName(XmStr) : AArch64::NoRegister;

323

324 return std::make_tuple(std::ref(Kind), Xn, Xm);

325}

326

327void SLSHardeningInserter::populateThunk(MachineFunction &MF) {

329 "ComdatThunks value changed since MF creation");

332 const ThunkKind &Kind = std::get<0>(KindAndRegs);

333 std::tie(std::ignore, Xn, Xm) = KindAndRegs;

334

335 const TargetInstrInfo *TII =

336 MF.getSubtarget().getInstrInfo();

337

338

339

340

341

342 if (MF.size() == 1) {

346 } else {

349 }

350

351 MachineBasicBlock *Entry = &MF.front();

353

354

355

356

357

358

359

360

361

362 Entry->addLiveIn(Xn);

363

365 .addReg(AArch64::XZR)

368 MachineInstrBuilder Builder =

370 if (Xm != AArch64::NoRegister) {

371 Entry->addLiveIn(Xm);

373 }

374

375

376

377

378

380 Entry->end(), DebugLoc(), true );

381}

382

383void SLSHardeningInserter::convertBLRToBL(

384 MachineModuleInfo &MMI, MachineBasicBlock &MBB,

386

387

388

389

390

391

392

393

394

395

396

397

398

399

400

401

402

403

404

405

406

407

408

409

410

411

412

413

414

415

416

417

418

419

420

421

422 MachineInstr &BLR = *MBBI;

425

426 unsigned NumRegOperands = Kind.HasXmOperand ? 2 : 1;

428 "Expected one or two register inputs");

432

434

435 MachineFunction &MF = *MBBI->getMF();

438

441

442 if (!Thunks.get(Kind.Id, Xn, Xm)) {

443 StringRef TargetAttrs = Kind.NeedsPAuth ? "+pauth" : "";

444 Thunks.set(Kind.Id, Xn, Xm);

445 createThunkFunction(MMI, ThunkName, ComdatThunks, TargetAttrs);

446 }

447

449

450

451

452

453

454

455

456

457

458 int ImpLROpIdx = -1;

459 int ImpSPOpIdx = -1;

463 if (Op.isReg())

464 continue;

465 if (Op.getReg() == AArch64::LR && Op.isDef())

466 ImpLROpIdx = OpIdx;

467 if (Op.getReg() == AArch64::SP && Op.isDef())

468 ImpSPOpIdx = OpIdx;

469 }

470 assert(ImpLROpIdx != -1);

471 assert(ImpSPOpIdx != -1);

472 int FirstOpIdxToRemove = std::max(ImpLROpIdx, ImpSPOpIdx);

473 int SecondOpIdxToRemove = std::min(ImpLROpIdx, ImpSPOpIdx);

476

479

480

481 for (unsigned OpIdx = 0; OpIdx < NumRegOperands; ++OpIdx) {

484 true, Op.isKill()));

485 }

486

488}

489

490bool SLSHardeningInserter::hardenBLRs(MachineModuleInfo &MMI,

491 MachineBasicBlock &MBB,

492 ThunksSet &Thunks) {

497 for (; MBBI != E; MBBI = NextMBBI) {

498 MachineInstr &MI = *MBBI;

499 NextMBBI = std::next(MBBI);

501 convertBLRToBL(MMI, MBB, MBBI, Thunks);

503 }

504 }

506}

507

508namespace {

509class AArch64SLSHardening : public ThunkInserterPass {

510public:

511 static char ID;

512

513 AArch64SLSHardening() : ThunkInserterPass(ID) {}

514

516};

517

518}

519

520char AArch64SLSHardening::ID = 0;

521

524

526 return new AArch64SLSHardening();

527}

assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")

const TargetInstrInfo & TII

#define AARCH64_SLS_HARDENING_NAME

Definition AArch64SLSHardening.cpp:37

static void insertSpeculationBarrier(const AArch64Subtarget *ST, MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, DebugLoc DL, bool AlwaysUseISBDSB=false)

Definition AArch64SLSHardening.cpp:221

static constexpr StringRef CommonNamePrefix

Definition AArch64SLSHardening.cpp:45

static SmallString< 32 > createThunkName(const ThunkKind &Kind, Register Xn, Register Xm)

Definition AArch64SLSHardening.cpp:279

static std::tuple< const ThunkKind &, Register, Register > parseThunkName(StringRef ThunkName)

Definition AArch64SLSHardening.cpp:290

static const ThunkKind * getThunkKind(unsigned OriginalOpcode)

Definition AArch64SLSHardening.cpp:178

static bool isBLR(const MachineInstr &MI)

Definition AArch64SLSHardening.cpp:195

MachineBasicBlock MachineBasicBlock::iterator DebugLoc bool AlwaysUseISBDSB

MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL

MachineBasicBlock MachineBasicBlock::iterator MBBI

static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")

Contains a base ThunkInserter class that simplifies injection of MI thunks as well as a default imple...

Promote Memory to Register

MachineInstr unsigned OpIdx

#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)

This file declares the machine register scavenger class.

static bool contains(SmallPtrSetImpl< ConstantExpr * > &Cache, ConstantExpr *Expr, Constant *C)

This file implements the StringSwitch template, which mimics a switch() statement whose cases are str...

FunctionPass class - This class is used to implement most global optimizations.

const MCInstrDesc & get(unsigned Opcode) const

Return the machine instruction descriptor that corresponds to the specified instruction opcode.

instr_iterator instr_begin()

LLVM_ABI iterator getFirstTerminator()

Returns an iterator to the first terminator instruction of this basic block.

Instructions::iterator instr_iterator

instr_iterator instr_end()

const MachineFunction * getParent() const

Return the MachineFunction containing this basic block.

LLVM_ABI instr_iterator erase(instr_iterator I)

Remove an instruction from the instruction list and delete it.

MachineInstrBundleIterator< MachineInstr > iterator

void moveAdditionalCallInfo(const MachineInstr *Old, const MachineInstr *New)

Move the call site info from Old to \New call site info.

const TargetSubtargetInfo & getSubtarget() const

getSubtarget - Return the subtarget for which this machine code is being compiled.

StringRef getName() const

getName - Return the name of the corresponding LLVM function.

void push_back(MachineBasicBlock *MBB)

MCContext & getContext() const

Function & getFunction()

Return the LLVM function that this machine code represents.

const MachineBasicBlock & front() const

MachineBasicBlock * CreateMachineBasicBlock(const BasicBlock *BB=nullptr, std::optional< UniqueBBID > BBID=std::nullopt)

CreateMachineInstr - Allocate a new MachineInstr.

const MachineInstrBuilder & addImm(int64_t Val) const

Add a new immediate operand.

const MachineInstrBuilder & addSym(MCSymbol *Sym, unsigned char TargetFlags=0) const

const MachineInstrBuilder & addReg(Register RegNo, unsigned flags=0, unsigned SubReg=0) const

Add a new virtual register operand.

Representation of each machine instruction.

unsigned getOpcode() const

Returns the opcode of this MachineInstr.

LLVM_ABI void addOperand(MachineFunction &MF, const MachineOperand &Op)

Add the specified operand to the instruction.

LLVM_ABI unsigned getNumExplicitOperands() const

Returns the number of non-implicit operands.

LLVM_ABI void copyImplicitOps(MachineFunction &MF, const MachineInstr &MI)

Copy implicit register operands from specified instruction to this instruction.

const DebugLoc & getDebugLoc() const

Returns the debug location id of this MachineInstr.

LLVM_ABI void removeOperand(unsigned OpNo)

Erase an operand from an instruction, leaving it with one fewer operand than it started with.

const MachineOperand & getOperand(unsigned i) const

This class contains meta information specific to a module.

Register getReg() const

getReg - Returns the register number.

static MachineOperand CreateReg(Register Reg, bool isDef, bool isImp=false, bool isKill=false, bool isDead=false, bool isUndef=false, bool isEarlyClobber=false, unsigned SubReg=0, bool isDebug=false, bool isInternalRead=false, bool isRenamable=false)

Wrapper class representing virtual and physical registers.

SmallString - A SmallString is just a SmallVector with methods and accessors that make it work better...

StringRef - Represent a constant reference to a string, i.e.

std::pair< StringRef, StringRef > split(char Separator) const

Split into two substrings around the first occurrence of a separator character.

bool starts_with(StringRef Prefix) const

Check if this string starts with the given Prefix.

StringRef drop_front(size_t N=1) const

Return a StringRef equal to 'this' but with the first N elements dropped.

A switch()-like statement whose cases are string literals.

StringSwitch & StartsWith(StringLiteral S, T Value)

TargetInstrInfo - Interface to description of machine instruction set.

virtual const TargetInstrInfo * getInstrInfo() const

This class assists in inserting MI thunk functions into the module and rewriting the existing machine...

#define llvm_unreachable(msg)

Marks that the current location is not supposed to be reachable.

unsigned ID

LLVM IR allows to use arbitrary numbers as calling convention identifiers.

This is an optimization pass for GlobalISel generic memory operations.

MachineInstrBuilder BuildMI(MachineFunction &MF, const MIMetadata &MIMD, const MCInstrDesc &MCID)

Builder interface. Specify how to create the initial instruction itself.

static bool isIndirectBranchOpcode(int Opc)

auto formatv(bool Validate, const char *Fmt, Ts &&...Vals)

decltype(auto) get(const PointerIntPair< PointerTy, IntBits, IntType, PtrTraits, Info > &Pair)

FunctionPass * createAArch64SLSHardeningPass()

DWARFExpression::Operation Op

bool operator|=(SparseBitVector< ElementSize > &LHS, const SparseBitVector< ElementSize > *RHS)