LLVM 22.0.0git
MIR2Vec.h
Go to the documentation of this file.
1//===- MIR2Vec.h - Implementation of MIR2Vec ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file defines the MIR2Vec vocabulary
11/// analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::MIREmbedder
12/// interface for generating Machine IR embeddings, and related utilities.
13///
14/// MIR2Vec extends IR2Vec to support Machine IR embeddings. It represents the
15/// LLVM Machine IR as embeddings which can be used as input to machine learning
16/// algorithms.
17///
18/// The original idea of MIR2Vec is described in the following paper:
19///
20/// RL4ReAl: Reinforcement Learning for Register Allocation. S. VenkataKeerthy,
21/// Siddharth Jain, Anilava Kundu, Rohit Aggarwal, Albert Cohen, and Ramakrishna
22/// Upadrasta. 2023. RL4ReAl: Reinforcement Learning for Register Allocation.
23/// Proceedings of the 32nd ACM SIGPLAN International Conference on Compiler
24/// Construction (CC 2023). https://doi.org/10.1145/3578360.3580273.
25/// https://arxiv.org/abs/2204.02013
26///
27//===----------------------------------------------------------------------===//
28
29#ifndef LLVM_CODEGEN_MIR2VEC_H
30#define LLVM_CODEGEN_MIR2VEC_H
31
38#include "llvm/IR/PassManager.h"
39#include "llvm/Pass.h"
41#include "llvm/Support/Error.h"
43#include <map>
44#include <set>
45#include <string>
46
47namespace llvm {
48
49class Module;
50class raw_ostream;
51class LLVMContext;
53class TargetInstrInfo;
54
55namespace mir2vec {
58
60
61/// Class for storing and accessing the MIR2Vec vocabulary.
62/// The MIRVocabulary class manages seed embeddings for LLVM Machine IR
65 using VocabMap = std::map<std::string, ir2vec::Embedding>;
66
67private:
68 // Define vocabulary layout - adapted for MIR
69 struct {
70 size_t OpcodeBase = 0;
71 size_t OperandBase = 0;
72 size_t TotalEntries = 0;
73 } Layout;
74
75 enum class Section : unsigned { Opcodes = 0, MaxSections };
76
78 mutable std::set<std::string> UniqueBaseOpcodeNames;
79 const TargetInstrInfo &TII;
80 void generateStorage(const VocabMap &OpcodeMap);
81 void buildCanonicalOpcodeMapping();
82
83 /// Get canonical index for a machine opcode
84 unsigned getCanonicalOpcodeIndex(unsigned Opcode) const;
85
86public:
87 /// Static method for extracting base opcode names (public for testing)
88 static std::string extractBaseOpcodeName(StringRef InstrName);
89
90 /// Get canonical index for base name (public for testing)
91 unsigned getCanonicalIndexForBaseName(StringRef BaseName) const;
92
93 /// Get the string key for a vocabulary entry at the given position
94 std::string getStringKey(unsigned Pos) const;
95
96 unsigned getDimension() const { return Storage.getDimension(); }
97
98 // Accessor methods
99 const Embedding &operator[](unsigned Opcode) const {
100 unsigned LocalIndex = getCanonicalOpcodeIndex(Opcode);
101 return Storage[static_cast<unsigned>(Section::Opcodes)][LocalIndex];
102 }
103
104 // Iterator access
106 const_iterator begin() const { return Storage.begin(); }
107
108 const_iterator end() const { return Storage.end(); }
109
110 /// Total number of entries in the vocabulary
111 size_t getCanonicalSize() const { return Storage.size(); }
112
113 MIRVocabulary() = delete;
114
115 /// Factory method to create MIRVocabulary from vocabulary map
116 static Expected<MIRVocabulary> create(VocabMap &&Entries,
117 const TargetInstrInfo &TII);
118
119private:
120 MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo &TII);
121};
122
123} // namespace mir2vec
124
125/// Pass to analyze and populate MIR2Vec vocabulary from a module
127 using VocabVector = std::vector<mir2vec::Embedding>;
128 using VocabMap = std::map<std::string, mir2vec::Embedding>;
129 VocabMap StrVocabMap;
130 VocabVector Vocab;
131
132 StringRef getPassName() const override;
133 Error readVocabulary();
134
135protected:
136 void getAnalysisUsage(AnalysisUsage &AU) const override {
138 AU.setPreservesAll();
139 }
140
141public:
142 static char ID;
145};
146
147/// This pass prints the embeddings in the MIR2Vec vocabulary
149 raw_ostream &OS;
150
151public:
152 static char ID;
155
156 bool runOnMachineFunction(MachineFunction &MF) override;
157 bool doFinalization(Module &M) override;
163
164 StringRef getPassName() const override {
165 return "MIR2Vec Vocabulary Printer Pass";
166 }
167};
168
169} // namespace llvm
170
171#endif // LLVM_CODEGEN_MIR2VEC_H
Provides ErrorOr<T> smart pointer.
const HexagonInstrInfo * TII
This file defines the IR2Vec vocabulary analysis(IR2VecVocabAnalysis), the core ir2vec::Embedder inte...
This header defines various interfaces for pass management in LLVM.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesAll()
Set by analyses that do not transform their input at all.
Lightweight error class with error context and mandatory checking.
Definition Error.h:159
Tagged union holding either a T or a Error.
Definition Error.h:485
ImmutablePass(char &pid)
Definition Pass.h:287
This is an important class for using LLVM in a threaded context.
Definition LLVMContext.h:68
Pass to analyze and populate MIR2Vec vocabulary from a module.
Definition MIR2Vec.h:126
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition MIR2Vec.h:136
Expected< mir2vec::MIRVocabulary > getMIR2VecVocabulary(const Module &M)
Definition MIR2Vec.cpp:237
StringRef getPassName() const override
getPassName - Return a nice clean name for a pass.
Definition MIR2Vec.h:164
bool doFinalization(Module &M) override
doFinalization - Virtual method overriden by subclasses to do any necessary clean up after all passes...
Definition MIR2Vec.cpp:279
bool runOnMachineFunction(MachineFunction &MF) override
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
Definition MIR2Vec.cpp:275
MIR2VecVocabPrinterLegacyPass(raw_ostream &OS)
Definition MIR2Vec.h:153
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.
Definition MIR2Vec.h:158
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
Definition Pass.cpp:85
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
TargetInstrInfo - Interface to description of machine instruction set.
Iterator support for section-based access.
Definition IR2Vec.h:196
Generic storage class for section-based vocabularies.
Definition IR2Vec.h:151
unsigned getDimension() const
Definition MIR2Vec.h:96
const_iterator end() const
Definition MIR2Vec.h:108
static std::string extractBaseOpcodeName(StringRef InstrName)
Static method for extracting base opcode names (public for testing)
Definition MIR2Vec.cpp:73
ir2vec::VocabStorage::const_iterator const_iterator
Definition MIR2Vec.h:105
const_iterator begin() const
Definition MIR2Vec.h:106
const Embedding & operator[](unsigned Opcode) const
Definition MIR2Vec.h:99
size_t getCanonicalSize() const
Total number of entries in the vocabulary.
Definition MIR2Vec.h:111
static Expected< MIRVocabulary > create(VocabMap &&Entries, const TargetInstrInfo &TII)
Factory method to create MIRVocabulary from vocabulary map.
Definition MIR2Vec.cpp:64
std::string getStringKey(unsigned Pos) const
Get the string key for a vocabulary entry at the given position.
Definition MIR2Vec.cpp:117
unsigned getCanonicalIndexForBaseName(StringRef BaseName) const
Get canonical index for base name (public for testing)
Definition MIR2Vec.cpp:103
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
llvm::cl::OptionCategory MIR2VecCategory
cl::opt< float > OpcWeight
ir2vec::Embedding Embedding
Definition MIR2Vec.h:59
This is an optimization pass for GlobalISel generic memory operations.
Embedding is a datatype that wraps std::vector<double>.
Definition IR2Vec.h:87