LLVM: include/llvm/Analysis/IR2Vec.h Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35#ifndef LLVM_ANALYSIS_IR2VEC_H

36#define LLVM_ANALYSIS_IR2VEC_H

37

46#include

47#include

48#include

49

50namespace llvm {

51

60

61

62

63

64

65

66

67

68

69

70

72

74

80

81

82

83

84

85

86

88private:

89 std::vector Data;

90

91public:

93 Embedding(const std::vector &V) : Data(V) {}

95 Embedding(std::initializer_list IL) : Data(IL) {}

96

99

100 size_t size() const { return Data.size(); }

101 bool empty() const { return Data.empty(); }

102

104 assert(Itr < Data.size() && "Index out of bounds");

105 return Data[Itr];

106 }

107

109 assert(Itr < Data.size() && "Index out of bounds");

110 return Data[Itr];

111 }

112

113 using iterator = std::vector::iterator;

115

122

123 const std::vector &getData() const { return Data; }

124

125

132

133

134

136

137

138

140 double Tolerance = 1e-4) const;

141

143};

144

147

148

149

150

152private:

153

154 std::vector<std::vector> Sections;

155

156

157

158

159 size_t TotalSize = 0;

160 unsigned Dimension = 0;

161

162public:

163

165

166

168

171

174

175

176 size_t size() const { return TotalSize; }

177

178

180 return static_cast<unsigned>(Sections.size());

181 }

182

183

184 const std::vector &operator[](unsigned SectionId) const {

185 assert(SectionId < Sections.size() && "Invalid section ID");

186 return Sections[SectionId];

187 }

188

189

191

192

193 bool isValid() const { return TotalSize > 0; }

194

195

198 unsigned SectionId = 0;

199 size_t LocalIndex = 0;

200

201 public:

203 size_t LocalIndex)

204 : Storage(Storage), SectionId(SectionId), LocalIndex(LocalIndex) {}

205

210 };

211

216

217 using VocabMap = std::map<std::string, Embedding>;

218

219

222 VocabMap &TargetVocab, unsigned &Dim);

223};

224

225

226

227

228

229

230

231

232

233

234

235

236

237

238

239

240

241

244

245

246

247

248

249

250

251

252

253

254

255

256

257

258

259

260

261 enum class Section : unsigned {

262 Opcodes = 0,

263 CanonicalTypes = 1,

264 Operands = 2,

265 Predicates = 3,

266 MaxSections

267 };

268

269

271

272 static constexpr unsigned NumICmpPredicates =

275 static constexpr unsigned NumFCmpPredicates =

278

279public:

280

296

297

305

306

307#define LAST_OTHER_INST(NUM) static constexpr unsigned MaxOpcodes = NUM;

308#include "llvm/IR/Instruction.def"

309#undef LAST_OTHER_INST

310

311 static constexpr unsigned MaxTypeIDs = Type::TypeID::TargetExtTyID + 1;

316

317

319 NumICmpPredicates + NumFCmpPredicates;

320

323

326

329

331 return Storage.size() == NumCanonicalEntries;

332 }

333

335 assert(isValid() && "IR2Vec Vocabulary is invalid");

336 return Storage.getDimension();

337 }

338

339

340

341 static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; }

342

343

345

346

348 return getVocabKeyForCanonicalTypeID(getCanonicalTypeID(TypeID));

349 }

350

351

353 unsigned Index = static_cast<unsigned>(Kind);

355 return OperandKindNames[Index];

356 }

357

358

360

361

363

364

366 assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");

367 return Opcode - 1;

368 }

369

372 return MaxOpcodes + static_cast<unsigned>(getCanonicalTypeID(TypeID));

373 }

374

378 return OperandBaseOffset + Index;

379 }

380

382 return PredicateBaseOffset + getPredicateLocalIndex(P);

383 }

384

385

387 assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");

388 return Storage[static_cast<unsigned>(Section::Opcodes)][Opcode - 1];

389 }

390

393 unsigned LocalIndex = static_cast<unsigned>(getCanonicalTypeID(TypeID));

394 return Storage[static_cast<unsigned>(Section::CanonicalTypes)][LocalIndex];

395 }

396

