xref: /freebsd/contrib/llvm-project/llvm/lib/Target/NVPTX/NVVMIntrRange.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- NVVMIntrRange.cpp - Set range attributes for NVVM intrinsics -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass adds appropriate range attributes for calls to NVVM
10 // intrinsics that return a limited range of values.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "NVPTX.h"
15 #include "NVPTXUtilities.h"
16 #include "llvm/IR/InstIterator.h"
17 #include "llvm/IR/Instructions.h"
18 #include "llvm/IR/IntrinsicInst.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/IntrinsicsNVPTX.h"
21 #include "llvm/IR/PassManager.h"
22 #include <cstdint>
23 
24 using namespace llvm;
25 
26 #define DEBUG_TYPE "nvvm-intr-range"
27 
28 namespace {
29 class NVVMIntrRange : public FunctionPass {
30 public:
31   static char ID;
NVVMIntrRange()32   NVVMIntrRange() : FunctionPass(ID) {}
33 
34   bool runOnFunction(Function &) override;
35 };
36 } // namespace
37 
createNVVMIntrRangePass()38 FunctionPass *llvm::createNVVMIntrRangePass() { return new NVVMIntrRange(); }
39 
40 char NVVMIntrRange::ID = 0;
41 INITIALIZE_PASS(NVVMIntrRange, "nvvm-intr-range",
42                 "Add !range metadata to NVVM intrinsics.", false, false)
43 
44 // Adds the passed-in [Low,High) range information as metadata to the
45 // passed-in call instruction.
addRangeAttr(uint64_t Low,uint64_t High,IntrinsicInst * II)46 static bool addRangeAttr(uint64_t Low, uint64_t High, IntrinsicInst *II) {
47   if (II->getMetadata(LLVMContext::MD_range))
48     return false;
49 
50   const uint64_t BitWidth = II->getType()->getIntegerBitWidth();
51   ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High));
52 
53   if (auto CurrentRange = II->getRange())
54     Range = Range.intersectWith(CurrentRange.value());
55 
56   II->addRangeRetAttr(Range);
57   return true;
58 }
59 
runNVVMIntrRange(Function & F)60 static bool runNVVMIntrRange(Function &F) {
61   struct Vector3 {
62     unsigned X, Y, Z;
63   };
64 
65   // All these annotations are only valid for kernel functions.
66   if (!isKernelFunction(F))
67     return false;
68 
69   const auto OverallReqNTID = getOverallReqNTID(F);
70   const auto OverallMaxNTID = getOverallMaxNTID(F);
71   const auto OverallClusterRank = getOverallClusterRank(F);
72 
73   // If this function lacks any range information, do nothing.
74   if (!(OverallReqNTID || OverallMaxNTID || OverallClusterRank))
75     return false;
76 
77   const unsigned FunctionNTID = OverallReqNTID.value_or(
78       OverallMaxNTID.value_or(std::numeric_limits<unsigned>::max()));
79 
80   const unsigned FunctionClusterRank =
81       OverallClusterRank.value_or(std::numeric_limits<unsigned>::max());
82 
83   const Vector3 MaxBlockSize{std::min(1024u, FunctionNTID),
84                              std::min(1024u, FunctionNTID),
85                              std::min(64u, FunctionNTID)};
86 
87   // We conservatively use the maximum grid size as an upper bound for the
88   // cluster rank.
89   const Vector3 MaxClusterRank{std::min(0x7fffffffu, FunctionClusterRank),
90                                std::min(0xffffu, FunctionClusterRank),
91                                std::min(0xffffu, FunctionClusterRank)};
92 
93   const auto ProccessIntrinsic = [&](IntrinsicInst *II) -> bool {
94     switch (II->getIntrinsicID()) {
95     // Index within block
96     case Intrinsic::nvvm_read_ptx_sreg_tid_x:
97       return addRangeAttr(0, MaxBlockSize.X, II);
98     case Intrinsic::nvvm_read_ptx_sreg_tid_y:
99       return addRangeAttr(0, MaxBlockSize.Y, II);
100     case Intrinsic::nvvm_read_ptx_sreg_tid_z:
101       return addRangeAttr(0, MaxBlockSize.Z, II);
102 
103     // Block size
104     case Intrinsic::nvvm_read_ptx_sreg_ntid_x:
105       return addRangeAttr(1, MaxBlockSize.X + 1, II);
106     case Intrinsic::nvvm_read_ptx_sreg_ntid_y:
107       return addRangeAttr(1, MaxBlockSize.Y + 1, II);
108     case Intrinsic::nvvm_read_ptx_sreg_ntid_z:
109       return addRangeAttr(1, MaxBlockSize.Z + 1, II);
110 
111     // Cluster size
112     case Intrinsic::nvvm_read_ptx_sreg_cluster_ctaid_x:
113       return addRangeAttr(0, MaxClusterRank.X, II);
114     case Intrinsic::nvvm_read_ptx_sreg_cluster_ctaid_y:
115       return addRangeAttr(0, MaxClusterRank.Y, II);
116     case Intrinsic::nvvm_read_ptx_sreg_cluster_ctaid_z:
117       return addRangeAttr(0, MaxClusterRank.Z, II);
118     case Intrinsic::nvvm_read_ptx_sreg_cluster_nctaid_x:
119       return addRangeAttr(1, MaxClusterRank.X + 1, II);
120     case Intrinsic::nvvm_read_ptx_sreg_cluster_nctaid_y:
121       return addRangeAttr(1, MaxClusterRank.Y + 1, II);
122     case Intrinsic::nvvm_read_ptx_sreg_cluster_nctaid_z:
123       return addRangeAttr(1, MaxClusterRank.Z + 1, II);
124 
125     case Intrinsic::nvvm_read_ptx_sreg_cluster_ctarank:
126       if (OverallClusterRank)
127         return addRangeAttr(0, FunctionClusterRank, II);
128       break;
129     case Intrinsic::nvvm_read_ptx_sreg_cluster_nctarank:
130       if (OverallClusterRank)
131         return addRangeAttr(1, FunctionClusterRank + 1, II);
132       break;
133     default:
134       return false;
135     }
136     return false;
137   };
138 
139   // Go through the calls in this function.
140   bool Changed = false;
141   for (Instruction &I : instructions(F))
142     if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I))
143       Changed |= ProccessIntrinsic(II);
144 
145   return Changed;
146 }
147 
runOnFunction(Function & F)148 bool NVVMIntrRange::runOnFunction(Function &F) { return runNVVMIntrRange(F); }
149 
run(Function & F,FunctionAnalysisManager & AM)150 PreservedAnalyses NVVMIntrRangePass::run(Function &F,
151                                          FunctionAnalysisManager &AM) {
152   return runNVVMIntrRange(F) ? PreservedAnalyses::none()
153                              : PreservedAnalyses::all();
154 }
155