LLVM: include/llvm/CodeGen/MIR2Vec.h Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40#ifndef LLVM_CODEGEN_MIR2VEC_H

41#define LLVM_CODEGEN_MIR2VEC_H

42

56#include

57#include

58#include

59

60namespace llvm {

61

67

69

71

72

75

78

83

84

85

88 using VocabMap = std::map<std::string, ir2vec::Embedding>;

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106 struct {

112 } Layout;

113

114

115

116

117

118

119 enum class Section : unsigned {

120 Opcodes = 0,

121 CommonOperands = 1,

122 PhyRegisters = 2,

123 VirtRegisters = 3,

124 MaxSections

125 };

126

128 std::setstd::string UniqueBaseOpcodeNames;

130

131

132

134

135

136

137 static constexpr StringLiteral CommonOperandNames[] = {

138 "Immediate", "CImmediate", "FPImmediate", "MBB",

139 "FrameIndex", "ConstantPoolIndex", "TargetIndex", "JumpTableIndex",

140 "ExternalSymbol", "GlobalAddress", "BlockAddress", "RegisterMask",

141 "RegisterLiveOut", "Metadata", "MCSymbol", "CFIIndex",

142 "IntrinsicID", "Predicate", "ShuffleMask", "LaneMask"};

144 "Common operand names size changed, update accordingly");

145

146 const TargetInstrInfo &TII;

147 const TargetRegisterInfo &TRI;

148 const MachineRegisterInfo &MRI;

149

150 void generateStorage(const VocabMap &OpcodeMap,

151 const VocabMap &CommonOperandMap,

152 const VocabMap &PhyRegMap, const VocabMap &VirtRegMap);

153 void buildCanonicalOpcodeMapping();

154 void buildRegisterOperandMapping();

155

156

157 LLVM_ABI unsigned getCanonicalOpcodeIndex(unsigned Opcode) const;

158

159

160 unsigned

162

163

165

166

169 unsigned LocalIndex = getCommonOperandIndex(OperandType);

170 return Storage[static_cast<unsigned>(Section::CommonOperands)][LocalIndex];

171 }

172

174

175

176 if (Reg.isValid())

177 return ZeroEmbedding;

178

179

180

181

182

183

184 if (Reg.isStack())

185 return ZeroEmbedding;

186

187 unsigned LocalIndex = getRegisterOperandIndex(Reg);

188 auto SectionID =

189 Reg.isPhysical() ? Section::PhyRegisters : Section::VirtRegisters;

190 return Storage[static_cast<unsigned>(SectionID)][LocalIndex];

191 }

192

193

194

195 LLVM_ABI unsigned getEntityIDForCommonOperand(

197 return Layout.CommonOperandBase + getCommonOperandIndex(OperandType);

198 }

199

200

201

202 unsigned getEntityIDForRegister(Register Reg) const {

203 if (Reg.isValid() || Reg.isStack())

204 return Layout

205 .VirtRegBase;

206 unsigned LocalIndex = getRegisterOperandIndex(Reg);

207 size_t BaseOffset =

208 Reg.isPhysical() ? Layout.PhyRegBase : Layout.VirtRegBase;

209 return BaseOffset + LocalIndex;

210 }

211

212public:

213

216

217

218

225 bool IsPhysical = true) const;

226

227

229

230 unsigned getDimension() const { return Storage.getDimension(); }

231

232

233

235 return Layout.OpcodeBase + getCanonicalOpcodeIndex(Opcode);

236 }

237

238

239

242 return getEntityIDForRegister(MO.getReg());

243 return getEntityIDForCommonOperand(MO.getType());

244 }

245

246

248 unsigned LocalIndex = getCanonicalOpcodeIndex(Opcode);

249 return Storage[static_cast<unsigned>(Section::Opcodes)][LocalIndex];

250 }

251

253 auto OperandType = Operand.getType();

255 return operator[](Operand.getReg());

