1 //===- AMDGPUSetWavePriority.cpp - Set wave priority ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Pass to temporarily raise the wave priority beginning the start of
11 /// the shader function until its last VMEM instructions to allow younger
12 /// waves to issue their VMEM instructions as well.
13 //
14 //===----------------------------------------------------------------------===//
15
16 #include "AMDGPU.h"
17 #include "GCNSubtarget.h"
18 #include "MCTargetDesc/AMDGPUMCTargetDesc.h"
19 #include "SIInstrInfo.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/CodeGen/MachineFunctionPass.h"
22 #include "llvm/CodeGen/MachinePassManager.h"
23
24 using namespace llvm;
25
26 #define DEBUG_TYPE "amdgpu-set-wave-priority"
27
28 static cl::opt<unsigned> DefaultVALUInstsThreshold(
29 "amdgpu-set-wave-priority-valu-insts-threshold",
30 cl::desc("VALU instruction count threshold for adjusting wave priority"),
31 cl::init(100), cl::Hidden);
32
33 namespace {
34
35 struct MBBInfo {
36 MBBInfo() = default;
37 unsigned NumVALUInstsAtStart = 0;
38 bool MayReachVMEMLoad = false;
39 MachineInstr *LastVMEMLoad = nullptr;
40 };
41
42 using MBBInfoSet = DenseMap<const MachineBasicBlock *, MBBInfo>;
43
44 class AMDGPUSetWavePriority {
45 public:
46 bool run(MachineFunction &MF);
47
48 private:
49 MachineInstr *BuildSetprioMI(MachineBasicBlock &MBB,
50 MachineBasicBlock::iterator I,
51 unsigned priority) const;
52
53 const SIInstrInfo *TII;
54 };
55
56 class AMDGPUSetWavePriorityLegacy : public MachineFunctionPass {
57 public:
58 static char ID;
59
AMDGPUSetWavePriorityLegacy()60 AMDGPUSetWavePriorityLegacy() : MachineFunctionPass(ID) {}
61
getPassName() const62 StringRef getPassName() const override { return "Set wave priority"; }
63
runOnMachineFunction(MachineFunction & MF)64 bool runOnMachineFunction(MachineFunction &MF) override {
65 if (skipFunction(MF.getFunction()))
66 return false;
67
68 return AMDGPUSetWavePriority().run(MF);
69 }
70 };
71
72 } // End anonymous namespace.
73
74 INITIALIZE_PASS(AMDGPUSetWavePriorityLegacy, DEBUG_TYPE, "Set wave priority",
75 false, false)
76
77 char AMDGPUSetWavePriorityLegacy::ID = 0;
78
createAMDGPUSetWavePriorityPass()79 FunctionPass *llvm::createAMDGPUSetWavePriorityPass() {
80 return new AMDGPUSetWavePriorityLegacy();
81 }
82
83 MachineInstr *
BuildSetprioMI(MachineBasicBlock & MBB,MachineBasicBlock::iterator I,unsigned priority) const84 AMDGPUSetWavePriority::BuildSetprioMI(MachineBasicBlock &MBB,
85 MachineBasicBlock::iterator I,
86 unsigned priority) const {
87 return BuildMI(MBB, I, DebugLoc(), TII->get(AMDGPU::S_SETPRIO))
88 .addImm(priority);
89 }
90
91 // Checks that for every predecessor Pred that can reach a VMEM load,
92 // none of Pred's successors can reach a VMEM load.
CanLowerPriorityDirectlyInPredecessors(const MachineBasicBlock & MBB,MBBInfoSet & MBBInfos)93 static bool CanLowerPriorityDirectlyInPredecessors(const MachineBasicBlock &MBB,
94 MBBInfoSet &MBBInfos) {
95 for (const MachineBasicBlock *Pred : MBB.predecessors()) {
96 if (!MBBInfos[Pred].MayReachVMEMLoad)
97 continue;
98 for (const MachineBasicBlock *Succ : Pred->successors()) {
99 if (MBBInfos[Succ].MayReachVMEMLoad)
100 return false;
101 }
102 }
103 return true;
104 }
105
isVMEMLoad(const MachineInstr & MI)106 static bool isVMEMLoad(const MachineInstr &MI) {
107 return SIInstrInfo::isVMEM(MI) && MI.mayLoad();
108 }
109
110 PreservedAnalyses
run(MachineFunction & MF,MachineFunctionAnalysisManager & MFAM)111 llvm::AMDGPUSetWavePriorityPass::run(MachineFunction &MF,
112 MachineFunctionAnalysisManager &MFAM) {
113 if (!AMDGPUSetWavePriority().run(MF))
114 return PreservedAnalyses::all();
115
116 return getMachineFunctionPassPreservedAnalyses();
117 }
118
run(MachineFunction & MF)119 bool AMDGPUSetWavePriority::run(MachineFunction &MF) {
120 const unsigned HighPriority = 3;
121 const unsigned LowPriority = 0;
122
123 Function &F = MF.getFunction();
124 if (!AMDGPU::isEntryFunctionCC(F.getCallingConv()))
125 return false;
126
127 const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
128 TII = ST.getInstrInfo();
129
130 unsigned VALUInstsThreshold = DefaultVALUInstsThreshold;
131 Attribute A = F.getFnAttribute("amdgpu-wave-priority-threshold");
132 if (A.isValid())
133 A.getValueAsString().getAsInteger(0, VALUInstsThreshold);
134
135 // Find VMEM loads that may be executed before long-enough sequences of
136 // VALU instructions. We currently assume that backedges/loops, branch
137 // probabilities and other details can be ignored, so we essentially
138 // determine the largest number of VALU instructions along every
139 // possible path from the start of the function that may potentially be
140 // executed provided no backedge is ever taken.
141 MBBInfoSet MBBInfos;
142 for (MachineBasicBlock *MBB : post_order(&MF)) {
143 bool AtStart = true;
144 unsigned MaxNumVALUInstsInMiddle = 0;
145 unsigned NumVALUInstsAtEnd = 0;
146 for (MachineInstr &MI : *MBB) {
147 if (isVMEMLoad(MI)) {
148 AtStart = false;
149 MBBInfo &Info = MBBInfos[MBB];
150 Info.NumVALUInstsAtStart = 0;
151 MaxNumVALUInstsInMiddle = 0;
152 NumVALUInstsAtEnd = 0;
153 Info.LastVMEMLoad = &MI;
154 } else if (SIInstrInfo::isDS(MI)) {
155 AtStart = false;
156 MaxNumVALUInstsInMiddle =
157 std::max(MaxNumVALUInstsInMiddle, NumVALUInstsAtEnd);
158 NumVALUInstsAtEnd = 0;
159 } else if (SIInstrInfo::isVALU(MI)) {
160 if (AtStart)
161 ++MBBInfos[MBB].NumVALUInstsAtStart;
162 ++NumVALUInstsAtEnd;
163 }
164 }
165
166 bool SuccsMayReachVMEMLoad = false;
167 unsigned NumFollowingVALUInsts = 0;
168 for (const MachineBasicBlock *Succ : MBB->successors()) {
169 const MBBInfo &SuccInfo = MBBInfos[Succ];
170 SuccsMayReachVMEMLoad |= SuccInfo.MayReachVMEMLoad;
171 NumFollowingVALUInsts =
172 std::max(NumFollowingVALUInsts, SuccInfo.NumVALUInstsAtStart);
173 }
174 MBBInfo &Info = MBBInfos[MBB];
175 if (AtStart)
176 Info.NumVALUInstsAtStart += NumFollowingVALUInsts;
177 NumVALUInstsAtEnd += NumFollowingVALUInsts;
178
179 unsigned MaxNumVALUInsts =
180 std::max(MaxNumVALUInstsInMiddle, NumVALUInstsAtEnd);
181 Info.MayReachVMEMLoad =
182 SuccsMayReachVMEMLoad ||
183 (Info.LastVMEMLoad && MaxNumVALUInsts >= VALUInstsThreshold);
184 }
185
186 MachineBasicBlock &Entry = MF.front();
187 if (!MBBInfos[&Entry].MayReachVMEMLoad)
188 return false;
189
190 // Raise the priority at the beginning of the shader.
191 MachineBasicBlock::iterator I = Entry.begin(), E = Entry.end();
192 while (I != E && !SIInstrInfo::isVALU(*I) && !I->isTerminator())
193 ++I;
194 BuildSetprioMI(Entry, I, HighPriority);
195
196 // Lower the priority on edges where control leaves blocks from which
197 // the VMEM loads are reachable.
198 SmallSet<MachineBasicBlock *, 16> PriorityLoweringBlocks;
199 for (MachineBasicBlock &MBB : MF) {
200 if (MBBInfos[&MBB].MayReachVMEMLoad) {
201 if (MBB.succ_empty())
202 PriorityLoweringBlocks.insert(&MBB);
203 continue;
204 }
205
206 if (CanLowerPriorityDirectlyInPredecessors(MBB, MBBInfos)) {
207 for (MachineBasicBlock *Pred : MBB.predecessors()) {
208 if (MBBInfos[Pred].MayReachVMEMLoad)
209 PriorityLoweringBlocks.insert(Pred);
210 }
211 continue;
212 }
213
214 // Where lowering the priority in predecessors is not possible, the
215 // block receiving control either was not part of a loop in the first
216 // place or the loop simplification/canonicalization pass should have
217 // already tried to split the edge and insert a preheader, and if for
218 // whatever reason it failed to do so, then this leaves us with the
219 // only option of lowering the priority within the loop.
220 PriorityLoweringBlocks.insert(&MBB);
221 }
222
223 for (MachineBasicBlock *MBB : PriorityLoweringBlocks) {
224 MachineInstr *LastVMEMLoad = MBBInfos[MBB].LastVMEMLoad;
225 BuildSetprioMI(*MBB,
226 LastVMEMLoad
227 ? std::next(MachineBasicBlock::iterator(LastVMEMLoad))
228 : MBB->begin(),
229 LowPriority);
230 }
231
232 return true;
233 }
234