1 //===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass modifies function signatures containing aggregate arguments 10 // and/or return value before IRTranslator. Information about the original 11 // signatures is stored in metadata. It is used during call lowering to 12 // restore correct SPIR-V types of function arguments and return values. 13 // This pass also substitutes some llvm intrinsic calls with calls to newly 14 // generated functions (as the Khronos LLVM/SPIR-V Translator does). 15 // 16 // NOTE: this pass is a module-level one due to the necessity to modify 17 // GVs/functions. 18 // 19 //===----------------------------------------------------------------------===// 20 21 #include "SPIRV.h" 22 #include "SPIRVSubtarget.h" 23 #include "SPIRVTargetMachine.h" 24 #include "SPIRVUtils.h" 25 #include "llvm/CodeGen/IntrinsicLowering.h" 26 #include "llvm/IR/IRBuilder.h" 27 #include "llvm/IR/IntrinsicInst.h" 28 #include "llvm/IR/Intrinsics.h" 29 #include "llvm/IR/IntrinsicsSPIRV.h" 30 #include "llvm/Transforms/Utils/Cloning.h" 31 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h" 32 33 using namespace llvm; 34 35 namespace llvm { 36 void initializeSPIRVPrepareFunctionsPass(PassRegistry &); 37 } 38 39 namespace { 40 41 class SPIRVPrepareFunctions : public ModulePass { 42 const SPIRVTargetMachine &TM; 43 bool substituteIntrinsicCalls(Function *F); 44 Function *removeAggregateTypesFromSignature(Function *F); 45 46 public: 47 static char ID; 48 SPIRVPrepareFunctions(const SPIRVTargetMachine &TM) : ModulePass(ID), TM(TM) { 49 initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry()); 50 } 51 52 bool runOnModule(Module &M) override; 53 54 StringRef getPassName() const override { return "SPIRV prepare functions"; } 55 56 void getAnalysisUsage(AnalysisUsage &AU) const override { 57 ModulePass::getAnalysisUsage(AU); 58 } 59 }; 60 61 } // namespace 62 63 char SPIRVPrepareFunctions::ID = 0; 64 65 INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions", 66 "SPIRV prepare functions", false, false) 67 68 std::string lowerLLVMIntrinsicName(IntrinsicInst *II) { 69 Function *IntrinsicFunc = II->getCalledFunction(); 70 assert(IntrinsicFunc && "Missing function"); 71 std::string FuncName = IntrinsicFunc->getName().str(); 72 std::replace(FuncName.begin(), FuncName.end(), '.', '_'); 73 FuncName = "spirv." + FuncName; 74 return FuncName; 75 } 76 77 static Function *getOrCreateFunction(Module *M, Type *RetTy, 78 ArrayRef<Type *> ArgTypes, 79 StringRef Name) { 80 FunctionType *FT = FunctionType::get(RetTy, ArgTypes, false); 81 Function *F = M->getFunction(Name); 82 if (F && F->getFunctionType() == FT) 83 return F; 84 Function *NewF = Function::Create(FT, GlobalValue::ExternalLinkage, Name, M); 85 if (F) 86 NewF->setDSOLocal(F->isDSOLocal()); 87 NewF->setCallingConv(CallingConv::SPIR_FUNC); 88 return NewF; 89 } 90 91 static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic) { 92 // For @llvm.memset.* intrinsic cases with constant value and length arguments 93 // are emulated via "storing" a constant array to the destination. For other 94 // cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the 95 // intrinsic to a loop via expandMemSetAsLoop(). 96 if (auto *MSI = dyn_cast<MemSetInst>(Intrinsic)) 97 if (isa<Constant>(MSI->getValue()) && isa<ConstantInt>(MSI->getLength())) 98 return false; // It is handled later using OpCopyMemorySized. 99 100 Module *M = Intrinsic->getModule(); 101 std::string FuncName = lowerLLVMIntrinsicName(Intrinsic); 102 if (Intrinsic->isVolatile()) 103 FuncName += ".volatile"; 104 // Redirect @llvm.intrinsic.* call to @spirv.llvm_intrinsic_* 105 Function *F = M->getFunction(FuncName); 106 if (F) { 107 Intrinsic->setCalledFunction(F); 108 return true; 109 } 110 // TODO copy arguments attributes: nocapture writeonly. 111 FunctionCallee FC = 112 M->getOrInsertFunction(FuncName, Intrinsic->getFunctionType()); 113 auto IntrinsicID = Intrinsic->getIntrinsicID(); 114 Intrinsic->setCalledFunction(FC); 115 116 F = dyn_cast<Function>(FC.getCallee()); 117 assert(F && "Callee must be a function"); 118 119 switch (IntrinsicID) { 120 case Intrinsic::memset: { 121 auto *MSI = static_cast<MemSetInst *>(Intrinsic); 122 Argument *Dest = F->getArg(0); 123 Argument *Val = F->getArg(1); 124 Argument *Len = F->getArg(2); 125 Argument *IsVolatile = F->getArg(3); 126 Dest->setName("dest"); 127 Val->setName("val"); 128 Len->setName("len"); 129 IsVolatile->setName("isvolatile"); 130 BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F); 131 IRBuilder<> IRB(EntryBB); 132 auto *MemSet = IRB.CreateMemSet(Dest, Val, Len, MSI->getDestAlign(), 133 MSI->isVolatile()); 134 IRB.CreateRetVoid(); 135 expandMemSetAsLoop(cast<MemSetInst>(MemSet)); 136 MemSet->eraseFromParent(); 137 break; 138 } 139 case Intrinsic::bswap: { 140 BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F); 141 IRBuilder<> IRB(EntryBB); 142 auto *BSwap = IRB.CreateIntrinsic(Intrinsic::bswap, Intrinsic->getType(), 143 F->getArg(0)); 144 IRB.CreateRet(BSwap); 145 IntrinsicLowering IL(M->getDataLayout()); 146 IL.LowerIntrinsicCall(BSwap); 147 break; 148 } 149 default: 150 break; 151 } 152 return true; 153 } 154 155 static void lowerFunnelShifts(IntrinsicInst *FSHIntrinsic) { 156 // Get a separate function - otherwise, we'd have to rework the CFG of the 157 // current one. Then simply replace the intrinsic uses with a call to the new 158 // function. 159 // Generate LLVM IR for i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c) 160 Module *M = FSHIntrinsic->getModule(); 161 FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType(); 162 Type *FSHRetTy = FSHFuncTy->getReturnType(); 163 const std::string FuncName = lowerLLVMIntrinsicName(FSHIntrinsic); 164 Function *FSHFunc = 165 getOrCreateFunction(M, FSHRetTy, FSHFuncTy->params(), FuncName); 166 167 if (!FSHFunc->empty()) { 168 FSHIntrinsic->setCalledFunction(FSHFunc); 169 return; 170 } 171 BasicBlock *RotateBB = BasicBlock::Create(M->getContext(), "rotate", FSHFunc); 172 IRBuilder<> IRB(RotateBB); 173 Type *Ty = FSHFunc->getReturnType(); 174 // Build the actual funnel shift rotate logic. 175 // In the comments, "int" is used interchangeably with "vector of int 176 // elements". 177 FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Ty); 178 Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty; 179 unsigned BitWidth = IntTy->getIntegerBitWidth(); 180 ConstantInt *BitWidthConstant = IRB.getInt({BitWidth, BitWidth}); 181 Value *BitWidthForInsts = 182 VectorTy 183 ? IRB.CreateVectorSplat(VectorTy->getNumElements(), BitWidthConstant) 184 : BitWidthConstant; 185 Value *RotateModVal = 186 IRB.CreateURem(/*Rotate*/ FSHFunc->getArg(2), BitWidthForInsts); 187 Value *FirstShift = nullptr, *SecShift = nullptr; 188 if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) { 189 // Shift the less significant number right, the "rotate" number of bits 190 // will be 0-filled on the left as a result of this regular shift. 191 FirstShift = IRB.CreateLShr(FSHFunc->getArg(1), RotateModVal); 192 } else { 193 // Shift the more significant number left, the "rotate" number of bits 194 // will be 0-filled on the right as a result of this regular shift. 195 FirstShift = IRB.CreateShl(FSHFunc->getArg(0), RotateModVal); 196 } 197 // We want the "rotate" number of the more significant int's LSBs (MSBs) to 198 // occupy the leftmost (rightmost) "0 space" left by the previous operation. 199 // Therefore, subtract the "rotate" number from the integer bitsize... 200 Value *SubRotateVal = IRB.CreateSub(BitWidthForInsts, RotateModVal); 201 if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) { 202 // ...and left-shift the more significant int by this number, zero-filling 203 // the LSBs. 204 SecShift = IRB.CreateShl(FSHFunc->getArg(0), SubRotateVal); 205 } else { 206 // ...and right-shift the less significant int by this number, zero-filling 207 // the MSBs. 208 SecShift = IRB.CreateLShr(FSHFunc->getArg(1), SubRotateVal); 209 } 210 // A simple binary addition of the shifted ints yields the final result. 211 IRB.CreateRet(IRB.CreateOr(FirstShift, SecShift)); 212 213 FSHIntrinsic->setCalledFunction(FSHFunc); 214 } 215 216 static void buildUMulWithOverflowFunc(Function *UMulFunc) { 217 // The function body is already created. 218 if (!UMulFunc->empty()) 219 return; 220 221 BasicBlock *EntryBB = BasicBlock::Create(UMulFunc->getParent()->getContext(), 222 "entry", UMulFunc); 223 IRBuilder<> IRB(EntryBB); 224 // Build the actual unsigned multiplication logic with the overflow 225 // indication. Do unsigned multiplication Mul = A * B. Then check 226 // if unsigned division Div = Mul / A is not equal to B. If so, 227 // then overflow has happened. 228 Value *Mul = IRB.CreateNUWMul(UMulFunc->getArg(0), UMulFunc->getArg(1)); 229 Value *Div = IRB.CreateUDiv(Mul, UMulFunc->getArg(0)); 230 Value *Overflow = IRB.CreateICmpNE(UMulFunc->getArg(0), Div); 231 232 // umul.with.overflow intrinsic return a structure, where the first element 233 // is the multiplication result, and the second is an overflow bit. 234 Type *StructTy = UMulFunc->getReturnType(); 235 Value *Agg = IRB.CreateInsertValue(PoisonValue::get(StructTy), Mul, {0}); 236 Value *Res = IRB.CreateInsertValue(Agg, Overflow, {1}); 237 IRB.CreateRet(Res); 238 } 239 240 static void lowerExpectAssume(IntrinsicInst *II) { 241 // If we cannot use the SPV_KHR_expect_assume extension, then we need to 242 // ignore the intrinsic and move on. It should be removed later on by LLVM. 243 // Otherwise we should lower the intrinsic to the corresponding SPIR-V 244 // instruction. 245 // For @llvm.assume we have OpAssumeTrueKHR. 246 // For @llvm.expect we have OpExpectKHR. 247 // 248 // We need to lower this into a builtin and then the builtin into a SPIR-V 249 // instruction. 250 if (II->getIntrinsicID() == Intrinsic::assume) { 251 Function *F = Intrinsic::getDeclaration( 252 II->getModule(), Intrinsic::SPVIntrinsics::spv_assume); 253 II->setCalledFunction(F); 254 } else if (II->getIntrinsicID() == Intrinsic::expect) { 255 Function *F = Intrinsic::getDeclaration( 256 II->getModule(), Intrinsic::SPVIntrinsics::spv_expect, 257 {II->getOperand(0)->getType()}); 258 II->setCalledFunction(F); 259 } else { 260 llvm_unreachable("Unknown intrinsic"); 261 } 262 263 return; 264 } 265 266 static void lowerUMulWithOverflow(IntrinsicInst *UMulIntrinsic) { 267 // Get a separate function - otherwise, we'd have to rework the CFG of the 268 // current one. Then simply replace the intrinsic uses with a call to the new 269 // function. 270 Module *M = UMulIntrinsic->getModule(); 271 FunctionType *UMulFuncTy = UMulIntrinsic->getFunctionType(); 272 Type *FSHLRetTy = UMulFuncTy->getReturnType(); 273 const std::string FuncName = lowerLLVMIntrinsicName(UMulIntrinsic); 274 Function *UMulFunc = 275 getOrCreateFunction(M, FSHLRetTy, UMulFuncTy->params(), FuncName); 276 buildUMulWithOverflowFunc(UMulFunc); 277 UMulIntrinsic->setCalledFunction(UMulFunc); 278 } 279 280 // Substitutes calls to LLVM intrinsics with either calls to SPIR-V intrinsics 281 // or calls to proper generated functions. Returns True if F was modified. 282 bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) { 283 bool Changed = false; 284 for (BasicBlock &BB : *F) { 285 for (Instruction &I : BB) { 286 auto Call = dyn_cast<CallInst>(&I); 287 if (!Call) 288 continue; 289 Function *CF = Call->getCalledFunction(); 290 if (!CF || !CF->isIntrinsic()) 291 continue; 292 auto *II = cast<IntrinsicInst>(Call); 293 if (II->getIntrinsicID() == Intrinsic::memset || 294 II->getIntrinsicID() == Intrinsic::bswap) 295 Changed |= lowerIntrinsicToFunction(II); 296 else if (II->getIntrinsicID() == Intrinsic::fshl || 297 II->getIntrinsicID() == Intrinsic::fshr) { 298 lowerFunnelShifts(II); 299 Changed = true; 300 } else if (II->getIntrinsicID() == Intrinsic::umul_with_overflow) { 301 lowerUMulWithOverflow(II); 302 Changed = true; 303 } else if (II->getIntrinsicID() == Intrinsic::assume || 304 II->getIntrinsicID() == Intrinsic::expect) { 305 const SPIRVSubtarget &STI = TM.getSubtarget<SPIRVSubtarget>(*F); 306 if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) 307 lowerExpectAssume(II); 308 Changed = true; 309 } 310 } 311 } 312 return Changed; 313 } 314 315 // Returns F if aggregate argument/return types are not present or cloned F 316 // function with the types replaced by i32 types. The change in types is 317 // noted in 'spv.cloned_funcs' metadata for later restoration. 318 Function * 319 SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) { 320 IRBuilder<> B(F->getContext()); 321 322 bool IsRetAggr = F->getReturnType()->isAggregateType(); 323 bool HasAggrArg = 324 std::any_of(F->arg_begin(), F->arg_end(), [](Argument &Arg) { 325 return Arg.getType()->isAggregateType(); 326 }); 327 bool DoClone = IsRetAggr || HasAggrArg; 328 if (!DoClone) 329 return F; 330 SmallVector<std::pair<int, Type *>, 4> ChangedTypes; 331 Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType(); 332 if (IsRetAggr) 333 ChangedTypes.push_back(std::pair<int, Type *>(-1, F->getReturnType())); 334 SmallVector<Type *, 4> ArgTypes; 335 for (const auto &Arg : F->args()) { 336 if (Arg.getType()->isAggregateType()) { 337 ArgTypes.push_back(B.getInt32Ty()); 338 ChangedTypes.push_back( 339 std::pair<int, Type *>(Arg.getArgNo(), Arg.getType())); 340 } else 341 ArgTypes.push_back(Arg.getType()); 342 } 343 FunctionType *NewFTy = 344 FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg()); 345 Function *NewF = 346 Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent()); 347 348 ValueToValueMapTy VMap; 349 auto NewFArgIt = NewF->arg_begin(); 350 for (auto &Arg : F->args()) { 351 StringRef ArgName = Arg.getName(); 352 NewFArgIt->setName(ArgName); 353 VMap[&Arg] = &(*NewFArgIt++); 354 } 355 SmallVector<ReturnInst *, 8> Returns; 356 357 CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly, 358 Returns); 359 NewF->takeName(F); 360 361 NamedMDNode *FuncMD = 362 F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs"); 363 SmallVector<Metadata *, 2> MDArgs; 364 MDArgs.push_back(MDString::get(B.getContext(), NewF->getName())); 365 for (auto &ChangedTyP : ChangedTypes) 366 MDArgs.push_back(MDNode::get( 367 B.getContext(), 368 {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)), 369 ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))})); 370 MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs); 371 FuncMD->addOperand(ThisFuncMD); 372 373 for (auto *U : make_early_inc_range(F->users())) { 374 if (auto *CI = dyn_cast<CallInst>(U)) 375 CI->mutateFunctionType(NewF->getFunctionType()); 376 U->replaceUsesOfWith(F, NewF); 377 } 378 return NewF; 379 } 380 381 bool SPIRVPrepareFunctions::runOnModule(Module &M) { 382 bool Changed = false; 383 for (Function &F : M) 384 Changed |= substituteIntrinsicCalls(&F); 385 386 std::vector<Function *> FuncsWorklist; 387 for (auto &F : M) 388 FuncsWorklist.push_back(&F); 389 390 for (auto *F : FuncsWorklist) { 391 Function *NewF = removeAggregateTypesFromSignature(F); 392 393 if (NewF != F) { 394 F->eraseFromParent(); 395 Changed = true; 396 } 397 } 398 return Changed; 399 } 400 401 ModulePass * 402 llvm::createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM) { 403 return new SPIRVPrepareFunctions(TM); 404 } 405