1 //===- Standard pass instrumentations handling ----------------*- C++ -*--===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// \file 9 /// 10 /// This file defines IR-printing pass instrumentation callbacks as well as 11 /// StandardInstrumentations class that manages standard pass instrumentations. 12 /// 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Passes/StandardInstrumentations.h" 16 #include "llvm/ADT/Optional.h" 17 #include "llvm/Analysis/CallGraphSCCPass.h" 18 #include "llvm/Analysis/LazyCallGraph.h" 19 #include "llvm/Analysis/LoopInfo.h" 20 #include "llvm/IR/Function.h" 21 #include "llvm/IR/IRPrintingPasses.h" 22 #include "llvm/IR/Module.h" 23 #include "llvm/IR/PassInstrumentation.h" 24 #include "llvm/Support/Debug.h" 25 #include "llvm/Support/FormatVariadic.h" 26 #include "llvm/Support/raw_ostream.h" 27 28 using namespace llvm; 29 30 namespace { 31 32 /// Extracting Module out of \p IR unit. Also fills a textual description 33 /// of \p IR for use in header when printing. 34 Optional<std::pair<const Module *, std::string>> unwrapModule(Any IR) { 35 if (any_isa<const Module *>(IR)) 36 return std::make_pair(any_cast<const Module *>(IR), std::string()); 37 38 if (any_isa<const Function *>(IR)) { 39 const Function *F = any_cast<const Function *>(IR); 40 if (!llvm::isFunctionInPrintList(F->getName())) 41 return None; 42 const Module *M = F->getParent(); 43 return std::make_pair(M, formatv(" (function: {0})", F->getName()).str()); 44 } 45 46 if (any_isa<const LazyCallGraph::SCC *>(IR)) { 47 const LazyCallGraph::SCC *C = any_cast<const LazyCallGraph::SCC *>(IR); 48 for (const LazyCallGraph::Node &N : *C) { 49 const Function &F = N.getFunction(); 50 if (!F.isDeclaration() && isFunctionInPrintList(F.getName())) { 51 const Module *M = F.getParent(); 52 return std::make_pair(M, formatv(" (scc: {0})", C->getName()).str()); 53 } 54 } 55 return None; 56 } 57 58 if (any_isa<const Loop *>(IR)) { 59 const Loop *L = any_cast<const Loop *>(IR); 60 const Function *F = L->getHeader()->getParent(); 61 if (!isFunctionInPrintList(F->getName())) 62 return None; 63 const Module *M = F->getParent(); 64 std::string LoopName; 65 raw_string_ostream ss(LoopName); 66 L->getHeader()->printAsOperand(ss, false); 67 return std::make_pair(M, formatv(" (loop: {0})", ss.str()).str()); 68 } 69 70 llvm_unreachable("Unknown IR unit"); 71 } 72 73 void printIR(const Function *F, StringRef Banner, 74 StringRef Extra = StringRef()) { 75 if (!llvm::isFunctionInPrintList(F->getName())) 76 return; 77 dbgs() << Banner << Extra << "\n" << static_cast<const Value &>(*F); 78 } 79 80 void printIR(const Module *M, StringRef Banner, StringRef Extra = StringRef()) { 81 if (llvm::isFunctionInPrintList("*") || llvm::forcePrintModuleIR()) { 82 dbgs() << Banner << Extra << "\n"; 83 M->print(dbgs(), nullptr, false); 84 } else { 85 for (const auto &F : M->functions()) { 86 printIR(&F, Banner, Extra); 87 } 88 } 89 } 90 91 void printIR(const LazyCallGraph::SCC *C, StringRef Banner, 92 StringRef Extra = StringRef()) { 93 bool BannerPrinted = false; 94 for (const LazyCallGraph::Node &N : *C) { 95 const Function &F = N.getFunction(); 96 if (!F.isDeclaration() && llvm::isFunctionInPrintList(F.getName())) { 97 if (!BannerPrinted) { 98 dbgs() << Banner << Extra << "\n"; 99 BannerPrinted = true; 100 } 101 F.print(dbgs()); 102 } 103 } 104 } 105 void printIR(const Loop *L, StringRef Banner) { 106 const Function *F = L->getHeader()->getParent(); 107 if (!llvm::isFunctionInPrintList(F->getName())) 108 return; 109 llvm::printLoop(const_cast<Loop &>(*L), dbgs(), std::string(Banner)); 110 } 111 112 /// Generic IR-printing helper that unpacks a pointer to IRUnit wrapped into 113 /// llvm::Any and does actual print job. 114 void unwrapAndPrint(Any IR, StringRef Banner, bool ForceModule = false) { 115 if (ForceModule) { 116 if (auto UnwrappedModule = unwrapModule(IR)) 117 printIR(UnwrappedModule->first, Banner, UnwrappedModule->second); 118 return; 119 } 120 121 if (any_isa<const Module *>(IR)) { 122 const Module *M = any_cast<const Module *>(IR); 123 assert(M && "module should be valid for printing"); 124 printIR(M, Banner); 125 return; 126 } 127 128 if (any_isa<const Function *>(IR)) { 129 const Function *F = any_cast<const Function *>(IR); 130 assert(F && "function should be valid for printing"); 131 printIR(F, Banner); 132 return; 133 } 134 135 if (any_isa<const LazyCallGraph::SCC *>(IR)) { 136 const LazyCallGraph::SCC *C = any_cast<const LazyCallGraph::SCC *>(IR); 137 assert(C && "scc should be valid for printing"); 138 std::string Extra = std::string(formatv(" (scc: {0})", C->getName())); 139 printIR(C, Banner, Extra); 140 return; 141 } 142 143 if (any_isa<const Loop *>(IR)) { 144 const Loop *L = any_cast<const Loop *>(IR); 145 assert(L && "Loop should be valid for printing"); 146 printIR(L, Banner); 147 return; 148 } 149 llvm_unreachable("Unknown wrapped IR type"); 150 } 151 152 } // namespace 153 154 PrintIRInstrumentation::~PrintIRInstrumentation() { 155 assert(ModuleDescStack.empty() && "ModuleDescStack is not empty at exit"); 156 } 157 158 void PrintIRInstrumentation::pushModuleDesc(StringRef PassID, Any IR) { 159 assert(StoreModuleDesc); 160 const Module *M = nullptr; 161 std::string Extra; 162 if (auto UnwrappedModule = unwrapModule(IR)) 163 std::tie(M, Extra) = UnwrappedModule.getValue(); 164 ModuleDescStack.emplace_back(M, Extra, PassID); 165 } 166 167 PrintIRInstrumentation::PrintModuleDesc 168 PrintIRInstrumentation::popModuleDesc(StringRef PassID) { 169 assert(!ModuleDescStack.empty() && "empty ModuleDescStack"); 170 PrintModuleDesc ModuleDesc = ModuleDescStack.pop_back_val(); 171 assert(std::get<2>(ModuleDesc).equals(PassID) && "malformed ModuleDescStack"); 172 return ModuleDesc; 173 } 174 175 bool PrintIRInstrumentation::printBeforePass(StringRef PassID, Any IR) { 176 if (PassID.startswith("PassManager<") || PassID.contains("PassAdaptor<")) 177 return true; 178 179 // Saving Module for AfterPassInvalidated operations. 180 // Note: here we rely on a fact that we do not change modules while 181 // traversing the pipeline, so the latest captured module is good 182 // for all print operations that has not happen yet. 183 if (StoreModuleDesc && llvm::shouldPrintAfterPass(PassID)) 184 pushModuleDesc(PassID, IR); 185 186 if (!llvm::shouldPrintBeforePass(PassID)) 187 return true; 188 189 SmallString<20> Banner = formatv("*** IR Dump Before {0} ***", PassID); 190 unwrapAndPrint(IR, Banner, llvm::forcePrintModuleIR()); 191 return true; 192 } 193 194 void PrintIRInstrumentation::printAfterPass(StringRef PassID, Any IR) { 195 if (PassID.startswith("PassManager<") || PassID.contains("PassAdaptor<")) 196 return; 197 198 if (!llvm::shouldPrintAfterPass(PassID)) 199 return; 200 201 if (StoreModuleDesc) 202 popModuleDesc(PassID); 203 204 SmallString<20> Banner = formatv("*** IR Dump After {0} ***", PassID); 205 unwrapAndPrint(IR, Banner, llvm::forcePrintModuleIR()); 206 } 207 208 void PrintIRInstrumentation::printAfterPassInvalidated(StringRef PassID) { 209 if (!StoreModuleDesc || !llvm::shouldPrintAfterPass(PassID)) 210 return; 211 212 if (PassID.startswith("PassManager<") || PassID.contains("PassAdaptor<")) 213 return; 214 215 const Module *M; 216 std::string Extra; 217 StringRef StoredPassID; 218 std::tie(M, Extra, StoredPassID) = popModuleDesc(PassID); 219 // Additional filtering (e.g. -filter-print-func) can lead to module 220 // printing being skipped. 221 if (!M) 222 return; 223 224 SmallString<20> Banner = 225 formatv("*** IR Dump After {0} *** invalidated: ", PassID); 226 printIR(M, Banner, Extra); 227 } 228 229 void PrintIRInstrumentation::registerCallbacks( 230 PassInstrumentationCallbacks &PIC) { 231 // BeforePass callback is not just for printing, it also saves a Module 232 // for later use in AfterPassInvalidated. 233 StoreModuleDesc = llvm::forcePrintModuleIR() && llvm::shouldPrintAfterPass(); 234 if (llvm::shouldPrintBeforePass() || StoreModuleDesc) 235 PIC.registerBeforePassCallback( 236 [this](StringRef P, Any IR) { return this->printBeforePass(P, IR); }); 237 238 if (llvm::shouldPrintAfterPass()) { 239 PIC.registerAfterPassCallback( 240 [this](StringRef P, Any IR) { this->printAfterPass(P, IR); }); 241 PIC.registerAfterPassInvalidatedCallback( 242 [this](StringRef P) { this->printAfterPassInvalidated(P); }); 243 } 244 } 245 246 void StandardInstrumentations::registerCallbacks( 247 PassInstrumentationCallbacks &PIC) { 248 PrintIR.registerCallbacks(PIC); 249 TimePasses.registerCallbacks(PIC); 250 } 251