1 //===- IR2Vec.cpp - Implementation of IR2Vec -----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM 4 // Exceptions. See the LICENSE file for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// This file implements the IR2Vec algorithm. 11 /// 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/Analysis/IR2Vec.h" 15 16 #include "llvm/ADT/DepthFirstIterator.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/Statistic.h" 19 #include "llvm/IR/CFG.h" 20 #include "llvm/IR/Module.h" 21 #include "llvm/IR/PassManager.h" 22 #include "llvm/Support/Debug.h" 23 #include "llvm/Support/Errc.h" 24 #include "llvm/Support/Error.h" 25 #include "llvm/Support/ErrorHandling.h" 26 #include "llvm/Support/Format.h" 27 #include "llvm/Support/MemoryBuffer.h" 28 29 using namespace llvm; 30 using namespace ir2vec; 31 32 #define DEBUG_TYPE "ir2vec" 33 34 STATISTIC(VocabMissCounter, 35 "Number of lookups to entites not present in the vocabulary"); 36 37 namespace llvm { 38 namespace ir2vec { 39 static cl::OptionCategory IR2VecCategory("IR2Vec Options"); 40 41 // FIXME: Use a default vocab when not specified 42 static cl::opt<std::string> 43 VocabFile("ir2vec-vocab-path", cl::Optional, 44 cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""), 45 cl::cat(IR2VecCategory)); 46 cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional, cl::init(1.0), 47 cl::desc("Weight for opcode embeddings"), 48 cl::cat(IR2VecCategory)); 49 cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5), 50 cl::desc("Weight for type embeddings"), 51 cl::cat(IR2VecCategory)); 52 cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2), 53 cl::desc("Weight for argument embeddings"), 54 cl::cat(IR2VecCategory)); 55 } // namespace ir2vec 56 } // namespace llvm 57 58 AnalysisKey IR2VecVocabAnalysis::Key; 59 60 // ==----------------------------------------------------------------------===// 61 // Local helper functions 62 //===----------------------------------------------------------------------===// 63 namespace llvm::json { 64 inline bool fromJSON(const llvm::json::Value &E, Embedding &Out, 65 llvm::json::Path P) { 66 std::vector<double> TempOut; 67 if (!llvm::json::fromJSON(E, TempOut, P)) 68 return false; 69 Out = Embedding(std::move(TempOut)); 70 return true; 71 } 72 } // namespace llvm::json 73 74 // ==----------------------------------------------------------------------===// 75 // Embedding 76 //===----------------------------------------------------------------------===// 77 Embedding &Embedding::operator+=(const Embedding &RHS) { 78 assert(this->size() == RHS.size() && "Vectors must have the same dimension"); 79 std::transform(this->begin(), this->end(), RHS.begin(), this->begin(), 80 std::plus<double>()); 81 return *this; 82 } 83 84 Embedding Embedding::operator+(const Embedding &RHS) const { 85 Embedding Result(*this); 86 Result += RHS; 87 return Result; 88 } 89 90 Embedding &Embedding::operator-=(const Embedding &RHS) { 91 assert(this->size() == RHS.size() && "Vectors must have the same dimension"); 92 std::transform(this->begin(), this->end(), RHS.begin(), this->begin(), 93 std::minus<double>()); 94 return *this; 95 } 96 97 Embedding Embedding::operator-(const Embedding &RHS) const { 98 Embedding Result(*this); 99 Result -= RHS; 100 return Result; 101 } 102 103 Embedding &Embedding::operator*=(double Factor) { 104 std::transform(this->begin(), this->end(), this->begin(), 105 [Factor](double Elem) { return Elem * Factor; }); 106 return *this; 107 } 108 109 Embedding Embedding::operator*(double Factor) const { 110 Embedding Result(*this); 111 Result *= Factor; 112 return Result; 113 } 114 115 Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) { 116 assert(this->size() == Src.size() && "Vectors must have the same dimension"); 117 for (size_t Itr = 0; Itr < this->size(); ++Itr) 118 (*this)[Itr] += Src[Itr] * Factor; 119 return *this; 120 } 121 122 bool Embedding::approximatelyEquals(const Embedding &RHS, 123 double Tolerance) const { 124 assert(this->size() == RHS.size() && "Vectors must have the same dimension"); 125 for (size_t Itr = 0; Itr < this->size(); ++Itr) 126 if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance) 127 return false; 128 return true; 129 } 130 131 void Embedding::print(raw_ostream &OS) const { 132 OS << " ["; 133 for (const auto &Elem : Data) 134 OS << " " << format("%.2f", Elem) << " "; 135 OS << "]\n"; 136 } 137 138 // ==----------------------------------------------------------------------===// 139 // Embedder and its subclasses 140 //===----------------------------------------------------------------------===// 141 142 Embedder::Embedder(const Function &F, const Vocabulary &Vocab) 143 : F(F), Vocab(Vocab), Dimension(Vocab.getDimension()), 144 OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) { 145 } 146 147 std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F, 148 const Vocabulary &Vocab) { 149 switch (Mode) { 150 case IR2VecKind::Symbolic: 151 return std::make_unique<SymbolicEmbedder>(F, Vocab); 152 } 153 return nullptr; 154 } 155 156 const InstEmbeddingsMap &Embedder::getInstVecMap() const { 157 if (InstVecMap.empty()) 158 computeEmbeddings(); 159 return InstVecMap; 160 } 161 162 const BBEmbeddingsMap &Embedder::getBBVecMap() const { 163 if (BBVecMap.empty()) 164 computeEmbeddings(); 165 return BBVecMap; 166 } 167 168 const Embedding &Embedder::getBBVector(const BasicBlock &BB) const { 169 auto It = BBVecMap.find(&BB); 170 if (It != BBVecMap.end()) 171 return It->second; 172 computeEmbeddings(BB); 173 return BBVecMap[&BB]; 174 } 175 176 const Embedding &Embedder::getFunctionVector() const { 177 // Currently, we always (re)compute the embeddings for the function. 178 // This is cheaper than caching the vector. 179 computeEmbeddings(); 180 return FuncVector; 181 } 182 183 void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const { 184 Embedding BBVector(Dimension, 0); 185 186 // We consider only the non-debug and non-pseudo instructions 187 for (const auto &I : BB.instructionsWithoutDebug()) { 188 Embedding ArgEmb(Dimension, 0); 189 for (const auto &Op : I.operands()) 190 ArgEmb += Vocab[Op]; 191 auto InstVector = 192 Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb; 193 InstVecMap[&I] = InstVector; 194 BBVector += InstVector; 195 } 196 BBVecMap[&BB] = BBVector; 197 } 198 199 void SymbolicEmbedder::computeEmbeddings() const { 200 if (F.isDeclaration()) 201 return; 202 203 // Consider only the basic blocks that are reachable from entry 204 for (const BasicBlock *BB : depth_first(&F)) { 205 computeEmbeddings(*BB); 206 FuncVector += BBVecMap[BB]; 207 } 208 } 209 210 // ==----------------------------------------------------------------------===// 211 // Vocabulary 212 //===----------------------------------------------------------------------===// 213 214 Vocabulary::Vocabulary(VocabVector &&Vocab) 215 : Vocab(std::move(Vocab)), Valid(true) {} 216 217 bool Vocabulary::isValid() const { 218 return Vocab.size() == (MaxOpcodes + MaxTypeIDs + MaxOperandKinds) && Valid; 219 } 220 221 size_t Vocabulary::size() const { 222 assert(Valid && "IR2Vec Vocabulary is invalid"); 223 return Vocab.size(); 224 } 225 226 unsigned Vocabulary::getDimension() const { 227 assert(Valid && "IR2Vec Vocabulary is invalid"); 228 return Vocab[0].size(); 229 } 230 231 const Embedding &Vocabulary::operator[](unsigned Opcode) const { 232 assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode"); 233 return Vocab[Opcode - 1]; 234 } 235 236 const Embedding &Vocabulary::operator[](Type::TypeID TypeId) const { 237 assert(static_cast<unsigned>(TypeId) < MaxTypeIDs && "Invalid type ID"); 238 return Vocab[MaxOpcodes + static_cast<unsigned>(TypeId)]; 239 } 240 241 const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const { 242 OperandKind ArgKind = getOperandKind(Arg); 243 return Vocab[MaxOpcodes + MaxTypeIDs + static_cast<unsigned>(ArgKind)]; 244 } 245 246 StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) { 247 assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode"); 248 #define HANDLE_INST(NUM, OPCODE, CLASS) \ 249 if (Opcode == NUM) { \ 250 return #OPCODE; \ 251 } 252 #include "llvm/IR/Instruction.def" 253 #undef HANDLE_INST 254 return "UnknownOpcode"; 255 } 256 257 StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) { 258 switch (TypeID) { 259 case Type::VoidTyID: 260 return "VoidTy"; 261 case Type::HalfTyID: 262 case Type::BFloatTyID: 263 case Type::FloatTyID: 264 case Type::DoubleTyID: 265 case Type::X86_FP80TyID: 266 case Type::FP128TyID: 267 case Type::PPC_FP128TyID: 268 return "FloatTy"; 269 case Type::IntegerTyID: 270 return "IntegerTy"; 271 case Type::FunctionTyID: 272 return "FunctionTy"; 273 case Type::StructTyID: 274 return "StructTy"; 275 case Type::ArrayTyID: 276 return "ArrayTy"; 277 case Type::PointerTyID: 278 case Type::TypedPointerTyID: 279 return "PointerTy"; 280 case Type::FixedVectorTyID: 281 case Type::ScalableVectorTyID: 282 return "VectorTy"; 283 case Type::LabelTyID: 284 return "LabelTy"; 285 case Type::TokenTyID: 286 return "TokenTy"; 287 case Type::MetadataTyID: 288 return "MetadataTy"; 289 case Type::X86_AMXTyID: 290 case Type::TargetExtTyID: 291 return "UnknownTy"; 292 } 293 return "UnknownTy"; 294 } 295 296 StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) { 297 unsigned Index = static_cast<unsigned>(Kind); 298 assert(Index < MaxOperandKinds && "Invalid OperandKind"); 299 return OperandKindNames[Index]; 300 } 301 302 Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) { 303 VocabVector DummyVocab; 304 float DummyVal = 0.1f; 305 // Create a dummy vocabulary with entries for all opcodes, types, and 306 // operand 307 for (unsigned _ : seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxTypeIDs + 308 Vocabulary::MaxOperandKinds)) { 309 DummyVocab.push_back(Embedding(Dim, DummyVal)); 310 DummyVal += 0.1; 311 } 312 return DummyVocab; 313 } 314 315 // Helper function to classify an operand into OperandKind 316 Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) { 317 if (isa<Function>(Op)) 318 return OperandKind::FunctionID; 319 if (isa<PointerType>(Op->getType())) 320 return OperandKind::PointerID; 321 if (isa<Constant>(Op)) 322 return OperandKind::ConstantID; 323 return OperandKind::VariableID; 324 } 325 326 StringRef Vocabulary::getStringKey(unsigned Pos) { 327 assert(Pos < MaxOpcodes + MaxTypeIDs + MaxOperandKinds && 328 "Position out of bounds in vocabulary"); 329 // Opcode 330 if (Pos < MaxOpcodes) 331 return getVocabKeyForOpcode(Pos + 1); 332 // Type 333 if (Pos < MaxOpcodes + MaxTypeIDs) 334 return getVocabKeyForTypeID(static_cast<Type::TypeID>(Pos - MaxOpcodes)); 335 // Operand 336 return getVocabKeyForOperandKind( 337 static_cast<OperandKind>(Pos - MaxOpcodes - MaxTypeIDs)); 338 } 339 340 // For now, assume vocabulary is stable unless explicitly invalidated. 341 bool Vocabulary::invalidate(Module &M, const PreservedAnalyses &PA, 342 ModuleAnalysisManager::Invalidator &Inv) const { 343 auto PAC = PA.getChecker<IR2VecVocabAnalysis>(); 344 return !(PAC.preservedWhenStateless()); 345 } 346 347 // ==----------------------------------------------------------------------===// 348 // IR2VecVocabAnalysis 349 //===----------------------------------------------------------------------===// 350 351 Error IR2VecVocabAnalysis::parseVocabSection( 352 StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab, 353 unsigned &Dim) { 354 json::Path::Root Path(""); 355 const json::Object *RootObj = ParsedVocabValue.getAsObject(); 356 if (!RootObj) 357 return createStringError(errc::invalid_argument, 358 "JSON root is not an object"); 359 360 const json::Value *SectionValue = RootObj->get(Key); 361 if (!SectionValue) 362 return createStringError(errc::invalid_argument, 363 "Missing '" + std::string(Key) + 364 "' section in vocabulary file"); 365 if (!json::fromJSON(*SectionValue, TargetVocab, Path)) 366 return createStringError(errc::illegal_byte_sequence, 367 "Unable to parse '" + std::string(Key) + 368 "' section from vocabulary"); 369 370 Dim = TargetVocab.begin()->second.size(); 371 if (Dim == 0) 372 return createStringError(errc::illegal_byte_sequence, 373 "Dimension of '" + std::string(Key) + 374 "' section of the vocabulary is zero"); 375 376 if (!std::all_of(TargetVocab.begin(), TargetVocab.end(), 377 [Dim](const std::pair<StringRef, Embedding> &Entry) { 378 return Entry.second.size() == Dim; 379 })) 380 return createStringError( 381 errc::illegal_byte_sequence, 382 "All vectors in the '" + std::string(Key) + 383 "' section of the vocabulary are not of the same dimension"); 384 385 return Error::success(); 386 } 387 388 // FIXME: Make this optional. We can avoid file reads 389 // by auto-generating a default vocabulary during the build time. 390 Error IR2VecVocabAnalysis::readVocabulary() { 391 auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true); 392 if (!BufOrError) 393 return createFileError(VocabFile, BufOrError.getError()); 394 395 auto Content = BufOrError.get()->getBuffer(); 396 397 Expected<json::Value> ParsedVocabValue = json::parse(Content); 398 if (!ParsedVocabValue) 399 return ParsedVocabValue.takeError(); 400 401 unsigned OpcodeDim = 0, TypeDim = 0, ArgDim = 0; 402 if (auto Err = 403 parseVocabSection("Opcodes", *ParsedVocabValue, OpcVocab, OpcodeDim)) 404 return Err; 405 406 if (auto Err = 407 parseVocabSection("Types", *ParsedVocabValue, TypeVocab, TypeDim)) 408 return Err; 409 410 if (auto Err = 411 parseVocabSection("Arguments", *ParsedVocabValue, ArgVocab, ArgDim)) 412 return Err; 413 414 if (!(OpcodeDim == TypeDim && TypeDim == ArgDim)) 415 return createStringError(errc::illegal_byte_sequence, 416 "Vocabulary sections have different dimensions"); 417 418 return Error::success(); 419 } 420 421 void IR2VecVocabAnalysis::generateNumMappedVocab() { 422 423 // Helper for handling missing entities in the vocabulary. 424 // Currently, we use a zero vector. In the future, we will throw an error to 425 // ensure that *all* known entities are present in the vocabulary. 426 auto handleMissingEntity = [](const std::string &Val) { 427 LLVM_DEBUG(errs() << Val 428 << " is not in vocabulary, using zero vector; This " 429 "would result in an error in future.\n"); 430 ++VocabMissCounter; 431 }; 432 433 unsigned Dim = OpcVocab.begin()->second.size(); 434 assert(Dim > 0 && "Vocabulary dimension must be greater than zero"); 435 436 // Handle Opcodes 437 std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes, 438 Embedding(Dim, 0)); 439 for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) { 440 StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1); 441 auto It = OpcVocab.find(VocabKey.str()); 442 if (It != OpcVocab.end()) 443 NumericOpcodeEmbeddings[Opcode] = It->second; 444 else 445 handleMissingEntity(VocabKey.str()); 446 } 447 Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(), 448 NumericOpcodeEmbeddings.end()); 449 450 // Handle Types 451 std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxTypeIDs, 452 Embedding(Dim, 0)); 453 for (unsigned TypeID : seq(0u, Vocabulary::MaxTypeIDs)) { 454 StringRef VocabKey = 455 Vocabulary::getVocabKeyForTypeID(static_cast<Type::TypeID>(TypeID)); 456 if (auto It = TypeVocab.find(VocabKey.str()); It != TypeVocab.end()) { 457 NumericTypeEmbeddings[TypeID] = It->second; 458 continue; 459 } 460 handleMissingEntity(VocabKey.str()); 461 } 462 Vocab.insert(Vocab.end(), NumericTypeEmbeddings.begin(), 463 NumericTypeEmbeddings.end()); 464 465 // Handle Arguments/Operands 466 std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds, 467 Embedding(Dim, 0)); 468 for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) { 469 Vocabulary::OperandKind Kind = static_cast<Vocabulary::OperandKind>(OpKind); 470 StringRef VocabKey = Vocabulary::getVocabKeyForOperandKind(Kind); 471 auto It = ArgVocab.find(VocabKey.str()); 472 if (It != ArgVocab.end()) { 473 NumericArgEmbeddings[OpKind] = It->second; 474 continue; 475 } 476 handleMissingEntity(VocabKey.str()); 477 } 478 Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(), 479 NumericArgEmbeddings.end()); 480 } 481 482 IR2VecVocabAnalysis::IR2VecVocabAnalysis(const VocabVector &Vocab) 483 : Vocab(Vocab) {} 484 485 IR2VecVocabAnalysis::IR2VecVocabAnalysis(VocabVector &&Vocab) 486 : Vocab(std::move(Vocab)) {} 487 488 void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) { 489 handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) { 490 Ctx.emitError("Error reading vocabulary: " + EI.message()); 491 }); 492 } 493 494 IR2VecVocabAnalysis::Result 495 IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) { 496 auto Ctx = &M.getContext(); 497 // If vocabulary is already populated by the constructor, use it. 498 if (!Vocab.empty()) 499 return Vocabulary(std::move(Vocab)); 500 501 // Otherwise, try to read from the vocabulary file. 502 if (VocabFile.empty()) { 503 // FIXME: Use default vocabulary 504 Ctx->emitError("IR2Vec vocabulary file path not specified; You may need to " 505 "set it using --ir2vec-vocab-path"); 506 return Vocabulary(); // Return invalid result 507 } 508 if (auto Err = readVocabulary()) { 509 emitError(std::move(Err), *Ctx); 510 return Vocabulary(); 511 } 512 513 // Scale the vocabulary sections based on the provided weights 514 auto scaleVocabSection = [](VocabMap &Vocab, double Weight) { 515 for (auto &Entry : Vocab) 516 Entry.second *= Weight; 517 }; 518 scaleVocabSection(OpcVocab, OpcWeight); 519 scaleVocabSection(TypeVocab, TypeWeight); 520 scaleVocabSection(ArgVocab, ArgWeight); 521 522 // Generate the numeric lookup vocabulary 523 generateNumMappedVocab(); 524 525 return Vocabulary(std::move(Vocab)); 526 } 527 528 // ==----------------------------------------------------------------------===// 529 // Printer Passes 530 //===----------------------------------------------------------------------===// 531 532 PreservedAnalyses IR2VecPrinterPass::run(Module &M, 533 ModuleAnalysisManager &MAM) { 534 auto Vocabulary = MAM.getResult<IR2VecVocabAnalysis>(M); 535 assert(Vocabulary.isValid() && "IR2Vec Vocabulary is invalid"); 536 537 for (Function &F : M) { 538 std::unique_ptr<Embedder> Emb = 539 Embedder::create(IR2VecKind::Symbolic, F, Vocabulary); 540 if (!Emb) { 541 OS << "Error creating IR2Vec embeddings \n"; 542 continue; 543 } 544 545 OS << "IR2Vec embeddings for function " << F.getName() << ":\n"; 546 OS << "Function vector: "; 547 Emb->getFunctionVector().print(OS); 548 549 OS << "Basic block vectors:\n"; 550 const auto &BBMap = Emb->getBBVecMap(); 551 for (const BasicBlock &BB : F) { 552 auto It = BBMap.find(&BB); 553 if (It != BBMap.end()) { 554 OS << "Basic block: " << BB.getName() << ":\n"; 555 It->second.print(OS); 556 } 557 } 558 559 OS << "Instruction vectors:\n"; 560 const auto &InstMap = Emb->getInstVecMap(); 561 for (const BasicBlock &BB : F) { 562 for (const Instruction &I : BB) { 563 auto It = InstMap.find(&I); 564 if (It != InstMap.end()) { 565 OS << "Instruction: "; 566 I.print(OS); 567 It->second.print(OS); 568 } 569 } 570 } 571 } 572 return PreservedAnalyses::all(); 573 } 574 575 PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M, 576 ModuleAnalysisManager &MAM) { 577 auto IR2VecVocabulary = MAM.getResult<IR2VecVocabAnalysis>(M); 578 assert(IR2VecVocabulary.isValid() && "IR2Vec Vocabulary is invalid"); 579 580 // Print each entry 581 unsigned Pos = 0; 582 for (const auto &Entry : IR2VecVocabulary) { 583 OS << "Key: " << IR2VecVocabulary.getStringKey(Pos++) << ": "; 584 Entry.print(OS); 585 } 586 return PreservedAnalyses::all(); 587 } 588