xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp (revision 22d7dd834bc5cd189810e414701e3ad1e98102e4)
1 //===- ArgumentPromotion.cpp - Promote by-reference arguments -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass promotes "by reference" arguments to be "by value" arguments.  In
10 // practice, this means looking for internal functions that have pointer
11 // arguments.  If it can prove, through the use of alias analysis, that an
12 // argument is *only* loaded, then it can pass the value into the function
13 // instead of the address of the value.  This can cause recursive simplification
14 // of code and lead to the elimination of allocas (especially in C++ template
15 // code like the STL).
16 //
17 // This pass also handles aggregate arguments that are passed into a function,
18 // scalarizing them if the elements of the aggregate are only loaded.  Note that
19 // by default it refuses to scalarize aggregates which would require passing in
20 // more than three operands to the function, because passing thousands of
21 // operands for a large array or structure is unprofitable! This limit can be
22 // configured or disabled, however.
23 //
24 // Note that this transformation could also be done for arguments that are only
25 // stored to (returning the value instead), but does not currently.  This case
26 // would be best handled when and if LLVM begins supporting multiple return
27 // values from functions.
28 //
29 //===----------------------------------------------------------------------===//
30 
31 #include "llvm/Transforms/IPO/ArgumentPromotion.h"
32 
33 #include "llvm/ADT/DepthFirstIterator.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/ScopeExit.h"
36 #include "llvm/ADT/SmallPtrSet.h"
37 #include "llvm/ADT/SmallVector.h"
38 #include "llvm/ADT/Statistic.h"
39 #include "llvm/ADT/Twine.h"
40 #include "llvm/Analysis/AssumptionCache.h"
41 #include "llvm/Analysis/BasicAliasAnalysis.h"
42 #include "llvm/Analysis/CallGraph.h"
43 #include "llvm/Analysis/Loads.h"
44 #include "llvm/Analysis/MemoryLocation.h"
45 #include "llvm/Analysis/TargetTransformInfo.h"
46 #include "llvm/Analysis/ValueTracking.h"
47 #include "llvm/IR/Argument.h"
48 #include "llvm/IR/Attributes.h"
49 #include "llvm/IR/BasicBlock.h"
50 #include "llvm/IR/CFG.h"
51 #include "llvm/IR/Constants.h"
52 #include "llvm/IR/DataLayout.h"
53 #include "llvm/IR/DerivedTypes.h"
54 #include "llvm/IR/Dominators.h"
55 #include "llvm/IR/Function.h"
56 #include "llvm/IR/IRBuilder.h"
57 #include "llvm/IR/InstrTypes.h"
58 #include "llvm/IR/Instruction.h"
59 #include "llvm/IR/Instructions.h"
60 #include "llvm/IR/Metadata.h"
61 #include "llvm/IR/NoFolder.h"
62 #include "llvm/IR/PassManager.h"
63 #include "llvm/IR/Type.h"
64 #include "llvm/IR/Use.h"
65 #include "llvm/IR/User.h"
66 #include "llvm/IR/Value.h"
67 #include "llvm/Support/Casting.h"
68 #include "llvm/Support/Debug.h"
69 #include "llvm/Support/raw_ostream.h"
70 #include "llvm/Transforms/Utils/PromoteMemToReg.h"
71 #include <algorithm>
72 #include <cassert>
73 #include <cstdint>
74 #include <utility>
75 #include <vector>
76 
77 using namespace llvm;
78 
79 #define DEBUG_TYPE "argpromotion"
80 
81 STATISTIC(NumArgumentsPromoted, "Number of pointer arguments promoted");
82 STATISTIC(NumArgumentsDead, "Number of dead pointer args eliminated");
83 
84 namespace {
85 
86 struct ArgPart {
87   Type *Ty;
88   Align Alignment;
89   /// A representative guaranteed-executed load or store instruction for use by
90   /// metadata transfer.
91   Instruction *MustExecInstr;
92 };
93 
94 using OffsetAndArgPart = std::pair<int64_t, ArgPart>;
95 
96 } // end anonymous namespace
97 
98 static Value *createByteGEP(IRBuilderBase &IRB, const DataLayout &DL,
99                             Value *Ptr, Type *ResElemTy, int64_t Offset) {
100   // For non-opaque pointers, try to create a "nice" GEP if possible, otherwise
101   // fall back to an i8 GEP to a specific offset.
102   unsigned AddrSpace = Ptr->getType()->getPointerAddressSpace();
103   APInt OrigOffset(DL.getIndexTypeSizeInBits(Ptr->getType()), Offset);
104   if (!Ptr->getType()->isOpaquePointerTy()) {
105     Type *OrigElemTy = Ptr->getType()->getNonOpaquePointerElementType();
106     if (OrigOffset == 0 && OrigElemTy == ResElemTy)
107       return Ptr;
108 
109     if (OrigElemTy->isSized()) {
110       APInt TmpOffset = OrigOffset;
111       Type *TmpTy = OrigElemTy;
112       SmallVector<APInt> IntIndices =
113           DL.getGEPIndicesForOffset(TmpTy, TmpOffset);
114       if (TmpOffset == 0) {
115         // Try to add trailing zero indices to reach the right type.
116         while (TmpTy != ResElemTy) {
117           Type *NextTy = GetElementPtrInst::getTypeAtIndex(TmpTy, (uint64_t)0);
118           if (!NextTy)
119             break;
120 
121           IntIndices.push_back(APInt::getZero(
122               isa<StructType>(TmpTy) ? 32 : OrigOffset.getBitWidth()));
123           TmpTy = NextTy;
124         }
125 
126         SmallVector<Value *> Indices;
127         for (const APInt &Index : IntIndices)
128           Indices.push_back(IRB.getInt(Index));
129 
130         if (OrigOffset != 0 || TmpTy == ResElemTy) {
131           Ptr = IRB.CreateGEP(OrigElemTy, Ptr, Indices);
132           return IRB.CreateBitCast(Ptr, ResElemTy->getPointerTo(AddrSpace));
133         }
134       }
135     }
136   }
137 
138   if (OrigOffset != 0) {
139     Ptr = IRB.CreateBitCast(Ptr, IRB.getInt8PtrTy(AddrSpace));
140     Ptr = IRB.CreateGEP(IRB.getInt8Ty(), Ptr, IRB.getInt(OrigOffset));
141   }
142   return IRB.CreateBitCast(Ptr, ResElemTy->getPointerTo(AddrSpace));
143 }
144 
145 /// DoPromotion - This method actually performs the promotion of the specified
146 /// arguments, and returns the new function.  At this point, we know that it's
147 /// safe to do so.
148 static Function *
149 doPromotion(Function *F, FunctionAnalysisManager &FAM,
150             const DenseMap<Argument *, SmallVector<OffsetAndArgPart, 4>>
151                 &ArgsToPromote) {
152   // Start by computing a new prototype for the function, which is the same as
153   // the old function, but has modified arguments.
154   FunctionType *FTy = F->getFunctionType();
155   std::vector<Type *> Params;
156 
157   // Attribute - Keep track of the parameter attributes for the arguments
158   // that we are *not* promoting. For the ones that we do promote, the parameter
159   // attributes are lost
160   SmallVector<AttributeSet, 8> ArgAttrVec;
161   AttributeList PAL = F->getAttributes();
162 
163   // First, determine the new argument list
164   unsigned ArgNo = 0;
165   for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
166        ++I, ++ArgNo) {
167     if (!ArgsToPromote.count(&*I)) {
168       // Unchanged argument
169       Params.push_back(I->getType());
170       ArgAttrVec.push_back(PAL.getParamAttrs(ArgNo));
171     } else if (I->use_empty()) {
172       // Dead argument (which are always marked as promotable)
173       ++NumArgumentsDead;
174     } else {
175       const auto &ArgParts = ArgsToPromote.find(&*I)->second;
176       for (const auto &Pair : ArgParts) {
177         Params.push_back(Pair.second.Ty);
178         ArgAttrVec.push_back(AttributeSet());
179       }
180       ++NumArgumentsPromoted;
181     }
182   }
183 
184   Type *RetTy = FTy->getReturnType();
185 
186   // Construct the new function type using the new arguments.
187   FunctionType *NFTy = FunctionType::get(RetTy, Params, FTy->isVarArg());
188 
189   // Create the new function body and insert it into the module.
190   Function *NF = Function::Create(NFTy, F->getLinkage(), F->getAddressSpace(),
191                                   F->getName());
192   NF->copyAttributesFrom(F);
193   NF->copyMetadata(F, 0);
194 
195   // The new function will have the !dbg metadata copied from the original
196   // function. The original function may not be deleted, and dbg metadata need
197   // to be unique, so we need to drop it.
198   F->setSubprogram(nullptr);
199 
200   LLVM_DEBUG(dbgs() << "ARG PROMOTION:  Promoting to:" << *NF << "\n"
201                     << "From: " << *F);
202 
203   uint64_t LargestVectorWidth = 0;
204   for (auto *I : Params)
205     if (auto *VT = dyn_cast<llvm::VectorType>(I))
206       LargestVectorWidth = std::max(
207           LargestVectorWidth, VT->getPrimitiveSizeInBits().getKnownMinValue());
208 
209   // Recompute the parameter attributes list based on the new arguments for
210   // the function.
211   NF->setAttributes(AttributeList::get(F->getContext(), PAL.getFnAttrs(),
212                                        PAL.getRetAttrs(), ArgAttrVec));
213   AttributeFuncs::updateMinLegalVectorWidthAttr(*NF, LargestVectorWidth);
214   ArgAttrVec.clear();
215 
216   F->getParent()->getFunctionList().insert(F->getIterator(), NF);
217   NF->takeName(F);
218 
219   // Loop over all the callers of the function, transforming the call sites to
220   // pass in the loaded pointers.
221   SmallVector<Value *, 16> Args;
222   const DataLayout &DL = F->getParent()->getDataLayout();
223   while (!F->use_empty()) {
224     CallBase &CB = cast<CallBase>(*F->user_back());
225     assert(CB.getCalledFunction() == F);
226     const AttributeList &CallPAL = CB.getAttributes();
227     IRBuilder<NoFolder> IRB(&CB);
228 
229     // Loop over the operands, inserting GEP and loads in the caller as
230     // appropriate.
231     auto *AI = CB.arg_begin();
232     ArgNo = 0;
233     for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
234          ++I, ++AI, ++ArgNo) {
235       if (!ArgsToPromote.count(&*I)) {
236         Args.push_back(*AI); // Unmodified argument
237         ArgAttrVec.push_back(CallPAL.getParamAttrs(ArgNo));
238       } else if (!I->use_empty()) {
239         Value *V = *AI;
240         const auto &ArgParts = ArgsToPromote.find(&*I)->second;
241         for (const auto &Pair : ArgParts) {
242           LoadInst *LI = IRB.CreateAlignedLoad(
243               Pair.second.Ty,
244               createByteGEP(IRB, DL, V, Pair.second.Ty, Pair.first),
245               Pair.second.Alignment, V->getName() + ".val");
246           if (Pair.second.MustExecInstr) {
247             LI->setAAMetadata(Pair.second.MustExecInstr->getAAMetadata());
248             LI->copyMetadata(*Pair.second.MustExecInstr,
249                              {LLVMContext::MD_range, LLVMContext::MD_nonnull,
250                               LLVMContext::MD_dereferenceable,
251                               LLVMContext::MD_dereferenceable_or_null,
252                               LLVMContext::MD_align, LLVMContext::MD_noundef,
253                               LLVMContext::MD_nontemporal});
254           }
255           Args.push_back(LI);
256           ArgAttrVec.push_back(AttributeSet());
257         }
258       }
259     }
260 
261     // Push any varargs arguments on the list.
262     for (; AI != CB.arg_end(); ++AI, ++ArgNo) {
263       Args.push_back(*AI);
264       ArgAttrVec.push_back(CallPAL.getParamAttrs(ArgNo));
265     }
266 
267     SmallVector<OperandBundleDef, 1> OpBundles;
268     CB.getOperandBundlesAsDefs(OpBundles);
269 
270     CallBase *NewCS = nullptr;
271     if (InvokeInst *II = dyn_cast<InvokeInst>(&CB)) {
272       NewCS = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(),
273                                  Args, OpBundles, "", &CB);
274     } else {
275       auto *NewCall = CallInst::Create(NF, Args, OpBundles, "", &CB);
276       NewCall->setTailCallKind(cast<CallInst>(&CB)->getTailCallKind());
277       NewCS = NewCall;
278     }
279     NewCS->setCallingConv(CB.getCallingConv());
280     NewCS->setAttributes(AttributeList::get(F->getContext(),
281                                             CallPAL.getFnAttrs(),
282                                             CallPAL.getRetAttrs(), ArgAttrVec));
283     NewCS->copyMetadata(CB, {LLVMContext::MD_prof, LLVMContext::MD_dbg});
284     Args.clear();
285     ArgAttrVec.clear();
286 
287     AttributeFuncs::updateMinLegalVectorWidthAttr(*CB.getCaller(),
288                                                   LargestVectorWidth);
289 
290     if (!CB.use_empty()) {
291       CB.replaceAllUsesWith(NewCS);
292       NewCS->takeName(&CB);
293     }
294 
295     // Finally, remove the old call from the program, reducing the use-count of
296     // F.
297     CB.eraseFromParent();
298   }
299 
300   // Since we have now created the new function, splice the body of the old
301   // function right into the new function, leaving the old rotting hulk of the
302   // function empty.
303   NF->splice(NF->begin(), F);
304 
305   // We will collect all the new created allocas to promote them into registers
306   // after the following loop
307   SmallVector<AllocaInst *, 4> Allocas;
308 
309   // Loop over the argument list, transferring uses of the old arguments over to
310   // the new arguments, also transferring over the names as well.
311   Function::arg_iterator I2 = NF->arg_begin();
312   for (Argument &Arg : F->args()) {
313     if (!ArgsToPromote.count(&Arg)) {
314       // If this is an unmodified argument, move the name and users over to the
315       // new version.
316       Arg.replaceAllUsesWith(&*I2);
317       I2->takeName(&Arg);
318       ++I2;
319       continue;
320     }
321 
322     // There potentially are metadata uses for things like llvm.dbg.value.
323     // Replace them with undef, after handling the other regular uses.
324     auto RauwUndefMetadata = make_scope_exit(
325         [&]() { Arg.replaceAllUsesWith(UndefValue::get(Arg.getType())); });
326 
327     if (Arg.use_empty())
328       continue;
329 
330     // Otherwise, if we promoted this argument, we have to create an alloca in
331     // the callee for every promotable part and store each of the new incoming
332     // arguments into the corresponding alloca, what lets the old code (the
333     // store instructions if they are allowed especially) a chance to work as
334     // before.
335     assert(Arg.getType()->isPointerTy() &&
336            "Only arguments with a pointer type are promotable");
337 
338     IRBuilder<NoFolder> IRB(&NF->begin()->front());
339 
340     // Add only the promoted elements, so parts from ArgsToPromote
341     SmallDenseMap<int64_t, AllocaInst *> OffsetToAlloca;
342     for (const auto &Pair : ArgsToPromote.find(&Arg)->second) {
343       int64_t Offset = Pair.first;
344       const ArgPart &Part = Pair.second;
345 
346       Argument *NewArg = I2++;
347       NewArg->setName(Arg.getName() + "." + Twine(Offset) + ".val");
348 
349       AllocaInst *NewAlloca = IRB.CreateAlloca(
350           Part.Ty, nullptr, Arg.getName() + "." + Twine(Offset) + ".allc");
351       NewAlloca->setAlignment(Pair.second.Alignment);
352       IRB.CreateAlignedStore(NewArg, NewAlloca, Pair.second.Alignment);
353 
354       // Collect the alloca to retarget the users to
355       OffsetToAlloca.insert({Offset, NewAlloca});
356     }
357 
358     auto GetAlloca = [&](Value *Ptr) {
359       APInt Offset(DL.getIndexTypeSizeInBits(Ptr->getType()), 0);
360       Ptr = Ptr->stripAndAccumulateConstantOffsets(DL, Offset,
361                                                    /* AllowNonInbounds */ true);
362       assert(Ptr == &Arg && "Not constant offset from arg?");
363       return OffsetToAlloca.lookup(Offset.getSExtValue());
364     };
365 
366     // Cleanup the code from the dead instructions: GEPs and BitCasts in between
367     // the original argument and its users: loads and stores. Retarget every
368     // user to the new created alloca.
369     SmallVector<Value *, 16> Worklist;
370     SmallVector<Instruction *, 16> DeadInsts;
371     append_range(Worklist, Arg.users());
372     while (!Worklist.empty()) {
373       Value *V = Worklist.pop_back_val();
374       if (isa<BitCastInst>(V) || isa<GetElementPtrInst>(V)) {
375         DeadInsts.push_back(cast<Instruction>(V));
376         append_range(Worklist, V->users());
377         continue;
378       }
379 
380       if (auto *LI = dyn_cast<LoadInst>(V)) {
381         Value *Ptr = LI->getPointerOperand();
382         LI->setOperand(LoadInst::getPointerOperandIndex(), GetAlloca(Ptr));
383         continue;
384       }
385 
386       if (auto *SI = dyn_cast<StoreInst>(V)) {
387         assert(!SI->isVolatile() && "Volatile operations can't be promoted.");
388         Value *Ptr = SI->getPointerOperand();
389         SI->setOperand(StoreInst::getPointerOperandIndex(), GetAlloca(Ptr));
390         continue;
391       }
392 
393       llvm_unreachable("Unexpected user");
394     }
395 
396     for (Instruction *I : DeadInsts) {
397       I->replaceAllUsesWith(PoisonValue::get(I->getType()));
398       I->eraseFromParent();
399     }
400 
401     // Collect the allocas for promotion
402     for (const auto &Pair : OffsetToAlloca) {
403       assert(isAllocaPromotable(Pair.second) &&
404              "By design, only promotable allocas should be produced.");
405       Allocas.push_back(Pair.second);
406     }
407   }
408 
409   LLVM_DEBUG(dbgs() << "ARG PROMOTION: " << Allocas.size()
410                     << " alloca(s) are promotable by Mem2Reg\n");
411 
412   if (!Allocas.empty()) {
413     // And we are able to call the `promoteMemoryToRegister()` function.
414     // Our earlier checks have ensured that PromoteMemToReg() will
415     // succeed.
416     auto &DT = FAM.getResult<DominatorTreeAnalysis>(*NF);
417     auto &AC = FAM.getResult<AssumptionAnalysis>(*NF);
418     PromoteMemToReg(Allocas, DT, &AC);
419   }
420 
421   return NF;
422 }
423 
424 /// Return true if we can prove that all callees pass in a valid pointer for the
425 /// specified function argument.
426 static bool allCallersPassValidPointerForArgument(Argument *Arg,
427                                                   Align NeededAlign,
428                                                   uint64_t NeededDerefBytes) {
429   Function *Callee = Arg->getParent();
430   const DataLayout &DL = Callee->getParent()->getDataLayout();
431   APInt Bytes(64, NeededDerefBytes);
432 
433   // Check if the argument itself is marked dereferenceable and aligned.
434   if (isDereferenceableAndAlignedPointer(Arg, NeededAlign, Bytes, DL))
435     return true;
436 
437   // Look at all call sites of the function.  At this point we know we only have
438   // direct callees.
439   return all_of(Callee->users(), [&](User *U) {
440     CallBase &CB = cast<CallBase>(*U);
441     return isDereferenceableAndAlignedPointer(CB.getArgOperand(Arg->getArgNo()),
442                                               NeededAlign, Bytes, DL);
443   });
444 }
445 
446 /// Determine that this argument is safe to promote, and find the argument
447 /// parts it can be promoted into.
448 static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
449                          unsigned MaxElements, bool IsRecursive,
450                          SmallVectorImpl<OffsetAndArgPart> &ArgPartsVec) {
451   // Quick exit for unused arguments
452   if (Arg->use_empty())
453     return true;
454 
455   // We can only promote this argument if all the uses are loads at known
456   // offsets.
457   //
458   // Promoting the argument causes it to be loaded in the caller
459   // unconditionally. This is only safe if we can prove that either the load
460   // would have happened in the callee anyway (ie, there is a load in the entry
461   // block) or the pointer passed in at every call site is guaranteed to be
462   // valid.
463   // In the former case, invalid loads can happen, but would have happened
464   // anyway, in the latter case, invalid loads won't happen. This prevents us
465   // from introducing an invalid load that wouldn't have happened in the
466   // original code.
467 
468   SmallDenseMap<int64_t, ArgPart, 4> ArgParts;
469   Align NeededAlign(1);
470   uint64_t NeededDerefBytes = 0;
471 
472   // And if this is a byval argument we also allow to have store instructions.
473   // Only handle in such way arguments with specified alignment;
474   // if it's unspecified, the actual alignment of the argument is
475   // target-specific.
476   bool AreStoresAllowed = Arg->getParamByValType() && Arg->getParamAlign();
477 
478   // An end user of a pointer argument is a load or store instruction.
479   // Returns std::nullopt if this load or store is not based on the argument.
480   // Return true if we can promote the instruction, false otherwise.
481   auto HandleEndUser = [&](auto *I, Type *Ty,
482                            bool GuaranteedToExecute) -> std::optional<bool> {
483     // Don't promote volatile or atomic instructions.
484     if (!I->isSimple())
485       return false;
486 
487     Value *Ptr = I->getPointerOperand();
488     APInt Offset(DL.getIndexTypeSizeInBits(Ptr->getType()), 0);
489     Ptr = Ptr->stripAndAccumulateConstantOffsets(DL, Offset,
490                                                  /* AllowNonInbounds */ true);
491     if (Ptr != Arg)
492       return std::nullopt;
493 
494     if (Offset.getSignificantBits() >= 64)
495       return false;
496 
497     TypeSize Size = DL.getTypeStoreSize(Ty);
498     // Don't try to promote scalable types.
499     if (Size.isScalable())
500       return false;
501 
502     // If this is a recursive function and one of the types is a pointer,
503     // then promoting it might lead to recursive promotion.
504     if (IsRecursive && Ty->isPointerTy())
505       return false;
506 
507     int64_t Off = Offset.getSExtValue();
508     auto Pair = ArgParts.try_emplace(
509         Off, ArgPart{Ty, I->getAlign(), GuaranteedToExecute ? I : nullptr});
510     ArgPart &Part = Pair.first->second;
511     bool OffsetNotSeenBefore = Pair.second;
512 
513     // We limit promotion to only promoting up to a fixed number of elements of
514     // the aggregate.
515     if (MaxElements > 0 && ArgParts.size() > MaxElements) {
516       LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
517                         << "more than " << MaxElements << " parts\n");
518       return false;
519     }
520 
521     // For now, we only support loading/storing one specific type at a given
522     // offset.
523     if (Part.Ty != Ty) {
524       LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
525                         << "accessed as both " << *Part.Ty << " and " << *Ty
526                         << " at offset " << Off << "\n");
527       return false;
528     }
529 
530     // If this instruction is not guaranteed to execute, and we haven't seen a
531     // load or store at this offset before (or it had lower alignment), then we
532     // need to remember that requirement.
533     // Note that skipping instructions of previously seen offsets is only
534     // correct because we only allow a single type for a given offset, which
535     // also means that the number of accessed bytes will be the same.
536     if (!GuaranteedToExecute &&
537         (OffsetNotSeenBefore || Part.Alignment < I->getAlign())) {
538       // We won't be able to prove dereferenceability for negative offsets.
539       if (Off < 0)
540         return false;
541 
542       // If the offset is not aligned, an aligned base pointer won't help.
543       if (!isAligned(I->getAlign(), Off))
544         return false;
545 
546       NeededDerefBytes = std::max(NeededDerefBytes, Off + Size.getFixedValue());
547       NeededAlign = std::max(NeededAlign, I->getAlign());
548     }
549 
550     Part.Alignment = std::max(Part.Alignment, I->getAlign());
551     return true;
552   };
553 
554   // Look for loads and stores that are guaranteed to execute on entry.
555   for (Instruction &I : Arg->getParent()->getEntryBlock()) {
556     std::optional<bool> Res{};
557     if (LoadInst *LI = dyn_cast<LoadInst>(&I))
558       Res = HandleEndUser(LI, LI->getType(), /* GuaranteedToExecute */ true);
559     else if (StoreInst *SI = dyn_cast<StoreInst>(&I))
560       Res = HandleEndUser(SI, SI->getValueOperand()->getType(),
561                           /* GuaranteedToExecute */ true);
562     if (Res && !*Res)
563       return false;
564 
565     if (!isGuaranteedToTransferExecutionToSuccessor(&I))
566       break;
567   }
568 
569   // Now look at all loads of the argument. Remember the load instructions
570   // for the aliasing check below.
571   SmallVector<const Use *, 16> Worklist;
572   SmallPtrSet<const Use *, 16> Visited;
573   SmallVector<LoadInst *, 16> Loads;
574   auto AppendUses = [&](const Value *V) {
575     for (const Use &U : V->uses())
576       if (Visited.insert(&U).second)
577         Worklist.push_back(&U);
578   };
579   AppendUses(Arg);
580   while (!Worklist.empty()) {
581     const Use *U = Worklist.pop_back_val();
582     Value *V = U->getUser();
583     if (isa<BitCastInst>(V)) {
584       AppendUses(V);
585       continue;
586     }
587 
588     if (auto *GEP = dyn_cast<GetElementPtrInst>(V)) {
589       if (!GEP->hasAllConstantIndices())
590         return false;
591       AppendUses(V);
592       continue;
593     }
594 
595     if (auto *LI = dyn_cast<LoadInst>(V)) {
596       if (!*HandleEndUser(LI, LI->getType(), /* GuaranteedToExecute */ false))
597         return false;
598       Loads.push_back(LI);
599       continue;
600     }
601 
602     // Stores are allowed for byval arguments
603     auto *SI = dyn_cast<StoreInst>(V);
604     if (AreStoresAllowed && SI &&
605         U->getOperandNo() == StoreInst::getPointerOperandIndex()) {
606       if (!*HandleEndUser(SI, SI->getValueOperand()->getType(),
607                           /* GuaranteedToExecute */ false))
608         return false;
609       continue;
610       // Only stores TO the argument is allowed, all the other stores are
611       // unknown users
612     }
613 
614     // Unknown user.
615     LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
616                       << "unknown user " << *V << "\n");
617     return false;
618   }
619 
620   if (NeededDerefBytes || NeededAlign > 1) {
621     // Try to prove a required deref / aligned requirement.
622     if (!allCallersPassValidPointerForArgument(Arg, NeededAlign,
623                                                NeededDerefBytes)) {
624       LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
625                         << "not dereferenceable or aligned\n");
626       return false;
627     }
628   }
629 
630   if (ArgParts.empty())
631     return true; // No users, this is a dead argument.
632 
633   // Sort parts by offset.
634   append_range(ArgPartsVec, ArgParts);
635   sort(ArgPartsVec, llvm::less_first());
636 
637   // Make sure the parts are non-overlapping.
638   int64_t Offset = ArgPartsVec[0].first;
639   for (const auto &Pair : ArgPartsVec) {
640     if (Pair.first < Offset)
641       return false; // Overlap with previous part.
642 
643     Offset = Pair.first + DL.getTypeStoreSize(Pair.second.Ty);
644   }
645 
646   // If store instructions are allowed, the path from the entry of the function
647   // to each load may be not free of instructions that potentially invalidate
648   // the load, and this is an admissible situation.
649   if (AreStoresAllowed)
650     return true;
651 
652   // Okay, now we know that the argument is only used by load instructions, and
653   // it is safe to unconditionally perform all of them. Use alias analysis to
654   // check to see if the pointer is guaranteed to not be modified from entry of
655   // the function to each of the load instructions.
656 
657   // Because there could be several/many load instructions, remember which
658   // blocks we know to be transparent to the load.
659   df_iterator_default_set<BasicBlock *, 16> TranspBlocks;
660 
661   for (LoadInst *Load : Loads) {
662     // Check to see if the load is invalidated from the start of the block to
663     // the load itself.
664     BasicBlock *BB = Load->getParent();
665 
666     MemoryLocation Loc = MemoryLocation::get(Load);
667     if (AAR.canInstructionRangeModRef(BB->front(), *Load, Loc, ModRefInfo::Mod))
668       return false; // Pointer is invalidated!
669 
670     // Now check every path from the entry block to the load for transparency.
671     // To do this, we perform a depth first search on the inverse CFG from the
672     // loading block.
673     for (BasicBlock *P : predecessors(BB)) {
674       for (BasicBlock *TranspBB : inverse_depth_first_ext(P, TranspBlocks))
675         if (AAR.canBasicBlockModify(*TranspBB, Loc))
676           return false;
677     }
678   }
679 
680   // If the path from the entry of the function to each load is free of
681   // instructions that potentially invalidate the load, we can make the
682   // transformation!
683   return true;
684 }
685 
686 /// Check if callers and callee agree on how promoted arguments would be
687 /// passed.
688 static bool areTypesABICompatible(ArrayRef<Type *> Types, const Function &F,
689                                   const TargetTransformInfo &TTI) {
690   return all_of(F.uses(), [&](const Use &U) {
691     CallBase *CB = dyn_cast<CallBase>(U.getUser());
692     if (!CB)
693       return false;
694 
695     const Function *Caller = CB->getCaller();
696     const Function *Callee = CB->getCalledFunction();
697     return TTI.areTypesABICompatible(Caller, Callee, Types);
698   });
699 }
700 
701 /// PromoteArguments - This method checks the specified function to see if there
702 /// are any promotable arguments and if it is safe to promote the function (for
703 /// example, all callers are direct).  If safe to promote some arguments, it
704 /// calls the DoPromotion method.
705 static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
706                                   unsigned MaxElements, bool IsRecursive) {
707   // Don't perform argument promotion for naked functions; otherwise we can end
708   // up removing parameters that are seemingly 'not used' as they are referred
709   // to in the assembly.
710   if (F->hasFnAttribute(Attribute::Naked))
711     return nullptr;
712 
713   // Make sure that it is local to this module.
714   if (!F->hasLocalLinkage())
715     return nullptr;
716 
717   // Don't promote arguments for variadic functions. Adding, removing, or
718   // changing non-pack parameters can change the classification of pack
719   // parameters. Frontends encode that classification at the call site in the
720   // IR, while in the callee the classification is determined dynamically based
721   // on the number of registers consumed so far.
722   if (F->isVarArg())
723     return nullptr;
724 
725   // Don't transform functions that receive inallocas, as the transformation may
726   // not be safe depending on calling convention.
727   if (F->getAttributes().hasAttrSomewhere(Attribute::InAlloca))
728     return nullptr;
729 
730   // First check: see if there are any pointer arguments!  If not, quick exit.
731   SmallVector<Argument *, 16> PointerArgs;
732   for (Argument &I : F->args())
733     if (I.getType()->isPointerTy())
734       PointerArgs.push_back(&I);
735   if (PointerArgs.empty())
736     return nullptr;
737 
738   // Second check: make sure that all callers are direct callers.  We can't
739   // transform functions that have indirect callers.  Also see if the function
740   // is self-recursive.
741   for (Use &U : F->uses()) {
742     CallBase *CB = dyn_cast<CallBase>(U.getUser());
743     // Must be a direct call.
744     if (CB == nullptr || !CB->isCallee(&U) ||
745         CB->getFunctionType() != F->getFunctionType())
746       return nullptr;
747 
748     // Can't change signature of musttail callee
749     if (CB->isMustTailCall())
750       return nullptr;
751 
752     if (CB->getFunction() == F)
753       IsRecursive = true;
754   }
755 
756   // Can't change signature of musttail caller
757   // FIXME: Support promoting whole chain of musttail functions
758   for (BasicBlock &BB : *F)
759     if (BB.getTerminatingMustTailCall())
760       return nullptr;
761 
762   const DataLayout &DL = F->getParent()->getDataLayout();
763   auto &AAR = FAM.getResult<AAManager>(*F);
764   const auto &TTI = FAM.getResult<TargetIRAnalysis>(*F);
765 
766   // Check to see which arguments are promotable.  If an argument is promotable,
767   // add it to ArgsToPromote.
768   DenseMap<Argument *, SmallVector<OffsetAndArgPart, 4>> ArgsToPromote;
769   for (Argument *PtrArg : PointerArgs) {
770     // Replace sret attribute with noalias. This reduces register pressure by
771     // avoiding a register copy.
772     if (PtrArg->hasStructRetAttr()) {
773       unsigned ArgNo = PtrArg->getArgNo();
774       F->removeParamAttr(ArgNo, Attribute::StructRet);
775       F->addParamAttr(ArgNo, Attribute::NoAlias);
776       for (Use &U : F->uses()) {
777         CallBase &CB = cast<CallBase>(*U.getUser());
778         CB.removeParamAttr(ArgNo, Attribute::StructRet);
779         CB.addParamAttr(ArgNo, Attribute::NoAlias);
780       }
781     }
782 
783     // If we can promote the pointer to its value.
784     SmallVector<OffsetAndArgPart, 4> ArgParts;
785 
786     if (findArgParts(PtrArg, DL, AAR, MaxElements, IsRecursive, ArgParts)) {
787       SmallVector<Type *, 4> Types;
788       for (const auto &Pair : ArgParts)
789         Types.push_back(Pair.second.Ty);
790 
791       if (areTypesABICompatible(Types, *F, TTI)) {
792         ArgsToPromote.insert({PtrArg, std::move(ArgParts)});
793       }
794     }
795   }
796 
797   // No promotable pointer arguments.
798   if (ArgsToPromote.empty())
799     return nullptr;
800 
801   return doPromotion(F, FAM, ArgsToPromote);
802 }
803 
804 PreservedAnalyses ArgumentPromotionPass::run(LazyCallGraph::SCC &C,
805                                              CGSCCAnalysisManager &AM,
806                                              LazyCallGraph &CG,
807                                              CGSCCUpdateResult &UR) {
808   bool Changed = false, LocalChange;
809 
810   // Iterate until we stop promoting from this SCC.
811   do {
812     LocalChange = false;
813 
814     FunctionAnalysisManager &FAM =
815         AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
816 
817     bool IsRecursive = C.size() > 1;
818     for (LazyCallGraph::Node &N : C) {
819       Function &OldF = N.getFunction();
820       Function *NewF = promoteArguments(&OldF, FAM, MaxElements, IsRecursive);
821       if (!NewF)
822         continue;
823       LocalChange = true;
824 
825       // Directly substitute the functions in the call graph. Note that this
826       // requires the old function to be completely dead and completely
827       // replaced by the new function. It does no call graph updates, it merely
828       // swaps out the particular function mapped to a particular node in the
829       // graph.
830       C.getOuterRefSCC().replaceNodeFunction(N, *NewF);
831       FAM.clear(OldF, OldF.getName());
832       OldF.eraseFromParent();
833 
834       PreservedAnalyses FuncPA;
835       FuncPA.preserveSet<CFGAnalyses>();
836       for (auto *U : NewF->users()) {
837         auto *UserF = cast<CallBase>(U)->getFunction();
838         FAM.invalidate(*UserF, FuncPA);
839       }
840     }
841 
842     Changed |= LocalChange;
843   } while (LocalChange);
844 
845   if (!Changed)
846     return PreservedAnalyses::all();
847 
848   PreservedAnalyses PA;
849   // We've cleared out analyses for deleted functions.
850   PA.preserve<FunctionAnalysisManagerCGSCCProxy>();
851   // We've manually invalidated analyses for functions we've modified.
852   PA.preserveSet<AllAnalysesOn<Function>>();
853   return PA;
854 }
855