xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPULowerKernelAttributes.cpp (revision fe6060f10f634930ff71b7c50291ddc610da2475)
10b57cec5SDimitry Andric //===-- AMDGPULowerKernelAttributes.cpp ------------------------------------------===//
20b57cec5SDimitry Andric //
30b57cec5SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40b57cec5SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
50b57cec5SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60b57cec5SDimitry Andric //
70b57cec5SDimitry Andric //===----------------------------------------------------------------------===//
80b57cec5SDimitry Andric //
90b57cec5SDimitry Andric /// \file This pass does attempts to make use of reqd_work_group_size metadata
100b57cec5SDimitry Andric /// to eliminate loads from the dispatch packet and to constant fold OpenCL
110b57cec5SDimitry Andric /// get_local_size-like functions.
120b57cec5SDimitry Andric //
130b57cec5SDimitry Andric //===----------------------------------------------------------------------===//
140b57cec5SDimitry Andric 
150b57cec5SDimitry Andric #include "AMDGPU.h"
160b57cec5SDimitry Andric #include "llvm/Analysis/ValueTracking.h"
170b57cec5SDimitry Andric #include "llvm/CodeGen/Passes.h"
180b57cec5SDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h"
190b57cec5SDimitry Andric #include "llvm/IR/Constants.h"
200b57cec5SDimitry Andric #include "llvm/IR/Function.h"
21e8d8bef9SDimitry Andric #include "llvm/IR/InstIterator.h"
220b57cec5SDimitry Andric #include "llvm/IR/Instructions.h"
23e8d8bef9SDimitry Andric #include "llvm/IR/IntrinsicsAMDGPU.h"
240b57cec5SDimitry Andric #include "llvm/IR/PatternMatch.h"
250b57cec5SDimitry Andric #include "llvm/Pass.h"
260b57cec5SDimitry Andric 
270b57cec5SDimitry Andric #define DEBUG_TYPE "amdgpu-lower-kernel-attributes"
280b57cec5SDimitry Andric 
290b57cec5SDimitry Andric using namespace llvm;
300b57cec5SDimitry Andric 
310b57cec5SDimitry Andric namespace {
320b57cec5SDimitry Andric 
330b57cec5SDimitry Andric // Field offsets in hsa_kernel_dispatch_packet_t.
340b57cec5SDimitry Andric enum DispatchPackedOffsets {
350b57cec5SDimitry Andric   WORKGROUP_SIZE_X = 4,
360b57cec5SDimitry Andric   WORKGROUP_SIZE_Y = 6,
370b57cec5SDimitry Andric   WORKGROUP_SIZE_Z = 8,
380b57cec5SDimitry Andric 
390b57cec5SDimitry Andric   GRID_SIZE_X = 12,
400b57cec5SDimitry Andric   GRID_SIZE_Y = 16,
410b57cec5SDimitry Andric   GRID_SIZE_Z = 20
420b57cec5SDimitry Andric };
430b57cec5SDimitry Andric 
440b57cec5SDimitry Andric class AMDGPULowerKernelAttributes : public ModulePass {
450b57cec5SDimitry Andric public:
460b57cec5SDimitry Andric   static char ID;
470b57cec5SDimitry Andric 
480b57cec5SDimitry Andric   AMDGPULowerKernelAttributes() : ModulePass(ID) {}
490b57cec5SDimitry Andric 
500b57cec5SDimitry Andric   bool runOnModule(Module &M) override;
510b57cec5SDimitry Andric 
520b57cec5SDimitry Andric   StringRef getPassName() const override {
530b57cec5SDimitry Andric     return "AMDGPU Kernel Attributes";
540b57cec5SDimitry Andric   }
550b57cec5SDimitry Andric 
560b57cec5SDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
570b57cec5SDimitry Andric     AU.setPreservesAll();
580b57cec5SDimitry Andric  }
590b57cec5SDimitry Andric };
600b57cec5SDimitry Andric 
610b57cec5SDimitry Andric } // end anonymous namespace
620b57cec5SDimitry Andric 
63e8d8bef9SDimitry Andric static bool processUse(CallInst *CI) {
640b57cec5SDimitry Andric   Function *F = CI->getParent()->getParent();
650b57cec5SDimitry Andric 
660b57cec5SDimitry Andric   auto MD = F->getMetadata("reqd_work_group_size");
670b57cec5SDimitry Andric   const bool HasReqdWorkGroupSize = MD && MD->getNumOperands() == 3;
680b57cec5SDimitry Andric 
690b57cec5SDimitry Andric   const bool HasUniformWorkGroupSize =
70*fe6060f1SDimitry Andric     F->getFnAttribute("uniform-work-group-size").getValueAsBool();
710b57cec5SDimitry Andric 
720b57cec5SDimitry Andric   if (!HasReqdWorkGroupSize && !HasUniformWorkGroupSize)
730b57cec5SDimitry Andric     return false;
740b57cec5SDimitry Andric 
750b57cec5SDimitry Andric   Value *WorkGroupSizeX = nullptr;
760b57cec5SDimitry Andric   Value *WorkGroupSizeY = nullptr;
770b57cec5SDimitry Andric   Value *WorkGroupSizeZ = nullptr;
780b57cec5SDimitry Andric 
790b57cec5SDimitry Andric   Value *GridSizeX = nullptr;
800b57cec5SDimitry Andric   Value *GridSizeY = nullptr;
810b57cec5SDimitry Andric   Value *GridSizeZ = nullptr;
820b57cec5SDimitry Andric 
83e8d8bef9SDimitry Andric   const DataLayout &DL = F->getParent()->getDataLayout();
840b57cec5SDimitry Andric 
850b57cec5SDimitry Andric   // We expect to see several GEP users, casted to the appropriate type and
860b57cec5SDimitry Andric   // loaded.
870b57cec5SDimitry Andric   for (User *U : CI->users()) {
880b57cec5SDimitry Andric     if (!U->hasOneUse())
890b57cec5SDimitry Andric       continue;
900b57cec5SDimitry Andric 
910b57cec5SDimitry Andric     int64_t Offset = 0;
920b57cec5SDimitry Andric     if (GetPointerBaseWithConstantOffset(U, Offset, DL) != CI)
930b57cec5SDimitry Andric       continue;
940b57cec5SDimitry Andric 
950b57cec5SDimitry Andric     auto *BCI = dyn_cast<BitCastInst>(*U->user_begin());
960b57cec5SDimitry Andric     if (!BCI || !BCI->hasOneUse())
970b57cec5SDimitry Andric       continue;
980b57cec5SDimitry Andric 
990b57cec5SDimitry Andric     auto *Load = dyn_cast<LoadInst>(*BCI->user_begin());
1000b57cec5SDimitry Andric     if (!Load || !Load->isSimple())
1010b57cec5SDimitry Andric       continue;
1020b57cec5SDimitry Andric 
1030b57cec5SDimitry Andric     unsigned LoadSize = DL.getTypeStoreSize(Load->getType());
1040b57cec5SDimitry Andric 
1050b57cec5SDimitry Andric     // TODO: Handle merged loads.
1060b57cec5SDimitry Andric     switch (Offset) {
1070b57cec5SDimitry Andric     case WORKGROUP_SIZE_X:
1080b57cec5SDimitry Andric       if (LoadSize == 2)
1090b57cec5SDimitry Andric         WorkGroupSizeX = Load;
1100b57cec5SDimitry Andric       break;
1110b57cec5SDimitry Andric     case WORKGROUP_SIZE_Y:
1120b57cec5SDimitry Andric       if (LoadSize == 2)
1130b57cec5SDimitry Andric         WorkGroupSizeY = Load;
1140b57cec5SDimitry Andric       break;
1150b57cec5SDimitry Andric     case WORKGROUP_SIZE_Z:
1160b57cec5SDimitry Andric       if (LoadSize == 2)
1170b57cec5SDimitry Andric         WorkGroupSizeZ = Load;
1180b57cec5SDimitry Andric       break;
1190b57cec5SDimitry Andric     case GRID_SIZE_X:
1200b57cec5SDimitry Andric       if (LoadSize == 4)
1210b57cec5SDimitry Andric         GridSizeX = Load;
1220b57cec5SDimitry Andric       break;
1230b57cec5SDimitry Andric     case GRID_SIZE_Y:
1240b57cec5SDimitry Andric       if (LoadSize == 4)
1250b57cec5SDimitry Andric         GridSizeY = Load;
1260b57cec5SDimitry Andric       break;
1270b57cec5SDimitry Andric     case GRID_SIZE_Z:
1280b57cec5SDimitry Andric       if (LoadSize == 4)
1290b57cec5SDimitry Andric         GridSizeZ = Load;
1300b57cec5SDimitry Andric       break;
1310b57cec5SDimitry Andric     default:
1320b57cec5SDimitry Andric       break;
1330b57cec5SDimitry Andric     }
1340b57cec5SDimitry Andric   }
1350b57cec5SDimitry Andric 
1360b57cec5SDimitry Andric   // Pattern match the code used to handle partial workgroup dispatches in the
1370b57cec5SDimitry Andric   // library implementation of get_local_size, so the entire function can be
1380b57cec5SDimitry Andric   // constant folded with a known group size.
1390b57cec5SDimitry Andric   //
1400b57cec5SDimitry Andric   // uint r = grid_size - group_id * group_size;
1410b57cec5SDimitry Andric   // get_local_size = (r < group_size) ? r : group_size;
1420b57cec5SDimitry Andric   //
1430b57cec5SDimitry Andric   // If we have uniform-work-group-size (which is the default in OpenCL 1.2),
1440b57cec5SDimitry Andric   // the grid_size is required to be a multiple of group_size). In this case:
1450b57cec5SDimitry Andric   //
1460b57cec5SDimitry Andric   // grid_size - (group_id * group_size) < group_size
1470b57cec5SDimitry Andric   // ->
1480b57cec5SDimitry Andric   // grid_size < group_size + (group_id * group_size)
1490b57cec5SDimitry Andric   //
1500b57cec5SDimitry Andric   // (grid_size / group_size) < 1 + group_id
1510b57cec5SDimitry Andric   //
1520b57cec5SDimitry Andric   // grid_size / group_size is at least 1, so we can conclude the select
1530b57cec5SDimitry Andric   // condition is false (except for group_id == 0, where the select result is
1540b57cec5SDimitry Andric   // the same).
1550b57cec5SDimitry Andric 
1560b57cec5SDimitry Andric   bool MadeChange = false;
1570b57cec5SDimitry Andric   Value *WorkGroupSizes[3] = { WorkGroupSizeX, WorkGroupSizeY, WorkGroupSizeZ };
1580b57cec5SDimitry Andric   Value *GridSizes[3] = { GridSizeX, GridSizeY, GridSizeZ };
1590b57cec5SDimitry Andric 
1600b57cec5SDimitry Andric   for (int I = 0; HasUniformWorkGroupSize && I < 3; ++I) {
1610b57cec5SDimitry Andric     Value *GroupSize = WorkGroupSizes[I];
1620b57cec5SDimitry Andric     Value *GridSize = GridSizes[I];
1630b57cec5SDimitry Andric     if (!GroupSize || !GridSize)
1640b57cec5SDimitry Andric       continue;
1650b57cec5SDimitry Andric 
1660b57cec5SDimitry Andric     for (User *U : GroupSize->users()) {
1670b57cec5SDimitry Andric       auto *ZextGroupSize = dyn_cast<ZExtInst>(U);
1680b57cec5SDimitry Andric       if (!ZextGroupSize)
1690b57cec5SDimitry Andric         continue;
1700b57cec5SDimitry Andric 
1710b57cec5SDimitry Andric       for (User *ZextUser : ZextGroupSize->users()) {
1720b57cec5SDimitry Andric         auto *SI = dyn_cast<SelectInst>(ZextUser);
1730b57cec5SDimitry Andric         if (!SI)
1740b57cec5SDimitry Andric           continue;
1750b57cec5SDimitry Andric 
1760b57cec5SDimitry Andric         using namespace llvm::PatternMatch;
1770b57cec5SDimitry Andric         auto GroupIDIntrin = I == 0 ?
1780b57cec5SDimitry Andric           m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>() :
1790b57cec5SDimitry Andric             (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>() :
1800b57cec5SDimitry Andric                       m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>());
1810b57cec5SDimitry Andric 
1820b57cec5SDimitry Andric         auto SubExpr = m_Sub(m_Specific(GridSize),
1830b57cec5SDimitry Andric                              m_Mul(GroupIDIntrin, m_Specific(ZextGroupSize)));
1840b57cec5SDimitry Andric 
1850b57cec5SDimitry Andric         ICmpInst::Predicate Pred;
1860b57cec5SDimitry Andric         if (match(SI,
1870b57cec5SDimitry Andric                   m_Select(m_ICmp(Pred, SubExpr, m_Specific(ZextGroupSize)),
1880b57cec5SDimitry Andric                            SubExpr,
1890b57cec5SDimitry Andric                            m_Specific(ZextGroupSize))) &&
1900b57cec5SDimitry Andric             Pred == ICmpInst::ICMP_ULT) {
1910b57cec5SDimitry Andric           if (HasReqdWorkGroupSize) {
1920b57cec5SDimitry Andric             ConstantInt *KnownSize
1930b57cec5SDimitry Andric               = mdconst::extract<ConstantInt>(MD->getOperand(I));
1940b57cec5SDimitry Andric             SI->replaceAllUsesWith(ConstantExpr::getIntegerCast(KnownSize,
1950b57cec5SDimitry Andric                                                                 SI->getType(),
1960b57cec5SDimitry Andric                                                                 false));
1970b57cec5SDimitry Andric           } else {
1980b57cec5SDimitry Andric             SI->replaceAllUsesWith(ZextGroupSize);
1990b57cec5SDimitry Andric           }
2000b57cec5SDimitry Andric 
2010b57cec5SDimitry Andric           MadeChange = true;
2020b57cec5SDimitry Andric         }
2030b57cec5SDimitry Andric       }
2040b57cec5SDimitry Andric     }
2050b57cec5SDimitry Andric   }
2060b57cec5SDimitry Andric 
2070b57cec5SDimitry Andric   if (!HasReqdWorkGroupSize)
2080b57cec5SDimitry Andric     return MadeChange;
2090b57cec5SDimitry Andric 
2100b57cec5SDimitry Andric   // Eliminate any other loads we can from the dispatch packet.
2110b57cec5SDimitry Andric   for (int I = 0; I < 3; ++I) {
2120b57cec5SDimitry Andric     Value *GroupSize = WorkGroupSizes[I];
2130b57cec5SDimitry Andric     if (!GroupSize)
2140b57cec5SDimitry Andric       continue;
2150b57cec5SDimitry Andric 
2160b57cec5SDimitry Andric     ConstantInt *KnownSize = mdconst::extract<ConstantInt>(MD->getOperand(I));
2170b57cec5SDimitry Andric     GroupSize->replaceAllUsesWith(
2180b57cec5SDimitry Andric       ConstantExpr::getIntegerCast(KnownSize,
2190b57cec5SDimitry Andric                                    GroupSize->getType(),
2200b57cec5SDimitry Andric                                    false));
2210b57cec5SDimitry Andric     MadeChange = true;
2220b57cec5SDimitry Andric   }
2230b57cec5SDimitry Andric 
2240b57cec5SDimitry Andric   return MadeChange;
2250b57cec5SDimitry Andric }
2260b57cec5SDimitry Andric 
2270b57cec5SDimitry Andric // TODO: Move makeLIDRangeMetadata usage into here. Seem to not get
2280b57cec5SDimitry Andric // TargetPassConfig for subtarget.
2290b57cec5SDimitry Andric bool AMDGPULowerKernelAttributes::runOnModule(Module &M) {
2300b57cec5SDimitry Andric   StringRef DispatchPtrName
2310b57cec5SDimitry Andric     = Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr);
2320b57cec5SDimitry Andric 
233e8d8bef9SDimitry Andric   Function *DispatchPtr = M.getFunction(DispatchPtrName);
2340b57cec5SDimitry Andric   if (!DispatchPtr) // Dispatch ptr not used.
2350b57cec5SDimitry Andric     return false;
2360b57cec5SDimitry Andric 
2370b57cec5SDimitry Andric   bool MadeChange = false;
2380b57cec5SDimitry Andric 
2390b57cec5SDimitry Andric   SmallPtrSet<Instruction *, 4> HandledUses;
2400b57cec5SDimitry Andric   for (auto *U : DispatchPtr->users()) {
2410b57cec5SDimitry Andric     CallInst *CI = cast<CallInst>(U);
2420b57cec5SDimitry Andric     if (HandledUses.insert(CI).second) {
2430b57cec5SDimitry Andric       if (processUse(CI))
2440b57cec5SDimitry Andric         MadeChange = true;
2450b57cec5SDimitry Andric     }
2460b57cec5SDimitry Andric   }
2470b57cec5SDimitry Andric 
2480b57cec5SDimitry Andric   return MadeChange;
2490b57cec5SDimitry Andric }
2500b57cec5SDimitry Andric 
2510b57cec5SDimitry Andric INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes, DEBUG_TYPE,
252*fe6060f1SDimitry Andric                       "AMDGPU Kernel Attributes", false, false)
253*fe6060f1SDimitry Andric INITIALIZE_PASS_END(AMDGPULowerKernelAttributes, DEBUG_TYPE,
254*fe6060f1SDimitry Andric                     "AMDGPU Kernel Attributes", false, false)
2550b57cec5SDimitry Andric 
2560b57cec5SDimitry Andric char AMDGPULowerKernelAttributes::ID = 0;
2570b57cec5SDimitry Andric 
2580b57cec5SDimitry Andric ModulePass *llvm::createAMDGPULowerKernelAttributesPass() {
2590b57cec5SDimitry Andric   return new AMDGPULowerKernelAttributes();
2600b57cec5SDimitry Andric }
261e8d8bef9SDimitry Andric 
262e8d8bef9SDimitry Andric PreservedAnalyses
263e8d8bef9SDimitry Andric AMDGPULowerKernelAttributesPass::run(Function &F, FunctionAnalysisManager &AM) {
264e8d8bef9SDimitry Andric   StringRef DispatchPtrName =
265e8d8bef9SDimitry Andric       Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr);
266e8d8bef9SDimitry Andric 
267e8d8bef9SDimitry Andric   Function *DispatchPtr = F.getParent()->getFunction(DispatchPtrName);
268e8d8bef9SDimitry Andric   if (!DispatchPtr) // Dispatch ptr not used.
269e8d8bef9SDimitry Andric     return PreservedAnalyses::all();
270e8d8bef9SDimitry Andric 
271e8d8bef9SDimitry Andric   for (Instruction &I : instructions(F)) {
272e8d8bef9SDimitry Andric     if (CallInst *CI = dyn_cast<CallInst>(&I)) {
273e8d8bef9SDimitry Andric       if (CI->getCalledFunction() == DispatchPtr)
274e8d8bef9SDimitry Andric         processUse(CI);
275e8d8bef9SDimitry Andric     }
276e8d8bef9SDimitry Andric   }
277e8d8bef9SDimitry Andric 
278e8d8bef9SDimitry Andric   return PreservedAnalyses::all();
279e8d8bef9SDimitry Andric }
280