LLVM 22.0.0git
MIR2Vec.cpp
Go to the documentation of this file.
1//===- MIR2Vec.cpp - Implementation of MIR2Vec ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file implements the MIR2Vec algorithm for Machine IR embeddings.
11///
12//===----------------------------------------------------------------------===//
13
15#include "llvm/ADT/Statistic.h"
17#include "llvm/IR/Module.h"
19#include "llvm/Pass.h"
20#include "llvm/Support/Errc.h"
22#include "llvm/Support/Regex.h"
23
24using namespace llvm;
25using namespace mir2vec;
26
27#define DEBUG_TYPE "mir2vec"
28
29STATISTIC(MIRVocabMissCounter,
30 "Number of lookups to MIR entities not present in the vocabulary");
31
32namespace llvm {
33namespace mir2vec {
35
36// FIXME: Use a default vocab when not specified
38 VocabFile("mir2vec-vocab-path", cl::Optional,
39 cl::desc("Path to the vocabulary file for MIR2Vec"), cl::init(""),
41cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
42 cl::desc("Weight for machine opcode embeddings"),
44} // namespace mir2vec
45} // namespace llvm
46
47//===----------------------------------------------------------------------===//
48// Vocabulary Implementation
49//===----------------------------------------------------------------------===//
50
51MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
52 const TargetInstrInfo &TII)
53 : TII(TII) {
54 buildCanonicalOpcodeMapping();
55
56 unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size();
57 assert(CanonicalOpcodeCount > 0 &&
58 "No canonical opcodes found for target - invalid vocabulary");
59 Layout.OperandBase = CanonicalOpcodeCount;
60 generateStorage(OpcodeEntries);
61 Layout.TotalEntries = Storage.size();
62}
63
65 const TargetInstrInfo &TII) {
66 if (Entries.empty())
68 "Empty vocabulary entries provided");
69
70 return MIRVocabulary(std::move(Entries), TII);
71}
72
74 // Extract base instruction name using regex to capture letters and
75 // underscores Examples: "ADD32rr" -> "ADD", "ARITH_FENCE" -> "ARITH_FENCE"
76 //
77 // TODO: Consider more sophisticated extraction:
78 // - Handle complex prefixes like "AVX1_SETALLONES" correctly (Currently, it
79 // would naively map to "AVX")
80 // - Extract width suffixes (8,16,32,64) as separate features
81 // - Capture addressing mode suffixes (r,i,m,ri,etc.) for better analysis
82 // (Currently, instances like "MOV32mi" map to "MOV", but "ADDPDrr" would map
83 // to "ADDPDrr")
84
85 assert(!InstrName.empty() && "Instruction name should not be empty");
86
87 // Use regex to extract initial sequence of letters and underscores
88 static const Regex BaseOpcodeRegex("([a-zA-Z_]+)");
90
91 if (BaseOpcodeRegex.match(InstrName, &Matches) && Matches.size() > 1) {
92 StringRef Match = Matches[1];
93 // Trim trailing underscores
94 while (!Match.empty() && Match.back() == '_')
95 Match = Match.drop_back();
96 return Match.str();
97 }
98
99 // Fallback to original name if no pattern matches
100 return InstrName.str();
101}
102
104 assert(!UniqueBaseOpcodeNames.empty() && "Canonical mapping not built");
105 auto It = std::find(UniqueBaseOpcodeNames.begin(),
106 UniqueBaseOpcodeNames.end(), BaseName.str());
107 assert(It != UniqueBaseOpcodeNames.end() &&
108 "Base name not found in unique opcodes");
109 return std::distance(UniqueBaseOpcodeNames.begin(), It);
110}
111
112unsigned MIRVocabulary::getCanonicalOpcodeIndex(unsigned Opcode) const {
113 auto BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode));
114 return getCanonicalIndexForBaseName(BaseOpcode);
115}
116
117std::string MIRVocabulary::getStringKey(unsigned Pos) const {
118 assert(Pos < Layout.TotalEntries && "Position out of bounds in vocabulary");
119
120 // For now, all entries are opcodes since we only have one section
121 if (Pos < Layout.OperandBase && Pos < UniqueBaseOpcodeNames.size()) {
122 // Convert canonical index back to base opcode name
123 auto It = UniqueBaseOpcodeNames.begin();
124 std::advance(It, Pos);
125 return *It;
126 }
127
128 llvm_unreachable("Invalid position in vocabulary");
129 return "";
130}
131
132void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap) {
133
134 // Helper for handling missing entities in the vocabulary.
135 // Currently, we use a zero vector. In the future, we will throw an error to
136 // ensure that *all* known entities are present in the vocabulary.
137 auto handleMissingEntity = [](StringRef Key) {
138 LLVM_DEBUG(errs() << "MIR2Vec: Missing vocabulary entry for " << Key
139 << "; using zero vector. This will result in an error "
140 "in the future.\n");
141 ++MIRVocabMissCounter;
142 };
143
144 // Initialize opcode embeddings section
145 unsigned EmbeddingDim = OpcodeMap.begin()->second.size();
146 std::vector<Embedding> OpcodeEmbeddings(Layout.OperandBase,
147 Embedding(EmbeddingDim));
148
149 // Populate opcode embeddings using canonical mapping
150 for (auto COpcodeName : UniqueBaseOpcodeNames) {
151 if (auto It = OpcodeMap.find(COpcodeName); It != OpcodeMap.end()) {
152 auto COpcodeIndex = getCanonicalIndexForBaseName(COpcodeName);
153 assert(COpcodeIndex < Layout.OperandBase &&
154 "Canonical index out of bounds");
155 OpcodeEmbeddings[COpcodeIndex] = It->second;
156 } else {
157 handleMissingEntity(COpcodeName);
158 }
159 }
160
161 // TODO: Add operand/argument embeddings as additional sections
162 // This will require extending the vocabulary format and layout
163
164 // Scale the vocabulary sections based on the provided weights
165 auto scaleVocabSection = [](std::vector<Embedding> &Embeddings,
166 double Weight) {
167 for (auto &Embedding : Embeddings)
168 Embedding *= Weight;
169 };
170 scaleVocabSection(OpcodeEmbeddings, OpcWeight);
171
172 std::vector<std::vector<Embedding>> Sections(1);
173 Sections[0] = std::move(OpcodeEmbeddings);
174
175 Storage = ir2vec::VocabStorage(std::move(Sections));
176}
177
178void MIRVocabulary::buildCanonicalOpcodeMapping() {
179 // Check if already built
180 if (!UniqueBaseOpcodeNames.empty())
181 return;
182
183 // Build mapping from opcodes to canonical base opcode indices
184 for (unsigned Opcode = 0; Opcode < TII.getNumOpcodes(); ++Opcode) {
185 std::string BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode));
186 UniqueBaseOpcodeNames.insert(BaseOpcode);
187 }
188
189 LLVM_DEBUG(dbgs() << "MIR2Vec: Built canonical mapping for target with "
190 << UniqueBaseOpcodeNames.size()
191 << " unique base opcodes\n");
192}
193
194//===----------------------------------------------------------------------===//
195// MIR2VecVocabLegacyAnalysis Implementation
196//===----------------------------------------------------------------------===//
197
200 "MIR2Vec Vocabulary Analysis", false, true)
203 "MIR2Vec Vocabulary Analysis", false, true)
204
205StringRef MIR2VecVocabLegacyAnalysis::getPassName() const {
206 return "MIR2Vec Vocabulary Analysis";
207}
208
209Error MIR2VecVocabLegacyAnalysis::readVocabulary() {
210 // TODO: Extend vocabulary format to support multiple sections
211 // (opcodes, operands, etc.) similar to IR2Vec structure
212 if (VocabFile.empty())
213 return createStringError(
215 "MIR2Vec vocabulary file path not specified; set it "
216 "using --mir2vec-vocab-path");
217
218 auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true);
219 if (!BufOrError)
220 return createFileError(VocabFile, BufOrError.getError());
221
222 auto Content = BufOrError.get()->getBuffer();
223
224 Expected<json::Value> ParsedVocabValue = json::parse(Content);
225 if (!ParsedVocabValue)
226 return ParsedVocabValue.takeError();
227
228 unsigned Dim = 0;
230 "entities", *ParsedVocabValue, StrVocabMap, Dim))
231 return Err;
232
233 return Error::success();
234}
235
238 if (StrVocabMap.empty()) {
239 if (Error Err = readVocabulary()) {
240 return std::move(Err);
241 }
242 }
243
244 // Get machine module info to access machine functions and target info
246
247 // Find first available machine function to get target instruction info
248 for (const auto &F : M) {
249 if (F.isDeclaration())
250 continue;
251
252 if (auto *MF = MMI.getMachineFunction(F)) {
253 const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo();
254 return mir2vec::MIRVocabulary::create(std::move(StrVocabMap), *TII);
255 }
256 }
257
258 // No machine functions available - return error
260 "No machine functions found in module");
261}
262
263//===----------------------------------------------------------------------===//
264// Printer Passes Implementation
265//===----------------------------------------------------------------------===//
266
269 "MIR2Vec Vocabulary Printer Pass", false, true)
273 "MIR2Vec Vocabulary Printer Pass", false, true)
274
278
281 auto MIR2VecVocabOrErr = Analysis.getMIR2VecVocabulary(M);
282
283 if (!MIR2VecVocabOrErr) {
284 OS << "MIR2Vec Vocabulary Printer: Failed to get vocabulary - "
285 << toString(MIR2VecVocabOrErr.takeError()) << "\n";
286 return false;
287 }
288
289 auto &MIR2VecVocab = *MIR2VecVocabOrErr;
290 unsigned Pos = 0;
291 for (const auto &Entry : MIR2VecVocab) {
292 OS << "Key: " << MIR2VecVocab.getStringKey(Pos++) << ": ";
293 Entry.print(OS);
294 }
295
296 return false;
297}
298
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
block Block Frequency Analysis
const HexagonInstrInfo * TII
Module.h This file contains the declarations for the Module class.
#define F(x, y, z)
Definition MD5.cpp:55
This file defines the MIR2Vec vocabulary analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::MIRE...
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
#define LLVM_DEBUG(...)
Definition Debug.h:114
Lightweight error class with error context and mandatory checking.
Definition Error.h:159
static ErrorSuccess success()
Create a success value.
Definition Error.h:336
Tagged union holding either a T or a Error.
Definition Error.h:485
Error takeError()
Take ownership of the stored error.
Definition Error.h:612
Pass to analyze and populate MIR2Vec vocabulary from a module.
Definition MIR2Vec.h:126
Expected< mir2vec::MIRVocabulary > getMIR2VecVocabulary(const Module &M)
Definition MIR2Vec.cpp:237
This pass prints the embeddings in the MIR2Vec vocabulary.
Definition MIR2Vec.h:148
bool doFinalization(Module &M) override
doFinalization - Virtual method overriden by subclasses to do any necessary clean up after all passes...
Definition MIR2Vec.cpp:279
bool runOnMachineFunction(MachineFunction &MF) override
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
Definition MIR2Vec.cpp:275
MIR2VecVocabPrinterLegacyPass(raw_ostream &OS)
Definition MIR2Vec.h:153
MachineFunctionPass - This class adapts the FunctionPass interface to allow convenient creation of pa...
This class contains meta information specific to a module.
LLVM_ABI MachineFunction * getMachineFunction(const Function &F) const
Returns the MachineFunction associated to IR function F if there is one, otherwise nullptr.
static ErrorOr< std::unique_ptr< MemoryBuffer > > getFileOrSTDIN(const Twine &Filename, bool IsText=false, bool RequiresNullTerminator=true, std::optional< Align > Alignment=std::nullopt)
Open the specified file as a MemoryBuffer, or open stdin if the Filename is "-".
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
LLVM_ABI bool match(StringRef String, SmallVectorImpl< StringRef > *Matches=nullptr, std::string *Error=nullptr) const
matches - Match the regex against a given String.
Definition Regex.cpp:83
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
std::string str() const
str - Get the contents as an std::string.
Definition StringRef.h:225
constexpr bool empty() const
empty - Check if the string is empty.
Definition StringRef.h:143
char back() const
back - Get the last character in the string.
Definition StringRef.h:155
StringRef drop_back(size_t N=1) const
Return a StringRef equal to 'this' but with the last N elements dropped.
Definition StringRef.h:618
TargetInstrInfo - Interface to description of machine instruction set.
Generic storage class for section-based vocabularies.
Definition IR2Vec.h:151
static Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab, unsigned &Dim)
Parse a vocabulary section from JSON and populate the target vocabulary map.
Definition IR2Vec.cpp:316
static std::string extractBaseOpcodeName(StringRef InstrName)
Static method for extracting base opcode names (public for testing)
Definition MIR2Vec.cpp:73
static Expected< MIRVocabulary > create(VocabMap &&Entries, const TargetInstrInfo &TII)
Factory method to create MIRVocabulary from vocabulary map.
Definition MIR2Vec.cpp:64
std::string getStringKey(unsigned Pos) const
Get the string key for a vocabulary entry at the given position.
Definition MIR2Vec.cpp:117
unsigned getCanonicalIndexForBaseName(StringRef BaseName) const
Get canonical index for base name (public for testing)
Definition MIR2Vec.cpp:103
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
initializer< Ty > init(const Ty &Val)
LLVM_ABI llvm::Expected< Value > parse(llvm::StringRef JSON)
Parses the provided JSON source, or returns a ParseError.
Definition JSON.cpp:675
llvm::cl::OptionCategory MIR2VecCategory
cl::opt< float > OpcWeight
static cl::opt< std::string > VocabFile("mir2vec-vocab-path", cl::Optional, cl::desc("Path to the vocabulary file for MIR2Vec"), cl::init(""), cl::cat(MIR2VecCategory))
ir2vec::Embedding Embedding
Definition MIR2Vec.h:59
This is an optimization pass for GlobalISel generic memory operations.
Error createFileError(const Twine &F, Error E)
Concatenate a source file path and/or name with an Error.
Definition Error.h:1399
Error createStringError(std::error_code EC, char const *Fmt, const Ts &... Vals)
Create formatted StringError object.
Definition Error.h:1305
@ invalid_argument
Definition Errc.h:56
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
LLVM_ABI MachineFunctionPass * createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS)
MIR2VecVocabPrinter pass - This pass prints out the MIR2Vec vocabulary contents to the given stream a...
Definition MIR2Vec.cpp:300
std::string toString(const APInt &I, unsigned Radix, bool Signed, bool formatAsCLiteral=false, bool UpperCase=true, bool InsertSeparators=false)