1 //===- ArgumentPromotion.cpp - Promote by-reference arguments -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass promotes "by reference" arguments to be "by value" arguments. In
10 // practice, this means looking for internal functions that have pointer
11 // arguments. If it can prove, through the use of alias analysis, that an
12 // argument is *only* loaded, then it can pass the value into the function
13 // instead of the address of the value. This can cause recursive simplification
14 // of code and lead to the elimination of allocas (especially in C++ template
15 // code like the STL).
16 //
17 // This pass also handles aggregate arguments that are passed into a function,
18 // scalarizing them if the elements of the aggregate are only loaded. Note that
19 // by default it refuses to scalarize aggregates which would require passing in
20 // more than three operands to the function, because passing thousands of
21 // operands for a large array or structure is unprofitable! This limit can be
22 // configured or disabled, however.
23 //
24 // Note that this transformation could also be done for arguments that are only
25 // stored to (returning the value instead), but does not currently. This case
26 // would be best handled when and if LLVM begins supporting multiple return
27 // values from functions.
28 //
29 //===----------------------------------------------------------------------===//
30
31 #include "llvm/Transforms/IPO/ArgumentPromotion.h"
32
33 #include "llvm/ADT/DepthFirstIterator.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/ScopeExit.h"
36 #include "llvm/ADT/SmallPtrSet.h"
37 #include "llvm/ADT/SmallVector.h"
38 #include "llvm/ADT/Statistic.h"
39 #include "llvm/ADT/Twine.h"
40 #include "llvm/Analysis/AssumptionCache.h"
41 #include "llvm/Analysis/BasicAliasAnalysis.h"
42 #include "llvm/Analysis/CallGraph.h"
43 #include "llvm/Analysis/Loads.h"
44 #include "llvm/Analysis/MemoryLocation.h"
45 #include "llvm/Analysis/TargetTransformInfo.h"
46 #include "llvm/Analysis/ValueTracking.h"
47 #include "llvm/IR/Argument.h"
48 #include "llvm/IR/Attributes.h"
49 #include "llvm/IR/BasicBlock.h"
50 #include "llvm/IR/CFG.h"
51 #include "llvm/IR/Constants.h"
52 #include "llvm/IR/DataLayout.h"
53 #include "llvm/IR/DerivedTypes.h"
54 #include "llvm/IR/Dominators.h"
55 #include "llvm/IR/Function.h"
56 #include "llvm/IR/IRBuilder.h"
57 #include "llvm/IR/InstrTypes.h"
58 #include "llvm/IR/Instruction.h"
59 #include "llvm/IR/Instructions.h"
60 #include "llvm/IR/Metadata.h"
61 #include "llvm/IR/Module.h"
62 #include "llvm/IR/NoFolder.h"
63 #include "llvm/IR/PassManager.h"
64 #include "llvm/IR/Type.h"
65 #include "llvm/IR/Use.h"
66 #include "llvm/IR/User.h"
67 #include "llvm/IR/Value.h"
68 #include "llvm/Support/Casting.h"
69 #include "llvm/Support/Debug.h"
70 #include "llvm/Support/raw_ostream.h"
71 #include "llvm/Transforms/Utils/Local.h"
72 #include "llvm/Transforms/Utils/PromoteMemToReg.h"
73 #include <algorithm>
74 #include <cassert>
75 #include <cstdint>
76 #include <utility>
77 #include <vector>
78
79 using namespace llvm;
80
81 #define DEBUG_TYPE "argpromotion"
82
83 STATISTIC(NumArgumentsPromoted, "Number of pointer arguments promoted");
84 STATISTIC(NumArgumentsDead, "Number of dead pointer args eliminated");
85
86 namespace {
87
88 struct ArgPart {
89 Type *Ty;
90 Align Alignment;
91 /// A representative guaranteed-executed load or store instruction for use by
92 /// metadata transfer.
93 Instruction *MustExecInstr;
94 };
95
96 using OffsetAndArgPart = std::pair<int64_t, ArgPart>;
97
98 } // end anonymous namespace
99
createByteGEP(IRBuilderBase & IRB,const DataLayout & DL,Value * Ptr,Type * ResElemTy,int64_t Offset)100 static Value *createByteGEP(IRBuilderBase &IRB, const DataLayout &DL,
101 Value *Ptr, Type *ResElemTy, int64_t Offset) {
102 if (Offset != 0) {
103 APInt APOffset(DL.getIndexTypeSizeInBits(Ptr->getType()), Offset);
104 Ptr = IRB.CreatePtrAdd(Ptr, IRB.getInt(APOffset));
105 }
106 return Ptr;
107 }
108
109 /// DoPromotion - This method actually performs the promotion of the specified
110 /// arguments, and returns the new function. At this point, we know that it's
111 /// safe to do so.
112 static Function *
doPromotion(Function * F,FunctionAnalysisManager & FAM,const DenseMap<Argument *,SmallVector<OffsetAndArgPart,4>> & ArgsToPromote)113 doPromotion(Function *F, FunctionAnalysisManager &FAM,
114 const DenseMap<Argument *, SmallVector<OffsetAndArgPart, 4>>
115 &ArgsToPromote) {
116 // Start by computing a new prototype for the function, which is the same as
117 // the old function, but has modified arguments.
118 FunctionType *FTy = F->getFunctionType();
119 std::vector<Type *> Params;
120
121 // Attribute - Keep track of the parameter attributes for the arguments
122 // that we are *not* promoting. For the ones that we do promote, the parameter
123 // attributes are lost
124 SmallVector<AttributeSet, 8> ArgAttrVec;
125 // Mapping from old to new argument indices. -1 for promoted or removed
126 // arguments.
127 SmallVector<unsigned> NewArgIndices;
128 AttributeList PAL = F->getAttributes();
129
130 // First, determine the new argument list
131 unsigned ArgNo = 0, NewArgNo = 0;
132 for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
133 ++I, ++ArgNo) {
134 if (!ArgsToPromote.count(&*I)) {
135 // Unchanged argument
136 Params.push_back(I->getType());
137 ArgAttrVec.push_back(PAL.getParamAttrs(ArgNo));
138 NewArgIndices.push_back(NewArgNo++);
139 } else if (I->use_empty()) {
140 // Dead argument (which are always marked as promotable)
141 ++NumArgumentsDead;
142 NewArgIndices.push_back((unsigned)-1);
143 } else {
144 const auto &ArgParts = ArgsToPromote.find(&*I)->second;
145 for (const auto &Pair : ArgParts) {
146 Params.push_back(Pair.second.Ty);
147 ArgAttrVec.push_back(AttributeSet());
148 }
149 ++NumArgumentsPromoted;
150 NewArgIndices.push_back((unsigned)-1);
151 NewArgNo += ArgParts.size();
152 }
153 }
154
155 Type *RetTy = FTy->getReturnType();
156
157 // Construct the new function type using the new arguments.
158 FunctionType *NFTy = FunctionType::get(RetTy, Params, FTy->isVarArg());
159
160 // Create the new function body and insert it into the module.
161 Function *NF = Function::Create(NFTy, F->getLinkage(), F->getAddressSpace(),
162 F->getName());
163 NF->copyAttributesFrom(F);
164 NF->copyMetadata(F, 0);
165 NF->setIsNewDbgInfoFormat(F->IsNewDbgInfoFormat);
166
167 // The new function will have the !dbg metadata copied from the original
168 // function. The original function may not be deleted, and dbg metadata need
169 // to be unique, so we need to drop it.
170 F->setSubprogram(nullptr);
171
172 LLVM_DEBUG(dbgs() << "ARG PROMOTION: Promoting to:" << *NF << "\n"
173 << "From: " << *F);
174
175 uint64_t LargestVectorWidth = 0;
176 for (auto *I : Params)
177 if (auto *VT = dyn_cast<llvm::VectorType>(I))
178 LargestVectorWidth = std::max(
179 LargestVectorWidth, VT->getPrimitiveSizeInBits().getKnownMinValue());
180
181 // Recompute the parameter attributes list based on the new arguments for
182 // the function.
183 NF->setAttributes(AttributeList::get(F->getContext(), PAL.getFnAttrs(),
184 PAL.getRetAttrs(), ArgAttrVec));
185
186 // Remap argument indices in allocsize attribute.
187 if (auto AllocSize = NF->getAttributes().getFnAttrs().getAllocSizeArgs()) {
188 unsigned Arg1 = NewArgIndices[AllocSize->first];
189 assert(Arg1 != (unsigned)-1 && "allocsize cannot be promoted argument");
190 std::optional<unsigned> Arg2;
191 if (AllocSize->second) {
192 Arg2 = NewArgIndices[*AllocSize->second];
193 assert(Arg2 != (unsigned)-1 && "allocsize cannot be promoted argument");
194 }
195 NF->addFnAttr(Attribute::getWithAllocSizeArgs(F->getContext(), Arg1, Arg2));
196 }
197
198 AttributeFuncs::updateMinLegalVectorWidthAttr(*NF, LargestVectorWidth);
199 ArgAttrVec.clear();
200
201 F->getParent()->getFunctionList().insert(F->getIterator(), NF);
202 NF->takeName(F);
203
204 // Loop over all the callers of the function, transforming the call sites to
205 // pass in the loaded pointers.
206 SmallVector<Value *, 16> Args;
207 const DataLayout &DL = F->getDataLayout();
208 SmallVector<WeakTrackingVH, 16> DeadArgs;
209
210 while (!F->use_empty()) {
211 CallBase &CB = cast<CallBase>(*F->user_back());
212 assert(CB.getCalledFunction() == F);
213 const AttributeList &CallPAL = CB.getAttributes();
214 IRBuilder<NoFolder> IRB(&CB);
215
216 // Loop over the operands, inserting GEP and loads in the caller as
217 // appropriate.
218 auto *AI = CB.arg_begin();
219 ArgNo = 0;
220 for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
221 ++I, ++AI, ++ArgNo) {
222 if (!ArgsToPromote.count(&*I)) {
223 Args.push_back(*AI); // Unmodified argument
224 ArgAttrVec.push_back(CallPAL.getParamAttrs(ArgNo));
225 } else if (!I->use_empty()) {
226 Value *V = *AI;
227 const auto &ArgParts = ArgsToPromote.find(&*I)->second;
228 for (const auto &Pair : ArgParts) {
229 LoadInst *LI = IRB.CreateAlignedLoad(
230 Pair.second.Ty,
231 createByteGEP(IRB, DL, V, Pair.second.Ty, Pair.first),
232 Pair.second.Alignment, V->getName() + ".val");
233 if (Pair.second.MustExecInstr) {
234 LI->setAAMetadata(Pair.second.MustExecInstr->getAAMetadata());
235 LI->copyMetadata(*Pair.second.MustExecInstr,
236 {LLVMContext::MD_dereferenceable,
237 LLVMContext::MD_dereferenceable_or_null,
238 LLVMContext::MD_noundef,
239 LLVMContext::MD_nontemporal});
240 // Only transfer poison-generating metadata if we also have
241 // !noundef.
242 // TODO: Without !noundef, we could merge this metadata across
243 // all promoted loads.
244 if (LI->hasMetadata(LLVMContext::MD_noundef))
245 LI->copyMetadata(*Pair.second.MustExecInstr,
246 {LLVMContext::MD_range, LLVMContext::MD_nonnull,
247 LLVMContext::MD_align});
248 }
249 Args.push_back(LI);
250 ArgAttrVec.push_back(AttributeSet());
251 }
252 } else {
253 assert(ArgsToPromote.count(&*I) && I->use_empty());
254 DeadArgs.emplace_back(AI->get());
255 }
256 }
257
258 // Push any varargs arguments on the list.
259 for (; AI != CB.arg_end(); ++AI, ++ArgNo) {
260 Args.push_back(*AI);
261 ArgAttrVec.push_back(CallPAL.getParamAttrs(ArgNo));
262 }
263
264 SmallVector<OperandBundleDef, 1> OpBundles;
265 CB.getOperandBundlesAsDefs(OpBundles);
266
267 CallBase *NewCS = nullptr;
268 if (InvokeInst *II = dyn_cast<InvokeInst>(&CB)) {
269 NewCS = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(),
270 Args, OpBundles, "", CB.getIterator());
271 } else {
272 auto *NewCall =
273 CallInst::Create(NF, Args, OpBundles, "", CB.getIterator());
274 NewCall->setTailCallKind(cast<CallInst>(&CB)->getTailCallKind());
275 NewCS = NewCall;
276 }
277 NewCS->setCallingConv(CB.getCallingConv());
278 NewCS->setAttributes(AttributeList::get(F->getContext(),
279 CallPAL.getFnAttrs(),
280 CallPAL.getRetAttrs(), ArgAttrVec));
281 NewCS->copyMetadata(CB, {LLVMContext::MD_prof, LLVMContext::MD_dbg});
282 Args.clear();
283 ArgAttrVec.clear();
284
285 AttributeFuncs::updateMinLegalVectorWidthAttr(*CB.getCaller(),
286 LargestVectorWidth);
287
288 if (!CB.use_empty()) {
289 CB.replaceAllUsesWith(NewCS);
290 NewCS->takeName(&CB);
291 }
292
293 // Finally, remove the old call from the program, reducing the use-count of
294 // F.
295 CB.eraseFromParent();
296 }
297
298 RecursivelyDeleteTriviallyDeadInstructionsPermissive(DeadArgs);
299
300 // Since we have now created the new function, splice the body of the old
301 // function right into the new function, leaving the old rotting hulk of the
302 // function empty.
303 NF->splice(NF->begin(), F);
304
305 // We will collect all the new created allocas to promote them into registers
306 // after the following loop
307 SmallVector<AllocaInst *, 4> Allocas;
308
309 // Loop over the argument list, transferring uses of the old arguments over to
310 // the new arguments, also transferring over the names as well.
311 Function::arg_iterator I2 = NF->arg_begin();
312 for (Argument &Arg : F->args()) {
313 if (!ArgsToPromote.count(&Arg)) {
314 // If this is an unmodified argument, move the name and users over to the
315 // new version.
316 Arg.replaceAllUsesWith(&*I2);
317 I2->takeName(&Arg);
318 ++I2;
319 continue;
320 }
321
322 // There potentially are metadata uses for things like llvm.dbg.value.
323 // Replace them with undef, after handling the other regular uses.
324 auto RauwUndefMetadata = make_scope_exit(
325 [&]() { Arg.replaceAllUsesWith(UndefValue::get(Arg.getType())); });
326
327 if (Arg.use_empty())
328 continue;
329
330 // Otherwise, if we promoted this argument, we have to create an alloca in
331 // the callee for every promotable part and store each of the new incoming
332 // arguments into the corresponding alloca, what lets the old code (the
333 // store instructions if they are allowed especially) a chance to work as
334 // before.
335 assert(Arg.getType()->isPointerTy() &&
336 "Only arguments with a pointer type are promotable");
337
338 IRBuilder<NoFolder> IRB(&NF->begin()->front());
339
340 // Add only the promoted elements, so parts from ArgsToPromote
341 SmallDenseMap<int64_t, AllocaInst *> OffsetToAlloca;
342 for (const auto &Pair : ArgsToPromote.find(&Arg)->second) {
343 int64_t Offset = Pair.first;
344 const ArgPart &Part = Pair.second;
345
346 Argument *NewArg = I2++;
347 NewArg->setName(Arg.getName() + "." + Twine(Offset) + ".val");
348
349 AllocaInst *NewAlloca = IRB.CreateAlloca(
350 Part.Ty, nullptr, Arg.getName() + "." + Twine(Offset) + ".allc");
351 NewAlloca->setAlignment(Pair.second.Alignment);
352 IRB.CreateAlignedStore(NewArg, NewAlloca, Pair.second.Alignment);
353
354 // Collect the alloca to retarget the users to
355 OffsetToAlloca.insert({Offset, NewAlloca});
356 }
357
358 auto GetAlloca = [&](Value *Ptr) {
359 APInt Offset(DL.getIndexTypeSizeInBits(Ptr->getType()), 0);
360 Ptr = Ptr->stripAndAccumulateConstantOffsets(DL, Offset,
361 /* AllowNonInbounds */ true);
362 assert(Ptr == &Arg && "Not constant offset from arg?");
363 return OffsetToAlloca.lookup(Offset.getSExtValue());
364 };
365
366 // Cleanup the code from the dead instructions: GEPs and BitCasts in between
367 // the original argument and its users: loads and stores. Retarget every
368 // user to the new created alloca.
369 SmallVector<Value *, 16> Worklist;
370 SmallVector<Instruction *, 16> DeadInsts;
371 append_range(Worklist, Arg.users());
372 while (!Worklist.empty()) {
373 Value *V = Worklist.pop_back_val();
374 if (isa<BitCastInst>(V) || isa<GetElementPtrInst>(V)) {
375 DeadInsts.push_back(cast<Instruction>(V));
376 append_range(Worklist, V->users());
377 continue;
378 }
379
380 if (auto *LI = dyn_cast<LoadInst>(V)) {
381 Value *Ptr = LI->getPointerOperand();
382 LI->setOperand(LoadInst::getPointerOperandIndex(), GetAlloca(Ptr));
383 continue;
384 }
385
386 if (auto *SI = dyn_cast<StoreInst>(V)) {
387 assert(!SI->isVolatile() && "Volatile operations can't be promoted.");
388 Value *Ptr = SI->getPointerOperand();
389 SI->setOperand(StoreInst::getPointerOperandIndex(), GetAlloca(Ptr));
390 continue;
391 }
392
393 llvm_unreachable("Unexpected user");
394 }
395
396 for (Instruction *I : DeadInsts) {
397 I->replaceAllUsesWith(PoisonValue::get(I->getType()));
398 I->eraseFromParent();
399 }
400
401 // Collect the allocas for promotion
402 for (const auto &Pair : OffsetToAlloca) {
403 assert(isAllocaPromotable(Pair.second) &&
404 "By design, only promotable allocas should be produced.");
405 Allocas.push_back(Pair.second);
406 }
407 }
408
409 LLVM_DEBUG(dbgs() << "ARG PROMOTION: " << Allocas.size()
410 << " alloca(s) are promotable by Mem2Reg\n");
411
412 if (!Allocas.empty()) {
413 // And we are able to call the `promoteMemoryToRegister()` function.
414 // Our earlier checks have ensured that PromoteMemToReg() will
415 // succeed.
416 auto &DT = FAM.getResult<DominatorTreeAnalysis>(*NF);
417 auto &AC = FAM.getResult<AssumptionAnalysis>(*NF);
418 PromoteMemToReg(Allocas, DT, &AC);
419 }
420
421 return NF;
422 }
423
424 /// Return true if we can prove that all callees pass in a valid pointer for the
425 /// specified function argument.
allCallersPassValidPointerForArgument(Argument * Arg,SmallPtrSetImpl<CallBase * > & RecursiveCalls,Align NeededAlign,uint64_t NeededDerefBytes)426 static bool allCallersPassValidPointerForArgument(
427 Argument *Arg, SmallPtrSetImpl<CallBase *> &RecursiveCalls,
428 Align NeededAlign, uint64_t NeededDerefBytes) {
429 Function *Callee = Arg->getParent();
430 const DataLayout &DL = Callee->getDataLayout();
431 APInt Bytes(64, NeededDerefBytes);
432
433 // Check if the argument itself is marked dereferenceable and aligned.
434 if (isDereferenceableAndAlignedPointer(Arg, NeededAlign, Bytes, DL))
435 return true;
436
437 // Look at all call sites of the function. At this point we know we only have
438 // direct callees.
439 return all_of(Callee->users(), [&](User *U) {
440 CallBase &CB = cast<CallBase>(*U);
441 // In case of functions with recursive calls, this check
442 // (isDereferenceableAndAlignedPointer) will fail when it tries to look at
443 // the first caller of this function. The caller may or may not have a load,
444 // incase it doesn't load the pointer being passed, this check will fail.
445 // So, it's safe to skip the check incase we know that we are dealing with a
446 // recursive call. For example we have a IR given below.
447 //
448 // def fun(ptr %a) {
449 // ...
450 // %loadres = load i32, ptr %a, align 4
451 // %res = call i32 @fun(ptr %a)
452 // ...
453 // }
454 //
455 // def bar(ptr %x) {
456 // ...
457 // %resbar = call i32 @fun(ptr %x)
458 // ...
459 // }
460 //
461 // Since we record processed recursive calls, we check if the current
462 // CallBase has been processed before. If yes it means that it is a
463 // recursive call and we can skip the check just for this call. So, just
464 // return true.
465 if (RecursiveCalls.contains(&CB))
466 return true;
467
468 return isDereferenceableAndAlignedPointer(CB.getArgOperand(Arg->getArgNo()),
469 NeededAlign, Bytes, DL);
470 });
471 }
472
473 /// Determine that this argument is safe to promote, and find the argument
474 /// parts it can be promoted into.
findArgParts(Argument * Arg,const DataLayout & DL,AAResults & AAR,unsigned MaxElements,bool IsRecursive,SmallVectorImpl<OffsetAndArgPart> & ArgPartsVec)475 static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
476 unsigned MaxElements, bool IsRecursive,
477 SmallVectorImpl<OffsetAndArgPart> &ArgPartsVec) {
478 // Quick exit for unused arguments
479 if (Arg->use_empty())
480 return true;
481
482 // We can only promote this argument if all the uses are loads at known
483 // offsets.
484 //
485 // Promoting the argument causes it to be loaded in the caller
486 // unconditionally. This is only safe if we can prove that either the load
487 // would have happened in the callee anyway (ie, there is a load in the entry
488 // block) or the pointer passed in at every call site is guaranteed to be
489 // valid.
490 // In the former case, invalid loads can happen, but would have happened
491 // anyway, in the latter case, invalid loads won't happen. This prevents us
492 // from introducing an invalid load that wouldn't have happened in the
493 // original code.
494
495 SmallDenseMap<int64_t, ArgPart, 4> ArgParts;
496 Align NeededAlign(1);
497 uint64_t NeededDerefBytes = 0;
498
499 // And if this is a byval argument we also allow to have store instructions.
500 // Only handle in such way arguments with specified alignment;
501 // if it's unspecified, the actual alignment of the argument is
502 // target-specific.
503 bool AreStoresAllowed = Arg->getParamByValType() && Arg->getParamAlign();
504
505 // An end user of a pointer argument is a load or store instruction.
506 // Returns std::nullopt if this load or store is not based on the argument.
507 // Return true if we can promote the instruction, false otherwise.
508 auto HandleEndUser = [&](auto *I, Type *Ty,
509 bool GuaranteedToExecute) -> std::optional<bool> {
510 // Don't promote volatile or atomic instructions.
511 if (!I->isSimple())
512 return false;
513
514 Value *Ptr = I->getPointerOperand();
515 APInt Offset(DL.getIndexTypeSizeInBits(Ptr->getType()), 0);
516 Ptr = Ptr->stripAndAccumulateConstantOffsets(DL, Offset,
517 /* AllowNonInbounds */ true);
518 if (Ptr != Arg)
519 return std::nullopt;
520
521 if (Offset.getSignificantBits() >= 64)
522 return false;
523
524 TypeSize Size = DL.getTypeStoreSize(Ty);
525 // Don't try to promote scalable types.
526 if (Size.isScalable())
527 return false;
528
529 // If this is a recursive function and one of the types is a pointer,
530 // then promoting it might lead to recursive promotion.
531 if (IsRecursive && Ty->isPointerTy())
532 return false;
533
534 int64_t Off = Offset.getSExtValue();
535 auto Pair = ArgParts.try_emplace(
536 Off, ArgPart{Ty, I->getAlign(), GuaranteedToExecute ? I : nullptr});
537 ArgPart &Part = Pair.first->second;
538 bool OffsetNotSeenBefore = Pair.second;
539
540 // We limit promotion to only promoting up to a fixed number of elements of
541 // the aggregate.
542 if (MaxElements > 0 && ArgParts.size() > MaxElements) {
543 LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
544 << "more than " << MaxElements << " parts\n");
545 return false;
546 }
547
548 // For now, we only support loading/storing one specific type at a given
549 // offset.
550 if (Part.Ty != Ty) {
551 LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
552 << "accessed as both " << *Part.Ty << " and " << *Ty
553 << " at offset " << Off << "\n");
554 return false;
555 }
556
557 // If this instruction is not guaranteed to execute, and we haven't seen a
558 // load or store at this offset before (or it had lower alignment), then we
559 // need to remember that requirement.
560 // Note that skipping instructions of previously seen offsets is only
561 // correct because we only allow a single type for a given offset, which
562 // also means that the number of accessed bytes will be the same.
563 if (!GuaranteedToExecute &&
564 (OffsetNotSeenBefore || Part.Alignment < I->getAlign())) {
565 // We won't be able to prove dereferenceability for negative offsets.
566 if (Off < 0)
567 return false;
568
569 // If the offset is not aligned, an aligned base pointer won't help.
570 if (!isAligned(I->getAlign(), Off))
571 return false;
572
573 NeededDerefBytes = std::max(NeededDerefBytes, Off + Size.getFixedValue());
574 NeededAlign = std::max(NeededAlign, I->getAlign());
575 }
576
577 Part.Alignment = std::max(Part.Alignment, I->getAlign());
578 return true;
579 };
580
581 // Look for loads and stores that are guaranteed to execute on entry.
582 for (Instruction &I : Arg->getParent()->getEntryBlock()) {
583 std::optional<bool> Res{};
584 if (LoadInst *LI = dyn_cast<LoadInst>(&I))
585 Res = HandleEndUser(LI, LI->getType(), /* GuaranteedToExecute */ true);
586 else if (StoreInst *SI = dyn_cast<StoreInst>(&I))
587 Res = HandleEndUser(SI, SI->getValueOperand()->getType(),
588 /* GuaranteedToExecute */ true);
589 if (Res && !*Res)
590 return false;
591
592 if (!isGuaranteedToTransferExecutionToSuccessor(&I))
593 break;
594 }
595
596 // Now look at all loads of the argument. Remember the load instructions
597 // for the aliasing check below.
598 SmallVector<const Use *, 16> Worklist;
599 SmallPtrSet<const Use *, 16> Visited;
600 SmallVector<LoadInst *, 16> Loads;
601 SmallPtrSet<CallBase *, 4> RecursiveCalls;
602 auto AppendUses = [&](const Value *V) {
603 for (const Use &U : V->uses())
604 if (Visited.insert(&U).second)
605 Worklist.push_back(&U);
606 };
607 AppendUses(Arg);
608 while (!Worklist.empty()) {
609 const Use *U = Worklist.pop_back_val();
610 Value *V = U->getUser();
611 if (isa<BitCastInst>(V)) {
612 AppendUses(V);
613 continue;
614 }
615
616 if (auto *GEP = dyn_cast<GetElementPtrInst>(V)) {
617 if (!GEP->hasAllConstantIndices())
618 return false;
619 AppendUses(V);
620 continue;
621 }
622
623 if (auto *LI = dyn_cast<LoadInst>(V)) {
624 if (!*HandleEndUser(LI, LI->getType(), /* GuaranteedToExecute */ false))
625 return false;
626 Loads.push_back(LI);
627 continue;
628 }
629
630 // Stores are allowed for byval arguments
631 auto *SI = dyn_cast<StoreInst>(V);
632 if (AreStoresAllowed && SI &&
633 U->getOperandNo() == StoreInst::getPointerOperandIndex()) {
634 if (!*HandleEndUser(SI, SI->getValueOperand()->getType(),
635 /* GuaranteedToExecute */ false))
636 return false;
637 continue;
638 // Only stores TO the argument is allowed, all the other stores are
639 // unknown users
640 }
641
642 auto *CB = dyn_cast<CallBase>(V);
643 Value *PtrArg = U->get();
644 if (CB && CB->getCalledFunction() == CB->getFunction()) {
645 if (PtrArg != Arg) {
646 LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
647 << "pointer offset is not equal to zero\n");
648 return false;
649 }
650
651 unsigned int ArgNo = Arg->getArgNo();
652 if (U->getOperandNo() != ArgNo) {
653 LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
654 << "arg position is different in callee\n");
655 return false;
656 }
657
658 // We limit promotion to only promoting up to a fixed number of elements
659 // of the aggregate.
660 if (MaxElements > 0 && ArgParts.size() > MaxElements) {
661 LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
662 << "more than " << MaxElements << " parts\n");
663 return false;
664 }
665
666 RecursiveCalls.insert(CB);
667 continue;
668 }
669 // Unknown user.
670 LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
671 << "unknown user " << *V << "\n");
672 return false;
673 }
674
675 if (NeededDerefBytes || NeededAlign > 1) {
676 // Try to prove a required deref / aligned requirement.
677 if (!allCallersPassValidPointerForArgument(Arg, RecursiveCalls, NeededAlign,
678 NeededDerefBytes)) {
679 LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
680 << "not dereferenceable or aligned\n");
681 return false;
682 }
683 }
684
685 if (ArgParts.empty())
686 return true; // No users, this is a dead argument.
687
688 // Sort parts by offset.
689 append_range(ArgPartsVec, ArgParts);
690 sort(ArgPartsVec, llvm::less_first());
691
692 // Make sure the parts are non-overlapping.
693 int64_t Offset = ArgPartsVec[0].first;
694 for (const auto &Pair : ArgPartsVec) {
695 if (Pair.first < Offset)
696 return false; // Overlap with previous part.
697
698 Offset = Pair.first + DL.getTypeStoreSize(Pair.second.Ty);
699 }
700
701 // If store instructions are allowed, the path from the entry of the function
702 // to each load may be not free of instructions that potentially invalidate
703 // the load, and this is an admissible situation.
704 if (AreStoresAllowed)
705 return true;
706
707 // Okay, now we know that the argument is only used by load instructions, and
708 // it is safe to unconditionally perform all of them. Use alias analysis to
709 // check to see if the pointer is guaranteed to not be modified from entry of
710 // the function to each of the load instructions.
711
712 for (LoadInst *Load : Loads) {
713 // Check to see if the load is invalidated from the start of the block to
714 // the load itself.
715 BasicBlock *BB = Load->getParent();
716
717 MemoryLocation Loc = MemoryLocation::get(Load);
718 if (AAR.canInstructionRangeModRef(BB->front(), *Load, Loc, ModRefInfo::Mod))
719 return false; // Pointer is invalidated!
720
721 // Now check every path from the entry block to the load for transparency.
722 // To do this, we perform a depth first search on the inverse CFG from the
723 // loading block.
724 for (BasicBlock *P : predecessors(BB)) {
725 for (BasicBlock *TranspBB : inverse_depth_first(P))
726 if (AAR.canBasicBlockModify(*TranspBB, Loc))
727 return false;
728 }
729 }
730
731 // If the path from the entry of the function to each load is free of
732 // instructions that potentially invalidate the load, we can make the
733 // transformation!
734 return true;
735 }
736
737 /// Check if callers and callee agree on how promoted arguments would be
738 /// passed.
areTypesABICompatible(ArrayRef<Type * > Types,const Function & F,const TargetTransformInfo & TTI)739 static bool areTypesABICompatible(ArrayRef<Type *> Types, const Function &F,
740 const TargetTransformInfo &TTI) {
741 return all_of(F.uses(), [&](const Use &U) {
742 CallBase *CB = dyn_cast<CallBase>(U.getUser());
743 if (!CB)
744 return false;
745
746 const Function *Caller = CB->getCaller();
747 const Function *Callee = CB->getCalledFunction();
748 return TTI.areTypesABICompatible(Caller, Callee, Types);
749 });
750 }
751
752 /// PromoteArguments - This method checks the specified function to see if there
753 /// are any promotable arguments and if it is safe to promote the function (for
754 /// example, all callers are direct). If safe to promote some arguments, it
755 /// calls the DoPromotion method.
promoteArguments(Function * F,FunctionAnalysisManager & FAM,unsigned MaxElements,bool IsRecursive)756 static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
757 unsigned MaxElements, bool IsRecursive) {
758 // Don't perform argument promotion for naked functions; otherwise we can end
759 // up removing parameters that are seemingly 'not used' as they are referred
760 // to in the assembly.
761 if (F->hasFnAttribute(Attribute::Naked))
762 return nullptr;
763
764 // Make sure that it is local to this module.
765 if (!F->hasLocalLinkage())
766 return nullptr;
767
768 // Don't promote arguments for variadic functions. Adding, removing, or
769 // changing non-pack parameters can change the classification of pack
770 // parameters. Frontends encode that classification at the call site in the
771 // IR, while in the callee the classification is determined dynamically based
772 // on the number of registers consumed so far.
773 if (F->isVarArg())
774 return nullptr;
775
776 // Don't transform functions that receive inallocas, as the transformation may
777 // not be safe depending on calling convention.
778 if (F->getAttributes().hasAttrSomewhere(Attribute::InAlloca))
779 return nullptr;
780
781 // First check: see if there are any pointer arguments! If not, quick exit.
782 SmallVector<Argument *, 16> PointerArgs;
783 for (Argument &I : F->args())
784 if (I.getType()->isPointerTy())
785 PointerArgs.push_back(&I);
786 if (PointerArgs.empty())
787 return nullptr;
788
789 // Second check: make sure that all callers are direct callers. We can't
790 // transform functions that have indirect callers. Also see if the function
791 // is self-recursive.
792 for (Use &U : F->uses()) {
793 CallBase *CB = dyn_cast<CallBase>(U.getUser());
794 // Must be a direct call.
795 if (CB == nullptr || !CB->isCallee(&U) ||
796 CB->getFunctionType() != F->getFunctionType())
797 return nullptr;
798
799 // Can't change signature of musttail callee
800 if (CB->isMustTailCall())
801 return nullptr;
802
803 if (CB->getFunction() == F)
804 IsRecursive = true;
805 }
806
807 // Can't change signature of musttail caller
808 // FIXME: Support promoting whole chain of musttail functions
809 for (BasicBlock &BB : *F)
810 if (BB.getTerminatingMustTailCall())
811 return nullptr;
812
813 const DataLayout &DL = F->getDataLayout();
814 auto &AAR = FAM.getResult<AAManager>(*F);
815 const auto &TTI = FAM.getResult<TargetIRAnalysis>(*F);
816
817 // Check to see which arguments are promotable. If an argument is promotable,
818 // add it to ArgsToPromote.
819 DenseMap<Argument *, SmallVector<OffsetAndArgPart, 4>> ArgsToPromote;
820 unsigned NumArgsAfterPromote = F->getFunctionType()->getNumParams();
821 for (Argument *PtrArg : PointerArgs) {
822 // Replace sret attribute with noalias. This reduces register pressure by
823 // avoiding a register copy.
824 if (PtrArg->hasStructRetAttr()) {
825 unsigned ArgNo = PtrArg->getArgNo();
826 F->removeParamAttr(ArgNo, Attribute::StructRet);
827 F->addParamAttr(ArgNo, Attribute::NoAlias);
828 for (Use &U : F->uses()) {
829 CallBase &CB = cast<CallBase>(*U.getUser());
830 CB.removeParamAttr(ArgNo, Attribute::StructRet);
831 CB.addParamAttr(ArgNo, Attribute::NoAlias);
832 }
833 }
834
835 // If we can promote the pointer to its value.
836 SmallVector<OffsetAndArgPart, 4> ArgParts;
837
838 if (findArgParts(PtrArg, DL, AAR, MaxElements, IsRecursive, ArgParts)) {
839 SmallVector<Type *, 4> Types;
840 for (const auto &Pair : ArgParts)
841 Types.push_back(Pair.second.Ty);
842
843 if (areTypesABICompatible(Types, *F, TTI)) {
844 NumArgsAfterPromote += ArgParts.size() - 1;
845 ArgsToPromote.insert({PtrArg, std::move(ArgParts)});
846 }
847 }
848 }
849
850 // No promotable pointer arguments.
851 if (ArgsToPromote.empty())
852 return nullptr;
853
854 if (NumArgsAfterPromote > TTI.getMaxNumArgs())
855 return nullptr;
856
857 return doPromotion(F, FAM, ArgsToPromote);
858 }
859
run(LazyCallGraph::SCC & C,CGSCCAnalysisManager & AM,LazyCallGraph & CG,CGSCCUpdateResult & UR)860 PreservedAnalyses ArgumentPromotionPass::run(LazyCallGraph::SCC &C,
861 CGSCCAnalysisManager &AM,
862 LazyCallGraph &CG,
863 CGSCCUpdateResult &UR) {
864 bool Changed = false, LocalChange;
865
866 // Iterate until we stop promoting from this SCC.
867 do {
868 LocalChange = false;
869
870 FunctionAnalysisManager &FAM =
871 AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
872
873 bool IsRecursive = C.size() > 1;
874 for (LazyCallGraph::Node &N : C) {
875 Function &OldF = N.getFunction();
876 Function *NewF = promoteArguments(&OldF, FAM, MaxElements, IsRecursive);
877 if (!NewF)
878 continue;
879 LocalChange = true;
880
881 // Directly substitute the functions in the call graph. Note that this
882 // requires the old function to be completely dead and completely
883 // replaced by the new function. It does no call graph updates, it merely
884 // swaps out the particular function mapped to a particular node in the
885 // graph.
886 C.getOuterRefSCC().replaceNodeFunction(N, *NewF);
887 FAM.clear(OldF, OldF.getName());
888 OldF.eraseFromParent();
889
890 PreservedAnalyses FuncPA;
891 FuncPA.preserveSet<CFGAnalyses>();
892 for (auto *U : NewF->users()) {
893 auto *UserF = cast<CallBase>(U)->getFunction();
894 FAM.invalidate(*UserF, FuncPA);
895 }
896 }
897
898 Changed |= LocalChange;
899 } while (LocalChange);
900
901 if (!Changed)
902 return PreservedAnalyses::all();
903
904 PreservedAnalyses PA;
905 // We've cleared out analyses for deleted functions.
906 PA.preserve<FunctionAnalysisManagerCGSCCProxy>();
907 // We've manually invalidated analyses for functions we've modified.
908 PA.preserveSet<AllAnalysesOn<Function>>();
909 return PA;
910 }
911