256 else

257 return operator[](OperandType);

258 }

259

260

263

265

267

268

270 create(VocabMap &&OpcMap, VocabMap &&CommonOperandsMap, VocabMap &&PhyRegMap,

273

274

279

280

282

283private:

284 MIRVocabulary(VocabMap &&OpcMap, VocabMap &&CommonOperandsMap,

285 VocabMap &&PhyRegMap, VocabMap &&VirtRegMap,

288};

289

290

292protected:

295

296

298

299

301

307

308

310

311

313

314

315

317

318public:

320

321

322

323 LLVM_ABI static std::unique_ptr

326

327

328

332

333

334

338

339

345};

346

347

348

349

351private:

353

354public:

358};

359

360}

361

362

363

364

365

366

367

368

369

371 using VocabMap = std::map<std::string, mir2vec::Embedding>;

372

373public:

375

377

378private:

379 Error readVocabulary(VocabMap &OpcVocab, VocabMap &CommonOperandVocab,

380 VocabMap &PhyRegVocabMap, VocabMap &VirtRegVocabMap);

382};

383

384

386 using VocabVector = std::vectormir2vec::Embedding;

387 using VocabMap = std::map<std::string, mir2vec::Embedding>;

388

390

391protected:

396 std::unique_ptr Provider;

397

398public:

401

406 Provider = std::make_unique(MMI);

407 return Provider->getVocabulary(M);

408 }

409

414};

415

416

419

420public:

424

432

434 return "MIR2Vec Vocabulary Printer Pass";

435 }

436};

437

438

439

442

443public:

447

454

456 return "MIR2Vec Embedder Printer Pass";

457 }

458};

459

460

462

463}

464

465#endif

unsigned const MachineRegisterInfo * MRI

assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")

const TargetInstrInfo & TII

#define LLVM_ABI_FOR_TEST

Provides ErrorOr smart pointer.

This file defines the IR2Vec vocabulary analysis(IR2VecVocabAnalysis), the core ir2vec::Embedder inte...

This header defines various interfaces for pass management in LLVM.

Register const TargetRegisterInfo * TRI

Promote Memory to Register

static cl::opt< RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode > Mode("regalloc-enable-advisor", cl::Hidden, cl::init(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Default), cl::desc("Enable regalloc advisor mode"), cl::values(clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Default, "default", "Default"), clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Release, "release", "precompiled"), clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Development, "development", "for training")))

Represent the analysis usage information of a pass.

AnalysisUsage & addRequired()

void setPreservesAll()

Set by analyses that do not transform their input at all.

Lightweight error class with error context and mandatory checking.

Tagged union holding either a T or a Error.

This is an important class for using LLVM in a threaded context.

MIR2VecPrinterLegacyPass(raw_ostream &OS)

Definition MIR2Vec.h:445

void getAnalysisUsage(AnalysisUsage &AU) const override

getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.

Definition MIR2Vec.h:449

bool runOnMachineFunction(MachineFunction &MF) override

runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...

static char ID

Definition MIR2Vec.h:444

StringRef getPassName() const override

getPassName - Return a nice clean name for a pass.

Definition MIR2Vec.h:455

Pass to analyze and populate MIR2Vec vocabulary from a module.

Definition MIR2Vec.h:385

static char ID

Definition MIR2Vec.h:399

MIR2VecVocabProvider & getProvider()

Definition MIR2Vec.h:410

Expected< mir2vec::MIRVocabulary > getMIR2VecVocabulary(const Module &M)

Definition MIR2Vec.h:402

std::unique_ptr< MIR2VecVocabProvider > Provider

Definition MIR2Vec.h:396

void getAnalysisUsage(AnalysisUsage &AU) const override

getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...

Definition MIR2Vec.h:392

MIR2VecVocabLegacyAnalysis()

Definition MIR2Vec.h:400

StringRef getPassName() const override