398 unsigned LocalIndex = static_cast<unsigned>(getOperandKind(&Arg));

400 return Storage[static_cast<unsigned>(Section::Operands)][LocalIndex];

401 }

402

404 unsigned LocalIndex = getPredicateLocalIndex(P);

405 return Storage[static_cast<unsigned>(Section::Predicates)][LocalIndex];

406 }

407

408

410

412 assert(isValid() && "IR2Vec Vocabulary is invalid");

413 return Storage.begin();

414 }

415

417

419 assert(isValid() && "IR2Vec Vocabulary is invalid");

420 return Storage.end();

421 }

422

424

425

426

427

429

430

432

434 ModuleAnalysisManager::Invalidator &Inv) const;

435

436private:

437 constexpr static unsigned NumCanonicalEntries =

439

440

441 constexpr static unsigned OperandBaseOffset =

443 constexpr static unsigned PredicateBaseOffset =

445

446

448 static CmpInst::Predicate getPredicateFromLocalIndex(unsigned LocalIndex);

449

450

451 static constexpr StringLiteral CanonicalTypeNames[] = {

452 "FloatTy", "VoidTy", "LabelTy", "MetadataTy",

453 "VectorTy", "TokenTy", "IntegerTy", "FunctionTy",

454 "PointerTy", "StructTy", "ArrayTy", "UnknownTy"};

455 static_assert(std::size(CanonicalTypeNames) ==

457 "CanonicalTypeNames array size must match MaxCanonicalType");

458

459

460 static constexpr StringLiteral OperandKindNames[] = {"Function", "Pointer",

461 "Constant", "Variable"};

462 static_assert(std::size(OperandKindNames) ==

464 "OperandKindNames array size must match MaxOperandKind");

465

466

467

468 static constexpr std::array<CanonicalTypeID, MaxTypeIDs> TypeIDMapping = {{

490 }};

491 static_assert(TypeIDMapping.size() == MaxTypeIDs,

492 "TypeIDMapping must cover all Type::TypeID values");

493

494

497 unsigned Index = static_cast<unsigned>(CType);

499 return CanonicalTypeNames[Index];

500 }

501

502

504 unsigned Index = static_cast<unsigned>(TypeID);

506 return TypeIDMapping[Index];

507 }

508

509

510

511

514 return getPredicateFromLocalIndex(Index);

515 }

516};

517

518

519

520

521

522

523

525protected:

528

529

531

532

533

535

540

541

543

544

546

547

548

550

551public:

553

554

555 LLVM_ABI static std::unique_ptr

557

558

559

563

564

565

569

570

572

573

574

575

576

577

579};

580

581

582

583

585private:

587

588public:

591};

592

593

594

595

597private:

598

599

602

603public:

607};

608

609}

610

611

612

613

615 using VocabMap = std::map<std::string, ir2vec::Embedding>;

616 std::optionalir2vec::VocabStorage Vocab;

617

618 Error readVocabulary(VocabMap &OpcVocab, VocabMap &TypeVocab,

619 VocabMap &ArgVocab);

620 void generateVocabStorage(VocabMap &OpcVocab, VocabMap &TypeVocab,

621 VocabMap &ArgVocab);

623

624public:

631};

632

633

634

643

644

653

654}

655

656#endif

assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")

This file defines the DenseMap class.

Provides ErrorOr smart pointer.

This header defines various interfaces for pass management in LLVM.

This file supports working with JSON data.

ModuleAnalysisManager MAM

static cl::opt< RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode > Mode("regalloc-enable-advisor", cl::Hidden, cl::init(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Default), cl::desc("Enable regalloc advisor mode"), cl::values(clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Default, "default", "Default"), clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Release, "release", "precompiled"), clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Development, "development", "for training")))

LLVM Basic Block Representation.

Predicate

This enumeration lists the possible predicates for CmpInst subclasses.

Lightweight error class with error context and mandatory checking.

IR2VecPrinterPass(raw_ostream &OS)

Definition IR2Vec.h:639

static bool isRequired()

Definition IR2Vec.h:641

LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)

This analysis provides the vocabulary for IR2Vec.

Definition IR2Vec.h:614

IR2VecVocabAnalysis()=default

ir2vec::Vocabulary Result

