LLVM: lib/Analysis/IR2Vec.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

15

29

30using namespace llvm;

31using namespace ir2vec;

32

33#define DEBUG_TYPE "ir2vec"

34

36 "Number of lookups to entities not present in the vocabulary");

37

38namespace llvm {

41

42

45 cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""),

48 cl::desc("Weight for opcode embeddings"),

51 cl::desc("Weight for type embeddings"),

54 cl::desc("Weight for argument embeddings"),

59 "Generate symbolic embeddings"),

61 "Generate flow-aware embeddings")),

64

65}

66}

67

69

70

71

72

76 std::vector TempOut;

78 return false;

79 Out = Embedding(std::move(TempOut));

80 return true;

81}

82}

83

84

85

86

88 assert(this->size() == RHS.size() && "Vectors must have the same dimension");

89 std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),

90 std::plus());

91 return *this;

92}

93

96 Result += RHS;

97 return Result;

98}

99

101 assert(this->size() == RHS.size() && "Vectors must have the same dimension");

102 std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),

103 std::minus());

104 return *this;

105}

106

109 Result -= RHS;

110 return Result;

111}

112

114 std::transform(this->begin(), this->end(), this->begin(),

115 [Factor](double Elem) { return Elem * Factor; });

116 return *this;

117}

118

121 Result *= Factor;

122 return Result;

123}

124

126 assert(this->size() == Src.size() && "Vectors must have the same dimension");

127 for (size_t Itr = 0; Itr < this->size(); ++Itr)

128 (*this)[Itr] += Src[Itr] * Factor;

129 return *this;

130}

131

133 double Tolerance) const {

134 assert(this->size() == RHS.size() && "Vectors must have the same dimension");

135 for (size_t Itr = 0; Itr < this->size(); ++Itr)

136 if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance) {

137 LLVM_DEBUG(errs() << "Embedding mismatch at index " << Itr << ": "

138 << (*this)[Itr] << " vs " << RHS[Itr]

139 << "; Tolerance: " << Tolerance << "\n");

140 return false;

141 }

142 return true;

143}

144

146 OS << " [";

147 for (const auto &Elem : Data)

148 OS << " " << format("%.2f", Elem) << " ";

149 OS << "]\n";

150}

151

152

153

154

155

158 switch (Mode) {

160 return std::make_unique(F, Vocab);

162 return std::make_unique(F, Vocab);

163 }

164 return nullptr;

165}

166

169

170 if (F.isDeclaration())

171 return FuncVector;

172

173

176 return FuncVector;

177}

178

181

182

185 return BBVector;

186}

187

189

190

192 for (const auto &Op : I.operands())

194 auto InstVector =

195 Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;

197 InstVector += Vocab[IC->getPredicate()];

198 return InstVector;

199}

200

202

203 auto It = InstVecMap.find(&I);

204 if (It != InstVecMap.end())

205 return It->second;

206

207

208

210 for (const auto &Op : I.operands()) {

211

213 auto DefIt = InstVecMap.find(DefInst);

214

215

216

217

218

219

220

221

222

223

224

225 if (DefIt != InstVecMap.end())

226 ArgEmb += DefIt->second;

227 else

229 }

230

231

232 else {

233 LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "

234 << *Op << "=" << Vocab[*Op][0] << "\n");

236 }

237 }

238

239

240 auto InstVector =

241 Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;

243 InstVector += Vocab[IC->getPredicate()];

244 InstVecMap[&I] = InstVector;

245 return InstVector;

246}

247

248

249

250

251

253 : Sections(std::move(SectionData)), TotalSize([&] {

254 assert(!Sections.empty() && "Vocabulary has no sections");

255

256 size_t Size = 0;

257 for (const auto &Section : Sections) {

258 assert(!Section.empty() && "Vocabulary section is empty");

259 Size += Section.size();

260 }

262 }()),

