35#ifndef LLVM_ANALYSIS_IR2VEC_H
36#define LLVM_ANALYSIS_IR2VEC_H
89 std::vector<double> Data;
93 Embedding(
const std::vector<double> &V) : Data(V) {}
95 Embedding(std::initializer_list<double> IL) : Data(IL) {}
100 size_t size()
const {
return Data.size(); }
101 bool empty()
const {
return Data.empty(); }
104 assert(Itr < Data.size() &&
"Index out of bounds");
109 assert(Itr < Data.size() &&
"Index out of bounds");
113 using iterator =
typename std::vector<double>::iterator;
123 const std::vector<double> &
getData()
const {
return Data; }
140 double Tolerance = 1e-4)
const;
154 std::vector<std::vector<Embedding>> Sections;
156 const size_t TotalSize;
157 const unsigned Dimension;
164 VocabStorage(std::vector<std::vector<Embedding>> &&SectionData);
173 size_t size()
const {
return TotalSize; }
177 return static_cast<unsigned>(Sections.size());
181 const std::vector<Embedding> &
operator[](
unsigned SectionId)
const {
182 assert(SectionId < Sections.size() &&
"Invalid section ID");
183 return Sections[SectionId];
190 bool isValid()
const {
return TotalSize > 0; }
195 unsigned SectionId = 0;
196 size_t LocalIndex = 0;
201 : Storage(Storage), SectionId(SectionId), LocalIndex(LocalIndex) {}
251 enum class Section :
unsigned {
262 static constexpr unsigned NumICmpPredicates =
265 static constexpr unsigned NumFCmpPredicates =
297#define LAST_OTHER_INST(NUM) static constexpr unsigned MaxOpcodes = NUM;
298#include "llvm/IR/Instruction.def"
299#undef LAST_OTHER_INST
301 static constexpr unsigned MaxTypeIDs = Type::TypeID::TargetExtTyID + 1;
309 NumICmpPredicates + NumFCmpPredicates;
321 return Storage.size() == NumCanonicalEntries;
326 return Storage.getDimension();
338 return getVocabKeyForCanonicalTypeID(getCanonicalTypeID(
TypeID));
343 unsigned Index =
static_cast<unsigned>(Kind);
345 return OperandKindNames[Index];
356 assert(Opcode >= 1 && Opcode <= MaxOpcodes &&
"Invalid opcode");
362 return MaxOpcodes +
static_cast<unsigned>(getCanonicalTypeID(
TypeID));
368 return OperandBaseOffset + Index;
372 return PredicateBaseOffset + getPredicateLocalIndex(
P);
377 assert(Opcode >= 1 && Opcode <= MaxOpcodes &&
"Invalid opcode");
378 return Storage[
static_cast<unsigned>(Section::Opcodes)][Opcode - 1];
383 unsigned LocalIndex =
static_cast<unsigned>(getCanonicalTypeID(
TypeID));
384 return Storage[
static_cast<unsigned>(Section::CanonicalTypes)][LocalIndex];
388 unsigned LocalIndex =
static_cast<unsigned>(
getOperandKind(&Arg));
390 return Storage[
static_cast<unsigned>(Section::Operands)][LocalIndex];
394 unsigned LocalIndex = getPredicateLocalIndex(
P);
395 return Storage[
static_cast<unsigned>(Section::Predicates)][LocalIndex];
403 return Storage.begin();
410 return Storage.end();
424 ModuleAnalysisManager::Invalidator &Inv)
const;
427 constexpr static unsigned NumCanonicalEntries =
431 constexpr static unsigned OperandBaseOffset =
433 constexpr static unsigned PredicateBaseOffset =
442 "FloatTy",
"VoidTy",
"LabelTy",
"MetadataTy",
443 "VectorTy",
"TokenTy",
"IntegerTy",
"FunctionTy",
444 "PointerTy",
"StructTy",
"ArrayTy",
"UnknownTy"};
445 static_assert(std::size(CanonicalTypeNames) ==
447 "CanonicalTypeNames array size must match MaxCanonicalType");
450 static constexpr StringLiteral OperandKindNames[] = {
"Function",
"Pointer",
451 "Constant",
"Variable"};
452 static_assert(std::size(OperandKindNames) ==
454 "OperandKindNames array size must match MaxOperandKind");
458 static constexpr std::array<CanonicalTypeID, MaxTypeIDs> TypeIDMapping = {{
481 static_assert(TypeIDMapping.size() ==
MaxTypeIDs,
482 "TypeIDMapping must cover all Type::TypeID values");
487 unsigned Index =
static_cast<unsigned>(CType);
489 return CanonicalTypeNames[
Index];
496 return TypeIDMapping[
Index];
504 return getPredicateFromLocalIndex(Index);
546 LLVM_ABI static std::unique_ptr<Embedder>
573 void computeEmbeddings(
const BasicBlock &BB)
const override;
585 void computeEmbeddings(
const BasicBlock &BB)
const override;
598 using VocabMap = std::map<std::string, ir2vec::Embedding>;
599 std::optional<ir2vec::VocabStorage> Vocab;
601 Error readVocabulary(VocabMap &OpcVocab, VocabMap &TypeVocab,
604 VocabMap &TargetVocab,
unsigned &Dim);
605 void generateVocabStorage(VocabMap &OpcVocab, VocabMap &TypeVocab,
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
This file defines the DenseMap class.
Provides ErrorOr<T> smart pointer.
This header defines various interfaces for pass management in LLVM.
This file supports working with JSON data.
ModuleAnalysisManager MAM
static cl::opt< RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode > Mode("regalloc-enable-advisor", cl::Hidden, cl::init(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Default), cl::desc("Enable regalloc advisor mode"), cl::values(clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Default, "default", "Default"), clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Release, "release", "precompiled"), clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Development, "development", "for training")))
LLVM Basic Block Representation.
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Lightweight error class with error context and mandatory checking.
IR2VecPrinterPass(raw_ostream &OS)
LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
This analysis provides the vocabulary for IR2Vec.
IR2VecVocabAnalysis()=default
ir2vec::Vocabulary Result
LLVM_ABI Result run(Module &M, ModuleAnalysisManager &MAM)
LLVM_ABI IR2VecVocabAnalysis(ir2vec::VocabStorage &&Vocab)
static LLVM_ABI AnalysisKey Key
LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
IR2VecVocabPrinterPass(raw_ostream &OS)
This is an important class for using LLVM in a threaded context.
A Module instance is used to store all the information related to an LLVM module.
A set of analyses that are preserved following a run of a transformation pass.
A wrapper around a string literal that serves as a proxy for constructing global tables of StringRefs...
StringRef - Represent a constant reference to a string, i.e.
TypeID
Definitions of all of the base types for the Type system.
LLVM Value Representation.
LLVM_ABI const Embedding & getBBVector(const BasicBlock &BB) const
Returns the embedding for a given basic block in the function F if it has been computed.
static LLVM_ABI std::unique_ptr< Embedder > create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab)
Factory method to create an Embedder object.
LLVM_ABI const BBEmbeddingsMap & getBBVecMap() const
Returns a map containing basic block and the corresponding embeddings for the function F if it has be...
void computeEmbeddings() const
Function to compute embeddings.
virtual ~Embedder()=default
LLVM_ABI const InstEmbeddingsMap & getInstVecMap() const
Returns a map containing instructions and the corresponding embeddings for the function F if it has b...
const float OpcWeight
Weights for different entities (like opcode, arguments, types) in the IR instructions to generate the...
const unsigned Dimension
Dimension of the vector representation; captured from the input vocabulary.
LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab)
virtual void computeEmbeddings(const BasicBlock &BB) const =0
Function to compute the embedding for a given basic block.
LLVM_ABI const Embedding & getFunctionVector() const
Computes and returns the embedding for the current function.
InstEmbeddingsMap InstVecMap
FlowAwareEmbedder(const Function &F, const Vocabulary &Vocab)
SymbolicEmbedder(const Function &F, const Vocabulary &Vocab)
Iterator support for section-based access.
const_iterator(const VocabStorage *Storage, unsigned SectionId, size_t LocalIndex)
LLVM_ABI bool operator!=(const const_iterator &Other) const
LLVM_ABI const_iterator & operator++()
LLVM_ABI const Embedding & operator*() const
LLVM_ABI bool operator==(const const_iterator &Other) const
Generic storage class for section-based vocabularies.
const_iterator end() const
unsigned getNumSections() const
Get number of sections.
VocabStorage()
Default constructor creates empty storage (invalid state)
VocabStorage & operator=(VocabStorage &&)=delete
VocabStorage & operator=(const VocabStorage &)=delete
unsigned getDimension() const
Get vocabulary dimension.
size_t size() const
Get total number of entries across all sections.
const_iterator begin() const
bool isValid() const
Check if vocabulary is valid (has data)
VocabStorage(VocabStorage &&)=default
const std::vector< Embedding > & operator[](unsigned SectionId) const
Section-based access: Storage[sectionId][localIndex].
VocabStorage(const VocabStorage &)=delete
Class for storing and accessing the IR2Vec vocabulary.
static LLVM_ABI StringRef getVocabKeyForOperandKind(OperandKind Kind)
Function to get vocabulary key for a given OperandKind.
LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA, ModuleAnalysisManager::Invalidator &Inv) const
const_iterator begin() const
LLVM_ABI unsigned getDimension() const
Vocabulary(Vocabulary &&)=default
static LLVM_ABI OperandKind getOperandKind(const Value *Op)
Function to classify an operand into OperandKind.
static LLVM_ABI unsigned getIndex(CmpInst::Predicate P)
Vocabulary & operator=(const Vocabulary &)=delete
static LLVM_ABI StringRef getStringKey(unsigned Pos)
Returns the string key for a given index position in the vocabulary.
static constexpr unsigned MaxCanonicalTypeIDs
LLVM_ABI const ir2vec::Embedding & operator[](CmpInst::Predicate P) const
static constexpr unsigned MaxOperandKinds
Vocabulary(const Vocabulary &)=delete
const_iterator cbegin() const
OperandKind
Operand kinds supported by IR2Vec Vocabulary.
static constexpr size_t getCanonicalSize()
Total number of entries (opcodes + canonicalized types + operand kinds + predicates)
static LLVM_ABI unsigned getIndex(const Value &Op)
static LLVM_ABI StringRef getVocabKeyForPredicate(CmpInst::Predicate P)
Function to get vocabulary key for a given predicate.
static constexpr unsigned MaxTypeIDs
LLVM_ABI Vocabulary(VocabStorage &&Storage)
LLVM_ABI const ir2vec::Embedding & operator[](Type::TypeID TypeID) const
static LLVM_ABI unsigned getIndex(Type::TypeID TypeID)
const_iterator end() const
static LLVM_ABI StringRef getVocabKeyForOpcode(unsigned Opcode)
Function to get vocabulary key for a given Opcode.
static LLVM_ABI StringRef getVocabKeyForTypeID(Type::TypeID TypeID)
Function to get vocabulary key for a given TypeID.
VocabStorage::const_iterator const_iterator
Const Iterator type aliases.
const_iterator cend() const
static LLVM_ABI unsigned getIndex(unsigned Opcode)
Functions to return flat index.
LLVM_ABI bool isValid() const
Vocabulary & operator=(Vocabulary &&Other)=delete
LLVM_ABI const ir2vec::Embedding & operator[](unsigned Opcode) const
Accessors to get the embedding for a given entity.
static LLVM_ABI VocabStorage createDummyVocabForTest(unsigned Dim=1)
Create a dummy vocabulary for testing purposes.
static constexpr unsigned MaxPredicateKinds
CanonicalTypeID
Canonical type IDs supported by IR2Vec Vocabulary.
LLVM_ABI const ir2vec::Embedding & operator[](const Value &Arg) const
A Value is an JSON value of unknown type.
This class implements an extremely fast bulk output stream that can only output to a stream.
DenseMap< const Instruction *, Embedding > InstEmbeddingsMap
LLVM_ABI cl::opt< float > ArgWeight
DenseMap< const BasicBlock *, Embedding > BBEmbeddingsMap
LLVM_ABI cl::opt< float > OpcWeight
LLVM_ABI cl::opt< float > TypeWeight
LLVM_ABI cl::opt< IR2VecKind > IR2VecEmbeddingKind
llvm::cl::OptionCategory IR2VecCategory
This is an optimization pass for GlobalISel generic memory operations.
IR2VecKind
IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
DWARFExpression::Operation Op
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
AnalysisManager< Module > ModuleAnalysisManager
Convenience typedef for the Module analysis manager.
Implement std::hash so that hash_code can be used in STL containers.
A CRTP mix-in that provides informational APIs needed for analysis passes.
A special type used by analysis passes to provide an address that identifies that particular analysis...
A CRTP mix-in to automatically provide informational APIs needed for passes.
Embedding is a datatype that wraps std::vector<double>.
const_iterator end() const
LLVM_ABI bool approximatelyEquals(const Embedding &RHS, double Tolerance=1e-4) const
Returns true if the embedding is approximately equal to the RHS embedding within the specified tolera...
const_iterator cbegin() const
LLVM_ABI Embedding & operator+=(const Embedding &RHS)
Arithmetic operators.
LLVM_ABI Embedding operator-(const Embedding &RHS) const
const std::vector< double > & getData() const
typename std::vector< double >::const_iterator const_iterator
Embedding(size_t Size, double InitialValue)
LLVM_ABI Embedding & operator-=(const Embedding &RHS)
const_iterator cend() const
LLVM_ABI Embedding operator*(double Factor) const
LLVM_ABI Embedding & operator*=(double Factor)
Embedding(std::initializer_list< double > IL)
Embedding(const std::vector< double > &V)
LLVM_ABI Embedding operator+(const Embedding &RHS) const
typename std::vector< double >::iterator iterator
LLVM_ABI Embedding & scaleAndAdd(const Embedding &Src, float Factor)
Adds Src Embedding scaled by Factor with the called Embedding.
Embedding(std::vector< double > &&V)
const double & operator[](size_t Itr) const
LLVM_ABI void print(raw_ostream &OS) const
const_iterator begin() const
double & operator[](size_t Itr)