xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- CallPromotionUtils.cpp - Utilities for call promotion ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities useful for promoting indirect call sites to
10 // direct call sites.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Transforms/Utils/CallPromotionUtils.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/Analysis/Loads.h"
17 #include "llvm/Analysis/TypeMetadataUtils.h"
18 #include "llvm/IR/AttributeMask.h"
19 #include "llvm/IR/Constant.h"
20 #include "llvm/IR/IRBuilder.h"
21 #include "llvm/IR/Instructions.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
24 
25 using namespace llvm;
26 
27 #define DEBUG_TYPE "call-promotion-utils"
28 
29 /// Fix-up phi nodes in an invoke instruction's normal destination.
30 ///
31 /// After versioning an invoke instruction, values coming from the original
32 /// block will now be coming from the "merge" block. For example, in the code
33 /// below:
34 ///
35 ///   then_bb:
36 ///     %t0 = invoke i32 %ptr() to label %merge_bb unwind label %unwind_dst
37 ///
38 ///   else_bb:
39 ///     %t1 = invoke i32 %ptr() to label %merge_bb unwind label %unwind_dst
40 ///
41 ///   merge_bb:
42 ///     %t2 = phi i32 [ %t0, %then_bb ], [ %t1, %else_bb ]
43 ///     br %normal_dst
44 ///
45 ///   normal_dst:
46 ///     %t3 = phi i32 [ %x, %orig_bb ], ...
47 ///
48 /// "orig_bb" is no longer a predecessor of "normal_dst", so the phi nodes in
49 /// "normal_dst" must be fixed to refer to "merge_bb":
50 ///
51 ///    normal_dst:
52 ///      %t3 = phi i32 [ %x, %merge_bb ], ...
53 ///
fixupPHINodeForNormalDest(InvokeInst * Invoke,BasicBlock * OrigBlock,BasicBlock * MergeBlock)54 static void fixupPHINodeForNormalDest(InvokeInst *Invoke, BasicBlock *OrigBlock,
55                                       BasicBlock *MergeBlock) {
56   for (PHINode &Phi : Invoke->getNormalDest()->phis()) {
57     int Idx = Phi.getBasicBlockIndex(OrigBlock);
58     if (Idx == -1)
59       continue;
60     Phi.setIncomingBlock(Idx, MergeBlock);
61   }
62 }
63 
64 /// Fix-up phi nodes in an invoke instruction's unwind destination.
65 ///
66 /// After versioning an invoke instruction, values coming from the original
67 /// block will now be coming from either the "then" block or the "else" block.
68 /// For example, in the code below:
69 ///
70 ///   then_bb:
71 ///     %t0 = invoke i32 %ptr() to label %merge_bb unwind label %unwind_dst
72 ///
73 ///   else_bb:
74 ///     %t1 = invoke i32 %ptr() to label %merge_bb unwind label %unwind_dst
75 ///
76 ///   unwind_dst:
77 ///     %t3 = phi i32 [ %x, %orig_bb ], ...
78 ///
79 /// "orig_bb" is no longer a predecessor of "unwind_dst", so the phi nodes in
80 /// "unwind_dst" must be fixed to refer to "then_bb" and "else_bb":
81 ///
82 ///   unwind_dst:
83 ///     %t3 = phi i32 [ %x, %then_bb ], [ %x, %else_bb ], ...
84 ///
fixupPHINodeForUnwindDest(InvokeInst * Invoke,BasicBlock * OrigBlock,BasicBlock * ThenBlock,BasicBlock * ElseBlock)85 static void fixupPHINodeForUnwindDest(InvokeInst *Invoke, BasicBlock *OrigBlock,
86                                       BasicBlock *ThenBlock,
87                                       BasicBlock *ElseBlock) {
88   for (PHINode &Phi : Invoke->getUnwindDest()->phis()) {
89     int Idx = Phi.getBasicBlockIndex(OrigBlock);
90     if (Idx == -1)
91       continue;
92     auto *V = Phi.getIncomingValue(Idx);
93     Phi.setIncomingBlock(Idx, ThenBlock);
94     Phi.addIncoming(V, ElseBlock);
95   }
96 }
97 
98 /// Create a phi node for the returned value of a call or invoke instruction.
99 ///
100 /// After versioning a call or invoke instruction that returns a value, we have
101 /// to merge the value of the original and new instructions. We do this by
102 /// creating a phi node and replacing uses of the original instruction with this
103 /// phi node.
104 ///
105 /// For example, if \p OrigInst is defined in "else_bb" and \p NewInst is
106 /// defined in "then_bb", we create the following phi node:
107 ///
108 ///   ; Uses of the original instruction are replaced by uses of the phi node.
109 ///   %t0 = phi i32 [ %orig_inst, %else_bb ], [ %new_inst, %then_bb ],
110 ///
createRetPHINode(Instruction * OrigInst,Instruction * NewInst,BasicBlock * MergeBlock,IRBuilder<> & Builder)111 static void createRetPHINode(Instruction *OrigInst, Instruction *NewInst,
112                              BasicBlock *MergeBlock, IRBuilder<> &Builder) {
113 
114   if (OrigInst->getType()->isVoidTy() || OrigInst->use_empty())
115     return;
116 
117   Builder.SetInsertPoint(MergeBlock, MergeBlock->begin());
118   PHINode *Phi = Builder.CreatePHI(OrigInst->getType(), 0);
119   SmallVector<User *, 16> UsersToUpdate(OrigInst->users());
120   for (User *U : UsersToUpdate)
121     U->replaceUsesOfWith(OrigInst, Phi);
122   Phi->addIncoming(OrigInst, OrigInst->getParent());
123   Phi->addIncoming(NewInst, NewInst->getParent());
124 }
125 
126 /// Cast a call or invoke instruction to the given type.
127 ///
128 /// When promoting a call site, the return type of the call site might not match
129 /// that of the callee. If this is the case, we have to cast the returned value
130 /// to the correct type. The location of the cast depends on if we have a call
131 /// or invoke instruction.
132 ///
133 /// For example, if the call instruction below requires a bitcast after
134 /// promotion:
135 ///
136 ///   orig_bb:
137 ///     %t0 = call i32 @func()
138 ///     ...
139 ///
140 /// The bitcast is placed after the call instruction:
141 ///
142 ///   orig_bb:
143 ///     ; Uses of the original return value are replaced by uses of the bitcast.
144 ///     %t0 = call i32 @func()
145 ///     %t1 = bitcast i32 %t0 to ...
146 ///     ...
147 ///
148 /// A similar transformation is performed for invoke instructions. However,
149 /// since invokes are terminating, a new block is created for the bitcast. For
150 /// example, if the invoke instruction below requires a bitcast after promotion:
151 ///
152 ///   orig_bb:
153 ///     %t0 = invoke i32 @func() to label %normal_dst unwind label %unwind_dst
154 ///
155 /// The edge between the original block and the invoke's normal destination is
156 /// split, and the bitcast is placed there:
157 ///
158 ///   orig_bb:
159 ///     %t0 = invoke i32 @func() to label %split_bb unwind label %unwind_dst
160 ///
161 ///   split_bb:
162 ///     ; Uses of the original return value are replaced by uses of the bitcast.
163 ///     %t1 = bitcast i32 %t0 to ...
164 ///     br label %normal_dst
165 ///
createRetBitCast(CallBase & CB,Type * RetTy,CastInst ** RetBitCast)166 static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) {
167 
168   // Save the users of the calling instruction. These uses will be changed to
169   // use the bitcast after we create it.
170   SmallVector<User *, 16> UsersToUpdate(CB.users());
171 
172   // Determine an appropriate location to create the bitcast for the return
173   // value. The location depends on if we have a call or invoke instruction.
174   BasicBlock::iterator InsertBefore;
175   if (auto *Invoke = dyn_cast<InvokeInst>(&CB))
176     InsertBefore =
177         SplitEdge(Invoke->getParent(), Invoke->getNormalDest())->begin();
178   else
179     InsertBefore = std::next(CB.getIterator());
180 
181   // Bitcast the return value to the correct type.
182   auto *Cast = CastInst::CreateBitOrPointerCast(&CB, RetTy, "", InsertBefore);
183   if (RetBitCast)
184     *RetBitCast = Cast;
185 
186   // Replace all the original uses of the calling instruction with the bitcast.
187   for (User *U : UsersToUpdate)
188     U->replaceUsesOfWith(&CB, Cast);
189 }
190 
191 /// Predicate and clone the given call site.
192 ///
193 /// This function creates an if-then-else structure at the location of the call
194 /// site. The "if" condition is specified by `Cond`.
195 /// The original call site is moved into the "else" block, and a clone of the
196 /// call site is placed in the "then" block. The cloned instruction is returned.
197 ///
198 /// For example, the call instruction below:
199 ///
200 ///   orig_bb:
201 ///     %t0 = call i32 %ptr()
202 ///     ...
203 ///
204 /// Is replace by the following:
205 ///
206 ///   orig_bb:
207 ///     %cond = Cond
208 ///     br i1 %cond, %then_bb, %else_bb
209 ///
210 ///   then_bb:
211 ///     ; The clone of the original call instruction is placed in the "then"
212 ///     ; block. It is not yet promoted.
213 ///     %t1 = call i32 %ptr()
214 ///     br merge_bb
215 ///
216 ///   else_bb:
217 ///     ; The original call instruction is moved to the "else" block.
218 ///     %t0 = call i32 %ptr()
219 ///     br merge_bb
220 ///
221 ///   merge_bb:
222 ///     ; Uses of the original call instruction are replaced by uses of the phi
223 ///     ; node.
224 ///     %t2 = phi i32 [ %t0, %else_bb ], [ %t1, %then_bb ]
225 ///     ...
226 ///
227 /// A similar transformation is performed for invoke instructions. However,
228 /// since invokes are terminating, more work is required. For example, the
229 /// invoke instruction below:
230 ///
231 ///   orig_bb:
232 ///     %t0 = invoke %ptr() to label %normal_dst unwind label %unwind_dst
233 ///
234 /// Is replace by the following:
235 ///
236 ///   orig_bb:
237 ///     %cond = Cond
238 ///     br i1 %cond, %then_bb, %else_bb
239 ///
240 ///   then_bb:
241 ///     ; The clone of the original invoke instruction is placed in the "then"
242 ///     ; block, and its normal destination is set to the "merge" block. It is
243 ///     ; not yet promoted.
244 ///     %t1 = invoke i32 %ptr() to label %merge_bb unwind label %unwind_dst
245 ///
246 ///   else_bb:
247 ///     ; The original invoke instruction is moved into the "else" block, and
248 ///     ; its normal destination is set to the "merge" block.
249 ///     %t0 = invoke i32 %ptr() to label %merge_bb unwind label %unwind_dst
250 ///
251 ///   merge_bb:
252 ///     ; Uses of the original invoke instruction are replaced by uses of the
253 ///     ; phi node, and the merge block branches to the normal destination.
254 ///     %t2 = phi i32 [ %t0, %else_bb ], [ %t1, %then_bb ]
255 ///     br %normal_dst
256 ///
257 /// An indirect musttail call is processed slightly differently in that:
258 /// 1. No merge block needed for the orginal and the cloned callsite, since
259 ///    either one ends the flow. No phi node is needed either.
260 /// 2. The return statement following the original call site is duplicated too
261 ///    and placed immediately after the cloned call site per the IR convention.
262 ///
263 /// For example, the musttail call instruction below:
264 ///
265 ///   orig_bb:
266 ///     %t0 = musttail call i32 %ptr()
267 ///     ...
268 ///
269 /// Is replaced by the following:
270 ///
271 ///   cond_bb:
272 ///     %cond = Cond
273 ///     br i1 %cond, %then_bb, %orig_bb
274 ///
275 ///   then_bb:
276 ///     ; The clone of the original call instruction is placed in the "then"
277 ///     ; block. It is not yet promoted.
278 ///     %t1 = musttail call i32 %ptr()
279 ///     ret %t1
280 ///
281 ///   orig_bb:
282 ///     ; The original call instruction stays in its original block.
283 ///     %t0 = musttail call i32 %ptr()
284 ///     ret %t0
versionCallSiteWithCond(CallBase & CB,Value * Cond,MDNode * BranchWeights)285 static CallBase &versionCallSiteWithCond(CallBase &CB, Value *Cond,
286                                          MDNode *BranchWeights) {
287 
288   IRBuilder<> Builder(&CB);
289   CallBase *OrigInst = &CB;
290   BasicBlock *OrigBlock = OrigInst->getParent();
291 
292   if (OrigInst->isMustTailCall()) {
293     // Create an if-then structure. The original instruction stays in its block,
294     // and a clone of the original instruction is placed in the "then" block.
295     Instruction *ThenTerm =
296         SplitBlockAndInsertIfThen(Cond, &CB, false, BranchWeights);
297     BasicBlock *ThenBlock = ThenTerm->getParent();
298     ThenBlock->setName("if.true.direct_targ");
299     CallBase *NewInst = cast<CallBase>(OrigInst->clone());
300     NewInst->insertBefore(ThenTerm);
301 
302     // Place a clone of the optional bitcast after the new call site.
303     Value *NewRetVal = NewInst;
304     auto Next = OrigInst->getNextNode();
305     if (auto *BitCast = dyn_cast_or_null<BitCastInst>(Next)) {
306       assert(BitCast->getOperand(0) == OrigInst &&
307              "bitcast following musttail call must use the call");
308       auto NewBitCast = BitCast->clone();
309       NewBitCast->replaceUsesOfWith(OrigInst, NewInst);
310       NewBitCast->insertBefore(ThenTerm);
311       NewRetVal = NewBitCast;
312       Next = BitCast->getNextNode();
313     }
314 
315     // Place a clone of the return instruction after the new call site.
316     ReturnInst *Ret = dyn_cast_or_null<ReturnInst>(Next);
317     assert(Ret && "musttail call must precede a ret with an optional bitcast");
318     auto NewRet = Ret->clone();
319     if (Ret->getReturnValue())
320       NewRet->replaceUsesOfWith(Ret->getReturnValue(), NewRetVal);
321     NewRet->insertBefore(ThenTerm);
322 
323     // A return instructions is terminating, so we don't need the terminator
324     // instruction just created.
325     ThenTerm->eraseFromParent();
326 
327     return *NewInst;
328   }
329 
330   // Create an if-then-else structure. The original instruction is moved into
331   // the "else" block, and a clone of the original instruction is placed in the
332   // "then" block.
333   Instruction *ThenTerm = nullptr;
334   Instruction *ElseTerm = nullptr;
335   SplitBlockAndInsertIfThenElse(Cond, &CB, &ThenTerm, &ElseTerm, BranchWeights);
336   BasicBlock *ThenBlock = ThenTerm->getParent();
337   BasicBlock *ElseBlock = ElseTerm->getParent();
338   BasicBlock *MergeBlock = OrigInst->getParent();
339 
340   ThenBlock->setName("if.true.direct_targ");
341   ElseBlock->setName("if.false.orig_indirect");
342   MergeBlock->setName("if.end.icp");
343 
344   CallBase *NewInst = cast<CallBase>(OrigInst->clone());
345   OrigInst->moveBefore(ElseTerm);
346   NewInst->insertBefore(ThenTerm);
347 
348   // If the original call site is an invoke instruction, we have extra work to
349   // do since invoke instructions are terminating. We have to fix-up phi nodes
350   // in the invoke's normal and unwind destinations.
351   if (auto *OrigInvoke = dyn_cast<InvokeInst>(OrigInst)) {
352     auto *NewInvoke = cast<InvokeInst>(NewInst);
353 
354     // Invoke instructions are terminating, so we don't need the terminator
355     // instructions that were just created.
356     ThenTerm->eraseFromParent();
357     ElseTerm->eraseFromParent();
358 
359     // Branch from the "merge" block to the original normal destination.
360     Builder.SetInsertPoint(MergeBlock);
361     Builder.CreateBr(OrigInvoke->getNormalDest());
362 
363     // Fix-up phi nodes in the original invoke's normal and unwind destinations.
364     fixupPHINodeForNormalDest(OrigInvoke, OrigBlock, MergeBlock);
365     fixupPHINodeForUnwindDest(OrigInvoke, MergeBlock, ThenBlock, ElseBlock);
366 
367     // Now set the normal destinations of the invoke instructions to be the
368     // "merge" block.
369     OrigInvoke->setNormalDest(MergeBlock);
370     NewInvoke->setNormalDest(MergeBlock);
371   }
372 
373   // Create a phi node for the returned value of the call site.
374   createRetPHINode(OrigInst, NewInst, MergeBlock, Builder);
375 
376   return *NewInst;
377 }
378 
379 // Predicate and clone the given call site using condition `CB.callee ==
380 // Callee`. See the comment `versionCallSiteWithCond` for the transformation.
versionCallSite(CallBase & CB,Value * Callee,MDNode * BranchWeights)381 CallBase &llvm::versionCallSite(CallBase &CB, Value *Callee,
382                                 MDNode *BranchWeights) {
383 
384   IRBuilder<> Builder(&CB);
385 
386   // Create the compare. The called value and callee must have the same type to
387   // be compared.
388   if (CB.getCalledOperand()->getType() != Callee->getType())
389     Callee = Builder.CreateBitCast(Callee, CB.getCalledOperand()->getType());
390   auto *Cond = Builder.CreateICmpEQ(CB.getCalledOperand(), Callee);
391 
392   return versionCallSiteWithCond(CB, Cond, BranchWeights);
393 }
394 
isLegalToPromote(const CallBase & CB,Function * Callee,const char ** FailureReason)395 bool llvm::isLegalToPromote(const CallBase &CB, Function *Callee,
396                             const char **FailureReason) {
397   assert(!CB.getCalledFunction() && "Only indirect call sites can be promoted");
398 
399   auto &DL = Callee->getDataLayout();
400 
401   // Check the return type. The callee's return value type must be bitcast
402   // compatible with the call site's type.
403   Type *CallRetTy = CB.getType();
404   Type *FuncRetTy = Callee->getReturnType();
405   if (CallRetTy != FuncRetTy)
406     if (!CastInst::isBitOrNoopPointerCastable(FuncRetTy, CallRetTy, DL)) {
407       if (FailureReason)
408         *FailureReason = "Return type mismatch";
409       return false;
410     }
411 
412   // The number of formal arguments of the callee.
413   unsigned NumParams = Callee->getFunctionType()->getNumParams();
414 
415   // The number of actual arguments in the call.
416   unsigned NumArgs = CB.arg_size();
417 
418   // Check the number of arguments. The callee and call site must agree on the
419   // number of arguments.
420   if (NumArgs != NumParams && !Callee->isVarArg()) {
421     if (FailureReason)
422       *FailureReason = "The number of arguments mismatch";
423     return false;
424   }
425 
426   // Check the argument types. The callee's formal argument types must be
427   // bitcast compatible with the corresponding actual argument types of the call
428   // site.
429   unsigned I = 0;
430   for (; I < NumParams; ++I) {
431     // Make sure that the callee and call agree on byval/inalloca. The types do
432     // not have to match.
433     if (Callee->hasParamAttribute(I, Attribute::ByVal) !=
434         CB.getAttributes().hasParamAttr(I, Attribute::ByVal)) {
435       if (FailureReason)
436         *FailureReason = "byval mismatch";
437       return false;
438     }
439     if (Callee->hasParamAttribute(I, Attribute::InAlloca) !=
440         CB.getAttributes().hasParamAttr(I, Attribute::InAlloca)) {
441       if (FailureReason)
442         *FailureReason = "inalloca mismatch";
443       return false;
444     }
445 
446     Type *FormalTy = Callee->getFunctionType()->getFunctionParamType(I);
447     Type *ActualTy = CB.getArgOperand(I)->getType();
448     if (FormalTy == ActualTy)
449       continue;
450     if (!CastInst::isBitOrNoopPointerCastable(ActualTy, FormalTy, DL)) {
451       if (FailureReason)
452         *FailureReason = "Argument type mismatch";
453       return false;
454     }
455 
456     // MustTail call needs stricter type match. See
457     // Verifier::verifyMustTailCall().
458     if (CB.isMustTailCall()) {
459       PointerType *PF = dyn_cast<PointerType>(FormalTy);
460       PointerType *PA = dyn_cast<PointerType>(ActualTy);
461       if (!PF || !PA || PF->getAddressSpace() != PA->getAddressSpace()) {
462         if (FailureReason)
463           *FailureReason = "Musttail call Argument type mismatch";
464         return false;
465       }
466     }
467   }
468   for (; I < NumArgs; I++) {
469     // Vararg functions can have more arguments than parameters.
470     assert(Callee->isVarArg());
471     if (CB.paramHasAttr(I, Attribute::StructRet)) {
472       if (FailureReason)
473         *FailureReason = "SRet arg to vararg function";
474       return false;
475     }
476   }
477 
478   return true;
479 }
480 
promoteCall(CallBase & CB,Function * Callee,CastInst ** RetBitCast)481 CallBase &llvm::promoteCall(CallBase &CB, Function *Callee,
482                             CastInst **RetBitCast) {
483   assert(!CB.getCalledFunction() && "Only indirect call sites can be promoted");
484 
485   // Set the called function of the call site to be the given callee (but don't
486   // change the type).
487   CB.setCalledOperand(Callee);
488 
489   // Since the call site will no longer be direct, we must clear metadata that
490   // is only appropriate for indirect calls. This includes !prof and !callees
491   // metadata.
492   CB.setMetadata(LLVMContext::MD_prof, nullptr);
493   CB.setMetadata(LLVMContext::MD_callees, nullptr);
494 
495   // If the function type of the call site matches that of the callee, no
496   // additional work is required.
497   if (CB.getFunctionType() == Callee->getFunctionType())
498     return CB;
499 
500   // Save the return types of the call site and callee.
501   Type *CallSiteRetTy = CB.getType();
502   Type *CalleeRetTy = Callee->getReturnType();
503 
504   // Change the function type of the call site the match that of the callee.
505   CB.mutateFunctionType(Callee->getFunctionType());
506 
507   // Inspect the arguments of the call site. If an argument's type doesn't
508   // match the corresponding formal argument's type in the callee, bitcast it
509   // to the correct type.
510   auto CalleeType = Callee->getFunctionType();
511   auto CalleeParamNum = CalleeType->getNumParams();
512 
513   LLVMContext &Ctx = Callee->getContext();
514   const AttributeList &CallerPAL = CB.getAttributes();
515   // The new list of argument attributes.
516   SmallVector<AttributeSet, 4> NewArgAttrs;
517   bool AttributeChanged = false;
518 
519   for (unsigned ArgNo = 0; ArgNo < CalleeParamNum; ++ArgNo) {
520     auto *Arg = CB.getArgOperand(ArgNo);
521     Type *FormalTy = CalleeType->getParamType(ArgNo);
522     Type *ActualTy = Arg->getType();
523     if (FormalTy != ActualTy) {
524       auto *Cast =
525           CastInst::CreateBitOrPointerCast(Arg, FormalTy, "", CB.getIterator());
526       CB.setArgOperand(ArgNo, Cast);
527 
528       // Remove any incompatible attributes for the argument.
529       AttrBuilder ArgAttrs(Ctx, CallerPAL.getParamAttrs(ArgNo));
530       ArgAttrs.remove(AttributeFuncs::typeIncompatible(FormalTy));
531 
532       // We may have a different byval/inalloca type.
533       if (ArgAttrs.getByValType())
534         ArgAttrs.addByValAttr(Callee->getParamByValType(ArgNo));
535       if (ArgAttrs.getInAllocaType())
536         ArgAttrs.addInAllocaAttr(Callee->getParamInAllocaType(ArgNo));
537 
538       NewArgAttrs.push_back(AttributeSet::get(Ctx, ArgAttrs));
539       AttributeChanged = true;
540     } else
541       NewArgAttrs.push_back(CallerPAL.getParamAttrs(ArgNo));
542   }
543 
544   // If the return type of the call site doesn't match that of the callee, cast
545   // the returned value to the appropriate type.
546   // Remove any incompatible return value attribute.
547   AttrBuilder RAttrs(Ctx, CallerPAL.getRetAttrs());
548   if (!CallSiteRetTy->isVoidTy() && CallSiteRetTy != CalleeRetTy) {
549     createRetBitCast(CB, CallSiteRetTy, RetBitCast);
550     RAttrs.remove(AttributeFuncs::typeIncompatible(CalleeRetTy));
551     AttributeChanged = true;
552   }
553 
554   // Set the new callsite attribute.
555   if (AttributeChanged)
556     CB.setAttributes(AttributeList::get(Ctx, CallerPAL.getFnAttrs(),
557                                         AttributeSet::get(Ctx, RAttrs),
558                                         NewArgAttrs));
559 
560   return CB;
561 }
562 
promoteCallWithIfThenElse(CallBase & CB,Function * Callee,MDNode * BranchWeights)563 CallBase &llvm::promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
564                                           MDNode *BranchWeights) {
565 
566   // Version the indirect call site. If the called value is equal to the given
567   // callee, 'NewInst' will be executed, otherwise the original call site will
568   // be executed.
569   CallBase &NewInst = versionCallSite(CB, Callee, BranchWeights);
570 
571   // Promote 'NewInst' so that it directly calls the desired function.
572   return promoteCall(NewInst, Callee);
573 }
574 
promoteCallWithVTableCmp(CallBase & CB,Instruction * VPtr,Function * Callee,ArrayRef<Constant * > AddressPoints,MDNode * BranchWeights)575 CallBase &llvm::promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr,
576                                          Function *Callee,
577                                          ArrayRef<Constant *> AddressPoints,
578                                          MDNode *BranchWeights) {
579   assert(!AddressPoints.empty() && "Caller should guarantee");
580   IRBuilder<> Builder(&CB);
581   SmallVector<Value *, 2> ICmps;
582   for (auto &AddressPoint : AddressPoints)
583     ICmps.push_back(Builder.CreateICmpEQ(VPtr, AddressPoint));
584 
585   // TODO: Perform tree height reduction if the number of ICmps is high.
586   Value *Cond = Builder.CreateOr(ICmps);
587 
588   // Version the indirect call site. If Cond is true, 'NewInst' will be
589   // executed, otherwise the original call site will be executed.
590   CallBase &NewInst = versionCallSiteWithCond(CB, Cond, BranchWeights);
591 
592   // Promote 'NewInst' so that it directly calls the desired function.
593   return promoteCall(NewInst, Callee);
594 }
595 
tryPromoteCall(CallBase & CB)596 bool llvm::tryPromoteCall(CallBase &CB) {
597   assert(!CB.getCalledFunction());
598   Module *M = CB.getCaller()->getParent();
599   const DataLayout &DL = M->getDataLayout();
600   Value *Callee = CB.getCalledOperand();
601 
602   LoadInst *VTableEntryLoad = dyn_cast<LoadInst>(Callee);
603   if (!VTableEntryLoad)
604     return false; // Not a vtable entry load.
605   Value *VTableEntryPtr = VTableEntryLoad->getPointerOperand();
606   APInt VTableOffset(DL.getTypeSizeInBits(VTableEntryPtr->getType()), 0);
607   Value *VTableBasePtr = VTableEntryPtr->stripAndAccumulateConstantOffsets(
608       DL, VTableOffset, /* AllowNonInbounds */ true);
609   LoadInst *VTablePtrLoad = dyn_cast<LoadInst>(VTableBasePtr);
610   if (!VTablePtrLoad)
611     return false; // Not a vtable load.
612   Value *Object = VTablePtrLoad->getPointerOperand();
613   APInt ObjectOffset(DL.getTypeSizeInBits(Object->getType()), 0);
614   Value *ObjectBase = Object->stripAndAccumulateConstantOffsets(
615       DL, ObjectOffset, /* AllowNonInbounds */ true);
616   if (!(isa<AllocaInst>(ObjectBase) && ObjectOffset == 0))
617     // Not an Alloca or the offset isn't zero.
618     return false;
619 
620   // Look for the vtable pointer store into the object by the ctor.
621   BasicBlock::iterator BBI(VTablePtrLoad);
622   Value *VTablePtr = FindAvailableLoadedValue(
623       VTablePtrLoad, VTablePtrLoad->getParent(), BBI, 0, nullptr, nullptr);
624   if (!VTablePtr)
625     return false; // No vtable found.
626   APInt VTableOffsetGVBase(DL.getTypeSizeInBits(VTablePtr->getType()), 0);
627   Value *VTableGVBase = VTablePtr->stripAndAccumulateConstantOffsets(
628       DL, VTableOffsetGVBase, /* AllowNonInbounds */ true);
629   GlobalVariable *GV = dyn_cast<GlobalVariable>(VTableGVBase);
630   if (!(GV && GV->isConstant() && GV->hasDefinitiveInitializer()))
631     // Not in the form of a global constant variable with an initializer.
632     return false;
633 
634   APInt VTableGVOffset = VTableOffsetGVBase + VTableOffset;
635   if (!(VTableGVOffset.getActiveBits() <= 64))
636     return false; // Out of range.
637 
638   Function *DirectCallee = nullptr;
639   std::tie(DirectCallee, std::ignore) =
640       getFunctionAtVTableOffset(GV, VTableGVOffset.getZExtValue(), *M);
641   if (!DirectCallee)
642     return false; // No function pointer found.
643 
644   if (!isLegalToPromote(CB, DirectCallee))
645     return false;
646 
647   // Success.
648   promoteCall(CB, DirectCallee);
649   return true;
650 }
651 
652 #undef DEBUG_TYPE
653