xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass modifies function signatures containing aggregate arguments
10 // and/or return value before IRTranslator. Information about the original
11 // signatures is stored in metadata. It is used during call lowering to
12 // restore correct SPIR-V types of function arguments and return values.
13 // This pass also substitutes some llvm intrinsic calls with calls to newly
14 // generated functions (as the Khronos LLVM/SPIR-V Translator does).
15 //
16 // NOTE: this pass is a module-level one due to the necessity to modify
17 // GVs/functions.
18 //
19 //===----------------------------------------------------------------------===//
20 
21 #include "SPIRV.h"
22 #include "SPIRVSubtarget.h"
23 #include "SPIRVTargetMachine.h"
24 #include "SPIRVUtils.h"
25 #include "llvm/Analysis/ValueTracking.h"
26 #include "llvm/CodeGen/IntrinsicLowering.h"
27 #include "llvm/IR/IRBuilder.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/IntrinsicsSPIRV.h"
31 #include "llvm/Transforms/Utils/Cloning.h"
32 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
33 #include <charconv>
34 #include <regex>
35 
36 using namespace llvm;
37 
38 namespace llvm {
39 void initializeSPIRVPrepareFunctionsPass(PassRegistry &);
40 }
41 
42 namespace {
43 
44 class SPIRVPrepareFunctions : public ModulePass {
45   const SPIRVTargetMachine &TM;
46   bool substituteIntrinsicCalls(Function *F);
47   Function *removeAggregateTypesFromSignature(Function *F);
48 
49 public:
50   static char ID;
SPIRVPrepareFunctions(const SPIRVTargetMachine & TM)51   SPIRVPrepareFunctions(const SPIRVTargetMachine &TM) : ModulePass(ID), TM(TM) {
52     initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry());
53   }
54 
55   bool runOnModule(Module &M) override;
56 
getPassName() const57   StringRef getPassName() const override { return "SPIRV prepare functions"; }
58 
getAnalysisUsage(AnalysisUsage & AU) const59   void getAnalysisUsage(AnalysisUsage &AU) const override {
60     ModulePass::getAnalysisUsage(AU);
61   }
62 };
63 
64 } // namespace
65 
66 char SPIRVPrepareFunctions::ID = 0;
67 
68 INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions",
69                 "SPIRV prepare functions", false, false)
70 
lowerLLVMIntrinsicName(IntrinsicInst * II)71 std::string lowerLLVMIntrinsicName(IntrinsicInst *II) {
72   Function *IntrinsicFunc = II->getCalledFunction();
73   assert(IntrinsicFunc && "Missing function");
74   std::string FuncName = IntrinsicFunc->getName().str();
75   std::replace(FuncName.begin(), FuncName.end(), '.', '_');
76   FuncName = "spirv." + FuncName;
77   return FuncName;
78 }
79 
getOrCreateFunction(Module * M,Type * RetTy,ArrayRef<Type * > ArgTypes,StringRef Name)80 static Function *getOrCreateFunction(Module *M, Type *RetTy,
81                                      ArrayRef<Type *> ArgTypes,
82                                      StringRef Name) {
83   FunctionType *FT = FunctionType::get(RetTy, ArgTypes, false);
84   Function *F = M->getFunction(Name);
85   if (F && F->getFunctionType() == FT)
86     return F;
87   Function *NewF = Function::Create(FT, GlobalValue::ExternalLinkage, Name, M);
88   if (F)
89     NewF->setDSOLocal(F->isDSOLocal());
90   NewF->setCallingConv(CallingConv::SPIR_FUNC);
91   return NewF;
92 }
93 
lowerIntrinsicToFunction(IntrinsicInst * Intrinsic)94 static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic) {
95   // For @llvm.memset.* intrinsic cases with constant value and length arguments
96   // are emulated via "storing" a constant array to the destination. For other
97   // cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the
98   // intrinsic to a loop via expandMemSetAsLoop().
99   if (auto *MSI = dyn_cast<MemSetInst>(Intrinsic))
100     if (isa<Constant>(MSI->getValue()) && isa<ConstantInt>(MSI->getLength()))
101       return false; // It is handled later using OpCopyMemorySized.
102 
103   Module *M = Intrinsic->getModule();
104   std::string FuncName = lowerLLVMIntrinsicName(Intrinsic);
105   if (Intrinsic->isVolatile())
106     FuncName += ".volatile";
107   // Redirect @llvm.intrinsic.* call to @spirv.llvm_intrinsic_*
108   Function *F = M->getFunction(FuncName);
109   if (F) {
110     Intrinsic->setCalledFunction(F);
111     return true;
112   }
113   // TODO copy arguments attributes: nocapture writeonly.
114   FunctionCallee FC =
115       M->getOrInsertFunction(FuncName, Intrinsic->getFunctionType());
116   auto IntrinsicID = Intrinsic->getIntrinsicID();
117   Intrinsic->setCalledFunction(FC);
118 
119   F = dyn_cast<Function>(FC.getCallee());
120   assert(F && "Callee must be a function");
121 
122   switch (IntrinsicID) {
123   case Intrinsic::memset: {
124     auto *MSI = static_cast<MemSetInst *>(Intrinsic);
125     Argument *Dest = F->getArg(0);
126     Argument *Val = F->getArg(1);
127     Argument *Len = F->getArg(2);
128     Argument *IsVolatile = F->getArg(3);
129     Dest->setName("dest");
130     Val->setName("val");
131     Len->setName("len");
132     IsVolatile->setName("isvolatile");
133     BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);
134     IRBuilder<> IRB(EntryBB);
135     auto *MemSet = IRB.CreateMemSet(Dest, Val, Len, MSI->getDestAlign(),
136                                     MSI->isVolatile());
137     IRB.CreateRetVoid();
138     expandMemSetAsLoop(cast<MemSetInst>(MemSet));
139     MemSet->eraseFromParent();
140     break;
141   }
142   case Intrinsic::bswap: {
143     BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);
144     IRBuilder<> IRB(EntryBB);
145     auto *BSwap = IRB.CreateIntrinsic(Intrinsic::bswap, Intrinsic->getType(),
146                                       F->getArg(0));
147     IRB.CreateRet(BSwap);
148     IntrinsicLowering IL(M->getDataLayout());
149     IL.LowerIntrinsicCall(BSwap);
150     break;
151   }
152   default:
153     break;
154   }
155   return true;
156 }
157 
getAnnotation(Value * AnnoVal,Value * OptAnnoVal)158 static std::string getAnnotation(Value *AnnoVal, Value *OptAnnoVal) {
159   if (auto *Ref = dyn_cast_or_null<GetElementPtrInst>(AnnoVal))
160     AnnoVal = Ref->getOperand(0);
161   if (auto *Ref = dyn_cast_or_null<BitCastInst>(OptAnnoVal))
162     OptAnnoVal = Ref->getOperand(0);
163 
164   std::string Anno;
165   if (auto *C = dyn_cast_or_null<Constant>(AnnoVal)) {
166     StringRef Str;
167     if (getConstantStringInfo(C, Str))
168       Anno = Str;
169   }
170   // handle optional annotation parameter in a way that Khronos Translator do
171   // (collect integers wrapped in a struct)
172   if (auto *C = dyn_cast_or_null<Constant>(OptAnnoVal);
173       C && C->getNumOperands()) {
174     Value *MaybeStruct = C->getOperand(0);
175     if (auto *Struct = dyn_cast<ConstantStruct>(MaybeStruct)) {
176       for (unsigned I = 0, E = Struct->getNumOperands(); I != E; ++I) {
177         if (auto *CInt = dyn_cast<ConstantInt>(Struct->getOperand(I)))
178           Anno += (I == 0 ? ": " : ", ") +
179                   std::to_string(CInt->getType()->getIntegerBitWidth() == 1
180                                      ? CInt->getZExtValue()
181                                      : CInt->getSExtValue());
182       }
183     } else if (auto *Struct = dyn_cast<ConstantAggregateZero>(MaybeStruct)) {
184       // { i32 i32 ... } zeroinitializer
185       for (unsigned I = 0, E = Struct->getType()->getStructNumElements();
186            I != E; ++I)
187         Anno += I == 0 ? ": 0" : ", 0";
188     }
189   }
190   return Anno;
191 }
192 
parseAnnotation(Value * I,const std::string & Anno,LLVMContext & Ctx,Type * Int32Ty)193 static SmallVector<Metadata *> parseAnnotation(Value *I,
194                                                const std::string &Anno,
195                                                LLVMContext &Ctx,
196                                                Type *Int32Ty) {
197   // Try to parse the annotation string according to the following rules:
198   // annotation := ({kind} | {kind:value,value,...})+
199   // kind := number
200   // value := number | string
201   static const std::regex R(
202       "\\{(\\d+)(?:[:,](\\d+|\"[^\"]*\")(?:,(\\d+|\"[^\"]*\"))*)?\\}");
203   SmallVector<Metadata *> MDs;
204   int Pos = 0;
205   for (std::sregex_iterator
206            It = std::sregex_iterator(Anno.begin(), Anno.end(), R),
207            ItEnd = std::sregex_iterator();
208        It != ItEnd; ++It) {
209     if (It->position() != Pos)
210       return SmallVector<Metadata *>{};
211     Pos = It->position() + It->length();
212     std::smatch Match = *It;
213     SmallVector<Metadata *> MDsItem;
214     for (std::size_t i = 1; i < Match.size(); ++i) {
215       std::ssub_match SMatch = Match[i];
216       std::string Item = SMatch.str();
217       if (Item.length() == 0)
218         break;
219       if (Item[0] == '"') {
220         Item = Item.substr(1, Item.length() - 2);
221         // Acceptable format of the string snippet is:
222         static const std::regex RStr("^(\\d+)(?:,(\\d+))*$");
223         if (std::smatch MatchStr; std::regex_match(Item, MatchStr, RStr)) {
224           for (std::size_t SubIdx = 1; SubIdx < MatchStr.size(); ++SubIdx)
225             if (std::string SubStr = MatchStr[SubIdx].str(); SubStr.length())
226               MDsItem.push_back(ConstantAsMetadata::get(
227                   ConstantInt::get(Int32Ty, std::stoi(SubStr))));
228         } else {
229           MDsItem.push_back(MDString::get(Ctx, Item));
230         }
231       } else if (int32_t Num;
232                  std::from_chars(Item.data(), Item.data() + Item.size(), Num)
233                      .ec == std::errc{}) {
234         MDsItem.push_back(
235             ConstantAsMetadata::get(ConstantInt::get(Int32Ty, Num)));
236       } else {
237         MDsItem.push_back(MDString::get(Ctx, Item));
238       }
239     }
240     if (MDsItem.size() == 0)
241       return SmallVector<Metadata *>{};
242     MDs.push_back(MDNode::get(Ctx, MDsItem));
243   }
244   return Pos == static_cast<int>(Anno.length()) ? MDs
245                                                 : SmallVector<Metadata *>{};
246 }
247 
lowerPtrAnnotation(IntrinsicInst * II)248 static void lowerPtrAnnotation(IntrinsicInst *II) {
249   LLVMContext &Ctx = II->getContext();
250   Type *Int32Ty = Type::getInt32Ty(Ctx);
251 
252   // Retrieve an annotation string from arguments.
253   Value *PtrArg = nullptr;
254   if (auto *BI = dyn_cast<BitCastInst>(II->getArgOperand(0)))
255     PtrArg = BI->getOperand(0);
256   else
257     PtrArg = II->getOperand(0);
258   std::string Anno =
259       getAnnotation(II->getArgOperand(1),
260                     4 < II->arg_size() ? II->getArgOperand(4) : nullptr);
261 
262   // Parse the annotation.
263   SmallVector<Metadata *> MDs = parseAnnotation(II, Anno, Ctx, Int32Ty);
264 
265   // If the annotation string is not parsed successfully we don't know the
266   // format used and output it as a general UserSemantic decoration.
267   // Otherwise MDs is a Metadata tuple (a decoration list) in the format
268   // expected by `spirv.Decorations`.
269   if (MDs.size() == 0) {
270     auto UserSemantic = ConstantAsMetadata::get(ConstantInt::get(
271         Int32Ty, static_cast<uint32_t>(SPIRV::Decoration::UserSemantic)));
272     MDs.push_back(MDNode::get(Ctx, {UserSemantic, MDString::get(Ctx, Anno)}));
273   }
274 
275   // Build the internal intrinsic function.
276   IRBuilder<> IRB(II->getParent());
277   IRB.SetInsertPoint(II);
278   IRB.CreateIntrinsic(
279       Intrinsic::spv_assign_decoration, {PtrArg->getType()},
280       {PtrArg, MetadataAsValue::get(Ctx, MDNode::get(Ctx, MDs))});
281   II->replaceAllUsesWith(II->getOperand(0));
282 }
283 
lowerFunnelShifts(IntrinsicInst * FSHIntrinsic)284 static void lowerFunnelShifts(IntrinsicInst *FSHIntrinsic) {
285   // Get a separate function - otherwise, we'd have to rework the CFG of the
286   // current one. Then simply replace the intrinsic uses with a call to the new
287   // function.
288   // Generate LLVM IR for  i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c)
289   Module *M = FSHIntrinsic->getModule();
290   FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType();
291   Type *FSHRetTy = FSHFuncTy->getReturnType();
292   const std::string FuncName = lowerLLVMIntrinsicName(FSHIntrinsic);
293   Function *FSHFunc =
294       getOrCreateFunction(M, FSHRetTy, FSHFuncTy->params(), FuncName);
295 
296   if (!FSHFunc->empty()) {
297     FSHIntrinsic->setCalledFunction(FSHFunc);
298     return;
299   }
300   BasicBlock *RotateBB = BasicBlock::Create(M->getContext(), "rotate", FSHFunc);
301   IRBuilder<> IRB(RotateBB);
302   Type *Ty = FSHFunc->getReturnType();
303   // Build the actual funnel shift rotate logic.
304   // In the comments, "int" is used interchangeably with "vector of int
305   // elements".
306   FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Ty);
307   Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty;
308   unsigned BitWidth = IntTy->getIntegerBitWidth();
309   ConstantInt *BitWidthConstant = IRB.getInt({BitWidth, BitWidth});
310   Value *BitWidthForInsts =
311       VectorTy
312           ? IRB.CreateVectorSplat(VectorTy->getNumElements(), BitWidthConstant)
313           : BitWidthConstant;
314   Value *RotateModVal =
315       IRB.CreateURem(/*Rotate*/ FSHFunc->getArg(2), BitWidthForInsts);
316   Value *FirstShift = nullptr, *SecShift = nullptr;
317   if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
318     // Shift the less significant number right, the "rotate" number of bits
319     // will be 0-filled on the left as a result of this regular shift.
320     FirstShift = IRB.CreateLShr(FSHFunc->getArg(1), RotateModVal);
321   } else {
322     // Shift the more significant number left, the "rotate" number of bits
323     // will be 0-filled on the right as a result of this regular shift.
324     FirstShift = IRB.CreateShl(FSHFunc->getArg(0), RotateModVal);
325   }
326   // We want the "rotate" number of the more significant int's LSBs (MSBs) to
327   // occupy the leftmost (rightmost) "0 space" left by the previous operation.
328   // Therefore, subtract the "rotate" number from the integer bitsize...
329   Value *SubRotateVal = IRB.CreateSub(BitWidthForInsts, RotateModVal);
330   if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
331     // ...and left-shift the more significant int by this number, zero-filling
332     // the LSBs.
333     SecShift = IRB.CreateShl(FSHFunc->getArg(0), SubRotateVal);
334   } else {
335     // ...and right-shift the less significant int by this number, zero-filling
336     // the MSBs.
337     SecShift = IRB.CreateLShr(FSHFunc->getArg(1), SubRotateVal);
338   }
339   // A simple binary addition of the shifted ints yields the final result.
340   IRB.CreateRet(IRB.CreateOr(FirstShift, SecShift));
341 
342   FSHIntrinsic->setCalledFunction(FSHFunc);
343 }
344 
buildUMulWithOverflowFunc(Function * UMulFunc)345 static void buildUMulWithOverflowFunc(Function *UMulFunc) {
346   // The function body is already created.
347   if (!UMulFunc->empty())
348     return;
349 
350   BasicBlock *EntryBB = BasicBlock::Create(UMulFunc->getParent()->getContext(),
351                                            "entry", UMulFunc);
352   IRBuilder<> IRB(EntryBB);
353   // Build the actual unsigned multiplication logic with the overflow
354   // indication. Do unsigned multiplication Mul = A * B. Then check
355   // if unsigned division Div = Mul / A is not equal to B. If so,
356   // then overflow has happened.
357   Value *Mul = IRB.CreateNUWMul(UMulFunc->getArg(0), UMulFunc->getArg(1));
358   Value *Div = IRB.CreateUDiv(Mul, UMulFunc->getArg(0));
359   Value *Overflow = IRB.CreateICmpNE(UMulFunc->getArg(0), Div);
360 
361   // umul.with.overflow intrinsic return a structure, where the first element
362   // is the multiplication result, and the second is an overflow bit.
363   Type *StructTy = UMulFunc->getReturnType();
364   Value *Agg = IRB.CreateInsertValue(PoisonValue::get(StructTy), Mul, {0});
365   Value *Res = IRB.CreateInsertValue(Agg, Overflow, {1});
366   IRB.CreateRet(Res);
367 }
368 
lowerExpectAssume(IntrinsicInst * II)369 static void lowerExpectAssume(IntrinsicInst *II) {
370   // If we cannot use the SPV_KHR_expect_assume extension, then we need to
371   // ignore the intrinsic and move on. It should be removed later on by LLVM.
372   // Otherwise we should lower the intrinsic to the corresponding SPIR-V
373   // instruction.
374   // For @llvm.assume we have OpAssumeTrueKHR.
375   // For @llvm.expect we have OpExpectKHR.
376   //
377   // We need to lower this into a builtin and then the builtin into a SPIR-V
378   // instruction.
379   if (II->getIntrinsicID() == Intrinsic::assume) {
380     Function *F = Intrinsic::getDeclaration(
381         II->getModule(), Intrinsic::SPVIntrinsics::spv_assume);
382     II->setCalledFunction(F);
383   } else if (II->getIntrinsicID() == Intrinsic::expect) {
384     Function *F = Intrinsic::getDeclaration(
385         II->getModule(), Intrinsic::SPVIntrinsics::spv_expect,
386         {II->getOperand(0)->getType()});
387     II->setCalledFunction(F);
388   } else {
389     llvm_unreachable("Unknown intrinsic");
390   }
391 
392   return;
393 }
394 
toSpvOverloadedIntrinsic(IntrinsicInst * II,Intrinsic::ID NewID,ArrayRef<unsigned> OpNos)395 static bool toSpvOverloadedIntrinsic(IntrinsicInst *II, Intrinsic::ID NewID,
396                                      ArrayRef<unsigned> OpNos) {
397   Function *F = nullptr;
398   if (OpNos.empty()) {
399     F = Intrinsic::getDeclaration(II->getModule(), NewID);
400   } else {
401     SmallVector<Type *, 4> Tys;
402     for (unsigned OpNo : OpNos)
403       Tys.push_back(II->getOperand(OpNo)->getType());
404     F = Intrinsic::getDeclaration(II->getModule(), NewID, Tys);
405   }
406   II->setCalledFunction(F);
407   return true;
408 }
409 
lowerUMulWithOverflow(IntrinsicInst * UMulIntrinsic)410 static void lowerUMulWithOverflow(IntrinsicInst *UMulIntrinsic) {
411   // Get a separate function - otherwise, we'd have to rework the CFG of the
412   // current one. Then simply replace the intrinsic uses with a call to the new
413   // function.
414   Module *M = UMulIntrinsic->getModule();
415   FunctionType *UMulFuncTy = UMulIntrinsic->getFunctionType();
416   Type *FSHLRetTy = UMulFuncTy->getReturnType();
417   const std::string FuncName = lowerLLVMIntrinsicName(UMulIntrinsic);
418   Function *UMulFunc =
419       getOrCreateFunction(M, FSHLRetTy, UMulFuncTy->params(), FuncName);
420   buildUMulWithOverflowFunc(UMulFunc);
421   UMulIntrinsic->setCalledFunction(UMulFunc);
422 }
423 
424 // Substitutes calls to LLVM intrinsics with either calls to SPIR-V intrinsics
425 // or calls to proper generated functions. Returns True if F was modified.
substituteIntrinsicCalls(Function * F)426 bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
427   bool Changed = false;
428   for (BasicBlock &BB : *F) {
429     for (Instruction &I : BB) {
430       auto Call = dyn_cast<CallInst>(&I);
431       if (!Call)
432         continue;
433       Function *CF = Call->getCalledFunction();
434       if (!CF || !CF->isIntrinsic())
435         continue;
436       auto *II = cast<IntrinsicInst>(Call);
437       switch (II->getIntrinsicID()) {
438       case Intrinsic::memset:
439       case Intrinsic::bswap:
440         Changed |= lowerIntrinsicToFunction(II);
441         break;
442       case Intrinsic::fshl:
443       case Intrinsic::fshr:
444         lowerFunnelShifts(II);
445         Changed = true;
446         break;
447       case Intrinsic::umul_with_overflow:
448         lowerUMulWithOverflow(II);
449         Changed = true;
450         break;
451       case Intrinsic::assume:
452       case Intrinsic::expect: {
453         const SPIRVSubtarget &STI = TM.getSubtarget<SPIRVSubtarget>(*F);
454         if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume))
455           lowerExpectAssume(II);
456         Changed = true;
457       } break;
458       case Intrinsic::lifetime_start:
459         Changed |= toSpvOverloadedIntrinsic(
460             II, Intrinsic::SPVIntrinsics::spv_lifetime_start, {1});
461         break;
462       case Intrinsic::lifetime_end:
463         Changed |= toSpvOverloadedIntrinsic(
464             II, Intrinsic::SPVIntrinsics::spv_lifetime_end, {1});
465         break;
466       case Intrinsic::ptr_annotation:
467         lowerPtrAnnotation(II);
468         Changed = true;
469         break;
470       }
471     }
472   }
473   return Changed;
474 }
475 
476 // Returns F if aggregate argument/return types are not present or cloned F
477 // function with the types replaced by i32 types. The change in types is
478 // noted in 'spv.cloned_funcs' metadata for later restoration.
479 Function *
removeAggregateTypesFromSignature(Function * F)480 SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
481   IRBuilder<> B(F->getContext());
482 
483   bool IsRetAggr = F->getReturnType()->isAggregateType();
484   bool HasAggrArg =
485       std::any_of(F->arg_begin(), F->arg_end(), [](Argument &Arg) {
486         return Arg.getType()->isAggregateType();
487       });
488   bool DoClone = IsRetAggr || HasAggrArg;
489   if (!DoClone)
490     return F;
491   SmallVector<std::pair<int, Type *>, 4> ChangedTypes;
492   Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType();
493   if (IsRetAggr)
494     ChangedTypes.push_back(std::pair<int, Type *>(-1, F->getReturnType()));
495   SmallVector<Type *, 4> ArgTypes;
496   for (const auto &Arg : F->args()) {
497     if (Arg.getType()->isAggregateType()) {
498       ArgTypes.push_back(B.getInt32Ty());
499       ChangedTypes.push_back(
500           std::pair<int, Type *>(Arg.getArgNo(), Arg.getType()));
501     } else
502       ArgTypes.push_back(Arg.getType());
503   }
504   FunctionType *NewFTy =
505       FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg());
506   Function *NewF =
507       Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent());
508 
509   ValueToValueMapTy VMap;
510   auto NewFArgIt = NewF->arg_begin();
511   for (auto &Arg : F->args()) {
512     StringRef ArgName = Arg.getName();
513     NewFArgIt->setName(ArgName);
514     VMap[&Arg] = &(*NewFArgIt++);
515   }
516   SmallVector<ReturnInst *, 8> Returns;
517 
518   CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
519                     Returns);
520   NewF->takeName(F);
521 
522   NamedMDNode *FuncMD =
523       F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs");
524   SmallVector<Metadata *, 2> MDArgs;
525   MDArgs.push_back(MDString::get(B.getContext(), NewF->getName()));
526   for (auto &ChangedTyP : ChangedTypes)
527     MDArgs.push_back(MDNode::get(
528         B.getContext(),
529         {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)),
530          ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))}));
531   MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs);
532   FuncMD->addOperand(ThisFuncMD);
533 
534   for (auto *U : make_early_inc_range(F->users())) {
535     if (auto *CI = dyn_cast<CallInst>(U))
536       CI->mutateFunctionType(NewF->getFunctionType());
537     U->replaceUsesOfWith(F, NewF);
538   }
539 
540   // register the mutation
541   if (RetType != F->getReturnType())
542     TM.getSubtarget<SPIRVSubtarget>(*F).getSPIRVGlobalRegistry()->addMutated(
543         NewF, F->getReturnType());
544   return NewF;
545 }
546 
runOnModule(Module & M)547 bool SPIRVPrepareFunctions::runOnModule(Module &M) {
548   bool Changed = false;
549   for (Function &F : M)
550     Changed |= substituteIntrinsicCalls(&F);
551 
552   std::vector<Function *> FuncsWorklist;
553   for (auto &F : M)
554     FuncsWorklist.push_back(&F);
555 
556   for (auto *F : FuncsWorklist) {
557     Function *NewF = removeAggregateTypesFromSignature(F);
558 
559     if (NewF != F) {
560       F->eraseFromParent();
561       Changed = true;
562     }
563   }
564   return Changed;
565 }
566 
567 ModulePass *
createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine & TM)568 llvm::createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM) {
569   return new SPIRVPrepareFunctions(TM);
570 }
571