1 //===-- AMDGPULowerKernelAttributes.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file This pass does attempts to make use of reqd_work_group_size metadata 10 /// to eliminate loads from the dispatch packet and to constant fold OpenCL 11 /// get_local_size-like functions. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "AMDGPU.h" 16 #include "llvm/Analysis/ValueTracking.h" 17 #include "llvm/CodeGen/Passes.h" 18 #include "llvm/CodeGen/TargetPassConfig.h" 19 #include "llvm/IR/Constants.h" 20 #include "llvm/IR/Function.h" 21 #include "llvm/IR/InstIterator.h" 22 #include "llvm/IR/Instructions.h" 23 #include "llvm/IR/IntrinsicsAMDGPU.h" 24 #include "llvm/IR/PatternMatch.h" 25 #include "llvm/Pass.h" 26 27 #define DEBUG_TYPE "amdgpu-lower-kernel-attributes" 28 29 using namespace llvm; 30 31 namespace { 32 33 // Field offsets in hsa_kernel_dispatch_packet_t. 34 enum DispatchPackedOffsets { 35 WORKGROUP_SIZE_X = 4, 36 WORKGROUP_SIZE_Y = 6, 37 WORKGROUP_SIZE_Z = 8, 38 39 GRID_SIZE_X = 12, 40 GRID_SIZE_Y = 16, 41 GRID_SIZE_Z = 20 42 }; 43 44 class AMDGPULowerKernelAttributes : public ModulePass { 45 public: 46 static char ID; 47 48 AMDGPULowerKernelAttributes() : ModulePass(ID) {} 49 50 bool runOnModule(Module &M) override; 51 52 StringRef getPassName() const override { 53 return "AMDGPU Kernel Attributes"; 54 } 55 56 void getAnalysisUsage(AnalysisUsage &AU) const override { 57 AU.setPreservesAll(); 58 } 59 }; 60 61 } // end anonymous namespace 62 63 static bool processUse(CallInst *CI) { 64 Function *F = CI->getParent()->getParent(); 65 66 auto MD = F->getMetadata("reqd_work_group_size"); 67 const bool HasReqdWorkGroupSize = MD && MD->getNumOperands() == 3; 68 69 const bool HasUniformWorkGroupSize = 70 F->getFnAttribute("uniform-work-group-size").getValueAsString() == "true"; 71 72 if (!HasReqdWorkGroupSize && !HasUniformWorkGroupSize) 73 return false; 74 75 Value *WorkGroupSizeX = nullptr; 76 Value *WorkGroupSizeY = nullptr; 77 Value *WorkGroupSizeZ = nullptr; 78 79 Value *GridSizeX = nullptr; 80 Value *GridSizeY = nullptr; 81 Value *GridSizeZ = nullptr; 82 83 const DataLayout &DL = F->getParent()->getDataLayout(); 84 85 // We expect to see several GEP users, casted to the appropriate type and 86 // loaded. 87 for (User *U : CI->users()) { 88 if (!U->hasOneUse()) 89 continue; 90 91 int64_t Offset = 0; 92 if (GetPointerBaseWithConstantOffset(U, Offset, DL) != CI) 93 continue; 94 95 auto *BCI = dyn_cast<BitCastInst>(*U->user_begin()); 96 if (!BCI || !BCI->hasOneUse()) 97 continue; 98 99 auto *Load = dyn_cast<LoadInst>(*BCI->user_begin()); 100 if (!Load || !Load->isSimple()) 101 continue; 102 103 unsigned LoadSize = DL.getTypeStoreSize(Load->getType()); 104 105 // TODO: Handle merged loads. 106 switch (Offset) { 107 case WORKGROUP_SIZE_X: 108 if (LoadSize == 2) 109 WorkGroupSizeX = Load; 110 break; 111 case WORKGROUP_SIZE_Y: 112 if (LoadSize == 2) 113 WorkGroupSizeY = Load; 114 break; 115 case WORKGROUP_SIZE_Z: 116 if (LoadSize == 2) 117 WorkGroupSizeZ = Load; 118 break; 119 case GRID_SIZE_X: 120 if (LoadSize == 4) 121 GridSizeX = Load; 122 break; 123 case GRID_SIZE_Y: 124 if (LoadSize == 4) 125 GridSizeY = Load; 126 break; 127 case GRID_SIZE_Z: 128 if (LoadSize == 4) 129 GridSizeZ = Load; 130 break; 131 default: 132 break; 133 } 134 } 135 136 // Pattern match the code used to handle partial workgroup dispatches in the 137 // library implementation of get_local_size, so the entire function can be 138 // constant folded with a known group size. 139 // 140 // uint r = grid_size - group_id * group_size; 141 // get_local_size = (r < group_size) ? r : group_size; 142 // 143 // If we have uniform-work-group-size (which is the default in OpenCL 1.2), 144 // the grid_size is required to be a multiple of group_size). In this case: 145 // 146 // grid_size - (group_id * group_size) < group_size 147 // -> 148 // grid_size < group_size + (group_id * group_size) 149 // 150 // (grid_size / group_size) < 1 + group_id 151 // 152 // grid_size / group_size is at least 1, so we can conclude the select 153 // condition is false (except for group_id == 0, where the select result is 154 // the same). 155 156 bool MadeChange = false; 157 Value *WorkGroupSizes[3] = { WorkGroupSizeX, WorkGroupSizeY, WorkGroupSizeZ }; 158 Value *GridSizes[3] = { GridSizeX, GridSizeY, GridSizeZ }; 159 160 for (int I = 0; HasUniformWorkGroupSize && I < 3; ++I) { 161 Value *GroupSize = WorkGroupSizes[I]; 162 Value *GridSize = GridSizes[I]; 163 if (!GroupSize || !GridSize) 164 continue; 165 166 for (User *U : GroupSize->users()) { 167 auto *ZextGroupSize = dyn_cast<ZExtInst>(U); 168 if (!ZextGroupSize) 169 continue; 170 171 for (User *ZextUser : ZextGroupSize->users()) { 172 auto *SI = dyn_cast<SelectInst>(ZextUser); 173 if (!SI) 174 continue; 175 176 using namespace llvm::PatternMatch; 177 auto GroupIDIntrin = I == 0 ? 178 m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>() : 179 (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>() : 180 m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>()); 181 182 auto SubExpr = m_Sub(m_Specific(GridSize), 183 m_Mul(GroupIDIntrin, m_Specific(ZextGroupSize))); 184 185 ICmpInst::Predicate Pred; 186 if (match(SI, 187 m_Select(m_ICmp(Pred, SubExpr, m_Specific(ZextGroupSize)), 188 SubExpr, 189 m_Specific(ZextGroupSize))) && 190 Pred == ICmpInst::ICMP_ULT) { 191 if (HasReqdWorkGroupSize) { 192 ConstantInt *KnownSize 193 = mdconst::extract<ConstantInt>(MD->getOperand(I)); 194 SI->replaceAllUsesWith(ConstantExpr::getIntegerCast(KnownSize, 195 SI->getType(), 196 false)); 197 } else { 198 SI->replaceAllUsesWith(ZextGroupSize); 199 } 200 201 MadeChange = true; 202 } 203 } 204 } 205 } 206 207 if (!HasReqdWorkGroupSize) 208 return MadeChange; 209 210 // Eliminate any other loads we can from the dispatch packet. 211 for (int I = 0; I < 3; ++I) { 212 Value *GroupSize = WorkGroupSizes[I]; 213 if (!GroupSize) 214 continue; 215 216 ConstantInt *KnownSize = mdconst::extract<ConstantInt>(MD->getOperand(I)); 217 GroupSize->replaceAllUsesWith( 218 ConstantExpr::getIntegerCast(KnownSize, 219 GroupSize->getType(), 220 false)); 221 MadeChange = true; 222 } 223 224 return MadeChange; 225 } 226 227 // TODO: Move makeLIDRangeMetadata usage into here. Seem to not get 228 // TargetPassConfig for subtarget. 229 bool AMDGPULowerKernelAttributes::runOnModule(Module &M) { 230 StringRef DispatchPtrName 231 = Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr); 232 233 Function *DispatchPtr = M.getFunction(DispatchPtrName); 234 if (!DispatchPtr) // Dispatch ptr not used. 235 return false; 236 237 bool MadeChange = false; 238 239 SmallPtrSet<Instruction *, 4> HandledUses; 240 for (auto *U : DispatchPtr->users()) { 241 CallInst *CI = cast<CallInst>(U); 242 if (HandledUses.insert(CI).second) { 243 if (processUse(CI)) 244 MadeChange = true; 245 } 246 } 247 248 return MadeChange; 249 } 250 251 INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes, DEBUG_TYPE, 252 "AMDGPU IR optimizations", false, false) 253 INITIALIZE_PASS_END(AMDGPULowerKernelAttributes, DEBUG_TYPE, "AMDGPU IR optimizations", 254 false, false) 255 256 char AMDGPULowerKernelAttributes::ID = 0; 257 258 ModulePass *llvm::createAMDGPULowerKernelAttributesPass() { 259 return new AMDGPULowerKernelAttributes(); 260 } 261 262 PreservedAnalyses 263 AMDGPULowerKernelAttributesPass::run(Function &F, FunctionAnalysisManager &AM) { 264 StringRef DispatchPtrName = 265 Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr); 266 267 Function *DispatchPtr = F.getParent()->getFunction(DispatchPtrName); 268 if (!DispatchPtr) // Dispatch ptr not used. 269 return PreservedAnalyses::all(); 270 271 for (Instruction &I : instructions(F)) { 272 if (CallInst *CI = dyn_cast<CallInst>(&I)) { 273 if (CI->getCalledFunction() == DispatchPtr) 274 processUse(CI); 275 } 276 } 277 278 return PreservedAnalyses::all(); 279 } 280