1 //===- Standard pass instrumentations handling ----------------*- C++ -*--===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// \file 9 /// 10 /// This file defines IR-printing pass instrumentation callbacks as well as 11 /// StandardInstrumentations class that manages standard pass instrumentations. 12 /// 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Passes/StandardInstrumentations.h" 16 #include "llvm/ADT/Optional.h" 17 #include "llvm/Analysis/CallGraphSCCPass.h" 18 #include "llvm/Analysis/LazyCallGraph.h" 19 #include "llvm/Analysis/LoopInfo.h" 20 #include "llvm/IR/Function.h" 21 #include "llvm/IR/IRPrintingPasses.h" 22 #include "llvm/IR/Module.h" 23 #include "llvm/IR/PassInstrumentation.h" 24 #include "llvm/Support/Debug.h" 25 #include "llvm/Support/FormatVariadic.h" 26 #include "llvm/Support/raw_ostream.h" 27 28 using namespace llvm; 29 30 namespace { 31 32 /// Extracting Module out of \p IR unit. Also fills a textual description 33 /// of \p IR for use in header when printing. 34 Optional<std::pair<const Module *, std::string>> unwrapModule(Any IR) { 35 if (any_isa<const Module *>(IR)) 36 return std::make_pair(any_cast<const Module *>(IR), std::string()); 37 38 if (any_isa<const Function *>(IR)) { 39 const Function *F = any_cast<const Function *>(IR); 40 if (!llvm::isFunctionInPrintList(F->getName())) 41 return None; 42 const Module *M = F->getParent(); 43 return std::make_pair(M, formatv(" (function: {0})", F->getName()).str()); 44 } 45 46 if (any_isa<const LazyCallGraph::SCC *>(IR)) { 47 const LazyCallGraph::SCC *C = any_cast<const LazyCallGraph::SCC *>(IR); 48 for (const LazyCallGraph::Node &N : *C) { 49 const Function &F = N.getFunction(); 50 if (!F.isDeclaration() && isFunctionInPrintList(F.getName())) { 51 const Module *M = F.getParent(); 52 return std::make_pair(M, formatv(" (scc: {0})", C->getName()).str()); 53 } 54 } 55 return None; 56 } 57 58 if (any_isa<const Loop *>(IR)) { 59 const Loop *L = any_cast<const Loop *>(IR); 60 const Function *F = L->getHeader()->getParent(); 61 if (!isFunctionInPrintList(F->getName())) 62 return None; 63 const Module *M = F->getParent(); 64 std::string LoopName; 65 raw_string_ostream ss(LoopName); 66 L->getHeader()->printAsOperand(ss, false); 67 return std::make_pair(M, formatv(" (loop: {0})", ss.str()).str()); 68 } 69 70 llvm_unreachable("Unknown IR unit"); 71 } 72 73 void printIR(const Module *M, StringRef Banner, StringRef Extra = StringRef()) { 74 dbgs() << Banner << Extra << "\n"; 75 M->print(dbgs(), nullptr, false); 76 } 77 void printIR(const Function *F, StringRef Banner, 78 StringRef Extra = StringRef()) { 79 if (!llvm::isFunctionInPrintList(F->getName())) 80 return; 81 dbgs() << Banner << Extra << "\n" << static_cast<const Value &>(*F); 82 } 83 void printIR(const LazyCallGraph::SCC *C, StringRef Banner, 84 StringRef Extra = StringRef()) { 85 bool BannerPrinted = false; 86 for (const LazyCallGraph::Node &N : *C) { 87 const Function &F = N.getFunction(); 88 if (!F.isDeclaration() && llvm::isFunctionInPrintList(F.getName())) { 89 if (!BannerPrinted) { 90 dbgs() << Banner << Extra << "\n"; 91 BannerPrinted = true; 92 } 93 F.print(dbgs()); 94 } 95 } 96 } 97 void printIR(const Loop *L, StringRef Banner) { 98 const Function *F = L->getHeader()->getParent(); 99 if (!llvm::isFunctionInPrintList(F->getName())) 100 return; 101 llvm::printLoop(const_cast<Loop &>(*L), dbgs(), Banner); 102 } 103 104 /// Generic IR-printing helper that unpacks a pointer to IRUnit wrapped into 105 /// llvm::Any and does actual print job. 106 void unwrapAndPrint(Any IR, StringRef Banner, bool ForceModule = false) { 107 if (ForceModule) { 108 if (auto UnwrappedModule = unwrapModule(IR)) 109 printIR(UnwrappedModule->first, Banner, UnwrappedModule->second); 110 return; 111 } 112 113 if (any_isa<const Module *>(IR)) { 114 const Module *M = any_cast<const Module *>(IR); 115 assert(M && "module should be valid for printing"); 116 printIR(M, Banner); 117 return; 118 } 119 120 if (any_isa<const Function *>(IR)) { 121 const Function *F = any_cast<const Function *>(IR); 122 assert(F && "function should be valid for printing"); 123 printIR(F, Banner); 124 return; 125 } 126 127 if (any_isa<const LazyCallGraph::SCC *>(IR)) { 128 const LazyCallGraph::SCC *C = any_cast<const LazyCallGraph::SCC *>(IR); 129 assert(C && "scc should be valid for printing"); 130 std::string Extra = formatv(" (scc: {0})", C->getName()); 131 printIR(C, Banner, Extra); 132 return; 133 } 134 135 if (any_isa<const Loop *>(IR)) { 136 const Loop *L = any_cast<const Loop *>(IR); 137 assert(L && "Loop should be valid for printing"); 138 printIR(L, Banner); 139 return; 140 } 141 llvm_unreachable("Unknown wrapped IR type"); 142 } 143 144 } // namespace 145 146 PrintIRInstrumentation::~PrintIRInstrumentation() { 147 assert(ModuleDescStack.empty() && "ModuleDescStack is not empty at exit"); 148 } 149 150 void PrintIRInstrumentation::pushModuleDesc(StringRef PassID, Any IR) { 151 assert(StoreModuleDesc); 152 const Module *M = nullptr; 153 std::string Extra; 154 if (auto UnwrappedModule = unwrapModule(IR)) 155 std::tie(M, Extra) = UnwrappedModule.getValue(); 156 ModuleDescStack.emplace_back(M, Extra, PassID); 157 } 158 159 PrintIRInstrumentation::PrintModuleDesc 160 PrintIRInstrumentation::popModuleDesc(StringRef PassID) { 161 assert(!ModuleDescStack.empty() && "empty ModuleDescStack"); 162 PrintModuleDesc ModuleDesc = ModuleDescStack.pop_back_val(); 163 assert(std::get<2>(ModuleDesc).equals(PassID) && "malformed ModuleDescStack"); 164 return ModuleDesc; 165 } 166 167 bool PrintIRInstrumentation::printBeforePass(StringRef PassID, Any IR) { 168 if (PassID.startswith("PassManager<") || PassID.contains("PassAdaptor<")) 169 return true; 170 171 // Saving Module for AfterPassInvalidated operations. 172 // Note: here we rely on a fact that we do not change modules while 173 // traversing the pipeline, so the latest captured module is good 174 // for all print operations that has not happen yet. 175 if (StoreModuleDesc && llvm::shouldPrintAfterPass(PassID)) 176 pushModuleDesc(PassID, IR); 177 178 if (!llvm::shouldPrintBeforePass(PassID)) 179 return true; 180 181 SmallString<20> Banner = formatv("*** IR Dump Before {0} ***", PassID); 182 unwrapAndPrint(IR, Banner, llvm::forcePrintModuleIR()); 183 return true; 184 } 185 186 void PrintIRInstrumentation::printAfterPass(StringRef PassID, Any IR) { 187 if (PassID.startswith("PassManager<") || PassID.contains("PassAdaptor<")) 188 return; 189 190 if (!llvm::shouldPrintAfterPass(PassID)) 191 return; 192 193 if (StoreModuleDesc) 194 popModuleDesc(PassID); 195 196 SmallString<20> Banner = formatv("*** IR Dump After {0} ***", PassID); 197 unwrapAndPrint(IR, Banner, llvm::forcePrintModuleIR()); 198 } 199 200 void PrintIRInstrumentation::printAfterPassInvalidated(StringRef PassID) { 201 if (!StoreModuleDesc || !llvm::shouldPrintAfterPass(PassID)) 202 return; 203 204 if (PassID.startswith("PassManager<") || PassID.contains("PassAdaptor<")) 205 return; 206 207 const Module *M; 208 std::string Extra; 209 StringRef StoredPassID; 210 std::tie(M, Extra, StoredPassID) = popModuleDesc(PassID); 211 // Additional filtering (e.g. -filter-print-func) can lead to module 212 // printing being skipped. 213 if (!M) 214 return; 215 216 SmallString<20> Banner = 217 formatv("*** IR Dump After {0} *** invalidated: ", PassID); 218 printIR(M, Banner, Extra); 219 } 220 221 void PrintIRInstrumentation::registerCallbacks( 222 PassInstrumentationCallbacks &PIC) { 223 // BeforePass callback is not just for printing, it also saves a Module 224 // for later use in AfterPassInvalidated. 225 StoreModuleDesc = llvm::forcePrintModuleIR() && llvm::shouldPrintAfterPass(); 226 if (llvm::shouldPrintBeforePass() || StoreModuleDesc) 227 PIC.registerBeforePassCallback( 228 [this](StringRef P, Any IR) { return this->printBeforePass(P, IR); }); 229 230 if (llvm::shouldPrintAfterPass()) { 231 PIC.registerAfterPassCallback( 232 [this](StringRef P, Any IR) { this->printAfterPass(P, IR); }); 233 PIC.registerAfterPassInvalidatedCallback( 234 [this](StringRef P) { this->printAfterPassInvalidated(P); }); 235 } 236 } 237 238 void StandardInstrumentations::registerCallbacks( 239 PassInstrumentationCallbacks &PIC) { 240 PrintIR.registerCallbacks(PIC); 241 TimePasses.registerCallbacks(PIC); 242 } 243