xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- ArgumentPromotion.cpp - Promote by-reference arguments -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass promotes "by reference" arguments to be "by value" arguments.  In
10 // practice, this means looking for internal functions that have pointer
11 // arguments.  If it can prove, through the use of alias analysis, that an
12 // argument is *only* loaded, then it can pass the value into the function
13 // instead of the address of the value.  This can cause recursive simplification
14 // of code and lead to the elimination of allocas (especially in C++ template
15 // code like the STL).
16 //
17 // This pass also handles aggregate arguments that are passed into a function,
18 // scalarizing them if the elements of the aggregate are only loaded.  Note that
19 // by default it refuses to scalarize aggregates which would require passing in
20 // more than three operands to the function, because passing thousands of
21 // operands for a large array or structure is unprofitable! This limit can be
22 // configured or disabled, however.
23 //
24 // Note that this transformation could also be done for arguments that are only
25 // stored to (returning the value instead), but does not currently.  This case
26 // would be best handled when and if LLVM begins supporting multiple return
27 // values from functions.
28 //
29 //===----------------------------------------------------------------------===//
30 
31 #include "llvm/Transforms/IPO/ArgumentPromotion.h"
32 
33 #include "llvm/ADT/DepthFirstIterator.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/ScopeExit.h"
36 #include "llvm/ADT/SmallPtrSet.h"
37 #include "llvm/ADT/SmallVector.h"
38 #include "llvm/ADT/Statistic.h"
39 #include "llvm/ADT/Twine.h"
40 #include "llvm/Analysis/AssumptionCache.h"
41 #include "llvm/Analysis/BasicAliasAnalysis.h"
42 #include "llvm/Analysis/CallGraph.h"
43 #include "llvm/Analysis/Loads.h"
44 #include "llvm/Analysis/MemoryLocation.h"
45 #include "llvm/Analysis/TargetTransformInfo.h"
46 #include "llvm/Analysis/ValueTracking.h"
47 #include "llvm/IR/Argument.h"
48 #include "llvm/IR/Attributes.h"
49 #include "llvm/IR/BasicBlock.h"
50 #include "llvm/IR/CFG.h"
51 #include "llvm/IR/Constants.h"
52 #include "llvm/IR/DataLayout.h"
53 #include "llvm/IR/DerivedTypes.h"
54 #include "llvm/IR/Dominators.h"
55 #include "llvm/IR/Function.h"
56 #include "llvm/IR/IRBuilder.h"
57 #include "llvm/IR/InstrTypes.h"
58 #include "llvm/IR/Instruction.h"
59 #include "llvm/IR/Instructions.h"
60 #include "llvm/IR/Metadata.h"
61 #include "llvm/IR/Module.h"
62 #include "llvm/IR/NoFolder.h"
63 #include "llvm/IR/PassManager.h"
64 #include "llvm/IR/Type.h"
65 #include "llvm/IR/Use.h"
66 #include "llvm/IR/User.h"
67 #include "llvm/IR/Value.h"
68 #include "llvm/Support/Casting.h"
69 #include "llvm/Support/Debug.h"
70 #include "llvm/Support/raw_ostream.h"
71 #include "llvm/Transforms/Utils/Local.h"
72 #include "llvm/Transforms/Utils/PromoteMemToReg.h"
73 #include <algorithm>
74 #include <cassert>
75 #include <cstdint>
76 #include <utility>
77 #include <vector>
78 
79 using namespace llvm;
80 
81 #define DEBUG_TYPE "argpromotion"
82 
83 STATISTIC(NumArgumentsPromoted, "Number of pointer arguments promoted");
84 STATISTIC(NumArgumentsDead, "Number of dead pointer args eliminated");
85 
86 namespace {
87 
88 struct ArgPart {
89   Type *Ty;
90   Align Alignment;
91   /// A representative guaranteed-executed load or store instruction for use by
92   /// metadata transfer.
93   Instruction *MustExecInstr;
94 };
95 
96 using OffsetAndArgPart = std::pair<int64_t, ArgPart>;
97 
98 } // end anonymous namespace
99 
createByteGEP(IRBuilderBase & IRB,const DataLayout & DL,Value * Ptr,Type * ResElemTy,int64_t Offset)100 static Value *createByteGEP(IRBuilderBase &IRB, const DataLayout &DL,
101                             Value *Ptr, Type *ResElemTy, int64_t Offset) {
102   if (Offset != 0) {
103     APInt APOffset(DL.getIndexTypeSizeInBits(Ptr->getType()), Offset);
104     Ptr = IRB.CreatePtrAdd(Ptr, IRB.getInt(APOffset));
105   }
106   return Ptr;
107 }
108 
109 /// DoPromotion - This method actually performs the promotion of the specified
110 /// arguments, and returns the new function.  At this point, we know that it's
111 /// safe to do so.
112 static Function *
doPromotion(Function * F,FunctionAnalysisManager & FAM,const DenseMap<Argument *,SmallVector<OffsetAndArgPart,4>> & ArgsToPromote)113 doPromotion(Function *F, FunctionAnalysisManager &FAM,
114             const DenseMap<Argument *, SmallVector<OffsetAndArgPart, 4>>
115                 &ArgsToPromote) {
116   // Start by computing a new prototype for the function, which is the same as
117   // the old function, but has modified arguments.
118   FunctionType *FTy = F->getFunctionType();
119   std::vector<Type *> Params;
120 
121   // Attribute - Keep track of the parameter attributes for the arguments
122   // that we are *not* promoting. For the ones that we do promote, the parameter
123   // attributes are lost
124   SmallVector<AttributeSet, 8> ArgAttrVec;
125   // Mapping from old to new argument indices. -1 for promoted or removed
126   // arguments.
127   SmallVector<unsigned> NewArgIndices;
128   AttributeList PAL = F->getAttributes();
129 
130   // First, determine the new argument list
131   unsigned ArgNo = 0, NewArgNo = 0;
132   for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
133        ++I, ++ArgNo) {
134     if (!ArgsToPromote.count(&*I)) {
135       // Unchanged argument
136       Params.push_back(I->getType());
137       ArgAttrVec.push_back(PAL.getParamAttrs(ArgNo));
138       NewArgIndices.push_back(NewArgNo++);
139     } else if (I->use_empty()) {
140       // Dead argument (which are always marked as promotable)
141       ++NumArgumentsDead;
142       NewArgIndices.push_back((unsigned)-1);
143     } else {
144       const auto &ArgParts = ArgsToPromote.find(&*I)->second;
145       for (const auto &Pair : ArgParts) {
146         Params.push_back(Pair.second.Ty);
147         ArgAttrVec.push_back(AttributeSet());
148       }
149       ++NumArgumentsPromoted;
150       NewArgIndices.push_back((unsigned)-1);
151       NewArgNo += ArgParts.size();
152     }
153   }
154 
155   Type *RetTy = FTy->getReturnType();
156 
157   // Construct the new function type using the new arguments.
158   FunctionType *NFTy = FunctionType::get(RetTy, Params, FTy->isVarArg());
159 
160   // Create the new function body and insert it into the module.
161   Function *NF = Function::Create(NFTy, F->getLinkage(), F->getAddressSpace(),
162                                   F->getName());
163   NF->copyAttributesFrom(F);
164   NF->copyMetadata(F, 0);
165   NF->setIsNewDbgInfoFormat(F->IsNewDbgInfoFormat);
166 
167   // The new function will have the !dbg metadata copied from the original
168   // function. The original function may not be deleted, and dbg metadata need
169   // to be unique, so we need to drop it.
170   F->setSubprogram(nullptr);
171 
172   LLVM_DEBUG(dbgs() << "ARG PROMOTION:  Promoting to:" << *NF << "\n"
173                     << "From: " << *F);
174 
175   uint64_t LargestVectorWidth = 0;
176   for (auto *I : Params)
177     if (auto *VT = dyn_cast<llvm::VectorType>(I))
178       LargestVectorWidth = std::max(
179           LargestVectorWidth, VT->getPrimitiveSizeInBits().getKnownMinValue());
180 
181   // Recompute the parameter attributes list based on the new arguments for
182   // the function.
183   NF->setAttributes(AttributeList::get(F->getContext(), PAL.getFnAttrs(),
184                                        PAL.getRetAttrs(), ArgAttrVec));
185 
186   // Remap argument indices in allocsize attribute.
187   if (auto AllocSize = NF->getAttributes().getFnAttrs().getAllocSizeArgs()) {
188     unsigned Arg1 = NewArgIndices[AllocSize->first];
189     assert(Arg1 != (unsigned)-1 && "allocsize cannot be promoted argument");
190     std::optional<unsigned> Arg2;
191     if (AllocSize->second) {
192       Arg2 = NewArgIndices[*AllocSize->second];
193       assert(Arg2 != (unsigned)-1 && "allocsize cannot be promoted argument");
194     }
195     NF->addFnAttr(Attribute::getWithAllocSizeArgs(F->getContext(), Arg1, Arg2));
196   }
197 
198   AttributeFuncs::updateMinLegalVectorWidthAttr(*NF, LargestVectorWidth);
199   ArgAttrVec.clear();
200 
201   F->getParent()->getFunctionList().insert(F->getIterator(), NF);
202   NF->takeName(F);
203 
204   // Loop over all the callers of the function, transforming the call sites to
205   // pass in the loaded pointers.
206   SmallVector<Value *, 16> Args;
207   const DataLayout &DL = F->getDataLayout();
208   SmallVector<WeakTrackingVH, 16> DeadArgs;
209 
210   while (!F->use_empty()) {
211     CallBase &CB = cast<CallBase>(*F->user_back());
212     assert(CB.getCalledFunction() == F);
213     const AttributeList &CallPAL = CB.getAttributes();
214     IRBuilder<NoFolder> IRB(&CB);
215 
216     // Loop over the operands, inserting GEP and loads in the caller as
217     // appropriate.
218     auto *AI = CB.arg_begin();
219     ArgNo = 0;
220     for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
221          ++I, ++AI, ++ArgNo) {
222       if (!ArgsToPromote.count(&*I)) {
223         Args.push_back(*AI); // Unmodified argument
224         ArgAttrVec.push_back(CallPAL.getParamAttrs(ArgNo));
225       } else if (!I->use_empty()) {
226         Value *V = *AI;
227         const auto &ArgParts = ArgsToPromote.find(&*I)->second;
228         for (const auto &Pair : ArgParts) {
229           LoadInst *LI = IRB.CreateAlignedLoad(
230               Pair.second.Ty,
231               createByteGEP(IRB, DL, V, Pair.second.Ty, Pair.first),
232               Pair.second.Alignment, V->getName() + ".val");
233           if (Pair.second.MustExecInstr) {
234             LI->setAAMetadata(Pair.second.MustExecInstr->getAAMetadata());
235             LI->copyMetadata(*Pair.second.MustExecInstr,
236                              {LLVMContext::MD_dereferenceable,
237                               LLVMContext::MD_dereferenceable_or_null,
238                               LLVMContext::MD_noundef,
239                               LLVMContext::MD_nontemporal});
240             // Only transfer poison-generating metadata if we also have
241             // !noundef.
242             // TODO: Without !noundef, we could merge this metadata across
243             // all promoted loads.
244             if (LI->hasMetadata(LLVMContext::MD_noundef))
245               LI->copyMetadata(*Pair.second.MustExecInstr,
246                                {LLVMContext::MD_range, LLVMContext::MD_nonnull,
247                                 LLVMContext::MD_align});
248           }
249           Args.push_back(LI);
250           ArgAttrVec.push_back(AttributeSet());
251         }
252       } else {
253         assert(ArgsToPromote.count(&*I) && I->use_empty());
254         DeadArgs.emplace_back(AI->get());
255       }
256     }
257 
258     // Push any varargs arguments on the list.
259     for (; AI != CB.arg_end(); ++AI, ++ArgNo) {
260       Args.push_back(*AI);
261       ArgAttrVec.push_back(CallPAL.getParamAttrs(ArgNo));
262     }
263 
264     SmallVector<OperandBundleDef, 1> OpBundles;
265     CB.getOperandBundlesAsDefs(OpBundles);
266 
267     CallBase *NewCS = nullptr;
268     if (InvokeInst *II = dyn_cast<InvokeInst>(&CB)) {
269       NewCS = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(),
270                                  Args, OpBundles, "", CB.getIterator());
271     } else {
272       auto *NewCall =
273           CallInst::Create(NF, Args, OpBundles, "", CB.getIterator());
274       NewCall->setTailCallKind(cast<CallInst>(&CB)->getTailCallKind());
275       NewCS = NewCall;
276     }
277     NewCS->setCallingConv(CB.getCallingConv());
278     NewCS->setAttributes(AttributeList::get(F->getContext(),
279                                             CallPAL.getFnAttrs(),
280                                             CallPAL.getRetAttrs(), ArgAttrVec));
281     NewCS->copyMetadata(CB, {LLVMContext::MD_prof, LLVMContext::MD_dbg});
282     Args.clear();
283     ArgAttrVec.clear();
284 
285     AttributeFuncs::updateMinLegalVectorWidthAttr(*CB.getCaller(),
286                                                   LargestVectorWidth);
287 
288     if (!CB.use_empty()) {
289       CB.replaceAllUsesWith(NewCS);
290       NewCS->takeName(&CB);
291     }
292 
293     // Finally, remove the old call from the program, reducing the use-count of
294     // F.
295     CB.eraseFromParent();
296   }
297 
298   RecursivelyDeleteTriviallyDeadInstructionsPermissive(DeadArgs);
299 
300   // Since we have now created the new function, splice the body of the old
301   // function right into the new function, leaving the old rotting hulk of the
302   // function empty.
303   NF->splice(NF->begin(), F);
304 
305   // We will collect all the new created allocas to promote them into registers
306   // after the following loop
307   SmallVector<AllocaInst *, 4> Allocas;
308 
309   // Loop over the argument list, transferring uses of the old arguments over to
310   // the new arguments, also transferring over the names as well.
311   Function::arg_iterator I2 = NF->arg_begin();
312   for (Argument &Arg : F->args()) {
313     if (!ArgsToPromote.count(&Arg)) {
314       // If this is an unmodified argument, move the name and users over to the
315       // new version.
316       Arg.replaceAllUsesWith(&*I2);
317       I2->takeName(&Arg);
318       ++I2;
319       continue;
320     }
321 
322     // There potentially are metadata uses for things like llvm.dbg.value.
323     // Replace them with undef, after handling the other regular uses.
324     auto RauwUndefMetadata = make_scope_exit(
325         [&]() { Arg.replaceAllUsesWith(UndefValue::get(Arg.getType())); });
326 
327     if (Arg.use_empty())
328       continue;
329 
330     // Otherwise, if we promoted this argument, we have to create an alloca in
331     // the callee for every promotable part and store each of the new incoming
332     // arguments into the corresponding alloca, what lets the old code (the
333     // store instructions if they are allowed especially) a chance to work as
334     // before.
335     assert(Arg.getType()->isPointerTy() &&
336            "Only arguments with a pointer type are promotable");
337 
338     IRBuilder<NoFolder> IRB(&NF->begin()->front());
339 
340     // Add only the promoted elements, so parts from ArgsToPromote
341     SmallDenseMap<int64_t, AllocaInst *> OffsetToAlloca;
342     for (const auto &Pair : ArgsToPromote.find(&Arg)->second) {
343       int64_t Offset = Pair.first;
344       const ArgPart &Part = Pair.second;
345 
346       Argument *NewArg = I2++;
347       NewArg->setName(Arg.getName() + "." + Twine(Offset) + ".val");
348 
349       AllocaInst *NewAlloca = IRB.CreateAlloca(
350           Part.Ty, nullptr, Arg.getName() + "." + Twine(Offset) + ".allc");
351       NewAlloca->setAlignment(Pair.second.Alignment);
352       IRB.CreateAlignedStore(NewArg, NewAlloca, Pair.second.Alignment);
353 
354       // Collect the alloca to retarget the users to
355       OffsetToAlloca.insert({Offset, NewAlloca});
356     }
357 
358     auto GetAlloca = [&](Value *Ptr) {
359       APInt Offset(DL.getIndexTypeSizeInBits(Ptr->getType()), 0);
360       Ptr = Ptr->stripAndAccumulateConstantOffsets(DL, Offset,
361                                                    /* AllowNonInbounds */ true);
362       assert(Ptr == &Arg && "Not constant offset from arg?");
363       return OffsetToAlloca.lookup(Offset.getSExtValue());
364     };
365 
366     // Cleanup the code from the dead instructions: GEPs and BitCasts in between
367     // the original argument and its users: loads and stores. Retarget every
368     // user to the new created alloca.
369     SmallVector<Value *, 16> Worklist;
370     SmallVector<Instruction *, 16> DeadInsts;
371     append_range(Worklist, Arg.users());
372     while (!Worklist.empty()) {
373       Value *V = Worklist.pop_back_val();
374       if (isa<BitCastInst>(V) || isa<GetElementPtrInst>(V)) {
375         DeadInsts.push_back(cast<Instruction>(V));
376         append_range(Worklist, V->users());
377         continue;
378       }
379 
380       if (auto *LI = dyn_cast<LoadInst>(V)) {
381         Value *Ptr = LI->getPointerOperand();
382         LI->setOperand(LoadInst::getPointerOperandIndex(), GetAlloca(Ptr));
383         continue;
384       }
385 
386       if (auto *SI = dyn_cast<StoreInst>(V)) {
387         assert(!SI->isVolatile() && "Volatile operations can't be promoted.");
388         Value *Ptr = SI->getPointerOperand();
389         SI->setOperand(StoreInst::getPointerOperandIndex(), GetAlloca(Ptr));
390         continue;
391       }
392 
393       llvm_unreachable("Unexpected user");
394     }
395 
396     for (Instruction *I : DeadInsts) {
397       I->replaceAllUsesWith(PoisonValue::get(I->getType()));
398       I->eraseFromParent();
399     }
400 
401     // Collect the allocas for promotion
402     for (const auto &Pair : OffsetToAlloca) {
403       assert(isAllocaPromotable(Pair.second) &&
404              "By design, only promotable allocas should be produced.");
405       Allocas.push_back(Pair.second);
406     }
407   }
408 
409   LLVM_DEBUG(dbgs() << "ARG PROMOTION: " << Allocas.size()
410                     << " alloca(s) are promotable by Mem2Reg\n");
411 
412   if (!Allocas.empty()) {
413     // And we are able to call the `promoteMemoryToRegister()` function.
414     // Our earlier checks have ensured that PromoteMemToReg() will
415     // succeed.
416     auto &DT = FAM.getResult<DominatorTreeAnalysis>(*NF);
417     auto &AC = FAM.getResult<AssumptionAnalysis>(*NF);
418     PromoteMemToReg(Allocas, DT, &AC);
419   }
420 
421   return NF;
422 }
423 
424 /// Return true if we can prove that all callees pass in a valid pointer for the
425 /// specified function argument.
allCallersPassValidPointerForArgument(Argument * Arg,SmallPtrSetImpl<CallBase * > & RecursiveCalls,Align NeededAlign,uint64_t NeededDerefBytes)426 static bool allCallersPassValidPointerForArgument(
427     Argument *Arg, SmallPtrSetImpl<CallBase *> &RecursiveCalls,
428     Align NeededAlign, uint64_t NeededDerefBytes) {
429   Function *Callee = Arg->getParent();
430   const DataLayout &DL = Callee->getDataLayout();
431   APInt Bytes(64, NeededDerefBytes);
432 
433   // Check if the argument itself is marked dereferenceable and aligned.
434   if (isDereferenceableAndAlignedPointer(Arg, NeededAlign, Bytes, DL))
435     return true;
436 
437   // Look at all call sites of the function.  At this point we know we only have
438   // direct callees.
439   return all_of(Callee->users(), [&](User *U) {
440     CallBase &CB = cast<CallBase>(*U);
441     // In case of functions with recursive calls, this check
442     // (isDereferenceableAndAlignedPointer) will fail when it tries to look at
443     // the first caller of this function. The caller may or may not have a load,
444     // incase it doesn't load the pointer being passed, this check will fail.
445     // So, it's safe to skip the check incase we know that we are dealing with a
446     // recursive call. For example we have a IR given below.
447     //
448     // def fun(ptr %a) {
449     //   ...
450     //   %loadres = load i32, ptr %a, align 4
451     //   %res = call i32 @fun(ptr %a)
452     //   ...
453     // }
454     //
455     // def bar(ptr %x) {
456     //   ...
457     //   %resbar = call i32 @fun(ptr %x)
458     //   ...
459     // }
460     //
461     // Since we record processed recursive calls, we check if the current
462     // CallBase has been processed before. If yes it means that it is a
463     // recursive call and we can skip the check just for this call. So, just
464     // return true.
465     if (RecursiveCalls.contains(&CB))
466       return true;
467 
468     return isDereferenceableAndAlignedPointer(CB.getArgOperand(Arg->getArgNo()),
469                                               NeededAlign, Bytes, DL);
470   });
471 }
472 
473 /// Determine that this argument is safe to promote, and find the argument
474 /// parts it can be promoted into.
findArgParts(Argument * Arg,const DataLayout & DL,AAResults & AAR,unsigned MaxElements,bool IsRecursive,SmallVectorImpl<OffsetAndArgPart> & ArgPartsVec)475 static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
476                          unsigned MaxElements, bool IsRecursive,
477                          SmallVectorImpl<OffsetAndArgPart> &ArgPartsVec) {
478   // Quick exit for unused arguments
479   if (Arg->use_empty())
480     return true;
481 
482   // We can only promote this argument if all the uses are loads at known
483   // offsets.
484   //
485   // Promoting the argument causes it to be loaded in the caller
486   // unconditionally. This is only safe if we can prove that either the load
487   // would have happened in the callee anyway (ie, there is a load in the entry
488   // block) or the pointer passed in at every call site is guaranteed to be
489   // valid.
490   // In the former case, invalid loads can happen, but would have happened
491   // anyway, in the latter case, invalid loads won't happen. This prevents us
492   // from introducing an invalid load that wouldn't have happened in the
493   // original code.
494 
495   SmallDenseMap<int64_t, ArgPart, 4> ArgParts;
496   Align NeededAlign(1);
497   uint64_t NeededDerefBytes = 0;
498 
499   // And if this is a byval argument we also allow to have store instructions.
500   // Only handle in such way arguments with specified alignment;
501   // if it's unspecified, the actual alignment of the argument is
502   // target-specific.
503   bool AreStoresAllowed = Arg->getParamByValType() && Arg->getParamAlign();
504 
505   // An end user of a pointer argument is a load or store instruction.
506   // Returns std::nullopt if this load or store is not based on the argument.
507   // Return true if we can promote the instruction, false otherwise.
508   auto HandleEndUser = [&](auto *I, Type *Ty,
509                            bool GuaranteedToExecute) -> std::optional<bool> {
510     // Don't promote volatile or atomic instructions.
511     if (!I->isSimple())
512       return false;
513 
514     Value *Ptr = I->getPointerOperand();
515     APInt Offset(DL.getIndexTypeSizeInBits(Ptr->getType()), 0);
516     Ptr = Ptr->stripAndAccumulateConstantOffsets(DL, Offset,
517                                                  /* AllowNonInbounds */ true);
518     if (Ptr != Arg)
519       return std::nullopt;
520 
521     if (Offset.getSignificantBits() >= 64)
522       return false;
523 
524     TypeSize Size = DL.getTypeStoreSize(Ty);
525     // Don't try to promote scalable types.
526     if (Size.isScalable())
527       return false;
528 
529     // If this is a recursive function and one of the types is a pointer,
530     // then promoting it might lead to recursive promotion.
531     if (IsRecursive && Ty->isPointerTy())
532       return false;
533 
534     int64_t Off = Offset.getSExtValue();
535     auto Pair = ArgParts.try_emplace(
536         Off, ArgPart{Ty, I->getAlign(), GuaranteedToExecute ? I : nullptr});
537     ArgPart &Part = Pair.first->second;
538     bool OffsetNotSeenBefore = Pair.second;
539 
540     // We limit promotion to only promoting up to a fixed number of elements of
541     // the aggregate.
542     if (MaxElements > 0 && ArgParts.size() > MaxElements) {
543       LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
544                         << "more than " << MaxElements << " parts\n");
545       return false;
546     }
547 
548     // For now, we only support loading/storing one specific type at a given
549     // offset.
550     if (Part.Ty != Ty) {
551       LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
552                         << "accessed as both " << *Part.Ty << " and " << *Ty
553                         << " at offset " << Off << "\n");
554       return false;
555     }
556 
557     // If this instruction is not guaranteed to execute, and we haven't seen a
558     // load or store at this offset before (or it had lower alignment), then we
559     // need to remember that requirement.
560     // Note that skipping instructions of previously seen offsets is only
561     // correct because we only allow a single type for a given offset, which
562     // also means that the number of accessed bytes will be the same.
563     if (!GuaranteedToExecute &&
564         (OffsetNotSeenBefore || Part.Alignment < I->getAlign())) {
565       // We won't be able to prove dereferenceability for negative offsets.
566       if (Off < 0)
567         return false;
568 
569       // If the offset is not aligned, an aligned base pointer won't help.
570       if (!isAligned(I->getAlign(), Off))
571         return false;
572 
573       NeededDerefBytes = std::max(NeededDerefBytes, Off + Size.getFixedValue());
574       NeededAlign = std::max(NeededAlign, I->getAlign());
575     }
576 
577     Part.Alignment = std::max(Part.Alignment, I->getAlign());
578     return true;
579   };
580 
581   // Look for loads and stores that are guaranteed to execute on entry.
582   for (Instruction &I : Arg->getParent()->getEntryBlock()) {
583     std::optional<bool> Res{};
584     if (LoadInst *LI = dyn_cast<LoadInst>(&I))
585       Res = HandleEndUser(LI, LI->getType(), /* GuaranteedToExecute */ true);
586     else if (StoreInst *SI = dyn_cast<StoreInst>(&I))
587       Res = HandleEndUser(SI, SI->getValueOperand()->getType(),
588                           /* GuaranteedToExecute */ true);
589     if (Res && !*Res)
590       return false;
591 
592     if (!isGuaranteedToTransferExecutionToSuccessor(&I))
593       break;
594   }
595 
596   // Now look at all loads of the argument. Remember the load instructions
597   // for the aliasing check below.
598   SmallVector<const Use *, 16> Worklist;
599   SmallPtrSet<const Use *, 16> Visited;
600   SmallVector<LoadInst *, 16> Loads;
601   SmallPtrSet<CallBase *, 4> RecursiveCalls;
602   auto AppendUses = [&](const Value *V) {
603     for (const Use &U : V->uses())
604       if (Visited.insert(&U).second)
605         Worklist.push_back(&U);
606   };
607   AppendUses(Arg);
608   while (!Worklist.empty()) {
609     const Use *U = Worklist.pop_back_val();
610     Value *V = U->getUser();
611     if (isa<BitCastInst>(V)) {
612       AppendUses(V);
613       continue;
614     }
615 
616     if (auto *GEP = dyn_cast<GetElementPtrInst>(V)) {
617       if (!GEP->hasAllConstantIndices())
618         return false;
619       AppendUses(V);
620       continue;
621     }
622 
623     if (auto *LI = dyn_cast<LoadInst>(V)) {
624       if (!*HandleEndUser(LI, LI->getType(), /* GuaranteedToExecute */ false))
625         return false;
626       Loads.push_back(LI);
627       continue;
628     }
629 
630     // Stores are allowed for byval arguments
631     auto *SI = dyn_cast<StoreInst>(V);
632     if (AreStoresAllowed && SI &&
633         U->getOperandNo() == StoreInst::getPointerOperandIndex()) {
634       if (!*HandleEndUser(SI, SI->getValueOperand()->getType(),
635                           /* GuaranteedToExecute */ false))
636         return false;
637       continue;
638       // Only stores TO the argument is allowed, all the other stores are
639       // unknown users
640     }
641 
642     auto *CB = dyn_cast<CallBase>(V);
643     Value *PtrArg = U->get();
644     if (CB && CB->getCalledFunction() == CB->getFunction()) {
645       if (PtrArg != Arg) {
646         LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
647                           << "pointer offset is not equal to zero\n");
648         return false;
649       }
650 
651       unsigned int ArgNo = Arg->getArgNo();
652       if (U->getOperandNo() != ArgNo) {
653         LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
654                           << "arg position is different in callee\n");
655         return false;
656       }
657 
658       // We limit promotion to only promoting up to a fixed number of elements
659       // of the aggregate.
660       if (MaxElements > 0 && ArgParts.size() > MaxElements) {
661         LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
662                           << "more than " << MaxElements << " parts\n");
663         return false;
664       }
665 
666       RecursiveCalls.insert(CB);
667       continue;
668     }
669     // Unknown user.
670     LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
671                       << "unknown user " << *V << "\n");
672     return false;
673   }
674 
675   if (NeededDerefBytes || NeededAlign > 1) {
676     // Try to prove a required deref / aligned requirement.
677     if (!allCallersPassValidPointerForArgument(Arg, RecursiveCalls, NeededAlign,
678                                                NeededDerefBytes)) {
679       LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
680                         << "not dereferenceable or aligned\n");
681       return false;
682     }
683   }
684 
685   if (ArgParts.empty())
686     return true; // No users, this is a dead argument.
687 
688   // Sort parts by offset.
689   append_range(ArgPartsVec, ArgParts);
690   sort(ArgPartsVec, llvm::less_first());
691 
692   // Make sure the parts are non-overlapping.
693   int64_t Offset = ArgPartsVec[0].first;
694   for (const auto &Pair : ArgPartsVec) {
695     if (Pair.first < Offset)
696       return false; // Overlap with previous part.
697 
698     Offset = Pair.first + DL.getTypeStoreSize(Pair.second.Ty);
699   }
700 
701   // If store instructions are allowed, the path from the entry of the function
702   // to each load may be not free of instructions that potentially invalidate
703   // the load, and this is an admissible situation.
704   if (AreStoresAllowed)
705     return true;
706 
707   // Okay, now we know that the argument is only used by load instructions, and
708   // it is safe to unconditionally perform all of them. Use alias analysis to
709   // check to see if the pointer is guaranteed to not be modified from entry of
710   // the function to each of the load instructions.
711 
712   for (LoadInst *Load : Loads) {
713     // Check to see if the load is invalidated from the start of the block to
714     // the load itself.
715     BasicBlock *BB = Load->getParent();
716 
717     MemoryLocation Loc = MemoryLocation::get(Load);
718     if (AAR.canInstructionRangeModRef(BB->front(), *Load, Loc, ModRefInfo::Mod))
719       return false; // Pointer is invalidated!
720 
721     // Now check every path from the entry block to the load for transparency.
722     // To do this, we perform a depth first search on the inverse CFG from the
723     // loading block.
724     for (BasicBlock *P : predecessors(BB)) {
725       for (BasicBlock *TranspBB : inverse_depth_first(P))
726         if (AAR.canBasicBlockModify(*TranspBB, Loc))
727           return false;
728     }
729   }
730 
731   // If the path from the entry of the function to each load is free of
732   // instructions that potentially invalidate the load, we can make the
733   // transformation!
734   return true;
735 }
736 
737 /// Check if callers and callee agree on how promoted arguments would be
738 /// passed.
areTypesABICompatible(ArrayRef<Type * > Types,const Function & F,const TargetTransformInfo & TTI)739 static bool areTypesABICompatible(ArrayRef<Type *> Types, const Function &F,
740                                   const TargetTransformInfo &TTI) {
741   return all_of(F.uses(), [&](const Use &U) {
742     CallBase *CB = dyn_cast<CallBase>(U.getUser());
743     if (!CB)
744       return false;
745 
746     const Function *Caller = CB->getCaller();
747     const Function *Callee = CB->getCalledFunction();
748     return TTI.areTypesABICompatible(Caller, Callee, Types);
749   });
750 }
751 
752 /// PromoteArguments - This method checks the specified function to see if there
753 /// are any promotable arguments and if it is safe to promote the function (for
754 /// example, all callers are direct).  If safe to promote some arguments, it
755 /// calls the DoPromotion method.
promoteArguments(Function * F,FunctionAnalysisManager & FAM,unsigned MaxElements,bool IsRecursive)756 static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
757                                   unsigned MaxElements, bool IsRecursive) {
758   // Don't perform argument promotion for naked functions; otherwise we can end
759   // up removing parameters that are seemingly 'not used' as they are referred
760   // to in the assembly.
761   if (F->hasFnAttribute(Attribute::Naked))
762     return nullptr;
763 
764   // Make sure that it is local to this module.
765   if (!F->hasLocalLinkage())
766     return nullptr;
767 
768   // Don't promote arguments for variadic functions. Adding, removing, or
769   // changing non-pack parameters can change the classification of pack
770   // parameters. Frontends encode that classification at the call site in the
771   // IR, while in the callee the classification is determined dynamically based
772   // on the number of registers consumed so far.
773   if (F->isVarArg())
774     return nullptr;
775 
776   // Don't transform functions that receive inallocas, as the transformation may
777   // not be safe depending on calling convention.
778   if (F->getAttributes().hasAttrSomewhere(Attribute::InAlloca))
779     return nullptr;
780 
781   // First check: see if there are any pointer arguments!  If not, quick exit.
782   SmallVector<Argument *, 16> PointerArgs;
783   for (Argument &I : F->args())
784     if (I.getType()->isPointerTy())
785       PointerArgs.push_back(&I);
786   if (PointerArgs.empty())
787     return nullptr;
788 
789   // Second check: make sure that all callers are direct callers.  We can't
790   // transform functions that have indirect callers.  Also see if the function
791   // is self-recursive.
792   for (Use &U : F->uses()) {
793     CallBase *CB = dyn_cast<CallBase>(U.getUser());
794     // Must be a direct call.
795     if (CB == nullptr || !CB->isCallee(&U) ||
796         CB->getFunctionType() != F->getFunctionType())
797       return nullptr;
798 
799     // Can't change signature of musttail callee
800     if (CB->isMustTailCall())
801       return nullptr;
802 
803     if (CB->getFunction() == F)
804       IsRecursive = true;
805   }
806 
807   // Can't change signature of musttail caller
808   // FIXME: Support promoting whole chain of musttail functions
809   for (BasicBlock &BB : *F)
810     if (BB.getTerminatingMustTailCall())
811       return nullptr;
812 
813   const DataLayout &DL = F->getDataLayout();
814   auto &AAR = FAM.getResult<AAManager>(*F);
815   const auto &TTI = FAM.getResult<TargetIRAnalysis>(*F);
816 
817   // Check to see which arguments are promotable.  If an argument is promotable,
818   // add it to ArgsToPromote.
819   DenseMap<Argument *, SmallVector<OffsetAndArgPart, 4>> ArgsToPromote;
820   unsigned NumArgsAfterPromote = F->getFunctionType()->getNumParams();
821   for (Argument *PtrArg : PointerArgs) {
822     // Replace sret attribute with noalias. This reduces register pressure by
823     // avoiding a register copy.
824     if (PtrArg->hasStructRetAttr()) {
825       unsigned ArgNo = PtrArg->getArgNo();
826       F->removeParamAttr(ArgNo, Attribute::StructRet);
827       F->addParamAttr(ArgNo, Attribute::NoAlias);
828       for (Use &U : F->uses()) {
829         CallBase &CB = cast<CallBase>(*U.getUser());
830         CB.removeParamAttr(ArgNo, Attribute::StructRet);
831         CB.addParamAttr(ArgNo, Attribute::NoAlias);
832       }
833     }
834 
835     // If we can promote the pointer to its value.
836     SmallVector<OffsetAndArgPart, 4> ArgParts;
837 
838     if (findArgParts(PtrArg, DL, AAR, MaxElements, IsRecursive, ArgParts)) {
839       SmallVector<Type *, 4> Types;
840       for (const auto &Pair : ArgParts)
841         Types.push_back(Pair.second.Ty);
842 
843       if (areTypesABICompatible(Types, *F, TTI)) {
844         NumArgsAfterPromote += ArgParts.size() - 1;
845         ArgsToPromote.insert({PtrArg, std::move(ArgParts)});
846       }
847     }
848   }
849 
850   // No promotable pointer arguments.
851   if (ArgsToPromote.empty())
852     return nullptr;
853 
854   if (NumArgsAfterPromote > TTI.getMaxNumArgs())
855     return nullptr;
856 
857   return doPromotion(F, FAM, ArgsToPromote);
858 }
859 
run(LazyCallGraph::SCC & C,CGSCCAnalysisManager & AM,LazyCallGraph & CG,CGSCCUpdateResult & UR)860 PreservedAnalyses ArgumentPromotionPass::run(LazyCallGraph::SCC &C,
861                                              CGSCCAnalysisManager &AM,
862                                              LazyCallGraph &CG,
863                                              CGSCCUpdateResult &UR) {
864   bool Changed = false, LocalChange;
865 
866   // Iterate until we stop promoting from this SCC.
867   do {
868     LocalChange = false;
869 
870     FunctionAnalysisManager &FAM =
871         AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
872 
873     bool IsRecursive = C.size() > 1;
874     for (LazyCallGraph::Node &N : C) {
875       Function &OldF = N.getFunction();
876       Function *NewF = promoteArguments(&OldF, FAM, MaxElements, IsRecursive);
877       if (!NewF)
878         continue;
879       LocalChange = true;
880 
881       // Directly substitute the functions in the call graph. Note that this
882       // requires the old function to be completely dead and completely
883       // replaced by the new function. It does no call graph updates, it merely
884       // swaps out the particular function mapped to a particular node in the
885       // graph.
886       C.getOuterRefSCC().replaceNodeFunction(N, *NewF);
887       FAM.clear(OldF, OldF.getName());
888       OldF.eraseFromParent();
889 
890       PreservedAnalyses FuncPA;
891       FuncPA.preserveSet<CFGAnalyses>();
892       for (auto *U : NewF->users()) {
893         auto *UserF = cast<CallBase>(U)->getFunction();
894         FAM.invalidate(*UserF, FuncPA);
895       }
896     }
897 
898     Changed |= LocalChange;
899   } while (LocalChange);
900 
901   if (!Changed)
902     return PreservedAnalyses::all();
903 
904   PreservedAnalyses PA;
905   // We've cleared out analyses for deleted functions.
906   PA.preserve<FunctionAnalysisManagerCGSCCProxy>();
907   // We've manually invalidated analyses for functions we've modified.
908   PA.preserveSet<AllAnalysesOn<Function>>();
909   return PA;
910 }
911