1 //===-- AMDGPULowerKernelAttributes.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file This pass does attempts to make use of reqd_work_group_size metadata 10 /// to eliminate loads from the dispatch packet and to constant fold OpenCL 11 /// get_local_size-like functions. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "AMDGPU.h" 16 #include "llvm/Analysis/ValueTracking.h" 17 #include "llvm/CodeGen/Passes.h" 18 #include "llvm/CodeGen/TargetPassConfig.h" 19 #include "llvm/IR/Constants.h" 20 #include "llvm/IR/Function.h" 21 #include "llvm/IR/InstIterator.h" 22 #include "llvm/IR/Instructions.h" 23 #include "llvm/IR/IntrinsicsAMDGPU.h" 24 #include "llvm/IR/PatternMatch.h" 25 #include "llvm/Pass.h" 26 27 #define DEBUG_TYPE "amdgpu-lower-kernel-attributes" 28 29 using namespace llvm; 30 31 namespace { 32 33 // Field offsets in hsa_kernel_dispatch_packet_t. 34 enum DispatchPackedOffsets { 35 WORKGROUP_SIZE_X = 4, 36 WORKGROUP_SIZE_Y = 6, 37 WORKGROUP_SIZE_Z = 8, 38 39 GRID_SIZE_X = 12, 40 GRID_SIZE_Y = 16, 41 GRID_SIZE_Z = 20 42 }; 43 44 class AMDGPULowerKernelAttributes : public ModulePass { 45 public: 46 static char ID; 47 48 AMDGPULowerKernelAttributes() : ModulePass(ID) {} 49 50 bool runOnModule(Module &M) override; 51 52 StringRef getPassName() const override { 53 return "AMDGPU Kernel Attributes"; 54 } 55 56 void getAnalysisUsage(AnalysisUsage &AU) const override { 57 AU.setPreservesAll(); 58 } 59 }; 60 61 } // end anonymous namespace 62 63 static bool processUse(CallInst *CI) { 64 Function *F = CI->getParent()->getParent(); 65 66 auto MD = F->getMetadata("reqd_work_group_size"); 67 const bool HasReqdWorkGroupSize = MD && MD->getNumOperands() == 3; 68 69 const bool HasUniformWorkGroupSize = 70 F->getFnAttribute("uniform-work-group-size").getValueAsBool(); 71 72 if (!HasReqdWorkGroupSize && !HasUniformWorkGroupSize) 73 return false; 74 75 Value *WorkGroupSizeX = nullptr; 76 Value *WorkGroupSizeY = nullptr; 77 Value *WorkGroupSizeZ = nullptr; 78 79 Value *GridSizeX = nullptr; 80 Value *GridSizeY = nullptr; 81 Value *GridSizeZ = nullptr; 82 83 const DataLayout &DL = F->getParent()->getDataLayout(); 84 85 // We expect to see several GEP users, casted to the appropriate type and 86 // loaded. 87 for (User *U : CI->users()) { 88 if (!U->hasOneUse()) 89 continue; 90 91 int64_t Offset = 0; 92 if (GetPointerBaseWithConstantOffset(U, Offset, DL) != CI) 93 continue; 94 95 auto *BCI = dyn_cast<BitCastInst>(*U->user_begin()); 96 if (!BCI || !BCI->hasOneUse()) 97 continue; 98 99 auto *Load = dyn_cast<LoadInst>(*BCI->user_begin()); 100 if (!Load || !Load->isSimple()) 101 continue; 102 103 unsigned LoadSize = DL.getTypeStoreSize(Load->getType()); 104 105 // TODO: Handle merged loads. 106 switch (Offset) { 107 case WORKGROUP_SIZE_X: 108 if (LoadSize == 2) 109 WorkGroupSizeX = Load; 110 break; 111 case WORKGROUP_SIZE_Y: 112 if (LoadSize == 2) 113 WorkGroupSizeY = Load; 114 break; 115 case WORKGROUP_SIZE_Z: 116 if (LoadSize == 2) 117 WorkGroupSizeZ = Load; 118 break; 119 case GRID_SIZE_X: 120 if (LoadSize == 4) 121 GridSizeX = Load; 122 break; 123 case GRID_SIZE_Y: 124 if (LoadSize == 4) 125 GridSizeY = Load; 126 break; 127 case GRID_SIZE_Z: 128 if (LoadSize == 4) 129 GridSizeZ = Load; 130 break; 131 default: 132 break; 133 } 134 } 135 136 // Pattern match the code used to handle partial workgroup dispatches in the 137 // library implementation of get_local_size, so the entire function can be 138 // constant folded with a known group size. 139 // 140 // uint r = grid_size - group_id * group_size; 141 // get_local_size = (r < group_size) ? r : group_size; 142 // 143 // If we have uniform-work-group-size (which is the default in OpenCL 1.2), 144 // the grid_size is required to be a multiple of group_size). In this case: 145 // 146 // grid_size - (group_id * group_size) < group_size 147 // -> 148 // grid_size < group_size + (group_id * group_size) 149 // 150 // (grid_size / group_size) < 1 + group_id 151 // 152 // grid_size / group_size is at least 1, so we can conclude the select 153 // condition is false (except for group_id == 0, where the select result is 154 // the same). 155 156 bool MadeChange = false; 157 Value *WorkGroupSizes[3] = { WorkGroupSizeX, WorkGroupSizeY, WorkGroupSizeZ }; 158 Value *GridSizes[3] = { GridSizeX, GridSizeY, GridSizeZ }; 159 160 for (int I = 0; HasUniformWorkGroupSize && I < 3; ++I) { 161 Value *GroupSize = WorkGroupSizes[I]; 162 Value *GridSize = GridSizes[I]; 163 if (!GroupSize || !GridSize) 164 continue; 165 166 using namespace llvm::PatternMatch; 167 auto GroupIDIntrin = 168 I == 0 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>() 169 : (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>() 170 : m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>()); 171 172 for (User *U : GroupSize->users()) { 173 auto *ZextGroupSize = dyn_cast<ZExtInst>(U); 174 if (!ZextGroupSize) 175 continue; 176 177 for (User *UMin : ZextGroupSize->users()) { 178 if (match(UMin, 179 m_UMin(m_Sub(m_Specific(GridSize), 180 m_Mul(GroupIDIntrin, m_Specific(ZextGroupSize))), 181 m_Specific(ZextGroupSize)))) { 182 if (HasReqdWorkGroupSize) { 183 ConstantInt *KnownSize 184 = mdconst::extract<ConstantInt>(MD->getOperand(I)); 185 UMin->replaceAllUsesWith(ConstantExpr::getIntegerCast( 186 KnownSize, UMin->getType(), false)); 187 } else { 188 UMin->replaceAllUsesWith(ZextGroupSize); 189 } 190 191 MadeChange = true; 192 } 193 } 194 } 195 } 196 197 if (!HasReqdWorkGroupSize) 198 return MadeChange; 199 200 // Eliminate any other loads we can from the dispatch packet. 201 for (int I = 0; I < 3; ++I) { 202 Value *GroupSize = WorkGroupSizes[I]; 203 if (!GroupSize) 204 continue; 205 206 ConstantInt *KnownSize = mdconst::extract<ConstantInt>(MD->getOperand(I)); 207 GroupSize->replaceAllUsesWith( 208 ConstantExpr::getIntegerCast(KnownSize, 209 GroupSize->getType(), 210 false)); 211 MadeChange = true; 212 } 213 214 return MadeChange; 215 } 216 217 // TODO: Move makeLIDRangeMetadata usage into here. Seem to not get 218 // TargetPassConfig for subtarget. 219 bool AMDGPULowerKernelAttributes::runOnModule(Module &M) { 220 StringRef DispatchPtrName 221 = Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr); 222 223 Function *DispatchPtr = M.getFunction(DispatchPtrName); 224 if (!DispatchPtr) // Dispatch ptr not used. 225 return false; 226 227 bool MadeChange = false; 228 229 SmallPtrSet<Instruction *, 4> HandledUses; 230 for (auto *U : DispatchPtr->users()) { 231 CallInst *CI = cast<CallInst>(U); 232 if (HandledUses.insert(CI).second) { 233 if (processUse(CI)) 234 MadeChange = true; 235 } 236 } 237 238 return MadeChange; 239 } 240 241 INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes, DEBUG_TYPE, 242 "AMDGPU Kernel Attributes", false, false) 243 INITIALIZE_PASS_END(AMDGPULowerKernelAttributes, DEBUG_TYPE, 244 "AMDGPU Kernel Attributes", false, false) 245 246 char AMDGPULowerKernelAttributes::ID = 0; 247 248 ModulePass *llvm::createAMDGPULowerKernelAttributesPass() { 249 return new AMDGPULowerKernelAttributes(); 250 } 251 252 PreservedAnalyses 253 AMDGPULowerKernelAttributesPass::run(Function &F, FunctionAnalysisManager &AM) { 254 StringRef DispatchPtrName = 255 Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr); 256 257 Function *DispatchPtr = F.getParent()->getFunction(DispatchPtrName); 258 if (!DispatchPtr) // Dispatch ptr not used. 259 return PreservedAnalyses::all(); 260 261 for (Instruction &I : instructions(F)) { 262 if (CallInst *CI = dyn_cast<CallInst>(&I)) { 263 if (CI->getCalledFunction() == DispatchPtr) 264 processUse(CI); 265 } 266 } 267 268 return PreservedAnalyses::all(); 269 } 270