1 //===- NVVMReflect.cpp - NVVM Emulate conditional compilation -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass replaces occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect
10 // with an integer.
11 //
12 // We choose the value we use by looking at metadata in the module itself. Note
13 // that we intentionally only have one way to choose these values, because other
14 // parts of LLVM (particularly, InstCombineCall) rely on being able to predict
15 // the values chosen by this pass.
16 //
17 // If we see an unknown string, we replace its call with 0.
18 //
19 //===----------------------------------------------------------------------===//
20
21 #include "NVPTX.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/Analysis/ConstantFolding.h"
25 #include "llvm/CodeGen/CommandFlags.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/DerivedTypes.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/Intrinsics.h"
31 #include "llvm/IR/IntrinsicsNVPTX.h"
32 #include "llvm/IR/Module.h"
33 #include "llvm/IR/PassManager.h"
34 #include "llvm/IR/Type.h"
35 #include "llvm/Pass.h"
36 #include "llvm/Support/CommandLine.h"
37 #include "llvm/Support/Debug.h"
38 #include "llvm/Support/raw_ostream.h"
39 #include "llvm/Transforms/Scalar.h"
40 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
41 #include "llvm/Transforms/Utils/Local.h"
42 #define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
43 #define NVVM_REFLECT_OCL_FUNCTION "__nvvm_reflect_ocl"
44 // Argument of reflect call to retrive arch number
45 #define CUDA_ARCH_NAME "__CUDA_ARCH"
46 // Argument of reflect call to retrive ftz mode
47 #define CUDA_FTZ_NAME "__CUDA_FTZ"
48 // Name of module metadata where ftz mode is stored
49 #define CUDA_FTZ_MODULE_NAME "nvvm-reflect-ftz"
50
51 using namespace llvm;
52
53 #define DEBUG_TYPE "nvvm-reflect"
54
55 namespace {
56 class NVVMReflect {
57 // Map from reflect function call arguments to the value to replace the call
58 // with. Should include __CUDA_FTZ and __CUDA_ARCH values.
59 StringMap<unsigned> ReflectMap;
60 bool handleReflectFunction(Module &M, StringRef ReflectName);
61 void populateReflectMap(Module &M);
62 void foldReflectCall(CallInst *Call, Constant *NewValue);
63
64 public:
65 // __CUDA_FTZ is assigned in `runOnModule` by checking nvvm-reflect-ftz module
66 // metadata.
NVVMReflect(unsigned SmVersion)67 explicit NVVMReflect(unsigned SmVersion)
68 : ReflectMap({{CUDA_ARCH_NAME, SmVersion * 10}}) {}
69 bool runOnModule(Module &M);
70 };
71
72 class NVVMReflectLegacyPass : public ModulePass {
73 NVVMReflect Impl;
74
75 public:
76 static char ID;
NVVMReflectLegacyPass(unsigned SmVersion)77 NVVMReflectLegacyPass(unsigned SmVersion) : ModulePass(ID), Impl(SmVersion) {}
78 bool runOnModule(Module &M) override;
79 };
80 } // namespace
81
createNVVMReflectPass(unsigned SmVersion)82 ModulePass *llvm::createNVVMReflectPass(unsigned SmVersion) {
83 return new NVVMReflectLegacyPass(SmVersion);
84 }
85
86 static cl::opt<bool>
87 NVVMReflectEnabled("nvvm-reflect-enable", cl::init(true), cl::Hidden,
88 cl::desc("NVVM reflection, enabled by default"));
89
90 char NVVMReflectLegacyPass::ID = 0;
91 INITIALIZE_PASS(NVVMReflectLegacyPass, "nvvm-reflect",
92 "Replace occurrences of __nvvm_reflect() calls with 0/1", false,
93 false)
94
95 // Allow users to specify additional key/value pairs to reflect. These key/value
96 // pairs are the last to be added to the ReflectMap, and therefore will take
97 // precedence over initial values (i.e. __CUDA_FTZ from module medadata and
98 // __CUDA_ARCH from SmVersion).
99 static cl::list<std::string> ReflectList(
100 "nvvm-reflect-add", cl::value_desc("name=<int>"), cl::Hidden,
101 cl::desc("A key=value pair. Replace __nvvm_reflect(name) with value."),
102 cl::ValueRequired);
103
104 // Set the ReflectMap with, first, the value of __CUDA_FTZ from module metadata,
105 // and then the key/value pairs from the command line.
populateReflectMap(Module & M)106 void NVVMReflect::populateReflectMap(Module &M) {
107 if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
108 M.getModuleFlag(CUDA_FTZ_MODULE_NAME)))
109 ReflectMap[CUDA_FTZ_NAME] = Flag->getSExtValue();
110
111 for (auto &Option : ReflectList) {
112 LLVM_DEBUG(dbgs() << "ReflectOption : " << Option << "\n");
113 StringRef OptionRef(Option);
114 auto [Name, Val] = OptionRef.split('=');
115 if (Name.empty())
116 report_fatal_error(Twine("Empty name in nvvm-reflect-add option '") +
117 Option + "'");
118 if (Val.empty())
119 report_fatal_error(Twine("Missing value in nvvm-reflect-add option '") +
120 Option + "'");
121 unsigned ValInt;
122 if (!to_integer(Val.trim(), ValInt, 10))
123 report_fatal_error(
124 Twine("integer value expected in nvvm-reflect-add option '") +
125 Option + "'");
126 ReflectMap[Name] = ValInt;
127 }
128 }
129
130 /// Process a reflect function by finding all its calls and replacing them with
131 /// appropriate constant values. For __CUDA_FTZ, uses the module flag value.
132 /// For __CUDA_ARCH, uses SmVersion * 10. For all other strings, uses 0.
handleReflectFunction(Module & M,StringRef ReflectName)133 bool NVVMReflect::handleReflectFunction(Module &M, StringRef ReflectName) {
134 Function *F = M.getFunction(ReflectName);
135 if (!F)
136 return false;
137 assert(F->isDeclaration() && "_reflect function should not have a body");
138 assert(F->getReturnType()->isIntegerTy() &&
139 "_reflect's return type should be integer");
140
141 const bool Changed = !F->use_empty();
142 for (User *U : make_early_inc_range(F->users())) {
143 // Reflect function calls look like:
144 // @arch = private unnamed_addr addrspace(1) constant [12 x i8]
145 // c"__CUDA_ARCH\00" call i32 @__nvvm_reflect(ptr addrspacecast (ptr
146 // addrspace(1) @arch to ptr)) We need to extract the string argument from
147 // the call (i.e. "__CUDA_ARCH")
148 auto *Call = dyn_cast<CallInst>(U);
149 if (!Call)
150 report_fatal_error(
151 "__nvvm_reflect can only be used in a call instruction");
152 if (Call->getNumOperands() != 2)
153 report_fatal_error("__nvvm_reflect requires exactly one argument");
154
155 auto *GlobalStr =
156 dyn_cast<Constant>(Call->getArgOperand(0)->stripPointerCasts());
157 if (!GlobalStr)
158 report_fatal_error("__nvvm_reflect argument must be a constant string");
159
160 auto *ConstantStr =
161 dyn_cast<ConstantDataSequential>(GlobalStr->getOperand(0));
162 if (!ConstantStr)
163 report_fatal_error("__nvvm_reflect argument must be a string constant");
164 if (!ConstantStr->isCString())
165 report_fatal_error(
166 "__nvvm_reflect argument must be a null-terminated string");
167
168 StringRef ReflectArg = ConstantStr->getAsString().drop_back();
169 if (ReflectArg.empty())
170 report_fatal_error("__nvvm_reflect argument cannot be empty");
171 // Now that we have extracted the string argument, we can look it up in the
172 // ReflectMap
173 unsigned ReflectVal = 0; // The default value is 0
174 if (ReflectMap.contains(ReflectArg))
175 ReflectVal = ReflectMap[ReflectArg];
176
177 LLVM_DEBUG(dbgs() << "Replacing call of reflect function " << F->getName()
178 << "(" << ReflectArg << ") with value " << ReflectVal
179 << "\n");
180 auto *NewValue = ConstantInt::get(Call->getType(), ReflectVal);
181 foldReflectCall(Call, NewValue);
182 Call->eraseFromParent();
183 }
184
185 // Remove the __nvvm_reflect function from the module
186 F->eraseFromParent();
187 return Changed;
188 }
189
foldReflectCall(CallInst * Call,Constant * NewValue)190 void NVVMReflect::foldReflectCall(CallInst *Call, Constant *NewValue) {
191 SmallVector<Instruction *, 8> Worklist;
192 // Replace an instruction with a constant and add all users of the instruction
193 // to the worklist
194 auto ReplaceInstructionWithConst = [&](Instruction *I, Constant *C) {
195 for (auto *U : I->users())
196 if (auto *UI = dyn_cast<Instruction>(U))
197 Worklist.push_back(UI);
198 I->replaceAllUsesWith(C);
199 };
200
201 ReplaceInstructionWithConst(Call, NewValue);
202
203 auto &DL = Call->getModule()->getDataLayout();
204 while (!Worklist.empty()) {
205 auto *I = Worklist.pop_back_val();
206 if (auto *C = ConstantFoldInstruction(I, DL)) {
207 ReplaceInstructionWithConst(I, C);
208 if (isInstructionTriviallyDead(I))
209 I->eraseFromParent();
210 } else if (I->isTerminator()) {
211 ConstantFoldTerminator(I->getParent());
212 }
213 }
214 }
215
runOnModule(Module & M)216 bool NVVMReflect::runOnModule(Module &M) {
217 if (!NVVMReflectEnabled)
218 return false;
219 populateReflectMap(M);
220 bool Changed = true;
221 Changed |= handleReflectFunction(M, NVVM_REFLECT_FUNCTION);
222 Changed |= handleReflectFunction(M, NVVM_REFLECT_OCL_FUNCTION);
223 Changed |=
224 handleReflectFunction(M, Intrinsic::getName(Intrinsic::nvvm_reflect));
225 return Changed;
226 }
227
runOnModule(Module & M)228 bool NVVMReflectLegacyPass::runOnModule(Module &M) {
229 return Impl.runOnModule(M);
230 }
231
run(Module & M,ModuleAnalysisManager & AM)232 PreservedAnalyses NVVMReflectPass::run(Module &M, ModuleAnalysisManager &AM) {
233 return NVVMReflect(SmVersion).runOnModule(M) ? PreservedAnalyses::none()
234 : PreservedAnalyses::all();
235 }
236