xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeExtractor.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===- CodeExtractor.cpp - Pull code region into a new function -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the interface to tear out a code region, such as an
10 // individual loop or a parallel section, into a new function, replacing it with
11 // a call to the new function.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Transforms/Utils/CodeExtractor.h"
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/DenseMap.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallPtrSet.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Analysis/AssumptionCache.h"
23 #include "llvm/Analysis/BlockFrequencyInfo.h"
24 #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
25 #include "llvm/Analysis/BranchProbabilityInfo.h"
26 #include "llvm/IR/Argument.h"
27 #include "llvm/IR/Attributes.h"
28 #include "llvm/IR/BasicBlock.h"
29 #include "llvm/IR/CFG.h"
30 #include "llvm/IR/Constant.h"
31 #include "llvm/IR/Constants.h"
32 #include "llvm/IR/DIBuilder.h"
33 #include "llvm/IR/DataLayout.h"
34 #include "llvm/IR/DebugInfo.h"
35 #include "llvm/IR/DebugInfoMetadata.h"
36 #include "llvm/IR/DerivedTypes.h"
37 #include "llvm/IR/Dominators.h"
38 #include "llvm/IR/Function.h"
39 #include "llvm/IR/GlobalValue.h"
40 #include "llvm/IR/InstIterator.h"
41 #include "llvm/IR/InstrTypes.h"
42 #include "llvm/IR/Instruction.h"
43 #include "llvm/IR/Instructions.h"
44 #include "llvm/IR/IntrinsicInst.h"
45 #include "llvm/IR/Intrinsics.h"
46 #include "llvm/IR/LLVMContext.h"
47 #include "llvm/IR/MDBuilder.h"
48 #include "llvm/IR/Module.h"
49 #include "llvm/IR/PatternMatch.h"
50 #include "llvm/IR/Type.h"
51 #include "llvm/IR/User.h"
52 #include "llvm/IR/Value.h"
53 #include "llvm/IR/Verifier.h"
54 #include "llvm/Support/BlockFrequency.h"
55 #include "llvm/Support/BranchProbability.h"
56 #include "llvm/Support/Casting.h"
57 #include "llvm/Support/CommandLine.h"
58 #include "llvm/Support/Debug.h"
59 #include "llvm/Support/ErrorHandling.h"
60 #include "llvm/Support/raw_ostream.h"
61 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
62 #include <cassert>
63 #include <cstdint>
64 #include <iterator>
65 #include <map>
66 #include <utility>
67 #include <vector>
68 
69 using namespace llvm;
70 using namespace llvm::PatternMatch;
71 using ProfileCount = Function::ProfileCount;
72 
73 #define DEBUG_TYPE "code-extractor"
74 
75 // Provide a command-line option to aggregate function arguments into a struct
76 // for functions produced by the code extractor. This is useful when converting
77 // extracted functions to pthread-based code, as only one argument (void*) can
78 // be passed in to pthread_create().
79 static cl::opt<bool>
80 AggregateArgsOpt("aggregate-extracted-args", cl::Hidden,
81                  cl::desc("Aggregate arguments to code-extracted functions"));
82 
83 /// Test whether a block is valid for extraction.
84 static bool isBlockValidForExtraction(const BasicBlock &BB,
85                                       const SetVector<BasicBlock *> &Result,
86                                       bool AllowVarArgs, bool AllowAlloca) {
87   // taking the address of a basic block moved to another function is illegal
88   if (BB.hasAddressTaken())
89     return false;
90 
91   // don't hoist code that uses another basicblock address, as it's likely to
92   // lead to unexpected behavior, like cross-function jumps
93   SmallPtrSet<User const *, 16> Visited;
94   SmallVector<User const *, 16> ToVisit(llvm::make_pointer_range(BB));
95 
96   while (!ToVisit.empty()) {
97     User const *Curr = ToVisit.pop_back_val();
98     if (!Visited.insert(Curr).second)
99       continue;
100     if (isa<BlockAddress const>(Curr))
101       return false; // even a reference to self is likely to be not compatible
102 
103     if (isa<Instruction>(Curr) && cast<Instruction>(Curr)->getParent() != &BB)
104       continue;
105 
106     for (auto const &U : Curr->operands()) {
107       if (auto *UU = dyn_cast<User>(U))
108         ToVisit.push_back(UU);
109     }
110   }
111 
112   // If explicitly requested, allow vastart and alloca. For invoke instructions
113   // verify that extraction is valid.
114   for (BasicBlock::const_iterator I = BB.begin(), E = BB.end(); I != E; ++I) {
115     if (isa<AllocaInst>(I)) {
116        if (!AllowAlloca)
117          return false;
118        continue;
119     }
120 
121     if (const auto *II = dyn_cast<InvokeInst>(I)) {
122       // Unwind destination (either a landingpad, catchswitch, or cleanuppad)
123       // must be a part of the subgraph which is being extracted.
124       if (auto *UBB = II->getUnwindDest())
125         if (!Result.count(UBB))
126           return false;
127       continue;
128     }
129 
130     // All catch handlers of a catchswitch instruction as well as the unwind
131     // destination must be in the subgraph.
132     if (const auto *CSI = dyn_cast<CatchSwitchInst>(I)) {
133       if (auto *UBB = CSI->getUnwindDest())
134         if (!Result.count(UBB))
135           return false;
136       for (const auto *HBB : CSI->handlers())
137         if (!Result.count(const_cast<BasicBlock*>(HBB)))
138           return false;
139       continue;
140     }
141 
142     // Make sure that entire catch handler is within subgraph. It is sufficient
143     // to check that catch return's block is in the list.
144     if (const auto *CPI = dyn_cast<CatchPadInst>(I)) {
145       for (const auto *U : CPI->users())
146         if (const auto *CRI = dyn_cast<CatchReturnInst>(U))
147           if (!Result.count(const_cast<BasicBlock*>(CRI->getParent())))
148             return false;
149       continue;
150     }
151 
152     // And do similar checks for cleanup handler - the entire handler must be
153     // in subgraph which is going to be extracted. For cleanup return should
154     // additionally check that the unwind destination is also in the subgraph.
155     if (const auto *CPI = dyn_cast<CleanupPadInst>(I)) {
156       for (const auto *U : CPI->users())
157         if (const auto *CRI = dyn_cast<CleanupReturnInst>(U))
158           if (!Result.count(const_cast<BasicBlock*>(CRI->getParent())))
159             return false;
160       continue;
161     }
162     if (const auto *CRI = dyn_cast<CleanupReturnInst>(I)) {
163       if (auto *UBB = CRI->getUnwindDest())
164         if (!Result.count(UBB))
165           return false;
166       continue;
167     }
168 
169     if (const CallInst *CI = dyn_cast<CallInst>(I)) {
170       // musttail calls have several restrictions, generally enforcing matching
171       // calling conventions between the caller parent and musttail callee.
172       // We can't usually honor them, because the extracted function has a
173       // different signature altogether, taking inputs/outputs and returning
174       // a control-flow identifier rather than the actual return value.
175       if (CI->isMustTailCall())
176         return false;
177 
178       if (const Function *F = CI->getCalledFunction()) {
179         auto IID = F->getIntrinsicID();
180         if (IID == Intrinsic::vastart) {
181           if (AllowVarArgs)
182             continue;
183           else
184             return false;
185         }
186 
187         // Currently, we miscompile outlined copies of eh_typid_for. There are
188         // proposals for fixing this in llvm.org/PR39545.
189         if (IID == Intrinsic::eh_typeid_for)
190           return false;
191       }
192     }
193   }
194 
195   return true;
196 }
197 
198 /// Build a set of blocks to extract if the input blocks are viable.
199 static SetVector<BasicBlock *>
200 buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
201                         bool AllowVarArgs, bool AllowAlloca) {
202   assert(!BBs.empty() && "The set of blocks to extract must be non-empty");
203   SetVector<BasicBlock *> Result;
204 
205   // Loop over the blocks, adding them to our set-vector, and aborting with an
206   // empty set if we encounter invalid blocks.
207   for (BasicBlock *BB : BBs) {
208     // If this block is dead, don't process it.
209     if (DT && !DT->isReachableFromEntry(BB))
210       continue;
211 
212     if (!Result.insert(BB))
213       llvm_unreachable("Repeated basic blocks in extraction input");
214   }
215 
216   LLVM_DEBUG(dbgs() << "Region front block: " << Result.front()->getName()
217                     << '\n');
218 
219   for (auto *BB : Result) {
220     if (!isBlockValidForExtraction(*BB, Result, AllowVarArgs, AllowAlloca))
221       return {};
222 
223     // Make sure that the first block is not a landing pad.
224     if (BB == Result.front()) {
225       if (BB->isEHPad()) {
226         LLVM_DEBUG(dbgs() << "The first block cannot be an unwind block\n");
227         return {};
228       }
229       continue;
230     }
231 
232     // All blocks other than the first must not have predecessors outside of
233     // the subgraph which is being extracted.
234     for (auto *PBB : predecessors(BB))
235       if (!Result.count(PBB)) {
236         LLVM_DEBUG(dbgs() << "No blocks in this region may have entries from "
237                              "outside the region except for the first block!\n"
238                           << "Problematic source BB: " << BB->getName() << "\n"
239                           << "Problematic destination BB: " << PBB->getName()
240                           << "\n");
241         return {};
242       }
243   }
244 
245   return Result;
246 }
247 
248 /// isAlignmentPreservedForAddrCast - Return true if the cast operation
249 /// for specified target preserves original alignment
250 static bool isAlignmentPreservedForAddrCast(const Triple &TargetTriple) {
251   switch (TargetTriple.getArch()) {
252   case Triple::ArchType::amdgcn:
253   case Triple::ArchType::r600:
254     return true;
255   // TODO: Add other architectures for which we are certain that alignment
256   // is preserved during address space cast operations.
257   default:
258     return false;
259   }
260   return false;
261 }
262 
263 CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
264                              bool AggregateArgs, BlockFrequencyInfo *BFI,
265                              BranchProbabilityInfo *BPI, AssumptionCache *AC,
266                              bool AllowVarArgs, bool AllowAlloca,
267                              BasicBlock *AllocationBlock, std::string Suffix,
268                              bool ArgsInZeroAddressSpace)
269     : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
270       BPI(BPI), AC(AC), AllocationBlock(AllocationBlock),
271       AllowVarArgs(AllowVarArgs),
272       Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
273       Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace) {}
274 
275 /// definedInRegion - Return true if the specified value is defined in the
276 /// extracted region.
277 static bool definedInRegion(const SetVector<BasicBlock *> &Blocks, Value *V) {
278   if (Instruction *I = dyn_cast<Instruction>(V))
279     if (Blocks.count(I->getParent()))
280       return true;
281   return false;
282 }
283 
284 /// definedInCaller - Return true if the specified value is defined in the
285 /// function being code extracted, but not in the region being extracted.
286 /// These values must be passed in as live-ins to the function.
287 static bool definedInCaller(const SetVector<BasicBlock *> &Blocks, Value *V) {
288   if (isa<Argument>(V)) return true;
289   if (Instruction *I = dyn_cast<Instruction>(V))
290     if (!Blocks.count(I->getParent()))
291       return true;
292   return false;
293 }
294 
295 static BasicBlock *getCommonExitBlock(const SetVector<BasicBlock *> &Blocks) {
296   BasicBlock *CommonExitBlock = nullptr;
297   auto hasNonCommonExitSucc = [&](BasicBlock *Block) {
298     for (auto *Succ : successors(Block)) {
299       // Internal edges, ok.
300       if (Blocks.count(Succ))
301         continue;
302       if (!CommonExitBlock) {
303         CommonExitBlock = Succ;
304         continue;
305       }
306       if (CommonExitBlock != Succ)
307         return true;
308     }
309     return false;
310   };
311 
312   if (any_of(Blocks, hasNonCommonExitSucc))
313     return nullptr;
314 
315   return CommonExitBlock;
316 }
317 
318 CodeExtractorAnalysisCache::CodeExtractorAnalysisCache(Function &F) {
319   for (BasicBlock &BB : F) {
320     for (Instruction &II : BB.instructionsWithoutDebug())
321       if (auto *AI = dyn_cast<AllocaInst>(&II))
322         Allocas.push_back(AI);
323 
324     findSideEffectInfoForBlock(BB);
325   }
326 }
327 
328 void CodeExtractorAnalysisCache::findSideEffectInfoForBlock(BasicBlock &BB) {
329   for (Instruction &II : BB.instructionsWithoutDebug()) {
330     unsigned Opcode = II.getOpcode();
331     Value *MemAddr = nullptr;
332     switch (Opcode) {
333     case Instruction::Store:
334     case Instruction::Load: {
335       if (Opcode == Instruction::Store) {
336         StoreInst *SI = cast<StoreInst>(&II);
337         MemAddr = SI->getPointerOperand();
338       } else {
339         LoadInst *LI = cast<LoadInst>(&II);
340         MemAddr = LI->getPointerOperand();
341       }
342       // Global variable can not be aliased with locals.
343       if (isa<Constant>(MemAddr))
344         break;
345       Value *Base = MemAddr->stripInBoundsConstantOffsets();
346       if (!isa<AllocaInst>(Base)) {
347         SideEffectingBlocks.insert(&BB);
348         return;
349       }
350       BaseMemAddrs[&BB].insert(Base);
351       break;
352     }
353     default: {
354       IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(&II);
355       if (IntrInst) {
356         if (IntrInst->isLifetimeStartOrEnd())
357           break;
358         SideEffectingBlocks.insert(&BB);
359         return;
360       }
361       // Treat all the other cases conservatively if it has side effects.
362       if (II.mayHaveSideEffects()) {
363         SideEffectingBlocks.insert(&BB);
364         return;
365       }
366     }
367     }
368   }
369 }
370 
371 bool CodeExtractorAnalysisCache::doesBlockContainClobberOfAddr(
372     BasicBlock &BB, AllocaInst *Addr) const {
373   if (SideEffectingBlocks.count(&BB))
374     return true;
375   auto It = BaseMemAddrs.find(&BB);
376   if (It != BaseMemAddrs.end())
377     return It->second.count(Addr);
378   return false;
379 }
380 
381 bool CodeExtractor::isLegalToShrinkwrapLifetimeMarkers(
382     const CodeExtractorAnalysisCache &CEAC, Instruction *Addr) const {
383   AllocaInst *AI = cast<AllocaInst>(Addr->stripInBoundsConstantOffsets());
384   Function *Func = (*Blocks.begin())->getParent();
385   for (BasicBlock &BB : *Func) {
386     if (Blocks.count(&BB))
387       continue;
388     if (CEAC.doesBlockContainClobberOfAddr(BB, AI))
389       return false;
390   }
391   return true;
392 }
393 
394 BasicBlock *
395 CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) {
396   BasicBlock *SinglePredFromOutlineRegion = nullptr;
397   assert(!Blocks.count(CommonExitBlock) &&
398          "Expect a block outside the region!");
399   for (auto *Pred : predecessors(CommonExitBlock)) {
400     if (!Blocks.count(Pred))
401       continue;
402     if (!SinglePredFromOutlineRegion) {
403       SinglePredFromOutlineRegion = Pred;
404     } else if (SinglePredFromOutlineRegion != Pred) {
405       SinglePredFromOutlineRegion = nullptr;
406       break;
407     }
408   }
409 
410   if (SinglePredFromOutlineRegion)
411     return SinglePredFromOutlineRegion;
412 
413 #ifndef NDEBUG
414   auto getFirstPHI = [](BasicBlock *BB) {
415     BasicBlock::iterator I = BB->begin();
416     PHINode *FirstPhi = nullptr;
417     while (I != BB->end()) {
418       PHINode *Phi = dyn_cast<PHINode>(I);
419       if (!Phi)
420         break;
421       if (!FirstPhi) {
422         FirstPhi = Phi;
423         break;
424       }
425     }
426     return FirstPhi;
427   };
428   // If there are any phi nodes, the single pred either exists or has already
429   // be created before code extraction.
430   assert(!getFirstPHI(CommonExitBlock) && "Phi not expected");
431 #endif
432 
433   BasicBlock *NewExitBlock =
434       CommonExitBlock->splitBasicBlock(CommonExitBlock->getFirstNonPHIIt());
435 
436   for (BasicBlock *Pred :
437        llvm::make_early_inc_range(predecessors(CommonExitBlock))) {
438     if (Blocks.count(Pred))
439       continue;
440     Pred->getTerminator()->replaceUsesOfWith(CommonExitBlock, NewExitBlock);
441   }
442   // Now add the old exit block to the outline region.
443   Blocks.insert(CommonExitBlock);
444   return CommonExitBlock;
445 }
446 
447 // Find the pair of life time markers for address 'Addr' that are either
448 // defined inside the outline region or can legally be shrinkwrapped into the
449 // outline region. If there are not other untracked uses of the address, return
450 // the pair of markers if found; otherwise return a pair of nullptr.
451 CodeExtractor::LifetimeMarkerInfo
452 CodeExtractor::getLifetimeMarkers(const CodeExtractorAnalysisCache &CEAC,
453                                   Instruction *Addr,
454                                   BasicBlock *ExitBlock) const {
455   LifetimeMarkerInfo Info;
456 
457   for (User *U : Addr->users()) {
458     IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(U);
459     if (IntrInst) {
460       // We don't model addresses with multiple start/end markers, but the
461       // markers do not need to be in the region.
462       if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start) {
463         if (Info.LifeStart)
464           return {};
465         Info.LifeStart = IntrInst;
466         continue;
467       }
468       if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_end) {
469         if (Info.LifeEnd)
470           return {};
471         Info.LifeEnd = IntrInst;
472         continue;
473       }
474     }
475     // Find untracked uses of the address, bail.
476     if (!definedInRegion(Blocks, U))
477       return {};
478   }
479 
480   if (!Info.LifeStart || !Info.LifeEnd)
481     return {};
482 
483   Info.SinkLifeStart = !definedInRegion(Blocks, Info.LifeStart);
484   Info.HoistLifeEnd = !definedInRegion(Blocks, Info.LifeEnd);
485   // Do legality check.
486   if ((Info.SinkLifeStart || Info.HoistLifeEnd) &&
487       !isLegalToShrinkwrapLifetimeMarkers(CEAC, Addr))
488     return {};
489 
490   // Check to see if we have a place to do hoisting, if not, bail.
491   if (Info.HoistLifeEnd && !ExitBlock)
492     return {};
493 
494   return Info;
495 }
496 
497 void CodeExtractor::findAllocas(const CodeExtractorAnalysisCache &CEAC,
498                                 ValueSet &SinkCands, ValueSet &HoistCands,
499                                 BasicBlock *&ExitBlock) const {
500   Function *Func = (*Blocks.begin())->getParent();
501   ExitBlock = getCommonExitBlock(Blocks);
502 
503   auto moveOrIgnoreLifetimeMarkers =
504       [&](const LifetimeMarkerInfo &LMI) -> bool {
505     if (!LMI.LifeStart)
506       return false;
507     if (LMI.SinkLifeStart) {
508       LLVM_DEBUG(dbgs() << "Sinking lifetime.start: " << *LMI.LifeStart
509                         << "\n");
510       SinkCands.insert(LMI.LifeStart);
511     }
512     if (LMI.HoistLifeEnd) {
513       LLVM_DEBUG(dbgs() << "Hoisting lifetime.end: " << *LMI.LifeEnd << "\n");
514       HoistCands.insert(LMI.LifeEnd);
515     }
516     return true;
517   };
518 
519   // Look up allocas in the original function in CodeExtractorAnalysisCache, as
520   // this is much faster than walking all the instructions.
521   for (AllocaInst *AI : CEAC.getAllocas()) {
522     BasicBlock *BB = AI->getParent();
523     if (Blocks.count(BB))
524       continue;
525 
526     // As a prior call to extractCodeRegion() may have shrinkwrapped the alloca,
527     // check whether it is actually still in the original function.
528     Function *AIFunc = BB->getParent();
529     if (AIFunc != Func)
530       continue;
531 
532     LifetimeMarkerInfo MarkerInfo = getLifetimeMarkers(CEAC, AI, ExitBlock);
533     bool Moved = moveOrIgnoreLifetimeMarkers(MarkerInfo);
534     if (Moved) {
535       LLVM_DEBUG(dbgs() << "Sinking alloca: " << *AI << "\n");
536       SinkCands.insert(AI);
537       continue;
538     }
539 
540     // Find bitcasts in the outlined region that have lifetime marker users
541     // outside that region. Replace the lifetime marker use with an
542     // outside region bitcast to avoid unnecessary alloca/reload instructions
543     // and extra lifetime markers.
544     SmallVector<Instruction *, 2> LifetimeBitcastUsers;
545     for (User *U : AI->users()) {
546       if (!definedInRegion(Blocks, U))
547         continue;
548 
549       if (U->stripInBoundsConstantOffsets() != AI)
550         continue;
551 
552       Instruction *Bitcast = cast<Instruction>(U);
553       for (User *BU : Bitcast->users()) {
554         auto *IntrInst = dyn_cast<LifetimeIntrinsic>(BU);
555         if (!IntrInst)
556           continue;
557 
558         if (definedInRegion(Blocks, IntrInst))
559           continue;
560 
561         LLVM_DEBUG(dbgs() << "Replace use of extracted region bitcast"
562                           << *Bitcast << " in out-of-region lifetime marker "
563                           << *IntrInst << "\n");
564         LifetimeBitcastUsers.push_back(IntrInst);
565       }
566     }
567 
568     for (Instruction *I : LifetimeBitcastUsers) {
569       Module *M = AIFunc->getParent();
570       LLVMContext &Ctx = M->getContext();
571       auto *Int8PtrTy = PointerType::getUnqual(Ctx);
572       CastInst *CastI =
573           CastInst::CreatePointerCast(AI, Int8PtrTy, "lt.cast", I->getIterator());
574       I->replaceUsesOfWith(I->getOperand(1), CastI);
575     }
576 
577     // Follow any bitcasts.
578     SmallVector<Instruction *, 2> Bitcasts;
579     SmallVector<LifetimeMarkerInfo, 2> BitcastLifetimeInfo;
580     for (User *U : AI->users()) {
581       if (U->stripInBoundsConstantOffsets() == AI) {
582         Instruction *Bitcast = cast<Instruction>(U);
583         LifetimeMarkerInfo LMI = getLifetimeMarkers(CEAC, Bitcast, ExitBlock);
584         if (LMI.LifeStart) {
585           Bitcasts.push_back(Bitcast);
586           BitcastLifetimeInfo.push_back(LMI);
587           continue;
588         }
589       }
590 
591       // Found unknown use of AI.
592       if (!definedInRegion(Blocks, U)) {
593         Bitcasts.clear();
594         break;
595       }
596     }
597 
598     // Either no bitcasts reference the alloca or there are unknown uses.
599     if (Bitcasts.empty())
600       continue;
601 
602     LLVM_DEBUG(dbgs() << "Sinking alloca (via bitcast): " << *AI << "\n");
603     SinkCands.insert(AI);
604     for (unsigned I = 0, E = Bitcasts.size(); I != E; ++I) {
605       Instruction *BitcastAddr = Bitcasts[I];
606       const LifetimeMarkerInfo &LMI = BitcastLifetimeInfo[I];
607       assert(LMI.LifeStart &&
608              "Unsafe to sink bitcast without lifetime markers");
609       moveOrIgnoreLifetimeMarkers(LMI);
610       if (!definedInRegion(Blocks, BitcastAddr)) {
611         LLVM_DEBUG(dbgs() << "Sinking bitcast-of-alloca: " << *BitcastAddr
612                           << "\n");
613         SinkCands.insert(BitcastAddr);
614       }
615     }
616   }
617 }
618 
619 bool CodeExtractor::isEligible() const {
620   if (Blocks.empty())
621     return false;
622   BasicBlock *Header = *Blocks.begin();
623   Function *F = Header->getParent();
624 
625   // For functions with varargs, check that varargs handling is only done in the
626   // outlined function, i.e vastart and vaend are only used in outlined blocks.
627   if (AllowVarArgs && F->getFunctionType()->isVarArg()) {
628     auto containsVarArgIntrinsic = [](const Instruction &I) {
629       if (const CallInst *CI = dyn_cast<CallInst>(&I))
630         if (const Function *Callee = CI->getCalledFunction())
631           return Callee->getIntrinsicID() == Intrinsic::vastart ||
632                  Callee->getIntrinsicID() == Intrinsic::vaend;
633       return false;
634     };
635 
636     for (auto &BB : *F) {
637       if (Blocks.count(&BB))
638         continue;
639       if (llvm::any_of(BB, containsVarArgIntrinsic))
640         return false;
641     }
642   }
643   // stacksave as input implies stackrestore in the outlined function.
644   // This can confuse prolog epilog insertion phase.
645   // stacksave's uses must not cross outlined function.
646   for (BasicBlock *BB : Blocks) {
647     for (Instruction &I : *BB) {
648       IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
649       if (!II)
650         continue;
651       bool IsSave = II->getIntrinsicID() == Intrinsic::stacksave;
652       bool IsRestore = II->getIntrinsicID() == Intrinsic::stackrestore;
653       if (IsSave && any_of(II->users(), [&Blks = this->Blocks](User *U) {
654             return !definedInRegion(Blks, U);
655           }))
656         return false;
657       if (IsRestore && !definedInRegion(Blocks, II->getArgOperand(0)))
658         return false;
659     }
660   }
661   return true;
662 }
663 
664 void CodeExtractor::findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs,
665                                       const ValueSet &SinkCands,
666                                       bool CollectGlobalInputs) const {
667   for (BasicBlock *BB : Blocks) {
668     // If a used value is defined outside the region, it's an input.  If an
669     // instruction is used outside the region, it's an output.
670     for (Instruction &II : *BB) {
671       for (auto &OI : II.operands()) {
672         Value *V = OI;
673         if (!SinkCands.count(V) &&
674             (definedInCaller(Blocks, V) ||
675              (CollectGlobalInputs && llvm::isa<llvm::GlobalVariable>(V))))
676           Inputs.insert(V);
677       }
678 
679       for (User *U : II.users())
680         if (!definedInRegion(Blocks, U)) {
681           Outputs.insert(&II);
682           break;
683         }
684     }
685   }
686 }
687 
688 /// severSplitPHINodesOfEntry - If a PHI node has multiple inputs from outside
689 /// of the region, we need to split the entry block of the region so that the
690 /// PHI node is easier to deal with.
691 void CodeExtractor::severSplitPHINodesOfEntry(BasicBlock *&Header) {
692   unsigned NumPredsFromRegion = 0;
693   unsigned NumPredsOutsideRegion = 0;
694 
695   if (Header != &Header->getParent()->getEntryBlock()) {
696     PHINode *PN = dyn_cast<PHINode>(Header->begin());
697     if (!PN) return;  // No PHI nodes.
698 
699     // If the header node contains any PHI nodes, check to see if there is more
700     // than one entry from outside the region.  If so, we need to sever the
701     // header block into two.
702     for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
703       if (Blocks.count(PN->getIncomingBlock(i)))
704         ++NumPredsFromRegion;
705       else
706         ++NumPredsOutsideRegion;
707 
708     // If there is one (or fewer) predecessor from outside the region, we don't
709     // need to do anything special.
710     if (NumPredsOutsideRegion <= 1) return;
711   }
712 
713   // Otherwise, we need to split the header block into two pieces: one
714   // containing PHI nodes merging values from outside of the region, and a
715   // second that contains all of the code for the block and merges back any
716   // incoming values from inside of the region.
717   BasicBlock *NewBB = SplitBlock(Header, Header->getFirstNonPHIIt(), DT);
718 
719   // We only want to code extract the second block now, and it becomes the new
720   // header of the region.
721   BasicBlock *OldPred = Header;
722   Blocks.remove(OldPred);
723   Blocks.insert(NewBB);
724   Header = NewBB;
725 
726   // Okay, now we need to adjust the PHI nodes and any branches from within the
727   // region to go to the new header block instead of the old header block.
728   if (NumPredsFromRegion) {
729     PHINode *PN = cast<PHINode>(OldPred->begin());
730     // Loop over all of the predecessors of OldPred that are in the region,
731     // changing them to branch to NewBB instead.
732     for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
733       if (Blocks.count(PN->getIncomingBlock(i))) {
734         Instruction *TI = PN->getIncomingBlock(i)->getTerminator();
735         TI->replaceUsesOfWith(OldPred, NewBB);
736       }
737 
738     // Okay, everything within the region is now branching to the right block, we
739     // just have to update the PHI nodes now, inserting PHI nodes into NewBB.
740     BasicBlock::iterator AfterPHIs;
741     for (AfterPHIs = OldPred->begin(); isa<PHINode>(AfterPHIs); ++AfterPHIs) {
742       PHINode *PN = cast<PHINode>(AfterPHIs);
743       // Create a new PHI node in the new region, which has an incoming value
744       // from OldPred of PN.
745       PHINode *NewPN = PHINode::Create(PN->getType(), 1 + NumPredsFromRegion,
746                                        PN->getName() + ".ce");
747       NewPN->insertBefore(NewBB->begin());
748       PN->replaceAllUsesWith(NewPN);
749       NewPN->addIncoming(PN, OldPred);
750 
751       // Loop over all of the incoming value in PN, moving them to NewPN if they
752       // are from the extracted region.
753       for (unsigned i = 0; i != PN->getNumIncomingValues(); ++i) {
754         if (Blocks.count(PN->getIncomingBlock(i))) {
755           NewPN->addIncoming(PN->getIncomingValue(i), PN->getIncomingBlock(i));
756           PN->removeIncomingValue(i);
757           --i;
758         }
759       }
760     }
761   }
762 }
763 
764 /// severSplitPHINodesOfExits - if PHI nodes in exit blocks have inputs from
765 /// outlined region, we split these PHIs on two: one with inputs from region
766 /// and other with remaining incoming blocks; then first PHIs are placed in
767 /// outlined region.
768 void CodeExtractor::severSplitPHINodesOfExits() {
769   for (BasicBlock *ExitBB : ExtractedFuncRetVals) {
770     BasicBlock *NewBB = nullptr;
771 
772     for (PHINode &PN : ExitBB->phis()) {
773       // Find all incoming values from the outlining region.
774       SmallVector<unsigned, 2> IncomingVals;
775       for (unsigned i = 0; i < PN.getNumIncomingValues(); ++i)
776         if (Blocks.count(PN.getIncomingBlock(i)))
777           IncomingVals.push_back(i);
778 
779       // Do not process PHI if there is one (or fewer) predecessor from region.
780       // If PHI has exactly one predecessor from region, only this one incoming
781       // will be replaced on codeRepl block, so it should be safe to skip PHI.
782       if (IncomingVals.size() <= 1)
783         continue;
784 
785       // Create block for new PHIs and add it to the list of outlined if it
786       // wasn't done before.
787       if (!NewBB) {
788         NewBB = BasicBlock::Create(ExitBB->getContext(),
789                                    ExitBB->getName() + ".split",
790                                    ExitBB->getParent(), ExitBB);
791         SmallVector<BasicBlock *, 4> Preds(predecessors(ExitBB));
792         for (BasicBlock *PredBB : Preds)
793           if (Blocks.count(PredBB))
794             PredBB->getTerminator()->replaceUsesOfWith(ExitBB, NewBB);
795         BranchInst::Create(ExitBB, NewBB);
796         Blocks.insert(NewBB);
797       }
798 
799       // Split this PHI.
800       PHINode *NewPN = PHINode::Create(PN.getType(), IncomingVals.size(),
801                                        PN.getName() + ".ce");
802       NewPN->insertBefore(NewBB->getFirstNonPHIIt());
803       for (unsigned i : IncomingVals)
804         NewPN->addIncoming(PN.getIncomingValue(i), PN.getIncomingBlock(i));
805       for (unsigned i : reverse(IncomingVals))
806         PN.removeIncomingValue(i, false);
807       PN.addIncoming(NewPN, NewBB);
808     }
809   }
810 }
811 
812 void CodeExtractor::splitReturnBlocks() {
813   for (BasicBlock *Block : Blocks)
814     if (ReturnInst *RI = dyn_cast<ReturnInst>(Block->getTerminator())) {
815       BasicBlock *New =
816           Block->splitBasicBlock(RI->getIterator(), Block->getName() + ".ret");
817       if (DT) {
818         // Old dominates New. New node dominates all other nodes dominated
819         // by Old.
820         DomTreeNode *OldNode = DT->getNode(Block);
821         SmallVector<DomTreeNode *, 8> Children(OldNode->begin(),
822                                                OldNode->end());
823 
824         DomTreeNode *NewNode = DT->addNewBlock(New, Block);
825 
826         for (DomTreeNode *I : Children)
827           DT->changeImmediateDominator(I, NewNode);
828       }
829     }
830 }
831 
832 Function *CodeExtractor::constructFunctionDeclaration(
833     const ValueSet &inputs, const ValueSet &outputs, BlockFrequency EntryFreq,
834     const Twine &Name, ValueSet &StructValues, StructType *&StructTy) {
835   LLVM_DEBUG(dbgs() << "inputs: " << inputs.size() << "\n");
836   LLVM_DEBUG(dbgs() << "outputs: " << outputs.size() << "\n");
837 
838   Function *oldFunction = Blocks.front()->getParent();
839   Module *M = Blocks.front()->getModule();
840 
841   // Assemble the function's parameter lists.
842   std::vector<Type *> ParamTy;
843   std::vector<Type *> AggParamTy;
844   const DataLayout &DL = M->getDataLayout();
845 
846   // Add the types of the input values to the function's argument list
847   for (Value *value : inputs) {
848     LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n");
849     if (AggregateArgs && !ExcludeArgsFromAggregate.contains(value)) {
850       AggParamTy.push_back(value->getType());
851       StructValues.insert(value);
852     } else
853       ParamTy.push_back(value->getType());
854   }
855 
856   // Add the types of the output values to the function's argument list.
857   for (Value *output : outputs) {
858     LLVM_DEBUG(dbgs() << "instr used in func: " << *output << "\n");
859     if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) {
860       AggParamTy.push_back(output->getType());
861       StructValues.insert(output);
862     } else
863       ParamTy.push_back(
864           PointerType::get(output->getContext(), DL.getAllocaAddrSpace()));
865   }
866 
867   assert(
868       (ParamTy.size() + AggParamTy.size()) ==
869           (inputs.size() + outputs.size()) &&
870       "Number of scalar and aggregate params does not match inputs, outputs");
871   assert((StructValues.empty() || AggregateArgs) &&
872          "Expeced StructValues only with AggregateArgs set");
873 
874   // Concatenate scalar and aggregate params in ParamTy.
875   if (!AggParamTy.empty()) {
876     StructTy = StructType::get(M->getContext(), AggParamTy);
877     ParamTy.push_back(PointerType::get(
878         M->getContext(), ArgsInZeroAddressSpace ? 0 : DL.getAllocaAddrSpace()));
879   }
880 
881   Type *RetTy = getSwitchType();
882   LLVM_DEBUG({
883     dbgs() << "Function type: " << *RetTy << " f(";
884     for (Type *i : ParamTy)
885       dbgs() << *i << ", ";
886     dbgs() << ")\n";
887   });
888 
889   FunctionType *funcType = FunctionType::get(
890       RetTy, ParamTy, AllowVarArgs && oldFunction->isVarArg());
891 
892   // Create the new function
893   Function *newFunction =
894       Function::Create(funcType, GlobalValue::InternalLinkage,
895                        oldFunction->getAddressSpace(), Name, M);
896 
897   // Propagate personality info to the new function if there is one.
898   if (oldFunction->hasPersonalityFn())
899     newFunction->setPersonalityFn(oldFunction->getPersonalityFn());
900 
901   // Inherit all of the target dependent attributes and white-listed
902   // target independent attributes.
903   //  (e.g. If the extracted region contains a call to an x86.sse
904   //  instruction we need to make sure that the extracted region has the
905   //  "target-features" attribute allowing it to be lowered.
906   // FIXME: This should be changed to check to see if a specific
907   //           attribute can not be inherited.
908   for (const auto &Attr : oldFunction->getAttributes().getFnAttrs()) {
909     if (Attr.isStringAttribute()) {
910       if (Attr.getKindAsString() == "thunk")
911         continue;
912     } else
913       switch (Attr.getKindAsEnum()) {
914       // Those attributes cannot be propagated safely. Explicitly list them
915       // here so we get a warning if new attributes are added.
916       case Attribute::AllocSize:
917       case Attribute::Builtin:
918       case Attribute::Convergent:
919       case Attribute::JumpTable:
920       case Attribute::Naked:
921       case Attribute::NoBuiltin:
922       case Attribute::NoMerge:
923       case Attribute::NoReturn:
924       case Attribute::NoSync:
925       case Attribute::ReturnsTwice:
926       case Attribute::Speculatable:
927       case Attribute::StackAlignment:
928       case Attribute::WillReturn:
929       case Attribute::AllocKind:
930       case Attribute::PresplitCoroutine:
931       case Attribute::Memory:
932       case Attribute::NoFPClass:
933       case Attribute::CoroDestroyOnlyWhenComplete:
934       case Attribute::CoroElideSafe:
935       case Attribute::NoDivergenceSource:
936         continue;
937       // Those attributes should be safe to propagate to the extracted function.
938       case Attribute::AlwaysInline:
939       case Attribute::Cold:
940       case Attribute::DisableSanitizerInstrumentation:
941       case Attribute::FnRetThunkExtern:
942       case Attribute::Hot:
943       case Attribute::HybridPatchable:
944       case Attribute::NoRecurse:
945       case Attribute::InlineHint:
946       case Attribute::MinSize:
947       case Attribute::NoCallback:
948       case Attribute::NoDuplicate:
949       case Attribute::NoFree:
950       case Attribute::NoImplicitFloat:
951       case Attribute::NoInline:
952       case Attribute::NonLazyBind:
953       case Attribute::NoRedZone:
954       case Attribute::NoUnwind:
955       case Attribute::NoSanitizeBounds:
956       case Attribute::NoSanitizeCoverage:
957       case Attribute::NullPointerIsValid:
958       case Attribute::OptimizeForDebugging:
959       case Attribute::OptForFuzzing:
960       case Attribute::OptimizeNone:
961       case Attribute::OptimizeForSize:
962       case Attribute::SafeStack:
963       case Attribute::ShadowCallStack:
964       case Attribute::SanitizeAddress:
965       case Attribute::SanitizeMemory:
966       case Attribute::SanitizeNumericalStability:
967       case Attribute::SanitizeThread:
968       case Attribute::SanitizeType:
969       case Attribute::SanitizeHWAddress:
970       case Attribute::SanitizeMemTag:
971       case Attribute::SanitizeRealtime:
972       case Attribute::SanitizeRealtimeBlocking:
973       case Attribute::SpeculativeLoadHardening:
974       case Attribute::StackProtect:
975       case Attribute::StackProtectReq:
976       case Attribute::StackProtectStrong:
977       case Attribute::StrictFP:
978       case Attribute::UWTable:
979       case Attribute::VScaleRange:
980       case Attribute::NoCfCheck:
981       case Attribute::MustProgress:
982       case Attribute::NoProfile:
983       case Attribute::SkipProfile:
984         break;
985       // These attributes cannot be applied to functions.
986       case Attribute::Alignment:
987       case Attribute::AllocatedPointer:
988       case Attribute::AllocAlign:
989       case Attribute::ByVal:
990       case Attribute::Captures:
991       case Attribute::Dereferenceable:
992       case Attribute::DereferenceableOrNull:
993       case Attribute::ElementType:
994       case Attribute::InAlloca:
995       case Attribute::InReg:
996       case Attribute::Nest:
997       case Attribute::NoAlias:
998       case Attribute::NoUndef:
999       case Attribute::NonNull:
1000       case Attribute::Preallocated:
1001       case Attribute::ReadNone:
1002       case Attribute::ReadOnly:
1003       case Attribute::Returned:
1004       case Attribute::SExt:
1005       case Attribute::StructRet:
1006       case Attribute::SwiftError:
1007       case Attribute::SwiftSelf:
1008       case Attribute::SwiftAsync:
1009       case Attribute::ZExt:
1010       case Attribute::ImmArg:
1011       case Attribute::ByRef:
1012       case Attribute::WriteOnly:
1013       case Attribute::Writable:
1014       case Attribute::DeadOnUnwind:
1015       case Attribute::Range:
1016       case Attribute::Initializes:
1017       case Attribute::NoExt:
1018       //  These are not really attributes.
1019       case Attribute::None:
1020       case Attribute::EndAttrKinds:
1021       case Attribute::EmptyKey:
1022       case Attribute::TombstoneKey:
1023       case Attribute::DeadOnReturn:
1024         llvm_unreachable("Not a function attribute");
1025       }
1026 
1027     newFunction->addFnAttr(Attr);
1028   }
1029 
1030   // Create scalar and aggregate iterators to name all of the arguments we
1031   // inserted.
1032   Function::arg_iterator ScalarAI = newFunction->arg_begin();
1033 
1034   // Set names and attributes for input and output arguments.
1035   ScalarAI = newFunction->arg_begin();
1036   for (Value *input : inputs) {
1037     if (StructValues.contains(input))
1038       continue;
1039 
1040     ScalarAI->setName(input->getName());
1041     if (input->isSwiftError())
1042       newFunction->addParamAttr(ScalarAI - newFunction->arg_begin(),
1043                                 Attribute::SwiftError);
1044     ++ScalarAI;
1045   }
1046   for (Value *output : outputs) {
1047     if (StructValues.contains(output))
1048       continue;
1049 
1050     ScalarAI->setName(output->getName() + ".out");
1051     ++ScalarAI;
1052   }
1053 
1054   // Update the entry count of the function.
1055   if (BFI) {
1056     auto Count = BFI->getProfileCountFromFreq(EntryFreq);
1057     if (Count.has_value())
1058       newFunction->setEntryCount(
1059           ProfileCount(*Count, Function::PCT_Real)); // FIXME
1060   }
1061 
1062   return newFunction;
1063 }
1064 
1065 /// If the original function has debug info, we have to add a debug location
1066 /// to the new branch instruction from the artificial entry block.
1067 /// We use the debug location of the first instruction in the extracted
1068 /// blocks, as there is no other equivalent line in the source code.
1069 static void applyFirstDebugLoc(Function *oldFunction,
1070                                ArrayRef<BasicBlock *> Blocks,
1071                                Instruction *BranchI) {
1072   if (oldFunction->getSubprogram()) {
1073     any_of(Blocks, [&BranchI](const BasicBlock *BB) {
1074       return any_of(*BB, [&BranchI](const Instruction &I) {
1075         if (!I.getDebugLoc())
1076           return false;
1077         BranchI->setDebugLoc(I.getDebugLoc());
1078         return true;
1079       });
1080     });
1081   }
1082 }
1083 
1084 /// Erase lifetime.start markers which reference inputs to the extraction
1085 /// region, and insert the referenced memory into \p LifetimesStart.
1086 ///
1087 /// The extraction region is defined by a set of blocks (\p Blocks), and a set
1088 /// of allocas which will be moved from the caller function into the extracted
1089 /// function (\p SunkAllocas).
1090 static void eraseLifetimeMarkersOnInputs(const SetVector<BasicBlock *> &Blocks,
1091                                          const SetVector<Value *> &SunkAllocas,
1092                                          SetVector<Value *> &LifetimesStart) {
1093   for (BasicBlock *BB : Blocks) {
1094     for (Instruction &I : llvm::make_early_inc_range(*BB)) {
1095       auto *II = dyn_cast<LifetimeIntrinsic>(&I);
1096       if (!II)
1097         continue;
1098 
1099       // Get the memory operand of the lifetime marker. If the underlying
1100       // object is a sunk alloca, or is otherwise defined in the extraction
1101       // region, the lifetime marker must not be erased.
1102       Value *Mem = II->getOperand(1)->stripInBoundsOffsets();
1103       if (SunkAllocas.count(Mem) || definedInRegion(Blocks, Mem))
1104         continue;
1105 
1106       if (II->getIntrinsicID() == Intrinsic::lifetime_start)
1107         LifetimesStart.insert(Mem);
1108       II->eraseFromParent();
1109     }
1110   }
1111 }
1112 
1113 /// Insert lifetime start/end markers surrounding the call to the new function
1114 /// for objects defined in the caller.
1115 static void insertLifetimeMarkersSurroundingCall(
1116     Module *M, ArrayRef<Value *> LifetimesStart, ArrayRef<Value *> LifetimesEnd,
1117     CallInst *TheCall) {
1118   LLVMContext &Ctx = M->getContext();
1119   auto NegativeOne = ConstantInt::getSigned(Type::getInt64Ty(Ctx), -1);
1120   Instruction *Term = TheCall->getParent()->getTerminator();
1121 
1122   // Emit lifetime markers for the pointers given in \p Objects. Insert the
1123   // markers before the call if \p InsertBefore, and after the call otherwise.
1124   auto insertMarkers = [&](Intrinsic::ID MarkerFunc, ArrayRef<Value *> Objects,
1125                            bool InsertBefore) {
1126     for (Value *Mem : Objects) {
1127       assert((!isa<Instruction>(Mem) || cast<Instruction>(Mem)->getFunction() ==
1128                                             TheCall->getFunction()) &&
1129              "Input memory not defined in original function");
1130 
1131       Function *Func =
1132           Intrinsic::getOrInsertDeclaration(M, MarkerFunc, Mem->getType());
1133       auto Marker = CallInst::Create(Func, {NegativeOne, Mem});
1134       if (InsertBefore)
1135         Marker->insertBefore(TheCall->getIterator());
1136       else
1137         Marker->insertBefore(Term->getIterator());
1138     }
1139   };
1140 
1141   if (!LifetimesStart.empty()) {
1142     insertMarkers(Intrinsic::lifetime_start, LifetimesStart,
1143                   /*InsertBefore=*/true);
1144   }
1145 
1146   if (!LifetimesEnd.empty()) {
1147     insertMarkers(Intrinsic::lifetime_end, LifetimesEnd,
1148                   /*InsertBefore=*/false);
1149   }
1150 }
1151 
1152 void CodeExtractor::moveCodeToFunction(Function *newFunction) {
1153   auto newFuncIt = newFunction->begin();
1154   for (BasicBlock *Block : Blocks) {
1155     // Delete the basic block from the old function, and the list of blocks
1156     Block->removeFromParent();
1157 
1158     // Insert this basic block into the new function
1159     // Insert the original blocks after the entry block created
1160     // for the new function. The entry block may be followed
1161     // by a set of exit blocks at this point, but these exit
1162     // blocks better be placed at the end of the new function.
1163     newFuncIt = newFunction->insert(std::next(newFuncIt), Block);
1164   }
1165 }
1166 
1167 void CodeExtractor::calculateNewCallTerminatorWeights(
1168     BasicBlock *CodeReplacer,
1169     const DenseMap<BasicBlock *, BlockFrequency> &ExitWeights,
1170     BranchProbabilityInfo *BPI) {
1171   using Distribution = BlockFrequencyInfoImplBase::Distribution;
1172   using BlockNode = BlockFrequencyInfoImplBase::BlockNode;
1173 
1174   // Update the branch weights for the exit block.
1175   Instruction *TI = CodeReplacer->getTerminator();
1176   SmallVector<unsigned, 8> BranchWeights(TI->getNumSuccessors(), 0);
1177 
1178   // Block Frequency distribution with dummy node.
1179   Distribution BranchDist;
1180 
1181   SmallVector<BranchProbability, 4> EdgeProbabilities(
1182       TI->getNumSuccessors(), BranchProbability::getUnknown());
1183 
1184   // Add each of the frequencies of the successors.
1185   for (unsigned i = 0, e = TI->getNumSuccessors(); i < e; ++i) {
1186     BlockNode ExitNode(i);
1187     uint64_t ExitFreq = ExitWeights.lookup(TI->getSuccessor(i)).getFrequency();
1188     if (ExitFreq != 0)
1189       BranchDist.addExit(ExitNode, ExitFreq);
1190     else
1191       EdgeProbabilities[i] = BranchProbability::getZero();
1192   }
1193 
1194   // Check for no total weight.
1195   if (BranchDist.Total == 0) {
1196     BPI->setEdgeProbability(CodeReplacer, EdgeProbabilities);
1197     return;
1198   }
1199 
1200   // Normalize the distribution so that they can fit in unsigned.
1201   BranchDist.normalize();
1202 
1203   // Create normalized branch weights and set the metadata.
1204   for (unsigned I = 0, E = BranchDist.Weights.size(); I < E; ++I) {
1205     const auto &Weight = BranchDist.Weights[I];
1206 
1207     // Get the weight and update the current BFI.
1208     BranchWeights[Weight.TargetNode.Index] = Weight.Amount;
1209     BranchProbability BP(Weight.Amount, BranchDist.Total);
1210     EdgeProbabilities[Weight.TargetNode.Index] = BP;
1211   }
1212   BPI->setEdgeProbability(CodeReplacer, EdgeProbabilities);
1213   TI->setMetadata(
1214       LLVMContext::MD_prof,
1215       MDBuilder(TI->getContext()).createBranchWeights(BranchWeights));
1216 }
1217 
1218 /// Erase debug info intrinsics which refer to values in \p F but aren't in
1219 /// \p F.
1220 static void eraseDebugIntrinsicsWithNonLocalRefs(Function &F) {
1221   for (Instruction &I : instructions(F)) {
1222     SmallVector<DbgVariableIntrinsic *, 4> DbgUsers;
1223     SmallVector<DbgVariableRecord *, 4> DbgVariableRecords;
1224     findDbgUsers(DbgUsers, &I, &DbgVariableRecords);
1225     for (DbgVariableIntrinsic *DVI : DbgUsers)
1226       if (DVI->getFunction() != &F)
1227         DVI->eraseFromParent();
1228     for (DbgVariableRecord *DVR : DbgVariableRecords)
1229       if (DVR->getFunction() != &F)
1230         DVR->eraseFromParent();
1231   }
1232 }
1233 
1234 /// Fix up the debug info in the old and new functions. Following changes are
1235 /// done.
1236 /// 1. If a debug record points to a value that has been replaced, update the
1237 ///    record to use the new value.
1238 /// 2. If an Input value that has been replaced was used as a location of a
1239 ///    debug record in the Parent function, then materealize a similar record in
1240 ///    the new function.
1241 /// 3. Point line locations and debug intrinsics to the new subprogram scope
1242 /// 4. Remove intrinsics which point to values outside of the new function.
1243 static void fixupDebugInfoPostExtraction(Function &OldFunc, Function &NewFunc,
1244                                          CallInst &TheCall,
1245                                          const SetVector<Value *> &Inputs,
1246                                          ArrayRef<Value *> NewValues) {
1247   DISubprogram *OldSP = OldFunc.getSubprogram();
1248   LLVMContext &Ctx = OldFunc.getContext();
1249 
1250   if (!OldSP) {
1251     // Erase any debug info the new function contains.
1252     stripDebugInfo(NewFunc);
1253     // Make sure the old function doesn't contain any non-local metadata refs.
1254     eraseDebugIntrinsicsWithNonLocalRefs(NewFunc);
1255     return;
1256   }
1257 
1258   // Create a subprogram for the new function. Leave out a description of the
1259   // function arguments, as the parameters don't correspond to anything at the
1260   // source level.
1261   assert(OldSP->getUnit() && "Missing compile unit for subprogram");
1262   DIBuilder DIB(*OldFunc.getParent(), /*AllowUnresolved=*/false,
1263                 OldSP->getUnit());
1264   auto SPType = DIB.createSubroutineType(DIB.getOrCreateTypeArray({}));
1265   DISubprogram::DISPFlags SPFlags = DISubprogram::SPFlagDefinition |
1266                                     DISubprogram::SPFlagOptimized |
1267                                     DISubprogram::SPFlagLocalToUnit;
1268   auto NewSP = DIB.createFunction(
1269       OldSP->getUnit(), NewFunc.getName(), NewFunc.getName(), OldSP->getFile(),
1270       /*LineNo=*/0, SPType, /*ScopeLine=*/0, DINode::FlagZero, SPFlags);
1271   NewFunc.setSubprogram(NewSP);
1272 
1273   auto UpdateOrInsertDebugRecord = [&](auto *DR, Value *OldLoc, Value *NewLoc,
1274                                        DIExpression *Expr, bool Declare) {
1275     if (DR->getParent()->getParent() == &NewFunc) {
1276       DR->replaceVariableLocationOp(OldLoc, NewLoc);
1277       return;
1278     }
1279     if (Declare) {
1280       DIB.insertDeclare(NewLoc, DR->getVariable(), Expr, DR->getDebugLoc(),
1281                         &NewFunc.getEntryBlock());
1282       return;
1283     }
1284     DIB.insertDbgValueIntrinsic(
1285         NewLoc, DR->getVariable(), Expr, DR->getDebugLoc(),
1286         NewFunc.getEntryBlock().getTerminator()->getIterator());
1287   };
1288   for (auto [Input, NewVal] : zip_equal(Inputs, NewValues)) {
1289     SmallVector<DbgVariableIntrinsic *, 1> DbgUsers;
1290     SmallVector<DbgVariableRecord *, 1> DPUsers;
1291     findDbgUsers(DbgUsers, Input, &DPUsers);
1292     DIExpression *Expr = DIB.createExpression();
1293 
1294     // Iterate the debud users of the Input values. If they are in the extracted
1295     // function then update their location with the new value. If they are in
1296     // the parent function then create a similar debug record.
1297     for (auto *DVI : DbgUsers)
1298       UpdateOrInsertDebugRecord(DVI, Input, NewVal, Expr,
1299                                 isa<DbgDeclareInst>(DVI));
1300     for (auto *DVR : DPUsers)
1301       UpdateOrInsertDebugRecord(DVR, Input, NewVal, Expr, DVR->isDbgDeclare());
1302   }
1303 
1304   auto IsInvalidLocation = [&NewFunc](Value *Location) {
1305     // Location is invalid if it isn't a constant, an instruction or an
1306     // argument, or is an instruction/argument but isn't in the new function.
1307     if (!Location || (!isa<Constant>(Location) && !isa<Argument>(Location) &&
1308                       !isa<Instruction>(Location)))
1309       return true;
1310 
1311     if (Argument *Arg = dyn_cast<Argument>(Location))
1312       return Arg->getParent() != &NewFunc;
1313     if (Instruction *LocationInst = dyn_cast<Instruction>(Location))
1314       return LocationInst->getFunction() != &NewFunc;
1315     return false;
1316   };
1317 
1318   // Debug intrinsics in the new function need to be updated in one of two
1319   // ways:
1320   //  1) They need to be deleted, because they describe a value in the old
1321   //     function.
1322   //  2) They need to point to fresh metadata, e.g. because they currently
1323   //     point to a variable in the wrong scope.
1324   SmallDenseMap<DINode *, DINode *> RemappedMetadata;
1325   SmallVector<DbgVariableRecord *, 4> DVRsToDelete;
1326   DenseMap<const MDNode *, MDNode *> Cache;
1327 
1328   auto GetUpdatedDIVariable = [&](DILocalVariable *OldVar) {
1329     DINode *&NewVar = RemappedMetadata[OldVar];
1330     if (!NewVar) {
1331       DILocalScope *NewScope = DILocalScope::cloneScopeForSubprogram(
1332           *OldVar->getScope(), *NewSP, Ctx, Cache);
1333       NewVar = DIB.createAutoVariable(
1334           NewScope, OldVar->getName(), OldVar->getFile(), OldVar->getLine(),
1335           OldVar->getType(), /*AlwaysPreserve=*/false, DINode::FlagZero,
1336           OldVar->getAlignInBits());
1337     }
1338     return cast<DILocalVariable>(NewVar);
1339   };
1340 
1341   auto UpdateDbgLabel = [&](auto *LabelRecord) {
1342     // Point the label record to a fresh label within the new function if
1343     // the record was not inlined from some other function.
1344     if (LabelRecord->getDebugLoc().getInlinedAt())
1345       return;
1346     DILabel *OldLabel = LabelRecord->getLabel();
1347     DINode *&NewLabel = RemappedMetadata[OldLabel];
1348     if (!NewLabel) {
1349       DILocalScope *NewScope = DILocalScope::cloneScopeForSubprogram(
1350           *OldLabel->getScope(), *NewSP, Ctx, Cache);
1351       NewLabel =
1352           DILabel::get(Ctx, NewScope, OldLabel->getName(), OldLabel->getFile(),
1353                        OldLabel->getLine(), OldLabel->getColumn(),
1354                        OldLabel->isArtificial(), OldLabel->getCoroSuspendIdx());
1355     }
1356     LabelRecord->setLabel(cast<DILabel>(NewLabel));
1357   };
1358 
1359   auto UpdateDbgRecordsOnInst = [&](Instruction &I) -> void {
1360     for (DbgRecord &DR : I.getDbgRecordRange()) {
1361       if (DbgLabelRecord *DLR = dyn_cast<DbgLabelRecord>(&DR)) {
1362         UpdateDbgLabel(DLR);
1363         continue;
1364       }
1365 
1366       DbgVariableRecord &DVR = cast<DbgVariableRecord>(DR);
1367       // If any of the used locations are invalid, delete the record.
1368       if (any_of(DVR.location_ops(), IsInvalidLocation)) {
1369         DVRsToDelete.push_back(&DVR);
1370         continue;
1371       }
1372 
1373       // DbgAssign intrinsics have an extra Value argument:
1374       if (DVR.isDbgAssign() && IsInvalidLocation(DVR.getAddress())) {
1375         DVRsToDelete.push_back(&DVR);
1376         continue;
1377       }
1378 
1379       // If the variable was in the scope of the old function, i.e. it was not
1380       // inlined, point the intrinsic to a fresh variable within the new
1381       // function.
1382       if (!DVR.getDebugLoc().getInlinedAt())
1383         DVR.setVariable(GetUpdatedDIVariable(DVR.getVariable()));
1384     }
1385   };
1386 
1387   for (Instruction &I : instructions(NewFunc))
1388     UpdateDbgRecordsOnInst(I);
1389 
1390   for (auto *DVR : DVRsToDelete)
1391     DVR->getMarker()->MarkedInstr->dropOneDbgRecord(DVR);
1392   DIB.finalizeSubprogram(NewSP);
1393 
1394   // Fix up the scope information attached to the line locations and the
1395   // debug assignment metadata in the new function.
1396   DenseMap<DIAssignID *, DIAssignID *> AssignmentIDMap;
1397   for (Instruction &I : instructions(NewFunc)) {
1398     if (const DebugLoc &DL = I.getDebugLoc())
1399       I.setDebugLoc(
1400           DebugLoc::replaceInlinedAtSubprogram(DL, *NewSP, Ctx, Cache));
1401     for (DbgRecord &DR : I.getDbgRecordRange())
1402       DR.setDebugLoc(DebugLoc::replaceInlinedAtSubprogram(DR.getDebugLoc(),
1403                                                           *NewSP, Ctx, Cache));
1404 
1405     // Loop info metadata may contain line locations. Fix them up.
1406     auto updateLoopInfoLoc = [&Ctx, &Cache, NewSP](Metadata *MD) -> Metadata * {
1407       if (auto *Loc = dyn_cast_or_null<DILocation>(MD))
1408         return DebugLoc::replaceInlinedAtSubprogram(Loc, *NewSP, Ctx, Cache);
1409       return MD;
1410     };
1411     updateLoopMetadataDebugLocations(I, updateLoopInfoLoc);
1412     at::remapAssignID(AssignmentIDMap, I);
1413   }
1414   if (!TheCall.getDebugLoc())
1415     TheCall.setDebugLoc(DILocation::get(Ctx, 0, 0, OldSP));
1416 
1417   eraseDebugIntrinsicsWithNonLocalRefs(NewFunc);
1418 }
1419 
1420 Function *
1421 CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC) {
1422   ValueSet Inputs, Outputs;
1423   return extractCodeRegion(CEAC, Inputs, Outputs);
1424 }
1425 
1426 Function *
1427 CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC,
1428                                  ValueSet &inputs, ValueSet &outputs) {
1429   if (!isEligible())
1430     return nullptr;
1431 
1432   // Assumption: this is a single-entry code region, and the header is the first
1433   // block in the region.
1434   BasicBlock *header = *Blocks.begin();
1435   Function *oldFunction = header->getParent();
1436 
1437   normalizeCFGForExtraction(header);
1438 
1439   // Remove @llvm.assume calls that will be moved to the new function from the
1440   // old function's assumption cache.
1441   for (BasicBlock *Block : Blocks) {
1442     for (Instruction &I : llvm::make_early_inc_range(*Block)) {
1443       if (auto *AI = dyn_cast<AssumeInst>(&I)) {
1444         if (AC)
1445           AC->unregisterAssumption(AI);
1446         AI->eraseFromParent();
1447       }
1448     }
1449   }
1450 
1451   ValueSet SinkingCands, HoistingCands;
1452   BasicBlock *CommonExit = nullptr;
1453   findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
1454   assert(HoistingCands.empty() || CommonExit);
1455 
1456   // Find inputs to, outputs from the code region.
1457   findInputsOutputs(inputs, outputs, SinkingCands);
1458 
1459   // Collect objects which are inputs to the extraction region and also
1460   // referenced by lifetime start markers within it. The effects of these
1461   // markers must be replicated in the calling function to prevent the stack
1462   // coloring pass from merging slots which store input objects.
1463   ValueSet LifetimesStart;
1464   eraseLifetimeMarkersOnInputs(Blocks, SinkingCands, LifetimesStart);
1465 
1466   if (!HoistingCands.empty()) {
1467     auto *HoistToBlock = findOrCreateBlockForHoisting(CommonExit);
1468     Instruction *TI = HoistToBlock->getTerminator();
1469     for (auto *II : HoistingCands)
1470       cast<Instruction>(II)->moveBefore(TI->getIterator());
1471     computeExtractedFuncRetVals();
1472   }
1473 
1474   // CFG/ExitBlocks must not change hereafter
1475 
1476   // Calculate the entry frequency of the new function before we change the root
1477   //   block.
1478   BlockFrequency EntryFreq;
1479   DenseMap<BasicBlock *, BlockFrequency> ExitWeights;
1480   if (BFI) {
1481     assert(BPI && "Both BPI and BFI are required to preserve profile info");
1482     for (BasicBlock *Pred : predecessors(header)) {
1483       if (Blocks.count(Pred))
1484         continue;
1485       EntryFreq +=
1486           BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, header);
1487     }
1488 
1489     for (BasicBlock *Succ : ExtractedFuncRetVals) {
1490       for (BasicBlock *Block : predecessors(Succ)) {
1491         if (!Blocks.count(Block))
1492           continue;
1493 
1494         // Update the branch weight for this successor.
1495         BlockFrequency &BF = ExitWeights[Succ];
1496         BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, Succ);
1497       }
1498     }
1499   }
1500 
1501   // Determine position for the replacement code. Do so before header is moved
1502   // to the new function.
1503   BasicBlock *ReplIP = header;
1504   while (ReplIP && Blocks.count(ReplIP))
1505     ReplIP = ReplIP->getNextNode();
1506 
1507   // Construct new function based on inputs/outputs & add allocas for all defs.
1508   std::string SuffixToUse =
1509       Suffix.empty()
1510           ? (header->getName().empty() ? "extracted" : header->getName().str())
1511           : Suffix;
1512 
1513   ValueSet StructValues;
1514   StructType *StructTy = nullptr;
1515   Function *newFunction = constructFunctionDeclaration(
1516       inputs, outputs, EntryFreq, oldFunction->getName() + "." + SuffixToUse,
1517       StructValues, StructTy);
1518   SmallVector<Value *> NewValues;
1519 
1520   emitFunctionBody(inputs, outputs, StructValues, newFunction, StructTy, header,
1521                    SinkingCands, NewValues);
1522 
1523   std::vector<Value *> Reloads;
1524   CallInst *TheCall = emitReplacerCall(
1525       inputs, outputs, StructValues, newFunction, StructTy, oldFunction, ReplIP,
1526       EntryFreq, LifetimesStart.getArrayRef(), Reloads);
1527 
1528   insertReplacerCall(oldFunction, header, TheCall->getParent(), outputs,
1529                      Reloads, ExitWeights);
1530 
1531   fixupDebugInfoPostExtraction(*oldFunction, *newFunction, *TheCall, inputs,
1532                                NewValues);
1533 
1534   LLVM_DEBUG(llvm::dbgs() << "After extractCodeRegion - newFunction:\n");
1535   LLVM_DEBUG(newFunction->dump());
1536   LLVM_DEBUG(llvm::dbgs() << "After extractCodeRegion - oldFunction:\n");
1537   LLVM_DEBUG(oldFunction->dump());
1538   LLVM_DEBUG(if (AC && verifyAssumptionCache(*oldFunction, *newFunction, AC))
1539                  report_fatal_error("Stale Asumption cache for old Function!"));
1540   return newFunction;
1541 }
1542 
1543 void CodeExtractor::normalizeCFGForExtraction(BasicBlock *&header) {
1544   // If we have any return instructions in the region, split those blocks so
1545   // that the return is not in the region.
1546   splitReturnBlocks();
1547 
1548   // If we have to split PHI nodes of the entry or exit blocks, do so now.
1549   severSplitPHINodesOfEntry(header);
1550 
1551   // If a PHI in an exit block has multiple incoming values from the outlined
1552   // region, create a new PHI for those values within the region such that only
1553   // PHI itself becomes an output value, not each of its incoming values
1554   // individually.
1555   computeExtractedFuncRetVals();
1556   severSplitPHINodesOfExits();
1557 }
1558 
1559 void CodeExtractor::computeExtractedFuncRetVals() {
1560   ExtractedFuncRetVals.clear();
1561 
1562   SmallPtrSet<BasicBlock *, 2> ExitBlocks;
1563   for (BasicBlock *Block : Blocks) {
1564     for (BasicBlock *Succ : successors(Block)) {
1565       if (Blocks.count(Succ))
1566         continue;
1567 
1568       bool IsNew = ExitBlocks.insert(Succ).second;
1569       if (IsNew)
1570         ExtractedFuncRetVals.push_back(Succ);
1571     }
1572   }
1573 }
1574 
1575 Type *CodeExtractor::getSwitchType() {
1576   LLVMContext &Context = Blocks.front()->getContext();
1577 
1578   assert(ExtractedFuncRetVals.size() < 0xffff &&
1579          "too many exit blocks for switch");
1580   switch (ExtractedFuncRetVals.size()) {
1581   case 0:
1582   case 1:
1583     return Type::getVoidTy(Context);
1584   case 2:
1585     // Conditional branch, return a bool
1586     return Type::getInt1Ty(Context);
1587   default:
1588     return Type::getInt16Ty(Context);
1589   }
1590 }
1591 
1592 void CodeExtractor::emitFunctionBody(
1593     const ValueSet &inputs, const ValueSet &outputs,
1594     const ValueSet &StructValues, Function *newFunction,
1595     StructType *StructArgTy, BasicBlock *header, const ValueSet &SinkingCands,
1596     SmallVectorImpl<Value *> &NewValues) {
1597   Function *oldFunction = header->getParent();
1598   LLVMContext &Context = oldFunction->getContext();
1599 
1600   // The new function needs a root node because other nodes can branch to the
1601   // head of the region, but the entry node of a function cannot have preds.
1602   BasicBlock *newFuncRoot =
1603       BasicBlock::Create(Context, "newFuncRoot", newFunction);
1604 
1605   // Now sink all instructions which only have non-phi uses inside the region.
1606   // Group the allocas at the start of the block, so that any bitcast uses of
1607   // the allocas are well-defined.
1608   for (auto *II : SinkingCands) {
1609     if (!isa<AllocaInst>(II)) {
1610       cast<Instruction>(II)->moveBefore(*newFuncRoot,
1611                                         newFuncRoot->getFirstInsertionPt());
1612     }
1613   }
1614   for (auto *II : SinkingCands) {
1615     if (auto *AI = dyn_cast<AllocaInst>(II)) {
1616       AI->moveBefore(*newFuncRoot, newFuncRoot->getFirstInsertionPt());
1617     }
1618   }
1619 
1620   Function::arg_iterator ScalarAI = newFunction->arg_begin();
1621   Argument *AggArg = StructValues.empty()
1622                          ? nullptr
1623                          : newFunction->getArg(newFunction->arg_size() - 1);
1624 
1625   // Rewrite all users of the inputs in the extracted region to use the
1626   // arguments (or appropriate addressing into struct) instead.
1627   for (unsigned i = 0, e = inputs.size(), aggIdx = 0; i != e; ++i) {
1628     Value *RewriteVal;
1629     if (StructValues.contains(inputs[i])) {
1630       Value *Idx[2];
1631       Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext()));
1632       Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), aggIdx);
1633       GetElementPtrInst *GEP = GetElementPtrInst::Create(
1634           StructArgTy, AggArg, Idx, "gep_" + inputs[i]->getName(), newFuncRoot);
1635       LoadInst *LoadGEP =
1636           new LoadInst(StructArgTy->getElementType(aggIdx), GEP,
1637                        "loadgep_" + inputs[i]->getName(), newFuncRoot);
1638       // If we load pointer, we can add optional !align metadata
1639       // The existence of the !align metadata on the instruction tells
1640       // the optimizer that the value loaded is known to be aligned to
1641       // a boundary specified by the integer value in the metadata node.
1642       // Example:
1643       // %res = load ptr, ptr %input, align 8, !align !align_md_node
1644       //                                 ^         ^
1645       //                                 |         |
1646       //            alignment of %input address    |
1647       //                                           |
1648       //                                     alignment of %res object
1649       if (StructArgTy->getElementType(aggIdx)->isPointerTy()) {
1650         unsigned AlignmentValue;
1651         const Triple &TargetTriple =
1652             newFunction->getParent()->getTargetTriple();
1653         const DataLayout &DL = header->getDataLayout();
1654         // Pointers without casting can provide more information about
1655         // alignment. Use pointers without casts if given target preserves
1656         // alignment information for cast the operation.
1657         if (isAlignmentPreservedForAddrCast(TargetTriple))
1658           AlignmentValue =
1659               inputs[i]->stripPointerCasts()->getPointerAlignment(DL).value();
1660         else
1661           AlignmentValue = inputs[i]->getPointerAlignment(DL).value();
1662         MDBuilder MDB(header->getContext());
1663         LoadGEP->setMetadata(
1664             LLVMContext::MD_align,
1665             MDNode::get(
1666                 header->getContext(),
1667                 MDB.createConstant(ConstantInt::get(
1668                     Type::getInt64Ty(header->getContext()), AlignmentValue))));
1669       }
1670       RewriteVal = LoadGEP;
1671       ++aggIdx;
1672     } else
1673       RewriteVal = &*ScalarAI++;
1674 
1675     NewValues.push_back(RewriteVal);
1676   }
1677 
1678   moveCodeToFunction(newFunction);
1679 
1680   for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
1681     Value *RewriteVal = NewValues[i];
1682 
1683     std::vector<User *> Users(inputs[i]->user_begin(), inputs[i]->user_end());
1684     for (User *use : Users)
1685       if (Instruction *inst = dyn_cast<Instruction>(use))
1686         if (Blocks.count(inst->getParent()))
1687           inst->replaceUsesOfWith(inputs[i], RewriteVal);
1688   }
1689 
1690   // Since there may be multiple exits from the original region, make the new
1691   // function return an unsigned, switch on that number.  This loop iterates
1692   // over all of the blocks in the extracted region, updating any terminator
1693   // instructions in the to-be-extracted region that branch to blocks that are
1694   // not in the region to be extracted.
1695   std::map<BasicBlock *, BasicBlock *> ExitBlockMap;
1696 
1697   // Iterate over the previously collected targets, and create new blocks inside
1698   // the function to branch to.
1699   for (auto P : enumerate(ExtractedFuncRetVals)) {
1700     BasicBlock *OldTarget = P.value();
1701     size_t SuccNum = P.index();
1702 
1703     BasicBlock *NewTarget = BasicBlock::Create(
1704         Context, OldTarget->getName() + ".exitStub", newFunction);
1705     ExitBlockMap[OldTarget] = NewTarget;
1706 
1707     Value *brVal = nullptr;
1708     Type *RetTy = getSwitchType();
1709     assert(ExtractedFuncRetVals.size() < 0xffff &&
1710            "too many exit blocks for switch");
1711     switch (ExtractedFuncRetVals.size()) {
1712     case 0:
1713     case 1:
1714       // No value needed.
1715       break;
1716     case 2: // Conditional branch, return a bool
1717       brVal = ConstantInt::get(RetTy, !SuccNum);
1718       break;
1719     default:
1720       brVal = ConstantInt::get(RetTy, SuccNum);
1721       break;
1722     }
1723 
1724     ReturnInst::Create(Context, brVal, NewTarget);
1725   }
1726 
1727   for (BasicBlock *Block : Blocks) {
1728     Instruction *TI = Block->getTerminator();
1729     for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) {
1730       if (Blocks.count(TI->getSuccessor(i)))
1731         continue;
1732       BasicBlock *OldTarget = TI->getSuccessor(i);
1733       // add a new basic block which returns the appropriate value
1734       BasicBlock *NewTarget = ExitBlockMap[OldTarget];
1735       assert(NewTarget && "Unknown target block!");
1736 
1737       // rewrite the original branch instruction with this new target
1738       TI->setSuccessor(i, NewTarget);
1739     }
1740   }
1741 
1742   // Loop over all of the PHI nodes in the header and exit blocks, and change
1743   // any references to the old incoming edge to be the new incoming edge.
1744   for (BasicBlock::iterator I = header->begin(); isa<PHINode>(I); ++I) {
1745     PHINode *PN = cast<PHINode>(I);
1746     for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
1747       if (!Blocks.count(PN->getIncomingBlock(i)))
1748         PN->setIncomingBlock(i, newFuncRoot);
1749   }
1750 
1751   // Connect newFunction entry block to new header.
1752   BranchInst *BranchI = BranchInst::Create(header, newFuncRoot);
1753   applyFirstDebugLoc(oldFunction, Blocks.getArrayRef(), BranchI);
1754 
1755   // Store the arguments right after the definition of output value.
1756   // This should be proceeded after creating exit stubs to be ensure that invoke
1757   // result restore will be placed in the outlined function.
1758   ScalarAI = newFunction->arg_begin();
1759   unsigned AggIdx = 0;
1760 
1761   for (Value *Input : inputs) {
1762     if (StructValues.contains(Input))
1763       ++AggIdx;
1764     else
1765       ++ScalarAI;
1766   }
1767 
1768   for (Value *Output : outputs) {
1769     // Find proper insertion point.
1770     // In case Output is an invoke, we insert the store at the beginning in the
1771     // 'normal destination' BB. Otherwise we insert the store right after
1772     // Output.
1773     BasicBlock::iterator InsertPt;
1774     if (auto *InvokeI = dyn_cast<InvokeInst>(Output))
1775       InsertPt = InvokeI->getNormalDest()->getFirstInsertionPt();
1776     else if (auto *Phi = dyn_cast<PHINode>(Output))
1777       InsertPt = Phi->getParent()->getFirstInsertionPt();
1778     else if (auto *OutI = dyn_cast<Instruction>(Output))
1779       InsertPt = std::next(OutI->getIterator());
1780     else {
1781       // Globals don't need to be updated, just advance to the next argument.
1782       if (StructValues.contains(Output))
1783         ++AggIdx;
1784       else
1785         ++ScalarAI;
1786       continue;
1787     }
1788 
1789     assert((InsertPt->getFunction() == newFunction ||
1790             Blocks.count(InsertPt->getParent())) &&
1791            "InsertPt should be in new function");
1792 
1793     if (StructValues.contains(Output)) {
1794       assert(AggArg && "Number of aggregate output arguments should match "
1795                        "the number of defined values");
1796       Value *Idx[2];
1797       Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
1798       Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), AggIdx);
1799       GetElementPtrInst *GEP = GetElementPtrInst::Create(
1800           StructArgTy, AggArg, Idx, "gep_" + Output->getName(), InsertPt);
1801       new StoreInst(Output, GEP, InsertPt);
1802       ++AggIdx;
1803     } else {
1804       assert(ScalarAI != newFunction->arg_end() &&
1805              "Number of scalar output arguments should match "
1806              "the number of defined values");
1807       new StoreInst(Output, &*ScalarAI, InsertPt);
1808       ++ScalarAI;
1809     }
1810   }
1811 
1812   if (ExtractedFuncRetVals.empty()) {
1813     // Mark the new function `noreturn` if applicable. Terminators which resume
1814     // exception propagation are treated as returning instructions. This is to
1815     // avoid inserting traps after calls to outlined functions which unwind.
1816     if (none_of(Blocks, [](const BasicBlock *BB) {
1817           const Instruction *Term = BB->getTerminator();
1818           return isa<ReturnInst>(Term) || isa<ResumeInst>(Term);
1819         }))
1820       newFunction->setDoesNotReturn();
1821   }
1822 }
1823 
1824 CallInst *CodeExtractor::emitReplacerCall(
1825     const ValueSet &inputs, const ValueSet &outputs,
1826     const ValueSet &StructValues, Function *newFunction,
1827     StructType *StructArgTy, Function *oldFunction, BasicBlock *ReplIP,
1828     BlockFrequency EntryFreq, ArrayRef<Value *> LifetimesStart,
1829     std::vector<Value *> &Reloads) {
1830   LLVMContext &Context = oldFunction->getContext();
1831   Module *M = oldFunction->getParent();
1832   const DataLayout &DL = M->getDataLayout();
1833 
1834   // This takes place of the original loop
1835   BasicBlock *codeReplacer =
1836       BasicBlock::Create(Context, "codeRepl", oldFunction, ReplIP);
1837   if (AllocationBlock)
1838     assert(AllocationBlock->getParent() == oldFunction &&
1839            "AllocationBlock is not in the same function");
1840   BasicBlock *AllocaBlock =
1841       AllocationBlock ? AllocationBlock : &oldFunction->getEntryBlock();
1842 
1843   // Update the entry count of the function.
1844   if (BFI)
1845     BFI->setBlockFreq(codeReplacer, EntryFreq);
1846 
1847   std::vector<Value *> params;
1848 
1849   // Add inputs as params, or to be filled into the struct
1850   for (Value *input : inputs) {
1851     if (StructValues.contains(input))
1852       continue;
1853 
1854     params.push_back(input);
1855   }
1856 
1857   // Create allocas for the outputs
1858   std::vector<Value *> ReloadOutputs;
1859   for (Value *output : outputs) {
1860     if (StructValues.contains(output))
1861       continue;
1862 
1863     AllocaInst *alloca = new AllocaInst(
1864         output->getType(), DL.getAllocaAddrSpace(), nullptr,
1865         output->getName() + ".loc", AllocaBlock->getFirstInsertionPt());
1866     params.push_back(alloca);
1867     ReloadOutputs.push_back(alloca);
1868   }
1869 
1870   AllocaInst *Struct = nullptr;
1871   if (!StructValues.empty()) {
1872     Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
1873                             "structArg", AllocaBlock->getFirstInsertionPt());
1874     if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
1875       auto *StructSpaceCast = new AddrSpaceCastInst(
1876           Struct, PointerType ::get(Context, 0), "structArg.ascast");
1877       StructSpaceCast->insertAfter(Struct->getIterator());
1878       params.push_back(StructSpaceCast);
1879     } else {
1880       params.push_back(Struct);
1881     }
1882 
1883     unsigned AggIdx = 0;
1884     for (Value *input : inputs) {
1885       if (!StructValues.contains(input))
1886         continue;
1887 
1888       Value *Idx[2];
1889       Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
1890       Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), AggIdx);
1891       GetElementPtrInst *GEP = GetElementPtrInst::Create(
1892           StructArgTy, Struct, Idx, "gep_" + input->getName());
1893       GEP->insertInto(codeReplacer, codeReplacer->end());
1894       new StoreInst(input, GEP, codeReplacer);
1895 
1896       ++AggIdx;
1897     }
1898   }
1899 
1900   // Emit the call to the function
1901   CallInst *call = CallInst::Create(
1902       newFunction, params, ExtractedFuncRetVals.size() > 1 ? "targetBlock" : "",
1903       codeReplacer);
1904 
1905   // Set swifterror parameter attributes.
1906   unsigned ParamIdx = 0;
1907   unsigned AggIdx = 0;
1908   for (auto input : inputs) {
1909     if (StructValues.contains(input)) {
1910       ++AggIdx;
1911     } else {
1912       if (input->isSwiftError())
1913         call->addParamAttr(ParamIdx, Attribute::SwiftError);
1914       ++ParamIdx;
1915     }
1916   }
1917 
1918   // Add debug location to the new call, if the original function has debug
1919   // info. In that case, the terminator of the entry block of the extracted
1920   // function contains the first debug location of the extracted function,
1921   // set in extractCodeRegion.
1922   if (codeReplacer->getParent()->getSubprogram()) {
1923     if (auto DL = newFunction->getEntryBlock().getTerminator()->getDebugLoc())
1924       call->setDebugLoc(DL);
1925   }
1926 
1927   // Reload the outputs passed in by reference, use the struct if output is in
1928   // the aggregate or reload from the scalar argument.
1929   for (unsigned i = 0, e = outputs.size(), scalarIdx = 0; i != e; ++i) {
1930     Value *Output = nullptr;
1931     if (StructValues.contains(outputs[i])) {
1932       Value *Idx[2];
1933       Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
1934       Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), AggIdx);
1935       GetElementPtrInst *GEP = GetElementPtrInst::Create(
1936           StructArgTy, Struct, Idx, "gep_reload_" + outputs[i]->getName());
1937       GEP->insertInto(codeReplacer, codeReplacer->end());
1938       Output = GEP;
1939       ++AggIdx;
1940     } else {
1941       Output = ReloadOutputs[scalarIdx];
1942       ++scalarIdx;
1943     }
1944     LoadInst *load =
1945         new LoadInst(outputs[i]->getType(), Output,
1946                      outputs[i]->getName() + ".reload", codeReplacer);
1947     Reloads.push_back(load);
1948   }
1949 
1950   // Now we can emit a switch statement using the call as a value.
1951   SwitchInst *TheSwitch =
1952       SwitchInst::Create(Constant::getNullValue(Type::getInt16Ty(Context)),
1953                          codeReplacer, 0, codeReplacer);
1954   for (auto P : enumerate(ExtractedFuncRetVals)) {
1955     BasicBlock *OldTarget = P.value();
1956     size_t SuccNum = P.index();
1957 
1958     TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context), SuccNum),
1959                        OldTarget);
1960   }
1961 
1962   // Now that we've done the deed, simplify the switch instruction.
1963   Type *OldFnRetTy = TheSwitch->getParent()->getParent()->getReturnType();
1964   switch (ExtractedFuncRetVals.size()) {
1965   case 0:
1966     // There are no successors (the block containing the switch itself), which
1967     // means that previously this was the last part of the function, and hence
1968     // this should be rewritten as a `ret` or `unreachable`.
1969     if (newFunction->doesNotReturn()) {
1970       // If fn is no return, end with an unreachable terminator.
1971       (void)new UnreachableInst(Context, TheSwitch->getIterator());
1972     } else if (OldFnRetTy->isVoidTy()) {
1973       // We have no return value.
1974       ReturnInst::Create(Context, nullptr,
1975                          TheSwitch->getIterator()); // Return void
1976     } else if (OldFnRetTy == TheSwitch->getCondition()->getType()) {
1977       // return what we have
1978       ReturnInst::Create(Context, TheSwitch->getCondition(),
1979                          TheSwitch->getIterator());
1980     } else {
1981       // Otherwise we must have code extracted an unwind or something, just
1982       // return whatever we want.
1983       ReturnInst::Create(Context, Constant::getNullValue(OldFnRetTy),
1984                          TheSwitch->getIterator());
1985     }
1986 
1987     TheSwitch->eraseFromParent();
1988     break;
1989   case 1:
1990     // Only a single destination, change the switch into an unconditional
1991     // branch.
1992     BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getIterator());
1993     TheSwitch->eraseFromParent();
1994     break;
1995   case 2:
1996     // Only two destinations, convert to a condition branch.
1997     // Remark: This also swaps the target branches:
1998     // 0 -> false -> getSuccessor(2); 1 -> true -> getSuccessor(1)
1999     BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getSuccessor(2),
2000                        call, TheSwitch->getIterator());
2001     TheSwitch->eraseFromParent();
2002     break;
2003   default:
2004     // Otherwise, make the default destination of the switch instruction be one
2005     // of the other successors.
2006     TheSwitch->setCondition(call);
2007     TheSwitch->setDefaultDest(
2008         TheSwitch->getSuccessor(ExtractedFuncRetVals.size()));
2009     // Remove redundant case
2010     TheSwitch->removeCase(
2011         SwitchInst::CaseIt(TheSwitch, ExtractedFuncRetVals.size() - 1));
2012     break;
2013   }
2014 
2015   // Insert lifetime markers around the reloads of any output values. The
2016   // allocas output values are stored in are only in-use in the codeRepl block.
2017   insertLifetimeMarkersSurroundingCall(M, ReloadOutputs, ReloadOutputs, call);
2018 
2019   // Replicate the effects of any lifetime start/end markers which referenced
2020   // input objects in the extraction region by placing markers around the call.
2021   insertLifetimeMarkersSurroundingCall(oldFunction->getParent(), LifetimesStart,
2022                                        {}, call);
2023 
2024   return call;
2025 }
2026 
2027 void CodeExtractor::insertReplacerCall(
2028     Function *oldFunction, BasicBlock *header, BasicBlock *codeReplacer,
2029     const ValueSet &outputs, ArrayRef<Value *> Reloads,
2030     const DenseMap<BasicBlock *, BlockFrequency> &ExitWeights) {
2031 
2032   // Rewrite branches to basic blocks outside of the loop to new dummy blocks
2033   // within the new function. This must be done before we lose track of which
2034   // blocks were originally in the code region.
2035   std::vector<User *> Users(header->user_begin(), header->user_end());
2036   for (auto &U : Users)
2037     // The BasicBlock which contains the branch is not in the region
2038     // modify the branch target to a new block
2039     if (Instruction *I = dyn_cast<Instruction>(U))
2040       if (I->isTerminator() && I->getFunction() == oldFunction &&
2041           !Blocks.count(I->getParent()))
2042         I->replaceUsesOfWith(header, codeReplacer);
2043 
2044   // When moving the code region it is sufficient to replace all uses to the
2045   // extracted function values. Since the original definition's block
2046   // dominated its use, it will also be dominated by codeReplacer's switch
2047   // which joined multiple exit blocks.
2048   for (BasicBlock *ExitBB : ExtractedFuncRetVals)
2049     for (PHINode &PN : ExitBB->phis()) {
2050       Value *IncomingCodeReplacerVal = nullptr;
2051       for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) {
2052         // Ignore incoming values from outside of the extracted region.
2053         if (!Blocks.count(PN.getIncomingBlock(i)))
2054           continue;
2055 
2056         // Ensure that there is only one incoming value from codeReplacer.
2057         if (!IncomingCodeReplacerVal) {
2058           PN.setIncomingBlock(i, codeReplacer);
2059           IncomingCodeReplacerVal = PN.getIncomingValue(i);
2060         } else
2061           assert(IncomingCodeReplacerVal == PN.getIncomingValue(i) &&
2062                  "PHI has two incompatbile incoming values from codeRepl");
2063       }
2064     }
2065 
2066   for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
2067     Value *load = Reloads[i];
2068     std::vector<User *> Users(outputs[i]->user_begin(), outputs[i]->user_end());
2069     for (User *U : Users) {
2070       Instruction *inst = cast<Instruction>(U);
2071       if (inst->getParent()->getParent() == oldFunction)
2072         inst->replaceUsesOfWith(outputs[i], load);
2073     }
2074   }
2075 
2076   // Update the branch weights for the exit block.
2077   if (BFI && ExtractedFuncRetVals.size() > 1)
2078     calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI);
2079 }
2080 
2081 bool CodeExtractor::verifyAssumptionCache(const Function &OldFunc,
2082                                           const Function &NewFunc,
2083                                           AssumptionCache *AC) {
2084   for (auto AssumeVH : AC->assumptions()) {
2085     auto *I = dyn_cast_or_null<CallInst>(AssumeVH);
2086     if (!I)
2087       continue;
2088 
2089     // There shouldn't be any llvm.assume intrinsics in the new function.
2090     if (I->getFunction() != &OldFunc)
2091       return true;
2092 
2093     // There shouldn't be any stale affected values in the assumption cache
2094     // that were previously in the old function, but that have now been moved
2095     // to the new function.
2096     for (auto AffectedValVH : AC->assumptionsFor(I->getOperand(0))) {
2097       auto *AffectedCI = dyn_cast_or_null<CallInst>(AffectedValVH);
2098       if (!AffectedCI)
2099         continue;
2100       if (AffectedCI->getFunction() != &OldFunc)
2101         return true;
2102       auto *AssumedInst = cast<Instruction>(AffectedCI->getOperand(0));
2103       if (AssumedInst->getFunction() != &OldFunc)
2104         return true;
2105     }
2106   }
2107   return false;
2108 }
2109 
2110 void CodeExtractor::excludeArgFromAggregate(Value *Arg) {
2111   ExcludeArgsFromAggregate.insert(Arg);
2112 }
2113