Definition IR2Vec.h:629

LLVM_ABI Result run(Module &M, ModuleAnalysisManager &MAM)

LLVM_ABI IR2VecVocabAnalysis(ir2vec::VocabStorage &&Vocab)

Definition IR2Vec.h:627

static LLVM_ABI AnalysisKey Key

Definition IR2Vec.h:625

static bool isRequired()

Definition IR2Vec.h:651

LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)

IR2VecVocabPrinterPass(raw_ostream &OS)

Definition IR2Vec.h:649

This is an important class for using LLVM in a threaded context.

A Module instance is used to store all the information related to an LLVM module.

A set of analyses that are preserved following a run of a transformation pass.

A wrapper around a string literal that serves as a proxy for constructing global tables of StringRefs...

StringRef - Represent a constant reference to a string, i.e.

TypeID

Definitions of all of the base types for the Type system.

LLVM Value Representation.

static LLVM_ABI std::unique_ptr< Embedder > create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab)

Factory method to create an Embedder object.

virtual Embedding computeEmbeddings(const Instruction &I) const =0

Function to compute the embedding for a given instruction.

LLVM_ABI Embedding getInstVector(const Instruction &I) const

Computes and returns the embedding for a given instruction in the function F.

Definition IR2Vec.h:560

const Vocabulary & Vocab

Definition IR2Vec.h:527

virtual ~Embedder()=default

const float TypeWeight

Definition IR2Vec.h:534

const float OpcWeight

Weights for different entities (like opcode, arguments, types) in the IR instructions to generate the...

Definition IR2Vec.h:534

LLVM_ABI Embedding getFunctionVector() const

Computes and returns the embedding for the current function.

Definition IR2Vec.h:571

LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab)

Definition IR2Vec.h:536

const unsigned Dimension

Dimension of the vector representation; captured from the input vocabulary.

Definition IR2Vec.h:530

virtual void invalidateEmbeddings()

Invalidate embeddings if cached.

Definition IR2Vec.h:578

Embedding computeEmbeddings() const

Function to compute embeddings.

const float ArgWeight

Definition IR2Vec.h:534

const Function & F

Definition IR2Vec.h:526

LLVM_ABI Embedding getBBVector(const BasicBlock &BB) const

Computes and returns the embedding for a given basic block in the function F.

Definition IR2Vec.h:566

void invalidateEmbeddings() override

Invalidate embeddings if cached.

Definition IR2Vec.h:606

FlowAwareEmbedder(const Function &F, const Vocabulary &Vocab)

Definition IR2Vec.h:604

SymbolicEmbedder(const Function &F, const Vocabulary &Vocab)

Definition IR2Vec.h:589

Iterator support for section-based access.

Definition IR2Vec.h:196

const_iterator(const VocabStorage *Storage, unsigned SectionId, size_t LocalIndex)

Definition IR2Vec.h:202

LLVM_ABI bool operator!=(const const_iterator &Other) const

LLVM_ABI const_iterator & operator++()

LLVM_ABI const Embedding & operator*() const

LLVM_ABI bool operator==(const const_iterator &Other) const

Generic storage class for section-based vocabularies.

Definition IR2Vec.h:151

static Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab, unsigned &Dim)

Parse a vocabulary section from JSON and populate the target vocabulary map.

VocabStorage & operator=(VocabStorage &&)=default

const_iterator end() const

Definition IR2Vec.h:213

unsigned getNumSections() const

Get number of sections.

Definition IR2Vec.h:179

VocabStorage & operator=(const VocabStorage &)=delete

unsigned getDimension() const

Get vocabulary dimension.

Definition IR2Vec.h:190

size_t size() const

Get total number of entries across all sections.

Definition IR2Vec.h:176

VocabStorage()=default

Default constructor creates empty storage (invalid state)

const_iterator begin() const

Definition IR2Vec.h:212

bool isValid() const

Check if vocabulary is valid (has data)

Definition IR2Vec.h:193

VocabStorage(VocabStorage &&)=default

std::map< std::string, Embedding > VocabMap

Definition IR2Vec.h:217

const std::vector< Embedding > & operator[](unsigned SectionId) const

Section-based access: Storage[sectionId][localIndex].

