xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeExtractor.cpp (revision 3008333d442f4daf0318cb1d249240e086208d68)
1  //===- CodeExtractor.cpp - Pull code region into a new function -----------===//
2  //
3  // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4  // See https://llvm.org/LICENSE.txt for license information.
5  // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6  //
7  //===----------------------------------------------------------------------===//
8  //
9  // This file implements the interface to tear out a code region, such as an
10  // individual loop or a parallel section, into a new function, replacing it with
11  // a call to the new function.
12  //
13  //===----------------------------------------------------------------------===//
14  
15  #include "llvm/Transforms/Utils/CodeExtractor.h"
16  #include "llvm/ADT/ArrayRef.h"
17  #include "llvm/ADT/DenseMap.h"
18  #include "llvm/ADT/Optional.h"
19  #include "llvm/ADT/STLExtras.h"
20  #include "llvm/ADT/SetVector.h"
21  #include "llvm/ADT/SmallPtrSet.h"
22  #include "llvm/ADT/SmallVector.h"
23  #include "llvm/Analysis/AssumptionCache.h"
24  #include "llvm/Analysis/BlockFrequencyInfo.h"
25  #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
26  #include "llvm/Analysis/BranchProbabilityInfo.h"
27  #include "llvm/Analysis/LoopInfo.h"
28  #include "llvm/IR/Argument.h"
29  #include "llvm/IR/Attributes.h"
30  #include "llvm/IR/BasicBlock.h"
31  #include "llvm/IR/CFG.h"
32  #include "llvm/IR/Constant.h"
33  #include "llvm/IR/Constants.h"
34  #include "llvm/IR/DataLayout.h"
35  #include "llvm/IR/DerivedTypes.h"
36  #include "llvm/IR/Dominators.h"
37  #include "llvm/IR/Function.h"
38  #include "llvm/IR/GlobalValue.h"
39  #include "llvm/IR/InstrTypes.h"
40  #include "llvm/IR/Instruction.h"
41  #include "llvm/IR/Instructions.h"
42  #include "llvm/IR/IntrinsicInst.h"
43  #include "llvm/IR/Intrinsics.h"
44  #include "llvm/IR/LLVMContext.h"
45  #include "llvm/IR/MDBuilder.h"
46  #include "llvm/IR/Module.h"
47  #include "llvm/IR/PatternMatch.h"
48  #include "llvm/IR/Type.h"
49  #include "llvm/IR/User.h"
50  #include "llvm/IR/Value.h"
51  #include "llvm/IR/Verifier.h"
52  #include "llvm/Pass.h"
53  #include "llvm/Support/BlockFrequency.h"
54  #include "llvm/Support/BranchProbability.h"
55  #include "llvm/Support/Casting.h"
56  #include "llvm/Support/CommandLine.h"
57  #include "llvm/Support/Debug.h"
58  #include "llvm/Support/ErrorHandling.h"
59  #include "llvm/Support/raw_ostream.h"
60  #include "llvm/Transforms/Utils/BasicBlockUtils.h"
61  #include "llvm/Transforms/Utils/Local.h"
62  #include <cassert>
63  #include <cstdint>
64  #include <iterator>
65  #include <map>
66  #include <set>
67  #include <utility>
68  #include <vector>
69  
70  using namespace llvm;
71  using namespace llvm::PatternMatch;
72  using ProfileCount = Function::ProfileCount;
73  
74  #define DEBUG_TYPE "code-extractor"
75  
76  // Provide a command-line option to aggregate function arguments into a struct
77  // for functions produced by the code extractor. This is useful when converting
78  // extracted functions to pthread-based code, as only one argument (void*) can
79  // be passed in to pthread_create().
80  static cl::opt<bool>
81  AggregateArgsOpt("aggregate-extracted-args", cl::Hidden,
82                   cl::desc("Aggregate arguments to code-extracted functions"));
83  
84  /// Test whether a block is valid for extraction.
85  static bool isBlockValidForExtraction(const BasicBlock &BB,
86                                        const SetVector<BasicBlock *> &Result,
87                                        bool AllowVarArgs, bool AllowAlloca) {
88    // taking the address of a basic block moved to another function is illegal
89    if (BB.hasAddressTaken())
90      return false;
91  
92    // don't hoist code that uses another basicblock address, as it's likely to
93    // lead to unexpected behavior, like cross-function jumps
94    SmallPtrSet<User const *, 16> Visited;
95    SmallVector<User const *, 16> ToVisit;
96  
97    for (Instruction const &Inst : BB)
98      ToVisit.push_back(&Inst);
99  
100    while (!ToVisit.empty()) {
101      User const *Curr = ToVisit.pop_back_val();
102      if (!Visited.insert(Curr).second)
103        continue;
104      if (isa<BlockAddress const>(Curr))
105        return false; // even a reference to self is likely to be not compatible
106  
107      if (isa<Instruction>(Curr) && cast<Instruction>(Curr)->getParent() != &BB)
108        continue;
109  
110      for (auto const &U : Curr->operands()) {
111        if (auto *UU = dyn_cast<User>(U))
112          ToVisit.push_back(UU);
113      }
114    }
115  
116    // If explicitly requested, allow vastart and alloca. For invoke instructions
117    // verify that extraction is valid.
118    for (BasicBlock::const_iterator I = BB.begin(), E = BB.end(); I != E; ++I) {
119      if (isa<AllocaInst>(I)) {
120         if (!AllowAlloca)
121           return false;
122         continue;
123      }
124  
125      if (const auto *II = dyn_cast<InvokeInst>(I)) {
126        // Unwind destination (either a landingpad, catchswitch, or cleanuppad)
127        // must be a part of the subgraph which is being extracted.
128        if (auto *UBB = II->getUnwindDest())
129          if (!Result.count(UBB))
130            return false;
131        continue;
132      }
133  
134      // All catch handlers of a catchswitch instruction as well as the unwind
135      // destination must be in the subgraph.
136      if (const auto *CSI = dyn_cast<CatchSwitchInst>(I)) {
137        if (auto *UBB = CSI->getUnwindDest())
138          if (!Result.count(UBB))
139            return false;
140        for (auto *HBB : CSI->handlers())
141          if (!Result.count(const_cast<BasicBlock*>(HBB)))
142            return false;
143        continue;
144      }
145  
146      // Make sure that entire catch handler is within subgraph. It is sufficient
147      // to check that catch return's block is in the list.
148      if (const auto *CPI = dyn_cast<CatchPadInst>(I)) {
149        for (const auto *U : CPI->users())
150          if (const auto *CRI = dyn_cast<CatchReturnInst>(U))
151            if (!Result.count(const_cast<BasicBlock*>(CRI->getParent())))
152              return false;
153        continue;
154      }
155  
156      // And do similar checks for cleanup handler - the entire handler must be
157      // in subgraph which is going to be extracted. For cleanup return should
158      // additionally check that the unwind destination is also in the subgraph.
159      if (const auto *CPI = dyn_cast<CleanupPadInst>(I)) {
160        for (const auto *U : CPI->users())
161          if (const auto *CRI = dyn_cast<CleanupReturnInst>(U))
162            if (!Result.count(const_cast<BasicBlock*>(CRI->getParent())))
163              return false;
164        continue;
165      }
166      if (const auto *CRI = dyn_cast<CleanupReturnInst>(I)) {
167        if (auto *UBB = CRI->getUnwindDest())
168          if (!Result.count(UBB))
169            return false;
170        continue;
171      }
172  
173      if (const CallInst *CI = dyn_cast<CallInst>(I)) {
174        if (const Function *F = CI->getCalledFunction()) {
175          auto IID = F->getIntrinsicID();
176          if (IID == Intrinsic::vastart) {
177            if (AllowVarArgs)
178              continue;
179            else
180              return false;
181          }
182  
183          // Currently, we miscompile outlined copies of eh_typid_for. There are
184          // proposals for fixing this in llvm.org/PR39545.
185          if (IID == Intrinsic::eh_typeid_for)
186            return false;
187        }
188      }
189    }
190  
191    return true;
192  }
193  
194  /// Build a set of blocks to extract if the input blocks are viable.
195  static SetVector<BasicBlock *>
196  buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
197                          bool AllowVarArgs, bool AllowAlloca) {
198    assert(!BBs.empty() && "The set of blocks to extract must be non-empty");
199    SetVector<BasicBlock *> Result;
200  
201    // Loop over the blocks, adding them to our set-vector, and aborting with an
202    // empty set if we encounter invalid blocks.
203    for (BasicBlock *BB : BBs) {
204      // If this block is dead, don't process it.
205      if (DT && !DT->isReachableFromEntry(BB))
206        continue;
207  
208      if (!Result.insert(BB))
209        llvm_unreachable("Repeated basic blocks in extraction input");
210    }
211  
212    LLVM_DEBUG(dbgs() << "Region front block: " << Result.front()->getName()
213                      << '\n');
214  
215    for (auto *BB : Result) {
216      if (!isBlockValidForExtraction(*BB, Result, AllowVarArgs, AllowAlloca))
217        return {};
218  
219      // Make sure that the first block is not a landing pad.
220      if (BB == Result.front()) {
221        if (BB->isEHPad()) {
222          LLVM_DEBUG(dbgs() << "The first block cannot be an unwind block\n");
223          return {};
224        }
225        continue;
226      }
227  
228      // All blocks other than the first must not have predecessors outside of
229      // the subgraph which is being extracted.
230      for (auto *PBB : predecessors(BB))
231        if (!Result.count(PBB)) {
232          LLVM_DEBUG(dbgs() << "No blocks in this region may have entries from "
233                               "outside the region except for the first block!\n"
234                            << "Problematic source BB: " << BB->getName() << "\n"
235                            << "Problematic destination BB: " << PBB->getName()
236                            << "\n");
237          return {};
238        }
239    }
240  
241    return Result;
242  }
243  
244  CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
245                               bool AggregateArgs, BlockFrequencyInfo *BFI,
246                               BranchProbabilityInfo *BPI, AssumptionCache *AC,
247                               bool AllowVarArgs, bool AllowAlloca,
248                               std::string Suffix)
249      : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
250        BPI(BPI), AC(AC), AllowVarArgs(AllowVarArgs),
251        Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
252        Suffix(Suffix) {}
253  
254  CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs,
255                               BlockFrequencyInfo *BFI,
256                               BranchProbabilityInfo *BPI, AssumptionCache *AC,
257                               std::string Suffix)
258      : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
259        BPI(BPI), AC(AC), AllowVarArgs(false),
260        Blocks(buildExtractionBlockSet(L.getBlocks(), &DT,
261                                       /* AllowVarArgs */ false,
262                                       /* AllowAlloca */ false)),
263        Suffix(Suffix) {}
264  
265  /// definedInRegion - Return true if the specified value is defined in the
266  /// extracted region.
267  static bool definedInRegion(const SetVector<BasicBlock *> &Blocks, Value *V) {
268    if (Instruction *I = dyn_cast<Instruction>(V))
269      if (Blocks.count(I->getParent()))
270        return true;
271    return false;
272  }
273  
274  /// definedInCaller - Return true if the specified value is defined in the
275  /// function being code extracted, but not in the region being extracted.
276  /// These values must be passed in as live-ins to the function.
277  static bool definedInCaller(const SetVector<BasicBlock *> &Blocks, Value *V) {
278    if (isa<Argument>(V)) return true;
279    if (Instruction *I = dyn_cast<Instruction>(V))
280      if (!Blocks.count(I->getParent()))
281        return true;
282    return false;
283  }
284  
285  static BasicBlock *getCommonExitBlock(const SetVector<BasicBlock *> &Blocks) {
286    BasicBlock *CommonExitBlock = nullptr;
287    auto hasNonCommonExitSucc = [&](BasicBlock *Block) {
288      for (auto *Succ : successors(Block)) {
289        // Internal edges, ok.
290        if (Blocks.count(Succ))
291          continue;
292        if (!CommonExitBlock) {
293          CommonExitBlock = Succ;
294          continue;
295        }
296        if (CommonExitBlock != Succ)
297          return true;
298      }
299      return false;
300    };
301  
302    if (any_of(Blocks, hasNonCommonExitSucc))
303      return nullptr;
304  
305    return CommonExitBlock;
306  }
307  
308  CodeExtractorAnalysisCache::CodeExtractorAnalysisCache(Function &F) {
309    for (BasicBlock &BB : F) {
310      for (Instruction &II : BB.instructionsWithoutDebug())
311        if (auto *AI = dyn_cast<AllocaInst>(&II))
312          Allocas.push_back(AI);
313  
314      findSideEffectInfoForBlock(BB);
315    }
316  }
317  
318  void CodeExtractorAnalysisCache::findSideEffectInfoForBlock(BasicBlock &BB) {
319    for (Instruction &II : BB.instructionsWithoutDebug()) {
320      unsigned Opcode = II.getOpcode();
321      Value *MemAddr = nullptr;
322      switch (Opcode) {
323      case Instruction::Store:
324      case Instruction::Load: {
325        if (Opcode == Instruction::Store) {
326          StoreInst *SI = cast<StoreInst>(&II);
327          MemAddr = SI->getPointerOperand();
328        } else {
329          LoadInst *LI = cast<LoadInst>(&II);
330          MemAddr = LI->getPointerOperand();
331        }
332        // Global variable can not be aliased with locals.
333        if (dyn_cast<Constant>(MemAddr))
334          break;
335        Value *Base = MemAddr->stripInBoundsConstantOffsets();
336        if (!isa<AllocaInst>(Base)) {
337          SideEffectingBlocks.insert(&BB);
338          return;
339        }
340        BaseMemAddrs[&BB].insert(Base);
341        break;
342      }
343      default: {
344        IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(&II);
345        if (IntrInst) {
346          if (IntrInst->isLifetimeStartOrEnd())
347            break;
348          SideEffectingBlocks.insert(&BB);
349          return;
350        }
351        // Treat all the other cases conservatively if it has side effects.
352        if (II.mayHaveSideEffects()) {
353          SideEffectingBlocks.insert(&BB);
354          return;
355        }
356      }
357      }
358    }
359  }
360  
361  bool CodeExtractorAnalysisCache::doesBlockContainClobberOfAddr(
362      BasicBlock &BB, AllocaInst *Addr) const {
363    if (SideEffectingBlocks.count(&BB))
364      return true;
365    auto It = BaseMemAddrs.find(&BB);
366    if (It != BaseMemAddrs.end())
367      return It->second.count(Addr);
368    return false;
369  }
370  
371  bool CodeExtractor::isLegalToShrinkwrapLifetimeMarkers(
372      const CodeExtractorAnalysisCache &CEAC, Instruction *Addr) const {
373    AllocaInst *AI = cast<AllocaInst>(Addr->stripInBoundsConstantOffsets());
374    Function *Func = (*Blocks.begin())->getParent();
375    for (BasicBlock &BB : *Func) {
376      if (Blocks.count(&BB))
377        continue;
378      if (CEAC.doesBlockContainClobberOfAddr(BB, AI))
379        return false;
380    }
381    return true;
382  }
383  
384  BasicBlock *
385  CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) {
386    BasicBlock *SinglePredFromOutlineRegion = nullptr;
387    assert(!Blocks.count(CommonExitBlock) &&
388           "Expect a block outside the region!");
389    for (auto *Pred : predecessors(CommonExitBlock)) {
390      if (!Blocks.count(Pred))
391        continue;
392      if (!SinglePredFromOutlineRegion) {
393        SinglePredFromOutlineRegion = Pred;
394      } else if (SinglePredFromOutlineRegion != Pred) {
395        SinglePredFromOutlineRegion = nullptr;
396        break;
397      }
398    }
399  
400    if (SinglePredFromOutlineRegion)
401      return SinglePredFromOutlineRegion;
402  
403  #ifndef NDEBUG
404    auto getFirstPHI = [](BasicBlock *BB) {
405      BasicBlock::iterator I = BB->begin();
406      PHINode *FirstPhi = nullptr;
407      while (I != BB->end()) {
408        PHINode *Phi = dyn_cast<PHINode>(I);
409        if (!Phi)
410          break;
411        if (!FirstPhi) {
412          FirstPhi = Phi;
413          break;
414        }
415      }
416      return FirstPhi;
417    };
418    // If there are any phi nodes, the single pred either exists or has already
419    // be created before code extraction.
420    assert(!getFirstPHI(CommonExitBlock) && "Phi not expected");
421  #endif
422  
423    BasicBlock *NewExitBlock = CommonExitBlock->splitBasicBlock(
424        CommonExitBlock->getFirstNonPHI()->getIterator());
425  
426    for (auto PI = pred_begin(CommonExitBlock), PE = pred_end(CommonExitBlock);
427         PI != PE;) {
428      BasicBlock *Pred = *PI++;
429      if (Blocks.count(Pred))
430        continue;
431      Pred->getTerminator()->replaceUsesOfWith(CommonExitBlock, NewExitBlock);
432    }
433    // Now add the old exit block to the outline region.
434    Blocks.insert(CommonExitBlock);
435    return CommonExitBlock;
436  }
437  
438  // Find the pair of life time markers for address 'Addr' that are either
439  // defined inside the outline region or can legally be shrinkwrapped into the
440  // outline region. If there are not other untracked uses of the address, return
441  // the pair of markers if found; otherwise return a pair of nullptr.
442  CodeExtractor::LifetimeMarkerInfo
443  CodeExtractor::getLifetimeMarkers(const CodeExtractorAnalysisCache &CEAC,
444                                    Instruction *Addr,
445                                    BasicBlock *ExitBlock) const {
446    LifetimeMarkerInfo Info;
447  
448    for (User *U : Addr->users()) {
449      IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(U);
450      if (IntrInst) {
451        if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start) {
452          // Do not handle the case where Addr has multiple start markers.
453          if (Info.LifeStart)
454            return {};
455          Info.LifeStart = IntrInst;
456        }
457        if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_end) {
458          if (Info.LifeEnd)
459            return {};
460          Info.LifeEnd = IntrInst;
461        }
462        continue;
463      }
464      // Find untracked uses of the address, bail.
465      if (!definedInRegion(Blocks, U))
466        return {};
467    }
468  
469    if (!Info.LifeStart || !Info.LifeEnd)
470      return {};
471  
472    Info.SinkLifeStart = !definedInRegion(Blocks, Info.LifeStart);
473    Info.HoistLifeEnd = !definedInRegion(Blocks, Info.LifeEnd);
474    // Do legality check.
475    if ((Info.SinkLifeStart || Info.HoistLifeEnd) &&
476        !isLegalToShrinkwrapLifetimeMarkers(CEAC, Addr))
477      return {};
478  
479    // Check to see if we have a place to do hoisting, if not, bail.
480    if (Info.HoistLifeEnd && !ExitBlock)
481      return {};
482  
483    return Info;
484  }
485  
486  void CodeExtractor::findAllocas(const CodeExtractorAnalysisCache &CEAC,
487                                  ValueSet &SinkCands, ValueSet &HoistCands,
488                                  BasicBlock *&ExitBlock) const {
489    Function *Func = (*Blocks.begin())->getParent();
490    ExitBlock = getCommonExitBlock(Blocks);
491  
492    auto moveOrIgnoreLifetimeMarkers =
493        [&](const LifetimeMarkerInfo &LMI) -> bool {
494      if (!LMI.LifeStart)
495        return false;
496      if (LMI.SinkLifeStart) {
497        LLVM_DEBUG(dbgs() << "Sinking lifetime.start: " << *LMI.LifeStart
498                          << "\n");
499        SinkCands.insert(LMI.LifeStart);
500      }
501      if (LMI.HoistLifeEnd) {
502        LLVM_DEBUG(dbgs() << "Hoisting lifetime.end: " << *LMI.LifeEnd << "\n");
503        HoistCands.insert(LMI.LifeEnd);
504      }
505      return true;
506    };
507  
508    // Look up allocas in the original function in CodeExtractorAnalysisCache, as
509    // this is much faster than walking all the instructions.
510    for (AllocaInst *AI : CEAC.getAllocas()) {
511      BasicBlock *BB = AI->getParent();
512      if (Blocks.count(BB))
513        continue;
514  
515      // As a prior call to extractCodeRegion() may have shrinkwrapped the alloca,
516      // check whether it is actually still in the original function.
517      Function *AIFunc = BB->getParent();
518      if (AIFunc != Func)
519        continue;
520  
521      LifetimeMarkerInfo MarkerInfo = getLifetimeMarkers(CEAC, AI, ExitBlock);
522      bool Moved = moveOrIgnoreLifetimeMarkers(MarkerInfo);
523      if (Moved) {
524        LLVM_DEBUG(dbgs() << "Sinking alloca: " << *AI << "\n");
525        SinkCands.insert(AI);
526        continue;
527      }
528  
529      // Follow any bitcasts.
530      SmallVector<Instruction *, 2> Bitcasts;
531      SmallVector<LifetimeMarkerInfo, 2> BitcastLifetimeInfo;
532      for (User *U : AI->users()) {
533        if (U->stripInBoundsConstantOffsets() == AI) {
534          Instruction *Bitcast = cast<Instruction>(U);
535          LifetimeMarkerInfo LMI = getLifetimeMarkers(CEAC, Bitcast, ExitBlock);
536          if (LMI.LifeStart) {
537            Bitcasts.push_back(Bitcast);
538            BitcastLifetimeInfo.push_back(LMI);
539            continue;
540          }
541        }
542  
543        // Found unknown use of AI.
544        if (!definedInRegion(Blocks, U)) {
545          Bitcasts.clear();
546          break;
547        }
548      }
549  
550      // Either no bitcasts reference the alloca or there are unknown uses.
551      if (Bitcasts.empty())
552        continue;
553  
554      LLVM_DEBUG(dbgs() << "Sinking alloca (via bitcast): " << *AI << "\n");
555      SinkCands.insert(AI);
556      for (unsigned I = 0, E = Bitcasts.size(); I != E; ++I) {
557        Instruction *BitcastAddr = Bitcasts[I];
558        const LifetimeMarkerInfo &LMI = BitcastLifetimeInfo[I];
559        assert(LMI.LifeStart &&
560               "Unsafe to sink bitcast without lifetime markers");
561        moveOrIgnoreLifetimeMarkers(LMI);
562        if (!definedInRegion(Blocks, BitcastAddr)) {
563          LLVM_DEBUG(dbgs() << "Sinking bitcast-of-alloca: " << *BitcastAddr
564                            << "\n");
565          SinkCands.insert(BitcastAddr);
566        }
567      }
568    }
569  }
570  
571  bool CodeExtractor::isEligible() const {
572    if (Blocks.empty())
573      return false;
574    BasicBlock *Header = *Blocks.begin();
575    Function *F = Header->getParent();
576  
577    // For functions with varargs, check that varargs handling is only done in the
578    // outlined function, i.e vastart and vaend are only used in outlined blocks.
579    if (AllowVarArgs && F->getFunctionType()->isVarArg()) {
580      auto containsVarArgIntrinsic = [](const Instruction &I) {
581        if (const CallInst *CI = dyn_cast<CallInst>(&I))
582          if (const Function *Callee = CI->getCalledFunction())
583            return Callee->getIntrinsicID() == Intrinsic::vastart ||
584                   Callee->getIntrinsicID() == Intrinsic::vaend;
585        return false;
586      };
587  
588      for (auto &BB : *F) {
589        if (Blocks.count(&BB))
590          continue;
591        if (llvm::any_of(BB, containsVarArgIntrinsic))
592          return false;
593      }
594    }
595    return true;
596  }
597  
598  void CodeExtractor::findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs,
599                                        const ValueSet &SinkCands) const {
600    for (BasicBlock *BB : Blocks) {
601      // If a used value is defined outside the region, it's an input.  If an
602      // instruction is used outside the region, it's an output.
603      for (Instruction &II : *BB) {
604        for (auto &OI : II.operands()) {
605          Value *V = OI;
606          if (!SinkCands.count(V) && definedInCaller(Blocks, V))
607            Inputs.insert(V);
608        }
609  
610        for (User *U : II.users())
611          if (!definedInRegion(Blocks, U)) {
612            Outputs.insert(&II);
613            break;
614          }
615      }
616    }
617  }
618  
619  /// severSplitPHINodesOfEntry - If a PHI node has multiple inputs from outside
620  /// of the region, we need to split the entry block of the region so that the
621  /// PHI node is easier to deal with.
622  void CodeExtractor::severSplitPHINodesOfEntry(BasicBlock *&Header) {
623    unsigned NumPredsFromRegion = 0;
624    unsigned NumPredsOutsideRegion = 0;
625  
626    if (Header != &Header->getParent()->getEntryBlock()) {
627      PHINode *PN = dyn_cast<PHINode>(Header->begin());
628      if (!PN) return;  // No PHI nodes.
629  
630      // If the header node contains any PHI nodes, check to see if there is more
631      // than one entry from outside the region.  If so, we need to sever the
632      // header block into two.
633      for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
634        if (Blocks.count(PN->getIncomingBlock(i)))
635          ++NumPredsFromRegion;
636        else
637          ++NumPredsOutsideRegion;
638  
639      // If there is one (or fewer) predecessor from outside the region, we don't
640      // need to do anything special.
641      if (NumPredsOutsideRegion <= 1) return;
642    }
643  
644    // Otherwise, we need to split the header block into two pieces: one
645    // containing PHI nodes merging values from outside of the region, and a
646    // second that contains all of the code for the block and merges back any
647    // incoming values from inside of the region.
648    BasicBlock *NewBB = SplitBlock(Header, Header->getFirstNonPHI(), DT);
649  
650    // We only want to code extract the second block now, and it becomes the new
651    // header of the region.
652    BasicBlock *OldPred = Header;
653    Blocks.remove(OldPred);
654    Blocks.insert(NewBB);
655    Header = NewBB;
656  
657    // Okay, now we need to adjust the PHI nodes and any branches from within the
658    // region to go to the new header block instead of the old header block.
659    if (NumPredsFromRegion) {
660      PHINode *PN = cast<PHINode>(OldPred->begin());
661      // Loop over all of the predecessors of OldPred that are in the region,
662      // changing them to branch to NewBB instead.
663      for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
664        if (Blocks.count(PN->getIncomingBlock(i))) {
665          Instruction *TI = PN->getIncomingBlock(i)->getTerminator();
666          TI->replaceUsesOfWith(OldPred, NewBB);
667        }
668  
669      // Okay, everything within the region is now branching to the right block, we
670      // just have to update the PHI nodes now, inserting PHI nodes into NewBB.
671      BasicBlock::iterator AfterPHIs;
672      for (AfterPHIs = OldPred->begin(); isa<PHINode>(AfterPHIs); ++AfterPHIs) {
673        PHINode *PN = cast<PHINode>(AfterPHIs);
674        // Create a new PHI node in the new region, which has an incoming value
675        // from OldPred of PN.
676        PHINode *NewPN = PHINode::Create(PN->getType(), 1 + NumPredsFromRegion,
677                                         PN->getName() + ".ce", &NewBB->front());
678        PN->replaceAllUsesWith(NewPN);
679        NewPN->addIncoming(PN, OldPred);
680  
681        // Loop over all of the incoming value in PN, moving them to NewPN if they
682        // are from the extracted region.
683        for (unsigned i = 0; i != PN->getNumIncomingValues(); ++i) {
684          if (Blocks.count(PN->getIncomingBlock(i))) {
685            NewPN->addIncoming(PN->getIncomingValue(i), PN->getIncomingBlock(i));
686            PN->removeIncomingValue(i);
687            --i;
688          }
689        }
690      }
691    }
692  }
693  
694  /// severSplitPHINodesOfExits - if PHI nodes in exit blocks have inputs from
695  /// outlined region, we split these PHIs on two: one with inputs from region
696  /// and other with remaining incoming blocks; then first PHIs are placed in
697  /// outlined region.
698  void CodeExtractor::severSplitPHINodesOfExits(
699      const SmallPtrSetImpl<BasicBlock *> &Exits) {
700    for (BasicBlock *ExitBB : Exits) {
701      BasicBlock *NewBB = nullptr;
702  
703      for (PHINode &PN : ExitBB->phis()) {
704        // Find all incoming values from the outlining region.
705        SmallVector<unsigned, 2> IncomingVals;
706        for (unsigned i = 0; i < PN.getNumIncomingValues(); ++i)
707          if (Blocks.count(PN.getIncomingBlock(i)))
708            IncomingVals.push_back(i);
709  
710        // Do not process PHI if there is one (or fewer) predecessor from region.
711        // If PHI has exactly one predecessor from region, only this one incoming
712        // will be replaced on codeRepl block, so it should be safe to skip PHI.
713        if (IncomingVals.size() <= 1)
714          continue;
715  
716        // Create block for new PHIs and add it to the list of outlined if it
717        // wasn't done before.
718        if (!NewBB) {
719          NewBB = BasicBlock::Create(ExitBB->getContext(),
720                                     ExitBB->getName() + ".split",
721                                     ExitBB->getParent(), ExitBB);
722          SmallVector<BasicBlock *, 4> Preds(pred_begin(ExitBB),
723                                             pred_end(ExitBB));
724          for (BasicBlock *PredBB : Preds)
725            if (Blocks.count(PredBB))
726              PredBB->getTerminator()->replaceUsesOfWith(ExitBB, NewBB);
727          BranchInst::Create(ExitBB, NewBB);
728          Blocks.insert(NewBB);
729        }
730  
731        // Split this PHI.
732        PHINode *NewPN =
733            PHINode::Create(PN.getType(), IncomingVals.size(),
734                            PN.getName() + ".ce", NewBB->getFirstNonPHI());
735        for (unsigned i : IncomingVals)
736          NewPN->addIncoming(PN.getIncomingValue(i), PN.getIncomingBlock(i));
737        for (unsigned i : reverse(IncomingVals))
738          PN.removeIncomingValue(i, false);
739        PN.addIncoming(NewPN, NewBB);
740      }
741    }
742  }
743  
744  void CodeExtractor::splitReturnBlocks() {
745    for (BasicBlock *Block : Blocks)
746      if (ReturnInst *RI = dyn_cast<ReturnInst>(Block->getTerminator())) {
747        BasicBlock *New =
748            Block->splitBasicBlock(RI->getIterator(), Block->getName() + ".ret");
749        if (DT) {
750          // Old dominates New. New node dominates all other nodes dominated
751          // by Old.
752          DomTreeNode *OldNode = DT->getNode(Block);
753          SmallVector<DomTreeNode *, 8> Children(OldNode->begin(),
754                                                 OldNode->end());
755  
756          DomTreeNode *NewNode = DT->addNewBlock(New, Block);
757  
758          for (DomTreeNode *I : Children)
759            DT->changeImmediateDominator(I, NewNode);
760        }
761      }
762  }
763  
764  /// constructFunction - make a function based on inputs and outputs, as follows:
765  /// f(in0, ..., inN, out0, ..., outN)
766  Function *CodeExtractor::constructFunction(const ValueSet &inputs,
767                                             const ValueSet &outputs,
768                                             BasicBlock *header,
769                                             BasicBlock *newRootNode,
770                                             BasicBlock *newHeader,
771                                             Function *oldFunction,
772                                             Module *M) {
773    LLVM_DEBUG(dbgs() << "inputs: " << inputs.size() << "\n");
774    LLVM_DEBUG(dbgs() << "outputs: " << outputs.size() << "\n");
775  
776    // This function returns unsigned, outputs will go back by reference.
777    switch (NumExitBlocks) {
778    case 0:
779    case 1: RetTy = Type::getVoidTy(header->getContext()); break;
780    case 2: RetTy = Type::getInt1Ty(header->getContext()); break;
781    default: RetTy = Type::getInt16Ty(header->getContext()); break;
782    }
783  
784    std::vector<Type *> paramTy;
785  
786    // Add the types of the input values to the function's argument list
787    for (Value *value : inputs) {
788      LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n");
789      paramTy.push_back(value->getType());
790    }
791  
792    // Add the types of the output values to the function's argument list.
793    for (Value *output : outputs) {
794      LLVM_DEBUG(dbgs() << "instr used in func: " << *output << "\n");
795      if (AggregateArgs)
796        paramTy.push_back(output->getType());
797      else
798        paramTy.push_back(PointerType::getUnqual(output->getType()));
799    }
800  
801    LLVM_DEBUG({
802      dbgs() << "Function type: " << *RetTy << " f(";
803      for (Type *i : paramTy)
804        dbgs() << *i << ", ";
805      dbgs() << ")\n";
806    });
807  
808    StructType *StructTy = nullptr;
809    if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
810      StructTy = StructType::get(M->getContext(), paramTy);
811      paramTy.clear();
812      paramTy.push_back(PointerType::getUnqual(StructTy));
813    }
814    FunctionType *funcType =
815                    FunctionType::get(RetTy, paramTy,
816                                      AllowVarArgs && oldFunction->isVarArg());
817  
818    std::string SuffixToUse =
819        Suffix.empty()
820            ? (header->getName().empty() ? "extracted" : header->getName().str())
821            : Suffix;
822    // Create the new function
823    Function *newFunction = Function::Create(
824        funcType, GlobalValue::InternalLinkage, oldFunction->getAddressSpace(),
825        oldFunction->getName() + "." + SuffixToUse, M);
826    // If the old function is no-throw, so is the new one.
827    if (oldFunction->doesNotThrow())
828      newFunction->setDoesNotThrow();
829  
830    // Inherit the uwtable attribute if we need to.
831    if (oldFunction->hasUWTable())
832      newFunction->setHasUWTable();
833  
834    // Inherit all of the target dependent attributes and white-listed
835    // target independent attributes.
836    //  (e.g. If the extracted region contains a call to an x86.sse
837    //  instruction we need to make sure that the extracted region has the
838    //  "target-features" attribute allowing it to be lowered.
839    // FIXME: This should be changed to check to see if a specific
840    //           attribute can not be inherited.
841    for (const auto &Attr : oldFunction->getAttributes().getFnAttributes()) {
842      if (Attr.isStringAttribute()) {
843        if (Attr.getKindAsString() == "thunk")
844          continue;
845      } else
846        switch (Attr.getKindAsEnum()) {
847        // Those attributes cannot be propagated safely. Explicitly list them
848        // here so we get a warning if new attributes are added. This list also
849        // includes non-function attributes.
850        case Attribute::Alignment:
851        case Attribute::AllocSize:
852        case Attribute::ArgMemOnly:
853        case Attribute::Builtin:
854        case Attribute::ByVal:
855        case Attribute::Convergent:
856        case Attribute::Dereferenceable:
857        case Attribute::DereferenceableOrNull:
858        case Attribute::InAlloca:
859        case Attribute::InReg:
860        case Attribute::InaccessibleMemOnly:
861        case Attribute::InaccessibleMemOrArgMemOnly:
862        case Attribute::JumpTable:
863        case Attribute::Naked:
864        case Attribute::Nest:
865        case Attribute::NoAlias:
866        case Attribute::NoBuiltin:
867        case Attribute::NoCapture:
868        case Attribute::NoReturn:
869        case Attribute::NoSync:
870        case Attribute::None:
871        case Attribute::NonNull:
872        case Attribute::ReadNone:
873        case Attribute::ReadOnly:
874        case Attribute::Returned:
875        case Attribute::ReturnsTwice:
876        case Attribute::SExt:
877        case Attribute::Speculatable:
878        case Attribute::StackAlignment:
879        case Attribute::StructRet:
880        case Attribute::SwiftError:
881        case Attribute::SwiftSelf:
882        case Attribute::WillReturn:
883        case Attribute::WriteOnly:
884        case Attribute::ZExt:
885        case Attribute::ImmArg:
886        case Attribute::EndAttrKinds:
887          continue;
888        // Those attributes should be safe to propagate to the extracted function.
889        case Attribute::AlwaysInline:
890        case Attribute::Cold:
891        case Attribute::NoRecurse:
892        case Attribute::InlineHint:
893        case Attribute::MinSize:
894        case Attribute::NoDuplicate:
895        case Attribute::NoFree:
896        case Attribute::NoImplicitFloat:
897        case Attribute::NoInline:
898        case Attribute::NonLazyBind:
899        case Attribute::NoRedZone:
900        case Attribute::NoUnwind:
901        case Attribute::OptForFuzzing:
902        case Attribute::OptimizeNone:
903        case Attribute::OptimizeForSize:
904        case Attribute::SafeStack:
905        case Attribute::ShadowCallStack:
906        case Attribute::SanitizeAddress:
907        case Attribute::SanitizeMemory:
908        case Attribute::SanitizeThread:
909        case Attribute::SanitizeHWAddress:
910        case Attribute::SanitizeMemTag:
911        case Attribute::SpeculativeLoadHardening:
912        case Attribute::StackProtect:
913        case Attribute::StackProtectReq:
914        case Attribute::StackProtectStrong:
915        case Attribute::StrictFP:
916        case Attribute::UWTable:
917        case Attribute::NoCfCheck:
918          break;
919        }
920  
921      newFunction->addFnAttr(Attr);
922    }
923    newFunction->getBasicBlockList().push_back(newRootNode);
924  
925    // Create an iterator to name all of the arguments we inserted.
926    Function::arg_iterator AI = newFunction->arg_begin();
927  
928    // Rewrite all users of the inputs in the extracted region to use the
929    // arguments (or appropriate addressing into struct) instead.
930    for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
931      Value *RewriteVal;
932      if (AggregateArgs) {
933        Value *Idx[2];
934        Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext()));
935        Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), i);
936        Instruction *TI = newFunction->begin()->getTerminator();
937        GetElementPtrInst *GEP = GetElementPtrInst::Create(
938            StructTy, &*AI, Idx, "gep_" + inputs[i]->getName(), TI);
939        RewriteVal = new LoadInst(StructTy->getElementType(i), GEP,
940                                  "loadgep_" + inputs[i]->getName(), TI);
941      } else
942        RewriteVal = &*AI++;
943  
944      std::vector<User *> Users(inputs[i]->user_begin(), inputs[i]->user_end());
945      for (User *use : Users)
946        if (Instruction *inst = dyn_cast<Instruction>(use))
947          if (Blocks.count(inst->getParent()))
948            inst->replaceUsesOfWith(inputs[i], RewriteVal);
949    }
950  
951    // Set names for input and output arguments.
952    if (!AggregateArgs) {
953      AI = newFunction->arg_begin();
954      for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++AI)
955        AI->setName(inputs[i]->getName());
956      for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++AI)
957        AI->setName(outputs[i]->getName()+".out");
958    }
959  
960    // Rewrite branches to basic blocks outside of the loop to new dummy blocks
961    // within the new function. This must be done before we lose track of which
962    // blocks were originally in the code region.
963    std::vector<User *> Users(header->user_begin(), header->user_end());
964    for (auto &U : Users)
965      // The BasicBlock which contains the branch is not in the region
966      // modify the branch target to a new block
967      if (Instruction *I = dyn_cast<Instruction>(U))
968        if (I->isTerminator() && I->getFunction() == oldFunction &&
969            !Blocks.count(I->getParent()))
970          I->replaceUsesOfWith(header, newHeader);
971  
972    return newFunction;
973  }
974  
975  /// Erase lifetime.start markers which reference inputs to the extraction
976  /// region, and insert the referenced memory into \p LifetimesStart.
977  ///
978  /// The extraction region is defined by a set of blocks (\p Blocks), and a set
979  /// of allocas which will be moved from the caller function into the extracted
980  /// function (\p SunkAllocas).
981  static void eraseLifetimeMarkersOnInputs(const SetVector<BasicBlock *> &Blocks,
982                                           const SetVector<Value *> &SunkAllocas,
983                                           SetVector<Value *> &LifetimesStart) {
984    for (BasicBlock *BB : Blocks) {
985      for (auto It = BB->begin(), End = BB->end(); It != End;) {
986        auto *II = dyn_cast<IntrinsicInst>(&*It);
987        ++It;
988        if (!II || !II->isLifetimeStartOrEnd())
989          continue;
990  
991        // Get the memory operand of the lifetime marker. If the underlying
992        // object is a sunk alloca, or is otherwise defined in the extraction
993        // region, the lifetime marker must not be erased.
994        Value *Mem = II->getOperand(1)->stripInBoundsOffsets();
995        if (SunkAllocas.count(Mem) || definedInRegion(Blocks, Mem))
996          continue;
997  
998        if (II->getIntrinsicID() == Intrinsic::lifetime_start)
999          LifetimesStart.insert(Mem);
1000        II->eraseFromParent();
1001      }
1002    }
1003  }
1004  
1005  /// Insert lifetime start/end markers surrounding the call to the new function
1006  /// for objects defined in the caller.
1007  static void insertLifetimeMarkersSurroundingCall(
1008      Module *M, ArrayRef<Value *> LifetimesStart, ArrayRef<Value *> LifetimesEnd,
1009      CallInst *TheCall) {
1010    LLVMContext &Ctx = M->getContext();
1011    auto Int8PtrTy = Type::getInt8PtrTy(Ctx);
1012    auto NegativeOne = ConstantInt::getSigned(Type::getInt64Ty(Ctx), -1);
1013    Instruction *Term = TheCall->getParent()->getTerminator();
1014  
1015    // The memory argument to a lifetime marker must be a i8*. Cache any bitcasts
1016    // needed to satisfy this requirement so they may be reused.
1017    DenseMap<Value *, Value *> Bitcasts;
1018  
1019    // Emit lifetime markers for the pointers given in \p Objects. Insert the
1020    // markers before the call if \p InsertBefore, and after the call otherwise.
1021    auto insertMarkers = [&](Function *MarkerFunc, ArrayRef<Value *> Objects,
1022                             bool InsertBefore) {
1023      for (Value *Mem : Objects) {
1024        assert((!isa<Instruction>(Mem) || cast<Instruction>(Mem)->getFunction() ==
1025                                              TheCall->getFunction()) &&
1026               "Input memory not defined in original function");
1027        Value *&MemAsI8Ptr = Bitcasts[Mem];
1028        if (!MemAsI8Ptr) {
1029          if (Mem->getType() == Int8PtrTy)
1030            MemAsI8Ptr = Mem;
1031          else
1032            MemAsI8Ptr =
1033                CastInst::CreatePointerCast(Mem, Int8PtrTy, "lt.cast", TheCall);
1034        }
1035  
1036        auto Marker = CallInst::Create(MarkerFunc, {NegativeOne, MemAsI8Ptr});
1037        if (InsertBefore)
1038          Marker->insertBefore(TheCall);
1039        else
1040          Marker->insertBefore(Term);
1041      }
1042    };
1043  
1044    if (!LifetimesStart.empty()) {
1045      auto StartFn = llvm::Intrinsic::getDeclaration(
1046          M, llvm::Intrinsic::lifetime_start, Int8PtrTy);
1047      insertMarkers(StartFn, LifetimesStart, /*InsertBefore=*/true);
1048    }
1049  
1050    if (!LifetimesEnd.empty()) {
1051      auto EndFn = llvm::Intrinsic::getDeclaration(
1052          M, llvm::Intrinsic::lifetime_end, Int8PtrTy);
1053      insertMarkers(EndFn, LifetimesEnd, /*InsertBefore=*/false);
1054    }
1055  }
1056  
1057  /// emitCallAndSwitchStatement - This method sets up the caller side by adding
1058  /// the call instruction, splitting any PHI nodes in the header block as
1059  /// necessary.
1060  CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
1061                                                      BasicBlock *codeReplacer,
1062                                                      ValueSet &inputs,
1063                                                      ValueSet &outputs) {
1064    // Emit a call to the new function, passing in: *pointer to struct (if
1065    // aggregating parameters), or plan inputs and allocated memory for outputs
1066    std::vector<Value *> params, StructValues, ReloadOutputs, Reloads;
1067  
1068    Module *M = newFunction->getParent();
1069    LLVMContext &Context = M->getContext();
1070    const DataLayout &DL = M->getDataLayout();
1071    CallInst *call = nullptr;
1072  
1073    // Add inputs as params, or to be filled into the struct
1074    unsigned ArgNo = 0;
1075    SmallVector<unsigned, 1> SwiftErrorArgs;
1076    for (Value *input : inputs) {
1077      if (AggregateArgs)
1078        StructValues.push_back(input);
1079      else {
1080        params.push_back(input);
1081        if (input->isSwiftError())
1082          SwiftErrorArgs.push_back(ArgNo);
1083      }
1084      ++ArgNo;
1085    }
1086  
1087    // Create allocas for the outputs
1088    for (Value *output : outputs) {
1089      if (AggregateArgs) {
1090        StructValues.push_back(output);
1091      } else {
1092        AllocaInst *alloca =
1093          new AllocaInst(output->getType(), DL.getAllocaAddrSpace(),
1094                         nullptr, output->getName() + ".loc",
1095                         &codeReplacer->getParent()->front().front());
1096        ReloadOutputs.push_back(alloca);
1097        params.push_back(alloca);
1098      }
1099    }
1100  
1101    StructType *StructArgTy = nullptr;
1102    AllocaInst *Struct = nullptr;
1103    if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
1104      std::vector<Type *> ArgTypes;
1105      for (ValueSet::iterator v = StructValues.begin(),
1106             ve = StructValues.end(); v != ve; ++v)
1107        ArgTypes.push_back((*v)->getType());
1108  
1109      // Allocate a struct at the beginning of this function
1110      StructArgTy = StructType::get(newFunction->getContext(), ArgTypes);
1111      Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
1112                              "structArg",
1113                              &codeReplacer->getParent()->front().front());
1114      params.push_back(Struct);
1115  
1116      for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
1117        Value *Idx[2];
1118        Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
1119        Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i);
1120        GetElementPtrInst *GEP = GetElementPtrInst::Create(
1121            StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName());
1122        codeReplacer->getInstList().push_back(GEP);
1123        StoreInst *SI = new StoreInst(StructValues[i], GEP);
1124        codeReplacer->getInstList().push_back(SI);
1125      }
1126    }
1127  
1128    // Emit the call to the function
1129    call = CallInst::Create(newFunction, params,
1130                            NumExitBlocks > 1 ? "targetBlock" : "");
1131    // Add debug location to the new call, if the original function has debug
1132    // info. In that case, the terminator of the entry block of the extracted
1133    // function contains the first debug location of the extracted function,
1134    // set in extractCodeRegion.
1135    if (codeReplacer->getParent()->getSubprogram()) {
1136      if (auto DL = newFunction->getEntryBlock().getTerminator()->getDebugLoc())
1137        call->setDebugLoc(DL);
1138    }
1139    codeReplacer->getInstList().push_back(call);
1140  
1141    // Set swifterror parameter attributes.
1142    for (unsigned SwiftErrArgNo : SwiftErrorArgs) {
1143      call->addParamAttr(SwiftErrArgNo, Attribute::SwiftError);
1144      newFunction->addParamAttr(SwiftErrArgNo, Attribute::SwiftError);
1145    }
1146  
1147    Function::arg_iterator OutputArgBegin = newFunction->arg_begin();
1148    unsigned FirstOut = inputs.size();
1149    if (!AggregateArgs)
1150      std::advance(OutputArgBegin, inputs.size());
1151  
1152    // Reload the outputs passed in by reference.
1153    for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
1154      Value *Output = nullptr;
1155      if (AggregateArgs) {
1156        Value *Idx[2];
1157        Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
1158        Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
1159        GetElementPtrInst *GEP = GetElementPtrInst::Create(
1160            StructArgTy, Struct, Idx, "gep_reload_" + outputs[i]->getName());
1161        codeReplacer->getInstList().push_back(GEP);
1162        Output = GEP;
1163      } else {
1164        Output = ReloadOutputs[i];
1165      }
1166      LoadInst *load = new LoadInst(outputs[i]->getType(), Output,
1167                                    outputs[i]->getName() + ".reload");
1168      Reloads.push_back(load);
1169      codeReplacer->getInstList().push_back(load);
1170      std::vector<User *> Users(outputs[i]->user_begin(), outputs[i]->user_end());
1171      for (unsigned u = 0, e = Users.size(); u != e; ++u) {
1172        Instruction *inst = cast<Instruction>(Users[u]);
1173        if (!Blocks.count(inst->getParent()))
1174          inst->replaceUsesOfWith(outputs[i], load);
1175      }
1176    }
1177  
1178    // Now we can emit a switch statement using the call as a value.
1179    SwitchInst *TheSwitch =
1180        SwitchInst::Create(Constant::getNullValue(Type::getInt16Ty(Context)),
1181                           codeReplacer, 0, codeReplacer);
1182  
1183    // Since there may be multiple exits from the original region, make the new
1184    // function return an unsigned, switch on that number.  This loop iterates
1185    // over all of the blocks in the extracted region, updating any terminator
1186    // instructions in the to-be-extracted region that branch to blocks that are
1187    // not in the region to be extracted.
1188    std::map<BasicBlock *, BasicBlock *> ExitBlockMap;
1189  
1190    unsigned switchVal = 0;
1191    for (BasicBlock *Block : Blocks) {
1192      Instruction *TI = Block->getTerminator();
1193      for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
1194        if (!Blocks.count(TI->getSuccessor(i))) {
1195          BasicBlock *OldTarget = TI->getSuccessor(i);
1196          // add a new basic block which returns the appropriate value
1197          BasicBlock *&NewTarget = ExitBlockMap[OldTarget];
1198          if (!NewTarget) {
1199            // If we don't already have an exit stub for this non-extracted
1200            // destination, create one now!
1201            NewTarget = BasicBlock::Create(Context,
1202                                           OldTarget->getName() + ".exitStub",
1203                                           newFunction);
1204            unsigned SuccNum = switchVal++;
1205  
1206            Value *brVal = nullptr;
1207            switch (NumExitBlocks) {
1208            case 0:
1209            case 1: break;  // No value needed.
1210            case 2:         // Conditional branch, return a bool
1211              brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum);
1212              break;
1213            default:
1214              brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum);
1215              break;
1216            }
1217  
1218            ReturnInst::Create(Context, brVal, NewTarget);
1219  
1220            // Update the switch instruction.
1221            TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context),
1222                                                SuccNum),
1223                               OldTarget);
1224          }
1225  
1226          // rewrite the original branch instruction with this new target
1227          TI->setSuccessor(i, NewTarget);
1228        }
1229    }
1230  
1231    // Store the arguments right after the definition of output value.
1232    // This should be proceeded after creating exit stubs to be ensure that invoke
1233    // result restore will be placed in the outlined function.
1234    Function::arg_iterator OAI = OutputArgBegin;
1235    for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
1236      auto *OutI = dyn_cast<Instruction>(outputs[i]);
1237      if (!OutI)
1238        continue;
1239  
1240      // Find proper insertion point.
1241      BasicBlock::iterator InsertPt;
1242      // In case OutI is an invoke, we insert the store at the beginning in the
1243      // 'normal destination' BB. Otherwise we insert the store right after OutI.
1244      if (auto *InvokeI = dyn_cast<InvokeInst>(OutI))
1245        InsertPt = InvokeI->getNormalDest()->getFirstInsertionPt();
1246      else if (auto *Phi = dyn_cast<PHINode>(OutI))
1247        InsertPt = Phi->getParent()->getFirstInsertionPt();
1248      else
1249        InsertPt = std::next(OutI->getIterator());
1250  
1251      Instruction *InsertBefore = &*InsertPt;
1252      assert((InsertBefore->getFunction() == newFunction ||
1253              Blocks.count(InsertBefore->getParent())) &&
1254             "InsertPt should be in new function");
1255      assert(OAI != newFunction->arg_end() &&
1256             "Number of output arguments should match "
1257             "the amount of defined values");
1258      if (AggregateArgs) {
1259        Value *Idx[2];
1260        Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
1261        Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
1262        GetElementPtrInst *GEP = GetElementPtrInst::Create(
1263            StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(),
1264            InsertBefore);
1265        new StoreInst(outputs[i], GEP, InsertBefore);
1266        // Since there should be only one struct argument aggregating
1267        // all the output values, we shouldn't increment OAI, which always
1268        // points to the struct argument, in this case.
1269      } else {
1270        new StoreInst(outputs[i], &*OAI, InsertBefore);
1271        ++OAI;
1272      }
1273    }
1274  
1275    // Now that we've done the deed, simplify the switch instruction.
1276    Type *OldFnRetTy = TheSwitch->getParent()->getParent()->getReturnType();
1277    switch (NumExitBlocks) {
1278    case 0:
1279      // There are no successors (the block containing the switch itself), which
1280      // means that previously this was the last part of the function, and hence
1281      // this should be rewritten as a `ret'
1282  
1283      // Check if the function should return a value
1284      if (OldFnRetTy->isVoidTy()) {
1285        ReturnInst::Create(Context, nullptr, TheSwitch);  // Return void
1286      } else if (OldFnRetTy == TheSwitch->getCondition()->getType()) {
1287        // return what we have
1288        ReturnInst::Create(Context, TheSwitch->getCondition(), TheSwitch);
1289      } else {
1290        // Otherwise we must have code extracted an unwind or something, just
1291        // return whatever we want.
1292        ReturnInst::Create(Context,
1293                           Constant::getNullValue(OldFnRetTy), TheSwitch);
1294      }
1295  
1296      TheSwitch->eraseFromParent();
1297      break;
1298    case 1:
1299      // Only a single destination, change the switch into an unconditional
1300      // branch.
1301      BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch);
1302      TheSwitch->eraseFromParent();
1303      break;
1304    case 2:
1305      BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getSuccessor(2),
1306                         call, TheSwitch);
1307      TheSwitch->eraseFromParent();
1308      break;
1309    default:
1310      // Otherwise, make the default destination of the switch instruction be one
1311      // of the other successors.
1312      TheSwitch->setCondition(call);
1313      TheSwitch->setDefaultDest(TheSwitch->getSuccessor(NumExitBlocks));
1314      // Remove redundant case
1315      TheSwitch->removeCase(SwitchInst::CaseIt(TheSwitch, NumExitBlocks-1));
1316      break;
1317    }
1318  
1319    // Insert lifetime markers around the reloads of any output values. The
1320    // allocas output values are stored in are only in-use in the codeRepl block.
1321    insertLifetimeMarkersSurroundingCall(M, ReloadOutputs, ReloadOutputs, call);
1322  
1323    return call;
1324  }
1325  
1326  void CodeExtractor::moveCodeToFunction(Function *newFunction) {
1327    Function *oldFunc = (*Blocks.begin())->getParent();
1328    Function::BasicBlockListType &oldBlocks = oldFunc->getBasicBlockList();
1329    Function::BasicBlockListType &newBlocks = newFunction->getBasicBlockList();
1330  
1331    for (BasicBlock *Block : Blocks) {
1332      // Delete the basic block from the old function, and the list of blocks
1333      oldBlocks.remove(Block);
1334  
1335      // Insert this basic block into the new function
1336      newBlocks.push_back(Block);
1337    }
1338  }
1339  
1340  void CodeExtractor::calculateNewCallTerminatorWeights(
1341      BasicBlock *CodeReplacer,
1342      DenseMap<BasicBlock *, BlockFrequency> &ExitWeights,
1343      BranchProbabilityInfo *BPI) {
1344    using Distribution = BlockFrequencyInfoImplBase::Distribution;
1345    using BlockNode = BlockFrequencyInfoImplBase::BlockNode;
1346  
1347    // Update the branch weights for the exit block.
1348    Instruction *TI = CodeReplacer->getTerminator();
1349    SmallVector<unsigned, 8> BranchWeights(TI->getNumSuccessors(), 0);
1350  
1351    // Block Frequency distribution with dummy node.
1352    Distribution BranchDist;
1353  
1354    // Add each of the frequencies of the successors.
1355    for (unsigned i = 0, e = TI->getNumSuccessors(); i < e; ++i) {
1356      BlockNode ExitNode(i);
1357      uint64_t ExitFreq = ExitWeights[TI->getSuccessor(i)].getFrequency();
1358      if (ExitFreq != 0)
1359        BranchDist.addExit(ExitNode, ExitFreq);
1360      else
1361        BPI->setEdgeProbability(CodeReplacer, i, BranchProbability::getZero());
1362    }
1363  
1364    // Check for no total weight.
1365    if (BranchDist.Total == 0)
1366      return;
1367  
1368    // Normalize the distribution so that they can fit in unsigned.
1369    BranchDist.normalize();
1370  
1371    // Create normalized branch weights and set the metadata.
1372    for (unsigned I = 0, E = BranchDist.Weights.size(); I < E; ++I) {
1373      const auto &Weight = BranchDist.Weights[I];
1374  
1375      // Get the weight and update the current BFI.
1376      BranchWeights[Weight.TargetNode.Index] = Weight.Amount;
1377      BranchProbability BP(Weight.Amount, BranchDist.Total);
1378      BPI->setEdgeProbability(CodeReplacer, Weight.TargetNode.Index, BP);
1379    }
1380    TI->setMetadata(
1381        LLVMContext::MD_prof,
1382        MDBuilder(TI->getContext()).createBranchWeights(BranchWeights));
1383  }
1384  
1385  Function *
1386  CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC) {
1387    if (!isEligible())
1388      return nullptr;
1389  
1390    // Assumption: this is a single-entry code region, and the header is the first
1391    // block in the region.
1392    BasicBlock *header = *Blocks.begin();
1393    Function *oldFunction = header->getParent();
1394  
1395    // Calculate the entry frequency of the new function before we change the root
1396    //   block.
1397    BlockFrequency EntryFreq;
1398    if (BFI) {
1399      assert(BPI && "Both BPI and BFI are required to preserve profile info");
1400      for (BasicBlock *Pred : predecessors(header)) {
1401        if (Blocks.count(Pred))
1402          continue;
1403        EntryFreq +=
1404            BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, header);
1405      }
1406    }
1407  
1408    if (AC) {
1409      // Remove @llvm.assume calls that were moved to the new function from the
1410      // old function's assumption cache.
1411      for (BasicBlock *Block : Blocks)
1412        for (auto &I : *Block)
1413          if (match(&I, m_Intrinsic<Intrinsic::assume>()))
1414            AC->unregisterAssumption(cast<CallInst>(&I));
1415    }
1416  
1417    // If we have any return instructions in the region, split those blocks so
1418    // that the return is not in the region.
1419    splitReturnBlocks();
1420  
1421    // Calculate the exit blocks for the extracted region and the total exit
1422    // weights for each of those blocks.
1423    DenseMap<BasicBlock *, BlockFrequency> ExitWeights;
1424    SmallPtrSet<BasicBlock *, 1> ExitBlocks;
1425    for (BasicBlock *Block : Blocks) {
1426      for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE;
1427           ++SI) {
1428        if (!Blocks.count(*SI)) {
1429          // Update the branch weight for this successor.
1430          if (BFI) {
1431            BlockFrequency &BF = ExitWeights[*SI];
1432            BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, *SI);
1433          }
1434          ExitBlocks.insert(*SI);
1435        }
1436      }
1437    }
1438    NumExitBlocks = ExitBlocks.size();
1439  
1440    // If we have to split PHI nodes of the entry or exit blocks, do so now.
1441    severSplitPHINodesOfEntry(header);
1442    severSplitPHINodesOfExits(ExitBlocks);
1443  
1444    // This takes place of the original loop
1445    BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(),
1446                                                  "codeRepl", oldFunction,
1447                                                  header);
1448  
1449    // The new function needs a root node because other nodes can branch to the
1450    // head of the region, but the entry node of a function cannot have preds.
1451    BasicBlock *newFuncRoot = BasicBlock::Create(header->getContext(),
1452                                                 "newFuncRoot");
1453    auto *BranchI = BranchInst::Create(header);
1454    // If the original function has debug info, we have to add a debug location
1455    // to the new branch instruction from the artificial entry block.
1456    // We use the debug location of the first instruction in the extracted
1457    // blocks, as there is no other equivalent line in the source code.
1458    if (oldFunction->getSubprogram()) {
1459      any_of(Blocks, [&BranchI](const BasicBlock *BB) {
1460        return any_of(*BB, [&BranchI](const Instruction &I) {
1461          if (!I.getDebugLoc())
1462            return false;
1463          BranchI->setDebugLoc(I.getDebugLoc());
1464          return true;
1465        });
1466      });
1467    }
1468    newFuncRoot->getInstList().push_back(BranchI);
1469  
1470    ValueSet inputs, outputs, SinkingCands, HoistingCands;
1471    BasicBlock *CommonExit = nullptr;
1472    findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
1473    assert(HoistingCands.empty() || CommonExit);
1474  
1475    // Find inputs to, outputs from the code region.
1476    findInputsOutputs(inputs, outputs, SinkingCands);
1477  
1478    // Now sink all instructions which only have non-phi uses inside the region.
1479    // Group the allocas at the start of the block, so that any bitcast uses of
1480    // the allocas are well-defined.
1481    AllocaInst *FirstSunkAlloca = nullptr;
1482    for (auto *II : SinkingCands) {
1483      if (auto *AI = dyn_cast<AllocaInst>(II)) {
1484        AI->moveBefore(*newFuncRoot, newFuncRoot->getFirstInsertionPt());
1485        if (!FirstSunkAlloca)
1486          FirstSunkAlloca = AI;
1487      }
1488    }
1489    assert((SinkingCands.empty() || FirstSunkAlloca) &&
1490           "Did not expect a sink candidate without any allocas");
1491    for (auto *II : SinkingCands) {
1492      if (!isa<AllocaInst>(II)) {
1493        cast<Instruction>(II)->moveAfter(FirstSunkAlloca);
1494      }
1495    }
1496  
1497    if (!HoistingCands.empty()) {
1498      auto *HoistToBlock = findOrCreateBlockForHoisting(CommonExit);
1499      Instruction *TI = HoistToBlock->getTerminator();
1500      for (auto *II : HoistingCands)
1501        cast<Instruction>(II)->moveBefore(TI);
1502    }
1503  
1504    // Collect objects which are inputs to the extraction region and also
1505    // referenced by lifetime start markers within it. The effects of these
1506    // markers must be replicated in the calling function to prevent the stack
1507    // coloring pass from merging slots which store input objects.
1508    ValueSet LifetimesStart;
1509    eraseLifetimeMarkersOnInputs(Blocks, SinkingCands, LifetimesStart);
1510  
1511    // Construct new function based on inputs/outputs & add allocas for all defs.
1512    Function *newFunction =
1513        constructFunction(inputs, outputs, header, newFuncRoot, codeReplacer,
1514                          oldFunction, oldFunction->getParent());
1515  
1516    // Update the entry count of the function.
1517    if (BFI) {
1518      auto Count = BFI->getProfileCountFromFreq(EntryFreq.getFrequency());
1519      if (Count.hasValue())
1520        newFunction->setEntryCount(
1521            ProfileCount(Count.getValue(), Function::PCT_Real)); // FIXME
1522      BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency());
1523    }
1524  
1525    CallInst *TheCall =
1526        emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs);
1527  
1528    moveCodeToFunction(newFunction);
1529  
1530    // Replicate the effects of any lifetime start/end markers which referenced
1531    // input objects in the extraction region by placing markers around the call.
1532    insertLifetimeMarkersSurroundingCall(
1533        oldFunction->getParent(), LifetimesStart.getArrayRef(), {}, TheCall);
1534  
1535    // Propagate personality info to the new function if there is one.
1536    if (oldFunction->hasPersonalityFn())
1537      newFunction->setPersonalityFn(oldFunction->getPersonalityFn());
1538  
1539    // Update the branch weights for the exit block.
1540    if (BFI && NumExitBlocks > 1)
1541      calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI);
1542  
1543    // Loop over all of the PHI nodes in the header and exit blocks, and change
1544    // any references to the old incoming edge to be the new incoming edge.
1545    for (BasicBlock::iterator I = header->begin(); isa<PHINode>(I); ++I) {
1546      PHINode *PN = cast<PHINode>(I);
1547      for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
1548        if (!Blocks.count(PN->getIncomingBlock(i)))
1549          PN->setIncomingBlock(i, newFuncRoot);
1550    }
1551  
1552    for (BasicBlock *ExitBB : ExitBlocks)
1553      for (PHINode &PN : ExitBB->phis()) {
1554        Value *IncomingCodeReplacerVal = nullptr;
1555        for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) {
1556          // Ignore incoming values from outside of the extracted region.
1557          if (!Blocks.count(PN.getIncomingBlock(i)))
1558            continue;
1559  
1560          // Ensure that there is only one incoming value from codeReplacer.
1561          if (!IncomingCodeReplacerVal) {
1562            PN.setIncomingBlock(i, codeReplacer);
1563            IncomingCodeReplacerVal = PN.getIncomingValue(i);
1564          } else
1565            assert(IncomingCodeReplacerVal == PN.getIncomingValue(i) &&
1566                   "PHI has two incompatbile incoming values from codeRepl");
1567        }
1568      }
1569  
1570    // Erase debug info intrinsics. Variable updates within the new function are
1571    // invisible to debuggers. This could be improved by defining a DISubprogram
1572    // for the new function.
1573    for (BasicBlock &BB : *newFunction) {
1574      auto BlockIt = BB.begin();
1575      // Remove debug info intrinsics from the new function.
1576      while (BlockIt != BB.end()) {
1577        Instruction *Inst = &*BlockIt;
1578        ++BlockIt;
1579        if (isa<DbgInfoIntrinsic>(Inst))
1580          Inst->eraseFromParent();
1581      }
1582      // Remove debug info intrinsics which refer to values in the new function
1583      // from the old function.
1584      SmallVector<DbgVariableIntrinsic *, 4> DbgUsers;
1585      for (Instruction &I : BB)
1586        findDbgUsers(DbgUsers, &I);
1587      for (DbgVariableIntrinsic *DVI : DbgUsers)
1588        DVI->eraseFromParent();
1589    }
1590  
1591    // Mark the new function `noreturn` if applicable. Terminators which resume
1592    // exception propagation are treated as returning instructions. This is to
1593    // avoid inserting traps after calls to outlined functions which unwind.
1594    bool doesNotReturn = none_of(*newFunction, [](const BasicBlock &BB) {
1595      const Instruction *Term = BB.getTerminator();
1596      return isa<ReturnInst>(Term) || isa<ResumeInst>(Term);
1597    });
1598    if (doesNotReturn)
1599      newFunction->setDoesNotReturn();
1600  
1601    LLVM_DEBUG(if (verifyFunction(*newFunction, &errs())) {
1602      newFunction->dump();
1603      report_fatal_error("verification of newFunction failed!");
1604    });
1605    LLVM_DEBUG(if (verifyFunction(*oldFunction))
1606               report_fatal_error("verification of oldFunction failed!"));
1607    LLVM_DEBUG(if (AC && verifyAssumptionCache(*oldFunction, AC))
1608               report_fatal_error("Stale Asumption cache for old Function!"));
1609    return newFunction;
1610  }
1611  
1612  bool CodeExtractor::verifyAssumptionCache(const Function& F,
1613                                            AssumptionCache *AC) {
1614    for (auto AssumeVH : AC->assumptions()) {
1615      CallInst *I = cast<CallInst>(AssumeVH);
1616      if (I->getFunction() != &F)
1617        return true;
1618    }
1619    return false;
1620  }
1621