xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp (revision 9c8bf69a53f628b62fb196182ea55fb34c1c19e1)
1 //===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass modifies function signatures containing aggregate arguments
10 // and/or return value before IRTranslator. Information about the original
11 // signatures is stored in metadata. It is used during call lowering to
12 // restore correct SPIR-V types of function arguments and return values.
13 // This pass also substitutes some llvm intrinsic calls with calls to newly
14 // generated functions (as the Khronos LLVM/SPIR-V Translator does).
15 //
16 // NOTE: this pass is a module-level one due to the necessity to modify
17 // GVs/functions.
18 //
19 //===----------------------------------------------------------------------===//
20 
21 #include "SPIRV.h"
22 #include "SPIRVTargetMachine.h"
23 #include "SPIRVUtils.h"
24 #include "llvm/CodeGen/IntrinsicLowering.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/IntrinsicInst.h"
27 #include "llvm/Transforms/Utils/Cloning.h"
28 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
29 
30 using namespace llvm;
31 
32 namespace llvm {
33 void initializeSPIRVPrepareFunctionsPass(PassRegistry &);
34 }
35 
36 namespace {
37 
38 class SPIRVPrepareFunctions : public ModulePass {
39   bool substituteIntrinsicCalls(Function *F);
40   Function *removeAggregateTypesFromSignature(Function *F);
41 
42 public:
43   static char ID;
44   SPIRVPrepareFunctions() : ModulePass(ID) {
45     initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry());
46   }
47 
48   bool runOnModule(Module &M) override;
49 
50   StringRef getPassName() const override { return "SPIRV prepare functions"; }
51 
52   void getAnalysisUsage(AnalysisUsage &AU) const override {
53     ModulePass::getAnalysisUsage(AU);
54   }
55 };
56 
57 } // namespace
58 
59 char SPIRVPrepareFunctions::ID = 0;
60 
61 INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions",
62                 "SPIRV prepare functions", false, false)
63 
64 std::string lowerLLVMIntrinsicName(IntrinsicInst *II) {
65   Function *IntrinsicFunc = II->getCalledFunction();
66   assert(IntrinsicFunc && "Missing function");
67   std::string FuncName = IntrinsicFunc->getName().str();
68   std::replace(FuncName.begin(), FuncName.end(), '.', '_');
69   FuncName = "spirv." + FuncName;
70   return FuncName;
71 }
72 
73 static Function *getOrCreateFunction(Module *M, Type *RetTy,
74                                      ArrayRef<Type *> ArgTypes,
75                                      StringRef Name) {
76   FunctionType *FT = FunctionType::get(RetTy, ArgTypes, false);
77   Function *F = M->getFunction(Name);
78   if (F && F->getFunctionType() == FT)
79     return F;
80   Function *NewF = Function::Create(FT, GlobalValue::ExternalLinkage, Name, M);
81   if (F)
82     NewF->setDSOLocal(F->isDSOLocal());
83   NewF->setCallingConv(CallingConv::SPIR_FUNC);
84   return NewF;
85 }
86 
87 static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic) {
88   // For @llvm.memset.* intrinsic cases with constant value and length arguments
89   // are emulated via "storing" a constant array to the destination. For other
90   // cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the
91   // intrinsic to a loop via expandMemSetAsLoop().
92   if (auto *MSI = dyn_cast<MemSetInst>(Intrinsic))
93     if (isa<Constant>(MSI->getValue()) && isa<ConstantInt>(MSI->getLength()))
94       return false; // It is handled later using OpCopyMemorySized.
95 
96   Module *M = Intrinsic->getModule();
97   std::string FuncName = lowerLLVMIntrinsicName(Intrinsic);
98   if (Intrinsic->isVolatile())
99     FuncName += ".volatile";
100   // Redirect @llvm.intrinsic.* call to @spirv.llvm_intrinsic_*
101   Function *F = M->getFunction(FuncName);
102   if (F) {
103     Intrinsic->setCalledFunction(F);
104     return true;
105   }
106   // TODO copy arguments attributes: nocapture writeonly.
107   FunctionCallee FC =
108       M->getOrInsertFunction(FuncName, Intrinsic->getFunctionType());
109   auto IntrinsicID = Intrinsic->getIntrinsicID();
110   Intrinsic->setCalledFunction(FC);
111 
112   F = dyn_cast<Function>(FC.getCallee());
113   assert(F && "Callee must be a function");
114 
115   switch (IntrinsicID) {
116   case Intrinsic::memset: {
117     auto *MSI = static_cast<MemSetInst *>(Intrinsic);
118     Argument *Dest = F->getArg(0);
119     Argument *Val = F->getArg(1);
120     Argument *Len = F->getArg(2);
121     Argument *IsVolatile = F->getArg(3);
122     Dest->setName("dest");
123     Val->setName("val");
124     Len->setName("len");
125     IsVolatile->setName("isvolatile");
126     BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);
127     IRBuilder<> IRB(EntryBB);
128     auto *MemSet = IRB.CreateMemSet(Dest, Val, Len, MSI->getDestAlign(),
129                                     MSI->isVolatile());
130     IRB.CreateRetVoid();
131     expandMemSetAsLoop(cast<MemSetInst>(MemSet));
132     MemSet->eraseFromParent();
133     break;
134   }
135   case Intrinsic::bswap: {
136     BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);
137     IRBuilder<> IRB(EntryBB);
138     auto *BSwap = IRB.CreateIntrinsic(Intrinsic::bswap, Intrinsic->getType(),
139                                       F->getArg(0));
140     IRB.CreateRet(BSwap);
141     IntrinsicLowering IL(M->getDataLayout());
142     IL.LowerIntrinsicCall(BSwap);
143     break;
144   }
145   default:
146     break;
147   }
148   return true;
149 }
150 
151 static void lowerFunnelShifts(IntrinsicInst *FSHIntrinsic) {
152   // Get a separate function - otherwise, we'd have to rework the CFG of the
153   // current one. Then simply replace the intrinsic uses with a call to the new
154   // function.
155   // Generate LLVM IR for  i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c)
156   Module *M = FSHIntrinsic->getModule();
157   FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType();
158   Type *FSHRetTy = FSHFuncTy->getReturnType();
159   const std::string FuncName = lowerLLVMIntrinsicName(FSHIntrinsic);
160   Function *FSHFunc =
161       getOrCreateFunction(M, FSHRetTy, FSHFuncTy->params(), FuncName);
162 
163   if (!FSHFunc->empty()) {
164     FSHIntrinsic->setCalledFunction(FSHFunc);
165     return;
166   }
167   BasicBlock *RotateBB = BasicBlock::Create(M->getContext(), "rotate", FSHFunc);
168   IRBuilder<> IRB(RotateBB);
169   Type *Ty = FSHFunc->getReturnType();
170   // Build the actual funnel shift rotate logic.
171   // In the comments, "int" is used interchangeably with "vector of int
172   // elements".
173   FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Ty);
174   Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty;
175   unsigned BitWidth = IntTy->getIntegerBitWidth();
176   ConstantInt *BitWidthConstant = IRB.getInt({BitWidth, BitWidth});
177   Value *BitWidthForInsts =
178       VectorTy
179           ? IRB.CreateVectorSplat(VectorTy->getNumElements(), BitWidthConstant)
180           : BitWidthConstant;
181   Value *RotateModVal =
182       IRB.CreateURem(/*Rotate*/ FSHFunc->getArg(2), BitWidthForInsts);
183   Value *FirstShift = nullptr, *SecShift = nullptr;
184   if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
185     // Shift the less significant number right, the "rotate" number of bits
186     // will be 0-filled on the left as a result of this regular shift.
187     FirstShift = IRB.CreateLShr(FSHFunc->getArg(1), RotateModVal);
188   } else {
189     // Shift the more significant number left, the "rotate" number of bits
190     // will be 0-filled on the right as a result of this regular shift.
191     FirstShift = IRB.CreateShl(FSHFunc->getArg(0), RotateModVal);
192   }
193   // We want the "rotate" number of the more significant int's LSBs (MSBs) to
194   // occupy the leftmost (rightmost) "0 space" left by the previous operation.
195   // Therefore, subtract the "rotate" number from the integer bitsize...
196   Value *SubRotateVal = IRB.CreateSub(BitWidthForInsts, RotateModVal);
197   if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
198     // ...and left-shift the more significant int by this number, zero-filling
199     // the LSBs.
200     SecShift = IRB.CreateShl(FSHFunc->getArg(0), SubRotateVal);
201   } else {
202     // ...and right-shift the less significant int by this number, zero-filling
203     // the MSBs.
204     SecShift = IRB.CreateLShr(FSHFunc->getArg(1), SubRotateVal);
205   }
206   // A simple binary addition of the shifted ints yields the final result.
207   IRB.CreateRet(IRB.CreateOr(FirstShift, SecShift));
208 
209   FSHIntrinsic->setCalledFunction(FSHFunc);
210 }
211 
212 static void buildUMulWithOverflowFunc(Function *UMulFunc) {
213   // The function body is already created.
214   if (!UMulFunc->empty())
215     return;
216 
217   BasicBlock *EntryBB = BasicBlock::Create(UMulFunc->getParent()->getContext(),
218                                            "entry", UMulFunc);
219   IRBuilder<> IRB(EntryBB);
220   // Build the actual unsigned multiplication logic with the overflow
221   // indication. Do unsigned multiplication Mul = A * B. Then check
222   // if unsigned division Div = Mul / A is not equal to B. If so,
223   // then overflow has happened.
224   Value *Mul = IRB.CreateNUWMul(UMulFunc->getArg(0), UMulFunc->getArg(1));
225   Value *Div = IRB.CreateUDiv(Mul, UMulFunc->getArg(0));
226   Value *Overflow = IRB.CreateICmpNE(UMulFunc->getArg(0), Div);
227 
228   // umul.with.overflow intrinsic return a structure, where the first element
229   // is the multiplication result, and the second is an overflow bit.
230   Type *StructTy = UMulFunc->getReturnType();
231   Value *Agg = IRB.CreateInsertValue(PoisonValue::get(StructTy), Mul, {0});
232   Value *Res = IRB.CreateInsertValue(Agg, Overflow, {1});
233   IRB.CreateRet(Res);
234 }
235 
236 static void lowerUMulWithOverflow(IntrinsicInst *UMulIntrinsic) {
237   // Get a separate function - otherwise, we'd have to rework the CFG of the
238   // current one. Then simply replace the intrinsic uses with a call to the new
239   // function.
240   Module *M = UMulIntrinsic->getModule();
241   FunctionType *UMulFuncTy = UMulIntrinsic->getFunctionType();
242   Type *FSHLRetTy = UMulFuncTy->getReturnType();
243   const std::string FuncName = lowerLLVMIntrinsicName(UMulIntrinsic);
244   Function *UMulFunc =
245       getOrCreateFunction(M, FSHLRetTy, UMulFuncTy->params(), FuncName);
246   buildUMulWithOverflowFunc(UMulFunc);
247   UMulIntrinsic->setCalledFunction(UMulFunc);
248 }
249 
250 // Substitutes calls to LLVM intrinsics with either calls to SPIR-V intrinsics
251 // or calls to proper generated functions. Returns True if F was modified.
252 bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
253   bool Changed = false;
254   for (BasicBlock &BB : *F) {
255     for (Instruction &I : BB) {
256       auto Call = dyn_cast<CallInst>(&I);
257       if (!Call)
258         continue;
259       Function *CF = Call->getCalledFunction();
260       if (!CF || !CF->isIntrinsic())
261         continue;
262       auto *II = cast<IntrinsicInst>(Call);
263       if (II->getIntrinsicID() == Intrinsic::memset ||
264           II->getIntrinsicID() == Intrinsic::bswap)
265         Changed |= lowerIntrinsicToFunction(II);
266       else if (II->getIntrinsicID() == Intrinsic::fshl ||
267                II->getIntrinsicID() == Intrinsic::fshr) {
268         lowerFunnelShifts(II);
269         Changed = true;
270       } else if (II->getIntrinsicID() == Intrinsic::umul_with_overflow) {
271         lowerUMulWithOverflow(II);
272         Changed = true;
273       }
274     }
275   }
276   return Changed;
277 }
278 
279 // Returns F if aggregate argument/return types are not present or cloned F
280 // function with the types replaced by i32 types. The change in types is
281 // noted in 'spv.cloned_funcs' metadata for later restoration.
282 Function *
283 SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
284   IRBuilder<> B(F->getContext());
285 
286   bool IsRetAggr = F->getReturnType()->isAggregateType();
287   bool HasAggrArg =
288       std::any_of(F->arg_begin(), F->arg_end(), [](Argument &Arg) {
289         return Arg.getType()->isAggregateType();
290       });
291   bool DoClone = IsRetAggr || HasAggrArg;
292   if (!DoClone)
293     return F;
294   SmallVector<std::pair<int, Type *>, 4> ChangedTypes;
295   Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType();
296   if (IsRetAggr)
297     ChangedTypes.push_back(std::pair<int, Type *>(-1, F->getReturnType()));
298   SmallVector<Type *, 4> ArgTypes;
299   for (const auto &Arg : F->args()) {
300     if (Arg.getType()->isAggregateType()) {
301       ArgTypes.push_back(B.getInt32Ty());
302       ChangedTypes.push_back(
303           std::pair<int, Type *>(Arg.getArgNo(), Arg.getType()));
304     } else
305       ArgTypes.push_back(Arg.getType());
306   }
307   FunctionType *NewFTy =
308       FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg());
309   Function *NewF =
310       Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent());
311 
312   ValueToValueMapTy VMap;
313   auto NewFArgIt = NewF->arg_begin();
314   for (auto &Arg : F->args()) {
315     StringRef ArgName = Arg.getName();
316     NewFArgIt->setName(ArgName);
317     VMap[&Arg] = &(*NewFArgIt++);
318   }
319   SmallVector<ReturnInst *, 8> Returns;
320 
321   CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
322                     Returns);
323   NewF->takeName(F);
324 
325   NamedMDNode *FuncMD =
326       F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs");
327   SmallVector<Metadata *, 2> MDArgs;
328   MDArgs.push_back(MDString::get(B.getContext(), NewF->getName()));
329   for (auto &ChangedTyP : ChangedTypes)
330     MDArgs.push_back(MDNode::get(
331         B.getContext(),
332         {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)),
333          ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))}));
334   MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs);
335   FuncMD->addOperand(ThisFuncMD);
336 
337   for (auto *U : make_early_inc_range(F->users())) {
338     if (auto *CI = dyn_cast<CallInst>(U))
339       CI->mutateFunctionType(NewF->getFunctionType());
340     U->replaceUsesOfWith(F, NewF);
341   }
342   return NewF;
343 }
344 
345 bool SPIRVPrepareFunctions::runOnModule(Module &M) {
346   bool Changed = false;
347   for (Function &F : M)
348     Changed |= substituteIntrinsicCalls(&F);
349 
350   std::vector<Function *> FuncsWorklist;
351   for (auto &F : M)
352     FuncsWorklist.push_back(&F);
353 
354   for (auto *F : FuncsWorklist) {
355     Function *NewF = removeAggregateTypesFromSignature(F);
356 
357     if (NewF != F) {
358       F->eraseFromParent();
359       Changed = true;
360     }
361   }
362   return Changed;
363 }
364 
365 ModulePass *llvm::createSPIRVPrepareFunctionsPass() {
366   return new SPIRVPrepareFunctions();
367 }
368