xref: /freebsd/contrib/llvm-project/llvm/lib/Target/NVPTX/NVVMIntrRange.cpp (revision 95eb4b873b6a8b527c5bd78d7191975dfca38998)
1 //===- NVVMIntrRange.cpp - Set !range metadata for NVVM intrinsics --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass adds appropriate !range metadata for calls to NVVM
10 // intrinsics that return a limited range of values.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "NVPTX.h"
15 #include "llvm/IR/Constants.h"
16 #include "llvm/IR/InstIterator.h"
17 #include "llvm/IR/Instructions.h"
18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/IR/IntrinsicsNVPTX.h"
20 #include "llvm/IR/PassManager.h"
21 #include "llvm/Support/CommandLine.h"
22 
23 using namespace llvm;
24 
25 #define DEBUG_TYPE "nvvm-intr-range"
26 
27 namespace llvm { void initializeNVVMIntrRangePass(PassRegistry &); }
28 
29 // Add !range metadata based on limits of given SM variant.
30 static cl::opt<unsigned> NVVMIntrRangeSM("nvvm-intr-range-sm", cl::init(20),
31                                          cl::Hidden, cl::desc("SM variant"));
32 
33 namespace {
34 class NVVMIntrRange : public FunctionPass {
35  private:
36    unsigned SmVersion;
37 
38  public:
39    static char ID;
40    NVVMIntrRange() : NVVMIntrRange(NVVMIntrRangeSM) {}
41    NVVMIntrRange(unsigned int SmVersion)
42        : FunctionPass(ID), SmVersion(SmVersion) {
43 
44      initializeNVVMIntrRangePass(*PassRegistry::getPassRegistry());
45    }
46 
47    bool runOnFunction(Function &) override;
48 };
49 }
50 
51 FunctionPass *llvm::createNVVMIntrRangePass(unsigned int SmVersion) {
52   return new NVVMIntrRange(SmVersion);
53 }
54 
55 char NVVMIntrRange::ID = 0;
56 INITIALIZE_PASS(NVVMIntrRange, "nvvm-intr-range",
57                 "Add !range metadata to NVVM intrinsics.", false, false)
58 
59 // Adds the passed-in [Low,High) range information as metadata to the
60 // passed-in call instruction.
61 static bool addRangeMetadata(uint64_t Low, uint64_t High, CallInst *C) {
62   // This call already has range metadata, nothing to do.
63   if (C->getMetadata(LLVMContext::MD_range))
64     return false;
65 
66   LLVMContext &Context = C->getParent()->getContext();
67   IntegerType *Int32Ty = Type::getInt32Ty(Context);
68   Metadata *LowAndHigh[] = {
69       ConstantAsMetadata::get(ConstantInt::get(Int32Ty, Low)),
70       ConstantAsMetadata::get(ConstantInt::get(Int32Ty, High))};
71   C->setMetadata(LLVMContext::MD_range, MDNode::get(Context, LowAndHigh));
72   return true;
73 }
74 
75 static bool runNVVMIntrRange(Function &F, unsigned SmVersion) {
76   struct {
77     unsigned x, y, z;
78   } MaxBlockSize, MaxGridSize;
79   MaxBlockSize.x = 1024;
80   MaxBlockSize.y = 1024;
81   MaxBlockSize.z = 64;
82 
83   MaxGridSize.x = SmVersion >= 30 ? 0x7fffffff : 0xffff;
84   MaxGridSize.y = 0xffff;
85   MaxGridSize.z = 0xffff;
86 
87   // Go through the calls in this function.
88   bool Changed = false;
89   for (Instruction &I : instructions(F)) {
90     CallInst *Call = dyn_cast<CallInst>(&I);
91     if (!Call)
92       continue;
93 
94     if (Function *Callee = Call->getCalledFunction()) {
95       switch (Callee->getIntrinsicID()) {
96       // Index within block
97       case Intrinsic::nvvm_read_ptx_sreg_tid_x:
98         Changed |= addRangeMetadata(0, MaxBlockSize.x, Call);
99         break;
100       case Intrinsic::nvvm_read_ptx_sreg_tid_y:
101         Changed |= addRangeMetadata(0, MaxBlockSize.y, Call);
102         break;
103       case Intrinsic::nvvm_read_ptx_sreg_tid_z:
104         Changed |= addRangeMetadata(0, MaxBlockSize.z, Call);
105         break;
106 
107       // Block size
108       case Intrinsic::nvvm_read_ptx_sreg_ntid_x:
109         Changed |= addRangeMetadata(1, MaxBlockSize.x+1, Call);
110         break;
111       case Intrinsic::nvvm_read_ptx_sreg_ntid_y:
112         Changed |= addRangeMetadata(1, MaxBlockSize.y+1, Call);
113         break;
114       case Intrinsic::nvvm_read_ptx_sreg_ntid_z:
115         Changed |= addRangeMetadata(1, MaxBlockSize.z+1, Call);
116         break;
117 
118       // Index within grid
119       case Intrinsic::nvvm_read_ptx_sreg_ctaid_x:
120         Changed |= addRangeMetadata(0, MaxGridSize.x, Call);
121         break;
122       case Intrinsic::nvvm_read_ptx_sreg_ctaid_y:
123         Changed |= addRangeMetadata(0, MaxGridSize.y, Call);
124         break;
125       case Intrinsic::nvvm_read_ptx_sreg_ctaid_z:
126         Changed |= addRangeMetadata(0, MaxGridSize.z, Call);
127         break;
128 
129       // Grid size
130       case Intrinsic::nvvm_read_ptx_sreg_nctaid_x:
131         Changed |= addRangeMetadata(1, MaxGridSize.x+1, Call);
132         break;
133       case Intrinsic::nvvm_read_ptx_sreg_nctaid_y:
134         Changed |= addRangeMetadata(1, MaxGridSize.y+1, Call);
135         break;
136       case Intrinsic::nvvm_read_ptx_sreg_nctaid_z:
137         Changed |= addRangeMetadata(1, MaxGridSize.z+1, Call);
138         break;
139 
140       // warp size is constant 32.
141       case Intrinsic::nvvm_read_ptx_sreg_warpsize:
142         Changed |= addRangeMetadata(32, 32+1, Call);
143         break;
144 
145       // Lane ID is [0..warpsize)
146       case Intrinsic::nvvm_read_ptx_sreg_laneid:
147         Changed |= addRangeMetadata(0, 32, Call);
148         break;
149 
150       default:
151         break;
152       }
153     }
154   }
155 
156   return Changed;
157 }
158 
159 bool NVVMIntrRange::runOnFunction(Function &F) {
160   return runNVVMIntrRange(F, SmVersion);
161 }
162 
163 NVVMIntrRangePass::NVVMIntrRangePass() : NVVMIntrRangePass(NVVMIntrRangeSM) {}
164 
165 PreservedAnalyses NVVMIntrRangePass::run(Function &F,
166                                          FunctionAnalysisManager &AM) {
167   return runNVVMIntrRange(F, SmVersion) ? PreservedAnalyses::none()
168                                         : PreservedAnalyses::all();
169 }
170