27#define DEBUG_TYPE "mir2vec"
30 "Number of lookups to MIR entities not present in the vocabulary");
42 cl::desc(
"Weight for machine opcode embeddings"),
54 buildCanonicalOpcodeMapping();
56 unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size();
57 assert(CanonicalOpcodeCount > 0 &&
58 "No canonical opcodes found for target - invalid vocabulary");
59 Layout.OperandBase = CanonicalOpcodeCount;
60 generateStorage(OpcodeEntries);
61 Layout.TotalEntries = Storage.size();
68 "Empty vocabulary entries provided");
85 assert(!InstrName.
empty() &&
"Instruction name should not be empty");
88 static const Regex BaseOpcodeRegex(
"([a-zA-Z_]+)");
91 if (BaseOpcodeRegex.
match(InstrName, &Matches) && Matches.
size() > 1) {
94 while (!Match.
empty() && Match.
back() ==
'_')
100 return InstrName.
str();
104 assert(!UniqueBaseOpcodeNames.empty() &&
"Canonical mapping not built");
105 auto It = std::find(UniqueBaseOpcodeNames.begin(),
106 UniqueBaseOpcodeNames.end(), BaseName.
str());
107 assert(It != UniqueBaseOpcodeNames.end() &&
108 "Base name not found in unique opcodes");
109 return std::distance(UniqueBaseOpcodeNames.begin(), It);
112unsigned MIRVocabulary::getCanonicalOpcodeIndex(
unsigned Opcode)
const {
113 auto BaseOpcode = extractBaseOpcodeName(
TII.getName(Opcode));
114 return getCanonicalIndexForBaseName(BaseOpcode);
118 assert(Pos < Layout.TotalEntries &&
"Position out of bounds in vocabulary");
121 if (Pos < Layout.OperandBase && Pos < UniqueBaseOpcodeNames.size()) {
123 auto It = UniqueBaseOpcodeNames.begin();
124 std::advance(It, Pos);
132void MIRVocabulary::generateStorage(
const VocabMap &OpcodeMap) {
139 <<
"; using zero vector. This will result in an error "
141 ++MIRVocabMissCounter;
145 unsigned EmbeddingDim = OpcodeMap.begin()->second.size();
146 std::vector<Embedding> OpcodeEmbeddings(Layout.OperandBase,
150 for (
auto COpcodeName : UniqueBaseOpcodeNames) {
151 if (
auto It = OpcodeMap.find(COpcodeName); It != OpcodeMap.end()) {
152 auto COpcodeIndex = getCanonicalIndexForBaseName(COpcodeName);
153 assert(COpcodeIndex < Layout.OperandBase &&
154 "Canonical index out of bounds");
155 OpcodeEmbeddings[COpcodeIndex] = It->second;
157 handleMissingEntity(COpcodeName);
165 auto scaleVocabSection = [](std::vector<Embedding> &Embeddings,
170 scaleVocabSection(OpcodeEmbeddings,
OpcWeight);
172 std::vector<std::vector<Embedding>> Sections(1);
173 Sections[0] = std::move(OpcodeEmbeddings);
178void MIRVocabulary::buildCanonicalOpcodeMapping() {
180 if (!UniqueBaseOpcodeNames.empty())
184 for (
unsigned Opcode = 0; Opcode <
TII.getNumOpcodes(); ++Opcode) {
185 std::string BaseOpcode = extractBaseOpcodeName(
TII.getName(Opcode));
186 UniqueBaseOpcodeNames.insert(BaseOpcode);
189 LLVM_DEBUG(
dbgs() <<
"MIR2Vec: Built canonical mapping for target with "
190 << UniqueBaseOpcodeNames.size()
191 <<
" unique base opcodes\n");
200 "MIR2Vec Vocabulary Analysis",
false,
true)
206 return "MIR2Vec Vocabulary Analysis";
209Error MIR2VecVocabLegacyAnalysis::readVocabulary() {
215 "MIR2Vec vocabulary file path not specified; set it "
216 "using --mir2vec-vocab-path");
222 auto Content = BufOrError.get()->getBuffer();
224 Expected<json::Value> ParsedVocabValue =
json::parse(Content);
225 if (!ParsedVocabValue)
230 "entities", *ParsedVocabValue, StrVocabMap, Dim))
238 if (StrVocabMap.empty()) {
239 if (
Error Err = readVocabulary()) {
240 return std::move(Err);
248 for (
const auto &
F : M) {
249 if (
F.isDeclaration())
260 "No machine functions found in module");
269 "MIR2Vec Vocabulary Printer Pass",
false,
true)
281 auto MIR2VecVocabOrErr =
Analysis.getMIR2VecVocabulary(M);
283 if (!MIR2VecVocabOrErr) {
284 OS <<
"MIR2Vec Vocabulary Printer: Failed to get vocabulary - "
285 <<
toString(MIR2VecVocabOrErr.takeError()) <<
"\n";
289 auto &MIR2VecVocab = *MIR2VecVocabOrErr;
291 for (
const auto &Entry : MIR2VecVocab) {
292 OS <<
"Key: " << MIR2VecVocab.getStringKey(Pos++) <<
": ";
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
block Block Frequency Analysis
const HexagonInstrInfo * TII
Module.h This file contains the declarations for the Module class.
This file defines the MIR2Vec vocabulary analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::MIRE...
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Lightweight error class with error context and mandatory checking.
static ErrorSuccess success()
Create a success value.
Tagged union holding either a T or a Error.
Error takeError()
Take ownership of the stored error.
Pass to analyze and populate MIR2Vec vocabulary from a module.
Expected< mir2vec::MIRVocabulary > getMIR2VecVocabulary(const Module &M)
This pass prints the embeddings in the MIR2Vec vocabulary.
bool doFinalization(Module &M) override
doFinalization - Virtual method overriden by subclasses to do any necessary clean up after all passes...
bool runOnMachineFunction(MachineFunction &MF) override
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
MIR2VecVocabPrinterLegacyPass(raw_ostream &OS)
MachineFunctionPass - This class adapts the FunctionPass interface to allow convenient creation of pa...
This class contains meta information specific to a module.
LLVM_ABI MachineFunction * getMachineFunction(const Function &F) const
Returns the MachineFunction associated to IR function F if there is one, otherwise nullptr.
static ErrorOr< std::unique_ptr< MemoryBuffer > > getFileOrSTDIN(const Twine &Filename, bool IsText=false, bool RequiresNullTerminator=true, std::optional< Align > Alignment=std::nullopt)
Open the specified file as a MemoryBuffer, or open stdin if the Filename is "-".
A Module instance is used to store all the information related to an LLVM module.
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
LLVM_ABI bool match(StringRef String, SmallVectorImpl< StringRef > *Matches=nullptr, std::string *Error=nullptr) const
matches - Match the regex against a given String.
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
std::string str() const
str - Get the contents as an std::string.
constexpr bool empty() const
empty - Check if the string is empty.
char back() const
back - Get the last character in the string.
StringRef drop_back(size_t N=1) const
Return a StringRef equal to 'this' but with the last N elements dropped.
TargetInstrInfo - Interface to description of machine instruction set.
Generic storage class for section-based vocabularies.
static Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab, unsigned &Dim)
Parse a vocabulary section from JSON and populate the target vocabulary map.
static std::string extractBaseOpcodeName(StringRef InstrName)
Static method for extracting base opcode names (public for testing)
static Expected< MIRVocabulary > create(VocabMap &&Entries, const TargetInstrInfo &TII)
Factory method to create MIRVocabulary from vocabulary map.
std::string getStringKey(unsigned Pos) const
Get the string key for a vocabulary entry at the given position.
unsigned getCanonicalIndexForBaseName(StringRef BaseName) const
Get canonical index for base name (public for testing)
This class implements an extremely fast bulk output stream that can only output to a stream.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
initializer< Ty > init(const Ty &Val)
LLVM_ABI llvm::Expected< Value > parse(llvm::StringRef JSON)
Parses the provided JSON source, or returns a ParseError.
llvm::cl::OptionCategory MIR2VecCategory
cl::opt< float > OpcWeight
static cl::opt< std::string > VocabFile("mir2vec-vocab-path", cl::Optional, cl::desc("Path to the vocabulary file for MIR2Vec"), cl::init(""), cl::cat(MIR2VecCategory))
ir2vec::Embedding Embedding
This is an optimization pass for GlobalISel generic memory operations.
Error createFileError(const Twine &F, Error E)
Concatenate a source file path and/or name with an Error.
Error createStringError(std::error_code EC, char const *Fmt, const Ts &... Vals)
Create formatted StringError object.
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
LLVM_ABI MachineFunctionPass * createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS)
MIR2VecVocabPrinter pass - This pass prints out the MIR2Vec vocabulary contents to the given stream a...
std::string toString(const APInt &I, unsigned Radix, bool Signed, bool formatAsCLiteral=false, bool UpperCase=true, bool InsertSeparators=false)