263 Dimension([&] {

264

265

266 assert(!Sections.empty() && "Vocabulary has no sections");

267 assert(!Sections[0].empty() && "First section of vocabulary is empty");

268 unsigned ExpectedDim = static_cast<unsigned>(Sections[0][0].size());

269

270

271

272 [[maybe_unused]] auto allSameDim =

273 [ExpectedDim](const std::vector &Section) {

274 return std::all_of(Section.begin(), Section.end(),

275 [ExpectedDim](const Embedding &Emb) {

276 return Emb.size() == ExpectedDim;

277 });

278 };

279 assert(std::all_of(Sections.begin(), Sections.end(), allSameDim) &&

280 "All embeddings must have the same dimension");

281

282 return ExpectedDim;

283 }()) {}

284

286 assert(SectionId < Storage->Sections.size() && "Invalid section ID");

287 assert(LocalIndex < Storage->Sections[SectionId].size() &&

288 "Local index out of range");

289 return Storage->Sections[SectionId][LocalIndex];

290}

291

293 ++LocalIndex;

294

296 LocalIndex >= Storage->Sections[SectionId].size()) {

297 assert(LocalIndex == Storage->Sections[SectionId].size() &&

298 "Local index should be at the end of the current section");

299 LocalIndex = 0;

300 ++SectionId;

301 }

302 return *this;

303}

304

307 return Storage == Other.Storage && SectionId == Other.SectionId &&

308 LocalIndex == Other.LocalIndex;

309}

310

313 return !(*this == Other);

314}

315

318 VocabMap &TargetVocab, unsigned &Dim) {

321 if (!RootObj)

323 "JSON root is not an object");

324

326 if (!SectionValue)

328 "Missing '" + std::string(Key) +

329 "' section in vocabulary file");

330 if (json::fromJSON(*SectionValue, TargetVocab, Path))

332 "Unable to parse '" + std::string(Key) +

333 "' section from vocabulary");

334

335 Dim = TargetVocab.begin()->second.size();

336 if (Dim == 0)

338 "Dimension of '" + std::string(Key) +

339 "' section of the vocabulary is zero");

340

341 if (!std::all_of(TargetVocab.begin(), TargetVocab.end(),

342 [Dim](const std::pair<StringRef, Embedding> &Entry) {

343 return Entry.second.size() == Dim;

344 }))

347 "All vectors in the '" + std::string(Key) +

348 "' section of the vocabulary are not of the same dimension");

349

351}

352

353

354

355

356

358 assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");

359#define HANDLE_INST(NUM, OPCODE, CLASS) \

360 if (Opcode == NUM) { \

361 return #OPCODE; \

362 }

363#include "llvm/IR/Instruction.def"

364#undef HANDLE_INST

365 return "UnknownOpcode";

366}

367

368

378

382 else

385}

386

387CmpInst::Predicate Vocabulary::getPredicateFromLocalIndex(unsigned LocalIndex) {

388 unsigned fcmpRange =

390 if (LocalIndex < fcmpRange)

392 LocalIndex);

393 else

395 LocalIndex - fcmpRange);

396}

397

401 PredNameBuffer = "FCMP_";

402 else

403 PredNameBuffer = "ICMP_";

405 return PredNameBuffer;

406}

407

409 assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary");

410

411 if (Pos < MaxOpcodes)

413

414 if (Pos < OperandBaseOffset)

415 return getVocabKeyForCanonicalTypeID(

417

418 if (Pos < PredicateBaseOffset)

420 static_cast<OperandKind>(Pos - OperandBaseOffset));

421

423}

424

425

427 ModuleAnalysisManager::Invalidator &Inv) const {

429 return !(PAC.preservedWhenStateless());

430}

431

433 float DummyVal = 0.1f;

434

435

436

437 std::vector<std::vector> Sections;

438 Sections.reserve(4);

439

440

441 std::vector OpcodeSec;