Definition IR2Vec.h:184

VocabStorage(const VocabStorage &)=delete

Class for storing and accessing the IR2Vec vocabulary.

Definition IR2Vec.h:242

static LLVM_ABI StringRef getVocabKeyForOperandKind(OperandKind Kind)

Function to get vocabulary key for a given OperandKind.

Definition IR2Vec.h:352

LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA, ModuleAnalysisManager::Invalidator &Inv) const

const_iterator begin() const

Definition IR2Vec.h:411

LLVM_ABI unsigned getDimension() const

Definition IR2Vec.h:334

Vocabulary(Vocabulary &&)=default

static LLVM_ABI OperandKind getOperandKind(const Value *Op)

Function to classify an operand into OperandKind.

static LLVM_ABI unsigned getIndex(CmpInst::Predicate P)

Definition IR2Vec.h:381

Vocabulary & operator=(const Vocabulary &)=delete

static LLVM_ABI StringRef getStringKey(unsigned Pos)

Returns the string key for a given index position in the vocabulary.

static constexpr unsigned MaxCanonicalTypeIDs

Definition IR2Vec.h:312

LLVM_ABI const ir2vec::Embedding & operator[](CmpInst::Predicate P) const

Definition IR2Vec.h:403

static constexpr unsigned MaxOperandKinds

Definition IR2Vec.h:314

Vocabulary(const Vocabulary &)=delete

const_iterator cbegin() const

Definition IR2Vec.h:416

OperandKind

Operand kinds supported by IR2Vec Vocabulary.

Definition IR2Vec.h:298

@ MaxOperandKind

Definition IR2Vec.h:303

@ ConstantID

Definition IR2Vec.h:301

@ PointerID

Definition IR2Vec.h:300

@ FunctionID

Definition IR2Vec.h:299

@ VariableID

Definition IR2Vec.h:302

static constexpr size_t getCanonicalSize()

Total number of entries (opcodes + canonicalized types + operand kinds + predicates)

Definition IR2Vec.h:341

static LLVM_ABI unsigned getIndex(const Value &Op)

Definition IR2Vec.h:375

static LLVM_ABI StringRef getVocabKeyForPredicate(CmpInst::Predicate P)

Function to get vocabulary key for a given predicate.

static constexpr unsigned MaxTypeIDs

Definition IR2Vec.h:311

LLVM_ABI Vocabulary(VocabStorage &&Storage)

Definition IR2Vec.h:322

LLVM_ABI const ir2vec::Embedding & operator[](Type::TypeID TypeID) const

Definition IR2Vec.h:391

static LLVM_ABI unsigned getIndex(Type::TypeID TypeID)

Definition IR2Vec.h:370

const_iterator end() const

Definition IR2Vec.h:418

static LLVM_ABI StringRef getVocabKeyForOpcode(unsigned Opcode)

Function to get vocabulary key for a given Opcode.

static LLVM_ABI StringRef getVocabKeyForTypeID(Type::TypeID TypeID)

Function to get vocabulary key for a given TypeID.

Definition IR2Vec.h:347

VocabStorage::const_iterator const_iterator

Const Iterator type aliases.

Definition IR2Vec.h:409

const_iterator cend() const

Definition IR2Vec.h:423

static LLVM_ABI unsigned getIndex(unsigned Opcode)

Functions to return flat index.

Definition IR2Vec.h:365

LLVM_ABI bool isValid() const

Definition IR2Vec.h:330

Vocabulary & operator=(Vocabulary &&Other)=delete

LLVM_ABI const ir2vec::Embedding & operator[](unsigned Opcode) const

Accessors to get the embedding for a given entity.

Definition IR2Vec.h:386

static LLVM_ABI VocabStorage createDummyVocabForTest(unsigned Dim=1)

Create a dummy vocabulary for testing purposes.

static constexpr unsigned MaxPredicateKinds

Definition IR2Vec.h:318

CanonicalTypeID

Canonical type IDs supported by IR2Vec Vocabulary.

Definition IR2Vec.h:281

@ VoidTy

Definition IR2Vec.h:283

@ FloatTy

Definition IR2Vec.h:282

@ TokenTy

Definition IR2Vec.h:287

