1 //===-- AMDGPULowerKernelAttributes.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file This pass does attempts to make use of reqd_work_group_size metadata
10 /// to eliminate loads from the dispatch packet and to constant fold OpenCL
11 /// get_local_size-like functions.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "AMDGPU.h"
16 #include "Utils/AMDGPUBaseInfo.h"
17 #include "llvm/Analysis/ConstantFolding.h"
18 #include "llvm/Analysis/ValueTracking.h"
19 #include "llvm/CodeGen/Passes.h"
20 #include "llvm/IR/Constants.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/InstIterator.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/IntrinsicsAMDGPU.h"
25 #include "llvm/IR/MDBuilder.h"
26 #include "llvm/IR/PatternMatch.h"
27 #include "llvm/Pass.h"
28
29 #define DEBUG_TYPE "amdgpu-lower-kernel-attributes"
30
31 using namespace llvm;
32
33 namespace {
34
35 // Field offsets in hsa_kernel_dispatch_packet_t.
36 enum DispatchPackedOffsets {
37 WORKGROUP_SIZE_X = 4,
38 WORKGROUP_SIZE_Y = 6,
39 WORKGROUP_SIZE_Z = 8,
40
41 GRID_SIZE_X = 12,
42 GRID_SIZE_Y = 16,
43 GRID_SIZE_Z = 20
44 };
45
46 // Field offsets to implicit kernel argument pointer.
47 enum ImplicitArgOffsets {
48 HIDDEN_BLOCK_COUNT_X = 0,
49 HIDDEN_BLOCK_COUNT_Y = 4,
50 HIDDEN_BLOCK_COUNT_Z = 8,
51
52 HIDDEN_GROUP_SIZE_X = 12,
53 HIDDEN_GROUP_SIZE_Y = 14,
54 HIDDEN_GROUP_SIZE_Z = 16,
55
56 HIDDEN_REMAINDER_X = 18,
57 HIDDEN_REMAINDER_Y = 20,
58 HIDDEN_REMAINDER_Z = 22,
59 };
60
61 class AMDGPULowerKernelAttributes : public ModulePass {
62 public:
63 static char ID;
64
AMDGPULowerKernelAttributes()65 AMDGPULowerKernelAttributes() : ModulePass(ID) {}
66
67 bool runOnModule(Module &M) override;
68
getPassName() const69 StringRef getPassName() const override {
70 return "AMDGPU Kernel Attributes";
71 }
72
getAnalysisUsage(AnalysisUsage & AU) const73 void getAnalysisUsage(AnalysisUsage &AU) const override {
74 AU.setPreservesAll();
75 }
76 };
77
getBasePtrIntrinsic(Module & M,bool IsV5OrAbove)78 Function *getBasePtrIntrinsic(Module &M, bool IsV5OrAbove) {
79 auto IntrinsicId = IsV5OrAbove ? Intrinsic::amdgcn_implicitarg_ptr
80 : Intrinsic::amdgcn_dispatch_ptr;
81 return Intrinsic::getDeclarationIfExists(&M, IntrinsicId);
82 }
83
84 } // end anonymous namespace
85
annotateGridSizeLoadWithRangeMD(LoadInst * Load,uint32_t MaxNumGroups)86 static void annotateGridSizeLoadWithRangeMD(LoadInst *Load,
87 uint32_t MaxNumGroups) {
88 if (MaxNumGroups == 0 || MaxNumGroups == std::numeric_limits<uint32_t>::max())
89 return;
90
91 if (!Load->getType()->isIntegerTy(32))
92 return;
93
94 // TODO: If there is existing range metadata, preserve it if it is stricter.
95 MDBuilder MDB(Load->getContext());
96 MDNode *Range = MDB.createRange(APInt(32, 1), APInt(32, MaxNumGroups + 1));
97 Load->setMetadata(LLVMContext::MD_range, Range);
98 }
99
processUse(CallInst * CI,bool IsV5OrAbove)100 static bool processUse(CallInst *CI, bool IsV5OrAbove) {
101 Function *F = CI->getParent()->getParent();
102
103 auto *MD = F->getMetadata("reqd_work_group_size");
104 const bool HasReqdWorkGroupSize = MD && MD->getNumOperands() == 3;
105
106 const bool HasUniformWorkGroupSize =
107 F->getFnAttribute("uniform-work-group-size").getValueAsBool();
108
109 SmallVector<unsigned> MaxNumWorkgroups =
110 AMDGPU::getIntegerVecAttribute(*F, "amdgpu-max-num-workgroups",
111 /*Size=*/3, /*DefaultVal=*/0);
112
113 if (!HasReqdWorkGroupSize && !HasUniformWorkGroupSize &&
114 none_of(MaxNumWorkgroups, [](unsigned X) { return X != 0; }))
115 return false;
116
117 Value *BlockCounts[3] = {nullptr, nullptr, nullptr};
118 Value *GroupSizes[3] = {nullptr, nullptr, nullptr};
119 Value *Remainders[3] = {nullptr, nullptr, nullptr};
120 Value *GridSizes[3] = {nullptr, nullptr, nullptr};
121
122 const DataLayout &DL = F->getDataLayout();
123
124 // We expect to see several GEP users, casted to the appropriate type and
125 // loaded.
126 for (User *U : CI->users()) {
127 if (!U->hasOneUse())
128 continue;
129
130 int64_t Offset = 0;
131 auto *Load = dyn_cast<LoadInst>(U); // Load from ImplicitArgPtr/DispatchPtr?
132 auto *BCI = dyn_cast<BitCastInst>(U);
133 if (!Load && !BCI) {
134 if (GetPointerBaseWithConstantOffset(U, Offset, DL) != CI)
135 continue;
136 Load = dyn_cast<LoadInst>(*U->user_begin()); // Load from GEP?
137 BCI = dyn_cast<BitCastInst>(*U->user_begin());
138 }
139
140 if (BCI) {
141 if (!BCI->hasOneUse())
142 continue;
143 Load = dyn_cast<LoadInst>(*BCI->user_begin()); // Load from BCI?
144 }
145
146 if (!Load || !Load->isSimple())
147 continue;
148
149 unsigned LoadSize = DL.getTypeStoreSize(Load->getType());
150
151 // TODO: Handle merged loads.
152 if (IsV5OrAbove) { // Base is ImplicitArgPtr.
153 switch (Offset) {
154 case HIDDEN_BLOCK_COUNT_X:
155 if (LoadSize == 4) {
156 BlockCounts[0] = Load;
157 annotateGridSizeLoadWithRangeMD(Load, MaxNumWorkgroups[0]);
158 }
159 break;
160 case HIDDEN_BLOCK_COUNT_Y:
161 if (LoadSize == 4) {
162 BlockCounts[1] = Load;
163 annotateGridSizeLoadWithRangeMD(Load, MaxNumWorkgroups[1]);
164 }
165 break;
166 case HIDDEN_BLOCK_COUNT_Z:
167 if (LoadSize == 4) {
168 BlockCounts[2] = Load;
169 annotateGridSizeLoadWithRangeMD(Load, MaxNumWorkgroups[2]);
170 }
171 break;
172 case HIDDEN_GROUP_SIZE_X:
173 if (LoadSize == 2)
174 GroupSizes[0] = Load;
175 break;
176 case HIDDEN_GROUP_SIZE_Y:
177 if (LoadSize == 2)
178 GroupSizes[1] = Load;
179 break;
180 case HIDDEN_GROUP_SIZE_Z:
181 if (LoadSize == 2)
182 GroupSizes[2] = Load;
183 break;
184 case HIDDEN_REMAINDER_X:
185 if (LoadSize == 2)
186 Remainders[0] = Load;
187 break;
188 case HIDDEN_REMAINDER_Y:
189 if (LoadSize == 2)
190 Remainders[1] = Load;
191 break;
192 case HIDDEN_REMAINDER_Z:
193 if (LoadSize == 2)
194 Remainders[2] = Load;
195 break;
196 default:
197 break;
198 }
199 } else { // Base is DispatchPtr.
200 switch (Offset) {
201 case WORKGROUP_SIZE_X:
202 if (LoadSize == 2)
203 GroupSizes[0] = Load;
204 break;
205 case WORKGROUP_SIZE_Y:
206 if (LoadSize == 2)
207 GroupSizes[1] = Load;
208 break;
209 case WORKGROUP_SIZE_Z:
210 if (LoadSize == 2)
211 GroupSizes[2] = Load;
212 break;
213 case GRID_SIZE_X:
214 if (LoadSize == 4)
215 GridSizes[0] = Load;
216 break;
217 case GRID_SIZE_Y:
218 if (LoadSize == 4)
219 GridSizes[1] = Load;
220 break;
221 case GRID_SIZE_Z:
222 if (LoadSize == 4)
223 GridSizes[2] = Load;
224 break;
225 default:
226 break;
227 }
228 }
229 }
230
231 bool MadeChange = false;
232 if (IsV5OrAbove && HasUniformWorkGroupSize) {
233 // Under v5 __ockl_get_local_size returns the value computed by the expression:
234 //
235 // workgroup_id < hidden_block_count ? hidden_group_size : hidden_remainder
236 //
237 // For functions with the attribute uniform-work-group-size=true. we can evaluate
238 // workgroup_id < hidden_block_count as true, and thus hidden_group_size is returned
239 // for __ockl_get_local_size.
240 for (int I = 0; I < 3; ++I) {
241 Value *BlockCount = BlockCounts[I];
242 if (!BlockCount)
243 continue;
244
245 using namespace llvm::PatternMatch;
246 auto GroupIDIntrin =
247 I == 0 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>()
248 : (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>()
249 : m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>());
250
251 for (User *ICmp : BlockCount->users()) {
252 if (match(ICmp, m_SpecificICmp(ICmpInst::ICMP_ULT, GroupIDIntrin,
253 m_Specific(BlockCount)))) {
254 ICmp->replaceAllUsesWith(llvm::ConstantInt::getTrue(ICmp->getType()));
255 MadeChange = true;
256 }
257 }
258 }
259
260 // All remainders should be 0 with uniform work group size.
261 for (Value *Remainder : Remainders) {
262 if (!Remainder)
263 continue;
264 Remainder->replaceAllUsesWith(Constant::getNullValue(Remainder->getType()));
265 MadeChange = true;
266 }
267 } else if (HasUniformWorkGroupSize) { // Pre-V5.
268 // Pattern match the code used to handle partial workgroup dispatches in the
269 // library implementation of get_local_size, so the entire function can be
270 // constant folded with a known group size.
271 //
272 // uint r = grid_size - group_id * group_size;
273 // get_local_size = (r < group_size) ? r : group_size;
274 //
275 // If we have uniform-work-group-size (which is the default in OpenCL 1.2),
276 // the grid_size is required to be a multiple of group_size). In this case:
277 //
278 // grid_size - (group_id * group_size) < group_size
279 // ->
280 // grid_size < group_size + (group_id * group_size)
281 //
282 // (grid_size / group_size) < 1 + group_id
283 //
284 // grid_size / group_size is at least 1, so we can conclude the select
285 // condition is false (except for group_id == 0, where the select result is
286 // the same).
287 for (int I = 0; I < 3; ++I) {
288 Value *GroupSize = GroupSizes[I];
289 Value *GridSize = GridSizes[I];
290 if (!GroupSize || !GridSize)
291 continue;
292
293 using namespace llvm::PatternMatch;
294 auto GroupIDIntrin =
295 I == 0 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>()
296 : (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>()
297 : m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>());
298
299 for (User *U : GroupSize->users()) {
300 auto *ZextGroupSize = dyn_cast<ZExtInst>(U);
301 if (!ZextGroupSize)
302 continue;
303
304 for (User *UMin : ZextGroupSize->users()) {
305 if (match(UMin,
306 m_UMin(m_Sub(m_Specific(GridSize),
307 m_Mul(GroupIDIntrin, m_Specific(ZextGroupSize))),
308 m_Specific(ZextGroupSize)))) {
309 if (HasReqdWorkGroupSize) {
310 ConstantInt *KnownSize
311 = mdconst::extract<ConstantInt>(MD->getOperand(I));
312 UMin->replaceAllUsesWith(ConstantFoldIntegerCast(
313 KnownSize, UMin->getType(), false, DL));
314 } else {
315 UMin->replaceAllUsesWith(ZextGroupSize);
316 }
317
318 MadeChange = true;
319 }
320 }
321 }
322 }
323 }
324
325 // If reqd_work_group_size is set, we can replace work group size with it.
326 if (!HasReqdWorkGroupSize)
327 return MadeChange;
328
329 for (int I = 0; I < 3; I++) {
330 Value *GroupSize = GroupSizes[I];
331 if (!GroupSize)
332 continue;
333
334 ConstantInt *KnownSize = mdconst::extract<ConstantInt>(MD->getOperand(I));
335 GroupSize->replaceAllUsesWith(
336 ConstantFoldIntegerCast(KnownSize, GroupSize->getType(), false, DL));
337 MadeChange = true;
338 }
339
340 return MadeChange;
341 }
342
343
344 // TODO: Move makeLIDRangeMetadata usage into here. Seem to not get
345 // TargetPassConfig for subtarget.
runOnModule(Module & M)346 bool AMDGPULowerKernelAttributes::runOnModule(Module &M) {
347 bool MadeChange = false;
348 bool IsV5OrAbove =
349 AMDGPU::getAMDHSACodeObjectVersion(M) >= AMDGPU::AMDHSA_COV5;
350 Function *BasePtr = getBasePtrIntrinsic(M, IsV5OrAbove);
351
352 if (!BasePtr) // ImplicitArgPtr/DispatchPtr not used.
353 return false;
354
355 SmallPtrSet<Instruction *, 4> HandledUses;
356 for (auto *U : BasePtr->users()) {
357 CallInst *CI = cast<CallInst>(U);
358 if (HandledUses.insert(CI).second) {
359 if (processUse(CI, IsV5OrAbove))
360 MadeChange = true;
361 }
362 }
363
364 return MadeChange;
365 }
366
367
368 INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes, DEBUG_TYPE,
369 "AMDGPU Kernel Attributes", false, false)
370 INITIALIZE_PASS_END(AMDGPULowerKernelAttributes, DEBUG_TYPE,
371 "AMDGPU Kernel Attributes", false, false)
372
373 char AMDGPULowerKernelAttributes::ID = 0;
374
createAMDGPULowerKernelAttributesPass()375 ModulePass *llvm::createAMDGPULowerKernelAttributesPass() {
376 return new AMDGPULowerKernelAttributes();
377 }
378
379 PreservedAnalyses
run(Function & F,FunctionAnalysisManager & AM)380 AMDGPULowerKernelAttributesPass::run(Function &F, FunctionAnalysisManager &AM) {
381 bool IsV5OrAbove =
382 AMDGPU::getAMDHSACodeObjectVersion(*F.getParent()) >= AMDGPU::AMDHSA_COV5;
383 Function *BasePtr = getBasePtrIntrinsic(*F.getParent(), IsV5OrAbove);
384
385 if (!BasePtr) // ImplicitArgPtr/DispatchPtr not used.
386 return PreservedAnalyses::all();
387
388 for (Instruction &I : instructions(F)) {
389 if (CallInst *CI = dyn_cast<CallInst>(&I)) {
390 if (CI->getCalledFunction() == BasePtr)
391 processUse(CI, IsV5OrAbove);
392 }
393 }
394
395 return PreservedAnalyses::all();
396 }
397