xref: /freebsd/contrib/llvm-project/llvm/include/llvm/Analysis/IR2Vec.h (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- IR2Vec.h - Implementation of IR2Vec ----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // Exceptions. See the LICENSE file for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file defines the IR2Vec vocabulary analysis(IR2VecVocabAnalysis),
11 /// the core ir2vec::Embedder interface for generating IR embeddings,
12 /// and related utilities like the IR2VecPrinterPass.
13 ///
14 /// Program Embeddings are typically or derived-from a learned
15 /// representation of the program. Such embeddings are used to represent the
16 /// programs as input to machine learning algorithms. IR2Vec represents the
17 /// LLVM IR as embeddings.
18 ///
19 /// The IR2Vec algorithm is described in the following paper:
20 ///
21 ///   IR2Vec: LLVM IR Based Scalable Program Embeddings, S. VenkataKeerthy,
22 ///   Rohit Aggarwal, Shalini Jain, Maunendra Sankar Desarkar, Ramakrishna
23 ///   Upadrasta, and Y. N. Srikant, ACM Transactions on Architecture and
24 ///   Code Optimization (TACO), 2020. https://doi.org/10.1145/3418463.
25 ///   https://arxiv.org/abs/1909.06228
26 ///
27 //===----------------------------------------------------------------------===//
28 
29 #ifndef LLVM_ANALYSIS_IR2VEC_H
30 #define LLVM_ANALYSIS_IR2VEC_H
31 
32 #include "llvm/ADT/DenseMap.h"
33 #include "llvm/IR/PassManager.h"
34 #include "llvm/IR/Type.h"
35 #include "llvm/Support/CommandLine.h"
36 #include "llvm/Support/Compiler.h"
37 #include "llvm/Support/ErrorOr.h"
38 #include "llvm/Support/JSON.h"
39 #include <map>
40 
41 namespace llvm {
42 
43 class Module;
44 class BasicBlock;
45 class Instruction;
46 class Function;
47 class Value;
48 class raw_ostream;
49 class LLVMContext;
50 class IR2VecVocabAnalysis;
51 
52 /// IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
53 /// Symbolic embeddings capture the "syntactic" and "statistical correlation"
54 /// of the IR entities. Flow-aware embeddings build on top of symbolic
55 /// embeddings and additionally capture the flow information in the IR.
56 /// IR2VecKind is used to specify the type of embeddings to generate.
57 /// Currently, only Symbolic embeddings are supported.
58 enum class IR2VecKind { Symbolic };
59 
60 namespace ir2vec {
61 
62 LLVM_ABI extern cl::opt<float> OpcWeight;
63 LLVM_ABI extern cl::opt<float> TypeWeight;
64 LLVM_ABI extern cl::opt<float> ArgWeight;
65 
66 /// Embedding is a datatype that wraps std::vector<double>. It provides
67 /// additional functionality for arithmetic and comparison operations.
68 /// It is meant to be used *like* std::vector<double> but is more restrictive
69 /// in the sense that it does not allow the user to change the size of the
70 /// embedding vector. The dimension of the embedding is fixed at the time of
71 /// construction of Embedding object. But the elements can be modified in-place.
72 struct Embedding {
73 private:
74   std::vector<double> Data;
75 
76 public:
77   Embedding() = default;
EmbeddingEmbedding78   Embedding(const std::vector<double> &V) : Data(V) {}
EmbeddingEmbedding79   Embedding(std::vector<double> &&V) : Data(std::move(V)) {}
EmbeddingEmbedding80   Embedding(std::initializer_list<double> IL) : Data(IL) {}
81 
EmbeddingEmbedding82   explicit Embedding(size_t Size) : Data(Size) {}
EmbeddingEmbedding83   Embedding(size_t Size, double InitialValue) : Data(Size, InitialValue) {}
84 
sizeEmbedding85   size_t size() const { return Data.size(); }
emptyEmbedding86   bool empty() const { return Data.empty(); }
87 
88   double &operator[](size_t Itr) {
89     assert(Itr < Data.size() && "Index out of bounds");
90     return Data[Itr];
91   }
92 
93   const double &operator[](size_t Itr) const {
94     assert(Itr < Data.size() && "Index out of bounds");
95     return Data[Itr];
96   }
97 
98   using iterator = typename std::vector<double>::iterator;
99   using const_iterator = typename std::vector<double>::const_iterator;
100 
beginEmbedding101   iterator begin() { return Data.begin(); }
endEmbedding102   iterator end() { return Data.end(); }
beginEmbedding103   const_iterator begin() const { return Data.begin(); }
endEmbedding104   const_iterator end() const { return Data.end(); }
cbeginEmbedding105   const_iterator cbegin() const { return Data.cbegin(); }
cendEmbedding106   const_iterator cend() const { return Data.cend(); }
107 
getDataEmbedding108   const std::vector<double> &getData() const { return Data; }
109 
110   /// Arithmetic operators
111   LLVM_ABI Embedding &operator+=(const Embedding &RHS);
112   LLVM_ABI Embedding operator+(const Embedding &RHS) const;
113   LLVM_ABI Embedding &operator-=(const Embedding &RHS);
114   LLVM_ABI Embedding operator-(const Embedding &RHS) const;
115   LLVM_ABI Embedding &operator*=(double Factor);
116   LLVM_ABI Embedding operator*(double Factor) const;
117 
118   /// Adds Src Embedding scaled by Factor with the called Embedding.
119   /// Called_Embedding += Src * Factor
120   LLVM_ABI Embedding &scaleAndAdd(const Embedding &Src, float Factor);
121 
122   /// Returns true if the embedding is approximately equal to the RHS embedding
123   /// within the specified tolerance.
124   LLVM_ABI bool approximatelyEquals(const Embedding &RHS,
125                                     double Tolerance = 1e-4) const;
126 
127   LLVM_ABI void print(raw_ostream &OS) const;
128 };
129 
130 using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
131 using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
132 
133 /// Class for storing and accessing the IR2Vec vocabulary.
134 /// Encapsulates all vocabulary-related constants, logic, and access methods.
135 class Vocabulary {
136   friend class llvm::IR2VecVocabAnalysis;
137   using VocabVector = std::vector<ir2vec::Embedding>;
138   VocabVector Vocab;
139   bool Valid = false;
140 
141   /// Operand kinds supported by IR2Vec Vocabulary
142   enum class OperandKind : unsigned {
143     FunctionID,
144     PointerID,
145     ConstantID,
146     VariableID,
147     MaxOperandKind
148   };
149   /// String mappings for OperandKind values
150   static constexpr StringLiteral OperandKindNames[] = {"Function", "Pointer",
151                                                        "Constant", "Variable"};
152   static_assert(std::size(OperandKindNames) ==
153                     static_cast<unsigned>(OperandKind::MaxOperandKind),
154                 "OperandKindNames array size must match MaxOperandKind");
155 
156   /// Vocabulary layout constants
157 #define LAST_OTHER_INST(NUM) static constexpr unsigned MaxOpcodes = NUM;
158 #include "llvm/IR/Instruction.def"
159 #undef LAST_OTHER_INST
160 
161   static constexpr unsigned MaxTypeIDs = Type::TypeID::TargetExtTyID + 1;
162   static constexpr unsigned MaxOperandKinds =
163       static_cast<unsigned>(OperandKind::MaxOperandKind);
164 
165 public:
166   Vocabulary() = default;
167   Vocabulary(VocabVector &&Vocab);
168 
169   bool isValid() const;
170   unsigned getDimension() const;
171   size_t size() const;
172 
173   /// Helper function to get vocabulary key for a given Opcode
174   static StringRef getVocabKeyForOpcode(unsigned Opcode);
175 
176   /// Helper function to get vocabulary key for a given TypeID
177   static StringRef getVocabKeyForTypeID(Type::TypeID TypeID);
178 
179   /// Helper function to get vocabulary key for a given OperandKind
180   static StringRef getVocabKeyForOperandKind(OperandKind Kind);
181 
182   /// Helper function to classify an operand into OperandKind
183   static OperandKind getOperandKind(const Value *Op);
184 
185   /// Accessors to get the embedding for a given entity.
186   const ir2vec::Embedding &operator[](unsigned Opcode) const;
187   const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
188   const ir2vec::Embedding &operator[](const Value *Arg) const;
189 
190   /// Const Iterator type aliases
191   using const_iterator = VocabVector::const_iterator;
begin()192   const_iterator begin() const {
193     assert(Valid && "IR2Vec Vocabulary is invalid");
194     return Vocab.begin();
195   }
196 
cbegin()197   const_iterator cbegin() const {
198     assert(Valid && "IR2Vec Vocabulary is invalid");
199     return Vocab.cbegin();
200   }
201 
end()202   const_iterator end() const {
203     assert(Valid && "IR2Vec Vocabulary is invalid");
204     return Vocab.end();
205   }
206 
cend()207   const_iterator cend() const {
208     assert(Valid && "IR2Vec Vocabulary is invalid");
209     return Vocab.cend();
210   }
211 
212   /// Returns the string key for a given index position in the vocabulary.
213   /// This is useful for debugging or printing the vocabulary. Do not use this
214   /// for embedding generation as string based lookups are inefficient.
215   static StringRef getStringKey(unsigned Pos);
216 
217   /// Create a dummy vocabulary for testing purposes.
218   static VocabVector createDummyVocabForTest(unsigned Dim = 1);
219 
220   bool invalidate(Module &M, const PreservedAnalyses &PA,
221                   ModuleAnalysisManager::Invalidator &Inv) const;
222 };
223 
224 /// Embedder provides the interface to generate embeddings (vector
225 /// representations) for instructions, basic blocks, and functions. The
226 /// vector representations are generated using IR2Vec algorithms.
227 ///
228 /// The Embedder class is an abstract class and it is intended to be
229 /// subclassed for different IR2Vec algorithms like Symbolic and Flow-aware.
230 class Embedder {
231 protected:
232   const Function &F;
233   const Vocabulary &Vocab;
234 
235   /// Dimension of the vector representation; captured from the input vocabulary
236   const unsigned Dimension;
237 
238   /// Weights for different entities (like opcode, arguments, types)
239   /// in the IR instructions to generate the vector representation.
240   const float OpcWeight, TypeWeight, ArgWeight;
241 
242   // Utility maps - these are used to store the vector representations of
243   // instructions, basic blocks and functions.
244   mutable Embedding FuncVector;
245   mutable BBEmbeddingsMap BBVecMap;
246   mutable InstEmbeddingsMap InstVecMap;
247 
248   LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab);
249 
250   /// Helper function to compute embeddings. It generates embeddings for all
251   /// the instructions and basic blocks in the function F. Logic of computing
252   /// the embeddings is specific to the kind of embeddings being computed.
253   virtual void computeEmbeddings() const = 0;
254 
255   /// Helper function to compute the embedding for a given basic block.
256   /// Specific to the kind of embeddings being computed.
257   virtual void computeEmbeddings(const BasicBlock &BB) const = 0;
258 
259 public:
260   virtual ~Embedder() = default;
261 
262   /// Factory method to create an Embedder object.
263   LLVM_ABI static std::unique_ptr<Embedder>
264   create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab);
265 
266   /// Returns a map containing instructions and the corresponding embeddings for
267   /// the function F if it has been computed. If not, it computes the embeddings
268   /// for the function and returns the map.
269   LLVM_ABI const InstEmbeddingsMap &getInstVecMap() const;
270 
271   /// Returns a map containing basic block and the corresponding embeddings for
272   /// the function F if it has been computed. If not, it computes the embeddings
273   /// for the function and returns the map.
274   LLVM_ABI const BBEmbeddingsMap &getBBVecMap() const;
275 
276   /// Returns the embedding for a given basic block in the function F if it has
277   /// been computed. If not, it computes the embedding for the basic block and
278   /// returns it.
279   LLVM_ABI const Embedding &getBBVector(const BasicBlock &BB) const;
280 
281   /// Computes and returns the embedding for the current function.
282   LLVM_ABI const Embedding &getFunctionVector() const;
283 };
284 
285 /// Class for computing the Symbolic embeddings of IR2Vec.
286 /// Symbolic embeddings are constructed based on the entity-level
287 /// representations obtained from the Vocabulary.
288 class LLVM_ABI SymbolicEmbedder : public Embedder {
289 private:
290   void computeEmbeddings() const override;
291   void computeEmbeddings(const BasicBlock &BB) const override;
292 
293 public:
SymbolicEmbedder(const Function & F,const Vocabulary & Vocab)294   SymbolicEmbedder(const Function &F, const Vocabulary &Vocab)
295       : Embedder(F, Vocab) {
296     FuncVector = Embedding(Dimension, 0);
297   }
298 };
299 
300 } // namespace ir2vec
301 
302 /// This analysis provides the vocabulary for IR2Vec. The vocabulary provides a
303 /// mapping between an entity of the IR (like opcode, type, argument, etc.) and
304 /// its corresponding embedding.
305 class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
306   using VocabVector = std::vector<ir2vec::Embedding>;
307   using VocabMap = std::map<std::string, ir2vec::Embedding>;
308   VocabMap OpcVocab, TypeVocab, ArgVocab;
309   VocabVector Vocab;
310 
311   Error readVocabulary();
312   Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue,
313                           VocabMap &TargetVocab, unsigned &Dim);
314   void generateNumMappedVocab();
315   void emitError(Error Err, LLVMContext &Ctx);
316 
317 public:
318   LLVM_ABI static AnalysisKey Key;
319   IR2VecVocabAnalysis() = default;
320   LLVM_ABI explicit IR2VecVocabAnalysis(const VocabVector &Vocab);
321   LLVM_ABI explicit IR2VecVocabAnalysis(VocabVector &&Vocab);
322   using Result = ir2vec::Vocabulary;
323   LLVM_ABI Result run(Module &M, ModuleAnalysisManager &MAM);
324 };
325 
326 /// This pass prints the IR2Vec embeddings for instructions, basic blocks, and
327 /// functions.
328 class IR2VecPrinterPass : public PassInfoMixin<IR2VecPrinterPass> {
329   raw_ostream &OS;
330 
331 public:
IR2VecPrinterPass(raw_ostream & OS)332   explicit IR2VecPrinterPass(raw_ostream &OS) : OS(OS) {}
333   LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
isRequired()334   static bool isRequired() { return true; }
335 };
336 
337 /// This pass prints the embeddings in the vocabulary
338 class IR2VecVocabPrinterPass : public PassInfoMixin<IR2VecVocabPrinterPass> {
339   raw_ostream &OS;
340 
341 public:
IR2VecVocabPrinterPass(raw_ostream & OS)342   explicit IR2VecVocabPrinterPass(raw_ostream &OS) : OS(OS) {}
343   LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
isRequired()344   static bool isRequired() { return true; }
345 };
346 
347 } // namespace llvm
348 
349 #endif // LLVM_ANALYSIS_IR2VEC_H
350