xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUSetWavePriority.cpp (revision 4fbb9c43aa44d9145151bb5f77d302ba01fb7551)
1 //===- AMDGPUSetWavePriority.cpp - Set wave priority ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Pass to temporarily raise the wave priority beginning the start of
11 /// the shader function until its last VMEM instructions to allow younger
12 /// waves to issue their VMEM instructions as well.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "AMDGPU.h"
17 #include "GCNSubtarget.h"
18 #include "MCTargetDesc/AMDGPUMCTargetDesc.h"
19 #include "SIInstrInfo.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/CodeGen/MachineFunctionPass.h"
22 #include "llvm/InitializePasses.h"
23 #include "llvm/Support/Allocator.h"
24 
25 using namespace llvm;
26 
27 #define DEBUG_TYPE "amdgpu-set-wave-priority"
28 
29 static cl::opt<unsigned> DefaultVALUInstsThreshold(
30     "amdgpu-set-wave-priority-valu-insts-threshold",
31     cl::desc("VALU instruction count threshold for adjusting wave priority"),
32     cl::init(100), cl::Hidden);
33 
34 namespace {
35 
36 struct MBBInfo {
37   MBBInfo() = default;
38   unsigned NumVALUInstsAtStart = 0;
39   bool MayReachVMEMLoad = false;
40   MachineInstr *LastVMEMLoad = nullptr;
41 };
42 
43 using MBBInfoSet = DenseMap<const MachineBasicBlock *, MBBInfo>;
44 
45 class AMDGPUSetWavePriority : public MachineFunctionPass {
46 public:
47   static char ID;
48 
49   AMDGPUSetWavePriority() : MachineFunctionPass(ID) {}
50 
51   StringRef getPassName() const override { return "Set wave priority"; }
52 
53   bool runOnMachineFunction(MachineFunction &MF) override;
54 
55 private:
56   MachineInstr *BuildSetprioMI(MachineBasicBlock &MBB,
57                                MachineBasicBlock::iterator I,
58                                unsigned priority) const;
59 
60   const SIInstrInfo *TII;
61 };
62 
63 } // End anonymous namespace.
64 
65 INITIALIZE_PASS(AMDGPUSetWavePriority, DEBUG_TYPE, "Set wave priority", false,
66                 false)
67 
68 char AMDGPUSetWavePriority::ID = 0;
69 
70 FunctionPass *llvm::createAMDGPUSetWavePriorityPass() {
71   return new AMDGPUSetWavePriority();
72 }
73 
74 MachineInstr *
75 AMDGPUSetWavePriority::BuildSetprioMI(MachineBasicBlock &MBB,
76                                       MachineBasicBlock::iterator I,
77                                       unsigned priority) const {
78   return BuildMI(MBB, I, DebugLoc(), TII->get(AMDGPU::S_SETPRIO))
79       .addImm(priority);
80 }
81 
82 // Checks that for every predecessor Pred that can reach a VMEM load,
83 // none of Pred's successors can reach a VMEM load.
84 static bool CanLowerPriorityDirectlyInPredecessors(const MachineBasicBlock &MBB,
85                                                    MBBInfoSet &MBBInfos) {
86   for (const MachineBasicBlock *Pred : MBB.predecessors()) {
87     if (!MBBInfos[Pred].MayReachVMEMLoad)
88       continue;
89     for (const MachineBasicBlock *Succ : Pred->successors()) {
90       if (MBBInfos[Succ].MayReachVMEMLoad)
91         return false;
92     }
93   }
94   return true;
95 }
96 
97 static bool isVMEMLoad(const MachineInstr &MI) {
98   return SIInstrInfo::isVMEM(MI) && MI.mayLoad();
99 }
100 
101 bool AMDGPUSetWavePriority::runOnMachineFunction(MachineFunction &MF) {
102   const unsigned HighPriority = 3;
103   const unsigned LowPriority = 0;
104 
105   Function &F = MF.getFunction();
106   if (skipFunction(F) || !AMDGPU::isEntryFunctionCC(F.getCallingConv()))
107     return false;
108 
109   const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
110   TII = ST.getInstrInfo();
111 
112   unsigned VALUInstsThreshold = DefaultVALUInstsThreshold;
113   Attribute A = F.getFnAttribute("amdgpu-wave-priority-threshold");
114   if (A.isValid())
115     A.getValueAsString().getAsInteger(0, VALUInstsThreshold);
116 
117   // Find VMEM loads that may be executed before long-enough sequences of
118   // VALU instructions. We currently assume that backedges/loops, branch
119   // probabilities and other details can be ignored, so we essentially
120   // determine the largest number of VALU instructions along every
121   // possible path from the start of the function that may potentially be
122   // executed provided no backedge is ever taken.
123   MBBInfoSet MBBInfos;
124   for (MachineBasicBlock *MBB : post_order(&MF)) {
125     bool AtStart = true;
126     unsigned MaxNumVALUInstsInMiddle = 0;
127     unsigned NumVALUInstsAtEnd = 0;
128     for (MachineInstr &MI : *MBB) {
129       if (isVMEMLoad(MI)) {
130         AtStart = false;
131         MBBInfo &Info = MBBInfos[MBB];
132         Info.NumVALUInstsAtStart = 0;
133         MaxNumVALUInstsInMiddle = 0;
134         NumVALUInstsAtEnd = 0;
135         Info.LastVMEMLoad = &MI;
136       } else if (SIInstrInfo::isDS(MI)) {
137         AtStart = false;
138         MaxNumVALUInstsInMiddle =
139             std::max(MaxNumVALUInstsInMiddle, NumVALUInstsAtEnd);
140         NumVALUInstsAtEnd = 0;
141       } else if (SIInstrInfo::isVALU(MI)) {
142         if (AtStart)
143           ++MBBInfos[MBB].NumVALUInstsAtStart;
144         ++NumVALUInstsAtEnd;
145       }
146     }
147 
148     bool SuccsMayReachVMEMLoad = false;
149     unsigned NumFollowingVALUInsts = 0;
150     for (const MachineBasicBlock *Succ : MBB->successors()) {
151       SuccsMayReachVMEMLoad |= MBBInfos[Succ].MayReachVMEMLoad;
152       NumFollowingVALUInsts =
153           std::max(NumFollowingVALUInsts, MBBInfos[Succ].NumVALUInstsAtStart);
154     }
155     MBBInfo &Info = MBBInfos[MBB];
156     if (AtStart)
157       Info.NumVALUInstsAtStart += NumFollowingVALUInsts;
158     NumVALUInstsAtEnd += NumFollowingVALUInsts;
159 
160     unsigned MaxNumVALUInsts =
161         std::max(MaxNumVALUInstsInMiddle, NumVALUInstsAtEnd);
162     Info.MayReachVMEMLoad =
163         SuccsMayReachVMEMLoad ||
164         (Info.LastVMEMLoad && MaxNumVALUInsts >= VALUInstsThreshold);
165   }
166 
167   MachineBasicBlock &Entry = MF.front();
168   if (!MBBInfos[&Entry].MayReachVMEMLoad)
169     return false;
170 
171   // Raise the priority at the beginning of the shader.
172   MachineBasicBlock::iterator I = Entry.begin(), E = Entry.end();
173   while (I != E && !SIInstrInfo::isVALU(*I) && !I->isTerminator())
174     ++I;
175   BuildSetprioMI(Entry, I, HighPriority);
176 
177   // Lower the priority on edges where control leaves blocks from which
178   // the VMEM loads are reachable.
179   SmallSet<MachineBasicBlock *, 16> PriorityLoweringBlocks;
180   for (MachineBasicBlock &MBB : MF) {
181     if (MBBInfos[&MBB].MayReachVMEMLoad) {
182       if (MBB.succ_empty())
183         PriorityLoweringBlocks.insert(&MBB);
184       continue;
185     }
186 
187     if (CanLowerPriorityDirectlyInPredecessors(MBB, MBBInfos)) {
188       for (MachineBasicBlock *Pred : MBB.predecessors()) {
189         if (MBBInfos[Pred].MayReachVMEMLoad)
190           PriorityLoweringBlocks.insert(Pred);
191       }
192       continue;
193     }
194 
195     // Where lowering the priority in predecessors is not possible, the
196     // block receiving control either was not part of a loop in the first
197     // place or the loop simplification/canonicalization pass should have
198     // already tried to split the edge and insert a preheader, and if for
199     // whatever reason it failed to do so, then this leaves us with the
200     // only option of lowering the priority within the loop.
201     PriorityLoweringBlocks.insert(&MBB);
202   }
203 
204   for (MachineBasicBlock *MBB : PriorityLoweringBlocks) {
205     BuildSetprioMI(
206         *MBB,
207         MBBInfos[MBB].LastVMEMLoad
208             ? std::next(MachineBasicBlock::iterator(MBBInfos[MBB].LastVMEMLoad))
209             : MBB->begin(),
210         LowPriority);
211   }
212 
213   return true;
214 }
215