1 //===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass modifies function signatures containing aggregate arguments 10 // and/or return value. Also it substitutes some llvm intrinsic calls by 11 // function calls, generating these functions as the translator does. 12 // 13 // NOTE: this pass is a module-level one due to the necessity to modify 14 // GVs/functions. 15 // 16 //===----------------------------------------------------------------------===// 17 18 #include "SPIRV.h" 19 #include "SPIRVTargetMachine.h" 20 #include "SPIRVUtils.h" 21 #include "llvm/IR/IRBuilder.h" 22 #include "llvm/IR/IntrinsicInst.h" 23 #include "llvm/Transforms/Utils/Cloning.h" 24 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h" 25 26 using namespace llvm; 27 28 namespace llvm { 29 void initializeSPIRVPrepareFunctionsPass(PassRegistry &); 30 } 31 32 namespace { 33 34 class SPIRVPrepareFunctions : public ModulePass { 35 Function *processFunctionSignature(Function *F); 36 37 public: 38 static char ID; 39 SPIRVPrepareFunctions() : ModulePass(ID) { 40 initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry()); 41 } 42 43 bool runOnModule(Module &M) override; 44 45 StringRef getPassName() const override { return "SPIRV prepare functions"; } 46 47 void getAnalysisUsage(AnalysisUsage &AU) const override { 48 ModulePass::getAnalysisUsage(AU); 49 } 50 }; 51 52 } // namespace 53 54 char SPIRVPrepareFunctions::ID = 0; 55 56 INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions", 57 "SPIRV prepare functions", false, false) 58 59 Function *SPIRVPrepareFunctions::processFunctionSignature(Function *F) { 60 IRBuilder<> B(F->getContext()); 61 62 bool IsRetAggr = F->getReturnType()->isAggregateType(); 63 bool HasAggrArg = 64 std::any_of(F->arg_begin(), F->arg_end(), [](Argument &Arg) { 65 return Arg.getType()->isAggregateType(); 66 }); 67 bool DoClone = IsRetAggr || HasAggrArg; 68 if (!DoClone) 69 return F; 70 SmallVector<std::pair<int, Type *>, 4> ChangedTypes; 71 Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType(); 72 if (IsRetAggr) 73 ChangedTypes.push_back(std::pair<int, Type *>(-1, F->getReturnType())); 74 SmallVector<Type *, 4> ArgTypes; 75 for (const auto &Arg : F->args()) { 76 if (Arg.getType()->isAggregateType()) { 77 ArgTypes.push_back(B.getInt32Ty()); 78 ChangedTypes.push_back( 79 std::pair<int, Type *>(Arg.getArgNo(), Arg.getType())); 80 } else 81 ArgTypes.push_back(Arg.getType()); 82 } 83 FunctionType *NewFTy = 84 FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg()); 85 Function *NewF = 86 Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent()); 87 88 ValueToValueMapTy VMap; 89 auto NewFArgIt = NewF->arg_begin(); 90 for (auto &Arg : F->args()) { 91 StringRef ArgName = Arg.getName(); 92 NewFArgIt->setName(ArgName); 93 VMap[&Arg] = &(*NewFArgIt++); 94 } 95 SmallVector<ReturnInst *, 8> Returns; 96 97 CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly, 98 Returns); 99 NewF->takeName(F); 100 101 NamedMDNode *FuncMD = 102 F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs"); 103 SmallVector<Metadata *, 2> MDArgs; 104 MDArgs.push_back(MDString::get(B.getContext(), NewF->getName())); 105 for (auto &ChangedTyP : ChangedTypes) 106 MDArgs.push_back(MDNode::get( 107 B.getContext(), 108 {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)), 109 ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))})); 110 MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs); 111 FuncMD->addOperand(ThisFuncMD); 112 113 for (auto *U : make_early_inc_range(F->users())) { 114 if (auto *CI = dyn_cast<CallInst>(U)) 115 CI->mutateFunctionType(NewF->getFunctionType()); 116 U->replaceUsesOfWith(F, NewF); 117 } 118 return NewF; 119 } 120 121 std::string lowerLLVMIntrinsicName(IntrinsicInst *II) { 122 Function *IntrinsicFunc = II->getCalledFunction(); 123 assert(IntrinsicFunc && "Missing function"); 124 std::string FuncName = IntrinsicFunc->getName().str(); 125 std::replace(FuncName.begin(), FuncName.end(), '.', '_'); 126 FuncName = "spirv." + FuncName; 127 return FuncName; 128 } 129 130 static Function *getOrCreateFunction(Module *M, Type *RetTy, 131 ArrayRef<Type *> ArgTypes, 132 StringRef Name) { 133 FunctionType *FT = FunctionType::get(RetTy, ArgTypes, false); 134 Function *F = M->getFunction(Name); 135 if (F && F->getFunctionType() == FT) 136 return F; 137 Function *NewF = Function::Create(FT, GlobalValue::ExternalLinkage, Name, M); 138 if (F) 139 NewF->setDSOLocal(F->isDSOLocal()); 140 NewF->setCallingConv(CallingConv::SPIR_FUNC); 141 return NewF; 142 } 143 144 static void lowerFunnelShifts(Module *M, IntrinsicInst *FSHIntrinsic) { 145 // Get a separate function - otherwise, we'd have to rework the CFG of the 146 // current one. Then simply replace the intrinsic uses with a call to the new 147 // function. 148 // Generate LLVM IR for i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c) 149 FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType(); 150 Type *FSHRetTy = FSHFuncTy->getReturnType(); 151 const std::string FuncName = lowerLLVMIntrinsicName(FSHIntrinsic); 152 Function *FSHFunc = 153 getOrCreateFunction(M, FSHRetTy, FSHFuncTy->params(), FuncName); 154 155 if (!FSHFunc->empty()) { 156 FSHIntrinsic->setCalledFunction(FSHFunc); 157 return; 158 } 159 BasicBlock *RotateBB = BasicBlock::Create(M->getContext(), "rotate", FSHFunc); 160 IRBuilder<> IRB(RotateBB); 161 Type *Ty = FSHFunc->getReturnType(); 162 // Build the actual funnel shift rotate logic. 163 // In the comments, "int" is used interchangeably with "vector of int 164 // elements". 165 FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Ty); 166 Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty; 167 unsigned BitWidth = IntTy->getIntegerBitWidth(); 168 ConstantInt *BitWidthConstant = IRB.getInt({BitWidth, BitWidth}); 169 Value *BitWidthForInsts = 170 VectorTy 171 ? IRB.CreateVectorSplat(VectorTy->getNumElements(), BitWidthConstant) 172 : BitWidthConstant; 173 Value *RotateModVal = 174 IRB.CreateURem(/*Rotate*/ FSHFunc->getArg(2), BitWidthForInsts); 175 Value *FirstShift = nullptr, *SecShift = nullptr; 176 if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) { 177 // Shift the less significant number right, the "rotate" number of bits 178 // will be 0-filled on the left as a result of this regular shift. 179 FirstShift = IRB.CreateLShr(FSHFunc->getArg(1), RotateModVal); 180 } else { 181 // Shift the more significant number left, the "rotate" number of bits 182 // will be 0-filled on the right as a result of this regular shift. 183 FirstShift = IRB.CreateShl(FSHFunc->getArg(0), RotateModVal); 184 } 185 // We want the "rotate" number of the more significant int's LSBs (MSBs) to 186 // occupy the leftmost (rightmost) "0 space" left by the previous operation. 187 // Therefore, subtract the "rotate" number from the integer bitsize... 188 Value *SubRotateVal = IRB.CreateSub(BitWidthForInsts, RotateModVal); 189 if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) { 190 // ...and left-shift the more significant int by this number, zero-filling 191 // the LSBs. 192 SecShift = IRB.CreateShl(FSHFunc->getArg(0), SubRotateVal); 193 } else { 194 // ...and right-shift the less significant int by this number, zero-filling 195 // the MSBs. 196 SecShift = IRB.CreateLShr(FSHFunc->getArg(1), SubRotateVal); 197 } 198 // A simple binary addition of the shifted ints yields the final result. 199 IRB.CreateRet(IRB.CreateOr(FirstShift, SecShift)); 200 201 FSHIntrinsic->setCalledFunction(FSHFunc); 202 } 203 204 static void buildUMulWithOverflowFunc(Module *M, Function *UMulFunc) { 205 // The function body is already created. 206 if (!UMulFunc->empty()) 207 return; 208 209 BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", UMulFunc); 210 IRBuilder<> IRB(EntryBB); 211 // Build the actual unsigned multiplication logic with the overflow 212 // indication. Do unsigned multiplication Mul = A * B. Then check 213 // if unsigned division Div = Mul / A is not equal to B. If so, 214 // then overflow has happened. 215 Value *Mul = IRB.CreateNUWMul(UMulFunc->getArg(0), UMulFunc->getArg(1)); 216 Value *Div = IRB.CreateUDiv(Mul, UMulFunc->getArg(0)); 217 Value *Overflow = IRB.CreateICmpNE(UMulFunc->getArg(0), Div); 218 219 // umul.with.overflow intrinsic return a structure, where the first element 220 // is the multiplication result, and the second is an overflow bit. 221 Type *StructTy = UMulFunc->getReturnType(); 222 Value *Agg = IRB.CreateInsertValue(UndefValue::get(StructTy), Mul, {0}); 223 Value *Res = IRB.CreateInsertValue(Agg, Overflow, {1}); 224 IRB.CreateRet(Res); 225 } 226 227 static void lowerUMulWithOverflow(Module *M, IntrinsicInst *UMulIntrinsic) { 228 // Get a separate function - otherwise, we'd have to rework the CFG of the 229 // current one. Then simply replace the intrinsic uses with a call to the new 230 // function. 231 FunctionType *UMulFuncTy = UMulIntrinsic->getFunctionType(); 232 Type *FSHLRetTy = UMulFuncTy->getReturnType(); 233 const std::string FuncName = lowerLLVMIntrinsicName(UMulIntrinsic); 234 Function *UMulFunc = 235 getOrCreateFunction(M, FSHLRetTy, UMulFuncTy->params(), FuncName); 236 buildUMulWithOverflowFunc(M, UMulFunc); 237 UMulIntrinsic->setCalledFunction(UMulFunc); 238 } 239 240 static void substituteIntrinsicCalls(Module *M, Function *F) { 241 for (BasicBlock &BB : *F) { 242 for (Instruction &I : BB) { 243 auto Call = dyn_cast<CallInst>(&I); 244 if (!Call) 245 continue; 246 Call->setTailCall(false); 247 Function *CF = Call->getCalledFunction(); 248 if (!CF || !CF->isIntrinsic()) 249 continue; 250 auto *II = cast<IntrinsicInst>(Call); 251 if (II->getIntrinsicID() == Intrinsic::fshl || 252 II->getIntrinsicID() == Intrinsic::fshr) 253 lowerFunnelShifts(M, II); 254 else if (II->getIntrinsicID() == Intrinsic::umul_with_overflow) 255 lowerUMulWithOverflow(M, II); 256 } 257 } 258 } 259 260 bool SPIRVPrepareFunctions::runOnModule(Module &M) { 261 for (Function &F : M) 262 substituteIntrinsicCalls(&M, &F); 263 264 std::vector<Function *> FuncsWorklist; 265 bool Changed = false; 266 for (auto &F : M) 267 FuncsWorklist.push_back(&F); 268 269 for (auto *Func : FuncsWorklist) { 270 Function *F = processFunctionSignature(Func); 271 272 bool CreatedNewF = F != Func; 273 274 if (Func->isDeclaration()) { 275 Changed |= CreatedNewF; 276 continue; 277 } 278 279 if (CreatedNewF) 280 Func->eraseFromParent(); 281 } 282 283 return Changed; 284 } 285 286 ModulePass *llvm::createSPIRVPrepareFunctionsPass() { 287 return new SPIRVPrepareFunctions(); 288 } 289