LLVM: include/llvm/CodeGen/MIR2Vec.h Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40#ifndef LLVM_CODEGEN_MIR2VEC_H
41#define LLVM_CODEGEN_MIR2VEC_H
42
56#include
57#include
58#include
59
60namespace llvm {
61
67
69
71
72
75
78
83
84
85
88 using VocabMap = std::map<std::string, ir2vec::Embedding>;
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106 struct {
112 } Layout;
113
114
115
116
117
118
119 enum class Section : unsigned {
120 Opcodes = 0,
121 CommonOperands = 1,
122 PhyRegisters = 2,
123 VirtRegisters = 3,
124 MaxSections
125 };
126
128 std::setstd::string UniqueBaseOpcodeNames;
130
131
132
134
135
136
137 static constexpr StringLiteral CommonOperandNames[] = {
138 "Immediate", "CImmediate", "FPImmediate", "MBB",
139 "FrameIndex", "ConstantPoolIndex", "TargetIndex", "JumpTableIndex",
140 "ExternalSymbol", "GlobalAddress", "BlockAddress", "RegisterMask",
141 "RegisterLiveOut", "Metadata", "MCSymbol", "CFIIndex",
142 "IntrinsicID", "Predicate", "ShuffleMask", "LaneMask"};
144 "Common operand names size changed, update accordingly");
145
146 const TargetInstrInfo &TII;
147 const TargetRegisterInfo &TRI;
148 const MachineRegisterInfo &MRI;
149
150 void generateStorage(const VocabMap &OpcodeMap,
151 const VocabMap &CommonOperandMap,
152 const VocabMap &PhyRegMap, const VocabMap &VirtRegMap);
153 void buildCanonicalOpcodeMapping();
154 void buildRegisterOperandMapping();
155
156
157 LLVM_ABI unsigned getCanonicalOpcodeIndex(unsigned Opcode) const;
158
159
160 unsigned
162
163
165
166
169 unsigned LocalIndex = getCommonOperandIndex(OperandType);
170 return Storage[static_cast<unsigned>(Section::CommonOperands)][LocalIndex];
171 }
172
174
175
176 if (.isValid())
177 return ZeroEmbedding;
178
179
180
181
182
183
184 if (Reg.isStack())
185 return ZeroEmbedding;
186
187 unsigned LocalIndex = getRegisterOperandIndex(Reg);
188 auto SectionID =
189 Reg.isPhysical() ? Section::PhyRegisters : Section::VirtRegisters;
190 return Storage[static_cast<unsigned>(SectionID)][LocalIndex];
191 }
192
193
194
195 LLVM_ABI unsigned getEntityIDForCommonOperand(
197 return Layout.CommonOperandBase + getCommonOperandIndex(OperandType);
198 }
199
200
201
202 unsigned getEntityIDForRegister(Register Reg) const {
203 if (.isValid() || Reg.isStack())
204 return Layout
205 .VirtRegBase;
206 unsigned LocalIndex = getRegisterOperandIndex(Reg);
207 size_t BaseOffset =
208 Reg.isPhysical() ? Layout.PhyRegBase : Layout.VirtRegBase;
209 return BaseOffset + LocalIndex;
210 }
211
212public:
213
216
217
218
225 bool IsPhysical = true) const;
226
227
229
230 unsigned getDimension() const { return Storage.getDimension(); }
231
232
233
235 return Layout.OpcodeBase + getCanonicalOpcodeIndex(Opcode);
236 }
237
238
239
242 return getEntityIDForRegister(MO.getReg());
243 return getEntityIDForCommonOperand(MO.getType());
244 }
245
246
248 unsigned LocalIndex = getCanonicalOpcodeIndex(Opcode);
249 return Storage[static_cast<unsigned>(Section::Opcodes)][LocalIndex];
250 }
251
253 auto OperandType = Operand.getType();
255 return operator[](Operand.getReg());
256 else
257 return operator[](OperandType);
258 }
259
260
263
265
267
268
270 create(VocabMap &&OpcMap, VocabMap &&CommonOperandsMap, VocabMap &&PhyRegMap,
273
274
279
280
282
283private:
284 MIRVocabulary(VocabMap &&OpcMap, VocabMap &&CommonOperandsMap,
285 VocabMap &&PhyRegMap, VocabMap &&VirtRegMap,
288};
289
290
292protected:
295
296
298
299
301
307
308
310
311
313
314
315
317
318public:
320
321
322
323 LLVM_ABI static std::unique_ptr
326
327
328
332
333
334
338
339
345};
346
347
348
349
351private:
353
354public:
358};
359
360}
361
362
363
364
365
366
367
368
369
371 using VocabMap = std::map<std::string, mir2vec::Embedding>;
372
373public:
375
377
378private:
379 Error readVocabulary(VocabMap &OpcVocab, VocabMap &CommonOperandVocab,
380 VocabMap &PhyRegVocabMap, VocabMap &VirtRegVocabMap);
382};
383
384
386 using VocabVector = std::vectormir2vec::Embedding;
387 using VocabMap = std::map<std::string, mir2vec::Embedding>;
388
390
391protected:
397
398public:
401
406 Provider = std::make_unique(MMI);
407 return Provider->getVocabulary(M);
408 }
409
414};
415
416
419
420public:
424
432
434 return "MIR2Vec Vocabulary Printer Pass";
435 }
436};
437
438
439
442
443public:
447
454
456 return "MIR2Vec Embedder Printer Pass";
457 }
458};
459
460
462
463}
464
465#endif
unsigned const MachineRegisterInfo * MRI
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
const TargetInstrInfo & TII
#define LLVM_ABI_FOR_TEST
Provides ErrorOr smart pointer.
This file defines the IR2Vec vocabulary analysis(IR2VecVocabAnalysis), the core ir2vec::Embedder inte...
This header defines various interfaces for pass management in LLVM.
Register const TargetRegisterInfo * TRI
Promote Memory to Register
static cl::opt< RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode > Mode("regalloc-enable-advisor", cl::Hidden, cl::init(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Default), cl::desc("Enable regalloc advisor mode"), cl::values(clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Default, "default", "Default"), clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Release, "release", "precompiled"), clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Development, "development", "for training")))
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesAll()
Set by analyses that do not transform their input at all.
Lightweight error class with error context and mandatory checking.
Tagged union holding either a T or a Error.
This is an important class for using LLVM in a threaded context.
MIR2VecPrinterLegacyPass(raw_ostream &OS)
Definition MIR2Vec.h:445
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.
Definition MIR2Vec.h:449
bool runOnMachineFunction(MachineFunction &MF) override
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
static char ID
Definition MIR2Vec.h:444
StringRef getPassName() const override
getPassName - Return a nice clean name for a pass.
Definition MIR2Vec.h:455
Pass to analyze and populate MIR2Vec vocabulary from a module.
Definition MIR2Vec.h:385
static char ID
Definition MIR2Vec.h:399
MIR2VecVocabProvider & getProvider()
Definition MIR2Vec.h:410
Expected< mir2vec::MIRVocabulary > getMIR2VecVocabulary(const Module &M)
Definition MIR2Vec.h:402
std::unique_ptr< MIR2VecVocabProvider > Provider
Definition MIR2Vec.h:396
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition MIR2Vec.h:392
MIR2VecVocabLegacyAnalysis()
Definition MIR2Vec.h:400
StringRef getPassName() const override
getPassName - Return a nice clean name for a pass.
Definition MIR2Vec.h:433
bool doFinalization(Module &M) override
doFinalization - Virtual method overriden by subclasses to do any necessary clean up after all passes...
static char ID
Definition MIR2Vec.h:421
bool runOnMachineFunction(MachineFunction &MF) override
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
MIR2VecVocabPrinterLegacyPass(raw_ostream &OS)
Definition MIR2Vec.h:422
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.
Definition MIR2Vec.h:427
MIR2Vec vocabulary provider used by pass managers and standalone tools.
Definition MIR2Vec.h:370
MIR2VecVocabProvider(const MachineModuleInfo &MMI)
Definition MIR2Vec.h:374
LLVM_ABI Expected< mir2vec::MIRVocabulary > getVocabulary(const Module &M)
MachineFunctionPass(char &ID)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.
Representation of each machine instruction.
This class contains meta information specific to a module.
MachineOperand class - Representation of each machine instruction operand.
MachineOperandType getType() const
getType - Returns the MachineOperandType for this operand.
Register getReg() const
getReg - Returns the register number.
@ MO_Register
Register operand.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
A Module instance is used to store all the information related to an LLVM module.
AnalysisType & getAnalysis() const
getAnalysis() - This function is used by subclasses to get to the analysis information ...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
A wrapper around a string literal that serves as a proxy for constructing global tables of StringRefs...
StringRef - Represent a constant reference to a string, i.e.
TargetInstrInfo - Interface to description of machine instruction set.
TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...
Iterator support for section-based access.
Generic storage class for section-based vocabularies.
Base class for MIR embedders.
Definition MIR2Vec.h:291
const unsigned Dimension
Dimension of the embeddings; Captured from the vocabulary.
Definition MIR2Vec.h:297
Embedding getMFunctionVector() const
Computes and returns the embedding for the current machine function.
Definition MIR2Vec.h:340
const MIRVocabulary & Vocab
Definition MIR2Vec.h:294
Embedding getMInstVector(const MachineInstr &MI) const
Computes and returns the embedding for a given machine instruction MI in the machine function MF.
Definition MIR2Vec.h:329
virtual Embedding computeEmbeddings(const MachineInstr &MI) const =0
Function to compute the embedding for a given machine instruction.
Embedding getMBBVector(const MachineBasicBlock &MBB) const
Computes and returns the embedding for a given machine basic block in the machine function MF.
Definition MIR2Vec.h:335
const float RegOperandWeight
Definition MIR2Vec.h:300
MIREmbedder(const MachineFunction &MF, const MIRVocabulary &Vocab)
Definition MIR2Vec.h:302
const float CommonOperandWeight
Definition MIR2Vec.h:300
LLVM_ABI Embedding computeEmbeddings() const
Function to compute embeddings.
const float OpcWeight
Weight for opcode embeddings.
Definition MIR2Vec.h:300
const MachineFunction & MF
Definition MIR2Vec.h:293
virtual ~MIREmbedder()=default
static LLVM_ABI std::unique_ptr< MIREmbedder > create(MIR2VecKind Mode, const MachineFunction &MF, const MIRVocabulary &Vocab)
Factory method to create an Embedder object of the specified kind Returns nullptr if the requested ki...
Class for storing and accessing the MIR2Vec vocabulary.
Definition MIR2Vec.h:86
size_t OpcodeBase
Definition MIR2Vec.h:107
unsigned getDimension() const
Definition MIR2Vec.h:230
unsigned getEntityIDForOpcode(unsigned Opcode) const
Get entity ID (flat index) for an opcode This is used for triplet generation.
Definition MIR2Vec.h:234
const_iterator end() const
Definition MIR2Vec.h:264
LLVM_ABI_FOR_TEST unsigned getCanonicalIndexForOperandName(StringRef OperandName) const
const Embedding & operator[](MachineOperand Operand) const
Definition MIR2Vec.h:252
LLVM_ABI_FOR_TEST unsigned getCanonicalIndexForRegisterClass(StringRef RegName, bool IsPhysical=true) const
static LLVM_ABI_FOR_TEST Expected< MIRVocabulary > create(VocabMap &&OpcMap, VocabMap &&CommonOperandsMap, VocabMap &&PhyRegMap, VocabMap &&VirtRegMap, const TargetInstrInfo &TII, const TargetRegisterInfo &TRI, const MachineRegisterInfo &MRI)
Factory method to create MIRVocabulary from vocabulary map.
static LLVM_ABI_FOR_TEST std::string extractBaseOpcodeName(StringRef InstrName)
Static method for extracting base opcode names (public for testing)
ir2vec::VocabStorage::const_iterator const_iterator
Definition MIR2Vec.h:261
const_iterator begin() const
Definition MIR2Vec.h:262
const Embedding & operator[](unsigned Opcode) const
Definition MIR2Vec.h:247
size_t getCanonicalSize() const
Total number of entries in the vocabulary.
Definition MIR2Vec.h:281
static LLVM_ABI Expected< MIRVocabulary > createDummyVocabForTest(const TargetInstrInfo &TII, const TargetRegisterInfo &TRI, const MachineRegisterInfo &MRI, unsigned Dim=1)
Create a dummy vocabulary for testing purposes.
unsigned getEntityIDForMachineOperand(const MachineOperand &MO) const
Get entity ID (flat index) for a machine operand This is used for triplet generation.
Definition MIR2Vec.h:240
size_t PhyRegBase
Definition MIR2Vec.h:109
LLVM_ABI std::string getStringKey(unsigned Pos) const
Get the string key for a vocabulary entry at the given position.
size_t CommonOperandBase
Definition MIR2Vec.h:108
size_t VirtRegBase
Definition MIR2Vec.h:110
size_t TotalEntries
Definition MIR2Vec.h:111
LLVM_ABI_FOR_TEST unsigned getCanonicalIndexForBaseName(StringRef BaseName) const
Get indices from opcode or operand names.
Class for computing Symbolic embeddings Symbolic embeddings are constructed based on the entity-level...
Definition MIR2Vec.h:350
static LLVM_ABI_FOR_TEST std::unique_ptr< SymbolicMIREmbedder > create(const MachineFunction &MF, const MIRVocabulary &Vocab)
SymbolicMIREmbedder(const MachineFunction &F, const MIRVocabulary &Vocab)
This class implements an extremely fast bulk output stream that can only output to a stream.
DenseMap< const MachineInstr *, Embedding > MachineInstEmbeddingsMap
Definition MIR2Vec.h:80
LLVM_ABI llvm:🆑:OptionCategory MIR2VecCategory
LLVM_ABI cl::opt< float > OpcWeight
LLVM_ABI cl::opt< float > RegOperandWeight
Definition MIR2Vec.h:77
ir2vec::Embedding Embedding
Definition MIR2Vec.h:79
DenseMap< const MachineBasicBlock *, Embedding > MachineBlockEmbeddingsMap
Definition MIR2Vec.h:81
LLVM_ABI cl::opt< float > CommonOperandWeight
Definition MIR2Vec.h:77
This is an optimization pass for GlobalISel generic memory operations.
LLVM_ABI MachineFunctionPass * createMIR2VecPrinterLegacyPass(raw_ostream &OS)
Create a machine pass that prints MIR2Vec embeddings.
MIR2VecKind
Definition MIR2Vec.h:68
Embedding is a datatype that wraps std::vector.