1 //===-------- MIRSampleProfile.cpp: MIRSampleFDO (For FSAFDO) -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file provides the implementation of the MIRSampleProfile loader, mainly 10 // for flow sensitive SampleFDO. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/CodeGen/MIRSampleProfile.h" 15 #include "llvm/ADT/DenseMap.h" 16 #include "llvm/ADT/DenseSet.h" 17 #include "llvm/Analysis/BlockFrequencyInfoImpl.h" 18 #include "llvm/CodeGen/MachineBlockFrequencyInfo.h" 19 #include "llvm/CodeGen/MachineBranchProbabilityInfo.h" 20 #include "llvm/CodeGen/MachineDominators.h" 21 #include "llvm/CodeGen/MachineInstr.h" 22 #include "llvm/CodeGen/MachineLoopInfo.h" 23 #include "llvm/CodeGen/MachineOptimizationRemarkEmitter.h" 24 #include "llvm/CodeGen/MachinePostDominators.h" 25 #include "llvm/CodeGen/Passes.h" 26 #include "llvm/IR/Function.h" 27 #include "llvm/IR/PseudoProbe.h" 28 #include "llvm/InitializePasses.h" 29 #include "llvm/Support/CommandLine.h" 30 #include "llvm/Support/Debug.h" 31 #include "llvm/Support/VirtualFileSystem.h" 32 #include "llvm/Support/raw_ostream.h" 33 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h" 34 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h" 35 #include <optional> 36 37 using namespace llvm; 38 using namespace sampleprof; 39 using namespace llvm::sampleprofutil; 40 using ProfileCount = Function::ProfileCount; 41 42 #define DEBUG_TYPE "fs-profile-loader" 43 44 static cl::opt<bool> ShowFSBranchProb( 45 "show-fs-branchprob", cl::Hidden, cl::init(false), 46 cl::desc("Print setting flow sensitive branch probabilities")); 47 static cl::opt<unsigned> FSProfileDebugProbDiffThreshold( 48 "fs-profile-debug-prob-diff-threshold", cl::init(10), 49 cl::desc("Only show debug message if the branch probility is greater than " 50 "this value (in percentage).")); 51 52 static cl::opt<unsigned> FSProfileDebugBWThreshold( 53 "fs-profile-debug-bw-threshold", cl::init(10000), 54 cl::desc("Only show debug message if the source branch weight is greater " 55 " than this value.")); 56 57 static cl::opt<bool> ViewBFIBefore("fs-viewbfi-before", cl::Hidden, 58 cl::init(false), 59 cl::desc("View BFI before MIR loader")); 60 static cl::opt<bool> ViewBFIAfter("fs-viewbfi-after", cl::Hidden, 61 cl::init(false), 62 cl::desc("View BFI after MIR loader")); 63 64 extern cl::opt<bool> ImprovedFSDiscriminator; 65 char MIRProfileLoaderPass::ID = 0; 66 67 INITIALIZE_PASS_BEGIN(MIRProfileLoaderPass, DEBUG_TYPE, 68 "Load MIR Sample Profile", 69 /* cfg = */ false, /* is_analysis = */ false) 70 INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo) 71 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree) 72 INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTree) 73 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo) 74 INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass) 75 INITIALIZE_PASS_END(MIRProfileLoaderPass, DEBUG_TYPE, "Load MIR Sample Profile", 76 /* cfg = */ false, /* is_analysis = */ false) 77 78 char &llvm::MIRProfileLoaderPassID = MIRProfileLoaderPass::ID; 79 80 FunctionPass * 81 llvm::createMIRProfileLoaderPass(std::string File, std::string RemappingFile, 82 FSDiscriminatorPass P, 83 IntrusiveRefCntPtr<vfs::FileSystem> FS) { 84 return new MIRProfileLoaderPass(File, RemappingFile, P, std::move(FS)); 85 } 86 87 namespace llvm { 88 89 // Internal option used to control BFI display only after MBP pass. 90 // Defined in CodeGen/MachineBlockFrequencyInfo.cpp: 91 // -view-block-layout-with-bfi={none | fraction | integer | count} 92 extern cl::opt<GVDAGType> ViewBlockLayoutWithBFI; 93 94 // Command line option to specify the name of the function for CFG dump 95 // Defined in Analysis/BlockFrequencyInfo.cpp: -view-bfi-func-name= 96 extern cl::opt<std::string> ViewBlockFreqFuncName; 97 98 std::optional<PseudoProbe> extractProbe(const MachineInstr &MI) { 99 if (MI.isPseudoProbe()) { 100 PseudoProbe Probe; 101 Probe.Id = MI.getOperand(1).getImm(); 102 Probe.Type = MI.getOperand(2).getImm(); 103 Probe.Attr = MI.getOperand(3).getImm(); 104 Probe.Factor = 1; 105 DILocation *DebugLoc = MI.getDebugLoc(); 106 Probe.Discriminator = DebugLoc ? DebugLoc->getDiscriminator() : 0; 107 return Probe; 108 } 109 110 // Ignore callsite probes since they do not have FS discriminators. 111 return std::nullopt; 112 } 113 114 namespace afdo_detail { 115 template <> struct IRTraits<MachineBasicBlock> { 116 using InstructionT = MachineInstr; 117 using BasicBlockT = MachineBasicBlock; 118 using FunctionT = MachineFunction; 119 using BlockFrequencyInfoT = MachineBlockFrequencyInfo; 120 using LoopT = MachineLoop; 121 using LoopInfoPtrT = MachineLoopInfo *; 122 using DominatorTreePtrT = MachineDominatorTree *; 123 using PostDominatorTreePtrT = MachinePostDominatorTree *; 124 using PostDominatorTreeT = MachinePostDominatorTree; 125 using OptRemarkEmitterT = MachineOptimizationRemarkEmitter; 126 using OptRemarkAnalysisT = MachineOptimizationRemarkAnalysis; 127 using PredRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>; 128 using SuccRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>; 129 static Function &getFunction(MachineFunction &F) { return F.getFunction(); } 130 static const MachineBasicBlock *getEntryBB(const MachineFunction *F) { 131 return GraphTraits<const MachineFunction *>::getEntryNode(F); 132 } 133 static PredRangeT getPredecessors(MachineBasicBlock *BB) { 134 return BB->predecessors(); 135 } 136 static SuccRangeT getSuccessors(MachineBasicBlock *BB) { 137 return BB->successors(); 138 } 139 }; 140 } // namespace afdo_detail 141 142 class MIRProfileLoader final 143 : public SampleProfileLoaderBaseImpl<MachineFunction> { 144 public: 145 void setInitVals(MachineDominatorTree *MDT, MachinePostDominatorTree *MPDT, 146 MachineLoopInfo *MLI, MachineBlockFrequencyInfo *MBFI, 147 MachineOptimizationRemarkEmitter *MORE) { 148 DT = MDT; 149 PDT = MPDT; 150 LI = MLI; 151 BFI = MBFI; 152 ORE = MORE; 153 } 154 void setFSPass(FSDiscriminatorPass Pass) { 155 P = Pass; 156 LowBit = getFSPassBitBegin(P); 157 HighBit = getFSPassBitEnd(P); 158 assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit"); 159 } 160 161 MIRProfileLoader(StringRef Name, StringRef RemapName, 162 IntrusiveRefCntPtr<vfs::FileSystem> FS) 163 : SampleProfileLoaderBaseImpl(std::string(Name), std::string(RemapName), 164 std::move(FS)) {} 165 166 void setBranchProbs(MachineFunction &F); 167 bool runOnFunction(MachineFunction &F); 168 bool doInitialization(Module &M); 169 bool isValid() const { return ProfileIsValid; } 170 171 protected: 172 friend class SampleCoverageTracker; 173 174 /// Hold the information of the basic block frequency. 175 MachineBlockFrequencyInfo *BFI; 176 177 /// PassNum is the sequence number this pass is called, start from 1. 178 FSDiscriminatorPass P; 179 180 // LowBit in the FS discriminator used by this instance. Note the number is 181 // 0-based. Base discrimnator use bit 0 to bit 11. 182 unsigned LowBit; 183 // HighwBit in the FS discriminator used by this instance. Note the number 184 // is 0-based. 185 unsigned HighBit; 186 187 bool ProfileIsValid = true; 188 ErrorOr<uint64_t> getInstWeight(const MachineInstr &MI) override { 189 if (FunctionSamples::ProfileIsProbeBased) 190 return getProbeWeight(MI); 191 if (ImprovedFSDiscriminator && MI.isMetaInstruction()) 192 return std::error_code(); 193 return getInstWeightImpl(MI); 194 } 195 }; 196 197 template <> 198 void SampleProfileLoaderBaseImpl<MachineFunction>::computeDominanceAndLoopInfo( 199 MachineFunction &F) {} 200 201 void MIRProfileLoader::setBranchProbs(MachineFunction &F) { 202 LLVM_DEBUG(dbgs() << "\nPropagation complete. Setting branch probs\n"); 203 for (auto &BI : F) { 204 MachineBasicBlock *BB = &BI; 205 if (BB->succ_size() < 2) 206 continue; 207 const MachineBasicBlock *EC = EquivalenceClass[BB]; 208 uint64_t BBWeight = BlockWeights[EC]; 209 uint64_t SumEdgeWeight = 0; 210 for (MachineBasicBlock *Succ : BB->successors()) { 211 Edge E = std::make_pair(BB, Succ); 212 SumEdgeWeight += EdgeWeights[E]; 213 } 214 215 if (BBWeight != SumEdgeWeight) { 216 LLVM_DEBUG(dbgs() << "BBweight is not equal to SumEdgeWeight: BBWWeight=" 217 << BBWeight << " SumEdgeWeight= " << SumEdgeWeight 218 << "\n"); 219 BBWeight = SumEdgeWeight; 220 } 221 if (BBWeight == 0) { 222 LLVM_DEBUG(dbgs() << "SKIPPED. All branch weights are zero.\n"); 223 continue; 224 } 225 226 #ifndef NDEBUG 227 uint64_t BBWeightOrig = BBWeight; 228 #endif 229 uint32_t MaxWeight = std::numeric_limits<uint32_t>::max(); 230 uint32_t Factor = 1; 231 if (BBWeight > MaxWeight) { 232 Factor = BBWeight / MaxWeight + 1; 233 BBWeight /= Factor; 234 LLVM_DEBUG(dbgs() << "Scaling weights by " << Factor << "\n"); 235 } 236 237 for (MachineBasicBlock::succ_iterator SI = BB->succ_begin(), 238 SE = BB->succ_end(); 239 SI != SE; ++SI) { 240 MachineBasicBlock *Succ = *SI; 241 Edge E = std::make_pair(BB, Succ); 242 uint64_t EdgeWeight = EdgeWeights[E]; 243 EdgeWeight /= Factor; 244 245 assert(BBWeight >= EdgeWeight && 246 "BBweight is larger than EdgeWeight -- should not happen.\n"); 247 248 BranchProbability OldProb = BFI->getMBPI()->getEdgeProbability(BB, SI); 249 BranchProbability NewProb(EdgeWeight, BBWeight); 250 if (OldProb == NewProb) 251 continue; 252 BB->setSuccProbability(SI, NewProb); 253 #ifndef NDEBUG 254 if (!ShowFSBranchProb) 255 continue; 256 bool Show = false; 257 BranchProbability Diff; 258 if (OldProb > NewProb) 259 Diff = OldProb - NewProb; 260 else 261 Diff = NewProb - OldProb; 262 Show = (Diff >= BranchProbability(FSProfileDebugProbDiffThreshold, 100)); 263 Show &= (BBWeightOrig >= FSProfileDebugBWThreshold); 264 265 auto DIL = BB->findBranchDebugLoc(); 266 auto SuccDIL = Succ->findBranchDebugLoc(); 267 if (Show) { 268 dbgs() << "Set branch fs prob: MBB (" << BB->getNumber() << " -> " 269 << Succ->getNumber() << "): "; 270 if (DIL) 271 dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":" 272 << DIL->getColumn(); 273 if (SuccDIL) 274 dbgs() << "-->" << SuccDIL->getFilename() << ":" << SuccDIL->getLine() 275 << ":" << SuccDIL->getColumn(); 276 dbgs() << " W=" << BBWeightOrig << " " << OldProb << " --> " << NewProb 277 << "\n"; 278 } 279 #endif 280 } 281 } 282 } 283 284 bool MIRProfileLoader::doInitialization(Module &M) { 285 auto &Ctx = M.getContext(); 286 287 auto ReaderOrErr = sampleprof::SampleProfileReader::create( 288 Filename, Ctx, *FS, P, RemappingFilename); 289 if (std::error_code EC = ReaderOrErr.getError()) { 290 std::string Msg = "Could not open profile: " + EC.message(); 291 Ctx.diagnose(DiagnosticInfoSampleProfile(Filename, Msg)); 292 return false; 293 } 294 295 Reader = std::move(ReaderOrErr.get()); 296 Reader->setModule(&M); 297 ProfileIsValid = (Reader->read() == sampleprof_error::success); 298 299 // Load pseudo probe descriptors for probe-based function samples. 300 if (Reader->profileIsProbeBased()) { 301 ProbeManager = std::make_unique<PseudoProbeManager>(M); 302 if (!ProbeManager->moduleIsProbed(M)) { 303 return false; 304 } 305 } 306 307 return true; 308 } 309 310 bool MIRProfileLoader::runOnFunction(MachineFunction &MF) { 311 // Do not load non-FS profiles. A line or probe can get a zero-valued 312 // discriminator at certain pass which could result in accidentally loading 313 // the corresponding base counter in the non-FS profile, while a non-zero 314 // discriminator would end up getting zero samples. This could in turn undo 315 // the sample distribution effort done by previous BFI maintenance and the 316 // probe distribution factor work for pseudo probes. 317 if (!Reader->profileIsFS()) 318 return false; 319 320 Function &Func = MF.getFunction(); 321 clearFunctionData(false); 322 Samples = Reader->getSamplesFor(Func); 323 if (!Samples || Samples->empty()) 324 return false; 325 326 if (FunctionSamples::ProfileIsProbeBased) { 327 if (!ProbeManager->profileIsValid(MF.getFunction(), *Samples)) 328 return false; 329 } else { 330 if (getFunctionLoc(MF) == 0) 331 return false; 332 } 333 334 DenseSet<GlobalValue::GUID> InlinedGUIDs; 335 bool Changed = computeAndPropagateWeights(MF, InlinedGUIDs); 336 337 // Set the new BPI, BFI. 338 setBranchProbs(MF); 339 340 return Changed; 341 } 342 343 } // namespace llvm 344 345 MIRProfileLoaderPass::MIRProfileLoaderPass( 346 std::string FileName, std::string RemappingFileName, FSDiscriminatorPass P, 347 IntrusiveRefCntPtr<vfs::FileSystem> FS) 348 : MachineFunctionPass(ID), ProfileFileName(FileName), P(P) { 349 LowBit = getFSPassBitBegin(P); 350 HighBit = getFSPassBitEnd(P); 351 352 auto VFS = FS ? std::move(FS) : vfs::getRealFileSystem(); 353 MIRSampleLoader = std::make_unique<MIRProfileLoader>( 354 FileName, RemappingFileName, std::move(VFS)); 355 assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit"); 356 } 357 358 bool MIRProfileLoaderPass::runOnMachineFunction(MachineFunction &MF) { 359 if (!MIRSampleLoader->isValid()) 360 return false; 361 362 LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Func: " 363 << MF.getFunction().getName() << "\n"); 364 MBFI = &getAnalysis<MachineBlockFrequencyInfo>(); 365 MIRSampleLoader->setInitVals( 366 &getAnalysis<MachineDominatorTree>(), 367 &getAnalysis<MachinePostDominatorTree>(), &getAnalysis<MachineLoopInfo>(), 368 MBFI, &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE()); 369 370 MF.RenumberBlocks(); 371 if (ViewBFIBefore && ViewBlockLayoutWithBFI != GVDT_None && 372 (ViewBlockFreqFuncName.empty() || 373 MF.getFunction().getName().equals(ViewBlockFreqFuncName))) { 374 MBFI->view("MIR_Prof_loader_b." + MF.getName(), false); 375 } 376 377 bool Changed = MIRSampleLoader->runOnFunction(MF); 378 if (Changed) 379 MBFI->calculate(MF, *MBFI->getMBPI(), *&getAnalysis<MachineLoopInfo>()); 380 381 if (ViewBFIAfter && ViewBlockLayoutWithBFI != GVDT_None && 382 (ViewBlockFreqFuncName.empty() || 383 MF.getFunction().getName().equals(ViewBlockFreqFuncName))) { 384 MBFI->view("MIR_prof_loader_a." + MF.getName(), false); 385 } 386 387 return Changed; 388 } 389 390 bool MIRProfileLoaderPass::doInitialization(Module &M) { 391 LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Module " << M.getName() 392 << "\n"); 393 394 MIRSampleLoader->setFSPass(P); 395 return MIRSampleLoader->doInitialization(M); 396 } 397 398 void MIRProfileLoaderPass::getAnalysisUsage(AnalysisUsage &AU) const { 399 AU.setPreservesAll(); 400 AU.addRequired<MachineBlockFrequencyInfo>(); 401 AU.addRequired<MachineDominatorTree>(); 402 AU.addRequired<MachinePostDominatorTree>(); 403 AU.addRequiredTransitive<MachineLoopInfo>(); 404 AU.addRequired<MachineOptimizationRemarkEmitterPass>(); 405 MachineFunctionPass::getAnalysisUsage(AU); 406 } 407