xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/MIRSampleProfile.cpp (revision d54a7d337331d991e039e4f42f6b4dc64aedce08)
1  //===-------- MIRSampleProfile.cpp: MIRSampleFDO (For FSAFDO) -------------===//
2  //
3  // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4  // See https://llvm.org/LICENSE.txt for license information.
5  // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6  //
7  //===----------------------------------------------------------------------===//
8  //
9  // This file provides the implementation of the MIRSampleProfile loader, mainly
10  // for flow sensitive SampleFDO.
11  //
12  //===----------------------------------------------------------------------===//
13  
14  #include "llvm/CodeGen/MIRSampleProfile.h"
15  #include "llvm/ADT/DenseMap.h"
16  #include "llvm/ADT/DenseSet.h"
17  #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
18  #include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
19  #include "llvm/CodeGen/MachineBranchProbabilityInfo.h"
20  #include "llvm/CodeGen/MachineDominators.h"
21  #include "llvm/CodeGen/MachineLoopInfo.h"
22  #include "llvm/CodeGen/MachineOptimizationRemarkEmitter.h"
23  #include "llvm/CodeGen/MachinePostDominators.h"
24  #include "llvm/CodeGen/Passes.h"
25  #include "llvm/IR/Function.h"
26  #include "llvm/InitializePasses.h"
27  #include "llvm/Support/CommandLine.h"
28  #include "llvm/Support/Debug.h"
29  #include "llvm/Support/raw_ostream.h"
30  #include "llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h"
31  #include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h"
32  
33  using namespace llvm;
34  using namespace sampleprof;
35  using namespace llvm::sampleprofutil;
36  using ProfileCount = Function::ProfileCount;
37  
38  #define DEBUG_TYPE "fs-profile-loader"
39  
40  static cl::opt<bool> ShowFSBranchProb(
41      "show-fs-branchprob", cl::Hidden, cl::init(false),
42      cl::desc("Print setting flow sensitive branch probabilities"));
43  static cl::opt<unsigned> FSProfileDebugProbDiffThreshold(
44      "fs-profile-debug-prob-diff-threshold", cl::init(10),
45      cl::desc("Only show debug message if the branch probility is greater than "
46               "this value (in percentage)."));
47  
48  static cl::opt<unsigned> FSProfileDebugBWThreshold(
49      "fs-profile-debug-bw-threshold", cl::init(10000),
50      cl::desc("Only show debug message if the source branch weight is greater "
51               " than this value."));
52  
53  static cl::opt<bool> ViewBFIBefore("fs-viewbfi-before", cl::Hidden,
54                                     cl::init(false),
55                                     cl::desc("View BFI before MIR loader"));
56  static cl::opt<bool> ViewBFIAfter("fs-viewbfi-after", cl::Hidden,
57                                    cl::init(false),
58                                    cl::desc("View BFI after MIR loader"));
59  
60  char MIRProfileLoaderPass::ID = 0;
61  
62  INITIALIZE_PASS_BEGIN(MIRProfileLoaderPass, DEBUG_TYPE,
63                        "Load MIR Sample Profile",
64                        /* cfg = */ false, /* is_analysis = */ false)
65  INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo)
66  INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree)
67  INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTree)
68  INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
69  INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass)
70  INITIALIZE_PASS_END(MIRProfileLoaderPass, DEBUG_TYPE, "Load MIR Sample Profile",
71                      /* cfg = */ false, /* is_analysis = */ false)
72  
73  char &llvm::MIRProfileLoaderPassID = MIRProfileLoaderPass::ID;
74  
75  FunctionPass *llvm::createMIRProfileLoaderPass(std::string File,
76                                                 std::string RemappingFile,
77                                                 FSDiscriminatorPass P) {
78    return new MIRProfileLoaderPass(File, RemappingFile, P);
79  }
80  
81  namespace llvm {
82  
83  // Internal option used to control BFI display only after MBP pass.
84  // Defined in CodeGen/MachineBlockFrequencyInfo.cpp:
85  // -view-block-layout-with-bfi={none | fraction | integer | count}
86  extern cl::opt<GVDAGType> ViewBlockLayoutWithBFI;
87  
88  // Command line option to specify the name of the function for CFG dump
89  // Defined in Analysis/BlockFrequencyInfo.cpp:  -view-bfi-func-name=
90  extern cl::opt<std::string> ViewBlockFreqFuncName;
91  
92  namespace afdo_detail {
93  template <> struct IRTraits<MachineBasicBlock> {
94    using InstructionT = MachineInstr;
95    using BasicBlockT = MachineBasicBlock;
96    using FunctionT = MachineFunction;
97    using BlockFrequencyInfoT = MachineBlockFrequencyInfo;
98    using LoopT = MachineLoop;
99    using LoopInfoPtrT = MachineLoopInfo *;
100    using DominatorTreePtrT = MachineDominatorTree *;
101    using PostDominatorTreePtrT = MachinePostDominatorTree *;
102    using PostDominatorTreeT = MachinePostDominatorTree;
103    using OptRemarkEmitterT = MachineOptimizationRemarkEmitter;
104    using OptRemarkAnalysisT = MachineOptimizationRemarkAnalysis;
105    using PredRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>;
106    using SuccRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>;
107    static Function &getFunction(MachineFunction &F) { return F.getFunction(); }
108    static const MachineBasicBlock *getEntryBB(const MachineFunction *F) {
109      return GraphTraits<const MachineFunction *>::getEntryNode(F);
110    }
111    static PredRangeT getPredecessors(MachineBasicBlock *BB) {
112      return BB->predecessors();
113    }
114    static SuccRangeT getSuccessors(MachineBasicBlock *BB) {
115      return BB->successors();
116    }
117  };
118  } // namespace afdo_detail
119  
120  class MIRProfileLoader final
121      : public SampleProfileLoaderBaseImpl<MachineBasicBlock> {
122  public:
123    void setInitVals(MachineDominatorTree *MDT, MachinePostDominatorTree *MPDT,
124                     MachineLoopInfo *MLI, MachineBlockFrequencyInfo *MBFI,
125                     MachineOptimizationRemarkEmitter *MORE) {
126      DT = MDT;
127      PDT = MPDT;
128      LI = MLI;
129      BFI = MBFI;
130      ORE = MORE;
131    }
132    void setFSPass(FSDiscriminatorPass Pass) {
133      P = Pass;
134      LowBit = getFSPassBitBegin(P);
135      HighBit = getFSPassBitEnd(P);
136      assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
137    }
138  
139    MIRProfileLoader(StringRef Name, StringRef RemapName)
140        : SampleProfileLoaderBaseImpl(std::string(Name), std::string(RemapName)) {
141    }
142  
143    void setBranchProbs(MachineFunction &F);
144    bool runOnFunction(MachineFunction &F);
145    bool doInitialization(Module &M);
146    bool isValid() const { return ProfileIsValid; }
147  
148  protected:
149    friend class SampleCoverageTracker;
150  
151    /// Hold the information of the basic block frequency.
152    MachineBlockFrequencyInfo *BFI;
153  
154    /// PassNum is the sequence number this pass is called, start from 1.
155    FSDiscriminatorPass P;
156  
157    // LowBit in the FS discriminator used by this instance. Note the number is
158    // 0-based. Base discrimnator use bit 0 to bit 11.
159    unsigned LowBit;
160    // HighwBit in the FS discriminator used by this instance. Note the number
161    // is 0-based.
162    unsigned HighBit;
163  
164    bool ProfileIsValid = true;
165  };
166  
167  template <>
168  void SampleProfileLoaderBaseImpl<
169      MachineBasicBlock>::computeDominanceAndLoopInfo(MachineFunction &F) {}
170  
171  void MIRProfileLoader::setBranchProbs(MachineFunction &F) {
172    LLVM_DEBUG(dbgs() << "\nPropagation complete. Setting branch probs\n");
173    for (auto &BI : F) {
174      MachineBasicBlock *BB = &BI;
175      if (BB->succ_size() < 2)
176        continue;
177      const MachineBasicBlock *EC = EquivalenceClass[BB];
178      uint64_t BBWeight = BlockWeights[EC];
179      uint64_t SumEdgeWeight = 0;
180      for (MachineBasicBlock *Succ : BB->successors()) {
181        Edge E = std::make_pair(BB, Succ);
182        SumEdgeWeight += EdgeWeights[E];
183      }
184  
185      if (BBWeight != SumEdgeWeight) {
186        LLVM_DEBUG(dbgs() << "BBweight is not equal to SumEdgeWeight: BBWWeight="
187                          << BBWeight << " SumEdgeWeight= " << SumEdgeWeight
188                          << "\n");
189        BBWeight = SumEdgeWeight;
190      }
191      if (BBWeight == 0) {
192        LLVM_DEBUG(dbgs() << "SKIPPED. All branch weights are zero.\n");
193        continue;
194      }
195  
196  #ifndef NDEBUG
197      uint64_t BBWeightOrig = BBWeight;
198  #endif
199      uint32_t MaxWeight = std::numeric_limits<uint32_t>::max();
200      uint32_t Factor = 1;
201      if (BBWeight > MaxWeight) {
202        Factor = BBWeight / MaxWeight + 1;
203        BBWeight /= Factor;
204        LLVM_DEBUG(dbgs() << "Scaling weights by " << Factor << "\n");
205      }
206  
207      for (MachineBasicBlock::succ_iterator SI = BB->succ_begin(),
208                                            SE = BB->succ_end();
209           SI != SE; ++SI) {
210        MachineBasicBlock *Succ = *SI;
211        Edge E = std::make_pair(BB, Succ);
212        uint64_t EdgeWeight = EdgeWeights[E];
213        EdgeWeight /= Factor;
214  
215        assert(BBWeight >= EdgeWeight &&
216               "BBweight is larger than EdgeWeight -- should not happen.\n");
217  
218        BranchProbability OldProb = BFI->getMBPI()->getEdgeProbability(BB, SI);
219        BranchProbability NewProb(EdgeWeight, BBWeight);
220        if (OldProb == NewProb)
221          continue;
222        BB->setSuccProbability(SI, NewProb);
223  #ifndef NDEBUG
224        if (!ShowFSBranchProb)
225          continue;
226        bool Show = false;
227        BranchProbability Diff;
228        if (OldProb > NewProb)
229          Diff = OldProb - NewProb;
230        else
231          Diff = NewProb - OldProb;
232        Show = (Diff >= BranchProbability(FSProfileDebugProbDiffThreshold, 100));
233        Show &= (BBWeightOrig >= FSProfileDebugBWThreshold);
234  
235        auto DIL = BB->findBranchDebugLoc();
236        auto SuccDIL = Succ->findBranchDebugLoc();
237        if (Show) {
238          dbgs() << "Set branch fs prob: MBB (" << BB->getNumber() << " -> "
239                 << Succ->getNumber() << "): ";
240          if (DIL)
241            dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":"
242                   << DIL->getColumn();
243          if (SuccDIL)
244            dbgs() << "-->" << SuccDIL->getFilename() << ":" << SuccDIL->getLine()
245                   << ":" << SuccDIL->getColumn();
246          dbgs() << " W=" << BBWeightOrig << "  " << OldProb << " --> " << NewProb
247                 << "\n";
248        }
249  #endif
250      }
251    }
252  }
253  
254  bool MIRProfileLoader::doInitialization(Module &M) {
255    auto &Ctx = M.getContext();
256  
257    auto ReaderOrErr = sampleprof::SampleProfileReader::create(Filename, Ctx, P,
258                                                               RemappingFilename);
259    if (std::error_code EC = ReaderOrErr.getError()) {
260      std::string Msg = "Could not open profile: " + EC.message();
261      Ctx.diagnose(DiagnosticInfoSampleProfile(Filename, Msg));
262      return false;
263    }
264  
265    Reader = std::move(ReaderOrErr.get());
266    Reader->setModule(&M);
267    ProfileIsValid = (Reader->read() == sampleprof_error::success);
268    Reader->getSummary();
269  
270    return true;
271  }
272  
273  bool MIRProfileLoader::runOnFunction(MachineFunction &MF) {
274    Function &Func = MF.getFunction();
275    clearFunctionData(false);
276    Samples = Reader->getSamplesFor(Func);
277    if (!Samples || Samples->empty())
278      return false;
279  
280    if (getFunctionLoc(MF) == 0)
281      return false;
282  
283    DenseSet<GlobalValue::GUID> InlinedGUIDs;
284    bool Changed = computeAndPropagateWeights(MF, InlinedGUIDs);
285  
286    // Set the new BPI, BFI.
287    setBranchProbs(MF);
288  
289    return Changed;
290  }
291  
292  } // namespace llvm
293  
294  MIRProfileLoaderPass::MIRProfileLoaderPass(std::string FileName,
295                                             std::string RemappingFileName,
296                                             FSDiscriminatorPass P)
297      : MachineFunctionPass(ID), ProfileFileName(FileName), P(P),
298        MIRSampleLoader(
299            std::make_unique<MIRProfileLoader>(FileName, RemappingFileName)) {
300    LowBit = getFSPassBitBegin(P);
301    HighBit = getFSPassBitEnd(P);
302    assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
303  }
304  
305  bool MIRProfileLoaderPass::runOnMachineFunction(MachineFunction &MF) {
306    if (!MIRSampleLoader->isValid())
307      return false;
308  
309    LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Func: "
310                      << MF.getFunction().getName() << "\n");
311    MBFI = &getAnalysis<MachineBlockFrequencyInfo>();
312    MIRSampleLoader->setInitVals(
313        &getAnalysis<MachineDominatorTree>(),
314        &getAnalysis<MachinePostDominatorTree>(), &getAnalysis<MachineLoopInfo>(),
315        MBFI, &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE());
316  
317    MF.RenumberBlocks();
318    if (ViewBFIBefore && ViewBlockLayoutWithBFI != GVDT_None &&
319        (ViewBlockFreqFuncName.empty() ||
320         MF.getFunction().getName().equals(ViewBlockFreqFuncName))) {
321      MBFI->view("MIR_Prof_loader_b." + MF.getName(), false);
322    }
323  
324    bool Changed = MIRSampleLoader->runOnFunction(MF);
325    if (Changed)
326      MBFI->calculate(MF, *MBFI->getMBPI(), *&getAnalysis<MachineLoopInfo>());
327  
328    if (ViewBFIAfter && ViewBlockLayoutWithBFI != GVDT_None &&
329        (ViewBlockFreqFuncName.empty() ||
330         MF.getFunction().getName().equals(ViewBlockFreqFuncName))) {
331      MBFI->view("MIR_prof_loader_a." + MF.getName(), false);
332    }
333  
334    return Changed;
335  }
336  
337  bool MIRProfileLoaderPass::doInitialization(Module &M) {
338    LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Module " << M.getName()
339                      << "\n");
340  
341    MIRSampleLoader->setFSPass(P);
342    return MIRSampleLoader->doInitialization(M);
343  }
344  
345  void MIRProfileLoaderPass::getAnalysisUsage(AnalysisUsage &AU) const {
346    AU.setPreservesAll();
347    AU.addRequired<MachineBlockFrequencyInfo>();
348    AU.addRequired<MachineDominatorTree>();
349    AU.addRequired<MachinePostDominatorTree>();
350    AU.addRequiredTransitive<MachineLoopInfo>();
351    AU.addRequired<MachineOptimizationRemarkEmitterPass>();
352    MachineFunctionPass::getAnalysisUsage(AU);
353  }
354