442 OpcodeSec.reserve(MaxOpcodes);

443 for (unsigned I = 0; I < MaxOpcodes; ++I) {

444 OpcodeSec.emplace_back(Dim, DummyVal);

445 DummyVal += 0.1f;

446 }

447 Sections.push_back(std::move(OpcodeSec));

448

449

450 std::vector TypeSec;

453 TypeSec.emplace_back(Dim, DummyVal);

454 DummyVal += 0.1f;

455 }

456 Sections.push_back(std::move(TypeSec));

457

458

459 std::vector OperandSec;

462 OperandSec.emplace_back(Dim, DummyVal);

463 DummyVal += 0.1f;

464 }

465 Sections.push_back(std::move(OperandSec));

466

467

468 std::vector PredicateSec;

471 PredicateSec.emplace_back(Dim, DummyVal);

472 DummyVal += 0.1f;

473 }

474 Sections.push_back(std::move(PredicateSec));

475

477}

478

479

480

481

482

483

484

485Error IR2VecVocabAnalysis::readVocabulary(VocabMap &OpcVocab,

489 if (!BufOrError)

491

492 auto Content = BufOrError.get()->getBuffer();

493

495 if (!ParsedVocabValue)

496 return ParsedVocabValue.takeError();

497

498 unsigned OpcodeDim = 0, TypeDim = 0, ArgDim = 0;

500 OpcVocab, OpcodeDim))

501 return Err;

502

504 TypeVocab, TypeDim))

505 return Err;

506

508 ArgVocab, ArgDim))

509 return Err;

510

511 if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))

513 "Vocabulary sections have different dimensions");

514

516}

517

518void IR2VecVocabAnalysis::generateVocabStorage(VocabMap &OpcVocab,

521

522

523

524

525 auto handleMissingEntity = [](const std::string &Val) {

527 << " is not in vocabulary, using zero vector; This "

528 "would result in an error in future.\n");

529 ++VocabMissCounter;

530 };

531

532 unsigned Dim = OpcVocab.begin()->second.size();

533 assert(Dim > 0 && "Vocabulary dimension must be greater than zero");

534

535

536 std::vector NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,

538 for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {

540 auto It = OpcVocab.find(VocabKey.str());

541 if (It != OpcVocab.end())

542 NumericOpcodeEmbeddings[Opcode] = It->second;

543 else

544 handleMissingEntity(VocabKey.str());

545 }

546

547

551 StringRef VocabKey = Vocabulary::getVocabKeyForCanonicalTypeID(

553 if (auto It = TypeVocab.find(VocabKey.str()); It != TypeVocab.end()) {

554 NumericTypeEmbeddings[CTypeID] = It->second;

555 continue;

556 }

557 handleMissingEntity(VocabKey.str());

558 }

559

560

566 auto It = ArgVocab.find(VocabKey.str());

567 if (It != ArgVocab.end()) {

568 NumericArgEmbeddings[OpKind] = It->second;

569 continue;

570 }

571 handleMissingEntity(VocabKey.str());

572 }

573

574

575

579 StringRef VocabKey =

581 auto It = ArgVocab.find(VocabKey.str());

582 if (It != ArgVocab.end()) {

583 NumericPredEmbeddings[PK] = It->second;

584 continue;

585 }

586 handleMissingEntity(VocabKey.str());

587 }

588

589

590

591 std::vector<std::vector> Sections(4);

592 Sections[static_cast<unsigned>(Vocabulary::Section::Opcodes)] =

593 std::move(NumericOpcodeEmbeddings);

594 Sections[static_cast<unsigned>(Vocabulary::Section::CanonicalTypes)] =

595 std::move(NumericTypeEmbeddings);

596 Sections[static_cast<unsigned>(Vocabulary::Section::Operands)] =

597 std::move(NumericArgEmbeddings);

598 Sections[static_cast<unsigned>(Vocabulary::Section::Predicates)] =

