xref: /freebsd/contrib/llvm-project/llvm/lib/Target/NVPTX/NVVMReflect.cpp (revision 0b57cec536236d46e3dba9bd041533462f33dbb7)
1*0b57cec5SDimitry Andric //===- NVVMReflect.cpp - NVVM Emulate conditional compilation -------------===//
2*0b57cec5SDimitry Andric //
3*0b57cec5SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*0b57cec5SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*0b57cec5SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*0b57cec5SDimitry Andric //
7*0b57cec5SDimitry Andric //===----------------------------------------------------------------------===//
8*0b57cec5SDimitry Andric //
9*0b57cec5SDimitry Andric // This pass replaces occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect
10*0b57cec5SDimitry Andric // with an integer.
11*0b57cec5SDimitry Andric //
12*0b57cec5SDimitry Andric // We choose the value we use by looking at metadata in the module itself.  Note
13*0b57cec5SDimitry Andric // that we intentionally only have one way to choose these values, because other
14*0b57cec5SDimitry Andric // parts of LLVM (particularly, InstCombineCall) rely on being able to predict
15*0b57cec5SDimitry Andric // the values chosen by this pass.
16*0b57cec5SDimitry Andric //
17*0b57cec5SDimitry Andric // If we see an unknown string, we replace its call with 0.
18*0b57cec5SDimitry Andric //
19*0b57cec5SDimitry Andric //===----------------------------------------------------------------------===//
20*0b57cec5SDimitry Andric 
21*0b57cec5SDimitry Andric #include "NVPTX.h"
22*0b57cec5SDimitry Andric #include "llvm/ADT/SmallVector.h"
23*0b57cec5SDimitry Andric #include "llvm/ADT/StringMap.h"
24*0b57cec5SDimitry Andric #include "llvm/IR/Constants.h"
25*0b57cec5SDimitry Andric #include "llvm/IR/DerivedTypes.h"
26*0b57cec5SDimitry Andric #include "llvm/IR/Function.h"
27*0b57cec5SDimitry Andric #include "llvm/IR/InstIterator.h"
28*0b57cec5SDimitry Andric #include "llvm/IR/Instructions.h"
29*0b57cec5SDimitry Andric #include "llvm/IR/Intrinsics.h"
30*0b57cec5SDimitry Andric #include "llvm/IR/Module.h"
31*0b57cec5SDimitry Andric #include "llvm/IR/Type.h"
32*0b57cec5SDimitry Andric #include "llvm/Pass.h"
33*0b57cec5SDimitry Andric #include "llvm/Support/CommandLine.h"
34*0b57cec5SDimitry Andric #include "llvm/Support/Debug.h"
35*0b57cec5SDimitry Andric #include "llvm/Support/raw_os_ostream.h"
36*0b57cec5SDimitry Andric #include "llvm/Support/raw_ostream.h"
37*0b57cec5SDimitry Andric #include "llvm/Transforms/Scalar.h"
38*0b57cec5SDimitry Andric #include <sstream>
39*0b57cec5SDimitry Andric #include <string>
40*0b57cec5SDimitry Andric #define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
41*0b57cec5SDimitry Andric 
42*0b57cec5SDimitry Andric using namespace llvm;
43*0b57cec5SDimitry Andric 
44*0b57cec5SDimitry Andric #define DEBUG_TYPE "nvptx-reflect"
45*0b57cec5SDimitry Andric 
46*0b57cec5SDimitry Andric namespace llvm { void initializeNVVMReflectPass(PassRegistry &); }
47*0b57cec5SDimitry Andric 
48*0b57cec5SDimitry Andric namespace {
49*0b57cec5SDimitry Andric class NVVMReflect : public FunctionPass {
50*0b57cec5SDimitry Andric public:
51*0b57cec5SDimitry Andric   static char ID;
52*0b57cec5SDimitry Andric   unsigned int SmVersion;
53*0b57cec5SDimitry Andric   NVVMReflect() : NVVMReflect(0) {}
54*0b57cec5SDimitry Andric   explicit NVVMReflect(unsigned int Sm) : FunctionPass(ID), SmVersion(Sm) {
55*0b57cec5SDimitry Andric     initializeNVVMReflectPass(*PassRegistry::getPassRegistry());
56*0b57cec5SDimitry Andric   }
57*0b57cec5SDimitry Andric 
58*0b57cec5SDimitry Andric   bool runOnFunction(Function &) override;
59*0b57cec5SDimitry Andric };
60*0b57cec5SDimitry Andric }
61*0b57cec5SDimitry Andric 
62*0b57cec5SDimitry Andric FunctionPass *llvm::createNVVMReflectPass(unsigned int SmVersion) {
63*0b57cec5SDimitry Andric   return new NVVMReflect(SmVersion);
64*0b57cec5SDimitry Andric }
65*0b57cec5SDimitry Andric 
66*0b57cec5SDimitry Andric static cl::opt<bool>
67*0b57cec5SDimitry Andric NVVMReflectEnabled("nvvm-reflect-enable", cl::init(true), cl::Hidden,
68*0b57cec5SDimitry Andric                    cl::desc("NVVM reflection, enabled by default"));
69*0b57cec5SDimitry Andric 
70*0b57cec5SDimitry Andric char NVVMReflect::ID = 0;
71*0b57cec5SDimitry Andric INITIALIZE_PASS(NVVMReflect, "nvvm-reflect",
72*0b57cec5SDimitry Andric                 "Replace occurrences of __nvvm_reflect() calls with 0/1", false,
73*0b57cec5SDimitry Andric                 false)
74*0b57cec5SDimitry Andric 
75*0b57cec5SDimitry Andric bool NVVMReflect::runOnFunction(Function &F) {
76*0b57cec5SDimitry Andric   if (!NVVMReflectEnabled)
77*0b57cec5SDimitry Andric     return false;
78*0b57cec5SDimitry Andric 
79*0b57cec5SDimitry Andric   if (F.getName() == NVVM_REFLECT_FUNCTION) {
80*0b57cec5SDimitry Andric     assert(F.isDeclaration() && "_reflect function should not have a body");
81*0b57cec5SDimitry Andric     assert(F.getReturnType()->isIntegerTy() &&
82*0b57cec5SDimitry Andric            "_reflect's return type should be integer");
83*0b57cec5SDimitry Andric     return false;
84*0b57cec5SDimitry Andric   }
85*0b57cec5SDimitry Andric 
86*0b57cec5SDimitry Andric   SmallVector<Instruction *, 4> ToRemove;
87*0b57cec5SDimitry Andric 
88*0b57cec5SDimitry Andric   // Go through the calls in this function.  Each call to __nvvm_reflect or
89*0b57cec5SDimitry Andric   // llvm.nvvm.reflect should be a CallInst with a ConstantArray argument.
90*0b57cec5SDimitry Andric   // First validate that. If the c-string corresponding to the ConstantArray can
91*0b57cec5SDimitry Andric   // be found successfully, see if it can be found in VarMap. If so, replace the
92*0b57cec5SDimitry Andric   // uses of CallInst with the value found in VarMap. If not, replace the use
93*0b57cec5SDimitry Andric   // with value 0.
94*0b57cec5SDimitry Andric 
95*0b57cec5SDimitry Andric   // The IR for __nvvm_reflect calls differs between CUDA versions.
96*0b57cec5SDimitry Andric   //
97*0b57cec5SDimitry Andric   // CUDA 6.5 and earlier uses this sequence:
98*0b57cec5SDimitry Andric   //    %ptr = tail call i8* @llvm.nvvm.ptr.constant.to.gen.p0i8.p4i8
99*0b57cec5SDimitry Andric   //        (i8 addrspace(4)* getelementptr inbounds
100*0b57cec5SDimitry Andric   //           ([8 x i8], [8 x i8] addrspace(4)* @str, i32 0, i32 0))
101*0b57cec5SDimitry Andric   //    %reflect = tail call i32 @__nvvm_reflect(i8* %ptr)
102*0b57cec5SDimitry Andric   //
103*0b57cec5SDimitry Andric   // The value returned by Sym->getOperand(0) is a Constant with a
104*0b57cec5SDimitry Andric   // ConstantDataSequential operand which can be converted to string and used
105*0b57cec5SDimitry Andric   // for lookup.
106*0b57cec5SDimitry Andric   //
107*0b57cec5SDimitry Andric   // CUDA 7.0 does it slightly differently:
108*0b57cec5SDimitry Andric   //   %reflect = call i32 @__nvvm_reflect(i8* addrspacecast
109*0b57cec5SDimitry Andric   //        (i8 addrspace(1)* getelementptr inbounds
110*0b57cec5SDimitry Andric   //           ([8 x i8], [8 x i8] addrspace(1)* @str, i32 0, i32 0) to i8*))
111*0b57cec5SDimitry Andric   //
112*0b57cec5SDimitry Andric   // In this case, we get a Constant with a GlobalVariable operand and we need
113*0b57cec5SDimitry Andric   // to dig deeper to find its initializer with the string we'll use for lookup.
114*0b57cec5SDimitry Andric   for (Instruction &I : instructions(F)) {
115*0b57cec5SDimitry Andric     CallInst *Call = dyn_cast<CallInst>(&I);
116*0b57cec5SDimitry Andric     if (!Call)
117*0b57cec5SDimitry Andric       continue;
118*0b57cec5SDimitry Andric     Function *Callee = Call->getCalledFunction();
119*0b57cec5SDimitry Andric     if (!Callee || (Callee->getName() != NVVM_REFLECT_FUNCTION &&
120*0b57cec5SDimitry Andric                     Callee->getIntrinsicID() != Intrinsic::nvvm_reflect))
121*0b57cec5SDimitry Andric       continue;
122*0b57cec5SDimitry Andric 
123*0b57cec5SDimitry Andric     // FIXME: Improve error handling here and elsewhere in this pass.
124*0b57cec5SDimitry Andric     assert(Call->getNumOperands() == 2 &&
125*0b57cec5SDimitry Andric            "Wrong number of operands to __nvvm_reflect function");
126*0b57cec5SDimitry Andric 
127*0b57cec5SDimitry Andric     // In cuda 6.5 and earlier, we will have an extra constant-to-generic
128*0b57cec5SDimitry Andric     // conversion of the string.
129*0b57cec5SDimitry Andric     const Value *Str = Call->getArgOperand(0);
130*0b57cec5SDimitry Andric     if (const CallInst *ConvCall = dyn_cast<CallInst>(Str)) {
131*0b57cec5SDimitry Andric       // FIXME: Add assertions about ConvCall.
132*0b57cec5SDimitry Andric       Str = ConvCall->getArgOperand(0);
133*0b57cec5SDimitry Andric     }
134*0b57cec5SDimitry Andric     assert(isa<ConstantExpr>(Str) &&
135*0b57cec5SDimitry Andric            "Format of __nvvm__reflect function not recognized");
136*0b57cec5SDimitry Andric     const ConstantExpr *GEP = cast<ConstantExpr>(Str);
137*0b57cec5SDimitry Andric 
138*0b57cec5SDimitry Andric     const Value *Sym = GEP->getOperand(0);
139*0b57cec5SDimitry Andric     assert(isa<Constant>(Sym) &&
140*0b57cec5SDimitry Andric            "Format of __nvvm_reflect function not recognized");
141*0b57cec5SDimitry Andric 
142*0b57cec5SDimitry Andric     const Value *Operand = cast<Constant>(Sym)->getOperand(0);
143*0b57cec5SDimitry Andric     if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Operand)) {
144*0b57cec5SDimitry Andric       // For CUDA-7.0 style __nvvm_reflect calls, we need to find the operand's
145*0b57cec5SDimitry Andric       // initializer.
146*0b57cec5SDimitry Andric       assert(GV->hasInitializer() &&
147*0b57cec5SDimitry Andric              "Format of _reflect function not recognized");
148*0b57cec5SDimitry Andric       const Constant *Initializer = GV->getInitializer();
149*0b57cec5SDimitry Andric       Operand = Initializer;
150*0b57cec5SDimitry Andric     }
151*0b57cec5SDimitry Andric 
152*0b57cec5SDimitry Andric     assert(isa<ConstantDataSequential>(Operand) &&
153*0b57cec5SDimitry Andric            "Format of _reflect function not recognized");
154*0b57cec5SDimitry Andric     assert(cast<ConstantDataSequential>(Operand)->isCString() &&
155*0b57cec5SDimitry Andric            "Format of _reflect function not recognized");
156*0b57cec5SDimitry Andric 
157*0b57cec5SDimitry Andric     StringRef ReflectArg = cast<ConstantDataSequential>(Operand)->getAsString();
158*0b57cec5SDimitry Andric     ReflectArg = ReflectArg.substr(0, ReflectArg.size() - 1);
159*0b57cec5SDimitry Andric     LLVM_DEBUG(dbgs() << "Arg of _reflect : " << ReflectArg << "\n");
160*0b57cec5SDimitry Andric 
161*0b57cec5SDimitry Andric     int ReflectVal = 0; // The default value is 0
162*0b57cec5SDimitry Andric     if (ReflectArg == "__CUDA_FTZ") {
163*0b57cec5SDimitry Andric       // Try to pull __CUDA_FTZ from the nvvm-reflect-ftz module flag.  Our
164*0b57cec5SDimitry Andric       // choice here must be kept in sync with AutoUpgrade, which uses the same
165*0b57cec5SDimitry Andric       // technique to detect whether ftz is enabled.
166*0b57cec5SDimitry Andric       if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
167*0b57cec5SDimitry Andric               F.getParent()->getModuleFlag("nvvm-reflect-ftz")))
168*0b57cec5SDimitry Andric         ReflectVal = Flag->getSExtValue();
169*0b57cec5SDimitry Andric     } else if (ReflectArg == "__CUDA_ARCH") {
170*0b57cec5SDimitry Andric       ReflectVal = SmVersion * 10;
171*0b57cec5SDimitry Andric     }
172*0b57cec5SDimitry Andric     Call->replaceAllUsesWith(ConstantInt::get(Call->getType(), ReflectVal));
173*0b57cec5SDimitry Andric     ToRemove.push_back(Call);
174*0b57cec5SDimitry Andric   }
175*0b57cec5SDimitry Andric 
176*0b57cec5SDimitry Andric   for (Instruction *I : ToRemove)
177*0b57cec5SDimitry Andric     I->eraseFromParent();
178*0b57cec5SDimitry Andric 
179*0b57cec5SDimitry Andric   return ToRemove.size() > 0;
180*0b57cec5SDimitry Andric }
181