getPassName - Return a nice clean name for a pass.

Definition MIR2Vec.h:433

bool doFinalization(Module &M) override

doFinalization - Virtual method overriden by subclasses to do any necessary clean up after all passes...

static char ID

Definition MIR2Vec.h:421

bool runOnMachineFunction(MachineFunction &MF) override

runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...

MIR2VecVocabPrinterLegacyPass(raw_ostream &OS)

Definition MIR2Vec.h:422

void getAnalysisUsage(AnalysisUsage &AU) const override

getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.

Definition MIR2Vec.h:427

MIR2Vec vocabulary provider used by pass managers and standalone tools.

Definition MIR2Vec.h:370

MIR2VecVocabProvider(const MachineModuleInfo &MMI)

Definition MIR2Vec.h:374

LLVM_ABI Expected< mir2vec::MIRVocabulary > getVocabulary(const Module &M)

MachineFunctionPass(char &ID)

void getAnalysisUsage(AnalysisUsage &AU) const override

getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.

Representation of each machine instruction.

This class contains meta information specific to a module.

MachineOperand class - Representation of each machine instruction operand.

MachineOperandType getType() const

getType - Returns the MachineOperandType for this operand.

Register getReg() const

getReg - Returns the register number.

@ MO_Register

Register operand.

MachineRegisterInfo - Keep track of information for virtual and physical registers,...

A Module instance is used to store all the information related to an LLVM module.

AnalysisType & getAnalysis() const

getAnalysis() - This function is used by subclasses to get to the analysis information ...

virtual StringRef getPassName() const

getPassName - Return a nice clean name for a pass.

This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.

A wrapper around a string literal that serves as a proxy for constructing global tables of StringRefs...

StringRef - Represent a constant reference to a string, i.e.

TargetInstrInfo - Interface to description of machine instruction set.

TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...

Iterator support for section-based access.

Generic storage class for section-based vocabularies.

Base class for MIR embedders.

Definition MIR2Vec.h:291

const unsigned Dimension

Dimension of the embeddings; Captured from the vocabulary.

Definition MIR2Vec.h:297

Embedding getMFunctionVector() const

Computes and returns the embedding for the current machine function.

Definition MIR2Vec.h:340

const MIRVocabulary & Vocab

Definition MIR2Vec.h:294

Embedding getMInstVector(const MachineInstr &MI) const

Computes and returns the embedding for a given machine instruction MI in the machine function MF.

Definition MIR2Vec.h:329

virtual Embedding computeEmbeddings(const MachineInstr &MI) const =0

Function to compute the embedding for a given machine instruction.

Embedding getMBBVector(const MachineBasicBlock &MBB) const

Computes and returns the embedding for a given machine basic block in the machine function MF.

Definition MIR2Vec.h:335

const float RegOperandWeight

Definition MIR2Vec.h:300

MIREmbedder(const MachineFunction &MF, const MIRVocabulary &Vocab)

Definition MIR2Vec.h:302

const float CommonOperandWeight

Definition MIR2Vec.h:300

LLVM_ABI Embedding computeEmbeddings() const

Function to compute embeddings.

const float OpcWeight

Weight for opcode embeddings.

Definition MIR2Vec.h:300

const MachineFunction & MF

Definition MIR2Vec.h:293

virtual ~MIREmbedder()=default

static LLVM_ABI std::unique_ptr< MIREmbedder > create(MIR2VecKind Mode, const MachineFunction &MF, const MIRVocabulary &Vocab)

Factory method to create an Embedder object of the specified kind Returns nullptr if the requested ki...

Class for storing and accessing the MIR2Vec vocabulary.

Definition MIR2Vec.h:86

size_t OpcodeBase

Definition MIR2Vec.h:107

unsigned getDimension() const

Definition MIR2Vec.h:230

unsigned getEntityIDForOpcode(unsigned Opcode) const

Get entity ID (flat index) for an opcode This is used for triplet generation.