599 std::move(NumericPredEmbeddings);

600

601

602 Vocab.emplace(std::move(Sections));

603}

604

605void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {

606 handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {

608 });

609}

610

613 auto Ctx = &M.getContext();

614

615 if (Vocab.has_value())

616 return Vocabulary(std::move(Vocab.value()));

617

618

620

621 Ctx->emitError("IR2Vec vocabulary file path not specified; You may need to "

622 "set it using --ir2vec-vocab-path");

623 return Vocabulary();

624 }

625

626 VocabMap OpcVocab, TypeVocab, ArgVocab;

627 if (auto Err = readVocabulary(OpcVocab, TypeVocab, ArgVocab)) {

628 emitError(std::move(Err), *Ctx);

630 }

631

632

633 auto scaleVocabSection = [](VocabMap &Vocab, double Weight) {

634 for (auto &Entry : Vocab)

635 Entry.second *= Weight;

636 };

637 scaleVocabSection(OpcVocab, OpcWeight);

638 scaleVocabSection(TypeVocab, TypeWeight);

639 scaleVocabSection(ArgVocab, ArgWeight);

640

641

642 generateVocabStorage(OpcVocab, TypeVocab, ArgVocab);

643

644 return Vocabulary(std::move(Vocab.value()));

645}

646

647

648

649

650

655

658 if (!Emb) {

659 OS << "Error creating IR2Vec embeddings \n";

660 continue;

661 }

662

663 OS << "IR2Vec embeddings for function " << F.getName() << ":\n";

664 OS << "Function vector: ";

665 Emb->getFunctionVector().print(OS);

666

667 OS << "Basic block vectors:\n";

669 OS << "Basic block: " << BB.getName() << ":\n";

670 Emb->getBBVector(BB).print(OS);

671 }

672

673 OS << "Instruction vectors:\n";

676 OS << "Instruction: ";

677 I.print(OS);

678 Emb->getInstVector(I).print(OS);

679 }

680 }

681 }

683}

684

688 assert(IR2VecVocabulary.isValid() && "IR2Vec Vocabulary is invalid");

689

690

691 unsigned Pos = 0;

692 for (const auto &Entry : IR2VecVocabulary) {

693 OS << "Key: " << IR2VecVocabulary.getStringKey(Pos++) << ": ";

694 Entry.print(OS);

695 }

697}

assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")

#define clEnumValN(ENUMVAL, FLAGNAME, DESC)

This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.

This file defines the IR2Vec vocabulary analysis(IR2VecVocabAnalysis), the core ir2vec::Embedder inte...

This file provides various utilities for inspecting and working with the control flow graph in LLVM I...

Module.h This file contains the declarations for the Module class.

This header defines various interfaces for pass management in LLVM.

ModuleAnalysisManager MAM

Provides some synthesis utilities to produce sequences of values.

This file defines the SmallVector class.

This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...

#define STATISTIC(VARNAME, DESC)

LLVM Basic Block Representation.

LLVM_ABI iterator_range< filter_iterator< BasicBlock::const_iterator, std::function< bool(const Instruction &)> > > instructionsWithoutDebug(bool SkipPseudoOp=true) const

Return a const iterator range over the instructions in the block, skipping any debug instructions.

Predicate

This enumeration lists the possible predicates for CmpInst subclasses.

static LLVM_ABI StringRef getPredicateName(Predicate P)

iterator find(const_arg_type_t< KeyT > Val)

virtual std::string message() const

Return the error message as a string.

Lightweight error class with error context and mandatory checking.

static ErrorSuccess success()

Create a success value.

Tagged union holding either a T or a Error.

Error takeError()

Take ownership of the stored error.

LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)

Definition IR2Vec.cpp:651

This analysis provides the vocabulary for IR2Vec.

ir2vec::Vocabulary Result

LLVM_ABI Result run(Module &M, ModuleAnalysisManager &MAM)

