xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp (revision a90b9d0159070121c221b966469c3e36d912bf82)
1 //===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass modifies function signatures containing aggregate arguments
10 // and/or return value before IRTranslator. Information about the original
11 // signatures is stored in metadata. It is used during call lowering to
12 // restore correct SPIR-V types of function arguments and return values.
13 // This pass also substitutes some llvm intrinsic calls with calls to newly
14 // generated functions (as the Khronos LLVM/SPIR-V Translator does).
15 //
16 // NOTE: this pass is a module-level one due to the necessity to modify
17 // GVs/functions.
18 //
19 //===----------------------------------------------------------------------===//
20 
21 #include "SPIRV.h"
22 #include "SPIRVSubtarget.h"
23 #include "SPIRVTargetMachine.h"
24 #include "SPIRVUtils.h"
25 #include "llvm/CodeGen/IntrinsicLowering.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/IntrinsicInst.h"
28 #include "llvm/IR/Intrinsics.h"
29 #include "llvm/IR/IntrinsicsSPIRV.h"
30 #include "llvm/Transforms/Utils/Cloning.h"
31 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
32 
33 using namespace llvm;
34 
35 namespace llvm {
36 void initializeSPIRVPrepareFunctionsPass(PassRegistry &);
37 }
38 
39 namespace {
40 
41 class SPIRVPrepareFunctions : public ModulePass {
42   const SPIRVTargetMachine &TM;
43   bool substituteIntrinsicCalls(Function *F);
44   Function *removeAggregateTypesFromSignature(Function *F);
45 
46 public:
47   static char ID;
48   SPIRVPrepareFunctions(const SPIRVTargetMachine &TM) : ModulePass(ID), TM(TM) {
49     initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry());
50   }
51 
52   bool runOnModule(Module &M) override;
53 
54   StringRef getPassName() const override { return "SPIRV prepare functions"; }
55 
56   void getAnalysisUsage(AnalysisUsage &AU) const override {
57     ModulePass::getAnalysisUsage(AU);
58   }
59 };
60 
61 } // namespace
62 
63 char SPIRVPrepareFunctions::ID = 0;
64 
65 INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions",
66                 "SPIRV prepare functions", false, false)
67 
68 std::string lowerLLVMIntrinsicName(IntrinsicInst *II) {
69   Function *IntrinsicFunc = II->getCalledFunction();
70   assert(IntrinsicFunc && "Missing function");
71   std::string FuncName = IntrinsicFunc->getName().str();
72   std::replace(FuncName.begin(), FuncName.end(), '.', '_');
73   FuncName = "spirv." + FuncName;
74   return FuncName;
75 }
76 
77 static Function *getOrCreateFunction(Module *M, Type *RetTy,
78                                      ArrayRef<Type *> ArgTypes,
79                                      StringRef Name) {
80   FunctionType *FT = FunctionType::get(RetTy, ArgTypes, false);
81   Function *F = M->getFunction(Name);
82   if (F && F->getFunctionType() == FT)
83     return F;
84   Function *NewF = Function::Create(FT, GlobalValue::ExternalLinkage, Name, M);
85   if (F)
86     NewF->setDSOLocal(F->isDSOLocal());
87   NewF->setCallingConv(CallingConv::SPIR_FUNC);
88   return NewF;
89 }
90 
91 static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic) {
92   // For @llvm.memset.* intrinsic cases with constant value and length arguments
93   // are emulated via "storing" a constant array to the destination. For other
94   // cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the
95   // intrinsic to a loop via expandMemSetAsLoop().
96   if (auto *MSI = dyn_cast<MemSetInst>(Intrinsic))
97     if (isa<Constant>(MSI->getValue()) && isa<ConstantInt>(MSI->getLength()))
98       return false; // It is handled later using OpCopyMemorySized.
99 
100   Module *M = Intrinsic->getModule();
101   std::string FuncName = lowerLLVMIntrinsicName(Intrinsic);
102   if (Intrinsic->isVolatile())
103     FuncName += ".volatile";
104   // Redirect @llvm.intrinsic.* call to @spirv.llvm_intrinsic_*
105   Function *F = M->getFunction(FuncName);
106   if (F) {
107     Intrinsic->setCalledFunction(F);
108     return true;
109   }
110   // TODO copy arguments attributes: nocapture writeonly.
111   FunctionCallee FC =
112       M->getOrInsertFunction(FuncName, Intrinsic->getFunctionType());
113   auto IntrinsicID = Intrinsic->getIntrinsicID();
114   Intrinsic->setCalledFunction(FC);
115 
116   F = dyn_cast<Function>(FC.getCallee());
117   assert(F && "Callee must be a function");
118 
119   switch (IntrinsicID) {
120   case Intrinsic::memset: {
121     auto *MSI = static_cast<MemSetInst *>(Intrinsic);
122     Argument *Dest = F->getArg(0);
123     Argument *Val = F->getArg(1);
124     Argument *Len = F->getArg(2);
125     Argument *IsVolatile = F->getArg(3);
126     Dest->setName("dest");
127     Val->setName("val");
128     Len->setName("len");
129     IsVolatile->setName("isvolatile");
130     BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);
131     IRBuilder<> IRB(EntryBB);
132     auto *MemSet = IRB.CreateMemSet(Dest, Val, Len, MSI->getDestAlign(),
133                                     MSI->isVolatile());
134     IRB.CreateRetVoid();
135     expandMemSetAsLoop(cast<MemSetInst>(MemSet));
136     MemSet->eraseFromParent();
137     break;
138   }
139   case Intrinsic::bswap: {
140     BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);
141     IRBuilder<> IRB(EntryBB);
142     auto *BSwap = IRB.CreateIntrinsic(Intrinsic::bswap, Intrinsic->getType(),
143                                       F->getArg(0));
144     IRB.CreateRet(BSwap);
145     IntrinsicLowering IL(M->getDataLayout());
146     IL.LowerIntrinsicCall(BSwap);
147     break;
148   }
149   default:
150     break;
151   }
152   return true;
153 }
154 
155 static void lowerFunnelShifts(IntrinsicInst *FSHIntrinsic) {
156   // Get a separate function - otherwise, we'd have to rework the CFG of the
157   // current one. Then simply replace the intrinsic uses with a call to the new
158   // function.
159   // Generate LLVM IR for  i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c)
160   Module *M = FSHIntrinsic->getModule();
161   FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType();
162   Type *FSHRetTy = FSHFuncTy->getReturnType();
163   const std::string FuncName = lowerLLVMIntrinsicName(FSHIntrinsic);
164   Function *FSHFunc =
165       getOrCreateFunction(M, FSHRetTy, FSHFuncTy->params(), FuncName);
166 
167   if (!FSHFunc->empty()) {
168     FSHIntrinsic->setCalledFunction(FSHFunc);
169     return;
170   }
171   BasicBlock *RotateBB = BasicBlock::Create(M->getContext(), "rotate", FSHFunc);
172   IRBuilder<> IRB(RotateBB);
173   Type *Ty = FSHFunc->getReturnType();
174   // Build the actual funnel shift rotate logic.
175   // In the comments, "int" is used interchangeably with "vector of int
176   // elements".
177   FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Ty);
178   Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty;
179   unsigned BitWidth = IntTy->getIntegerBitWidth();
180   ConstantInt *BitWidthConstant = IRB.getInt({BitWidth, BitWidth});
181   Value *BitWidthForInsts =
182       VectorTy
183           ? IRB.CreateVectorSplat(VectorTy->getNumElements(), BitWidthConstant)
184           : BitWidthConstant;
185   Value *RotateModVal =
186       IRB.CreateURem(/*Rotate*/ FSHFunc->getArg(2), BitWidthForInsts);
187   Value *FirstShift = nullptr, *SecShift = nullptr;
188   if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
189     // Shift the less significant number right, the "rotate" number of bits
190     // will be 0-filled on the left as a result of this regular shift.
191     FirstShift = IRB.CreateLShr(FSHFunc->getArg(1), RotateModVal);
192   } else {
193     // Shift the more significant number left, the "rotate" number of bits
194     // will be 0-filled on the right as a result of this regular shift.
195     FirstShift = IRB.CreateShl(FSHFunc->getArg(0), RotateModVal);
196   }
197   // We want the "rotate" number of the more significant int's LSBs (MSBs) to
198   // occupy the leftmost (rightmost) "0 space" left by the previous operation.
199   // Therefore, subtract the "rotate" number from the integer bitsize...
200   Value *SubRotateVal = IRB.CreateSub(BitWidthForInsts, RotateModVal);
201   if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
202     // ...and left-shift the more significant int by this number, zero-filling
203     // the LSBs.
204     SecShift = IRB.CreateShl(FSHFunc->getArg(0), SubRotateVal);
205   } else {
206     // ...and right-shift the less significant int by this number, zero-filling
207     // the MSBs.
208     SecShift = IRB.CreateLShr(FSHFunc->getArg(1), SubRotateVal);
209   }
210   // A simple binary addition of the shifted ints yields the final result.
211   IRB.CreateRet(IRB.CreateOr(FirstShift, SecShift));
212 
213   FSHIntrinsic->setCalledFunction(FSHFunc);
214 }
215 
216 static void buildUMulWithOverflowFunc(Function *UMulFunc) {
217   // The function body is already created.
218   if (!UMulFunc->empty())
219     return;
220 
221   BasicBlock *EntryBB = BasicBlock::Create(UMulFunc->getParent()->getContext(),
222                                            "entry", UMulFunc);
223   IRBuilder<> IRB(EntryBB);
224   // Build the actual unsigned multiplication logic with the overflow
225   // indication. Do unsigned multiplication Mul = A * B. Then check
226   // if unsigned division Div = Mul / A is not equal to B. If so,
227   // then overflow has happened.
228   Value *Mul = IRB.CreateNUWMul(UMulFunc->getArg(0), UMulFunc->getArg(1));
229   Value *Div = IRB.CreateUDiv(Mul, UMulFunc->getArg(0));
230   Value *Overflow = IRB.CreateICmpNE(UMulFunc->getArg(0), Div);
231 
232   // umul.with.overflow intrinsic return a structure, where the first element
233   // is the multiplication result, and the second is an overflow bit.
234   Type *StructTy = UMulFunc->getReturnType();
235   Value *Agg = IRB.CreateInsertValue(PoisonValue::get(StructTy), Mul, {0});
236   Value *Res = IRB.CreateInsertValue(Agg, Overflow, {1});
237   IRB.CreateRet(Res);
238 }
239 
240 static void lowerExpectAssume(IntrinsicInst *II) {
241   // If we cannot use the SPV_KHR_expect_assume extension, then we need to
242   // ignore the intrinsic and move on. It should be removed later on by LLVM.
243   // Otherwise we should lower the intrinsic to the corresponding SPIR-V
244   // instruction.
245   // For @llvm.assume we have OpAssumeTrueKHR.
246   // For @llvm.expect we have OpExpectKHR.
247   //
248   // We need to lower this into a builtin and then the builtin into a SPIR-V
249   // instruction.
250   if (II->getIntrinsicID() == Intrinsic::assume) {
251     Function *F = Intrinsic::getDeclaration(
252         II->getModule(), Intrinsic::SPVIntrinsics::spv_assume);
253     II->setCalledFunction(F);
254   } else if (II->getIntrinsicID() == Intrinsic::expect) {
255     Function *F = Intrinsic::getDeclaration(
256         II->getModule(), Intrinsic::SPVIntrinsics::spv_expect,
257         {II->getOperand(0)->getType()});
258     II->setCalledFunction(F);
259   } else {
260     llvm_unreachable("Unknown intrinsic");
261   }
262 
263   return;
264 }
265 
266 static void lowerUMulWithOverflow(IntrinsicInst *UMulIntrinsic) {
267   // Get a separate function - otherwise, we'd have to rework the CFG of the
268   // current one. Then simply replace the intrinsic uses with a call to the new
269   // function.
270   Module *M = UMulIntrinsic->getModule();
271   FunctionType *UMulFuncTy = UMulIntrinsic->getFunctionType();
272   Type *FSHLRetTy = UMulFuncTy->getReturnType();
273   const std::string FuncName = lowerLLVMIntrinsicName(UMulIntrinsic);
274   Function *UMulFunc =
275       getOrCreateFunction(M, FSHLRetTy, UMulFuncTy->params(), FuncName);
276   buildUMulWithOverflowFunc(UMulFunc);
277   UMulIntrinsic->setCalledFunction(UMulFunc);
278 }
279 
280 // Substitutes calls to LLVM intrinsics with either calls to SPIR-V intrinsics
281 // or calls to proper generated functions. Returns True if F was modified.
282 bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
283   bool Changed = false;
284   for (BasicBlock &BB : *F) {
285     for (Instruction &I : BB) {
286       auto Call = dyn_cast<CallInst>(&I);
287       if (!Call)
288         continue;
289       Function *CF = Call->getCalledFunction();
290       if (!CF || !CF->isIntrinsic())
291         continue;
292       auto *II = cast<IntrinsicInst>(Call);
293       if (II->getIntrinsicID() == Intrinsic::memset ||
294           II->getIntrinsicID() == Intrinsic::bswap)
295         Changed |= lowerIntrinsicToFunction(II);
296       else if (II->getIntrinsicID() == Intrinsic::fshl ||
297                II->getIntrinsicID() == Intrinsic::fshr) {
298         lowerFunnelShifts(II);
299         Changed = true;
300       } else if (II->getIntrinsicID() == Intrinsic::umul_with_overflow) {
301         lowerUMulWithOverflow(II);
302         Changed = true;
303       } else if (II->getIntrinsicID() == Intrinsic::assume ||
304                  II->getIntrinsicID() == Intrinsic::expect) {
305         const SPIRVSubtarget &STI = TM.getSubtarget<SPIRVSubtarget>(*F);
306         if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume))
307           lowerExpectAssume(II);
308         Changed = true;
309       }
310     }
311   }
312   return Changed;
313 }
314 
315 // Returns F if aggregate argument/return types are not present or cloned F
316 // function with the types replaced by i32 types. The change in types is
317 // noted in 'spv.cloned_funcs' metadata for later restoration.
318 Function *
319 SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
320   IRBuilder<> B(F->getContext());
321 
322   bool IsRetAggr = F->getReturnType()->isAggregateType();
323   bool HasAggrArg =
324       std::any_of(F->arg_begin(), F->arg_end(), [](Argument &Arg) {
325         return Arg.getType()->isAggregateType();
326       });
327   bool DoClone = IsRetAggr || HasAggrArg;
328   if (!DoClone)
329     return F;
330   SmallVector<std::pair<int, Type *>, 4> ChangedTypes;
331   Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType();
332   if (IsRetAggr)
333     ChangedTypes.push_back(std::pair<int, Type *>(-1, F->getReturnType()));
334   SmallVector<Type *, 4> ArgTypes;
335   for (const auto &Arg : F->args()) {
336     if (Arg.getType()->isAggregateType()) {
337       ArgTypes.push_back(B.getInt32Ty());
338       ChangedTypes.push_back(
339           std::pair<int, Type *>(Arg.getArgNo(), Arg.getType()));
340     } else
341       ArgTypes.push_back(Arg.getType());
342   }
343   FunctionType *NewFTy =
344       FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg());
345   Function *NewF =
346       Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent());
347 
348   ValueToValueMapTy VMap;
349   auto NewFArgIt = NewF->arg_begin();
350   for (auto &Arg : F->args()) {
351     StringRef ArgName = Arg.getName();
352     NewFArgIt->setName(ArgName);
353     VMap[&Arg] = &(*NewFArgIt++);
354   }
355   SmallVector<ReturnInst *, 8> Returns;
356 
357   CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
358                     Returns);
359   NewF->takeName(F);
360 
361   NamedMDNode *FuncMD =
362       F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs");
363   SmallVector<Metadata *, 2> MDArgs;
364   MDArgs.push_back(MDString::get(B.getContext(), NewF->getName()));
365   for (auto &ChangedTyP : ChangedTypes)
366     MDArgs.push_back(MDNode::get(
367         B.getContext(),
368         {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)),
369          ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))}));
370   MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs);
371   FuncMD->addOperand(ThisFuncMD);
372 
373   for (auto *U : make_early_inc_range(F->users())) {
374     if (auto *CI = dyn_cast<CallInst>(U))
375       CI->mutateFunctionType(NewF->getFunctionType());
376     U->replaceUsesOfWith(F, NewF);
377   }
378   return NewF;
379 }
380 
381 bool SPIRVPrepareFunctions::runOnModule(Module &M) {
382   bool Changed = false;
383   for (Function &F : M)
384     Changed |= substituteIntrinsicCalls(&F);
385 
386   std::vector<Function *> FuncsWorklist;
387   for (auto &F : M)
388     FuncsWorklist.push_back(&F);
389 
390   for (auto *F : FuncsWorklist) {
391     Function *NewF = removeAggregateTypesFromSignature(F);
392 
393     if (NewF != F) {
394       F->eraseFromParent();
395       Changed = true;
396     }
397   }
398   return Changed;
399 }
400 
401 ModulePass *
402 llvm::createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM) {
403   return new SPIRVPrepareFunctions(TM);
404 }
405