1 //===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass modifies function signatures containing aggregate arguments 10 // and/or return value. Also it substitutes some llvm intrinsic calls by 11 // function calls, generating these functions as the translator does. 12 // 13 // NOTE: this pass is a module-level one due to the necessity to modify 14 // GVs/functions. 15 // 16 //===----------------------------------------------------------------------===// 17 18 #include "SPIRV.h" 19 #include "SPIRVTargetMachine.h" 20 #include "SPIRVUtils.h" 21 #include "llvm/CodeGen/IntrinsicLowering.h" 22 #include "llvm/IR/IRBuilder.h" 23 #include "llvm/IR/IntrinsicInst.h" 24 #include "llvm/Transforms/Utils/Cloning.h" 25 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h" 26 27 using namespace llvm; 28 29 namespace llvm { 30 void initializeSPIRVPrepareFunctionsPass(PassRegistry &); 31 } 32 33 namespace { 34 35 class SPIRVPrepareFunctions : public ModulePass { 36 Function *processFunctionSignature(Function *F); 37 38 public: 39 static char ID; 40 SPIRVPrepareFunctions() : ModulePass(ID) { 41 initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry()); 42 } 43 44 bool runOnModule(Module &M) override; 45 46 StringRef getPassName() const override { return "SPIRV prepare functions"; } 47 48 void getAnalysisUsage(AnalysisUsage &AU) const override { 49 ModulePass::getAnalysisUsage(AU); 50 } 51 }; 52 53 } // namespace 54 55 char SPIRVPrepareFunctions::ID = 0; 56 57 INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions", 58 "SPIRV prepare functions", false, false) 59 60 Function *SPIRVPrepareFunctions::processFunctionSignature(Function *F) { 61 IRBuilder<> B(F->getContext()); 62 63 bool IsRetAggr = F->getReturnType()->isAggregateType(); 64 bool HasAggrArg = 65 std::any_of(F->arg_begin(), F->arg_end(), [](Argument &Arg) { 66 return Arg.getType()->isAggregateType(); 67 }); 68 bool DoClone = IsRetAggr || HasAggrArg; 69 if (!DoClone) 70 return F; 71 SmallVector<std::pair<int, Type *>, 4> ChangedTypes; 72 Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType(); 73 if (IsRetAggr) 74 ChangedTypes.push_back(std::pair<int, Type *>(-1, F->getReturnType())); 75 SmallVector<Type *, 4> ArgTypes; 76 for (const auto &Arg : F->args()) { 77 if (Arg.getType()->isAggregateType()) { 78 ArgTypes.push_back(B.getInt32Ty()); 79 ChangedTypes.push_back( 80 std::pair<int, Type *>(Arg.getArgNo(), Arg.getType())); 81 } else 82 ArgTypes.push_back(Arg.getType()); 83 } 84 FunctionType *NewFTy = 85 FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg()); 86 Function *NewF = 87 Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent()); 88 89 ValueToValueMapTy VMap; 90 auto NewFArgIt = NewF->arg_begin(); 91 for (auto &Arg : F->args()) { 92 StringRef ArgName = Arg.getName(); 93 NewFArgIt->setName(ArgName); 94 VMap[&Arg] = &(*NewFArgIt++); 95 } 96 SmallVector<ReturnInst *, 8> Returns; 97 98 CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly, 99 Returns); 100 NewF->takeName(F); 101 102 NamedMDNode *FuncMD = 103 F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs"); 104 SmallVector<Metadata *, 2> MDArgs; 105 MDArgs.push_back(MDString::get(B.getContext(), NewF->getName())); 106 for (auto &ChangedTyP : ChangedTypes) 107 MDArgs.push_back(MDNode::get( 108 B.getContext(), 109 {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)), 110 ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))})); 111 MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs); 112 FuncMD->addOperand(ThisFuncMD); 113 114 for (auto *U : make_early_inc_range(F->users())) { 115 if (auto *CI = dyn_cast<CallInst>(U)) 116 CI->mutateFunctionType(NewF->getFunctionType()); 117 U->replaceUsesOfWith(F, NewF); 118 } 119 return NewF; 120 } 121 122 std::string lowerLLVMIntrinsicName(IntrinsicInst *II) { 123 Function *IntrinsicFunc = II->getCalledFunction(); 124 assert(IntrinsicFunc && "Missing function"); 125 std::string FuncName = IntrinsicFunc->getName().str(); 126 std::replace(FuncName.begin(), FuncName.end(), '.', '_'); 127 FuncName = "spirv." + FuncName; 128 return FuncName; 129 } 130 131 static Function *getOrCreateFunction(Module *M, Type *RetTy, 132 ArrayRef<Type *> ArgTypes, 133 StringRef Name) { 134 FunctionType *FT = FunctionType::get(RetTy, ArgTypes, false); 135 Function *F = M->getFunction(Name); 136 if (F && F->getFunctionType() == FT) 137 return F; 138 Function *NewF = Function::Create(FT, GlobalValue::ExternalLinkage, Name, M); 139 if (F) 140 NewF->setDSOLocal(F->isDSOLocal()); 141 NewF->setCallingConv(CallingConv::SPIR_FUNC); 142 return NewF; 143 } 144 145 static void lowerIntrinsicToFunction(Module *M, IntrinsicInst *Intrinsic) { 146 // For @llvm.memset.* intrinsic cases with constant value and length arguments 147 // are emulated via "storing" a constant array to the destination. For other 148 // cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the 149 // intrinsic to a loop via expandMemSetAsLoop(). 150 if (auto *MSI = dyn_cast<MemSetInst>(Intrinsic)) 151 if (isa<Constant>(MSI->getValue()) && isa<ConstantInt>(MSI->getLength())) 152 return; // It is handled later using OpCopyMemorySized. 153 154 std::string FuncName = lowerLLVMIntrinsicName(Intrinsic); 155 if (Intrinsic->isVolatile()) 156 FuncName += ".volatile"; 157 // Redirect @llvm.intrinsic.* call to @spirv.llvm_intrinsic_* 158 Function *F = M->getFunction(FuncName); 159 if (F) { 160 Intrinsic->setCalledFunction(F); 161 return; 162 } 163 // TODO copy arguments attributes: nocapture writeonly. 164 FunctionCallee FC = 165 M->getOrInsertFunction(FuncName, Intrinsic->getFunctionType()); 166 auto IntrinsicID = Intrinsic->getIntrinsicID(); 167 Intrinsic->setCalledFunction(FC); 168 169 F = dyn_cast<Function>(FC.getCallee()); 170 assert(F && "Callee must be a function"); 171 172 switch (IntrinsicID) { 173 case Intrinsic::memset: { 174 auto *MSI = static_cast<MemSetInst *>(Intrinsic); 175 Argument *Dest = F->getArg(0); 176 Argument *Val = F->getArg(1); 177 Argument *Len = F->getArg(2); 178 Argument *IsVolatile = F->getArg(3); 179 Dest->setName("dest"); 180 Val->setName("val"); 181 Len->setName("len"); 182 IsVolatile->setName("isvolatile"); 183 BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F); 184 IRBuilder<> IRB(EntryBB); 185 auto *MemSet = IRB.CreateMemSet(Dest, Val, Len, MSI->getDestAlign(), 186 MSI->isVolatile()); 187 IRB.CreateRetVoid(); 188 expandMemSetAsLoop(cast<MemSetInst>(MemSet)); 189 MemSet->eraseFromParent(); 190 break; 191 } 192 case Intrinsic::bswap: { 193 BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F); 194 IRBuilder<> IRB(EntryBB); 195 auto *BSwap = IRB.CreateIntrinsic(Intrinsic::bswap, Intrinsic->getType(), 196 F->getArg(0)); 197 IRB.CreateRet(BSwap); 198 IntrinsicLowering IL(M->getDataLayout()); 199 IL.LowerIntrinsicCall(BSwap); 200 break; 201 } 202 default: 203 break; 204 } 205 return; 206 } 207 208 static void lowerFunnelShifts(Module *M, IntrinsicInst *FSHIntrinsic) { 209 // Get a separate function - otherwise, we'd have to rework the CFG of the 210 // current one. Then simply replace the intrinsic uses with a call to the new 211 // function. 212 // Generate LLVM IR for i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c) 213 FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType(); 214 Type *FSHRetTy = FSHFuncTy->getReturnType(); 215 const std::string FuncName = lowerLLVMIntrinsicName(FSHIntrinsic); 216 Function *FSHFunc = 217 getOrCreateFunction(M, FSHRetTy, FSHFuncTy->params(), FuncName); 218 219 if (!FSHFunc->empty()) { 220 FSHIntrinsic->setCalledFunction(FSHFunc); 221 return; 222 } 223 BasicBlock *RotateBB = BasicBlock::Create(M->getContext(), "rotate", FSHFunc); 224 IRBuilder<> IRB(RotateBB); 225 Type *Ty = FSHFunc->getReturnType(); 226 // Build the actual funnel shift rotate logic. 227 // In the comments, "int" is used interchangeably with "vector of int 228 // elements". 229 FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Ty); 230 Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty; 231 unsigned BitWidth = IntTy->getIntegerBitWidth(); 232 ConstantInt *BitWidthConstant = IRB.getInt({BitWidth, BitWidth}); 233 Value *BitWidthForInsts = 234 VectorTy 235 ? IRB.CreateVectorSplat(VectorTy->getNumElements(), BitWidthConstant) 236 : BitWidthConstant; 237 Value *RotateModVal = 238 IRB.CreateURem(/*Rotate*/ FSHFunc->getArg(2), BitWidthForInsts); 239 Value *FirstShift = nullptr, *SecShift = nullptr; 240 if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) { 241 // Shift the less significant number right, the "rotate" number of bits 242 // will be 0-filled on the left as a result of this regular shift. 243 FirstShift = IRB.CreateLShr(FSHFunc->getArg(1), RotateModVal); 244 } else { 245 // Shift the more significant number left, the "rotate" number of bits 246 // will be 0-filled on the right as a result of this regular shift. 247 FirstShift = IRB.CreateShl(FSHFunc->getArg(0), RotateModVal); 248 } 249 // We want the "rotate" number of the more significant int's LSBs (MSBs) to 250 // occupy the leftmost (rightmost) "0 space" left by the previous operation. 251 // Therefore, subtract the "rotate" number from the integer bitsize... 252 Value *SubRotateVal = IRB.CreateSub(BitWidthForInsts, RotateModVal); 253 if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) { 254 // ...and left-shift the more significant int by this number, zero-filling 255 // the LSBs. 256 SecShift = IRB.CreateShl(FSHFunc->getArg(0), SubRotateVal); 257 } else { 258 // ...and right-shift the less significant int by this number, zero-filling 259 // the MSBs. 260 SecShift = IRB.CreateLShr(FSHFunc->getArg(1), SubRotateVal); 261 } 262 // A simple binary addition of the shifted ints yields the final result. 263 IRB.CreateRet(IRB.CreateOr(FirstShift, SecShift)); 264 265 FSHIntrinsic->setCalledFunction(FSHFunc); 266 } 267 268 static void buildUMulWithOverflowFunc(Module *M, Function *UMulFunc) { 269 // The function body is already created. 270 if (!UMulFunc->empty()) 271 return; 272 273 BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", UMulFunc); 274 IRBuilder<> IRB(EntryBB); 275 // Build the actual unsigned multiplication logic with the overflow 276 // indication. Do unsigned multiplication Mul = A * B. Then check 277 // if unsigned division Div = Mul / A is not equal to B. If so, 278 // then overflow has happened. 279 Value *Mul = IRB.CreateNUWMul(UMulFunc->getArg(0), UMulFunc->getArg(1)); 280 Value *Div = IRB.CreateUDiv(Mul, UMulFunc->getArg(0)); 281 Value *Overflow = IRB.CreateICmpNE(UMulFunc->getArg(0), Div); 282 283 // umul.with.overflow intrinsic return a structure, where the first element 284 // is the multiplication result, and the second is an overflow bit. 285 Type *StructTy = UMulFunc->getReturnType(); 286 Value *Agg = IRB.CreateInsertValue(PoisonValue::get(StructTy), Mul, {0}); 287 Value *Res = IRB.CreateInsertValue(Agg, Overflow, {1}); 288 IRB.CreateRet(Res); 289 } 290 291 static void lowerUMulWithOverflow(Module *M, IntrinsicInst *UMulIntrinsic) { 292 // Get a separate function - otherwise, we'd have to rework the CFG of the 293 // current one. Then simply replace the intrinsic uses with a call to the new 294 // function. 295 FunctionType *UMulFuncTy = UMulIntrinsic->getFunctionType(); 296 Type *FSHLRetTy = UMulFuncTy->getReturnType(); 297 const std::string FuncName = lowerLLVMIntrinsicName(UMulIntrinsic); 298 Function *UMulFunc = 299 getOrCreateFunction(M, FSHLRetTy, UMulFuncTy->params(), FuncName); 300 buildUMulWithOverflowFunc(M, UMulFunc); 301 UMulIntrinsic->setCalledFunction(UMulFunc); 302 } 303 304 static void substituteIntrinsicCalls(Module *M, Function *F) { 305 for (BasicBlock &BB : *F) { 306 for (Instruction &I : BB) { 307 auto Call = dyn_cast<CallInst>(&I); 308 if (!Call) 309 continue; 310 Call->setTailCall(false); 311 Function *CF = Call->getCalledFunction(); 312 if (!CF || !CF->isIntrinsic()) 313 continue; 314 auto *II = cast<IntrinsicInst>(Call); 315 if (II->getIntrinsicID() == Intrinsic::memset || 316 II->getIntrinsicID() == Intrinsic::bswap) 317 lowerIntrinsicToFunction(M, II); 318 else if (II->getIntrinsicID() == Intrinsic::fshl || 319 II->getIntrinsicID() == Intrinsic::fshr) 320 lowerFunnelShifts(M, II); 321 else if (II->getIntrinsicID() == Intrinsic::umul_with_overflow) 322 lowerUMulWithOverflow(M, II); 323 } 324 } 325 } 326 327 bool SPIRVPrepareFunctions::runOnModule(Module &M) { 328 for (Function &F : M) 329 substituteIntrinsicCalls(&M, &F); 330 331 std::vector<Function *> FuncsWorklist; 332 bool Changed = false; 333 for (auto &F : M) 334 FuncsWorklist.push_back(&F); 335 336 for (auto *Func : FuncsWorklist) { 337 Function *F = processFunctionSignature(Func); 338 339 bool CreatedNewF = F != Func; 340 341 if (Func->isDeclaration()) { 342 Changed |= CreatedNewF; 343 continue; 344 } 345 346 if (CreatedNewF) 347 Func->eraseFromParent(); 348 } 349 350 return Changed; 351 } 352 353 ModulePass *llvm::createSPIRVPrepareFunctionsPass() { 354 return new SPIRVPrepareFunctions(); 355 } 356