xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/IR2Vec.cpp (revision 1342eb5a832fa10e689a29faab3acb6054e4778c)
1 //===- IR2Vec.cpp - Implementation of IR2Vec -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // Exceptions. See the LICENSE file for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file implements the IR2Vec algorithm.
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Analysis/IR2Vec.h"
15 
16 #include "llvm/ADT/DepthFirstIterator.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/Statistic.h"
19 #include "llvm/IR/CFG.h"
20 #include "llvm/IR/Module.h"
21 #include "llvm/IR/PassManager.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Errc.h"
24 #include "llvm/Support/Error.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/Format.h"
27 #include "llvm/Support/MemoryBuffer.h"
28 
29 using namespace llvm;
30 using namespace ir2vec;
31 
32 #define DEBUG_TYPE "ir2vec"
33 
34 STATISTIC(VocabMissCounter,
35           "Number of lookups to entites not present in the vocabulary");
36 
37 namespace llvm {
38 namespace ir2vec {
39 static cl::OptionCategory IR2VecCategory("IR2Vec Options");
40 
41 // FIXME: Use a default vocab when not specified
42 static cl::opt<std::string>
43     VocabFile("ir2vec-vocab-path", cl::Optional,
44               cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""),
45               cl::cat(IR2VecCategory));
46 cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional, cl::init(1.0),
47                          cl::desc("Weight for opcode embeddings"),
48                          cl::cat(IR2VecCategory));
49 cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5),
50                           cl::desc("Weight for type embeddings"),
51                           cl::cat(IR2VecCategory));
52 cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2),
53                          cl::desc("Weight for argument embeddings"),
54                          cl::cat(IR2VecCategory));
55 } // namespace ir2vec
56 } // namespace llvm
57 
58 AnalysisKey IR2VecVocabAnalysis::Key;
59 
60 // ==----------------------------------------------------------------------===//
61 // Local helper functions
62 //===----------------------------------------------------------------------===//
63 namespace llvm::json {
64 inline bool fromJSON(const llvm::json::Value &E, Embedding &Out,
65                      llvm::json::Path P) {
66   std::vector<double> TempOut;
67   if (!llvm::json::fromJSON(E, TempOut, P))
68     return false;
69   Out = Embedding(std::move(TempOut));
70   return true;
71 }
72 } // namespace llvm::json
73 
74 // ==----------------------------------------------------------------------===//
75 // Embedding
76 //===----------------------------------------------------------------------===//
77 Embedding &Embedding::operator+=(const Embedding &RHS) {
78   assert(this->size() == RHS.size() && "Vectors must have the same dimension");
79   std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
80                  std::plus<double>());
81   return *this;
82 }
83 
84 Embedding Embedding::operator+(const Embedding &RHS) const {
85   Embedding Result(*this);
86   Result += RHS;
87   return Result;
88 }
89 
90 Embedding &Embedding::operator-=(const Embedding &RHS) {
91   assert(this->size() == RHS.size() && "Vectors must have the same dimension");
92   std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
93                  std::minus<double>());
94   return *this;
95 }
96 
97 Embedding Embedding::operator-(const Embedding &RHS) const {
98   Embedding Result(*this);
99   Result -= RHS;
100   return Result;
101 }
102 
103 Embedding &Embedding::operator*=(double Factor) {
104   std::transform(this->begin(), this->end(), this->begin(),
105                  [Factor](double Elem) { return Elem * Factor; });
106   return *this;
107 }
108 
109 Embedding Embedding::operator*(double Factor) const {
110   Embedding Result(*this);
111   Result *= Factor;
112   return Result;
113 }
114 
115 Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
116   assert(this->size() == Src.size() && "Vectors must have the same dimension");
117   for (size_t Itr = 0; Itr < this->size(); ++Itr)
118     (*this)[Itr] += Src[Itr] * Factor;
119   return *this;
120 }
121 
122 bool Embedding::approximatelyEquals(const Embedding &RHS,
123                                     double Tolerance) const {
124   assert(this->size() == RHS.size() && "Vectors must have the same dimension");
125   for (size_t Itr = 0; Itr < this->size(); ++Itr)
126     if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance)
127       return false;
128   return true;
129 }
130 
131 void Embedding::print(raw_ostream &OS) const {
132   OS << " [";
133   for (const auto &Elem : Data)
134     OS << " " << format("%.2f", Elem) << " ";
135   OS << "]\n";
136 }
137 
138 // ==----------------------------------------------------------------------===//
139 // Embedder and its subclasses
140 //===----------------------------------------------------------------------===//
141 
142 Embedder::Embedder(const Function &F, const Vocabulary &Vocab)
143     : F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
144       OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {
145 }
146 
147 std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
148                                            const Vocabulary &Vocab) {
149   switch (Mode) {
150   case IR2VecKind::Symbolic:
151     return std::make_unique<SymbolicEmbedder>(F, Vocab);
152   }
153   return nullptr;
154 }
155 
156 const InstEmbeddingsMap &Embedder::getInstVecMap() const {
157   if (InstVecMap.empty())
158     computeEmbeddings();
159   return InstVecMap;
160 }
161 
162 const BBEmbeddingsMap &Embedder::getBBVecMap() const {
163   if (BBVecMap.empty())
164     computeEmbeddings();
165   return BBVecMap;
166 }
167 
168 const Embedding &Embedder::getBBVector(const BasicBlock &BB) const {
169   auto It = BBVecMap.find(&BB);
170   if (It != BBVecMap.end())
171     return It->second;
172   computeEmbeddings(BB);
173   return BBVecMap[&BB];
174 }
175 
176 const Embedding &Embedder::getFunctionVector() const {
177   // Currently, we always (re)compute the embeddings for the function.
178   // This is cheaper than caching the vector.
179   computeEmbeddings();
180   return FuncVector;
181 }
182 
183 void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
184   Embedding BBVector(Dimension, 0);
185 
186   // We consider only the non-debug and non-pseudo instructions
187   for (const auto &I : BB.instructionsWithoutDebug()) {
188     Embedding ArgEmb(Dimension, 0);
189     for (const auto &Op : I.operands())
190       ArgEmb += Vocab[Op];
191     auto InstVector =
192         Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
193     InstVecMap[&I] = InstVector;
194     BBVector += InstVector;
195   }
196   BBVecMap[&BB] = BBVector;
197 }
198 
199 void SymbolicEmbedder::computeEmbeddings() const {
200   if (F.isDeclaration())
201     return;
202 
203   // Consider only the basic blocks that are reachable from entry
204   for (const BasicBlock *BB : depth_first(&F)) {
205     computeEmbeddings(*BB);
206     FuncVector += BBVecMap[BB];
207   }
208 }
209 
210 // ==----------------------------------------------------------------------===//
211 // Vocabulary
212 //===----------------------------------------------------------------------===//
213 
214 Vocabulary::Vocabulary(VocabVector &&Vocab)
215     : Vocab(std::move(Vocab)), Valid(true) {}
216 
217 bool Vocabulary::isValid() const {
218   return Vocab.size() == (MaxOpcodes + MaxTypeIDs + MaxOperandKinds) && Valid;
219 }
220 
221 size_t Vocabulary::size() const {
222   assert(Valid && "IR2Vec Vocabulary is invalid");
223   return Vocab.size();
224 }
225 
226 unsigned Vocabulary::getDimension() const {
227   assert(Valid && "IR2Vec Vocabulary is invalid");
228   return Vocab[0].size();
229 }
230 
231 const Embedding &Vocabulary::operator[](unsigned Opcode) const {
232   assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
233   return Vocab[Opcode - 1];
234 }
235 
236 const Embedding &Vocabulary::operator[](Type::TypeID TypeId) const {
237   assert(static_cast<unsigned>(TypeId) < MaxTypeIDs && "Invalid type ID");
238   return Vocab[MaxOpcodes + static_cast<unsigned>(TypeId)];
239 }
240 
241 const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const {
242   OperandKind ArgKind = getOperandKind(Arg);
243   return Vocab[MaxOpcodes + MaxTypeIDs + static_cast<unsigned>(ArgKind)];
244 }
245 
246 StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
247   assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
248 #define HANDLE_INST(NUM, OPCODE, CLASS)                                        \
249   if (Opcode == NUM) {                                                         \
250     return #OPCODE;                                                            \
251   }
252 #include "llvm/IR/Instruction.def"
253 #undef HANDLE_INST
254   return "UnknownOpcode";
255 }
256 
257 StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
258   switch (TypeID) {
259   case Type::VoidTyID:
260     return "VoidTy";
261   case Type::HalfTyID:
262   case Type::BFloatTyID:
263   case Type::FloatTyID:
264   case Type::DoubleTyID:
265   case Type::X86_FP80TyID:
266   case Type::FP128TyID:
267   case Type::PPC_FP128TyID:
268     return "FloatTy";
269   case Type::IntegerTyID:
270     return "IntegerTy";
271   case Type::FunctionTyID:
272     return "FunctionTy";
273   case Type::StructTyID:
274     return "StructTy";
275   case Type::ArrayTyID:
276     return "ArrayTy";
277   case Type::PointerTyID:
278   case Type::TypedPointerTyID:
279     return "PointerTy";
280   case Type::FixedVectorTyID:
281   case Type::ScalableVectorTyID:
282     return "VectorTy";
283   case Type::LabelTyID:
284     return "LabelTy";
285   case Type::TokenTyID:
286     return "TokenTy";
287   case Type::MetadataTyID:
288     return "MetadataTy";
289   case Type::X86_AMXTyID:
290   case Type::TargetExtTyID:
291     return "UnknownTy";
292   }
293   return "UnknownTy";
294 }
295 
296 StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) {
297   unsigned Index = static_cast<unsigned>(Kind);
298   assert(Index < MaxOperandKinds && "Invalid OperandKind");
299   return OperandKindNames[Index];
300 }
301 
302 Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
303   VocabVector DummyVocab;
304   float DummyVal = 0.1f;
305   // Create a dummy vocabulary with entries for all opcodes, types, and
306   // operand
307   for (unsigned _ : seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxTypeIDs +
308                                 Vocabulary::MaxOperandKinds)) {
309     DummyVocab.push_back(Embedding(Dim, DummyVal));
310     DummyVal += 0.1;
311   }
312   return DummyVocab;
313 }
314 
315 // Helper function to classify an operand into OperandKind
316 Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
317   if (isa<Function>(Op))
318     return OperandKind::FunctionID;
319   if (isa<PointerType>(Op->getType()))
320     return OperandKind::PointerID;
321   if (isa<Constant>(Op))
322     return OperandKind::ConstantID;
323   return OperandKind::VariableID;
324 }
325 
326 StringRef Vocabulary::getStringKey(unsigned Pos) {
327   assert(Pos < MaxOpcodes + MaxTypeIDs + MaxOperandKinds &&
328          "Position out of bounds in vocabulary");
329   // Opcode
330   if (Pos < MaxOpcodes)
331     return getVocabKeyForOpcode(Pos + 1);
332   // Type
333   if (Pos < MaxOpcodes + MaxTypeIDs)
334     return getVocabKeyForTypeID(static_cast<Type::TypeID>(Pos - MaxOpcodes));
335   // Operand
336   return getVocabKeyForOperandKind(
337       static_cast<OperandKind>(Pos - MaxOpcodes - MaxTypeIDs));
338 }
339 
340 // For now, assume vocabulary is stable unless explicitly invalidated.
341 bool Vocabulary::invalidate(Module &M, const PreservedAnalyses &PA,
342                             ModuleAnalysisManager::Invalidator &Inv) const {
343   auto PAC = PA.getChecker<IR2VecVocabAnalysis>();
344   return !(PAC.preservedWhenStateless());
345 }
346 
347 // ==----------------------------------------------------------------------===//
348 // IR2VecVocabAnalysis
349 //===----------------------------------------------------------------------===//
350 
351 Error IR2VecVocabAnalysis::parseVocabSection(
352     StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab,
353     unsigned &Dim) {
354   json::Path::Root Path("");
355   const json::Object *RootObj = ParsedVocabValue.getAsObject();
356   if (!RootObj)
357     return createStringError(errc::invalid_argument,
358                              "JSON root is not an object");
359 
360   const json::Value *SectionValue = RootObj->get(Key);
361   if (!SectionValue)
362     return createStringError(errc::invalid_argument,
363                              "Missing '" + std::string(Key) +
364                                  "' section in vocabulary file");
365   if (!json::fromJSON(*SectionValue, TargetVocab, Path))
366     return createStringError(errc::illegal_byte_sequence,
367                              "Unable to parse '" + std::string(Key) +
368                                  "' section from vocabulary");
369 
370   Dim = TargetVocab.begin()->second.size();
371   if (Dim == 0)
372     return createStringError(errc::illegal_byte_sequence,
373                              "Dimension of '" + std::string(Key) +
374                                  "' section of the vocabulary is zero");
375 
376   if (!std::all_of(TargetVocab.begin(), TargetVocab.end(),
377                    [Dim](const std::pair<StringRef, Embedding> &Entry) {
378                      return Entry.second.size() == Dim;
379                    }))
380     return createStringError(
381         errc::illegal_byte_sequence,
382         "All vectors in the '" + std::string(Key) +
383             "' section of the vocabulary are not of the same dimension");
384 
385   return Error::success();
386 }
387 
388 // FIXME: Make this optional. We can avoid file reads
389 // by auto-generating a default vocabulary during the build time.
390 Error IR2VecVocabAnalysis::readVocabulary() {
391   auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true);
392   if (!BufOrError)
393     return createFileError(VocabFile, BufOrError.getError());
394 
395   auto Content = BufOrError.get()->getBuffer();
396 
397   Expected<json::Value> ParsedVocabValue = json::parse(Content);
398   if (!ParsedVocabValue)
399     return ParsedVocabValue.takeError();
400 
401   unsigned OpcodeDim = 0, TypeDim = 0, ArgDim = 0;
402   if (auto Err =
403           parseVocabSection("Opcodes", *ParsedVocabValue, OpcVocab, OpcodeDim))
404     return Err;
405 
406   if (auto Err =
407           parseVocabSection("Types", *ParsedVocabValue, TypeVocab, TypeDim))
408     return Err;
409 
410   if (auto Err =
411           parseVocabSection("Arguments", *ParsedVocabValue, ArgVocab, ArgDim))
412     return Err;
413 
414   if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))
415     return createStringError(errc::illegal_byte_sequence,
416                              "Vocabulary sections have different dimensions");
417 
418   return Error::success();
419 }
420 
421 void IR2VecVocabAnalysis::generateNumMappedVocab() {
422 
423   // Helper for handling missing entities in the vocabulary.
424   // Currently, we use a zero vector. In the future, we will throw an error to
425   // ensure that *all* known entities are present in the vocabulary.
426   auto handleMissingEntity = [](const std::string &Val) {
427     LLVM_DEBUG(errs() << Val
428                       << " is not in vocabulary, using zero vector; This "
429                          "would result in an error in future.\n");
430     ++VocabMissCounter;
431   };
432 
433   unsigned Dim = OpcVocab.begin()->second.size();
434   assert(Dim > 0 && "Vocabulary dimension must be greater than zero");
435 
436   // Handle Opcodes
437   std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
438                                                  Embedding(Dim, 0));
439   for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
440     StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
441     auto It = OpcVocab.find(VocabKey.str());
442     if (It != OpcVocab.end())
443       NumericOpcodeEmbeddings[Opcode] = It->second;
444     else
445       handleMissingEntity(VocabKey.str());
446   }
447   Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(),
448                NumericOpcodeEmbeddings.end());
449 
450   // Handle Types
451   std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxTypeIDs,
452                                                Embedding(Dim, 0));
453   for (unsigned TypeID : seq(0u, Vocabulary::MaxTypeIDs)) {
454     StringRef VocabKey =
455         Vocabulary::getVocabKeyForTypeID(static_cast<Type::TypeID>(TypeID));
456     if (auto It = TypeVocab.find(VocabKey.str()); It != TypeVocab.end()) {
457       NumericTypeEmbeddings[TypeID] = It->second;
458       continue;
459     }
460     handleMissingEntity(VocabKey.str());
461   }
462   Vocab.insert(Vocab.end(), NumericTypeEmbeddings.begin(),
463                NumericTypeEmbeddings.end());
464 
465   // Handle Arguments/Operands
466   std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds,
467                                               Embedding(Dim, 0));
468   for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) {
469     Vocabulary::OperandKind Kind = static_cast<Vocabulary::OperandKind>(OpKind);
470     StringRef VocabKey = Vocabulary::getVocabKeyForOperandKind(Kind);
471     auto It = ArgVocab.find(VocabKey.str());
472     if (It != ArgVocab.end()) {
473       NumericArgEmbeddings[OpKind] = It->second;
474       continue;
475     }
476     handleMissingEntity(VocabKey.str());
477   }
478   Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(),
479                NumericArgEmbeddings.end());
480 }
481 
482 IR2VecVocabAnalysis::IR2VecVocabAnalysis(const VocabVector &Vocab)
483     : Vocab(Vocab) {}
484 
485 IR2VecVocabAnalysis::IR2VecVocabAnalysis(VocabVector &&Vocab)
486     : Vocab(std::move(Vocab)) {}
487 
488 void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
489   handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
490     Ctx.emitError("Error reading vocabulary: " + EI.message());
491   });
492 }
493 
494 IR2VecVocabAnalysis::Result
495 IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
496   auto Ctx = &M.getContext();
497   // If vocabulary is already populated by the constructor, use it.
498   if (!Vocab.empty())
499     return Vocabulary(std::move(Vocab));
500 
501   // Otherwise, try to read from the vocabulary file.
502   if (VocabFile.empty()) {
503     // FIXME: Use default vocabulary
504     Ctx->emitError("IR2Vec vocabulary file path not specified; You may need to "
505                    "set it using --ir2vec-vocab-path");
506     return Vocabulary(); // Return invalid result
507   }
508   if (auto Err = readVocabulary()) {
509     emitError(std::move(Err), *Ctx);
510     return Vocabulary();
511   }
512 
513   // Scale the vocabulary sections based on the provided weights
514   auto scaleVocabSection = [](VocabMap &Vocab, double Weight) {
515     for (auto &Entry : Vocab)
516       Entry.second *= Weight;
517   };
518   scaleVocabSection(OpcVocab, OpcWeight);
519   scaleVocabSection(TypeVocab, TypeWeight);
520   scaleVocabSection(ArgVocab, ArgWeight);
521 
522   // Generate the numeric lookup vocabulary
523   generateNumMappedVocab();
524 
525   return Vocabulary(std::move(Vocab));
526 }
527 
528 // ==----------------------------------------------------------------------===//
529 // Printer Passes
530 //===----------------------------------------------------------------------===//
531 
532 PreservedAnalyses IR2VecPrinterPass::run(Module &M,
533                                          ModuleAnalysisManager &MAM) {
534   auto Vocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
535   assert(Vocabulary.isValid() && "IR2Vec Vocabulary is invalid");
536 
537   for (Function &F : M) {
538     std::unique_ptr<Embedder> Emb =
539         Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
540     if (!Emb) {
541       OS << "Error creating IR2Vec embeddings \n";
542       continue;
543     }
544 
545     OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
546     OS << "Function vector: ";
547     Emb->getFunctionVector().print(OS);
548 
549     OS << "Basic block vectors:\n";
550     const auto &BBMap = Emb->getBBVecMap();
551     for (const BasicBlock &BB : F) {
552       auto It = BBMap.find(&BB);
553       if (It != BBMap.end()) {
554         OS << "Basic block: " << BB.getName() << ":\n";
555         It->second.print(OS);
556       }
557     }
558 
559     OS << "Instruction vectors:\n";
560     const auto &InstMap = Emb->getInstVecMap();
561     for (const BasicBlock &BB : F) {
562       for (const Instruction &I : BB) {
563         auto It = InstMap.find(&I);
564         if (It != InstMap.end()) {
565           OS << "Instruction: ";
566           I.print(OS);
567           It->second.print(OS);
568         }
569       }
570     }
571   }
572   return PreservedAnalyses::all();
573 }
574 
575 PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M,
576                                               ModuleAnalysisManager &MAM) {
577   auto IR2VecVocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
578   assert(IR2VecVocabulary.isValid() && "IR2Vec Vocabulary is invalid");
579 
580   // Print each entry
581   unsigned Pos = 0;
582   for (const auto &Entry : IR2VecVocabulary) {
583     OS << "Key: " << IR2VecVocabulary.getStringKey(Pos++) << ": ";
584     Entry.print(OS);
585   }
586   return PreservedAnalyses::all();
587 }
588