1 //===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass modifies function signatures containing aggregate arguments 10 // and/or return value before IRTranslator. Information about the original 11 // signatures is stored in metadata. It is used during call lowering to 12 // restore correct SPIR-V types of function arguments and return values. 13 // This pass also substitutes some llvm intrinsic calls with calls to newly 14 // generated functions (as the Khronos LLVM/SPIR-V Translator does). 15 // 16 // NOTE: this pass is a module-level one due to the necessity to modify 17 // GVs/functions. 18 // 19 //===----------------------------------------------------------------------===// 20 21 #include "SPIRV.h" 22 #include "SPIRVTargetMachine.h" 23 #include "SPIRVUtils.h" 24 #include "llvm/CodeGen/IntrinsicLowering.h" 25 #include "llvm/IR/IRBuilder.h" 26 #include "llvm/IR/IntrinsicInst.h" 27 #include "llvm/Transforms/Utils/Cloning.h" 28 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h" 29 30 using namespace llvm; 31 32 namespace llvm { 33 void initializeSPIRVPrepareFunctionsPass(PassRegistry &); 34 } 35 36 namespace { 37 38 class SPIRVPrepareFunctions : public ModulePass { 39 bool substituteIntrinsicCalls(Function *F); 40 Function *removeAggregateTypesFromSignature(Function *F); 41 42 public: 43 static char ID; 44 SPIRVPrepareFunctions() : ModulePass(ID) { 45 initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry()); 46 } 47 48 bool runOnModule(Module &M) override; 49 50 StringRef getPassName() const override { return "SPIRV prepare functions"; } 51 52 void getAnalysisUsage(AnalysisUsage &AU) const override { 53 ModulePass::getAnalysisUsage(AU); 54 } 55 }; 56 57 } // namespace 58 59 char SPIRVPrepareFunctions::ID = 0; 60 61 INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions", 62 "SPIRV prepare functions", false, false) 63 64 std::string lowerLLVMIntrinsicName(IntrinsicInst *II) { 65 Function *IntrinsicFunc = II->getCalledFunction(); 66 assert(IntrinsicFunc && "Missing function"); 67 std::string FuncName = IntrinsicFunc->getName().str(); 68 std::replace(FuncName.begin(), FuncName.end(), '.', '_'); 69 FuncName = "spirv." + FuncName; 70 return FuncName; 71 } 72 73 static Function *getOrCreateFunction(Module *M, Type *RetTy, 74 ArrayRef<Type *> ArgTypes, 75 StringRef Name) { 76 FunctionType *FT = FunctionType::get(RetTy, ArgTypes, false); 77 Function *F = M->getFunction(Name); 78 if (F && F->getFunctionType() == FT) 79 return F; 80 Function *NewF = Function::Create(FT, GlobalValue::ExternalLinkage, Name, M); 81 if (F) 82 NewF->setDSOLocal(F->isDSOLocal()); 83 NewF->setCallingConv(CallingConv::SPIR_FUNC); 84 return NewF; 85 } 86 87 static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic) { 88 // For @llvm.memset.* intrinsic cases with constant value and length arguments 89 // are emulated via "storing" a constant array to the destination. For other 90 // cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the 91 // intrinsic to a loop via expandMemSetAsLoop(). 92 if (auto *MSI = dyn_cast<MemSetInst>(Intrinsic)) 93 if (isa<Constant>(MSI->getValue()) && isa<ConstantInt>(MSI->getLength())) 94 return false; // It is handled later using OpCopyMemorySized. 95 96 Module *M = Intrinsic->getModule(); 97 std::string FuncName = lowerLLVMIntrinsicName(Intrinsic); 98 if (Intrinsic->isVolatile()) 99 FuncName += ".volatile"; 100 // Redirect @llvm.intrinsic.* call to @spirv.llvm_intrinsic_* 101 Function *F = M->getFunction(FuncName); 102 if (F) { 103 Intrinsic->setCalledFunction(F); 104 return true; 105 } 106 // TODO copy arguments attributes: nocapture writeonly. 107 FunctionCallee FC = 108 M->getOrInsertFunction(FuncName, Intrinsic->getFunctionType()); 109 auto IntrinsicID = Intrinsic->getIntrinsicID(); 110 Intrinsic->setCalledFunction(FC); 111 112 F = dyn_cast<Function>(FC.getCallee()); 113 assert(F && "Callee must be a function"); 114 115 switch (IntrinsicID) { 116 case Intrinsic::memset: { 117 auto *MSI = static_cast<MemSetInst *>(Intrinsic); 118 Argument *Dest = F->getArg(0); 119 Argument *Val = F->getArg(1); 120 Argument *Len = F->getArg(2); 121 Argument *IsVolatile = F->getArg(3); 122 Dest->setName("dest"); 123 Val->setName("val"); 124 Len->setName("len"); 125 IsVolatile->setName("isvolatile"); 126 BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F); 127 IRBuilder<> IRB(EntryBB); 128 auto *MemSet = IRB.CreateMemSet(Dest, Val, Len, MSI->getDestAlign(), 129 MSI->isVolatile()); 130 IRB.CreateRetVoid(); 131 expandMemSetAsLoop(cast<MemSetInst>(MemSet)); 132 MemSet->eraseFromParent(); 133 break; 134 } 135 case Intrinsic::bswap: { 136 BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F); 137 IRBuilder<> IRB(EntryBB); 138 auto *BSwap = IRB.CreateIntrinsic(Intrinsic::bswap, Intrinsic->getType(), 139 F->getArg(0)); 140 IRB.CreateRet(BSwap); 141 IntrinsicLowering IL(M->getDataLayout()); 142 IL.LowerIntrinsicCall(BSwap); 143 break; 144 } 145 default: 146 break; 147 } 148 return true; 149 } 150 151 static void lowerFunnelShifts(IntrinsicInst *FSHIntrinsic) { 152 // Get a separate function - otherwise, we'd have to rework the CFG of the 153 // current one. Then simply replace the intrinsic uses with a call to the new 154 // function. 155 // Generate LLVM IR for i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c) 156 Module *M = FSHIntrinsic->getModule(); 157 FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType(); 158 Type *FSHRetTy = FSHFuncTy->getReturnType(); 159 const std::string FuncName = lowerLLVMIntrinsicName(FSHIntrinsic); 160 Function *FSHFunc = 161 getOrCreateFunction(M, FSHRetTy, FSHFuncTy->params(), FuncName); 162 163 if (!FSHFunc->empty()) { 164 FSHIntrinsic->setCalledFunction(FSHFunc); 165 return; 166 } 167 BasicBlock *RotateBB = BasicBlock::Create(M->getContext(), "rotate", FSHFunc); 168 IRBuilder<> IRB(RotateBB); 169 Type *Ty = FSHFunc->getReturnType(); 170 // Build the actual funnel shift rotate logic. 171 // In the comments, "int" is used interchangeably with "vector of int 172 // elements". 173 FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Ty); 174 Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty; 175 unsigned BitWidth = IntTy->getIntegerBitWidth(); 176 ConstantInt *BitWidthConstant = IRB.getInt({BitWidth, BitWidth}); 177 Value *BitWidthForInsts = 178 VectorTy 179 ? IRB.CreateVectorSplat(VectorTy->getNumElements(), BitWidthConstant) 180 : BitWidthConstant; 181 Value *RotateModVal = 182 IRB.CreateURem(/*Rotate*/ FSHFunc->getArg(2), BitWidthForInsts); 183 Value *FirstShift = nullptr, *SecShift = nullptr; 184 if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) { 185 // Shift the less significant number right, the "rotate" number of bits 186 // will be 0-filled on the left as a result of this regular shift. 187 FirstShift = IRB.CreateLShr(FSHFunc->getArg(1), RotateModVal); 188 } else { 189 // Shift the more significant number left, the "rotate" number of bits 190 // will be 0-filled on the right as a result of this regular shift. 191 FirstShift = IRB.CreateShl(FSHFunc->getArg(0), RotateModVal); 192 } 193 // We want the "rotate" number of the more significant int's LSBs (MSBs) to 194 // occupy the leftmost (rightmost) "0 space" left by the previous operation. 195 // Therefore, subtract the "rotate" number from the integer bitsize... 196 Value *SubRotateVal = IRB.CreateSub(BitWidthForInsts, RotateModVal); 197 if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) { 198 // ...and left-shift the more significant int by this number, zero-filling 199 // the LSBs. 200 SecShift = IRB.CreateShl(FSHFunc->getArg(0), SubRotateVal); 201 } else { 202 // ...and right-shift the less significant int by this number, zero-filling 203 // the MSBs. 204 SecShift = IRB.CreateLShr(FSHFunc->getArg(1), SubRotateVal); 205 } 206 // A simple binary addition of the shifted ints yields the final result. 207 IRB.CreateRet(IRB.CreateOr(FirstShift, SecShift)); 208 209 FSHIntrinsic->setCalledFunction(FSHFunc); 210 } 211 212 static void buildUMulWithOverflowFunc(Function *UMulFunc) { 213 // The function body is already created. 214 if (!UMulFunc->empty()) 215 return; 216 217 BasicBlock *EntryBB = BasicBlock::Create(UMulFunc->getParent()->getContext(), 218 "entry", UMulFunc); 219 IRBuilder<> IRB(EntryBB); 220 // Build the actual unsigned multiplication logic with the overflow 221 // indication. Do unsigned multiplication Mul = A * B. Then check 222 // if unsigned division Div = Mul / A is not equal to B. If so, 223 // then overflow has happened. 224 Value *Mul = IRB.CreateNUWMul(UMulFunc->getArg(0), UMulFunc->getArg(1)); 225 Value *Div = IRB.CreateUDiv(Mul, UMulFunc->getArg(0)); 226 Value *Overflow = IRB.CreateICmpNE(UMulFunc->getArg(0), Div); 227 228 // umul.with.overflow intrinsic return a structure, where the first element 229 // is the multiplication result, and the second is an overflow bit. 230 Type *StructTy = UMulFunc->getReturnType(); 231 Value *Agg = IRB.CreateInsertValue(PoisonValue::get(StructTy), Mul, {0}); 232 Value *Res = IRB.CreateInsertValue(Agg, Overflow, {1}); 233 IRB.CreateRet(Res); 234 } 235 236 static void lowerUMulWithOverflow(IntrinsicInst *UMulIntrinsic) { 237 // Get a separate function - otherwise, we'd have to rework the CFG of the 238 // current one. Then simply replace the intrinsic uses with a call to the new 239 // function. 240 Module *M = UMulIntrinsic->getModule(); 241 FunctionType *UMulFuncTy = UMulIntrinsic->getFunctionType(); 242 Type *FSHLRetTy = UMulFuncTy->getReturnType(); 243 const std::string FuncName = lowerLLVMIntrinsicName(UMulIntrinsic); 244 Function *UMulFunc = 245 getOrCreateFunction(M, FSHLRetTy, UMulFuncTy->params(), FuncName); 246 buildUMulWithOverflowFunc(UMulFunc); 247 UMulIntrinsic->setCalledFunction(UMulFunc); 248 } 249 250 // Substitutes calls to LLVM intrinsics with either calls to SPIR-V intrinsics 251 // or calls to proper generated functions. Returns True if F was modified. 252 bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) { 253 bool Changed = false; 254 for (BasicBlock &BB : *F) { 255 for (Instruction &I : BB) { 256 auto Call = dyn_cast<CallInst>(&I); 257 if (!Call) 258 continue; 259 Function *CF = Call->getCalledFunction(); 260 if (!CF || !CF->isIntrinsic()) 261 continue; 262 auto *II = cast<IntrinsicInst>(Call); 263 if (II->getIntrinsicID() == Intrinsic::memset || 264 II->getIntrinsicID() == Intrinsic::bswap) 265 Changed |= lowerIntrinsicToFunction(II); 266 else if (II->getIntrinsicID() == Intrinsic::fshl || 267 II->getIntrinsicID() == Intrinsic::fshr) { 268 lowerFunnelShifts(II); 269 Changed = true; 270 } else if (II->getIntrinsicID() == Intrinsic::umul_with_overflow) { 271 lowerUMulWithOverflow(II); 272 Changed = true; 273 } 274 } 275 } 276 return Changed; 277 } 278 279 // Returns F if aggregate argument/return types are not present or cloned F 280 // function with the types replaced by i32 types. The change in types is 281 // noted in 'spv.cloned_funcs' metadata for later restoration. 282 Function * 283 SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) { 284 IRBuilder<> B(F->getContext()); 285 286 bool IsRetAggr = F->getReturnType()->isAggregateType(); 287 bool HasAggrArg = 288 std::any_of(F->arg_begin(), F->arg_end(), [](Argument &Arg) { 289 return Arg.getType()->isAggregateType(); 290 }); 291 bool DoClone = IsRetAggr || HasAggrArg; 292 if (!DoClone) 293 return F; 294 SmallVector<std::pair<int, Type *>, 4> ChangedTypes; 295 Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType(); 296 if (IsRetAggr) 297 ChangedTypes.push_back(std::pair<int, Type *>(-1, F->getReturnType())); 298 SmallVector<Type *, 4> ArgTypes; 299 for (const auto &Arg : F->args()) { 300 if (Arg.getType()->isAggregateType()) { 301 ArgTypes.push_back(B.getInt32Ty()); 302 ChangedTypes.push_back( 303 std::pair<int, Type *>(Arg.getArgNo(), Arg.getType())); 304 } else 305 ArgTypes.push_back(Arg.getType()); 306 } 307 FunctionType *NewFTy = 308 FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg()); 309 Function *NewF = 310 Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent()); 311 312 ValueToValueMapTy VMap; 313 auto NewFArgIt = NewF->arg_begin(); 314 for (auto &Arg : F->args()) { 315 StringRef ArgName = Arg.getName(); 316 NewFArgIt->setName(ArgName); 317 VMap[&Arg] = &(*NewFArgIt++); 318 } 319 SmallVector<ReturnInst *, 8> Returns; 320 321 CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly, 322 Returns); 323 NewF->takeName(F); 324 325 NamedMDNode *FuncMD = 326 F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs"); 327 SmallVector<Metadata *, 2> MDArgs; 328 MDArgs.push_back(MDString::get(B.getContext(), NewF->getName())); 329 for (auto &ChangedTyP : ChangedTypes) 330 MDArgs.push_back(MDNode::get( 331 B.getContext(), 332 {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)), 333 ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))})); 334 MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs); 335 FuncMD->addOperand(ThisFuncMD); 336 337 for (auto *U : make_early_inc_range(F->users())) { 338 if (auto *CI = dyn_cast<CallInst>(U)) 339 CI->mutateFunctionType(NewF->getFunctionType()); 340 U->replaceUsesOfWith(F, NewF); 341 } 342 return NewF; 343 } 344 345 bool SPIRVPrepareFunctions::runOnModule(Module &M) { 346 bool Changed = false; 347 for (Function &F : M) 348 Changed |= substituteIntrinsicCalls(&F); 349 350 std::vector<Function *> FuncsWorklist; 351 for (auto &F : M) 352 FuncsWorklist.push_back(&F); 353 354 for (auto *F : FuncsWorklist) { 355 Function *NewF = removeAggregateTypesFromSignature(F); 356 357 if (NewF != F) { 358 F->eraseFromParent(); 359 Changed = true; 360 } 361 } 362 return Changed; 363 } 364 365 ModulePass *llvm::createSPIRVPrepareFunctionsPass() { 366 return new SPIRVPrepareFunctions(); 367 } 368