Definition IR2Vec.cpp:612

static LLVM_ABI AnalysisKey Key

LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)

Definition IR2Vec.cpp:685

LLVM_ABI void emitError(const Instruction *I, const Twine &ErrorStr)

emitError - Emit an error message to the currently installed error handler with optional location inf...

static ErrorOr< std::unique_ptr< MemoryBuffer > > getFileOrSTDIN(const Twine &Filename, bool IsText=false, bool RequiresNullTerminator=true, std::optional< Align > Alignment=std::nullopt)

Open the specified file as a MemoryBuffer, or open stdin if the Filename is "-".

A Module instance is used to store all the information related to an LLVM module.

A set of analyses that are preserved following a run of a transformation pass.

static PreservedAnalyses all()

Construct a special preserved set that preserves all passes.

PreservedAnalysisChecker getChecker() const

Build a checker for this PreservedAnalyses and the specified analysis type.

SmallString - A SmallString is just a SmallVector with methods and accessors that make it work better...

StringRef - Represent a constant reference to a string, i.e.

std::string str() const

str - Get the contents as an std::string.

LLVM Value Representation.

static LLVM_ABI std::unique_ptr< Embedder > create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab)

Factory method to create an Embedder object.

Definition IR2Vec.cpp:156

const unsigned Dimension

Dimension of the vector representation; captured from the input vocabulary.

Embedding computeEmbeddings() const

Function to compute embeddings.

Definition IR2Vec.cpp:167

Iterator support for section-based access.

const_iterator(const VocabStorage *Storage, unsigned SectionId, size_t LocalIndex)

LLVM_ABI bool operator!=(const const_iterator &Other) const

Definition IR2Vec.cpp:311

LLVM_ABI const_iterator & operator++()

Definition IR2Vec.cpp:292

LLVM_ABI const Embedding & operator*() const

Definition IR2Vec.cpp:285

LLVM_ABI bool operator==(const const_iterator &Other) const

Definition IR2Vec.cpp:305

Generic storage class for section-based vocabularies.

static Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab, unsigned &Dim)

Parse a vocabulary section from JSON and populate the target vocabulary map.

Definition IR2Vec.cpp:316

unsigned getNumSections() const

Get number of sections.

size_t size() const

Get total number of entries across all sections.

VocabStorage()=default

Default constructor creates empty storage (invalid state)

std::map< std::string, Embedding > VocabMap

Class for storing and accessing the IR2Vec vocabulary.

static LLVM_ABI StringRef getVocabKeyForOperandKind(OperandKind Kind)

Function to get vocabulary key for a given OperandKind.

LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA, ModuleAnalysisManager::Invalidator &Inv) const

Definition IR2Vec.cpp:426

static LLVM_ABI OperandKind getOperandKind(const Value *Op)

Function to classify an operand into OperandKind.

Definition IR2Vec.cpp:369

friend class llvm::IR2VecVocabAnalysis

static LLVM_ABI StringRef getStringKey(unsigned Pos)

Returns the string key for a given index position in the vocabulary.

Definition IR2Vec.cpp:408

static constexpr unsigned MaxCanonicalTypeIDs

static constexpr unsigned MaxOperandKinds

OperandKind

Operand kinds supported by IR2Vec Vocabulary.

static LLVM_ABI StringRef getVocabKeyForPredicate(CmpInst::Predicate P)

Function to get vocabulary key for a given predicate.

Definition IR2Vec.cpp:398

static LLVM_ABI StringRef getVocabKeyForOpcode(unsigned Opcode)

Function to get vocabulary key for a given Opcode.

Definition IR2Vec.cpp:357

LLVM_ABI bool isValid() const

static LLVM_ABI VocabStorage createDummyVocabForTest(unsigned Dim=1)

Create a dummy vocabulary for testing purposes.

Definition IR2Vec.cpp:432

static constexpr unsigned MaxPredicateKinds