@ ArrayTy

Definition IR2Vec.h:292

@ UnknownTy

Definition IR2Vec.h:293

@ LabelTy

Definition IR2Vec.h:284

@ VectorTy

Definition IR2Vec.h:286

@ FunctionTy

Definition IR2Vec.h:289

@ MaxCanonicalType

Definition IR2Vec.h:294

@ MetadataTy

Definition IR2Vec.h:285

@ IntegerTy

Definition IR2Vec.h:288

@ StructTy

Definition IR2Vec.h:291

@ PointerTy

Definition IR2Vec.h:290

LLVM_ABI const ir2vec::Embedding & operator[](const Value &Arg) const

Definition IR2Vec.h:397

A Value is an JSON value of unknown type.

This class implements an extremely fast bulk output stream that can only output to a stream.

DenseMap< const Instruction *, Embedding > InstEmbeddingsMap

Definition IR2Vec.h:145

LLVM_ABI cl::opt< float > ArgWeight

DenseMap< const BasicBlock *, Embedding > BBEmbeddingsMap

Definition IR2Vec.h:146

LLVM_ABI cl::opt< float > OpcWeight

LLVM_ABI cl::opt< float > TypeWeight

LLVM_ABI cl::opt< IR2VecKind > IR2VecEmbeddingKind

LLVM_ABI llvm:🆑:OptionCategory IR2VecCategory

This is an optimization pass for GlobalISel generic memory operations.

IR2VecKind

IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.

Definition IR2Vec.h:71

@ FlowAware

Definition IR2Vec.h:71

@ Symbolic

Definition IR2Vec.h:71

LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key

DWARFExpression::Operation Op

OutputIt move(R &&Range, OutputIt Out)

Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.

AnalysisManager< Module > ModuleAnalysisManager

Convenience typedef for the Module analysis manager.

Implement std::hash so that hash_code can be used in STL containers.

A CRTP mix-in that provides informational APIs needed for analysis passes.

A special type used by analysis passes to provide an address that identifies that particular analysis...

A CRTP mix-in to automatically provide informational APIs needed for passes.

Embedding is a datatype that wraps std::vector.

Definition IR2Vec.h:87

iterator end()

Definition IR2Vec.h:117

const_iterator end() const

Definition IR2Vec.h:119

iterator begin()

Definition IR2Vec.h:116

LLVM_ABI bool approximatelyEquals(const Embedding &RHS, double Tolerance=1e-4) const

Returns true if the embedding is approximately equal to the RHS embedding within the specified tolera...

const_iterator cbegin() const

Definition IR2Vec.h:120

std::vector< double >::iterator iterator

Definition IR2Vec.h:113

LLVM_ABI Embedding & operator+=(const Embedding &RHS)

Arithmetic operators.

std::vector< double >::const_iterator const_iterator

Definition IR2Vec.h:114

LLVM_ABI Embedding operator-(const Embedding &RHS) const

const std::vector< double > & getData() const

Definition IR2Vec.h:123

Embedding(size_t Size, double InitialValue)

Definition IR2Vec.h:98

LLVM_ABI Embedding & operator-=(const Embedding &RHS)

const_iterator cend() const

Definition IR2Vec.h:121

LLVM_ABI Embedding operator*(double Factor) const

size_t size() const

Definition IR2Vec.h:100

LLVM_ABI Embedding & operator*=(double Factor)

Embedding(std::initializer_list< double > IL)

Definition IR2Vec.h:95

Embedding(const std::vector< double > &V)

Definition IR2Vec.h:93

LLVM_ABI Embedding operator+(const Embedding &RHS) const

bool empty() const

Definition IR2Vec.h:101

LLVM_ABI Embedding & scaleAndAdd(const Embedding &Src, float Factor)

Adds Src Embedding scaled by Factor with the called Embedding.

Embedding(std::vector< double > &&V)

Definition IR2Vec.h:94

const double & operator[](size_t Itr) const

Definition IR2Vec.h:108

Embedding(size_t Size)

Definition IR2Vec.h:97

LLVM_ABI void print(raw_ostream &OS) const

const_iterator begin() const

Definition IR2Vec.h:118

double & operator[](size_t Itr)

Definition IR2Vec.h:103