10b57cec5SDimitry Andric //===- NVVMReflect.cpp - NVVM Emulate conditional compilation -------------===//
20b57cec5SDimitry Andric //
30b57cec5SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40b57cec5SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
50b57cec5SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60b57cec5SDimitry Andric //
70b57cec5SDimitry Andric //===----------------------------------------------------------------------===//
80b57cec5SDimitry Andric //
90b57cec5SDimitry Andric // This pass replaces occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect
100b57cec5SDimitry Andric // with an integer.
110b57cec5SDimitry Andric //
120b57cec5SDimitry Andric // We choose the value we use by looking at metadata in the module itself. Note
130b57cec5SDimitry Andric // that we intentionally only have one way to choose these values, because other
140b57cec5SDimitry Andric // parts of LLVM (particularly, InstCombineCall) rely on being able to predict
150b57cec5SDimitry Andric // the values chosen by this pass.
160b57cec5SDimitry Andric //
170b57cec5SDimitry Andric // If we see an unknown string, we replace its call with 0.
180b57cec5SDimitry Andric //
190b57cec5SDimitry Andric //===----------------------------------------------------------------------===//
200b57cec5SDimitry Andric
210b57cec5SDimitry Andric #include "NVPTX.h"
220b57cec5SDimitry Andric #include "llvm/ADT/SmallVector.h"
23*0fca6ea1SDimitry Andric #include "llvm/Analysis/ConstantFolding.h"
240b57cec5SDimitry Andric #include "llvm/IR/Constants.h"
250b57cec5SDimitry Andric #include "llvm/IR/DerivedTypes.h"
260b57cec5SDimitry Andric #include "llvm/IR/Function.h"
270b57cec5SDimitry Andric #include "llvm/IR/InstIterator.h"
280b57cec5SDimitry Andric #include "llvm/IR/Instructions.h"
290b57cec5SDimitry Andric #include "llvm/IR/Intrinsics.h"
30480093f4SDimitry Andric #include "llvm/IR/IntrinsicsNVPTX.h"
310b57cec5SDimitry Andric #include "llvm/IR/Module.h"
32e8d8bef9SDimitry Andric #include "llvm/IR/PassManager.h"
330b57cec5SDimitry Andric #include "llvm/IR/Type.h"
340b57cec5SDimitry Andric #include "llvm/Pass.h"
350b57cec5SDimitry Andric #include "llvm/Support/CommandLine.h"
360b57cec5SDimitry Andric #include "llvm/Support/Debug.h"
370b57cec5SDimitry Andric #include "llvm/Support/raw_os_ostream.h"
380b57cec5SDimitry Andric #include "llvm/Support/raw_ostream.h"
390b57cec5SDimitry Andric #include "llvm/Transforms/Scalar.h"
40*0fca6ea1SDimitry Andric #include "llvm/Transforms/Utils/BasicBlockUtils.h"
41*0fca6ea1SDimitry Andric #include "llvm/Transforms/Utils/Local.h"
42*0fca6ea1SDimitry Andric #include <algorithm>
430b57cec5SDimitry Andric #include <sstream>
440b57cec5SDimitry Andric #include <string>
450b57cec5SDimitry Andric #define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
46bdd1243dSDimitry Andric #define NVVM_REFLECT_OCL_FUNCTION "__nvvm_reflect_ocl"
470b57cec5SDimitry Andric
480b57cec5SDimitry Andric using namespace llvm;
490b57cec5SDimitry Andric
500b57cec5SDimitry Andric #define DEBUG_TYPE "nvptx-reflect"
510b57cec5SDimitry Andric
520b57cec5SDimitry Andric namespace llvm { void initializeNVVMReflectPass(PassRegistry &); }
530b57cec5SDimitry Andric
540b57cec5SDimitry Andric namespace {
550b57cec5SDimitry Andric class NVVMReflect : public FunctionPass {
560b57cec5SDimitry Andric public:
570b57cec5SDimitry Andric static char ID;
580b57cec5SDimitry Andric unsigned int SmVersion;
NVVMReflect()590b57cec5SDimitry Andric NVVMReflect() : NVVMReflect(0) {}
NVVMReflect(unsigned int Sm)600b57cec5SDimitry Andric explicit NVVMReflect(unsigned int Sm) : FunctionPass(ID), SmVersion(Sm) {
610b57cec5SDimitry Andric initializeNVVMReflectPass(*PassRegistry::getPassRegistry());
620b57cec5SDimitry Andric }
630b57cec5SDimitry Andric
640b57cec5SDimitry Andric bool runOnFunction(Function &) override;
650b57cec5SDimitry Andric };
660b57cec5SDimitry Andric }
670b57cec5SDimitry Andric
createNVVMReflectPass(unsigned int SmVersion)680b57cec5SDimitry Andric FunctionPass *llvm::createNVVMReflectPass(unsigned int SmVersion) {
690b57cec5SDimitry Andric return new NVVMReflect(SmVersion);
700b57cec5SDimitry Andric }
710b57cec5SDimitry Andric
720b57cec5SDimitry Andric static cl::opt<bool>
730b57cec5SDimitry Andric NVVMReflectEnabled("nvvm-reflect-enable", cl::init(true), cl::Hidden,
740b57cec5SDimitry Andric cl::desc("NVVM reflection, enabled by default"));
750b57cec5SDimitry Andric
760b57cec5SDimitry Andric char NVVMReflect::ID = 0;
770b57cec5SDimitry Andric INITIALIZE_PASS(NVVMReflect, "nvvm-reflect",
780b57cec5SDimitry Andric "Replace occurrences of __nvvm_reflect() calls with 0/1", false,
790b57cec5SDimitry Andric false)
800b57cec5SDimitry Andric
runNVVMReflect(Function & F,unsigned SmVersion)81e8d8bef9SDimitry Andric static bool runNVVMReflect(Function &F, unsigned SmVersion) {
820b57cec5SDimitry Andric if (!NVVMReflectEnabled)
830b57cec5SDimitry Andric return false;
840b57cec5SDimitry Andric
85bdd1243dSDimitry Andric if (F.getName() == NVVM_REFLECT_FUNCTION ||
86bdd1243dSDimitry Andric F.getName() == NVVM_REFLECT_OCL_FUNCTION) {
870b57cec5SDimitry Andric assert(F.isDeclaration() && "_reflect function should not have a body");
880b57cec5SDimitry Andric assert(F.getReturnType()->isIntegerTy() &&
890b57cec5SDimitry Andric "_reflect's return type should be integer");
900b57cec5SDimitry Andric return false;
910b57cec5SDimitry Andric }
920b57cec5SDimitry Andric
930b57cec5SDimitry Andric SmallVector<Instruction *, 4> ToRemove;
94*0fca6ea1SDimitry Andric SmallVector<Instruction *, 4> ToSimplify;
950b57cec5SDimitry Andric
960b57cec5SDimitry Andric // Go through the calls in this function. Each call to __nvvm_reflect or
970b57cec5SDimitry Andric // llvm.nvvm.reflect should be a CallInst with a ConstantArray argument.
980b57cec5SDimitry Andric // First validate that. If the c-string corresponding to the ConstantArray can
990b57cec5SDimitry Andric // be found successfully, see if it can be found in VarMap. If so, replace the
1000b57cec5SDimitry Andric // uses of CallInst with the value found in VarMap. If not, replace the use
1010b57cec5SDimitry Andric // with value 0.
1020b57cec5SDimitry Andric
1030b57cec5SDimitry Andric // The IR for __nvvm_reflect calls differs between CUDA versions.
1040b57cec5SDimitry Andric //
1050b57cec5SDimitry Andric // CUDA 6.5 and earlier uses this sequence:
1060b57cec5SDimitry Andric // %ptr = tail call i8* @llvm.nvvm.ptr.constant.to.gen.p0i8.p4i8
1070b57cec5SDimitry Andric // (i8 addrspace(4)* getelementptr inbounds
1080b57cec5SDimitry Andric // ([8 x i8], [8 x i8] addrspace(4)* @str, i32 0, i32 0))
1090b57cec5SDimitry Andric // %reflect = tail call i32 @__nvvm_reflect(i8* %ptr)
1100b57cec5SDimitry Andric //
1110b57cec5SDimitry Andric // The value returned by Sym->getOperand(0) is a Constant with a
1120b57cec5SDimitry Andric // ConstantDataSequential operand which can be converted to string and used
1130b57cec5SDimitry Andric // for lookup.
1140b57cec5SDimitry Andric //
1150b57cec5SDimitry Andric // CUDA 7.0 does it slightly differently:
1160b57cec5SDimitry Andric // %reflect = call i32 @__nvvm_reflect(i8* addrspacecast
1170b57cec5SDimitry Andric // (i8 addrspace(1)* getelementptr inbounds
1180b57cec5SDimitry Andric // ([8 x i8], [8 x i8] addrspace(1)* @str, i32 0, i32 0) to i8*))
1190b57cec5SDimitry Andric //
1200b57cec5SDimitry Andric // In this case, we get a Constant with a GlobalVariable operand and we need
1210b57cec5SDimitry Andric // to dig deeper to find its initializer with the string we'll use for lookup.
1220b57cec5SDimitry Andric for (Instruction &I : instructions(F)) {
1230b57cec5SDimitry Andric CallInst *Call = dyn_cast<CallInst>(&I);
1240b57cec5SDimitry Andric if (!Call)
1250b57cec5SDimitry Andric continue;
1260b57cec5SDimitry Andric Function *Callee = Call->getCalledFunction();
1270b57cec5SDimitry Andric if (!Callee || (Callee->getName() != NVVM_REFLECT_FUNCTION &&
128bdd1243dSDimitry Andric Callee->getName() != NVVM_REFLECT_OCL_FUNCTION &&
1290b57cec5SDimitry Andric Callee->getIntrinsicID() != Intrinsic::nvvm_reflect))
1300b57cec5SDimitry Andric continue;
1310b57cec5SDimitry Andric
1320b57cec5SDimitry Andric // FIXME: Improve error handling here and elsewhere in this pass.
1330b57cec5SDimitry Andric assert(Call->getNumOperands() == 2 &&
1340b57cec5SDimitry Andric "Wrong number of operands to __nvvm_reflect function");
1350b57cec5SDimitry Andric
1360b57cec5SDimitry Andric // In cuda 6.5 and earlier, we will have an extra constant-to-generic
1370b57cec5SDimitry Andric // conversion of the string.
1380b57cec5SDimitry Andric const Value *Str = Call->getArgOperand(0);
1390b57cec5SDimitry Andric if (const CallInst *ConvCall = dyn_cast<CallInst>(Str)) {
1400b57cec5SDimitry Andric // FIXME: Add assertions about ConvCall.
1410b57cec5SDimitry Andric Str = ConvCall->getArgOperand(0);
1420b57cec5SDimitry Andric }
14381ad6265SDimitry Andric // Pre opaque pointers we have a constant expression wrapping the constant
14481ad6265SDimitry Andric // string.
14581ad6265SDimitry Andric Str = Str->stripPointerCasts();
14681ad6265SDimitry Andric assert(isa<Constant>(Str) &&
1470b57cec5SDimitry Andric "Format of __nvvm_reflect function not recognized");
1480b57cec5SDimitry Andric
14981ad6265SDimitry Andric const Value *Operand = cast<Constant>(Str)->getOperand(0);
1500b57cec5SDimitry Andric if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Operand)) {
1510b57cec5SDimitry Andric // For CUDA-7.0 style __nvvm_reflect calls, we need to find the operand's
1520b57cec5SDimitry Andric // initializer.
1530b57cec5SDimitry Andric assert(GV->hasInitializer() &&
1540b57cec5SDimitry Andric "Format of _reflect function not recognized");
1550b57cec5SDimitry Andric const Constant *Initializer = GV->getInitializer();
1560b57cec5SDimitry Andric Operand = Initializer;
1570b57cec5SDimitry Andric }
1580b57cec5SDimitry Andric
1590b57cec5SDimitry Andric assert(isa<ConstantDataSequential>(Operand) &&
1600b57cec5SDimitry Andric "Format of _reflect function not recognized");
1610b57cec5SDimitry Andric assert(cast<ConstantDataSequential>(Operand)->isCString() &&
1620b57cec5SDimitry Andric "Format of _reflect function not recognized");
1630b57cec5SDimitry Andric
1640b57cec5SDimitry Andric StringRef ReflectArg = cast<ConstantDataSequential>(Operand)->getAsString();
1650b57cec5SDimitry Andric ReflectArg = ReflectArg.substr(0, ReflectArg.size() - 1);
1660b57cec5SDimitry Andric LLVM_DEBUG(dbgs() << "Arg of _reflect : " << ReflectArg << "\n");
1670b57cec5SDimitry Andric
1680b57cec5SDimitry Andric int ReflectVal = 0; // The default value is 0
1690b57cec5SDimitry Andric if (ReflectArg == "__CUDA_FTZ") {
1700b57cec5SDimitry Andric // Try to pull __CUDA_FTZ from the nvvm-reflect-ftz module flag. Our
1710b57cec5SDimitry Andric // choice here must be kept in sync with AutoUpgrade, which uses the same
1720b57cec5SDimitry Andric // technique to detect whether ftz is enabled.
1730b57cec5SDimitry Andric if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
1740b57cec5SDimitry Andric F.getParent()->getModuleFlag("nvvm-reflect-ftz")))
1750b57cec5SDimitry Andric ReflectVal = Flag->getSExtValue();
1760b57cec5SDimitry Andric } else if (ReflectArg == "__CUDA_ARCH") {
1770b57cec5SDimitry Andric ReflectVal = SmVersion * 10;
1780b57cec5SDimitry Andric }
179*0fca6ea1SDimitry Andric
180*0fca6ea1SDimitry Andric // If the immediate user is a simple comparison we want to simplify it.
181*0fca6ea1SDimitry Andric for (User *U : Call->users())
182*0fca6ea1SDimitry Andric if (Instruction *I = dyn_cast<Instruction>(U))
183*0fca6ea1SDimitry Andric ToSimplify.push_back(I);
184*0fca6ea1SDimitry Andric
1850b57cec5SDimitry Andric Call->replaceAllUsesWith(ConstantInt::get(Call->getType(), ReflectVal));
1860b57cec5SDimitry Andric ToRemove.push_back(Call);
1870b57cec5SDimitry Andric }
1880b57cec5SDimitry Andric
189*0fca6ea1SDimitry Andric // The code guarded by __nvvm_reflect may be invalid for the target machine.
190*0fca6ea1SDimitry Andric // Traverse the use-def chain, continually simplifying constant expressions
191*0fca6ea1SDimitry Andric // until we find a terminator that we can then remove.
192*0fca6ea1SDimitry Andric while (!ToSimplify.empty()) {
193*0fca6ea1SDimitry Andric Instruction *I = ToSimplify.pop_back_val();
194*0fca6ea1SDimitry Andric if (Constant *C =
195*0fca6ea1SDimitry Andric ConstantFoldInstruction(I, F.getDataLayout())) {
196*0fca6ea1SDimitry Andric for (User *U : I->users())
197*0fca6ea1SDimitry Andric if (Instruction *I = dyn_cast<Instruction>(U))
198*0fca6ea1SDimitry Andric ToSimplify.push_back(I);
199*0fca6ea1SDimitry Andric
200*0fca6ea1SDimitry Andric I->replaceAllUsesWith(C);
201*0fca6ea1SDimitry Andric if (isInstructionTriviallyDead(I)) {
202*0fca6ea1SDimitry Andric ToRemove.push_back(I);
203*0fca6ea1SDimitry Andric }
204*0fca6ea1SDimitry Andric } else if (I->isTerminator()) {
205*0fca6ea1SDimitry Andric ConstantFoldTerminator(I->getParent());
206*0fca6ea1SDimitry Andric }
207*0fca6ea1SDimitry Andric }
208*0fca6ea1SDimitry Andric
209*0fca6ea1SDimitry Andric // Removing via isInstructionTriviallyDead may add duplicates to the ToRemove
210*0fca6ea1SDimitry Andric // array. Filter out the duplicates before starting to erase from parent.
211*0fca6ea1SDimitry Andric std::sort(ToRemove.begin(), ToRemove.end());
212*0fca6ea1SDimitry Andric auto NewLastIter = llvm::unique(ToRemove);
213*0fca6ea1SDimitry Andric ToRemove.erase(NewLastIter, ToRemove.end());
214*0fca6ea1SDimitry Andric
2150b57cec5SDimitry Andric for (Instruction *I : ToRemove)
2160b57cec5SDimitry Andric I->eraseFromParent();
2170b57cec5SDimitry Andric
2180b57cec5SDimitry Andric return ToRemove.size() > 0;
2190b57cec5SDimitry Andric }
220e8d8bef9SDimitry Andric
runOnFunction(Function & F)221e8d8bef9SDimitry Andric bool NVVMReflect::runOnFunction(Function &F) {
222e8d8bef9SDimitry Andric return runNVVMReflect(F, SmVersion);
223e8d8bef9SDimitry Andric }
224e8d8bef9SDimitry Andric
NVVMReflectPass()225e8d8bef9SDimitry Andric NVVMReflectPass::NVVMReflectPass() : NVVMReflectPass(0) {}
226e8d8bef9SDimitry Andric
run(Function & F,FunctionAnalysisManager & AM)227e8d8bef9SDimitry Andric PreservedAnalyses NVVMReflectPass::run(Function &F,
228e8d8bef9SDimitry Andric FunctionAnalysisManager &AM) {
229e8d8bef9SDimitry Andric return runNVVMReflect(F, SmVersion) ? PreservedAnalyses::none()
230e8d8bef9SDimitry Andric : PreservedAnalyses::all();
231e8d8bef9SDimitry Andric }
232