CanonicalTypeID

Canonical type IDs supported by IR2Vec Vocabulary.

An Object is a JSON object, which maps strings to heterogenous JSON values.

LLVM_ABI Value * get(StringRef K)

The root is the trivial Path to the root value.

A "cursor" marking a position within a Value.

A Value is an JSON value of unknown type.

const json::Object * getAsObject() const

This class implements an extremely fast bulk output stream that can only output to a stream.

ValuesClass values(OptsTy... Options)

Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...

initializer< Ty > init(const Ty &Val)

static cl::opt< std::string > VocabFile("ir2vec-vocab-path", cl::Optional, cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""), cl::cat(IR2VecCategory))

LLVM_ABI cl::opt< float > ArgWeight

LLVM_ABI cl::opt< float > OpcWeight

LLVM_ABI cl::opt< float > TypeWeight

LLVM_ABI cl::opt< IR2VecKind > IR2VecEmbeddingKind

LLVM_ABI llvm:🆑:OptionCategory IR2VecCategory

LLVM_ABI llvm::Expected< Value > parse(llvm::StringRef JSON)

Parses the provided JSON source, or returns a ParseError.

bool fromJSON(const Value &E, std::string &Out, Path P)

ir2vec::Embedding Embedding

This is an optimization pass for GlobalISel generic memory operations.

Error createFileError(const Twine &F, Error E)

Concatenate a source file path and/or name with an Error.

decltype(auto) dyn_cast(const From &Val)

dyn_cast - Return the argument parameter cast to the specified type.

void handleAllErrors(Error E, HandlerTs &&... Handlers)

Behaves the same as handleErrors, except that by contract all errors must be handled by the given han...

Error createStringError(std::error_code EC, char const *Fmt, const Ts &... Vals)

Create formatted StringError object.

IR2VecKind

IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.

bool isa(const From &Val)

isa - Return true if the parameter to the template is an instance of one of the template type argu...

format_object< Ts... > format(const char *Fmt, const Ts &... Vals)

These are helper functions used to produce formatted output.

LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key

LLVM_ABI raw_fd_ostream & errs()

This returns a reference to a raw_ostream for standard error.

DWARFExpression::Operation Op

OutputIt move(R &&Range, OutputIt Out)

Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.

iterator_range< df_iterator< T > > depth_first(const T &G)

auto seq(T Begin, T End)

Iterate over an integral type from Begin up to - but not including - End.

AnalysisManager< Module > ModuleAnalysisManager

Convenience typedef for the Module analysis manager.

Implement std::hash so that hash_code can be used in STL containers.

A special type used by analysis passes to provide an address that identifies that particular analysis...

Embedding is a datatype that wraps std::vector.

LLVM_ABI bool approximatelyEquals(const Embedding &RHS, double Tolerance=1e-4) const

Returns true if the embedding is approximately equal to the RHS embedding within the specified tolera...

Definition IR2Vec.cpp:132

LLVM_ABI Embedding & operator+=(const Embedding &RHS)

Arithmetic operators.

Definition IR2Vec.cpp:87

LLVM_ABI Embedding operator-(const Embedding &RHS) const

Definition IR2Vec.cpp:107

LLVM_ABI Embedding & operator-=(const Embedding &RHS)

Definition IR2Vec.cpp:100

LLVM_ABI Embedding operator*(double Factor) const

Definition IR2Vec.cpp:119

LLVM_ABI Embedding & operator*=(double Factor)

Definition IR2Vec.cpp:113

LLVM_ABI Embedding operator+(const Embedding &RHS) const

Definition IR2Vec.cpp:94

LLVM_ABI Embedding & scaleAndAdd(const Embedding &Src, float Factor)

Adds Src Embedding scaled by Factor with the called Embedding.

Definition IR2Vec.cpp:125

LLVM_ABI void print(raw_ostream &OS) const

Definition IR2Vec.cpp:145