1 //===- IR2Vec.cpp - Implementation of IR2Vec -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // Exceptions. See the LICENSE file for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file implements the IR2Vec algorithm.
11 ///
12 //===----------------------------------------------------------------------===//
13
14 #include "llvm/Analysis/IR2Vec.h"
15
16 #include "llvm/ADT/DepthFirstIterator.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/Statistic.h"
19 #include "llvm/IR/CFG.h"
20 #include "llvm/IR/Module.h"
21 #include "llvm/IR/PassManager.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Errc.h"
24 #include "llvm/Support/Error.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/Format.h"
27 #include "llvm/Support/MemoryBuffer.h"
28
29 using namespace llvm;
30 using namespace ir2vec;
31
32 #define DEBUG_TYPE "ir2vec"
33
34 STATISTIC(VocabMissCounter,
35 "Number of lookups to entites not present in the vocabulary");
36
37 namespace llvm {
38 namespace ir2vec {
39 static cl::OptionCategory IR2VecCategory("IR2Vec Options");
40
41 // FIXME: Use a default vocab when not specified
42 static cl::opt<std::string>
43 VocabFile("ir2vec-vocab-path", cl::Optional,
44 cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""),
45 cl::cat(IR2VecCategory));
46 cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional, cl::init(1.0),
47 cl::desc("Weight for opcode embeddings"),
48 cl::cat(IR2VecCategory));
49 cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5),
50 cl::desc("Weight for type embeddings"),
51 cl::cat(IR2VecCategory));
52 cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2),
53 cl::desc("Weight for argument embeddings"),
54 cl::cat(IR2VecCategory));
55 } // namespace ir2vec
56 } // namespace llvm
57
58 AnalysisKey IR2VecVocabAnalysis::Key;
59
60 // ==----------------------------------------------------------------------===//
61 // Local helper functions
62 //===----------------------------------------------------------------------===//
63 namespace llvm::json {
fromJSON(const llvm::json::Value & E,Embedding & Out,llvm::json::Path P)64 inline bool fromJSON(const llvm::json::Value &E, Embedding &Out,
65 llvm::json::Path P) {
66 std::vector<double> TempOut;
67 if (!llvm::json::fromJSON(E, TempOut, P))
68 return false;
69 Out = Embedding(std::move(TempOut));
70 return true;
71 }
72 } // namespace llvm::json
73
74 // ==----------------------------------------------------------------------===//
75 // Embedding
76 //===----------------------------------------------------------------------===//
operator +=(const Embedding & RHS)77 Embedding &Embedding::operator+=(const Embedding &RHS) {
78 assert(this->size() == RHS.size() && "Vectors must have the same dimension");
79 std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
80 std::plus<double>());
81 return *this;
82 }
83
operator +(const Embedding & RHS) const84 Embedding Embedding::operator+(const Embedding &RHS) const {
85 Embedding Result(*this);
86 Result += RHS;
87 return Result;
88 }
89
operator -=(const Embedding & RHS)90 Embedding &Embedding::operator-=(const Embedding &RHS) {
91 assert(this->size() == RHS.size() && "Vectors must have the same dimension");
92 std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
93 std::minus<double>());
94 return *this;
95 }
96
operator -(const Embedding & RHS) const97 Embedding Embedding::operator-(const Embedding &RHS) const {
98 Embedding Result(*this);
99 Result -= RHS;
100 return Result;
101 }
102
operator *=(double Factor)103 Embedding &Embedding::operator*=(double Factor) {
104 std::transform(this->begin(), this->end(), this->begin(),
105 [Factor](double Elem) { return Elem * Factor; });
106 return *this;
107 }
108
operator *(double Factor) const109 Embedding Embedding::operator*(double Factor) const {
110 Embedding Result(*this);
111 Result *= Factor;
112 return Result;
113 }
114
scaleAndAdd(const Embedding & Src,float Factor)115 Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
116 assert(this->size() == Src.size() && "Vectors must have the same dimension");
117 for (size_t Itr = 0; Itr < this->size(); ++Itr)
118 (*this)[Itr] += Src[Itr] * Factor;
119 return *this;
120 }
121
approximatelyEquals(const Embedding & RHS,double Tolerance) const122 bool Embedding::approximatelyEquals(const Embedding &RHS,
123 double Tolerance) const {
124 assert(this->size() == RHS.size() && "Vectors must have the same dimension");
125 for (size_t Itr = 0; Itr < this->size(); ++Itr)
126 if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance)
127 return false;
128 return true;
129 }
130
print(raw_ostream & OS) const131 void Embedding::print(raw_ostream &OS) const {
132 OS << " [";
133 for (const auto &Elem : Data)
134 OS << " " << format("%.2f", Elem) << " ";
135 OS << "]\n";
136 }
137
138 // ==----------------------------------------------------------------------===//
139 // Embedder and its subclasses
140 //===----------------------------------------------------------------------===//
141
Embedder(const Function & F,const Vocabulary & Vocab)142 Embedder::Embedder(const Function &F, const Vocabulary &Vocab)
143 : F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
144 OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {
145 }
146
create(IR2VecKind Mode,const Function & F,const Vocabulary & Vocab)147 std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
148 const Vocabulary &Vocab) {
149 switch (Mode) {
150 case IR2VecKind::Symbolic:
151 return std::make_unique<SymbolicEmbedder>(F, Vocab);
152 }
153 return nullptr;
154 }
155
getInstVecMap() const156 const InstEmbeddingsMap &Embedder::getInstVecMap() const {
157 if (InstVecMap.empty())
158 computeEmbeddings();
159 return InstVecMap;
160 }
161
getBBVecMap() const162 const BBEmbeddingsMap &Embedder::getBBVecMap() const {
163 if (BBVecMap.empty())
164 computeEmbeddings();
165 return BBVecMap;
166 }
167
getBBVector(const BasicBlock & BB) const168 const Embedding &Embedder::getBBVector(const BasicBlock &BB) const {
169 auto It = BBVecMap.find(&BB);
170 if (It != BBVecMap.end())
171 return It->second;
172 computeEmbeddings(BB);
173 return BBVecMap[&BB];
174 }
175
getFunctionVector() const176 const Embedding &Embedder::getFunctionVector() const {
177 // Currently, we always (re)compute the embeddings for the function.
178 // This is cheaper than caching the vector.
179 computeEmbeddings();
180 return FuncVector;
181 }
182
computeEmbeddings(const BasicBlock & BB) const183 void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
184 Embedding BBVector(Dimension, 0);
185
186 // We consider only the non-debug and non-pseudo instructions
187 for (const auto &I : BB.instructionsWithoutDebug()) {
188 Embedding ArgEmb(Dimension, 0);
189 for (const auto &Op : I.operands())
190 ArgEmb += Vocab[Op];
191 auto InstVector =
192 Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
193 InstVecMap[&I] = InstVector;
194 BBVector += InstVector;
195 }
196 BBVecMap[&BB] = BBVector;
197 }
198
computeEmbeddings() const199 void SymbolicEmbedder::computeEmbeddings() const {
200 if (F.isDeclaration())
201 return;
202
203 // Consider only the basic blocks that are reachable from entry
204 for (const BasicBlock *BB : depth_first(&F)) {
205 computeEmbeddings(*BB);
206 FuncVector += BBVecMap[BB];
207 }
208 }
209
210 // ==----------------------------------------------------------------------===//
211 // Vocabulary
212 //===----------------------------------------------------------------------===//
213
Vocabulary(VocabVector && Vocab)214 Vocabulary::Vocabulary(VocabVector &&Vocab)
215 : Vocab(std::move(Vocab)), Valid(true) {}
216
isValid() const217 bool Vocabulary::isValid() const {
218 return Vocab.size() == (MaxOpcodes + MaxTypeIDs + MaxOperandKinds) && Valid;
219 }
220
size() const221 size_t Vocabulary::size() const {
222 assert(Valid && "IR2Vec Vocabulary is invalid");
223 return Vocab.size();
224 }
225
getDimension() const226 unsigned Vocabulary::getDimension() const {
227 assert(Valid && "IR2Vec Vocabulary is invalid");
228 return Vocab[0].size();
229 }
230
operator [](unsigned Opcode) const231 const Embedding &Vocabulary::operator[](unsigned Opcode) const {
232 assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
233 return Vocab[Opcode - 1];
234 }
235
operator [](Type::TypeID TypeId) const236 const Embedding &Vocabulary::operator[](Type::TypeID TypeId) const {
237 assert(static_cast<unsigned>(TypeId) < MaxTypeIDs && "Invalid type ID");
238 return Vocab[MaxOpcodes + static_cast<unsigned>(TypeId)];
239 }
240
operator [](const Value * Arg) const241 const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const {
242 OperandKind ArgKind = getOperandKind(Arg);
243 return Vocab[MaxOpcodes + MaxTypeIDs + static_cast<unsigned>(ArgKind)];
244 }
245
getVocabKeyForOpcode(unsigned Opcode)246 StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
247 assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
248 #define HANDLE_INST(NUM, OPCODE, CLASS) \
249 if (Opcode == NUM) { \
250 return #OPCODE; \
251 }
252 #include "llvm/IR/Instruction.def"
253 #undef HANDLE_INST
254 return "UnknownOpcode";
255 }
256
getVocabKeyForTypeID(Type::TypeID TypeID)257 StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
258 switch (TypeID) {
259 case Type::VoidTyID:
260 return "VoidTy";
261 case Type::HalfTyID:
262 case Type::BFloatTyID:
263 case Type::FloatTyID:
264 case Type::DoubleTyID:
265 case Type::X86_FP80TyID:
266 case Type::FP128TyID:
267 case Type::PPC_FP128TyID:
268 return "FloatTy";
269 case Type::IntegerTyID:
270 return "IntegerTy";
271 case Type::FunctionTyID:
272 return "FunctionTy";
273 case Type::StructTyID:
274 return "StructTy";
275 case Type::ArrayTyID:
276 return "ArrayTy";
277 case Type::PointerTyID:
278 case Type::TypedPointerTyID:
279 return "PointerTy";
280 case Type::FixedVectorTyID:
281 case Type::ScalableVectorTyID:
282 return "VectorTy";
283 case Type::LabelTyID:
284 return "LabelTy";
285 case Type::TokenTyID:
286 return "TokenTy";
287 case Type::MetadataTyID:
288 return "MetadataTy";
289 case Type::X86_AMXTyID:
290 case Type::TargetExtTyID:
291 return "UnknownTy";
292 }
293 return "UnknownTy";
294 }
295
getVocabKeyForOperandKind(Vocabulary::OperandKind Kind)296 StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) {
297 unsigned Index = static_cast<unsigned>(Kind);
298 assert(Index < MaxOperandKinds && "Invalid OperandKind");
299 return OperandKindNames[Index];
300 }
301
createDummyVocabForTest(unsigned Dim)302 Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
303 VocabVector DummyVocab;
304 float DummyVal = 0.1f;
305 // Create a dummy vocabulary with entries for all opcodes, types, and
306 // operand
307 for (unsigned _ : seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxTypeIDs +
308 Vocabulary::MaxOperandKinds)) {
309 DummyVocab.push_back(Embedding(Dim, DummyVal));
310 DummyVal += 0.1;
311 }
312 return DummyVocab;
313 }
314
315 // Helper function to classify an operand into OperandKind
getOperandKind(const Value * Op)316 Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
317 if (isa<Function>(Op))
318 return OperandKind::FunctionID;
319 if (isa<PointerType>(Op->getType()))
320 return OperandKind::PointerID;
321 if (isa<Constant>(Op))
322 return OperandKind::ConstantID;
323 return OperandKind::VariableID;
324 }
325
getStringKey(unsigned Pos)326 StringRef Vocabulary::getStringKey(unsigned Pos) {
327 assert(Pos < MaxOpcodes + MaxTypeIDs + MaxOperandKinds &&
328 "Position out of bounds in vocabulary");
329 // Opcode
330 if (Pos < MaxOpcodes)
331 return getVocabKeyForOpcode(Pos + 1);
332 // Type
333 if (Pos < MaxOpcodes + MaxTypeIDs)
334 return getVocabKeyForTypeID(static_cast<Type::TypeID>(Pos - MaxOpcodes));
335 // Operand
336 return getVocabKeyForOperandKind(
337 static_cast<OperandKind>(Pos - MaxOpcodes - MaxTypeIDs));
338 }
339
340 // For now, assume vocabulary is stable unless explicitly invalidated.
invalidate(Module & M,const PreservedAnalyses & PA,ModuleAnalysisManager::Invalidator & Inv) const341 bool Vocabulary::invalidate(Module &M, const PreservedAnalyses &PA,
342 ModuleAnalysisManager::Invalidator &Inv) const {
343 auto PAC = PA.getChecker<IR2VecVocabAnalysis>();
344 return !(PAC.preservedWhenStateless());
345 }
346
347 // ==----------------------------------------------------------------------===//
348 // IR2VecVocabAnalysis
349 //===----------------------------------------------------------------------===//
350
parseVocabSection(StringRef Key,const json::Value & ParsedVocabValue,VocabMap & TargetVocab,unsigned & Dim)351 Error IR2VecVocabAnalysis::parseVocabSection(
352 StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab,
353 unsigned &Dim) {
354 json::Path::Root Path("");
355 const json::Object *RootObj = ParsedVocabValue.getAsObject();
356 if (!RootObj)
357 return createStringError(errc::invalid_argument,
358 "JSON root is not an object");
359
360 const json::Value *SectionValue = RootObj->get(Key);
361 if (!SectionValue)
362 return createStringError(errc::invalid_argument,
363 "Missing '" + std::string(Key) +
364 "' section in vocabulary file");
365 if (!json::fromJSON(*SectionValue, TargetVocab, Path))
366 return createStringError(errc::illegal_byte_sequence,
367 "Unable to parse '" + std::string(Key) +
368 "' section from vocabulary");
369
370 Dim = TargetVocab.begin()->second.size();
371 if (Dim == 0)
372 return createStringError(errc::illegal_byte_sequence,
373 "Dimension of '" + std::string(Key) +
374 "' section of the vocabulary is zero");
375
376 if (!std::all_of(TargetVocab.begin(), TargetVocab.end(),
377 [Dim](const std::pair<StringRef, Embedding> &Entry) {
378 return Entry.second.size() == Dim;
379 }))
380 return createStringError(
381 errc::illegal_byte_sequence,
382 "All vectors in the '" + std::string(Key) +
383 "' section of the vocabulary are not of the same dimension");
384
385 return Error::success();
386 }
387
388 // FIXME: Make this optional. We can avoid file reads
389 // by auto-generating a default vocabulary during the build time.
readVocabulary()390 Error IR2VecVocabAnalysis::readVocabulary() {
391 auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true);
392 if (!BufOrError)
393 return createFileError(VocabFile, BufOrError.getError());
394
395 auto Content = BufOrError.get()->getBuffer();
396
397 Expected<json::Value> ParsedVocabValue = json::parse(Content);
398 if (!ParsedVocabValue)
399 return ParsedVocabValue.takeError();
400
401 unsigned OpcodeDim = 0, TypeDim = 0, ArgDim = 0;
402 if (auto Err =
403 parseVocabSection("Opcodes", *ParsedVocabValue, OpcVocab, OpcodeDim))
404 return Err;
405
406 if (auto Err =
407 parseVocabSection("Types", *ParsedVocabValue, TypeVocab, TypeDim))
408 return Err;
409
410 if (auto Err =
411 parseVocabSection("Arguments", *ParsedVocabValue, ArgVocab, ArgDim))
412 return Err;
413
414 if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))
415 return createStringError(errc::illegal_byte_sequence,
416 "Vocabulary sections have different dimensions");
417
418 return Error::success();
419 }
420
generateNumMappedVocab()421 void IR2VecVocabAnalysis::generateNumMappedVocab() {
422
423 // Helper for handling missing entities in the vocabulary.
424 // Currently, we use a zero vector. In the future, we will throw an error to
425 // ensure that *all* known entities are present in the vocabulary.
426 auto handleMissingEntity = [](const std::string &Val) {
427 LLVM_DEBUG(errs() << Val
428 << " is not in vocabulary, using zero vector; This "
429 "would result in an error in future.\n");
430 ++VocabMissCounter;
431 };
432
433 unsigned Dim = OpcVocab.begin()->second.size();
434 assert(Dim > 0 && "Vocabulary dimension must be greater than zero");
435
436 // Handle Opcodes
437 std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
438 Embedding(Dim, 0));
439 for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
440 StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
441 auto It = OpcVocab.find(VocabKey.str());
442 if (It != OpcVocab.end())
443 NumericOpcodeEmbeddings[Opcode] = It->second;
444 else
445 handleMissingEntity(VocabKey.str());
446 }
447 Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(),
448 NumericOpcodeEmbeddings.end());
449
450 // Handle Types
451 std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxTypeIDs,
452 Embedding(Dim, 0));
453 for (unsigned TypeID : seq(0u, Vocabulary::MaxTypeIDs)) {
454 StringRef VocabKey =
455 Vocabulary::getVocabKeyForTypeID(static_cast<Type::TypeID>(TypeID));
456 if (auto It = TypeVocab.find(VocabKey.str()); It != TypeVocab.end()) {
457 NumericTypeEmbeddings[TypeID] = It->second;
458 continue;
459 }
460 handleMissingEntity(VocabKey.str());
461 }
462 Vocab.insert(Vocab.end(), NumericTypeEmbeddings.begin(),
463 NumericTypeEmbeddings.end());
464
465 // Handle Arguments/Operands
466 std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds,
467 Embedding(Dim, 0));
468 for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) {
469 Vocabulary::OperandKind Kind = static_cast<Vocabulary::OperandKind>(OpKind);
470 StringRef VocabKey = Vocabulary::getVocabKeyForOperandKind(Kind);
471 auto It = ArgVocab.find(VocabKey.str());
472 if (It != ArgVocab.end()) {
473 NumericArgEmbeddings[OpKind] = It->second;
474 continue;
475 }
476 handleMissingEntity(VocabKey.str());
477 }
478 Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(),
479 NumericArgEmbeddings.end());
480 }
481
IR2VecVocabAnalysis(const VocabVector & Vocab)482 IR2VecVocabAnalysis::IR2VecVocabAnalysis(const VocabVector &Vocab)
483 : Vocab(Vocab) {}
484
IR2VecVocabAnalysis(VocabVector && Vocab)485 IR2VecVocabAnalysis::IR2VecVocabAnalysis(VocabVector &&Vocab)
486 : Vocab(std::move(Vocab)) {}
487
emitError(Error Err,LLVMContext & Ctx)488 void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
489 handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
490 Ctx.emitError("Error reading vocabulary: " + EI.message());
491 });
492 }
493
494 IR2VecVocabAnalysis::Result
run(Module & M,ModuleAnalysisManager & AM)495 IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
496 auto Ctx = &M.getContext();
497 // If vocabulary is already populated by the constructor, use it.
498 if (!Vocab.empty())
499 return Vocabulary(std::move(Vocab));
500
501 // Otherwise, try to read from the vocabulary file.
502 if (VocabFile.empty()) {
503 // FIXME: Use default vocabulary
504 Ctx->emitError("IR2Vec vocabulary file path not specified; You may need to "
505 "set it using --ir2vec-vocab-path");
506 return Vocabulary(); // Return invalid result
507 }
508 if (auto Err = readVocabulary()) {
509 emitError(std::move(Err), *Ctx);
510 return Vocabulary();
511 }
512
513 // Scale the vocabulary sections based on the provided weights
514 auto scaleVocabSection = [](VocabMap &Vocab, double Weight) {
515 for (auto &Entry : Vocab)
516 Entry.second *= Weight;
517 };
518 scaleVocabSection(OpcVocab, OpcWeight);
519 scaleVocabSection(TypeVocab, TypeWeight);
520 scaleVocabSection(ArgVocab, ArgWeight);
521
522 // Generate the numeric lookup vocabulary
523 generateNumMappedVocab();
524
525 return Vocabulary(std::move(Vocab));
526 }
527
528 // ==----------------------------------------------------------------------===//
529 // Printer Passes
530 //===----------------------------------------------------------------------===//
531
run(Module & M,ModuleAnalysisManager & MAM)532 PreservedAnalyses IR2VecPrinterPass::run(Module &M,
533 ModuleAnalysisManager &MAM) {
534 auto Vocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
535 assert(Vocabulary.isValid() && "IR2Vec Vocabulary is invalid");
536
537 for (Function &F : M) {
538 std::unique_ptr<Embedder> Emb =
539 Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
540 if (!Emb) {
541 OS << "Error creating IR2Vec embeddings \n";
542 continue;
543 }
544
545 OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
546 OS << "Function vector: ";
547 Emb->getFunctionVector().print(OS);
548
549 OS << "Basic block vectors:\n";
550 const auto &BBMap = Emb->getBBVecMap();
551 for (const BasicBlock &BB : F) {
552 auto It = BBMap.find(&BB);
553 if (It != BBMap.end()) {
554 OS << "Basic block: " << BB.getName() << ":\n";
555 It->second.print(OS);
556 }
557 }
558
559 OS << "Instruction vectors:\n";
560 const auto &InstMap = Emb->getInstVecMap();
561 for (const BasicBlock &BB : F) {
562 for (const Instruction &I : BB) {
563 auto It = InstMap.find(&I);
564 if (It != InstMap.end()) {
565 OS << "Instruction: ";
566 I.print(OS);
567 It->second.print(OS);
568 }
569 }
570 }
571 }
572 return PreservedAnalyses::all();
573 }
574
run(Module & M,ModuleAnalysisManager & MAM)575 PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M,
576 ModuleAnalysisManager &MAM) {
577 auto IR2VecVocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
578 assert(IR2VecVocabulary.isValid() && "IR2Vec Vocabulary is invalid");
579
580 // Print each entry
581 unsigned Pos = 0;
582 for (const auto &Entry : IR2VecVocabulary) {
583 OS << "Key: " << IR2VecVocabulary.getStringKey(Pos++) << ": ";
584 Entry.print(OS);
585 }
586 return PreservedAnalyses::all();
587 }
588