Definition MIR2Vec.h:234

const_iterator end() const

Definition MIR2Vec.h:264

LLVM_ABI_FOR_TEST unsigned getCanonicalIndexForOperandName(StringRef OperandName) const

const Embedding & operator[](MachineOperand Operand) const

Definition MIR2Vec.h:252

LLVM_ABI_FOR_TEST unsigned getCanonicalIndexForRegisterClass(StringRef RegName, bool IsPhysical=true) const

static LLVM_ABI_FOR_TEST Expected< MIRVocabulary > create(VocabMap &&OpcMap, VocabMap &&CommonOperandsMap, VocabMap &&PhyRegMap, VocabMap &&VirtRegMap, const TargetInstrInfo &TII, const TargetRegisterInfo &TRI, const MachineRegisterInfo &MRI)

Factory method to create MIRVocabulary from vocabulary map.

static LLVM_ABI_FOR_TEST std::string extractBaseOpcodeName(StringRef InstrName)

Static method for extracting base opcode names (public for testing)

ir2vec::VocabStorage::const_iterator const_iterator

Definition MIR2Vec.h:261

const_iterator begin() const

Definition MIR2Vec.h:262

const Embedding & operator[](unsigned Opcode) const

Definition MIR2Vec.h:247

size_t getCanonicalSize() const

Total number of entries in the vocabulary.

Definition MIR2Vec.h:281

static LLVM_ABI Expected< MIRVocabulary > createDummyVocabForTest(const TargetInstrInfo &TII, const TargetRegisterInfo &TRI, const MachineRegisterInfo &MRI, unsigned Dim=1)

Create a dummy vocabulary for testing purposes.

unsigned getEntityIDForMachineOperand(const MachineOperand &MO) const

Get entity ID (flat index) for a machine operand This is used for triplet generation.

Definition MIR2Vec.h:240

size_t PhyRegBase

Definition MIR2Vec.h:109

LLVM_ABI std::string getStringKey(unsigned Pos) const

Get the string key for a vocabulary entry at the given position.

size_t CommonOperandBase

Definition MIR2Vec.h:108

size_t VirtRegBase

Definition MIR2Vec.h:110

size_t TotalEntries

Definition MIR2Vec.h:111

LLVM_ABI_FOR_TEST unsigned getCanonicalIndexForBaseName(StringRef BaseName) const

Get indices from opcode or operand names.

Class for computing Symbolic embeddings Symbolic embeddings are constructed based on the entity-level...

Definition MIR2Vec.h:350

static LLVM_ABI_FOR_TEST std::unique_ptr< SymbolicMIREmbedder > create(const MachineFunction &MF, const MIRVocabulary &Vocab)

SymbolicMIREmbedder(const MachineFunction &F, const MIRVocabulary &Vocab)

This class implements an extremely fast bulk output stream that can only output to a stream.

DenseMap< const MachineInstr *, Embedding > MachineInstEmbeddingsMap

Definition MIR2Vec.h:80

LLVM_ABI llvm:🆑:OptionCategory MIR2VecCategory

LLVM_ABI cl::opt< float > OpcWeight

LLVM_ABI cl::opt< float > RegOperandWeight

Definition MIR2Vec.h:77

ir2vec::Embedding Embedding

Definition MIR2Vec.h:79

DenseMap< const MachineBasicBlock *, Embedding > MachineBlockEmbeddingsMap

Definition MIR2Vec.h:81

LLVM_ABI cl::opt< float > CommonOperandWeight

Definition MIR2Vec.h:77

This is an optimization pass for GlobalISel generic memory operations.

LLVM_ABI MachineFunctionPass * createMIR2VecPrinterLegacyPass(raw_ostream &OS)

Create a machine pass that prints MIR2Vec embeddings.

MIR2VecKind

Definition MIR2Vec.h:68

Embedding is a datatype that wraps std::vector.