1 //===-------- MIRSampleProfile.cpp: MIRSampleFDO (For FSAFDO) -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file provides the implementation of the MIRSampleProfile loader, mainly 10 // for flow sensitive SampleFDO. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/CodeGen/MIRSampleProfile.h" 15 #include "llvm/ADT/DenseMap.h" 16 #include "llvm/ADT/DenseSet.h" 17 #include "llvm/Analysis/BlockFrequencyInfoImpl.h" 18 #include "llvm/CodeGen/MachineBlockFrequencyInfo.h" 19 #include "llvm/CodeGen/MachineBranchProbabilityInfo.h" 20 #include "llvm/CodeGen/MachineDominators.h" 21 #include "llvm/CodeGen/MachineLoopInfo.h" 22 #include "llvm/CodeGen/MachineOptimizationRemarkEmitter.h" 23 #include "llvm/CodeGen/MachinePostDominators.h" 24 #include "llvm/CodeGen/Passes.h" 25 #include "llvm/IR/Function.h" 26 #include "llvm/InitializePasses.h" 27 #include "llvm/Support/CommandLine.h" 28 #include "llvm/Support/Debug.h" 29 #include "llvm/Support/raw_ostream.h" 30 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h" 31 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h" 32 33 using namespace llvm; 34 using namespace sampleprof; 35 using namespace llvm::sampleprofutil; 36 using ProfileCount = Function::ProfileCount; 37 38 #define DEBUG_TYPE "fs-profile-loader" 39 40 static cl::opt<bool> ShowFSBranchProb( 41 "show-fs-branchprob", cl::Hidden, cl::init(false), 42 cl::desc("Print setting flow sensitive branch probabilities")); 43 static cl::opt<unsigned> FSProfileDebugProbDiffThreshold( 44 "fs-profile-debug-prob-diff-threshold", cl::init(10), 45 cl::desc("Only show debug message if the branch probility is greater than " 46 "this value (in percentage).")); 47 48 static cl::opt<unsigned> FSProfileDebugBWThreshold( 49 "fs-profile-debug-bw-threshold", cl::init(10000), 50 cl::desc("Only show debug message if the source branch weight is greater " 51 " than this value.")); 52 53 static cl::opt<bool> ViewBFIBefore("fs-viewbfi-before", cl::Hidden, 54 cl::init(false), 55 cl::desc("View BFI before MIR loader")); 56 static cl::opt<bool> ViewBFIAfter("fs-viewbfi-after", cl::Hidden, 57 cl::init(false), 58 cl::desc("View BFI after MIR loader")); 59 60 char MIRProfileLoaderPass::ID = 0; 61 62 INITIALIZE_PASS_BEGIN(MIRProfileLoaderPass, DEBUG_TYPE, 63 "Load MIR Sample Profile", 64 /* cfg = */ false, /* is_analysis = */ false) 65 INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo) 66 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree) 67 INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTree) 68 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo) 69 INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass) 70 INITIALIZE_PASS_END(MIRProfileLoaderPass, DEBUG_TYPE, "Load MIR Sample Profile", 71 /* cfg = */ false, /* is_analysis = */ false) 72 73 char &llvm::MIRProfileLoaderPassID = MIRProfileLoaderPass::ID; 74 75 FunctionPass *llvm::createMIRProfileLoaderPass(std::string File, 76 std::string RemappingFile, 77 FSDiscriminatorPass P) { 78 return new MIRProfileLoaderPass(File, RemappingFile, P); 79 } 80 81 namespace llvm { 82 83 // Internal option used to control BFI display only after MBP pass. 84 // Defined in CodeGen/MachineBlockFrequencyInfo.cpp: 85 // -view-block-layout-with-bfi={none | fraction | integer | count} 86 extern cl::opt<GVDAGType> ViewBlockLayoutWithBFI; 87 88 // Command line option to specify the name of the function for CFG dump 89 // Defined in Analysis/BlockFrequencyInfo.cpp: -view-bfi-func-name= 90 extern cl::opt<std::string> ViewBlockFreqFuncName; 91 92 namespace afdo_detail { 93 template <> struct IRTraits<MachineBasicBlock> { 94 using InstructionT = MachineInstr; 95 using BasicBlockT = MachineBasicBlock; 96 using FunctionT = MachineFunction; 97 using BlockFrequencyInfoT = MachineBlockFrequencyInfo; 98 using LoopT = MachineLoop; 99 using LoopInfoPtrT = MachineLoopInfo *; 100 using DominatorTreePtrT = MachineDominatorTree *; 101 using PostDominatorTreePtrT = MachinePostDominatorTree *; 102 using PostDominatorTreeT = MachinePostDominatorTree; 103 using OptRemarkEmitterT = MachineOptimizationRemarkEmitter; 104 using OptRemarkAnalysisT = MachineOptimizationRemarkAnalysis; 105 using PredRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>; 106 using SuccRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>; 107 static Function &getFunction(MachineFunction &F) { return F.getFunction(); } 108 static const MachineBasicBlock *getEntryBB(const MachineFunction *F) { 109 return GraphTraits<const MachineFunction *>::getEntryNode(F); 110 } 111 static PredRangeT getPredecessors(MachineBasicBlock *BB) { 112 return BB->predecessors(); 113 } 114 static SuccRangeT getSuccessors(MachineBasicBlock *BB) { 115 return BB->successors(); 116 } 117 }; 118 } // namespace afdo_detail 119 120 class MIRProfileLoader final 121 : public SampleProfileLoaderBaseImpl<MachineBasicBlock> { 122 public: 123 void setInitVals(MachineDominatorTree *MDT, MachinePostDominatorTree *MPDT, 124 MachineLoopInfo *MLI, MachineBlockFrequencyInfo *MBFI, 125 MachineOptimizationRemarkEmitter *MORE) { 126 DT = MDT; 127 PDT = MPDT; 128 LI = MLI; 129 BFI = MBFI; 130 ORE = MORE; 131 } 132 void setFSPass(FSDiscriminatorPass Pass) { 133 P = Pass; 134 LowBit = getFSPassBitBegin(P); 135 HighBit = getFSPassBitEnd(P); 136 assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit"); 137 } 138 139 MIRProfileLoader(StringRef Name, StringRef RemapName) 140 : SampleProfileLoaderBaseImpl(std::string(Name), std::string(RemapName)) { 141 } 142 143 void setBranchProbs(MachineFunction &F); 144 bool runOnFunction(MachineFunction &F); 145 bool doInitialization(Module &M); 146 bool isValid() const { return ProfileIsValid; } 147 148 protected: 149 friend class SampleCoverageTracker; 150 151 /// Hold the information of the basic block frequency. 152 MachineBlockFrequencyInfo *BFI; 153 154 /// PassNum is the sequence number this pass is called, start from 1. 155 FSDiscriminatorPass P; 156 157 // LowBit in the FS discriminator used by this instance. Note the number is 158 // 0-based. Base discrimnator use bit 0 to bit 11. 159 unsigned LowBit; 160 // HighwBit in the FS discriminator used by this instance. Note the number 161 // is 0-based. 162 unsigned HighBit; 163 164 bool ProfileIsValid = true; 165 }; 166 167 template <> 168 void SampleProfileLoaderBaseImpl< 169 MachineBasicBlock>::computeDominanceAndLoopInfo(MachineFunction &F) {} 170 171 void MIRProfileLoader::setBranchProbs(MachineFunction &F) { 172 LLVM_DEBUG(dbgs() << "\nPropagation complete. Setting branch probs\n"); 173 for (auto &BI : F) { 174 MachineBasicBlock *BB = &BI; 175 if (BB->succ_size() < 2) 176 continue; 177 const MachineBasicBlock *EC = EquivalenceClass[BB]; 178 uint64_t BBWeight = BlockWeights[EC]; 179 uint64_t SumEdgeWeight = 0; 180 for (MachineBasicBlock *Succ : BB->successors()) { 181 Edge E = std::make_pair(BB, Succ); 182 SumEdgeWeight += EdgeWeights[E]; 183 } 184 185 if (BBWeight != SumEdgeWeight) { 186 LLVM_DEBUG(dbgs() << "BBweight is not equal to SumEdgeWeight: BBWWeight=" 187 << BBWeight << " SumEdgeWeight= " << SumEdgeWeight 188 << "\n"); 189 BBWeight = SumEdgeWeight; 190 } 191 if (BBWeight == 0) { 192 LLVM_DEBUG(dbgs() << "SKIPPED. All branch weights are zero.\n"); 193 continue; 194 } 195 196 #ifndef NDEBUG 197 uint64_t BBWeightOrig = BBWeight; 198 #endif 199 uint32_t MaxWeight = std::numeric_limits<uint32_t>::max(); 200 uint32_t Factor = 1; 201 if (BBWeight > MaxWeight) { 202 Factor = BBWeight / MaxWeight + 1; 203 BBWeight /= Factor; 204 LLVM_DEBUG(dbgs() << "Scaling weights by " << Factor << "\n"); 205 } 206 207 for (MachineBasicBlock::succ_iterator SI = BB->succ_begin(), 208 SE = BB->succ_end(); 209 SI != SE; ++SI) { 210 MachineBasicBlock *Succ = *SI; 211 Edge E = std::make_pair(BB, Succ); 212 uint64_t EdgeWeight = EdgeWeights[E]; 213 EdgeWeight /= Factor; 214 215 assert(BBWeight >= EdgeWeight && 216 "BBweight is larger than EdgeWeight -- should not happen.\n"); 217 218 BranchProbability OldProb = BFI->getMBPI()->getEdgeProbability(BB, SI); 219 BranchProbability NewProb(EdgeWeight, BBWeight); 220 if (OldProb == NewProb) 221 continue; 222 BB->setSuccProbability(SI, NewProb); 223 #ifndef NDEBUG 224 if (!ShowFSBranchProb) 225 continue; 226 bool Show = false; 227 BranchProbability Diff; 228 if (OldProb > NewProb) 229 Diff = OldProb - NewProb; 230 else 231 Diff = NewProb - OldProb; 232 Show = (Diff >= BranchProbability(FSProfileDebugProbDiffThreshold, 100)); 233 Show &= (BBWeightOrig >= FSProfileDebugBWThreshold); 234 235 auto DIL = BB->findBranchDebugLoc(); 236 auto SuccDIL = Succ->findBranchDebugLoc(); 237 if (Show) { 238 dbgs() << "Set branch fs prob: MBB (" << BB->getNumber() << " -> " 239 << Succ->getNumber() << "): "; 240 if (DIL) 241 dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":" 242 << DIL->getColumn(); 243 if (SuccDIL) 244 dbgs() << "-->" << SuccDIL->getFilename() << ":" << SuccDIL->getLine() 245 << ":" << SuccDIL->getColumn(); 246 dbgs() << " W=" << BBWeightOrig << " " << OldProb << " --> " << NewProb 247 << "\n"; 248 } 249 #endif 250 } 251 } 252 } 253 254 bool MIRProfileLoader::doInitialization(Module &M) { 255 auto &Ctx = M.getContext(); 256 257 auto ReaderOrErr = sampleprof::SampleProfileReader::create(Filename, Ctx, P, 258 RemappingFilename); 259 if (std::error_code EC = ReaderOrErr.getError()) { 260 std::string Msg = "Could not open profile: " + EC.message(); 261 Ctx.diagnose(DiagnosticInfoSampleProfile(Filename, Msg)); 262 return false; 263 } 264 265 Reader = std::move(ReaderOrErr.get()); 266 Reader->setModule(&M); 267 ProfileIsValid = (Reader->read() == sampleprof_error::success); 268 Reader->getSummary(); 269 270 return true; 271 } 272 273 bool MIRProfileLoader::runOnFunction(MachineFunction &MF) { 274 Function &Func = MF.getFunction(); 275 clearFunctionData(false); 276 Samples = Reader->getSamplesFor(Func); 277 if (!Samples || Samples->empty()) 278 return false; 279 280 if (getFunctionLoc(MF) == 0) 281 return false; 282 283 DenseSet<GlobalValue::GUID> InlinedGUIDs; 284 bool Changed = computeAndPropagateWeights(MF, InlinedGUIDs); 285 286 // Set the new BPI, BFI. 287 setBranchProbs(MF); 288 289 return Changed; 290 } 291 292 } // namespace llvm 293 294 MIRProfileLoaderPass::MIRProfileLoaderPass(std::string FileName, 295 std::string RemappingFileName, 296 FSDiscriminatorPass P) 297 : MachineFunctionPass(ID), ProfileFileName(FileName), P(P), 298 MIRSampleLoader( 299 std::make_unique<MIRProfileLoader>(FileName, RemappingFileName)) { 300 LowBit = getFSPassBitBegin(P); 301 HighBit = getFSPassBitEnd(P); 302 assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit"); 303 } 304 305 bool MIRProfileLoaderPass::runOnMachineFunction(MachineFunction &MF) { 306 if (!MIRSampleLoader->isValid()) 307 return false; 308 309 LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Func: " 310 << MF.getFunction().getName() << "\n"); 311 MBFI = &getAnalysis<MachineBlockFrequencyInfo>(); 312 MIRSampleLoader->setInitVals( 313 &getAnalysis<MachineDominatorTree>(), 314 &getAnalysis<MachinePostDominatorTree>(), &getAnalysis<MachineLoopInfo>(), 315 MBFI, &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE()); 316 317 MF.RenumberBlocks(); 318 if (ViewBFIBefore && ViewBlockLayoutWithBFI != GVDT_None && 319 (ViewBlockFreqFuncName.empty() || 320 MF.getFunction().getName().equals(ViewBlockFreqFuncName))) { 321 MBFI->view("MIR_Prof_loader_b." + MF.getName(), false); 322 } 323 324 bool Changed = MIRSampleLoader->runOnFunction(MF); 325 if (Changed) 326 MBFI->calculate(MF, *MBFI->getMBPI(), *&getAnalysis<MachineLoopInfo>()); 327 328 if (ViewBFIAfter && ViewBlockLayoutWithBFI != GVDT_None && 329 (ViewBlockFreqFuncName.empty() || 330 MF.getFunction().getName().equals(ViewBlockFreqFuncName))) { 331 MBFI->view("MIR_prof_loader_a." + MF.getName(), false); 332 } 333 334 return Changed; 335 } 336 337 bool MIRProfileLoaderPass::doInitialization(Module &M) { 338 LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Module " << M.getName() 339 << "\n"); 340 341 MIRSampleLoader->setFSPass(P); 342 return MIRSampleLoader->doInitialization(M); 343 } 344 345 void MIRProfileLoaderPass::getAnalysisUsage(AnalysisUsage &AU) const { 346 AU.setPreservesAll(); 347 AU.addRequired<MachineBlockFrequencyInfo>(); 348 AU.addRequired<MachineDominatorTree>(); 349 AU.addRequired<MachinePostDominatorTree>(); 350 AU.addRequiredTransitive<MachineLoopInfo>(); 351 AU.addRequired<MachineOptimizationRemarkEmitterPass>(); 352 MachineFunctionPass::getAnalysisUsage(AU); 353 } 354