28#define DEBUG_TYPE "mir2vec"
31 "Number of lookups to MIR entities not present in the vocabulary");
43 cl::desc(
"Weight for machine opcode embeddings"),
50 cl::desc(
"Weight for register operand embeddings"),
55 "Generate symbolic embeddings for MIR")),
61 cl::desc(
"Print all vocabulary entries including zero embeddings"),
72 VocabMap &&PhysicalRegisterMap,
73 VocabMap &&VirtualRegisterMap,
78 buildCanonicalOpcodeMapping();
79 unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size();
80 assert(CanonicalOpcodeCount > 0 &&
81 "No canonical opcodes found for target - invalid vocabulary");
83 buildRegisterOperandMapping();
86 Layout.OpcodeBase = 0;
87 Layout.CommonOperandBase = CanonicalOpcodeCount;
89 Layout.PhyRegBase = Layout.CommonOperandBase + std::size(CommonOperandNames);
90 Layout.VirtRegBase = Layout.PhyRegBase + RegisterOperandNames.size();
92 generateStorage(OpcodeMap, CommonOperandMap, PhysicalRegisterMap,
94 Layout.TotalEntries = Storage.size();
102 if (OpcodeMap.empty() || CommonOperandMap.empty() || PhyRegMap.empty() ||
105 "Empty vocabulary entries provided");
107 MIRVocabulary Vocab(std::move(OpcodeMap), std::move(CommonOperandMap),
108 std::move(PhyRegMap), std::move(
VirtRegMap), TII, TRI,
114 "Failed to create valid vocabulary storage");
116 return std::move(Vocab);
131 assert(!InstrName.
empty() &&
"Instruction name should not be empty");
134 static const Regex BaseOpcodeRegex(
"([a-zA-Z_]+)");
137 if (BaseOpcodeRegex.
match(InstrName, &Matches) && Matches.
size() > 1) {
140 while (!Match.
empty() && Match.
back() ==
'_')
146 return InstrName.
str();
150 assert(!UniqueBaseOpcodeNames.empty() &&
"Canonical mapping not built");
151 auto It = std::find(UniqueBaseOpcodeNames.begin(),
152 UniqueBaseOpcodeNames.end(), BaseName.
str());
153 assert(It != UniqueBaseOpcodeNames.end() &&
154 "Base name not found in unique opcodes");
155 return std::distance(UniqueBaseOpcodeNames.begin(), It);
158unsigned MIRVocabulary::getCanonicalOpcodeIndex(
unsigned Opcode)
const {
159 auto BaseOpcode = extractBaseOpcodeName(
TII.getName(Opcode));
160 return getCanonicalIndexForBaseName(BaseOpcode);
165 auto It = std::find(std::begin(CommonOperandNames),
166 std::end(CommonOperandNames), OperandName);
167 assert(It != std::end(CommonOperandNames) &&
168 "Operand name not found in common operands");
169 return Layout.CommonOperandBase +
170 std::distance(std::begin(CommonOperandNames), It);
175 bool IsPhysical)
const {
176 auto It = std::find(RegisterOperandNames.begin(), RegisterOperandNames.end(),
178 assert(It != RegisterOperandNames.end() &&
179 "Register name not found in register operands");
180 unsigned LocalIndex = std::distance(RegisterOperandNames.begin(), It);
181 return (IsPhysical ? Layout.PhyRegBase : Layout.VirtRegBase) + LocalIndex;
185 assert(Pos < Layout.TotalEntries &&
"Position out of bounds in vocabulary");
188 if (Pos < Layout.CommonOperandBase) {
190 auto It = UniqueBaseOpcodeNames.begin();
191 std::advance(It, Pos);
192 assert(It != UniqueBaseOpcodeNames.end() &&
193 "Canonical index out of bounds in opcode section");
197 auto getLocalIndex = [](
unsigned Pos,
size_t BaseOffset,
size_t Bound,
199 unsigned LocalIndex = Pos - BaseOffset;
200 assert(LocalIndex < Bound && Msg);
205 if (Pos < Layout.PhyRegBase) {
206 unsigned LocalIndex = getLocalIndex(
207 Pos, Layout.CommonOperandBase, std::size(CommonOperandNames),
208 "Local index out of bounds in common operands");
209 return CommonOperandNames[LocalIndex].str();
213 if (Pos < Layout.VirtRegBase) {
214 unsigned LocalIndex =
215 getLocalIndex(Pos, Layout.PhyRegBase, RegisterOperandNames.size(),
216 "Local index out of bounds in physical registers");
217 return "PhyReg_" + RegisterOperandNames[LocalIndex];
221 unsigned LocalIndex =
222 getLocalIndex(Pos, Layout.VirtRegBase, RegisterOperandNames.size(),
223 "Local index out of bounds in virtual registers");
224 return "VirtReg_" + RegisterOperandNames[LocalIndex];
227void MIRVocabulary::generateStorage(
const VocabMap &OpcodeMap,
228 const VocabMap &CommonOperandsMap,
229 const VocabMap &PhyRegMap,
237 <<
"; using zero vector. This will result in an error "
239 ++MIRVocabMissCounter;
243 unsigned EmbeddingDim = OpcodeMap.begin()->second.size();
244 std::vector<Embedding> OpcodeEmbeddings(Layout.CommonOperandBase,
248 for (
auto COpcodeName : UniqueBaseOpcodeNames) {
249 if (
auto It = OpcodeMap.find(COpcodeName); It != OpcodeMap.end()) {
250 auto COpcodeIndex = getCanonicalIndexForBaseName(COpcodeName);
251 assert(COpcodeIndex < Layout.CommonOperandBase &&
252 "Canonical index out of bounds");
253 OpcodeEmbeddings[COpcodeIndex] = It->second;
255 handleMissingEntity(COpcodeName);
260 std::vector<Embedding> CommonOperandEmbeddings(std::size(CommonOperandNames),
262 unsigned OperandIndex = 0;
263 for (
const auto &CommonOperandName : CommonOperandNames) {
264 if (
auto It = CommonOperandsMap.find(CommonOperandName.str());
265 It != CommonOperandsMap.end()) {
266 CommonOperandEmbeddings[OperandIndex] = It->second;
268 handleMissingEntity(CommonOperandName);
274 auto createRegisterEmbeddings = [&](
const VocabMap &RegMap) {
275 std::vector<Embedding> RegEmbeddings(
TRI.getNumRegClasses(),
277 unsigned RegOperandIndex = 0;
278 for (
const auto &RegOperandName : RegisterOperandNames) {
279 if (
auto It = RegMap.find(RegOperandName); It != RegMap.end())
280 RegEmbeddings[RegOperandIndex] = It->second;
282 handleMissingEntity(RegOperandName);
285 return RegEmbeddings;
289 std::vector<Embedding> PhyRegEmbeddings = createRegisterEmbeddings(PhyRegMap);
290 std::vector<Embedding> VirtRegEmbeddings =
294 auto scaleVocabSection = [](std::vector<Embedding> &Embeddings,
299 scaleVocabSection(OpcodeEmbeddings,
OpcWeight);
304 std::vector<std::vector<Embedding>> Sections(
305 static_cast<unsigned>(Section::MaxSections));
306 Sections[
static_cast<unsigned>(Section::Opcodes)] =
307 std::move(OpcodeEmbeddings);
308 Sections[
static_cast<unsigned>(Section::CommonOperands)] =
309 std::move(CommonOperandEmbeddings);
310 Sections[
static_cast<unsigned>(Section::PhyRegisters)] =
311 std::move(PhyRegEmbeddings);
312 Sections[
static_cast<unsigned>(Section::VirtRegisters)] =
313 std::move(VirtRegEmbeddings);
318void MIRVocabulary::buildCanonicalOpcodeMapping() {
320 if (!UniqueBaseOpcodeNames.empty())
324 for (
unsigned Opcode = 0; Opcode <
TII.getNumOpcodes(); ++Opcode) {
325 std::string BaseOpcode = extractBaseOpcodeName(
TII.getName(Opcode));
326 UniqueBaseOpcodeNames.insert(BaseOpcode);
329 LLVM_DEBUG(
dbgs() <<
"MIR2Vec: Built canonical mapping for target with "
330 << UniqueBaseOpcodeNames.size()
331 <<
" unique base opcodes\n");
334void MIRVocabulary::buildRegisterOperandMapping() {
336 if (!RegisterOperandNames.empty())
339 for (
unsigned RC = 0; RC <
TRI.getNumRegClasses(); ++RC) {
346 RegisterOperandNames.push_back(ClassName.
str());
350unsigned MIRVocabulary::getCommonOperandIndex(
353 "Expected non-register operand type");
359unsigned MIRVocabulary::getRegisterOperandIndex(
Register Reg)
const {
360 assert(!RegisterOperandNames.empty() &&
"Register operand mapping not built");
363 "Expected a physical or virtual register");
371 RegClass =
TRI.getMinimalPhysRegClass(
Reg);
373 RegClass =
MRI.getRegClass(
Reg);
376 return RegClass->
getID();
385 assert(Dim > 0 &&
"Dimension must be greater than zero");
387 float DummyVal = 0.1f;
389 VocabMap DummyOpcMap, DummyOperandMap, DummyPhyRegMap, DummyVirtRegMap;
392 for (
unsigned Opcode = 0; Opcode < TII.getNumOpcodes(); ++Opcode) {
394 if (DummyOpcMap.count(BaseOpcode) == 0) {
395 DummyOpcMap[BaseOpcode] =
Embedding(Dim, DummyVal);
401 for (
const auto &CommonOperandName : CommonOperandNames) {
402 DummyOperandMap[CommonOperandName.str()] =
Embedding(Dim, DummyVal);
407 for (
unsigned RC = 0; RC < TRI.getNumRegClasses(); ++RC) {
412 std::string ClassName = TRI.getRegClassName(RegClass);
413 DummyPhyRegMap[ClassName] =
Embedding(Dim, DummyVal);
414 DummyVirtRegMap[ClassName] =
Embedding(Dim, DummyVal);
420 std::move(DummyOpcMap), std::move(DummyOperandMap),
421 std::move(DummyPhyRegMap), std::move(DummyVirtRegMap), TII, TRI, MRI);
430 VocabMap OpcVocab, CommonOperandVocab, PhyRegVocabMap, VirtRegVocabMap;
432 if (
Error Err = readVocabulary(OpcVocab, CommonOperandVocab, PhyRegVocabMap,
434 return std::move(Err);
436 for (
const auto &
F : M) {
437 if (
F.isDeclaration())
440 if (
auto *MF = MMI.getMachineFunction(
F)) {
441 auto &Subtarget = MF->getSubtarget();
442 if (
const auto *
TII = Subtarget.getInstrInfo())
443 if (
const auto *
TRI = Subtarget.getRegisterInfo())
445 std::move(OpcVocab), std::move(CommonOperandVocab),
446 std::move(PhyRegVocabMap), std::move(VirtRegVocabMap), *
TII, *
TRI,
451 "No machine functions found in module");
454Error MIR2VecVocabProvider::readVocabulary(VocabMap &OpcodeVocab,
455 VocabMap &CommonOperandVocab,
456 VocabMap &PhyRegVocabMap,
457 VocabMap &VirtRegVocabMap) {
461 "MIR2Vec vocabulary file path not specified; set it "
462 "using --mir2vec-vocab-path");
468 auto Content = BufOrError.get()->getBuffer();
471 if (!ParsedVocabValue)
474 unsigned OpcodeDim = 0, CommonOperandDim = 0, PhyRegOperandDim = 0,
475 VirtRegOperandDim = 0;
477 "Opcodes", *ParsedVocabValue, OpcodeVocab, OpcodeDim))
481 "CommonOperands", *ParsedVocabValue, CommonOperandVocab,
486 "PhysicalRegisters", *ParsedVocabValue, PhyRegVocabMap,
491 "VirtualRegisters", *ParsedVocabValue, VirtRegVocabMap,
496 if (!(OpcodeDim == CommonOperandDim && CommonOperandDim == PhyRegOperandDim &&
497 PhyRegOperandDim == VirtRegOperandDim)) {
500 "MIR2Vec vocabulary sections have different dimensions");
508 "MIR2Vec Vocabulary Analysis",
false,
true)
514 return "MIR2Vec Vocabulary Analysis";
526 return std::make_unique<SymbolicMIREmbedder>(
MF,
Vocab);
535 const auto &Subtarget =
MF.getSubtarget();
536 const auto *
TII = Subtarget.getInstrInfo();
538 MF.getFunction().getContext().emitError(
539 "MIR2Vec: No TargetInstrInfo available; cannot compute embeddings");
544 for (
const auto &
MI :
MBB) {
546 if (
MI.isDebugInstr())
567std::unique_ptr<SymbolicMIREmbedder>
570 return std::make_unique<SymbolicMIREmbedder>(
MF,
Vocab);
575 if (
MI.isDebugInstr())
583 InstructionEmbedding +=
Vocab[MO];
585 return InstructionEmbedding;
594 "MIR2Vec Vocabulary Printer Pass",
false,
true)
606 auto MIR2VecVocabOrErr =
Analysis.getMIR2VecVocabulary(M);
608 if (!MIR2VecVocabOrErr) {
609 OS <<
"MIR2Vec Vocabulary Printer: Failed to get vocabulary - "
610 <<
toString(MIR2VecVocabOrErr.takeError()) <<
"\n";
614 auto &MIR2VecVocab = *MIR2VecVocabOrErr;
616 for (
const auto &Entry : MIR2VecVocab) {
620 OS <<
"Key: " << MIR2VecVocab.getStringKey(Pos) <<
": ";
636 "MIR2Vec Embedder Printer Pass",
false,
true)
640 "MIR2Vec Embedder Printer Pass",
false,
true)
645 Analysis.getMIR2VecVocabulary(*MF.getFunction().getParent());
646 assert(VocabOrErr &&
"Failed to get MIR2Vec vocabulary");
647 auto &MIRVocab = *VocabOrErr;
651 OS <<
"Error creating MIR2Vec embeddings for function " << MF.getName()
656 OS <<
"MIR2Vec embeddings for machine function " << MF.getName() <<
":\n";
657 OS <<
"Machine Function vector: ";
658 Emb->getMFunctionVector().print(OS);
660 OS <<
"Machine basic block vectors:\n";
662 OS <<
"Machine basic block: " <<
MBB.getFullName() <<
":\n";
663 Emb->getMBBVector(
MBB).print(OS);
666 OS <<
"Machine instruction vectors:\n";
671 if (
MI.isDebugInstr())
674 OS <<
"Machine instruction: ";
676 Emb->getMInstVector(
MI).print(OS);
unsigned const MachineRegisterInfo * MRI
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
block Block Frequency Analysis
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
const HexagonInstrInfo * TII
Module.h This file contains the declarations for the Module class.
This file defines the MIR2Vec framework for generating Machine IR embeddings.
Register const TargetRegisterInfo * TRI
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
SmallVector< MachineBasicBlock *, 4 > MBBVector
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Lightweight error class with error context and mandatory checking.
static ErrorSuccess success()
Create a success value.
Tagged union holding either a T or a Error.
Error takeError()
Take ownership of the stored error.
This pass prints the MIR2Vec embeddings for machine functions, basic blocks, and instructions.
MIR2VecPrinterLegacyPass(raw_ostream &OS)
bool runOnMachineFunction(MachineFunction &MF) override
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
Pass to analyze and populate MIR2Vec vocabulary from a module.
This pass prints the embeddings in the MIR2Vec vocabulary.
bool doFinalization(Module &M) override
doFinalization - Virtual method overriden by subclasses to do any necessary clean up after all passes...
bool runOnMachineFunction(MachineFunction &MF) override
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
MIR2VecVocabPrinterLegacyPass(raw_ostream &OS)
LLVM_ABI Expected< mir2vec::MIRVocabulary > getVocabulary(const Module &M)
MachineFunctionPass - This class adapts the FunctionPass interface to allow convenient creation of pa...
Representation of each machine instruction.
MachineOperand class - Representation of each machine instruction operand.
@ MO_Register
Register operand.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
static ErrorOr< std::unique_ptr< MemoryBuffer > > getFileOrSTDIN(const Twine &Filename, bool IsText=false, bool RequiresNullTerminator=true, std::optional< Align > Alignment=std::nullopt)
Open the specified file as a MemoryBuffer, or open stdin if the Filename is "-".
A Module instance is used to store all the information related to an LLVM module.
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
LLVM_ABI bool match(StringRef String, SmallVectorImpl< StringRef > *Matches=nullptr, std::string *Error=nullptr) const
matches - Match the regex against a given String.
Wrapper class representing virtual and physical registers.
constexpr bool isValid() const
constexpr bool isVirtual() const
Return true if the specified register number is in the virtual register namespace.
constexpr bool isPhysical() const
Return true if the specified register number is in the physical register namespace.
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
std::string str() const
str - Get the contents as an std::string.
constexpr bool empty() const
empty - Check if the string is empty.
char back() const
back - Get the last character in the string.
StringRef drop_back(size_t N=1) const
Return a StringRef equal to 'this' but with the last N elements dropped.
TargetInstrInfo - Interface to description of machine instruction set.
unsigned getID() const
Return the register class ID number.
TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...
Generic storage class for section-based vocabularies.
static Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab, unsigned &Dim)
Parse a vocabulary section from JSON and populate the target vocabulary map.
unsigned getDimension() const
Get vocabulary dimension.
bool isValid() const
Check if vocabulary is valid (has data)
const unsigned Dimension
Dimension of the embeddings; Captured from the vocabulary.
const MIRVocabulary & Vocab
MIREmbedder(const MachineFunction &MF, const MIRVocabulary &Vocab)
LLVM_ABI Embedding computeEmbeddings() const
Function to compute embeddings.
const MachineFunction & MF
static LLVM_ABI std::unique_ptr< MIREmbedder > create(MIR2VecKind Mode, const MachineFunction &MF, const MIRVocabulary &Vocab)
Factory method to create an Embedder object of the specified kind Returns nullptr if the requested ki...
Class for storing and accessing the MIR2Vec vocabulary.
LLVM_ABI_FOR_TEST unsigned getCanonicalIndexForOperandName(StringRef OperandName) const
LLVM_ABI_FOR_TEST unsigned getCanonicalIndexForRegisterClass(StringRef RegName, bool IsPhysical=true) const
static LLVM_ABI_FOR_TEST Expected< MIRVocabulary > create(VocabMap &&OpcMap, VocabMap &&CommonOperandsMap, VocabMap &&PhyRegMap, VocabMap &&VirtRegMap, const TargetInstrInfo &TII, const TargetRegisterInfo &TRI, const MachineRegisterInfo &MRI)
Factory method to create MIRVocabulary from vocabulary map.
static LLVM_ABI_FOR_TEST std::string extractBaseOpcodeName(StringRef InstrName)
Static method for extracting base opcode names (public for testing)
static LLVM_ABI Expected< MIRVocabulary > createDummyVocabForTest(const TargetInstrInfo &TII, const TargetRegisterInfo &TRI, const MachineRegisterInfo &MRI, unsigned Dim=1)
Create a dummy vocabulary for testing purposes.
LLVM_ABI std::string getStringKey(unsigned Pos) const
Get the string key for a vocabulary entry at the given position.
LLVM_ABI_FOR_TEST unsigned getCanonicalIndexForBaseName(StringRef BaseName) const
Get indices from opcode or operand names.
static LLVM_ABI_FOR_TEST std::unique_ptr< SymbolicMIREmbedder > create(const MachineFunction &MF, const MIRVocabulary &Vocab)
SymbolicMIREmbedder(const MachineFunction &F, const MIRVocabulary &Vocab)
This class implements an extremely fast bulk output stream that can only output to a stream.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
OperandType
Operands are tagged with one of the values of this enum.
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
LLVM_ABI llvm::Expected< Value > parse(llvm::StringRef JSON)
Parses the provided JSON source, or returns a ParseError.
LLVM_ABI llvm::cl::OptionCategory MIR2VecCategory
LLVM_ABI cl::opt< float > OpcWeight
static cl::opt< std::string > VocabFile("mir2vec-vocab-path", cl::Optional, cl::desc("Path to the vocabulary file for MIR2Vec"), cl::init(""), cl::cat(MIR2VecCategory))
LLVM_ABI cl::opt< float > RegOperandWeight
static cl::opt< bool > PrintAllVocabEntries("mir2vec-print-all-vocab-entries", cl::Optional, cl::init(false), cl::desc("Print all vocabulary entries including zero embeddings"), cl::cat(MIR2VecCategory))
ir2vec::Embedding Embedding
LLVM_ABI cl::opt< float > CommonOperandWeight
cl::opt< MIR2VecKind > MIR2VecEmbeddingKind("mir2vec-kind", cl::Optional, cl::values(clEnumValN(MIR2VecKind::Symbolic, "symbolic", "Generate symbolic embeddings for MIR")), cl::init(MIR2VecKind::Symbolic), cl::desc("MIR2Vec embedding kind"), cl::cat(MIR2VecCategory))
This is an optimization pass for GlobalISel generic memory operations.
Error createFileError(const Twine &F, Error E)
Concatenate a source file path and/or name with an Error.
Error createStringError(std::error_code EC, char const *Fmt, const Ts &... Vals)
Create formatted StringError object.
LLVM_ABI MachineFunctionPass * createMIR2VecPrinterLegacyPass(raw_ostream &OS)
Create a machine pass that prints MIR2Vec embeddings.
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
LLVM_ABI MachineFunctionPass * createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS)
MIR2VecVocabPrinter pass - This pass prints out the MIR2Vec vocabulary contents to the given stream a...
std::string toString(const APInt &I, unsigned Radix, bool Signed, bool formatAsCLiteral=false, bool UpperCase=true, bool InsertSeparators=false)
iterator_range< df_iterator< T > > depth_first(const T &G)