1 //===- IR2Vec.h - Implementation of IR2Vec ----------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM 4 // Exceptions. See the LICENSE file for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// This file defines the IR2Vec vocabulary analysis(IR2VecVocabAnalysis), 11 /// the core ir2vec::Embedder interface for generating IR embeddings, 12 /// and related utilities like the IR2VecPrinterPass. 13 /// 14 /// Program Embeddings are typically or derived-from a learned 15 /// representation of the program. Such embeddings are used to represent the 16 /// programs as input to machine learning algorithms. IR2Vec represents the 17 /// LLVM IR as embeddings. 18 /// 19 /// The IR2Vec algorithm is described in the following paper: 20 /// 21 /// IR2Vec: LLVM IR Based Scalable Program Embeddings, S. VenkataKeerthy, 22 /// Rohit Aggarwal, Shalini Jain, Maunendra Sankar Desarkar, Ramakrishna 23 /// Upadrasta, and Y. N. Srikant, ACM Transactions on Architecture and 24 /// Code Optimization (TACO), 2020. https://doi.org/10.1145/3418463. 25 /// https://arxiv.org/abs/1909.06228 26 /// 27 //===----------------------------------------------------------------------===// 28 29 #ifndef LLVM_ANALYSIS_IR2VEC_H 30 #define LLVM_ANALYSIS_IR2VEC_H 31 32 #include "llvm/ADT/DenseMap.h" 33 #include "llvm/IR/PassManager.h" 34 #include "llvm/IR/Type.h" 35 #include "llvm/Support/CommandLine.h" 36 #include "llvm/Support/Compiler.h" 37 #include "llvm/Support/ErrorOr.h" 38 #include "llvm/Support/JSON.h" 39 #include <map> 40 41 namespace llvm { 42 43 class Module; 44 class BasicBlock; 45 class Instruction; 46 class Function; 47 class Value; 48 class raw_ostream; 49 class LLVMContext; 50 class IR2VecVocabAnalysis; 51 52 /// IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware. 53 /// Symbolic embeddings capture the "syntactic" and "statistical correlation" 54 /// of the IR entities. Flow-aware embeddings build on top of symbolic 55 /// embeddings and additionally capture the flow information in the IR. 56 /// IR2VecKind is used to specify the type of embeddings to generate. 57 /// Currently, only Symbolic embeddings are supported. 58 enum class IR2VecKind { Symbolic }; 59 60 namespace ir2vec { 61 62 LLVM_ABI extern cl::opt<float> OpcWeight; 63 LLVM_ABI extern cl::opt<float> TypeWeight; 64 LLVM_ABI extern cl::opt<float> ArgWeight; 65 66 /// Embedding is a datatype that wraps std::vector<double>. It provides 67 /// additional functionality for arithmetic and comparison operations. 68 /// It is meant to be used *like* std::vector<double> but is more restrictive 69 /// in the sense that it does not allow the user to change the size of the 70 /// embedding vector. The dimension of the embedding is fixed at the time of 71 /// construction of Embedding object. But the elements can be modified in-place. 72 struct Embedding { 73 private: 74 std::vector<double> Data; 75 76 public: 77 Embedding() = default; EmbeddingEmbedding78 Embedding(const std::vector<double> &V) : Data(V) {} EmbeddingEmbedding79 Embedding(std::vector<double> &&V) : Data(std::move(V)) {} EmbeddingEmbedding80 Embedding(std::initializer_list<double> IL) : Data(IL) {} 81 EmbeddingEmbedding82 explicit Embedding(size_t Size) : Data(Size) {} EmbeddingEmbedding83 Embedding(size_t Size, double InitialValue) : Data(Size, InitialValue) {} 84 sizeEmbedding85 size_t size() const { return Data.size(); } emptyEmbedding86 bool empty() const { return Data.empty(); } 87 88 double &operator[](size_t Itr) { 89 assert(Itr < Data.size() && "Index out of bounds"); 90 return Data[Itr]; 91 } 92 93 const double &operator[](size_t Itr) const { 94 assert(Itr < Data.size() && "Index out of bounds"); 95 return Data[Itr]; 96 } 97 98 using iterator = typename std::vector<double>::iterator; 99 using const_iterator = typename std::vector<double>::const_iterator; 100 beginEmbedding101 iterator begin() { return Data.begin(); } endEmbedding102 iterator end() { return Data.end(); } beginEmbedding103 const_iterator begin() const { return Data.begin(); } endEmbedding104 const_iterator end() const { return Data.end(); } cbeginEmbedding105 const_iterator cbegin() const { return Data.cbegin(); } cendEmbedding106 const_iterator cend() const { return Data.cend(); } 107 getDataEmbedding108 const std::vector<double> &getData() const { return Data; } 109 110 /// Arithmetic operators 111 LLVM_ABI Embedding &operator+=(const Embedding &RHS); 112 LLVM_ABI Embedding operator+(const Embedding &RHS) const; 113 LLVM_ABI Embedding &operator-=(const Embedding &RHS); 114 LLVM_ABI Embedding operator-(const Embedding &RHS) const; 115 LLVM_ABI Embedding &operator*=(double Factor); 116 LLVM_ABI Embedding operator*(double Factor) const; 117 118 /// Adds Src Embedding scaled by Factor with the called Embedding. 119 /// Called_Embedding += Src * Factor 120 LLVM_ABI Embedding &scaleAndAdd(const Embedding &Src, float Factor); 121 122 /// Returns true if the embedding is approximately equal to the RHS embedding 123 /// within the specified tolerance. 124 LLVM_ABI bool approximatelyEquals(const Embedding &RHS, 125 double Tolerance = 1e-4) const; 126 127 LLVM_ABI void print(raw_ostream &OS) const; 128 }; 129 130 using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>; 131 using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>; 132 133 /// Class for storing and accessing the IR2Vec vocabulary. 134 /// Encapsulates all vocabulary-related constants, logic, and access methods. 135 class Vocabulary { 136 friend class llvm::IR2VecVocabAnalysis; 137 using VocabVector = std::vector<ir2vec::Embedding>; 138 VocabVector Vocab; 139 bool Valid = false; 140 141 /// Operand kinds supported by IR2Vec Vocabulary 142 enum class OperandKind : unsigned { 143 FunctionID, 144 PointerID, 145 ConstantID, 146 VariableID, 147 MaxOperandKind 148 }; 149 /// String mappings for OperandKind values 150 static constexpr StringLiteral OperandKindNames[] = {"Function", "Pointer", 151 "Constant", "Variable"}; 152 static_assert(std::size(OperandKindNames) == 153 static_cast<unsigned>(OperandKind::MaxOperandKind), 154 "OperandKindNames array size must match MaxOperandKind"); 155 156 /// Vocabulary layout constants 157 #define LAST_OTHER_INST(NUM) static constexpr unsigned MaxOpcodes = NUM; 158 #include "llvm/IR/Instruction.def" 159 #undef LAST_OTHER_INST 160 161 static constexpr unsigned MaxTypeIDs = Type::TypeID::TargetExtTyID + 1; 162 static constexpr unsigned MaxOperandKinds = 163 static_cast<unsigned>(OperandKind::MaxOperandKind); 164 165 public: 166 Vocabulary() = default; 167 Vocabulary(VocabVector &&Vocab); 168 169 bool isValid() const; 170 unsigned getDimension() const; 171 size_t size() const; 172 173 /// Helper function to get vocabulary key for a given Opcode 174 static StringRef getVocabKeyForOpcode(unsigned Opcode); 175 176 /// Helper function to get vocabulary key for a given TypeID 177 static StringRef getVocabKeyForTypeID(Type::TypeID TypeID); 178 179 /// Helper function to get vocabulary key for a given OperandKind 180 static StringRef getVocabKeyForOperandKind(OperandKind Kind); 181 182 /// Helper function to classify an operand into OperandKind 183 static OperandKind getOperandKind(const Value *Op); 184 185 /// Accessors to get the embedding for a given entity. 186 const ir2vec::Embedding &operator[](unsigned Opcode) const; 187 const ir2vec::Embedding &operator[](Type::TypeID TypeId) const; 188 const ir2vec::Embedding &operator[](const Value *Arg) const; 189 190 /// Const Iterator type aliases 191 using const_iterator = VocabVector::const_iterator; begin()192 const_iterator begin() const { 193 assert(Valid && "IR2Vec Vocabulary is invalid"); 194 return Vocab.begin(); 195 } 196 cbegin()197 const_iterator cbegin() const { 198 assert(Valid && "IR2Vec Vocabulary is invalid"); 199 return Vocab.cbegin(); 200 } 201 end()202 const_iterator end() const { 203 assert(Valid && "IR2Vec Vocabulary is invalid"); 204 return Vocab.end(); 205 } 206 cend()207 const_iterator cend() const { 208 assert(Valid && "IR2Vec Vocabulary is invalid"); 209 return Vocab.cend(); 210 } 211 212 /// Returns the string key for a given index position in the vocabulary. 213 /// This is useful for debugging or printing the vocabulary. Do not use this 214 /// for embedding generation as string based lookups are inefficient. 215 static StringRef getStringKey(unsigned Pos); 216 217 /// Create a dummy vocabulary for testing purposes. 218 static VocabVector createDummyVocabForTest(unsigned Dim = 1); 219 220 bool invalidate(Module &M, const PreservedAnalyses &PA, 221 ModuleAnalysisManager::Invalidator &Inv) const; 222 }; 223 224 /// Embedder provides the interface to generate embeddings (vector 225 /// representations) for instructions, basic blocks, and functions. The 226 /// vector representations are generated using IR2Vec algorithms. 227 /// 228 /// The Embedder class is an abstract class and it is intended to be 229 /// subclassed for different IR2Vec algorithms like Symbolic and Flow-aware. 230 class Embedder { 231 protected: 232 const Function &F; 233 const Vocabulary &Vocab; 234 235 /// Dimension of the vector representation; captured from the input vocabulary 236 const unsigned Dimension; 237 238 /// Weights for different entities (like opcode, arguments, types) 239 /// in the IR instructions to generate the vector representation. 240 const float OpcWeight, TypeWeight, ArgWeight; 241 242 // Utility maps - these are used to store the vector representations of 243 // instructions, basic blocks and functions. 244 mutable Embedding FuncVector; 245 mutable BBEmbeddingsMap BBVecMap; 246 mutable InstEmbeddingsMap InstVecMap; 247 248 LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab); 249 250 /// Helper function to compute embeddings. It generates embeddings for all 251 /// the instructions and basic blocks in the function F. Logic of computing 252 /// the embeddings is specific to the kind of embeddings being computed. 253 virtual void computeEmbeddings() const = 0; 254 255 /// Helper function to compute the embedding for a given basic block. 256 /// Specific to the kind of embeddings being computed. 257 virtual void computeEmbeddings(const BasicBlock &BB) const = 0; 258 259 public: 260 virtual ~Embedder() = default; 261 262 /// Factory method to create an Embedder object. 263 LLVM_ABI static std::unique_ptr<Embedder> 264 create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab); 265 266 /// Returns a map containing instructions and the corresponding embeddings for 267 /// the function F if it has been computed. If not, it computes the embeddings 268 /// for the function and returns the map. 269 LLVM_ABI const InstEmbeddingsMap &getInstVecMap() const; 270 271 /// Returns a map containing basic block and the corresponding embeddings for 272 /// the function F if it has been computed. If not, it computes the embeddings 273 /// for the function and returns the map. 274 LLVM_ABI const BBEmbeddingsMap &getBBVecMap() const; 275 276 /// Returns the embedding for a given basic block in the function F if it has 277 /// been computed. If not, it computes the embedding for the basic block and 278 /// returns it. 279 LLVM_ABI const Embedding &getBBVector(const BasicBlock &BB) const; 280 281 /// Computes and returns the embedding for the current function. 282 LLVM_ABI const Embedding &getFunctionVector() const; 283 }; 284 285 /// Class for computing the Symbolic embeddings of IR2Vec. 286 /// Symbolic embeddings are constructed based on the entity-level 287 /// representations obtained from the Vocabulary. 288 class LLVM_ABI SymbolicEmbedder : public Embedder { 289 private: 290 void computeEmbeddings() const override; 291 void computeEmbeddings(const BasicBlock &BB) const override; 292 293 public: SymbolicEmbedder(const Function & F,const Vocabulary & Vocab)294 SymbolicEmbedder(const Function &F, const Vocabulary &Vocab) 295 : Embedder(F, Vocab) { 296 FuncVector = Embedding(Dimension, 0); 297 } 298 }; 299 300 } // namespace ir2vec 301 302 /// This analysis provides the vocabulary for IR2Vec. The vocabulary provides a 303 /// mapping between an entity of the IR (like opcode, type, argument, etc.) and 304 /// its corresponding embedding. 305 class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> { 306 using VocabVector = std::vector<ir2vec::Embedding>; 307 using VocabMap = std::map<std::string, ir2vec::Embedding>; 308 VocabMap OpcVocab, TypeVocab, ArgVocab; 309 VocabVector Vocab; 310 311 Error readVocabulary(); 312 Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue, 313 VocabMap &TargetVocab, unsigned &Dim); 314 void generateNumMappedVocab(); 315 void emitError(Error Err, LLVMContext &Ctx); 316 317 public: 318 LLVM_ABI static AnalysisKey Key; 319 IR2VecVocabAnalysis() = default; 320 LLVM_ABI explicit IR2VecVocabAnalysis(const VocabVector &Vocab); 321 LLVM_ABI explicit IR2VecVocabAnalysis(VocabVector &&Vocab); 322 using Result = ir2vec::Vocabulary; 323 LLVM_ABI Result run(Module &M, ModuleAnalysisManager &MAM); 324 }; 325 326 /// This pass prints the IR2Vec embeddings for instructions, basic blocks, and 327 /// functions. 328 class IR2VecPrinterPass : public PassInfoMixin<IR2VecPrinterPass> { 329 raw_ostream &OS; 330 331 public: IR2VecPrinterPass(raw_ostream & OS)332 explicit IR2VecPrinterPass(raw_ostream &OS) : OS(OS) {} 333 LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM); isRequired()334 static bool isRequired() { return true; } 335 }; 336 337 /// This pass prints the embeddings in the vocabulary 338 class IR2VecVocabPrinterPass : public PassInfoMixin<IR2VecVocabPrinterPass> { 339 raw_ostream &OS; 340 341 public: IR2VecVocabPrinterPass(raw_ostream & OS)342 explicit IR2VecVocabPrinterPass(raw_ostream &OS) : OS(OS) {} 343 LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM); isRequired()344 static bool isRequired() { return true; } 345 }; 346 347 } // namespace llvm 348 349 #endif // LLVM_ANALYSIS_IR2VEC_H 350