xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Scalar/SROA.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1  //===- SROA.cpp - Scalar Replacement Of Aggregates ------------------------===//
2  //
3  // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4  // See https://llvm.org/LICENSE.txt for license information.
5  // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6  //
7  //===----------------------------------------------------------------------===//
8  /// \file
9  /// This transformation implements the well known scalar replacement of
10  /// aggregates transformation. It tries to identify promotable elements of an
11  /// aggregate alloca, and promote them to registers. It will also try to
12  /// convert uses of an element (or set of elements) of an alloca into a vector
13  /// or bitfield-style integer scalar if appropriate.
14  ///
15  /// It works to do this with minimal slicing of the alloca so that regions
16  /// which are merely transferred in and out of external memory remain unchanged
17  /// and are not decomposed to scalar code.
18  ///
19  /// Because this also performs alloca promotion, it can be thought of as also
20  /// serving the purpose of SSA formation. The algorithm iterates on the
21  /// function until all opportunities for promotion have been realized.
22  ///
23  //===----------------------------------------------------------------------===//
24  
25  #include "llvm/Transforms/Scalar/SROA.h"
26  #include "llvm/ADT/APInt.h"
27  #include "llvm/ADT/ArrayRef.h"
28  #include "llvm/ADT/DenseMap.h"
29  #include "llvm/ADT/MapVector.h"
30  #include "llvm/ADT/PointerIntPair.h"
31  #include "llvm/ADT/STLExtras.h"
32  #include "llvm/ADT/SetVector.h"
33  #include "llvm/ADT/SmallBitVector.h"
34  #include "llvm/ADT/SmallPtrSet.h"
35  #include "llvm/ADT/SmallVector.h"
36  #include "llvm/ADT/Statistic.h"
37  #include "llvm/ADT/StringRef.h"
38  #include "llvm/ADT/Twine.h"
39  #include "llvm/ADT/iterator.h"
40  #include "llvm/ADT/iterator_range.h"
41  #include "llvm/Analysis/AssumptionCache.h"
42  #include "llvm/Analysis/DomTreeUpdater.h"
43  #include "llvm/Analysis/GlobalsModRef.h"
44  #include "llvm/Analysis/Loads.h"
45  #include "llvm/Analysis/PtrUseVisitor.h"
46  #include "llvm/Config/llvm-config.h"
47  #include "llvm/IR/BasicBlock.h"
48  #include "llvm/IR/Constant.h"
49  #include "llvm/IR/ConstantFolder.h"
50  #include "llvm/IR/Constants.h"
51  #include "llvm/IR/DIBuilder.h"
52  #include "llvm/IR/DataLayout.h"
53  #include "llvm/IR/DebugInfo.h"
54  #include "llvm/IR/DebugInfoMetadata.h"
55  #include "llvm/IR/DerivedTypes.h"
56  #include "llvm/IR/Dominators.h"
57  #include "llvm/IR/Function.h"
58  #include "llvm/IR/GetElementPtrTypeIterator.h"
59  #include "llvm/IR/GlobalAlias.h"
60  #include "llvm/IR/IRBuilder.h"
61  #include "llvm/IR/InstVisitor.h"
62  #include "llvm/IR/Instruction.h"
63  #include "llvm/IR/Instructions.h"
64  #include "llvm/IR/IntrinsicInst.h"
65  #include "llvm/IR/LLVMContext.h"
66  #include "llvm/IR/Metadata.h"
67  #include "llvm/IR/Module.h"
68  #include "llvm/IR/Operator.h"
69  #include "llvm/IR/PassManager.h"
70  #include "llvm/IR/Type.h"
71  #include "llvm/IR/Use.h"
72  #include "llvm/IR/User.h"
73  #include "llvm/IR/Value.h"
74  #include "llvm/IR/ValueHandle.h"
75  #include "llvm/InitializePasses.h"
76  #include "llvm/Pass.h"
77  #include "llvm/Support/Casting.h"
78  #include "llvm/Support/CommandLine.h"
79  #include "llvm/Support/Compiler.h"
80  #include "llvm/Support/Debug.h"
81  #include "llvm/Support/ErrorHandling.h"
82  #include "llvm/Support/raw_ostream.h"
83  #include "llvm/Transforms/Scalar.h"
84  #include "llvm/Transforms/Utils/BasicBlockUtils.h"
85  #include "llvm/Transforms/Utils/Local.h"
86  #include "llvm/Transforms/Utils/PromoteMemToReg.h"
87  #include <algorithm>
88  #include <cassert>
89  #include <cstddef>
90  #include <cstdint>
91  #include <cstring>
92  #include <iterator>
93  #include <string>
94  #include <tuple>
95  #include <utility>
96  #include <variant>
97  #include <vector>
98  
99  using namespace llvm;
100  
101  #define DEBUG_TYPE "sroa"
102  
103  STATISTIC(NumAllocasAnalyzed, "Number of allocas analyzed for replacement");
104  STATISTIC(NumAllocaPartitions, "Number of alloca partitions formed");
105  STATISTIC(MaxPartitionsPerAlloca, "Maximum number of partitions per alloca");
106  STATISTIC(NumAllocaPartitionUses, "Number of alloca partition uses rewritten");
107  STATISTIC(MaxUsesPerAllocaPartition, "Maximum number of uses of a partition");
108  STATISTIC(NumNewAllocas, "Number of new, smaller allocas introduced");
109  STATISTIC(NumPromoted, "Number of allocas promoted to SSA values");
110  STATISTIC(NumLoadsSpeculated, "Number of loads speculated to allow promotion");
111  STATISTIC(NumLoadsPredicated,
112            "Number of loads rewritten into predicated loads to allow promotion");
113  STATISTIC(
114      NumStoresPredicated,
115      "Number of stores rewritten into predicated loads to allow promotion");
116  STATISTIC(NumDeleted, "Number of instructions deleted");
117  STATISTIC(NumVectorized, "Number of vectorized aggregates");
118  
119  /// Disable running mem2reg during SROA in order to test or debug SROA.
120  static cl::opt<bool> SROASkipMem2Reg("sroa-skip-mem2reg", cl::init(false),
121                                       cl::Hidden);
122  namespace {
123  
124  class AllocaSliceRewriter;
125  class AllocaSlices;
126  class Partition;
127  
128  class SelectHandSpeculativity {
129    unsigned char Storage = 0; // None are speculatable by default.
130    using TrueVal = Bitfield::Element<bool, 0, 1>;  // Low 0'th bit.
131    using FalseVal = Bitfield::Element<bool, 1, 1>; // Low 1'th bit.
132  public:
133    SelectHandSpeculativity() = default;
134    SelectHandSpeculativity &setAsSpeculatable(bool isTrueVal);
135    bool isSpeculatable(bool isTrueVal) const;
136    bool areAllSpeculatable() const;
137    bool areAnySpeculatable() const;
138    bool areNoneSpeculatable() const;
139    // For interop as int half of PointerIntPair.
operator intptr_t() const140    explicit operator intptr_t() const { return static_cast<intptr_t>(Storage); }
SelectHandSpeculativity(intptr_t Storage_)141    explicit SelectHandSpeculativity(intptr_t Storage_) : Storage(Storage_) {}
142  };
143  static_assert(sizeof(SelectHandSpeculativity) == sizeof(unsigned char));
144  
145  using PossiblySpeculatableLoad =
146      PointerIntPair<LoadInst *, 2, SelectHandSpeculativity>;
147  using UnspeculatableStore = StoreInst *;
148  using RewriteableMemOp =
149      std::variant<PossiblySpeculatableLoad, UnspeculatableStore>;
150  using RewriteableMemOps = SmallVector<RewriteableMemOp, 2>;
151  
152  /// An optimization pass providing Scalar Replacement of Aggregates.
153  ///
154  /// This pass takes allocations which can be completely analyzed (that is, they
155  /// don't escape) and tries to turn them into scalar SSA values. There are
156  /// a few steps to this process.
157  ///
158  /// 1) It takes allocations of aggregates and analyzes the ways in which they
159  ///    are used to try to split them into smaller allocations, ideally of
160  ///    a single scalar data type. It will split up memcpy and memset accesses
161  ///    as necessary and try to isolate individual scalar accesses.
162  /// 2) It will transform accesses into forms which are suitable for SSA value
163  ///    promotion. This can be replacing a memset with a scalar store of an
164  ///    integer value, or it can involve speculating operations on a PHI or
165  ///    select to be a PHI or select of the results.
166  /// 3) Finally, this will try to detect a pattern of accesses which map cleanly
167  ///    onto insert and extract operations on a vector value, and convert them to
168  ///    this form. By doing so, it will enable promotion of vector aggregates to
169  ///    SSA vector values.
170  class SROA {
171    LLVMContext *const C;
172    DomTreeUpdater *const DTU;
173    AssumptionCache *const AC;
174    const bool PreserveCFG;
175  
176    /// Worklist of alloca instructions to simplify.
177    ///
178    /// Each alloca in the function is added to this. Each new alloca formed gets
179    /// added to it as well to recursively simplify unless that alloca can be
180    /// directly promoted. Finally, each time we rewrite a use of an alloca other
181    /// the one being actively rewritten, we add it back onto the list if not
182    /// already present to ensure it is re-visited.
183    SmallSetVector<AllocaInst *, 16> Worklist;
184  
185    /// A collection of instructions to delete.
186    /// We try to batch deletions to simplify code and make things a bit more
187    /// efficient. We also make sure there is no dangling pointers.
188    SmallVector<WeakVH, 8> DeadInsts;
189  
190    /// Post-promotion worklist.
191    ///
192    /// Sometimes we discover an alloca which has a high probability of becoming
193    /// viable for SROA after a round of promotion takes place. In those cases,
194    /// the alloca is enqueued here for re-processing.
195    ///
196    /// Note that we have to be very careful to clear allocas out of this list in
197    /// the event they are deleted.
198    SmallSetVector<AllocaInst *, 16> PostPromotionWorklist;
199  
200    /// A collection of alloca instructions we can directly promote.
201    std::vector<AllocaInst *> PromotableAllocas;
202  
203    /// A worklist of PHIs to speculate prior to promoting allocas.
204    ///
205    /// All of these PHIs have been checked for the safety of speculation and by
206    /// being speculated will allow promoting allocas currently in the promotable
207    /// queue.
208    SmallSetVector<PHINode *, 8> SpeculatablePHIs;
209  
210    /// A worklist of select instructions to rewrite prior to promoting
211    /// allocas.
212    SmallMapVector<SelectInst *, RewriteableMemOps, 8> SelectsToRewrite;
213  
214    /// Select instructions that use an alloca and are subsequently loaded can be
215    /// rewritten to load both input pointers and then select between the result,
216    /// allowing the load of the alloca to be promoted.
217    /// From this:
218    ///   %P2 = select i1 %cond, ptr %Alloca, ptr %Other
219    ///   %V = load <type>, ptr %P2
220    /// to:
221    ///   %V1 = load <type>, ptr %Alloca      -> will be mem2reg'd
222    ///   %V2 = load <type>, ptr %Other
223    ///   %V = select i1 %cond, <type> %V1, <type> %V2
224    ///
225    /// We can do this to a select if its only uses are loads
226    /// and if either the operand to the select can be loaded unconditionally,
227    ///        or if we are allowed to perform CFG modifications.
228    /// If found an intervening bitcast with a single use of the load,
229    /// allow the promotion.
230    static std::optional<RewriteableMemOps>
231    isSafeSelectToSpeculate(SelectInst &SI, bool PreserveCFG);
232  
233  public:
SROA(LLVMContext * C,DomTreeUpdater * DTU,AssumptionCache * AC,SROAOptions PreserveCFG_)234    SROA(LLVMContext *C, DomTreeUpdater *DTU, AssumptionCache *AC,
235         SROAOptions PreserveCFG_)
236        : C(C), DTU(DTU), AC(AC),
237          PreserveCFG(PreserveCFG_ == SROAOptions::PreserveCFG) {}
238  
239    /// Main run method used by both the SROAPass and by the legacy pass.
240    std::pair<bool /*Changed*/, bool /*CFGChanged*/> runSROA(Function &F);
241  
242  private:
243    friend class AllocaSliceRewriter;
244  
245    bool presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS);
246    AllocaInst *rewritePartition(AllocaInst &AI, AllocaSlices &AS, Partition &P);
247    bool splitAlloca(AllocaInst &AI, AllocaSlices &AS);
248    std::pair<bool /*Changed*/, bool /*CFGChanged*/> runOnAlloca(AllocaInst &AI);
249    void clobberUse(Use &U);
250    bool deleteDeadInstructions(SmallPtrSetImpl<AllocaInst *> &DeletedAllocas);
251    bool promoteAllocas(Function &F);
252  };
253  
254  } // end anonymous namespace
255  
256  /// Calculate the fragment of a variable to use when slicing a store
257  /// based on the slice dimensions, existing fragment, and base storage
258  /// fragment.
259  /// Results:
260  /// UseFrag - Use Target as the new fragment.
261  /// UseNoFrag - The new slice already covers the whole variable.
262  /// Skip - The new alloca slice doesn't include this variable.
263  /// FIXME: Can we use calculateFragmentIntersect instead?
264  namespace {
265  enum FragCalcResult { UseFrag, UseNoFrag, Skip };
266  }
267  static FragCalcResult
calculateFragment(DILocalVariable * Variable,uint64_t NewStorageSliceOffsetInBits,uint64_t NewStorageSliceSizeInBits,std::optional<DIExpression::FragmentInfo> StorageFragment,std::optional<DIExpression::FragmentInfo> CurrentFragment,DIExpression::FragmentInfo & Target)268  calculateFragment(DILocalVariable *Variable,
269                    uint64_t NewStorageSliceOffsetInBits,
270                    uint64_t NewStorageSliceSizeInBits,
271                    std::optional<DIExpression::FragmentInfo> StorageFragment,
272                    std::optional<DIExpression::FragmentInfo> CurrentFragment,
273                    DIExpression::FragmentInfo &Target) {
274    // If the base storage describes part of the variable apply the offset and
275    // the size constraint.
276    if (StorageFragment) {
277      Target.SizeInBits =
278          std::min(NewStorageSliceSizeInBits, StorageFragment->SizeInBits);
279      Target.OffsetInBits =
280          NewStorageSliceOffsetInBits + StorageFragment->OffsetInBits;
281    } else {
282      Target.SizeInBits = NewStorageSliceSizeInBits;
283      Target.OffsetInBits = NewStorageSliceOffsetInBits;
284    }
285  
286    // If this slice extracts the entirety of an independent variable from a
287    // larger alloca, do not produce a fragment expression, as the variable is
288    // not fragmented.
289    if (!CurrentFragment) {
290      if (auto Size = Variable->getSizeInBits()) {
291        // Treat the current fragment as covering the whole variable.
292        CurrentFragment = DIExpression::FragmentInfo(*Size, 0);
293        if (Target == CurrentFragment)
294          return UseNoFrag;
295      }
296    }
297  
298    // No additional work to do if there isn't a fragment already, or there is
299    // but it already exactly describes the new assignment.
300    if (!CurrentFragment || *CurrentFragment == Target)
301      return UseFrag;
302  
303    // Reject the target fragment if it doesn't fit wholly within the current
304    // fragment. TODO: We could instead chop up the target to fit in the case of
305    // a partial overlap.
306    if (Target.startInBits() < CurrentFragment->startInBits() ||
307        Target.endInBits() > CurrentFragment->endInBits())
308      return Skip;
309  
310    // Target fits within the current fragment, return it.
311    return UseFrag;
312  }
313  
getAggregateVariable(DbgVariableIntrinsic * DVI)314  static DebugVariable getAggregateVariable(DbgVariableIntrinsic *DVI) {
315    return DebugVariable(DVI->getVariable(), std::nullopt,
316                         DVI->getDebugLoc().getInlinedAt());
317  }
getAggregateVariable(DbgVariableRecord * DVR)318  static DebugVariable getAggregateVariable(DbgVariableRecord *DVR) {
319    return DebugVariable(DVR->getVariable(), std::nullopt,
320                         DVR->getDebugLoc().getInlinedAt());
321  }
322  
323  /// Helpers for handling new and old debug info modes in migrateDebugInfo.
324  /// These overloads unwrap a DbgInstPtr {Instruction* | DbgRecord*} union based
325  /// on the \p Unused parameter type.
UnwrapDbgInstPtr(DbgInstPtr P,DbgVariableRecord * Unused)326  DbgVariableRecord *UnwrapDbgInstPtr(DbgInstPtr P, DbgVariableRecord *Unused) {
327    (void)Unused;
328    return static_cast<DbgVariableRecord *>(cast<DbgRecord *>(P));
329  }
UnwrapDbgInstPtr(DbgInstPtr P,DbgAssignIntrinsic * Unused)330  DbgAssignIntrinsic *UnwrapDbgInstPtr(DbgInstPtr P, DbgAssignIntrinsic *Unused) {
331    (void)Unused;
332    return static_cast<DbgAssignIntrinsic *>(cast<Instruction *>(P));
333  }
334  
335  /// Find linked dbg.assign and generate a new one with the correct
336  /// FragmentInfo. Link Inst to the new dbg.assign.  If Value is nullptr the
337  /// value component is copied from the old dbg.assign to the new.
338  /// \param OldAlloca             Alloca for the variable before splitting.
339  /// \param IsSplit               True if the store (not necessarily alloca)
340  ///                              is being split.
341  /// \param OldAllocaOffsetInBits Offset of the slice taken from OldAlloca.
342  /// \param SliceSizeInBits       New number of bits being written to.
343  /// \param OldInst               Instruction that is being split.
344  /// \param Inst                  New instruction performing this part of the
345  ///                              split store.
346  /// \param Dest                  Store destination.
347  /// \param Value                 Stored value.
348  /// \param DL                    Datalayout.
migrateDebugInfo(AllocaInst * OldAlloca,bool IsSplit,uint64_t OldAllocaOffsetInBits,uint64_t SliceSizeInBits,Instruction * OldInst,Instruction * Inst,Value * Dest,Value * Value,const DataLayout & DL)349  static void migrateDebugInfo(AllocaInst *OldAlloca, bool IsSplit,
350                               uint64_t OldAllocaOffsetInBits,
351                               uint64_t SliceSizeInBits, Instruction *OldInst,
352                               Instruction *Inst, Value *Dest, Value *Value,
353                               const DataLayout &DL) {
354    auto MarkerRange = at::getAssignmentMarkers(OldInst);
355    auto DVRAssignMarkerRange = at::getDVRAssignmentMarkers(OldInst);
356    // Nothing to do if OldInst has no linked dbg.assign intrinsics.
357    if (MarkerRange.empty() && DVRAssignMarkerRange.empty())
358      return;
359  
360    LLVM_DEBUG(dbgs() << "  migrateDebugInfo\n");
361    LLVM_DEBUG(dbgs() << "    OldAlloca: " << *OldAlloca << "\n");
362    LLVM_DEBUG(dbgs() << "    IsSplit: " << IsSplit << "\n");
363    LLVM_DEBUG(dbgs() << "    OldAllocaOffsetInBits: " << OldAllocaOffsetInBits
364                      << "\n");
365    LLVM_DEBUG(dbgs() << "    SliceSizeInBits: " << SliceSizeInBits << "\n");
366    LLVM_DEBUG(dbgs() << "    OldInst: " << *OldInst << "\n");
367    LLVM_DEBUG(dbgs() << "    Inst: " << *Inst << "\n");
368    LLVM_DEBUG(dbgs() << "    Dest: " << *Dest << "\n");
369    if (Value)
370      LLVM_DEBUG(dbgs() << "    Value: " << *Value << "\n");
371  
372    /// Map of aggregate variables to their fragment associated with OldAlloca.
373    DenseMap<DebugVariable, std::optional<DIExpression::FragmentInfo>>
374        BaseFragments;
375    for (auto *DAI : at::getAssignmentMarkers(OldAlloca))
376      BaseFragments[getAggregateVariable(DAI)] =
377          DAI->getExpression()->getFragmentInfo();
378    for (auto *DVR : at::getDVRAssignmentMarkers(OldAlloca))
379      BaseFragments[getAggregateVariable(DVR)] =
380          DVR->getExpression()->getFragmentInfo();
381  
382    // The new inst needs a DIAssignID unique metadata tag (if OldInst has
383    // one). It shouldn't already have one: assert this assumption.
384    assert(!Inst->getMetadata(LLVMContext::MD_DIAssignID));
385    DIAssignID *NewID = nullptr;
386    auto &Ctx = Inst->getContext();
387    DIBuilder DIB(*OldInst->getModule(), /*AllowUnresolved*/ false);
388    assert(OldAlloca->isStaticAlloca());
389  
390    auto MigrateDbgAssign = [&](auto *DbgAssign) {
391      LLVM_DEBUG(dbgs() << "      existing dbg.assign is: " << *DbgAssign
392                        << "\n");
393      auto *Expr = DbgAssign->getExpression();
394      bool SetKillLocation = false;
395  
396      if (IsSplit) {
397        std::optional<DIExpression::FragmentInfo> BaseFragment;
398        {
399          auto R = BaseFragments.find(getAggregateVariable(DbgAssign));
400          if (R == BaseFragments.end())
401            return;
402          BaseFragment = R->second;
403        }
404        std::optional<DIExpression::FragmentInfo> CurrentFragment =
405            Expr->getFragmentInfo();
406        DIExpression::FragmentInfo NewFragment;
407        FragCalcResult Result = calculateFragment(
408            DbgAssign->getVariable(), OldAllocaOffsetInBits, SliceSizeInBits,
409            BaseFragment, CurrentFragment, NewFragment);
410  
411        if (Result == Skip)
412          return;
413        if (Result == UseFrag && !(NewFragment == CurrentFragment)) {
414          if (CurrentFragment) {
415            // Rewrite NewFragment to be relative to the existing one (this is
416            // what createFragmentExpression wants).  CalculateFragment has
417            // already resolved the size for us. FIXME: Should it return the
418            // relative fragment too?
419            NewFragment.OffsetInBits -= CurrentFragment->OffsetInBits;
420          }
421          // Add the new fragment info to the existing expression if possible.
422          if (auto E = DIExpression::createFragmentExpression(
423                  Expr, NewFragment.OffsetInBits, NewFragment.SizeInBits)) {
424            Expr = *E;
425          } else {
426            // Otherwise, add the new fragment info to an empty expression and
427            // discard the value component of this dbg.assign as the value cannot
428            // be computed with the new fragment.
429            Expr = *DIExpression::createFragmentExpression(
430                DIExpression::get(Expr->getContext(), std::nullopt),
431                NewFragment.OffsetInBits, NewFragment.SizeInBits);
432            SetKillLocation = true;
433          }
434        }
435      }
436  
437      // If we haven't created a DIAssignID ID do that now and attach it to Inst.
438      if (!NewID) {
439        NewID = DIAssignID::getDistinct(Ctx);
440        Inst->setMetadata(LLVMContext::MD_DIAssignID, NewID);
441      }
442  
443      ::Value *NewValue = Value ? Value : DbgAssign->getValue();
444      auto *NewAssign = UnwrapDbgInstPtr(
445          DIB.insertDbgAssign(Inst, NewValue, DbgAssign->getVariable(), Expr,
446                              Dest,
447                              DIExpression::get(Expr->getContext(), std::nullopt),
448                              DbgAssign->getDebugLoc()),
449          DbgAssign);
450  
451      // If we've updated the value but the original dbg.assign has an arglist
452      // then kill it now - we can't use the requested new value.
453      // We can't replace the DIArgList with the new value as it'd leave
454      // the DIExpression in an invalid state (DW_OP_LLVM_arg operands without
455      // an arglist). And we can't keep the DIArgList in case the linked store
456      // is being split - in which case the DIArgList + expression may no longer
457      // be computing the correct value.
458      // This should be a very rare situation as it requires the value being
459      // stored to differ from the dbg.assign (i.e., the value has been
460      // represented differently in the debug intrinsic for some reason).
461      SetKillLocation |=
462          Value && (DbgAssign->hasArgList() ||
463                    !DbgAssign->getExpression()->isSingleLocationExpression());
464      if (SetKillLocation)
465        NewAssign->setKillLocation();
466  
467      // We could use more precision here at the cost of some additional (code)
468      // complexity - if the original dbg.assign was adjacent to its store, we
469      // could position this new dbg.assign adjacent to its store rather than the
470      // old dbg.assgn. That would result in interleaved dbg.assigns rather than
471      // what we get now:
472      //    split store !1
473      //    split store !2
474      //    dbg.assign !1
475      //    dbg.assign !2
476      // This (current behaviour) results results in debug assignments being
477      // noted as slightly offset (in code) from the store. In practice this
478      // should have little effect on the debugging experience due to the fact
479      // that all the split stores should get the same line number.
480      NewAssign->moveBefore(DbgAssign);
481  
482      NewAssign->setDebugLoc(DbgAssign->getDebugLoc());
483      LLVM_DEBUG(dbgs() << "Created new assign: " << *NewAssign << "\n");
484    };
485  
486    for_each(MarkerRange, MigrateDbgAssign);
487    for_each(DVRAssignMarkerRange, MigrateDbgAssign);
488  }
489  
490  namespace {
491  
492  /// A custom IRBuilder inserter which prefixes all names, but only in
493  /// Assert builds.
494  class IRBuilderPrefixedInserter final : public IRBuilderDefaultInserter {
495    std::string Prefix;
496  
getNameWithPrefix(const Twine & Name) const497    Twine getNameWithPrefix(const Twine &Name) const {
498      return Name.isTriviallyEmpty() ? Name : Prefix + Name;
499    }
500  
501  public:
SetNamePrefix(const Twine & P)502    void SetNamePrefix(const Twine &P) { Prefix = P.str(); }
503  
InsertHelper(Instruction * I,const Twine & Name,BasicBlock::iterator InsertPt) const504    void InsertHelper(Instruction *I, const Twine &Name,
505                      BasicBlock::iterator InsertPt) const override {
506      IRBuilderDefaultInserter::InsertHelper(I, getNameWithPrefix(Name),
507                                             InsertPt);
508    }
509  };
510  
511  /// Provide a type for IRBuilder that drops names in release builds.
512  using IRBuilderTy = IRBuilder<ConstantFolder, IRBuilderPrefixedInserter>;
513  
514  /// A used slice of an alloca.
515  ///
516  /// This structure represents a slice of an alloca used by some instruction. It
517  /// stores both the begin and end offsets of this use, a pointer to the use
518  /// itself, and a flag indicating whether we can classify the use as splittable
519  /// or not when forming partitions of the alloca.
520  class Slice {
521    /// The beginning offset of the range.
522    uint64_t BeginOffset = 0;
523  
524    /// The ending offset, not included in the range.
525    uint64_t EndOffset = 0;
526  
527    /// Storage for both the use of this slice and whether it can be
528    /// split.
529    PointerIntPair<Use *, 1, bool> UseAndIsSplittable;
530  
531  public:
532    Slice() = default;
533  
Slice(uint64_t BeginOffset,uint64_t EndOffset,Use * U,bool IsSplittable)534    Slice(uint64_t BeginOffset, uint64_t EndOffset, Use *U, bool IsSplittable)
535        : BeginOffset(BeginOffset), EndOffset(EndOffset),
536          UseAndIsSplittable(U, IsSplittable) {}
537  
beginOffset() const538    uint64_t beginOffset() const { return BeginOffset; }
endOffset() const539    uint64_t endOffset() const { return EndOffset; }
540  
isSplittable() const541    bool isSplittable() const { return UseAndIsSplittable.getInt(); }
makeUnsplittable()542    void makeUnsplittable() { UseAndIsSplittable.setInt(false); }
543  
getUse() const544    Use *getUse() const { return UseAndIsSplittable.getPointer(); }
545  
isDead() const546    bool isDead() const { return getUse() == nullptr; }
kill()547    void kill() { UseAndIsSplittable.setPointer(nullptr); }
548  
549    /// Support for ordering ranges.
550    ///
551    /// This provides an ordering over ranges such that start offsets are
552    /// always increasing, and within equal start offsets, the end offsets are
553    /// decreasing. Thus the spanning range comes first in a cluster with the
554    /// same start position.
operator <(const Slice & RHS) const555    bool operator<(const Slice &RHS) const {
556      if (beginOffset() < RHS.beginOffset())
557        return true;
558      if (beginOffset() > RHS.beginOffset())
559        return false;
560      if (isSplittable() != RHS.isSplittable())
561        return !isSplittable();
562      if (endOffset() > RHS.endOffset())
563        return true;
564      return false;
565    }
566  
567    /// Support comparison with a single offset to allow binary searches.
operator <(const Slice & LHS,uint64_t RHSOffset)568    friend LLVM_ATTRIBUTE_UNUSED bool operator<(const Slice &LHS,
569                                                uint64_t RHSOffset) {
570      return LHS.beginOffset() < RHSOffset;
571    }
operator <(uint64_t LHSOffset,const Slice & RHS)572    friend LLVM_ATTRIBUTE_UNUSED bool operator<(uint64_t LHSOffset,
573                                                const Slice &RHS) {
574      return LHSOffset < RHS.beginOffset();
575    }
576  
operator ==(const Slice & RHS) const577    bool operator==(const Slice &RHS) const {
578      return isSplittable() == RHS.isSplittable() &&
579             beginOffset() == RHS.beginOffset() && endOffset() == RHS.endOffset();
580    }
operator !=(const Slice & RHS) const581    bool operator!=(const Slice &RHS) const { return !operator==(RHS); }
582  };
583  
584  /// Representation of the alloca slices.
585  ///
586  /// This class represents the slices of an alloca which are formed by its
587  /// various uses. If a pointer escapes, we can't fully build a representation
588  /// for the slices used and we reflect that in this structure. The uses are
589  /// stored, sorted by increasing beginning offset and with unsplittable slices
590  /// starting at a particular offset before splittable slices.
591  class AllocaSlices {
592  public:
593    /// Construct the slices of a particular alloca.
594    AllocaSlices(const DataLayout &DL, AllocaInst &AI);
595  
596    /// Test whether a pointer to the allocation escapes our analysis.
597    ///
598    /// If this is true, the slices are never fully built and should be
599    /// ignored.
isEscaped() const600    bool isEscaped() const { return PointerEscapingInstr; }
601  
602    /// Support for iterating over the slices.
603    /// @{
604    using iterator = SmallVectorImpl<Slice>::iterator;
605    using range = iterator_range<iterator>;
606  
begin()607    iterator begin() { return Slices.begin(); }
end()608    iterator end() { return Slices.end(); }
609  
610    using const_iterator = SmallVectorImpl<Slice>::const_iterator;
611    using const_range = iterator_range<const_iterator>;
612  
begin() const613    const_iterator begin() const { return Slices.begin(); }
end() const614    const_iterator end() const { return Slices.end(); }
615    /// @}
616  
617    /// Erase a range of slices.
erase(iterator Start,iterator Stop)618    void erase(iterator Start, iterator Stop) { Slices.erase(Start, Stop); }
619  
620    /// Insert new slices for this alloca.
621    ///
622    /// This moves the slices into the alloca's slices collection, and re-sorts
623    /// everything so that the usual ordering properties of the alloca's slices
624    /// hold.
insert(ArrayRef<Slice> NewSlices)625    void insert(ArrayRef<Slice> NewSlices) {
626      int OldSize = Slices.size();
627      Slices.append(NewSlices.begin(), NewSlices.end());
628      auto SliceI = Slices.begin() + OldSize;
629      std::stable_sort(SliceI, Slices.end());
630      std::inplace_merge(Slices.begin(), SliceI, Slices.end());
631    }
632  
633    // Forward declare the iterator and range accessor for walking the
634    // partitions.
635    class partition_iterator;
636    iterator_range<partition_iterator> partitions();
637  
638    /// Access the dead users for this alloca.
getDeadUsers() const639    ArrayRef<Instruction *> getDeadUsers() const { return DeadUsers; }
640  
641    /// Access Uses that should be dropped if the alloca is promotable.
getDeadUsesIfPromotable() const642    ArrayRef<Use *> getDeadUsesIfPromotable() const {
643      return DeadUseIfPromotable;
644    }
645  
646    /// Access the dead operands referring to this alloca.
647    ///
648    /// These are operands which have cannot actually be used to refer to the
649    /// alloca as they are outside its range and the user doesn't correct for
650    /// that. These mostly consist of PHI node inputs and the like which we just
651    /// need to replace with undef.
getDeadOperands() const652    ArrayRef<Use *> getDeadOperands() const { return DeadOperands; }
653  
654  #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
655    void print(raw_ostream &OS, const_iterator I, StringRef Indent = "  ") const;
656    void printSlice(raw_ostream &OS, const_iterator I,
657                    StringRef Indent = "  ") const;
658    void printUse(raw_ostream &OS, const_iterator I,
659                  StringRef Indent = "  ") const;
660    void print(raw_ostream &OS) const;
661    void dump(const_iterator I) const;
662    void dump() const;
663  #endif
664  
665  private:
666    template <typename DerivedT, typename RetT = void> class BuilderBase;
667    class SliceBuilder;
668  
669    friend class AllocaSlices::SliceBuilder;
670  
671  #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
672    /// Handle to alloca instruction to simplify method interfaces.
673    AllocaInst &AI;
674  #endif
675  
676    /// The instruction responsible for this alloca not having a known set
677    /// of slices.
678    ///
679    /// When an instruction (potentially) escapes the pointer to the alloca, we
680    /// store a pointer to that here and abort trying to form slices of the
681    /// alloca. This will be null if the alloca slices are analyzed successfully.
682    Instruction *PointerEscapingInstr;
683  
684    /// The slices of the alloca.
685    ///
686    /// We store a vector of the slices formed by uses of the alloca here. This
687    /// vector is sorted by increasing begin offset, and then the unsplittable
688    /// slices before the splittable ones. See the Slice inner class for more
689    /// details.
690    SmallVector<Slice, 8> Slices;
691  
692    /// Instructions which will become dead if we rewrite the alloca.
693    ///
694    /// Note that these are not separated by slice. This is because we expect an
695    /// alloca to be completely rewritten or not rewritten at all. If rewritten,
696    /// all these instructions can simply be removed and replaced with poison as
697    /// they come from outside of the allocated space.
698    SmallVector<Instruction *, 8> DeadUsers;
699  
700    /// Uses which will become dead if can promote the alloca.
701    SmallVector<Use *, 8> DeadUseIfPromotable;
702  
703    /// Operands which will become dead if we rewrite the alloca.
704    ///
705    /// These are operands that in their particular use can be replaced with
706    /// poison when we rewrite the alloca. These show up in out-of-bounds inputs
707    /// to PHI nodes and the like. They aren't entirely dead (there might be
708    /// a GEP back into the bounds using it elsewhere) and nor is the PHI, but we
709    /// want to swap this particular input for poison to simplify the use lists of
710    /// the alloca.
711    SmallVector<Use *, 8> DeadOperands;
712  };
713  
714  /// A partition of the slices.
715  ///
716  /// An ephemeral representation for a range of slices which can be viewed as
717  /// a partition of the alloca. This range represents a span of the alloca's
718  /// memory which cannot be split, and provides access to all of the slices
719  /// overlapping some part of the partition.
720  ///
721  /// Objects of this type are produced by traversing the alloca's slices, but
722  /// are only ephemeral and not persistent.
723  class Partition {
724  private:
725    friend class AllocaSlices;
726    friend class AllocaSlices::partition_iterator;
727  
728    using iterator = AllocaSlices::iterator;
729  
730    /// The beginning and ending offsets of the alloca for this
731    /// partition.
732    uint64_t BeginOffset = 0, EndOffset = 0;
733  
734    /// The start and end iterators of this partition.
735    iterator SI, SJ;
736  
737    /// A collection of split slice tails overlapping the partition.
738    SmallVector<Slice *, 4> SplitTails;
739  
740    /// Raw constructor builds an empty partition starting and ending at
741    /// the given iterator.
Partition(iterator SI)742    Partition(iterator SI) : SI(SI), SJ(SI) {}
743  
744  public:
745    /// The start offset of this partition.
746    ///
747    /// All of the contained slices start at or after this offset.
beginOffset() const748    uint64_t beginOffset() const { return BeginOffset; }
749  
750    /// The end offset of this partition.
751    ///
752    /// All of the contained slices end at or before this offset.
endOffset() const753    uint64_t endOffset() const { return EndOffset; }
754  
755    /// The size of the partition.
756    ///
757    /// Note that this can never be zero.
size() const758    uint64_t size() const {
759      assert(BeginOffset < EndOffset && "Partitions must span some bytes!");
760      return EndOffset - BeginOffset;
761    }
762  
763    /// Test whether this partition contains no slices, and merely spans
764    /// a region occupied by split slices.
empty() const765    bool empty() const { return SI == SJ; }
766  
767    /// \name Iterate slices that start within the partition.
768    /// These may be splittable or unsplittable. They have a begin offset >= the
769    /// partition begin offset.
770    /// @{
771    // FIXME: We should probably define a "concat_iterator" helper and use that
772    // to stitch together pointee_iterators over the split tails and the
773    // contiguous iterators of the partition. That would give a much nicer
774    // interface here. We could then additionally expose filtered iterators for
775    // split, unsplit, and unsplittable splices based on the usage patterns.
begin() const776    iterator begin() const { return SI; }
end() const777    iterator end() const { return SJ; }
778    /// @}
779  
780    /// Get the sequence of split slice tails.
781    ///
782    /// These tails are of slices which start before this partition but are
783    /// split and overlap into the partition. We accumulate these while forming
784    /// partitions.
splitSliceTails() const785    ArrayRef<Slice *> splitSliceTails() const { return SplitTails; }
786  };
787  
788  } // end anonymous namespace
789  
790  /// An iterator over partitions of the alloca's slices.
791  ///
792  /// This iterator implements the core algorithm for partitioning the alloca's
793  /// slices. It is a forward iterator as we don't support backtracking for
794  /// efficiency reasons, and re-use a single storage area to maintain the
795  /// current set of split slices.
796  ///
797  /// It is templated on the slice iterator type to use so that it can operate
798  /// with either const or non-const slice iterators.
799  class AllocaSlices::partition_iterator
800      : public iterator_facade_base<partition_iterator, std::forward_iterator_tag,
801                                    Partition> {
802    friend class AllocaSlices;
803  
804    /// Most of the state for walking the partitions is held in a class
805    /// with a nice interface for examining them.
806    Partition P;
807  
808    /// We need to keep the end of the slices to know when to stop.
809    AllocaSlices::iterator SE;
810  
811    /// We also need to keep track of the maximum split end offset seen.
812    /// FIXME: Do we really?
813    uint64_t MaxSplitSliceEndOffset = 0;
814  
815    /// Sets the partition to be empty at given iterator, and sets the
816    /// end iterator.
partition_iterator(AllocaSlices::iterator SI,AllocaSlices::iterator SE)817    partition_iterator(AllocaSlices::iterator SI, AllocaSlices::iterator SE)
818        : P(SI), SE(SE) {
819      // If not already at the end, advance our state to form the initial
820      // partition.
821      if (SI != SE)
822        advance();
823    }
824  
825    /// Advance the iterator to the next partition.
826    ///
827    /// Requires that the iterator not be at the end of the slices.
advance()828    void advance() {
829      assert((P.SI != SE || !P.SplitTails.empty()) &&
830             "Cannot advance past the end of the slices!");
831  
832      // Clear out any split uses which have ended.
833      if (!P.SplitTails.empty()) {
834        if (P.EndOffset >= MaxSplitSliceEndOffset) {
835          // If we've finished all splits, this is easy.
836          P.SplitTails.clear();
837          MaxSplitSliceEndOffset = 0;
838        } else {
839          // Remove the uses which have ended in the prior partition. This
840          // cannot change the max split slice end because we just checked that
841          // the prior partition ended prior to that max.
842          llvm::erase_if(P.SplitTails,
843                         [&](Slice *S) { return S->endOffset() <= P.EndOffset; });
844          assert(llvm::any_of(P.SplitTails,
845                              [&](Slice *S) {
846                                return S->endOffset() == MaxSplitSliceEndOffset;
847                              }) &&
848                 "Could not find the current max split slice offset!");
849          assert(llvm::all_of(P.SplitTails,
850                              [&](Slice *S) {
851                                return S->endOffset() <= MaxSplitSliceEndOffset;
852                              }) &&
853                 "Max split slice end offset is not actually the max!");
854        }
855      }
856  
857      // If P.SI is already at the end, then we've cleared the split tail and
858      // now have an end iterator.
859      if (P.SI == SE) {
860        assert(P.SplitTails.empty() && "Failed to clear the split slices!");
861        return;
862      }
863  
864      // If we had a non-empty partition previously, set up the state for
865      // subsequent partitions.
866      if (P.SI != P.SJ) {
867        // Accumulate all the splittable slices which started in the old
868        // partition into the split list.
869        for (Slice &S : P)
870          if (S.isSplittable() && S.endOffset() > P.EndOffset) {
871            P.SplitTails.push_back(&S);
872            MaxSplitSliceEndOffset =
873                std::max(S.endOffset(), MaxSplitSliceEndOffset);
874          }
875  
876        // Start from the end of the previous partition.
877        P.SI = P.SJ;
878  
879        // If P.SI is now at the end, we at most have a tail of split slices.
880        if (P.SI == SE) {
881          P.BeginOffset = P.EndOffset;
882          P.EndOffset = MaxSplitSliceEndOffset;
883          return;
884        }
885  
886        // If the we have split slices and the next slice is after a gap and is
887        // not splittable immediately form an empty partition for the split
888        // slices up until the next slice begins.
889        if (!P.SplitTails.empty() && P.SI->beginOffset() != P.EndOffset &&
890            !P.SI->isSplittable()) {
891          P.BeginOffset = P.EndOffset;
892          P.EndOffset = P.SI->beginOffset();
893          return;
894        }
895      }
896  
897      // OK, we need to consume new slices. Set the end offset based on the
898      // current slice, and step SJ past it. The beginning offset of the
899      // partition is the beginning offset of the next slice unless we have
900      // pre-existing split slices that are continuing, in which case we begin
901      // at the prior end offset.
902      P.BeginOffset = P.SplitTails.empty() ? P.SI->beginOffset() : P.EndOffset;
903      P.EndOffset = P.SI->endOffset();
904      ++P.SJ;
905  
906      // There are two strategies to form a partition based on whether the
907      // partition starts with an unsplittable slice or a splittable slice.
908      if (!P.SI->isSplittable()) {
909        // When we're forming an unsplittable region, it must always start at
910        // the first slice and will extend through its end.
911        assert(P.BeginOffset == P.SI->beginOffset());
912  
913        // Form a partition including all of the overlapping slices with this
914        // unsplittable slice.
915        while (P.SJ != SE && P.SJ->beginOffset() < P.EndOffset) {
916          if (!P.SJ->isSplittable())
917            P.EndOffset = std::max(P.EndOffset, P.SJ->endOffset());
918          ++P.SJ;
919        }
920  
921        // We have a partition across a set of overlapping unsplittable
922        // partitions.
923        return;
924      }
925  
926      // If we're starting with a splittable slice, then we need to form
927      // a synthetic partition spanning it and any other overlapping splittable
928      // splices.
929      assert(P.SI->isSplittable() && "Forming a splittable partition!");
930  
931      // Collect all of the overlapping splittable slices.
932      while (P.SJ != SE && P.SJ->beginOffset() < P.EndOffset &&
933             P.SJ->isSplittable()) {
934        P.EndOffset = std::max(P.EndOffset, P.SJ->endOffset());
935        ++P.SJ;
936      }
937  
938      // Back upiP.EndOffset if we ended the span early when encountering an
939      // unsplittable slice. This synthesizes the early end offset of
940      // a partition spanning only splittable slices.
941      if (P.SJ != SE && P.SJ->beginOffset() < P.EndOffset) {
942        assert(!P.SJ->isSplittable());
943        P.EndOffset = P.SJ->beginOffset();
944      }
945    }
946  
947  public:
operator ==(const partition_iterator & RHS) const948    bool operator==(const partition_iterator &RHS) const {
949      assert(SE == RHS.SE &&
950             "End iterators don't match between compared partition iterators!");
951  
952      // The observed positions of partitions is marked by the P.SI iterator and
953      // the emptiness of the split slices. The latter is only relevant when
954      // P.SI == SE, as the end iterator will additionally have an empty split
955      // slices list, but the prior may have the same P.SI and a tail of split
956      // slices.
957      if (P.SI == RHS.P.SI && P.SplitTails.empty() == RHS.P.SplitTails.empty()) {
958        assert(P.SJ == RHS.P.SJ &&
959               "Same set of slices formed two different sized partitions!");
960        assert(P.SplitTails.size() == RHS.P.SplitTails.size() &&
961               "Same slice position with differently sized non-empty split "
962               "slice tails!");
963        return true;
964      }
965      return false;
966    }
967  
operator ++()968    partition_iterator &operator++() {
969      advance();
970      return *this;
971    }
972  
operator *()973    Partition &operator*() { return P; }
974  };
975  
976  /// A forward range over the partitions of the alloca's slices.
977  ///
978  /// This accesses an iterator range over the partitions of the alloca's
979  /// slices. It computes these partitions on the fly based on the overlapping
980  /// offsets of the slices and the ability to split them. It will visit "empty"
981  /// partitions to cover regions of the alloca only accessed via split
982  /// slices.
partitions()983  iterator_range<AllocaSlices::partition_iterator> AllocaSlices::partitions() {
984    return make_range(partition_iterator(begin(), end()),
985                      partition_iterator(end(), end()));
986  }
987  
foldSelectInst(SelectInst & SI)988  static Value *foldSelectInst(SelectInst &SI) {
989    // If the condition being selected on is a constant or the same value is
990    // being selected between, fold the select. Yes this does (rarely) happen
991    // early on.
992    if (ConstantInt *CI = dyn_cast<ConstantInt>(SI.getCondition()))
993      return SI.getOperand(1 + CI->isZero());
994    if (SI.getOperand(1) == SI.getOperand(2))
995      return SI.getOperand(1);
996  
997    return nullptr;
998  }
999  
1000  /// A helper that folds a PHI node or a select.
foldPHINodeOrSelectInst(Instruction & I)1001  static Value *foldPHINodeOrSelectInst(Instruction &I) {
1002    if (PHINode *PN = dyn_cast<PHINode>(&I)) {
1003      // If PN merges together the same value, return that value.
1004      return PN->hasConstantValue();
1005    }
1006    return foldSelectInst(cast<SelectInst>(I));
1007  }
1008  
1009  /// Builder for the alloca slices.
1010  ///
1011  /// This class builds a set of alloca slices by recursively visiting the uses
1012  /// of an alloca and making a slice for each load and store at each offset.
1013  class AllocaSlices::SliceBuilder : public PtrUseVisitor<SliceBuilder> {
1014    friend class PtrUseVisitor<SliceBuilder>;
1015    friend class InstVisitor<SliceBuilder>;
1016  
1017    using Base = PtrUseVisitor<SliceBuilder>;
1018  
1019    const uint64_t AllocSize;
1020    AllocaSlices &AS;
1021  
1022    SmallDenseMap<Instruction *, unsigned> MemTransferSliceMap;
1023    SmallDenseMap<Instruction *, uint64_t> PHIOrSelectSizes;
1024  
1025    /// Set to de-duplicate dead instructions found in the use walk.
1026    SmallPtrSet<Instruction *, 4> VisitedDeadInsts;
1027  
1028  public:
SliceBuilder(const DataLayout & DL,AllocaInst & AI,AllocaSlices & AS)1029    SliceBuilder(const DataLayout &DL, AllocaInst &AI, AllocaSlices &AS)
1030        : PtrUseVisitor<SliceBuilder>(DL),
1031          AllocSize(DL.getTypeAllocSize(AI.getAllocatedType()).getFixedValue()),
1032          AS(AS) {}
1033  
1034  private:
markAsDead(Instruction & I)1035    void markAsDead(Instruction &I) {
1036      if (VisitedDeadInsts.insert(&I).second)
1037        AS.DeadUsers.push_back(&I);
1038    }
1039  
insertUse(Instruction & I,const APInt & Offset,uint64_t Size,bool IsSplittable=false)1040    void insertUse(Instruction &I, const APInt &Offset, uint64_t Size,
1041                   bool IsSplittable = false) {
1042      // Completely skip uses which have a zero size or start either before or
1043      // past the end of the allocation.
1044      if (Size == 0 || Offset.uge(AllocSize)) {
1045        LLVM_DEBUG(dbgs() << "WARNING: Ignoring " << Size << " byte use @"
1046                          << Offset
1047                          << " which has zero size or starts outside of the "
1048                          << AllocSize << " byte alloca:\n"
1049                          << "    alloca: " << AS.AI << "\n"
1050                          << "       use: " << I << "\n");
1051        return markAsDead(I);
1052      }
1053  
1054      uint64_t BeginOffset = Offset.getZExtValue();
1055      uint64_t EndOffset = BeginOffset + Size;
1056  
1057      // Clamp the end offset to the end of the allocation. Note that this is
1058      // formulated to handle even the case where "BeginOffset + Size" overflows.
1059      // This may appear superficially to be something we could ignore entirely,
1060      // but that is not so! There may be widened loads or PHI-node uses where
1061      // some instructions are dead but not others. We can't completely ignore
1062      // them, and so have to record at least the information here.
1063      assert(AllocSize >= BeginOffset); // Established above.
1064      if (Size > AllocSize - BeginOffset) {
1065        LLVM_DEBUG(dbgs() << "WARNING: Clamping a " << Size << " byte use @"
1066                          << Offset << " to remain within the " << AllocSize
1067                          << " byte alloca:\n"
1068                          << "    alloca: " << AS.AI << "\n"
1069                          << "       use: " << I << "\n");
1070        EndOffset = AllocSize;
1071      }
1072  
1073      AS.Slices.push_back(Slice(BeginOffset, EndOffset, U, IsSplittable));
1074    }
1075  
visitBitCastInst(BitCastInst & BC)1076    void visitBitCastInst(BitCastInst &BC) {
1077      if (BC.use_empty())
1078        return markAsDead(BC);
1079  
1080      return Base::visitBitCastInst(BC);
1081    }
1082  
visitAddrSpaceCastInst(AddrSpaceCastInst & ASC)1083    void visitAddrSpaceCastInst(AddrSpaceCastInst &ASC) {
1084      if (ASC.use_empty())
1085        return markAsDead(ASC);
1086  
1087      return Base::visitAddrSpaceCastInst(ASC);
1088    }
1089  
visitGetElementPtrInst(GetElementPtrInst & GEPI)1090    void visitGetElementPtrInst(GetElementPtrInst &GEPI) {
1091      if (GEPI.use_empty())
1092        return markAsDead(GEPI);
1093  
1094      return Base::visitGetElementPtrInst(GEPI);
1095    }
1096  
handleLoadOrStore(Type * Ty,Instruction & I,const APInt & Offset,uint64_t Size,bool IsVolatile)1097    void handleLoadOrStore(Type *Ty, Instruction &I, const APInt &Offset,
1098                           uint64_t Size, bool IsVolatile) {
1099      // We allow splitting of non-volatile loads and stores where the type is an
1100      // integer type. These may be used to implement 'memcpy' or other "transfer
1101      // of bits" patterns.
1102      bool IsSplittable =
1103          Ty->isIntegerTy() && !IsVolatile && DL.typeSizeEqualsStoreSize(Ty);
1104  
1105      insertUse(I, Offset, Size, IsSplittable);
1106    }
1107  
visitLoadInst(LoadInst & LI)1108    void visitLoadInst(LoadInst &LI) {
1109      assert((!LI.isSimple() || LI.getType()->isSingleValueType()) &&
1110             "All simple FCA loads should have been pre-split");
1111  
1112      if (!IsOffsetKnown)
1113        return PI.setAborted(&LI);
1114  
1115      TypeSize Size = DL.getTypeStoreSize(LI.getType());
1116      if (Size.isScalable())
1117        return PI.setAborted(&LI);
1118  
1119      return handleLoadOrStore(LI.getType(), LI, Offset, Size.getFixedValue(),
1120                               LI.isVolatile());
1121    }
1122  
visitStoreInst(StoreInst & SI)1123    void visitStoreInst(StoreInst &SI) {
1124      Value *ValOp = SI.getValueOperand();
1125      if (ValOp == *U)
1126        return PI.setEscapedAndAborted(&SI);
1127      if (!IsOffsetKnown)
1128        return PI.setAborted(&SI);
1129  
1130      TypeSize StoreSize = DL.getTypeStoreSize(ValOp->getType());
1131      if (StoreSize.isScalable())
1132        return PI.setAborted(&SI);
1133  
1134      uint64_t Size = StoreSize.getFixedValue();
1135  
1136      // If this memory access can be shown to *statically* extend outside the
1137      // bounds of the allocation, it's behavior is undefined, so simply
1138      // ignore it. Note that this is more strict than the generic clamping
1139      // behavior of insertUse. We also try to handle cases which might run the
1140      // risk of overflow.
1141      // FIXME: We should instead consider the pointer to have escaped if this
1142      // function is being instrumented for addressing bugs or race conditions.
1143      if (Size > AllocSize || Offset.ugt(AllocSize - Size)) {
1144        LLVM_DEBUG(dbgs() << "WARNING: Ignoring " << Size << " byte store @"
1145                          << Offset << " which extends past the end of the "
1146                          << AllocSize << " byte alloca:\n"
1147                          << "    alloca: " << AS.AI << "\n"
1148                          << "       use: " << SI << "\n");
1149        return markAsDead(SI);
1150      }
1151  
1152      assert((!SI.isSimple() || ValOp->getType()->isSingleValueType()) &&
1153             "All simple FCA stores should have been pre-split");
1154      handleLoadOrStore(ValOp->getType(), SI, Offset, Size, SI.isVolatile());
1155    }
1156  
visitMemSetInst(MemSetInst & II)1157    void visitMemSetInst(MemSetInst &II) {
1158      assert(II.getRawDest() == *U && "Pointer use is not the destination?");
1159      ConstantInt *Length = dyn_cast<ConstantInt>(II.getLength());
1160      if ((Length && Length->getValue() == 0) ||
1161          (IsOffsetKnown && Offset.uge(AllocSize)))
1162        // Zero-length mem transfer intrinsics can be ignored entirely.
1163        return markAsDead(II);
1164  
1165      if (!IsOffsetKnown)
1166        return PI.setAborted(&II);
1167  
1168      insertUse(II, Offset,
1169                Length ? Length->getLimitedValue()
1170                       : AllocSize - Offset.getLimitedValue(),
1171                (bool)Length);
1172    }
1173  
visitMemTransferInst(MemTransferInst & II)1174    void visitMemTransferInst(MemTransferInst &II) {
1175      ConstantInt *Length = dyn_cast<ConstantInt>(II.getLength());
1176      if (Length && Length->getValue() == 0)
1177        // Zero-length mem transfer intrinsics can be ignored entirely.
1178        return markAsDead(II);
1179  
1180      // Because we can visit these intrinsics twice, also check to see if the
1181      // first time marked this instruction as dead. If so, skip it.
1182      if (VisitedDeadInsts.count(&II))
1183        return;
1184  
1185      if (!IsOffsetKnown)
1186        return PI.setAborted(&II);
1187  
1188      // This side of the transfer is completely out-of-bounds, and so we can
1189      // nuke the entire transfer. However, we also need to nuke the other side
1190      // if already added to our partitions.
1191      // FIXME: Yet another place we really should bypass this when
1192      // instrumenting for ASan.
1193      if (Offset.uge(AllocSize)) {
1194        SmallDenseMap<Instruction *, unsigned>::iterator MTPI =
1195            MemTransferSliceMap.find(&II);
1196        if (MTPI != MemTransferSliceMap.end())
1197          AS.Slices[MTPI->second].kill();
1198        return markAsDead(II);
1199      }
1200  
1201      uint64_t RawOffset = Offset.getLimitedValue();
1202      uint64_t Size = Length ? Length->getLimitedValue() : AllocSize - RawOffset;
1203  
1204      // Check for the special case where the same exact value is used for both
1205      // source and dest.
1206      if (*U == II.getRawDest() && *U == II.getRawSource()) {
1207        // For non-volatile transfers this is a no-op.
1208        if (!II.isVolatile())
1209          return markAsDead(II);
1210  
1211        return insertUse(II, Offset, Size, /*IsSplittable=*/false);
1212      }
1213  
1214      // If we have seen both source and destination for a mem transfer, then
1215      // they both point to the same alloca.
1216      bool Inserted;
1217      SmallDenseMap<Instruction *, unsigned>::iterator MTPI;
1218      std::tie(MTPI, Inserted) =
1219          MemTransferSliceMap.insert(std::make_pair(&II, AS.Slices.size()));
1220      unsigned PrevIdx = MTPI->second;
1221      if (!Inserted) {
1222        Slice &PrevP = AS.Slices[PrevIdx];
1223  
1224        // Check if the begin offsets match and this is a non-volatile transfer.
1225        // In that case, we can completely elide the transfer.
1226        if (!II.isVolatile() && PrevP.beginOffset() == RawOffset) {
1227          PrevP.kill();
1228          return markAsDead(II);
1229        }
1230  
1231        // Otherwise we have an offset transfer within the same alloca. We can't
1232        // split those.
1233        PrevP.makeUnsplittable();
1234      }
1235  
1236      // Insert the use now that we've fixed up the splittable nature.
1237      insertUse(II, Offset, Size, /*IsSplittable=*/Inserted && Length);
1238  
1239      // Check that we ended up with a valid index in the map.
1240      assert(AS.Slices[PrevIdx].getUse()->getUser() == &II &&
1241             "Map index doesn't point back to a slice with this user.");
1242    }
1243  
1244    // Disable SRoA for any intrinsics except for lifetime invariants and
1245    // invariant group.
1246    // FIXME: What about debug intrinsics? This matches old behavior, but
1247    // doesn't make sense.
visitIntrinsicInst(IntrinsicInst & II)1248    void visitIntrinsicInst(IntrinsicInst &II) {
1249      if (II.isDroppable()) {
1250        AS.DeadUseIfPromotable.push_back(U);
1251        return;
1252      }
1253  
1254      if (!IsOffsetKnown)
1255        return PI.setAborted(&II);
1256  
1257      if (II.isLifetimeStartOrEnd()) {
1258        ConstantInt *Length = cast<ConstantInt>(II.getArgOperand(0));
1259        uint64_t Size = std::min(AllocSize - Offset.getLimitedValue(),
1260                                 Length->getLimitedValue());
1261        insertUse(II, Offset, Size, true);
1262        return;
1263      }
1264  
1265      if (II.isLaunderOrStripInvariantGroup()) {
1266        insertUse(II, Offset, AllocSize, true);
1267        enqueueUsers(II);
1268        return;
1269      }
1270  
1271      Base::visitIntrinsicInst(II);
1272    }
1273  
hasUnsafePHIOrSelectUse(Instruction * Root,uint64_t & Size)1274    Instruction *hasUnsafePHIOrSelectUse(Instruction *Root, uint64_t &Size) {
1275      // We consider any PHI or select that results in a direct load or store of
1276      // the same offset to be a viable use for slicing purposes. These uses
1277      // are considered unsplittable and the size is the maximum loaded or stored
1278      // size.
1279      SmallPtrSet<Instruction *, 4> Visited;
1280      SmallVector<std::pair<Instruction *, Instruction *>, 4> Uses;
1281      Visited.insert(Root);
1282      Uses.push_back(std::make_pair(cast<Instruction>(*U), Root));
1283      const DataLayout &DL = Root->getDataLayout();
1284      // If there are no loads or stores, the access is dead. We mark that as
1285      // a size zero access.
1286      Size = 0;
1287      do {
1288        Instruction *I, *UsedI;
1289        std::tie(UsedI, I) = Uses.pop_back_val();
1290  
1291        if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
1292          TypeSize LoadSize = DL.getTypeStoreSize(LI->getType());
1293          if (LoadSize.isScalable()) {
1294            PI.setAborted(LI);
1295            return nullptr;
1296          }
1297          Size = std::max(Size, LoadSize.getFixedValue());
1298          continue;
1299        }
1300        if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
1301          Value *Op = SI->getOperand(0);
1302          if (Op == UsedI)
1303            return SI;
1304          TypeSize StoreSize = DL.getTypeStoreSize(Op->getType());
1305          if (StoreSize.isScalable()) {
1306            PI.setAborted(SI);
1307            return nullptr;
1308          }
1309          Size = std::max(Size, StoreSize.getFixedValue());
1310          continue;
1311        }
1312  
1313        if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(I)) {
1314          if (!GEP->hasAllZeroIndices())
1315            return GEP;
1316        } else if (!isa<BitCastInst>(I) && !isa<PHINode>(I) &&
1317                   !isa<SelectInst>(I) && !isa<AddrSpaceCastInst>(I)) {
1318          return I;
1319        }
1320  
1321        for (User *U : I->users())
1322          if (Visited.insert(cast<Instruction>(U)).second)
1323            Uses.push_back(std::make_pair(I, cast<Instruction>(U)));
1324      } while (!Uses.empty());
1325  
1326      return nullptr;
1327    }
1328  
visitPHINodeOrSelectInst(Instruction & I)1329    void visitPHINodeOrSelectInst(Instruction &I) {
1330      assert(isa<PHINode>(I) || isa<SelectInst>(I));
1331      if (I.use_empty())
1332        return markAsDead(I);
1333  
1334      // If this is a PHI node before a catchswitch, we cannot insert any non-PHI
1335      // instructions in this BB, which may be required during rewriting. Bail out
1336      // on these cases.
1337      if (isa<PHINode>(I) &&
1338          I.getParent()->getFirstInsertionPt() == I.getParent()->end())
1339        return PI.setAborted(&I);
1340  
1341      // TODO: We could use simplifyInstruction here to fold PHINodes and
1342      // SelectInsts. However, doing so requires to change the current
1343      // dead-operand-tracking mechanism. For instance, suppose neither loading
1344      // from %U nor %other traps. Then "load (select undef, %U, %other)" does not
1345      // trap either.  However, if we simply replace %U with undef using the
1346      // current dead-operand-tracking mechanism, "load (select undef, undef,
1347      // %other)" may trap because the select may return the first operand
1348      // "undef".
1349      if (Value *Result = foldPHINodeOrSelectInst(I)) {
1350        if (Result == *U)
1351          // If the result of the constant fold will be the pointer, recurse
1352          // through the PHI/select as if we had RAUW'ed it.
1353          enqueueUsers(I);
1354        else
1355          // Otherwise the operand to the PHI/select is dead, and we can replace
1356          // it with poison.
1357          AS.DeadOperands.push_back(U);
1358  
1359        return;
1360      }
1361  
1362      if (!IsOffsetKnown)
1363        return PI.setAborted(&I);
1364  
1365      // See if we already have computed info on this node.
1366      uint64_t &Size = PHIOrSelectSizes[&I];
1367      if (!Size) {
1368        // This is a new PHI/Select, check for an unsafe use of it.
1369        if (Instruction *UnsafeI = hasUnsafePHIOrSelectUse(&I, Size))
1370          return PI.setAborted(UnsafeI);
1371      }
1372  
1373      // For PHI and select operands outside the alloca, we can't nuke the entire
1374      // phi or select -- the other side might still be relevant, so we special
1375      // case them here and use a separate structure to track the operands
1376      // themselves which should be replaced with poison.
1377      // FIXME: This should instead be escaped in the event we're instrumenting
1378      // for address sanitization.
1379      if (Offset.uge(AllocSize)) {
1380        AS.DeadOperands.push_back(U);
1381        return;
1382      }
1383  
1384      insertUse(I, Offset, Size);
1385    }
1386  
visitPHINode(PHINode & PN)1387    void visitPHINode(PHINode &PN) { visitPHINodeOrSelectInst(PN); }
1388  
visitSelectInst(SelectInst & SI)1389    void visitSelectInst(SelectInst &SI) { visitPHINodeOrSelectInst(SI); }
1390  
1391    /// Disable SROA entirely if there are unhandled users of the alloca.
visitInstruction(Instruction & I)1392    void visitInstruction(Instruction &I) { PI.setAborted(&I); }
1393  };
1394  
AllocaSlices(const DataLayout & DL,AllocaInst & AI)1395  AllocaSlices::AllocaSlices(const DataLayout &DL, AllocaInst &AI)
1396      :
1397  #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
1398        AI(AI),
1399  #endif
1400        PointerEscapingInstr(nullptr) {
1401    SliceBuilder PB(DL, AI, *this);
1402    SliceBuilder::PtrInfo PtrI = PB.visitPtr(AI);
1403    if (PtrI.isEscaped() || PtrI.isAborted()) {
1404      // FIXME: We should sink the escape vs. abort info into the caller nicely,
1405      // possibly by just storing the PtrInfo in the AllocaSlices.
1406      PointerEscapingInstr = PtrI.getEscapingInst() ? PtrI.getEscapingInst()
1407                                                    : PtrI.getAbortingInst();
1408      assert(PointerEscapingInstr && "Did not track a bad instruction");
1409      return;
1410    }
1411  
1412    llvm::erase_if(Slices, [](const Slice &S) { return S.isDead(); });
1413  
1414    // Sort the uses. This arranges for the offsets to be in ascending order,
1415    // and the sizes to be in descending order.
1416    llvm::stable_sort(Slices);
1417  }
1418  
1419  #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
1420  
print(raw_ostream & OS,const_iterator I,StringRef Indent) const1421  void AllocaSlices::print(raw_ostream &OS, const_iterator I,
1422                           StringRef Indent) const {
1423    printSlice(OS, I, Indent);
1424    OS << "\n";
1425    printUse(OS, I, Indent);
1426  }
1427  
printSlice(raw_ostream & OS,const_iterator I,StringRef Indent) const1428  void AllocaSlices::printSlice(raw_ostream &OS, const_iterator I,
1429                                StringRef Indent) const {
1430    OS << Indent << "[" << I->beginOffset() << "," << I->endOffset() << ")"
1431       << " slice #" << (I - begin())
1432       << (I->isSplittable() ? " (splittable)" : "");
1433  }
1434  
printUse(raw_ostream & OS,const_iterator I,StringRef Indent) const1435  void AllocaSlices::printUse(raw_ostream &OS, const_iterator I,
1436                              StringRef Indent) const {
1437    OS << Indent << "  used by: " << *I->getUse()->getUser() << "\n";
1438  }
1439  
print(raw_ostream & OS) const1440  void AllocaSlices::print(raw_ostream &OS) const {
1441    if (PointerEscapingInstr) {
1442      OS << "Can't analyze slices for alloca: " << AI << "\n"
1443         << "  A pointer to this alloca escaped by:\n"
1444         << "  " << *PointerEscapingInstr << "\n";
1445      return;
1446    }
1447  
1448    OS << "Slices of alloca: " << AI << "\n";
1449    for (const_iterator I = begin(), E = end(); I != E; ++I)
1450      print(OS, I);
1451  }
1452  
dump(const_iterator I) const1453  LLVM_DUMP_METHOD void AllocaSlices::dump(const_iterator I) const {
1454    print(dbgs(), I);
1455  }
dump() const1456  LLVM_DUMP_METHOD void AllocaSlices::dump() const { print(dbgs()); }
1457  
1458  #endif // !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
1459  
1460  /// Walk the range of a partitioning looking for a common type to cover this
1461  /// sequence of slices.
1462  static std::pair<Type *, IntegerType *>
findCommonType(AllocaSlices::const_iterator B,AllocaSlices::const_iterator E,uint64_t EndOffset)1463  findCommonType(AllocaSlices::const_iterator B, AllocaSlices::const_iterator E,
1464                 uint64_t EndOffset) {
1465    Type *Ty = nullptr;
1466    bool TyIsCommon = true;
1467    IntegerType *ITy = nullptr;
1468  
1469    // Note that we need to look at *every* alloca slice's Use to ensure we
1470    // always get consistent results regardless of the order of slices.
1471    for (AllocaSlices::const_iterator I = B; I != E; ++I) {
1472      Use *U = I->getUse();
1473      if (isa<IntrinsicInst>(*U->getUser()))
1474        continue;
1475      if (I->beginOffset() != B->beginOffset() || I->endOffset() != EndOffset)
1476        continue;
1477  
1478      Type *UserTy = nullptr;
1479      if (LoadInst *LI = dyn_cast<LoadInst>(U->getUser())) {
1480        UserTy = LI->getType();
1481      } else if (StoreInst *SI = dyn_cast<StoreInst>(U->getUser())) {
1482        UserTy = SI->getValueOperand()->getType();
1483      }
1484  
1485      if (IntegerType *UserITy = dyn_cast_or_null<IntegerType>(UserTy)) {
1486        // If the type is larger than the partition, skip it. We only encounter
1487        // this for split integer operations where we want to use the type of the
1488        // entity causing the split. Also skip if the type is not a byte width
1489        // multiple.
1490        if (UserITy->getBitWidth() % 8 != 0 ||
1491            UserITy->getBitWidth() / 8 > (EndOffset - B->beginOffset()))
1492          continue;
1493  
1494        // Track the largest bitwidth integer type used in this way in case there
1495        // is no common type.
1496        if (!ITy || ITy->getBitWidth() < UserITy->getBitWidth())
1497          ITy = UserITy;
1498      }
1499  
1500      // To avoid depending on the order of slices, Ty and TyIsCommon must not
1501      // depend on types skipped above.
1502      if (!UserTy || (Ty && Ty != UserTy))
1503        TyIsCommon = false; // Give up on anything but an iN type.
1504      else
1505        Ty = UserTy;
1506    }
1507  
1508    return {TyIsCommon ? Ty : nullptr, ITy};
1509  }
1510  
1511  /// PHI instructions that use an alloca and are subsequently loaded can be
1512  /// rewritten to load both input pointers in the pred blocks and then PHI the
1513  /// results, allowing the load of the alloca to be promoted.
1514  /// From this:
1515  ///   %P2 = phi [i32* %Alloca, i32* %Other]
1516  ///   %V = load i32* %P2
1517  /// to:
1518  ///   %V1 = load i32* %Alloca      -> will be mem2reg'd
1519  ///   ...
1520  ///   %V2 = load i32* %Other
1521  ///   ...
1522  ///   %V = phi [i32 %V1, i32 %V2]
1523  ///
1524  /// We can do this to a select if its only uses are loads and if the operands
1525  /// to the select can be loaded unconditionally.
1526  ///
1527  /// FIXME: This should be hoisted into a generic utility, likely in
1528  /// Transforms/Util/Local.h
isSafePHIToSpeculate(PHINode & PN)1529  static bool isSafePHIToSpeculate(PHINode &PN) {
1530    const DataLayout &DL = PN.getDataLayout();
1531  
1532    // For now, we can only do this promotion if the load is in the same block
1533    // as the PHI, and if there are no stores between the phi and load.
1534    // TODO: Allow recursive phi users.
1535    // TODO: Allow stores.
1536    BasicBlock *BB = PN.getParent();
1537    Align MaxAlign;
1538    uint64_t APWidth = DL.getIndexTypeSizeInBits(PN.getType());
1539    Type *LoadType = nullptr;
1540    for (User *U : PN.users()) {
1541      LoadInst *LI = dyn_cast<LoadInst>(U);
1542      if (!LI || !LI->isSimple())
1543        return false;
1544  
1545      // For now we only allow loads in the same block as the PHI.  This is
1546      // a common case that happens when instcombine merges two loads through
1547      // a PHI.
1548      if (LI->getParent() != BB)
1549        return false;
1550  
1551      if (LoadType) {
1552        if (LoadType != LI->getType())
1553          return false;
1554      } else {
1555        LoadType = LI->getType();
1556      }
1557  
1558      // Ensure that there are no instructions between the PHI and the load that
1559      // could store.
1560      for (BasicBlock::iterator BBI(PN); &*BBI != LI; ++BBI)
1561        if (BBI->mayWriteToMemory())
1562          return false;
1563  
1564      MaxAlign = std::max(MaxAlign, LI->getAlign());
1565    }
1566  
1567    if (!LoadType)
1568      return false;
1569  
1570    APInt LoadSize =
1571        APInt(APWidth, DL.getTypeStoreSize(LoadType).getFixedValue());
1572  
1573    // We can only transform this if it is safe to push the loads into the
1574    // predecessor blocks. The only thing to watch out for is that we can't put
1575    // a possibly trapping load in the predecessor if it is a critical edge.
1576    for (unsigned Idx = 0, Num = PN.getNumIncomingValues(); Idx != Num; ++Idx) {
1577      Instruction *TI = PN.getIncomingBlock(Idx)->getTerminator();
1578      Value *InVal = PN.getIncomingValue(Idx);
1579  
1580      // If the value is produced by the terminator of the predecessor (an
1581      // invoke) or it has side-effects, there is no valid place to put a load
1582      // in the predecessor.
1583      if (TI == InVal || TI->mayHaveSideEffects())
1584        return false;
1585  
1586      // If the predecessor has a single successor, then the edge isn't
1587      // critical.
1588      if (TI->getNumSuccessors() == 1)
1589        continue;
1590  
1591      // If this pointer is always safe to load, or if we can prove that there
1592      // is already a load in the block, then we can move the load to the pred
1593      // block.
1594      if (isSafeToLoadUnconditionally(InVal, MaxAlign, LoadSize, DL, TI))
1595        continue;
1596  
1597      return false;
1598    }
1599  
1600    return true;
1601  }
1602  
speculatePHINodeLoads(IRBuilderTy & IRB,PHINode & PN)1603  static void speculatePHINodeLoads(IRBuilderTy &IRB, PHINode &PN) {
1604    LLVM_DEBUG(dbgs() << "    original: " << PN << "\n");
1605  
1606    LoadInst *SomeLoad = cast<LoadInst>(PN.user_back());
1607    Type *LoadTy = SomeLoad->getType();
1608    IRB.SetInsertPoint(&PN);
1609    PHINode *NewPN = IRB.CreatePHI(LoadTy, PN.getNumIncomingValues(),
1610                                   PN.getName() + ".sroa.speculated");
1611  
1612    // Get the AA tags and alignment to use from one of the loads. It does not
1613    // matter which one we get and if any differ.
1614    AAMDNodes AATags = SomeLoad->getAAMetadata();
1615    Align Alignment = SomeLoad->getAlign();
1616  
1617    // Rewrite all loads of the PN to use the new PHI.
1618    while (!PN.use_empty()) {
1619      LoadInst *LI = cast<LoadInst>(PN.user_back());
1620      LI->replaceAllUsesWith(NewPN);
1621      LI->eraseFromParent();
1622    }
1623  
1624    // Inject loads into all of the pred blocks.
1625    DenseMap<BasicBlock *, Value *> InjectedLoads;
1626    for (unsigned Idx = 0, Num = PN.getNumIncomingValues(); Idx != Num; ++Idx) {
1627      BasicBlock *Pred = PN.getIncomingBlock(Idx);
1628      Value *InVal = PN.getIncomingValue(Idx);
1629  
1630      // A PHI node is allowed to have multiple (duplicated) entries for the same
1631      // basic block, as long as the value is the same. So if we already injected
1632      // a load in the predecessor, then we should reuse the same load for all
1633      // duplicated entries.
1634      if (Value *V = InjectedLoads.lookup(Pred)) {
1635        NewPN->addIncoming(V, Pred);
1636        continue;
1637      }
1638  
1639      Instruction *TI = Pred->getTerminator();
1640      IRB.SetInsertPoint(TI);
1641  
1642      LoadInst *Load = IRB.CreateAlignedLoad(
1643          LoadTy, InVal, Alignment,
1644          (PN.getName() + ".sroa.speculate.load." + Pred->getName()));
1645      ++NumLoadsSpeculated;
1646      if (AATags)
1647        Load->setAAMetadata(AATags);
1648      NewPN->addIncoming(Load, Pred);
1649      InjectedLoads[Pred] = Load;
1650    }
1651  
1652    LLVM_DEBUG(dbgs() << "          speculated to: " << *NewPN << "\n");
1653    PN.eraseFromParent();
1654  }
1655  
1656  SelectHandSpeculativity &
setAsSpeculatable(bool isTrueVal)1657  SelectHandSpeculativity::setAsSpeculatable(bool isTrueVal) {
1658    if (isTrueVal)
1659      Bitfield::set<SelectHandSpeculativity::TrueVal>(Storage, true);
1660    else
1661      Bitfield::set<SelectHandSpeculativity::FalseVal>(Storage, true);
1662    return *this;
1663  }
1664  
isSpeculatable(bool isTrueVal) const1665  bool SelectHandSpeculativity::isSpeculatable(bool isTrueVal) const {
1666    return isTrueVal ? Bitfield::get<SelectHandSpeculativity::TrueVal>(Storage)
1667                     : Bitfield::get<SelectHandSpeculativity::FalseVal>(Storage);
1668  }
1669  
areAllSpeculatable() const1670  bool SelectHandSpeculativity::areAllSpeculatable() const {
1671    return isSpeculatable(/*isTrueVal=*/true) &&
1672           isSpeculatable(/*isTrueVal=*/false);
1673  }
1674  
areAnySpeculatable() const1675  bool SelectHandSpeculativity::areAnySpeculatable() const {
1676    return isSpeculatable(/*isTrueVal=*/true) ||
1677           isSpeculatable(/*isTrueVal=*/false);
1678  }
areNoneSpeculatable() const1679  bool SelectHandSpeculativity::areNoneSpeculatable() const {
1680    return !areAnySpeculatable();
1681  }
1682  
1683  static SelectHandSpeculativity
isSafeLoadOfSelectToSpeculate(LoadInst & LI,SelectInst & SI,bool PreserveCFG)1684  isSafeLoadOfSelectToSpeculate(LoadInst &LI, SelectInst &SI, bool PreserveCFG) {
1685    assert(LI.isSimple() && "Only for simple loads");
1686    SelectHandSpeculativity Spec;
1687  
1688    const DataLayout &DL = SI.getDataLayout();
1689    for (Value *Value : {SI.getTrueValue(), SI.getFalseValue()})
1690      if (isSafeToLoadUnconditionally(Value, LI.getType(), LI.getAlign(), DL,
1691                                      &LI))
1692        Spec.setAsSpeculatable(/*isTrueVal=*/Value == SI.getTrueValue());
1693      else if (PreserveCFG)
1694        return Spec;
1695  
1696    return Spec;
1697  }
1698  
1699  std::optional<RewriteableMemOps>
isSafeSelectToSpeculate(SelectInst & SI,bool PreserveCFG)1700  SROA::isSafeSelectToSpeculate(SelectInst &SI, bool PreserveCFG) {
1701    RewriteableMemOps Ops;
1702  
1703    for (User *U : SI.users()) {
1704      if (auto *BC = dyn_cast<BitCastInst>(U); BC && BC->hasOneUse())
1705        U = *BC->user_begin();
1706  
1707      if (auto *Store = dyn_cast<StoreInst>(U)) {
1708        // Note that atomic stores can be transformed; atomic semantics do not
1709        // have any meaning for a local alloca. Stores are not speculatable,
1710        // however, so if we can't turn it into a predicated store, we are done.
1711        if (Store->isVolatile() || PreserveCFG)
1712          return {}; // Give up on this `select`.
1713        Ops.emplace_back(Store);
1714        continue;
1715      }
1716  
1717      auto *LI = dyn_cast<LoadInst>(U);
1718  
1719      // Note that atomic loads can be transformed;
1720      // atomic semantics do not have any meaning for a local alloca.
1721      if (!LI || LI->isVolatile())
1722        return {}; // Give up on this `select`.
1723  
1724      PossiblySpeculatableLoad Load(LI);
1725      if (!LI->isSimple()) {
1726        // If the `load` is not simple, we can't speculatively execute it,
1727        // but we could handle this via a CFG modification. But can we?
1728        if (PreserveCFG)
1729          return {}; // Give up on this `select`.
1730        Ops.emplace_back(Load);
1731        continue;
1732      }
1733  
1734      SelectHandSpeculativity Spec =
1735          isSafeLoadOfSelectToSpeculate(*LI, SI, PreserveCFG);
1736      if (PreserveCFG && !Spec.areAllSpeculatable())
1737        return {}; // Give up on this `select`.
1738  
1739      Load.setInt(Spec);
1740      Ops.emplace_back(Load);
1741    }
1742  
1743    return Ops;
1744  }
1745  
speculateSelectInstLoads(SelectInst & SI,LoadInst & LI,IRBuilderTy & IRB)1746  static void speculateSelectInstLoads(SelectInst &SI, LoadInst &LI,
1747                                       IRBuilderTy &IRB) {
1748    LLVM_DEBUG(dbgs() << "    original load: " << SI << "\n");
1749  
1750    Value *TV = SI.getTrueValue();
1751    Value *FV = SI.getFalseValue();
1752    // Replace the given load of the select with a select of two loads.
1753  
1754    assert(LI.isSimple() && "We only speculate simple loads");
1755  
1756    IRB.SetInsertPoint(&LI);
1757  
1758    LoadInst *TL =
1759        IRB.CreateAlignedLoad(LI.getType(), TV, LI.getAlign(),
1760                              LI.getName() + ".sroa.speculate.load.true");
1761    LoadInst *FL =
1762        IRB.CreateAlignedLoad(LI.getType(), FV, LI.getAlign(),
1763                              LI.getName() + ".sroa.speculate.load.false");
1764    NumLoadsSpeculated += 2;
1765  
1766    // Transfer alignment and AA info if present.
1767    TL->setAlignment(LI.getAlign());
1768    FL->setAlignment(LI.getAlign());
1769  
1770    AAMDNodes Tags = LI.getAAMetadata();
1771    if (Tags) {
1772      TL->setAAMetadata(Tags);
1773      FL->setAAMetadata(Tags);
1774    }
1775  
1776    Value *V = IRB.CreateSelect(SI.getCondition(), TL, FL,
1777                                LI.getName() + ".sroa.speculated");
1778  
1779    LLVM_DEBUG(dbgs() << "          speculated to: " << *V << "\n");
1780    LI.replaceAllUsesWith(V);
1781  }
1782  
1783  template <typename T>
rewriteMemOpOfSelect(SelectInst & SI,T & I,SelectHandSpeculativity Spec,DomTreeUpdater & DTU)1784  static void rewriteMemOpOfSelect(SelectInst &SI, T &I,
1785                                   SelectHandSpeculativity Spec,
1786                                   DomTreeUpdater &DTU) {
1787    assert((isa<LoadInst>(I) || isa<StoreInst>(I)) && "Only for load and store!");
1788    LLVM_DEBUG(dbgs() << "    original mem op: " << I << "\n");
1789    BasicBlock *Head = I.getParent();
1790    Instruction *ThenTerm = nullptr;
1791    Instruction *ElseTerm = nullptr;
1792    if (Spec.areNoneSpeculatable())
1793      SplitBlockAndInsertIfThenElse(SI.getCondition(), &I, &ThenTerm, &ElseTerm,
1794                                    SI.getMetadata(LLVMContext::MD_prof), &DTU);
1795    else {
1796      SplitBlockAndInsertIfThen(SI.getCondition(), &I, /*Unreachable=*/false,
1797                                SI.getMetadata(LLVMContext::MD_prof), &DTU,
1798                                /*LI=*/nullptr, /*ThenBlock=*/nullptr);
1799      if (Spec.isSpeculatable(/*isTrueVal=*/true))
1800        cast<BranchInst>(Head->getTerminator())->swapSuccessors();
1801    }
1802    auto *HeadBI = cast<BranchInst>(Head->getTerminator());
1803    Spec = {}; // Do not use `Spec` beyond this point.
1804    BasicBlock *Tail = I.getParent();
1805    Tail->setName(Head->getName() + ".cont");
1806    PHINode *PN;
1807    if (isa<LoadInst>(I))
1808      PN = PHINode::Create(I.getType(), 2, "", I.getIterator());
1809    for (BasicBlock *SuccBB : successors(Head)) {
1810      bool IsThen = SuccBB == HeadBI->getSuccessor(0);
1811      int SuccIdx = IsThen ? 0 : 1;
1812      auto *NewMemOpBB = SuccBB == Tail ? Head : SuccBB;
1813      auto &CondMemOp = cast<T>(*I.clone());
1814      if (NewMemOpBB != Head) {
1815        NewMemOpBB->setName(Head->getName() + (IsThen ? ".then" : ".else"));
1816        if (isa<LoadInst>(I))
1817          ++NumLoadsPredicated;
1818        else
1819          ++NumStoresPredicated;
1820      } else {
1821        CondMemOp.dropUBImplyingAttrsAndMetadata();
1822        ++NumLoadsSpeculated;
1823      }
1824      CondMemOp.insertBefore(NewMemOpBB->getTerminator());
1825      Value *Ptr = SI.getOperand(1 + SuccIdx);
1826      CondMemOp.setOperand(I.getPointerOperandIndex(), Ptr);
1827      if (isa<LoadInst>(I)) {
1828        CondMemOp.setName(I.getName() + (IsThen ? ".then" : ".else") + ".val");
1829        PN->addIncoming(&CondMemOp, NewMemOpBB);
1830      } else
1831        LLVM_DEBUG(dbgs() << "                 to: " << CondMemOp << "\n");
1832    }
1833    if (isa<LoadInst>(I)) {
1834      PN->takeName(&I);
1835      LLVM_DEBUG(dbgs() << "          to: " << *PN << "\n");
1836      I.replaceAllUsesWith(PN);
1837    }
1838  }
1839  
rewriteMemOpOfSelect(SelectInst & SelInst,Instruction & I,SelectHandSpeculativity Spec,DomTreeUpdater & DTU)1840  static void rewriteMemOpOfSelect(SelectInst &SelInst, Instruction &I,
1841                                   SelectHandSpeculativity Spec,
1842                                   DomTreeUpdater &DTU) {
1843    if (auto *LI = dyn_cast<LoadInst>(&I))
1844      rewriteMemOpOfSelect(SelInst, *LI, Spec, DTU);
1845    else if (auto *SI = dyn_cast<StoreInst>(&I))
1846      rewriteMemOpOfSelect(SelInst, *SI, Spec, DTU);
1847    else
1848      llvm_unreachable_internal("Only for load and store.");
1849  }
1850  
rewriteSelectInstMemOps(SelectInst & SI,const RewriteableMemOps & Ops,IRBuilderTy & IRB,DomTreeUpdater * DTU)1851  static bool rewriteSelectInstMemOps(SelectInst &SI,
1852                                      const RewriteableMemOps &Ops,
1853                                      IRBuilderTy &IRB, DomTreeUpdater *DTU) {
1854    bool CFGChanged = false;
1855    LLVM_DEBUG(dbgs() << "    original select: " << SI << "\n");
1856  
1857    for (const RewriteableMemOp &Op : Ops) {
1858      SelectHandSpeculativity Spec;
1859      Instruction *I;
1860      if (auto *const *US = std::get_if<UnspeculatableStore>(&Op)) {
1861        I = *US;
1862      } else {
1863        auto PSL = std::get<PossiblySpeculatableLoad>(Op);
1864        I = PSL.getPointer();
1865        Spec = PSL.getInt();
1866      }
1867      if (Spec.areAllSpeculatable()) {
1868        speculateSelectInstLoads(SI, cast<LoadInst>(*I), IRB);
1869      } else {
1870        assert(DTU && "Should not get here when not allowed to modify the CFG!");
1871        rewriteMemOpOfSelect(SI, *I, Spec, *DTU);
1872        CFGChanged = true;
1873      }
1874      I->eraseFromParent();
1875    }
1876  
1877    for (User *U : make_early_inc_range(SI.users()))
1878      cast<BitCastInst>(U)->eraseFromParent();
1879    SI.eraseFromParent();
1880    return CFGChanged;
1881  }
1882  
1883  /// Compute an adjusted pointer from Ptr by Offset bytes where the
1884  /// resulting pointer has PointerTy.
getAdjustedPtr(IRBuilderTy & IRB,const DataLayout & DL,Value * Ptr,APInt Offset,Type * PointerTy,const Twine & NamePrefix)1885  static Value *getAdjustedPtr(IRBuilderTy &IRB, const DataLayout &DL, Value *Ptr,
1886                               APInt Offset, Type *PointerTy,
1887                               const Twine &NamePrefix) {
1888    if (Offset != 0)
1889      Ptr = IRB.CreateInBoundsPtrAdd(Ptr, IRB.getInt(Offset),
1890                                     NamePrefix + "sroa_idx");
1891    return IRB.CreatePointerBitCastOrAddrSpaceCast(Ptr, PointerTy,
1892                                                   NamePrefix + "sroa_cast");
1893  }
1894  
1895  /// Compute the adjusted alignment for a load or store from an offset.
getAdjustedAlignment(Instruction * I,uint64_t Offset)1896  static Align getAdjustedAlignment(Instruction *I, uint64_t Offset) {
1897    return commonAlignment(getLoadStoreAlignment(I), Offset);
1898  }
1899  
1900  /// Test whether we can convert a value from the old to the new type.
1901  ///
1902  /// This predicate should be used to guard calls to convertValue in order to
1903  /// ensure that we only try to convert viable values. The strategy is that we
1904  /// will peel off single element struct and array wrappings to get to an
1905  /// underlying value, and convert that value.
canConvertValue(const DataLayout & DL,Type * OldTy,Type * NewTy)1906  static bool canConvertValue(const DataLayout &DL, Type *OldTy, Type *NewTy) {
1907    if (OldTy == NewTy)
1908      return true;
1909  
1910    // For integer types, we can't handle any bit-width differences. This would
1911    // break both vector conversions with extension and introduce endianness
1912    // issues when in conjunction with loads and stores.
1913    if (isa<IntegerType>(OldTy) && isa<IntegerType>(NewTy)) {
1914      assert(cast<IntegerType>(OldTy)->getBitWidth() !=
1915                 cast<IntegerType>(NewTy)->getBitWidth() &&
1916             "We can't have the same bitwidth for different int types");
1917      return false;
1918    }
1919  
1920    if (DL.getTypeSizeInBits(NewTy).getFixedValue() !=
1921        DL.getTypeSizeInBits(OldTy).getFixedValue())
1922      return false;
1923    if (!NewTy->isSingleValueType() || !OldTy->isSingleValueType())
1924      return false;
1925  
1926    // We can convert pointers to integers and vice-versa. Same for vectors
1927    // of pointers and integers.
1928    OldTy = OldTy->getScalarType();
1929    NewTy = NewTy->getScalarType();
1930    if (NewTy->isPointerTy() || OldTy->isPointerTy()) {
1931      if (NewTy->isPointerTy() && OldTy->isPointerTy()) {
1932        unsigned OldAS = OldTy->getPointerAddressSpace();
1933        unsigned NewAS = NewTy->getPointerAddressSpace();
1934        // Convert pointers if they are pointers from the same address space or
1935        // different integral (not non-integral) address spaces with the same
1936        // pointer size.
1937        return OldAS == NewAS ||
1938               (!DL.isNonIntegralAddressSpace(OldAS) &&
1939                !DL.isNonIntegralAddressSpace(NewAS) &&
1940                DL.getPointerSize(OldAS) == DL.getPointerSize(NewAS));
1941      }
1942  
1943      // We can convert integers to integral pointers, but not to non-integral
1944      // pointers.
1945      if (OldTy->isIntegerTy())
1946        return !DL.isNonIntegralPointerType(NewTy);
1947  
1948      // We can convert integral pointers to integers, but non-integral pointers
1949      // need to remain pointers.
1950      if (!DL.isNonIntegralPointerType(OldTy))
1951        return NewTy->isIntegerTy();
1952  
1953      return false;
1954    }
1955  
1956    if (OldTy->isTargetExtTy() || NewTy->isTargetExtTy())
1957      return false;
1958  
1959    return true;
1960  }
1961  
1962  /// Generic routine to convert an SSA value to a value of a different
1963  /// type.
1964  ///
1965  /// This will try various different casting techniques, such as bitcasts,
1966  /// inttoptr, and ptrtoint casts. Use the \c canConvertValue predicate to test
1967  /// two types for viability with this routine.
convertValue(const DataLayout & DL,IRBuilderTy & IRB,Value * V,Type * NewTy)1968  static Value *convertValue(const DataLayout &DL, IRBuilderTy &IRB, Value *V,
1969                             Type *NewTy) {
1970    Type *OldTy = V->getType();
1971    assert(canConvertValue(DL, OldTy, NewTy) && "Value not convertable to type");
1972  
1973    if (OldTy == NewTy)
1974      return V;
1975  
1976    assert(!(isa<IntegerType>(OldTy) && isa<IntegerType>(NewTy)) &&
1977           "Integer types must be the exact same to convert.");
1978  
1979    // See if we need inttoptr for this type pair. May require additional bitcast.
1980    if (OldTy->isIntOrIntVectorTy() && NewTy->isPtrOrPtrVectorTy()) {
1981      // Expand <2 x i32> to i8* --> <2 x i32> to i64 to i8*
1982      // Expand i128 to <2 x i8*> --> i128 to <2 x i64> to <2 x i8*>
1983      // Expand <4 x i32> to <2 x i8*> --> <4 x i32> to <2 x i64> to <2 x i8*>
1984      // Directly handle i64 to i8*
1985      return IRB.CreateIntToPtr(IRB.CreateBitCast(V, DL.getIntPtrType(NewTy)),
1986                                NewTy);
1987    }
1988  
1989    // See if we need ptrtoint for this type pair. May require additional bitcast.
1990    if (OldTy->isPtrOrPtrVectorTy() && NewTy->isIntOrIntVectorTy()) {
1991      // Expand <2 x i8*> to i128 --> <2 x i8*> to <2 x i64> to i128
1992      // Expand i8* to <2 x i32> --> i8* to i64 to <2 x i32>
1993      // Expand <2 x i8*> to <4 x i32> --> <2 x i8*> to <2 x i64> to <4 x i32>
1994      // Expand i8* to i64 --> i8* to i64 to i64
1995      return IRB.CreateBitCast(IRB.CreatePtrToInt(V, DL.getIntPtrType(OldTy)),
1996                               NewTy);
1997    }
1998  
1999    if (OldTy->isPtrOrPtrVectorTy() && NewTy->isPtrOrPtrVectorTy()) {
2000      unsigned OldAS = OldTy->getPointerAddressSpace();
2001      unsigned NewAS = NewTy->getPointerAddressSpace();
2002      // To convert pointers with different address spaces (they are already
2003      // checked convertible, i.e. they have the same pointer size), so far we
2004      // cannot use `bitcast` (which has restrict on the same address space) or
2005      // `addrspacecast` (which is not always no-op casting). Instead, use a pair
2006      // of no-op `ptrtoint`/`inttoptr` casts through an integer with the same bit
2007      // size.
2008      if (OldAS != NewAS) {
2009        assert(DL.getPointerSize(OldAS) == DL.getPointerSize(NewAS));
2010        return IRB.CreateIntToPtr(IRB.CreatePtrToInt(V, DL.getIntPtrType(OldTy)),
2011                                  NewTy);
2012      }
2013    }
2014  
2015    return IRB.CreateBitCast(V, NewTy);
2016  }
2017  
2018  /// Test whether the given slice use can be promoted to a vector.
2019  ///
2020  /// This function is called to test each entry in a partition which is slated
2021  /// for a single slice.
isVectorPromotionViableForSlice(Partition & P,const Slice & S,VectorType * Ty,uint64_t ElementSize,const DataLayout & DL)2022  static bool isVectorPromotionViableForSlice(Partition &P, const Slice &S,
2023                                              VectorType *Ty,
2024                                              uint64_t ElementSize,
2025                                              const DataLayout &DL) {
2026    // First validate the slice offsets.
2027    uint64_t BeginOffset =
2028        std::max(S.beginOffset(), P.beginOffset()) - P.beginOffset();
2029    uint64_t BeginIndex = BeginOffset / ElementSize;
2030    if (BeginIndex * ElementSize != BeginOffset ||
2031        BeginIndex >= cast<FixedVectorType>(Ty)->getNumElements())
2032      return false;
2033    uint64_t EndOffset = std::min(S.endOffset(), P.endOffset()) - P.beginOffset();
2034    uint64_t EndIndex = EndOffset / ElementSize;
2035    if (EndIndex * ElementSize != EndOffset ||
2036        EndIndex > cast<FixedVectorType>(Ty)->getNumElements())
2037      return false;
2038  
2039    assert(EndIndex > BeginIndex && "Empty vector!");
2040    uint64_t NumElements = EndIndex - BeginIndex;
2041    Type *SliceTy = (NumElements == 1)
2042                        ? Ty->getElementType()
2043                        : FixedVectorType::get(Ty->getElementType(), NumElements);
2044  
2045    Type *SplitIntTy =
2046        Type::getIntNTy(Ty->getContext(), NumElements * ElementSize * 8);
2047  
2048    Use *U = S.getUse();
2049  
2050    if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(U->getUser())) {
2051      if (MI->isVolatile())
2052        return false;
2053      if (!S.isSplittable())
2054        return false; // Skip any unsplittable intrinsics.
2055    } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(U->getUser())) {
2056      if (!II->isLifetimeStartOrEnd() && !II->isDroppable())
2057        return false;
2058    } else if (LoadInst *LI = dyn_cast<LoadInst>(U->getUser())) {
2059      if (LI->isVolatile())
2060        return false;
2061      Type *LTy = LI->getType();
2062      // Disable vector promotion when there are loads or stores of an FCA.
2063      if (LTy->isStructTy())
2064        return false;
2065      if (P.beginOffset() > S.beginOffset() || P.endOffset() < S.endOffset()) {
2066        assert(LTy->isIntegerTy());
2067        LTy = SplitIntTy;
2068      }
2069      if (!canConvertValue(DL, SliceTy, LTy))
2070        return false;
2071    } else if (StoreInst *SI = dyn_cast<StoreInst>(U->getUser())) {
2072      if (SI->isVolatile())
2073        return false;
2074      Type *STy = SI->getValueOperand()->getType();
2075      // Disable vector promotion when there are loads or stores of an FCA.
2076      if (STy->isStructTy())
2077        return false;
2078      if (P.beginOffset() > S.beginOffset() || P.endOffset() < S.endOffset()) {
2079        assert(STy->isIntegerTy());
2080        STy = SplitIntTy;
2081      }
2082      if (!canConvertValue(DL, STy, SliceTy))
2083        return false;
2084    } else {
2085      return false;
2086    }
2087  
2088    return true;
2089  }
2090  
2091  /// Test whether a vector type is viable for promotion.
2092  ///
2093  /// This implements the necessary checking for \c checkVectorTypesForPromotion
2094  /// (and thus isVectorPromotionViable) over all slices of the alloca for the
2095  /// given VectorType.
checkVectorTypeForPromotion(Partition & P,VectorType * VTy,const DataLayout & DL)2096  static bool checkVectorTypeForPromotion(Partition &P, VectorType *VTy,
2097                                          const DataLayout &DL) {
2098    uint64_t ElementSize =
2099        DL.getTypeSizeInBits(VTy->getElementType()).getFixedValue();
2100  
2101    // While the definition of LLVM vectors is bitpacked, we don't support sizes
2102    // that aren't byte sized.
2103    if (ElementSize % 8)
2104      return false;
2105    assert((DL.getTypeSizeInBits(VTy).getFixedValue() % 8) == 0 &&
2106           "vector size not a multiple of element size?");
2107    ElementSize /= 8;
2108  
2109    for (const Slice &S : P)
2110      if (!isVectorPromotionViableForSlice(P, S, VTy, ElementSize, DL))
2111        return false;
2112  
2113    for (const Slice *S : P.splitSliceTails())
2114      if (!isVectorPromotionViableForSlice(P, *S, VTy, ElementSize, DL))
2115        return false;
2116  
2117    return true;
2118  }
2119  
2120  /// Test whether any vector type in \p CandidateTys is viable for promotion.
2121  ///
2122  /// This implements the necessary checking for \c isVectorPromotionViable over
2123  /// all slices of the alloca for the given VectorType.
2124  static VectorType *
checkVectorTypesForPromotion(Partition & P,const DataLayout & DL,SmallVectorImpl<VectorType * > & CandidateTys,bool HaveCommonEltTy,Type * CommonEltTy,bool HaveVecPtrTy,bool HaveCommonVecPtrTy,VectorType * CommonVecPtrTy)2125  checkVectorTypesForPromotion(Partition &P, const DataLayout &DL,
2126                               SmallVectorImpl<VectorType *> &CandidateTys,
2127                               bool HaveCommonEltTy, Type *CommonEltTy,
2128                               bool HaveVecPtrTy, bool HaveCommonVecPtrTy,
2129                               VectorType *CommonVecPtrTy) {
2130    // If we didn't find a vector type, nothing to do here.
2131    if (CandidateTys.empty())
2132      return nullptr;
2133  
2134    // Pointer-ness is sticky, if we had a vector-of-pointers candidate type,
2135    // then we should choose it, not some other alternative.
2136    // But, we can't perform a no-op pointer address space change via bitcast,
2137    // so if we didn't have a common pointer element type, bail.
2138    if (HaveVecPtrTy && !HaveCommonVecPtrTy)
2139      return nullptr;
2140  
2141    // Try to pick the "best" element type out of the choices.
2142    if (!HaveCommonEltTy && HaveVecPtrTy) {
2143      // If there was a pointer element type, there's really only one choice.
2144      CandidateTys.clear();
2145      CandidateTys.push_back(CommonVecPtrTy);
2146    } else if (!HaveCommonEltTy && !HaveVecPtrTy) {
2147      // Integer-ify vector types.
2148      for (VectorType *&VTy : CandidateTys) {
2149        if (!VTy->getElementType()->isIntegerTy())
2150          VTy = cast<VectorType>(VTy->getWithNewType(IntegerType::getIntNTy(
2151              VTy->getContext(), VTy->getScalarSizeInBits())));
2152      }
2153  
2154      // Rank the remaining candidate vector types. This is easy because we know
2155      // they're all integer vectors. We sort by ascending number of elements.
2156      auto RankVectorTypesComp = [&DL](VectorType *RHSTy, VectorType *LHSTy) {
2157        (void)DL;
2158        assert(DL.getTypeSizeInBits(RHSTy).getFixedValue() ==
2159                   DL.getTypeSizeInBits(LHSTy).getFixedValue() &&
2160               "Cannot have vector types of different sizes!");
2161        assert(RHSTy->getElementType()->isIntegerTy() &&
2162               "All non-integer types eliminated!");
2163        assert(LHSTy->getElementType()->isIntegerTy() &&
2164               "All non-integer types eliminated!");
2165        return cast<FixedVectorType>(RHSTy)->getNumElements() <
2166               cast<FixedVectorType>(LHSTy)->getNumElements();
2167      };
2168      auto RankVectorTypesEq = [&DL](VectorType *RHSTy, VectorType *LHSTy) {
2169        (void)DL;
2170        assert(DL.getTypeSizeInBits(RHSTy).getFixedValue() ==
2171                   DL.getTypeSizeInBits(LHSTy).getFixedValue() &&
2172               "Cannot have vector types of different sizes!");
2173        assert(RHSTy->getElementType()->isIntegerTy() &&
2174               "All non-integer types eliminated!");
2175        assert(LHSTy->getElementType()->isIntegerTy() &&
2176               "All non-integer types eliminated!");
2177        return cast<FixedVectorType>(RHSTy)->getNumElements() ==
2178               cast<FixedVectorType>(LHSTy)->getNumElements();
2179      };
2180      llvm::sort(CandidateTys, RankVectorTypesComp);
2181      CandidateTys.erase(llvm::unique(CandidateTys, RankVectorTypesEq),
2182                         CandidateTys.end());
2183    } else {
2184  // The only way to have the same element type in every vector type is to
2185  // have the same vector type. Check that and remove all but one.
2186  #ifndef NDEBUG
2187      for (VectorType *VTy : CandidateTys) {
2188        assert(VTy->getElementType() == CommonEltTy &&
2189               "Unaccounted for element type!");
2190        assert(VTy == CandidateTys[0] &&
2191               "Different vector types with the same element type!");
2192      }
2193  #endif
2194      CandidateTys.resize(1);
2195    }
2196  
2197    // FIXME: hack. Do we have a named constant for this?
2198    // SDAG SDNode can't have more than 65535 operands.
2199    llvm::erase_if(CandidateTys, [](VectorType *VTy) {
2200      return cast<FixedVectorType>(VTy)->getNumElements() >
2201             std::numeric_limits<unsigned short>::max();
2202    });
2203  
2204    for (VectorType *VTy : CandidateTys)
2205      if (checkVectorTypeForPromotion(P, VTy, DL))
2206        return VTy;
2207  
2208    return nullptr;
2209  }
2210  
createAndCheckVectorTypesForPromotion(SetVector<Type * > & OtherTys,ArrayRef<VectorType * > CandidateTysCopy,function_ref<void (Type *)> CheckCandidateType,Partition & P,const DataLayout & DL,SmallVectorImpl<VectorType * > & CandidateTys,bool & HaveCommonEltTy,Type * & CommonEltTy,bool & HaveVecPtrTy,bool & HaveCommonVecPtrTy,VectorType * & CommonVecPtrTy)2211  static VectorType *createAndCheckVectorTypesForPromotion(
2212      SetVector<Type *> &OtherTys, ArrayRef<VectorType *> CandidateTysCopy,
2213      function_ref<void(Type *)> CheckCandidateType, Partition &P,
2214      const DataLayout &DL, SmallVectorImpl<VectorType *> &CandidateTys,
2215      bool &HaveCommonEltTy, Type *&CommonEltTy, bool &HaveVecPtrTy,
2216      bool &HaveCommonVecPtrTy, VectorType *&CommonVecPtrTy) {
2217    [[maybe_unused]] VectorType *OriginalElt =
2218        CandidateTysCopy.size() ? CandidateTysCopy[0] : nullptr;
2219    // Consider additional vector types where the element type size is a
2220    // multiple of load/store element size.
2221    for (Type *Ty : OtherTys) {
2222      if (!VectorType::isValidElementType(Ty))
2223        continue;
2224      unsigned TypeSize = DL.getTypeSizeInBits(Ty).getFixedValue();
2225      // Make a copy of CandidateTys and iterate through it, because we
2226      // might append to CandidateTys in the loop.
2227      for (VectorType *const VTy : CandidateTysCopy) {
2228        // The elements in the copy should remain invariant throughout the loop
2229        assert(CandidateTysCopy[0] == OriginalElt && "Different Element");
2230        unsigned VectorSize = DL.getTypeSizeInBits(VTy).getFixedValue();
2231        unsigned ElementSize =
2232            DL.getTypeSizeInBits(VTy->getElementType()).getFixedValue();
2233        if (TypeSize != VectorSize && TypeSize != ElementSize &&
2234            VectorSize % TypeSize == 0) {
2235          VectorType *NewVTy = VectorType::get(Ty, VectorSize / TypeSize, false);
2236          CheckCandidateType(NewVTy);
2237        }
2238      }
2239    }
2240  
2241    return checkVectorTypesForPromotion(P, DL, CandidateTys, HaveCommonEltTy,
2242                                        CommonEltTy, HaveVecPtrTy,
2243                                        HaveCommonVecPtrTy, CommonVecPtrTy);
2244  }
2245  
2246  /// Test whether the given alloca partitioning and range of slices can be
2247  /// promoted to a vector.
2248  ///
2249  /// This is a quick test to check whether we can rewrite a particular alloca
2250  /// partition (and its newly formed alloca) into a vector alloca with only
2251  /// whole-vector loads and stores such that it could be promoted to a vector
2252  /// SSA value. We only can ensure this for a limited set of operations, and we
2253  /// don't want to do the rewrites unless we are confident that the result will
2254  /// be promotable, so we have an early test here.
isVectorPromotionViable(Partition & P,const DataLayout & DL)2255  static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) {
2256    // Collect the candidate types for vector-based promotion. Also track whether
2257    // we have different element types.
2258    SmallVector<VectorType *, 4> CandidateTys;
2259    SetVector<Type *> LoadStoreTys;
2260    SetVector<Type *> DeferredTys;
2261    Type *CommonEltTy = nullptr;
2262    VectorType *CommonVecPtrTy = nullptr;
2263    bool HaveVecPtrTy = false;
2264    bool HaveCommonEltTy = true;
2265    bool HaveCommonVecPtrTy = true;
2266    auto CheckCandidateType = [&](Type *Ty) {
2267      if (auto *VTy = dyn_cast<VectorType>(Ty)) {
2268        // Return if bitcast to vectors is different for total size in bits.
2269        if (!CandidateTys.empty()) {
2270          VectorType *V = CandidateTys[0];
2271          if (DL.getTypeSizeInBits(VTy).getFixedValue() !=
2272              DL.getTypeSizeInBits(V).getFixedValue()) {
2273            CandidateTys.clear();
2274            return;
2275          }
2276        }
2277        CandidateTys.push_back(VTy);
2278        Type *EltTy = VTy->getElementType();
2279  
2280        if (!CommonEltTy)
2281          CommonEltTy = EltTy;
2282        else if (CommonEltTy != EltTy)
2283          HaveCommonEltTy = false;
2284  
2285        if (EltTy->isPointerTy()) {
2286          HaveVecPtrTy = true;
2287          if (!CommonVecPtrTy)
2288            CommonVecPtrTy = VTy;
2289          else if (CommonVecPtrTy != VTy)
2290            HaveCommonVecPtrTy = false;
2291        }
2292      }
2293    };
2294  
2295    // Put load and store types into a set for de-duplication.
2296    for (const Slice &S : P) {
2297      Type *Ty;
2298      if (auto *LI = dyn_cast<LoadInst>(S.getUse()->getUser()))
2299        Ty = LI->getType();
2300      else if (auto *SI = dyn_cast<StoreInst>(S.getUse()->getUser()))
2301        Ty = SI->getValueOperand()->getType();
2302      else
2303        continue;
2304  
2305      auto CandTy = Ty->getScalarType();
2306      if (CandTy->isPointerTy() && (S.beginOffset() != P.beginOffset() ||
2307                                    S.endOffset() != P.endOffset())) {
2308        DeferredTys.insert(Ty);
2309        continue;
2310      }
2311  
2312      LoadStoreTys.insert(Ty);
2313      // Consider any loads or stores that are the exact size of the slice.
2314      if (S.beginOffset() == P.beginOffset() && S.endOffset() == P.endOffset())
2315        CheckCandidateType(Ty);
2316    }
2317  
2318    SmallVector<VectorType *, 4> CandidateTysCopy = CandidateTys;
2319    if (auto *VTy = createAndCheckVectorTypesForPromotion(
2320            LoadStoreTys, CandidateTysCopy, CheckCandidateType, P, DL,
2321            CandidateTys, HaveCommonEltTy, CommonEltTy, HaveVecPtrTy,
2322            HaveCommonVecPtrTy, CommonVecPtrTy))
2323      return VTy;
2324  
2325    CandidateTys.clear();
2326    return createAndCheckVectorTypesForPromotion(
2327        DeferredTys, CandidateTysCopy, CheckCandidateType, P, DL, CandidateTys,
2328        HaveCommonEltTy, CommonEltTy, HaveVecPtrTy, HaveCommonVecPtrTy,
2329        CommonVecPtrTy);
2330  }
2331  
2332  /// Test whether a slice of an alloca is valid for integer widening.
2333  ///
2334  /// This implements the necessary checking for the \c isIntegerWideningViable
2335  /// test below on a single slice of the alloca.
isIntegerWideningViableForSlice(const Slice & S,uint64_t AllocBeginOffset,Type * AllocaTy,const DataLayout & DL,bool & WholeAllocaOp)2336  static bool isIntegerWideningViableForSlice(const Slice &S,
2337                                              uint64_t AllocBeginOffset,
2338                                              Type *AllocaTy,
2339                                              const DataLayout &DL,
2340                                              bool &WholeAllocaOp) {
2341    uint64_t Size = DL.getTypeStoreSize(AllocaTy).getFixedValue();
2342  
2343    uint64_t RelBegin = S.beginOffset() - AllocBeginOffset;
2344    uint64_t RelEnd = S.endOffset() - AllocBeginOffset;
2345  
2346    Use *U = S.getUse();
2347  
2348    // Lifetime intrinsics operate over the whole alloca whose sizes are usually
2349    // larger than other load/store slices (RelEnd > Size). But lifetime are
2350    // always promotable and should not impact other slices' promotability of the
2351    // partition.
2352    if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(U->getUser())) {
2353      if (II->isLifetimeStartOrEnd() || II->isDroppable())
2354        return true;
2355    }
2356  
2357    // We can't reasonably handle cases where the load or store extends past
2358    // the end of the alloca's type and into its padding.
2359    if (RelEnd > Size)
2360      return false;
2361  
2362    if (LoadInst *LI = dyn_cast<LoadInst>(U->getUser())) {
2363      if (LI->isVolatile())
2364        return false;
2365      // We can't handle loads that extend past the allocated memory.
2366      if (DL.getTypeStoreSize(LI->getType()).getFixedValue() > Size)
2367        return false;
2368      // So far, AllocaSliceRewriter does not support widening split slice tails
2369      // in rewriteIntegerLoad.
2370      if (S.beginOffset() < AllocBeginOffset)
2371        return false;
2372      // Note that we don't count vector loads or stores as whole-alloca
2373      // operations which enable integer widening because we would prefer to use
2374      // vector widening instead.
2375      if (!isa<VectorType>(LI->getType()) && RelBegin == 0 && RelEnd == Size)
2376        WholeAllocaOp = true;
2377      if (IntegerType *ITy = dyn_cast<IntegerType>(LI->getType())) {
2378        if (ITy->getBitWidth() < DL.getTypeStoreSizeInBits(ITy).getFixedValue())
2379          return false;
2380      } else if (RelBegin != 0 || RelEnd != Size ||
2381                 !canConvertValue(DL, AllocaTy, LI->getType())) {
2382        // Non-integer loads need to be convertible from the alloca type so that
2383        // they are promotable.
2384        return false;
2385      }
2386    } else if (StoreInst *SI = dyn_cast<StoreInst>(U->getUser())) {
2387      Type *ValueTy = SI->getValueOperand()->getType();
2388      if (SI->isVolatile())
2389        return false;
2390      // We can't handle stores that extend past the allocated memory.
2391      if (DL.getTypeStoreSize(ValueTy).getFixedValue() > Size)
2392        return false;
2393      // So far, AllocaSliceRewriter does not support widening split slice tails
2394      // in rewriteIntegerStore.
2395      if (S.beginOffset() < AllocBeginOffset)
2396        return false;
2397      // Note that we don't count vector loads or stores as whole-alloca
2398      // operations which enable integer widening because we would prefer to use
2399      // vector widening instead.
2400      if (!isa<VectorType>(ValueTy) && RelBegin == 0 && RelEnd == Size)
2401        WholeAllocaOp = true;
2402      if (IntegerType *ITy = dyn_cast<IntegerType>(ValueTy)) {
2403        if (ITy->getBitWidth() < DL.getTypeStoreSizeInBits(ITy).getFixedValue())
2404          return false;
2405      } else if (RelBegin != 0 || RelEnd != Size ||
2406                 !canConvertValue(DL, ValueTy, AllocaTy)) {
2407        // Non-integer stores need to be convertible to the alloca type so that
2408        // they are promotable.
2409        return false;
2410      }
2411    } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(U->getUser())) {
2412      if (MI->isVolatile() || !isa<Constant>(MI->getLength()))
2413        return false;
2414      if (!S.isSplittable())
2415        return false; // Skip any unsplittable intrinsics.
2416    } else {
2417      return false;
2418    }
2419  
2420    return true;
2421  }
2422  
2423  /// Test whether the given alloca partition's integer operations can be
2424  /// widened to promotable ones.
2425  ///
2426  /// This is a quick test to check whether we can rewrite the integer loads and
2427  /// stores to a particular alloca into wider loads and stores and be able to
2428  /// promote the resulting alloca.
isIntegerWideningViable(Partition & P,Type * AllocaTy,const DataLayout & DL)2429  static bool isIntegerWideningViable(Partition &P, Type *AllocaTy,
2430                                      const DataLayout &DL) {
2431    uint64_t SizeInBits = DL.getTypeSizeInBits(AllocaTy).getFixedValue();
2432    // Don't create integer types larger than the maximum bitwidth.
2433    if (SizeInBits > IntegerType::MAX_INT_BITS)
2434      return false;
2435  
2436    // Don't try to handle allocas with bit-padding.
2437    if (SizeInBits != DL.getTypeStoreSizeInBits(AllocaTy).getFixedValue())
2438      return false;
2439  
2440    // We need to ensure that an integer type with the appropriate bitwidth can
2441    // be converted to the alloca type, whatever that is. We don't want to force
2442    // the alloca itself to have an integer type if there is a more suitable one.
2443    Type *IntTy = Type::getIntNTy(AllocaTy->getContext(), SizeInBits);
2444    if (!canConvertValue(DL, AllocaTy, IntTy) ||
2445        !canConvertValue(DL, IntTy, AllocaTy))
2446      return false;
2447  
2448    // While examining uses, we ensure that the alloca has a covering load or
2449    // store. We don't want to widen the integer operations only to fail to
2450    // promote due to some other unsplittable entry (which we may make splittable
2451    // later). However, if there are only splittable uses, go ahead and assume
2452    // that we cover the alloca.
2453    // FIXME: We shouldn't consider split slices that happen to start in the
2454    // partition here...
2455    bool WholeAllocaOp = P.empty() && DL.isLegalInteger(SizeInBits);
2456  
2457    for (const Slice &S : P)
2458      if (!isIntegerWideningViableForSlice(S, P.beginOffset(), AllocaTy, DL,
2459                                           WholeAllocaOp))
2460        return false;
2461  
2462    for (const Slice *S : P.splitSliceTails())
2463      if (!isIntegerWideningViableForSlice(*S, P.beginOffset(), AllocaTy, DL,
2464                                           WholeAllocaOp))
2465        return false;
2466  
2467    return WholeAllocaOp;
2468  }
2469  
extractInteger(const DataLayout & DL,IRBuilderTy & IRB,Value * V,IntegerType * Ty,uint64_t Offset,const Twine & Name)2470  static Value *extractInteger(const DataLayout &DL, IRBuilderTy &IRB, Value *V,
2471                               IntegerType *Ty, uint64_t Offset,
2472                               const Twine &Name) {
2473    LLVM_DEBUG(dbgs() << "       start: " << *V << "\n");
2474    IntegerType *IntTy = cast<IntegerType>(V->getType());
2475    assert(DL.getTypeStoreSize(Ty).getFixedValue() + Offset <=
2476               DL.getTypeStoreSize(IntTy).getFixedValue() &&
2477           "Element extends past full value");
2478    uint64_t ShAmt = 8 * Offset;
2479    if (DL.isBigEndian())
2480      ShAmt = 8 * (DL.getTypeStoreSize(IntTy).getFixedValue() -
2481                   DL.getTypeStoreSize(Ty).getFixedValue() - Offset);
2482    if (ShAmt) {
2483      V = IRB.CreateLShr(V, ShAmt, Name + ".shift");
2484      LLVM_DEBUG(dbgs() << "     shifted: " << *V << "\n");
2485    }
2486    assert(Ty->getBitWidth() <= IntTy->getBitWidth() &&
2487           "Cannot extract to a larger integer!");
2488    if (Ty != IntTy) {
2489      V = IRB.CreateTrunc(V, Ty, Name + ".trunc");
2490      LLVM_DEBUG(dbgs() << "     trunced: " << *V << "\n");
2491    }
2492    return V;
2493  }
2494  
insertInteger(const DataLayout & DL,IRBuilderTy & IRB,Value * Old,Value * V,uint64_t Offset,const Twine & Name)2495  static Value *insertInteger(const DataLayout &DL, IRBuilderTy &IRB, Value *Old,
2496                              Value *V, uint64_t Offset, const Twine &Name) {
2497    IntegerType *IntTy = cast<IntegerType>(Old->getType());
2498    IntegerType *Ty = cast<IntegerType>(V->getType());
2499    assert(Ty->getBitWidth() <= IntTy->getBitWidth() &&
2500           "Cannot insert a larger integer!");
2501    LLVM_DEBUG(dbgs() << "       start: " << *V << "\n");
2502    if (Ty != IntTy) {
2503      V = IRB.CreateZExt(V, IntTy, Name + ".ext");
2504      LLVM_DEBUG(dbgs() << "    extended: " << *V << "\n");
2505    }
2506    assert(DL.getTypeStoreSize(Ty).getFixedValue() + Offset <=
2507               DL.getTypeStoreSize(IntTy).getFixedValue() &&
2508           "Element store outside of alloca store");
2509    uint64_t ShAmt = 8 * Offset;
2510    if (DL.isBigEndian())
2511      ShAmt = 8 * (DL.getTypeStoreSize(IntTy).getFixedValue() -
2512                   DL.getTypeStoreSize(Ty).getFixedValue() - Offset);
2513    if (ShAmt) {
2514      V = IRB.CreateShl(V, ShAmt, Name + ".shift");
2515      LLVM_DEBUG(dbgs() << "     shifted: " << *V << "\n");
2516    }
2517  
2518    if (ShAmt || Ty->getBitWidth() < IntTy->getBitWidth()) {
2519      APInt Mask = ~Ty->getMask().zext(IntTy->getBitWidth()).shl(ShAmt);
2520      Old = IRB.CreateAnd(Old, Mask, Name + ".mask");
2521      LLVM_DEBUG(dbgs() << "      masked: " << *Old << "\n");
2522      V = IRB.CreateOr(Old, V, Name + ".insert");
2523      LLVM_DEBUG(dbgs() << "    inserted: " << *V << "\n");
2524    }
2525    return V;
2526  }
2527  
extractVector(IRBuilderTy & IRB,Value * V,unsigned BeginIndex,unsigned EndIndex,const Twine & Name)2528  static Value *extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex,
2529                              unsigned EndIndex, const Twine &Name) {
2530    auto *VecTy = cast<FixedVectorType>(V->getType());
2531    unsigned NumElements = EndIndex - BeginIndex;
2532    assert(NumElements <= VecTy->getNumElements() && "Too many elements!");
2533  
2534    if (NumElements == VecTy->getNumElements())
2535      return V;
2536  
2537    if (NumElements == 1) {
2538      V = IRB.CreateExtractElement(V, IRB.getInt32(BeginIndex),
2539                                   Name + ".extract");
2540      LLVM_DEBUG(dbgs() << "     extract: " << *V << "\n");
2541      return V;
2542    }
2543  
2544    auto Mask = llvm::to_vector<8>(llvm::seq<int>(BeginIndex, EndIndex));
2545    V = IRB.CreateShuffleVector(V, Mask, Name + ".extract");
2546    LLVM_DEBUG(dbgs() << "     shuffle: " << *V << "\n");
2547    return V;
2548  }
2549  
insertVector(IRBuilderTy & IRB,Value * Old,Value * V,unsigned BeginIndex,const Twine & Name)2550  static Value *insertVector(IRBuilderTy &IRB, Value *Old, Value *V,
2551                             unsigned BeginIndex, const Twine &Name) {
2552    VectorType *VecTy = cast<VectorType>(Old->getType());
2553    assert(VecTy && "Can only insert a vector into a vector");
2554  
2555    VectorType *Ty = dyn_cast<VectorType>(V->getType());
2556    if (!Ty) {
2557      // Single element to insert.
2558      V = IRB.CreateInsertElement(Old, V, IRB.getInt32(BeginIndex),
2559                                  Name + ".insert");
2560      LLVM_DEBUG(dbgs() << "     insert: " << *V << "\n");
2561      return V;
2562    }
2563  
2564    assert(cast<FixedVectorType>(Ty)->getNumElements() <=
2565               cast<FixedVectorType>(VecTy)->getNumElements() &&
2566           "Too many elements!");
2567    if (cast<FixedVectorType>(Ty)->getNumElements() ==
2568        cast<FixedVectorType>(VecTy)->getNumElements()) {
2569      assert(V->getType() == VecTy && "Vector type mismatch");
2570      return V;
2571    }
2572    unsigned EndIndex = BeginIndex + cast<FixedVectorType>(Ty)->getNumElements();
2573  
2574    // When inserting a smaller vector into the larger to store, we first
2575    // use a shuffle vector to widen it with undef elements, and then
2576    // a second shuffle vector to select between the loaded vector and the
2577    // incoming vector.
2578    SmallVector<int, 8> Mask;
2579    Mask.reserve(cast<FixedVectorType>(VecTy)->getNumElements());
2580    for (unsigned i = 0; i != cast<FixedVectorType>(VecTy)->getNumElements(); ++i)
2581      if (i >= BeginIndex && i < EndIndex)
2582        Mask.push_back(i - BeginIndex);
2583      else
2584        Mask.push_back(-1);
2585    V = IRB.CreateShuffleVector(V, Mask, Name + ".expand");
2586    LLVM_DEBUG(dbgs() << "    shuffle: " << *V << "\n");
2587  
2588    SmallVector<Constant *, 8> Mask2;
2589    Mask2.reserve(cast<FixedVectorType>(VecTy)->getNumElements());
2590    for (unsigned i = 0; i != cast<FixedVectorType>(VecTy)->getNumElements(); ++i)
2591      Mask2.push_back(IRB.getInt1(i >= BeginIndex && i < EndIndex));
2592  
2593    V = IRB.CreateSelect(ConstantVector::get(Mask2), V, Old, Name + "blend");
2594  
2595    LLVM_DEBUG(dbgs() << "    blend: " << *V << "\n");
2596    return V;
2597  }
2598  
2599  namespace {
2600  
2601  /// Visitor to rewrite instructions using p particular slice of an alloca
2602  /// to use a new alloca.
2603  ///
2604  /// Also implements the rewriting to vector-based accesses when the partition
2605  /// passes the isVectorPromotionViable predicate. Most of the rewriting logic
2606  /// lives here.
2607  class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
2608    // Befriend the base class so it can delegate to private visit methods.
2609    friend class InstVisitor<AllocaSliceRewriter, bool>;
2610  
2611    using Base = InstVisitor<AllocaSliceRewriter, bool>;
2612  
2613    const DataLayout &DL;
2614    AllocaSlices &AS;
2615    SROA &Pass;
2616    AllocaInst &OldAI, &NewAI;
2617    const uint64_t NewAllocaBeginOffset, NewAllocaEndOffset;
2618    Type *NewAllocaTy;
2619  
2620    // This is a convenience and flag variable that will be null unless the new
2621    // alloca's integer operations should be widened to this integer type due to
2622    // passing isIntegerWideningViable above. If it is non-null, the desired
2623    // integer type will be stored here for easy access during rewriting.
2624    IntegerType *IntTy;
2625  
2626    // If we are rewriting an alloca partition which can be written as pure
2627    // vector operations, we stash extra information here. When VecTy is
2628    // non-null, we have some strict guarantees about the rewritten alloca:
2629    //   - The new alloca is exactly the size of the vector type here.
2630    //   - The accesses all either map to the entire vector or to a single
2631    //     element.
2632    //   - The set of accessing instructions is only one of those handled above
2633    //     in isVectorPromotionViable. Generally these are the same access kinds
2634    //     which are promotable via mem2reg.
2635    VectorType *VecTy;
2636    Type *ElementTy;
2637    uint64_t ElementSize;
2638  
2639    // The original offset of the slice currently being rewritten relative to
2640    // the original alloca.
2641    uint64_t BeginOffset = 0;
2642    uint64_t EndOffset = 0;
2643  
2644    // The new offsets of the slice currently being rewritten relative to the
2645    // original alloca.
2646    uint64_t NewBeginOffset = 0, NewEndOffset = 0;
2647  
2648    uint64_t SliceSize = 0;
2649    bool IsSplittable = false;
2650    bool IsSplit = false;
2651    Use *OldUse = nullptr;
2652    Instruction *OldPtr = nullptr;
2653  
2654    // Track post-rewrite users which are PHI nodes and Selects.
2655    SmallSetVector<PHINode *, 8> &PHIUsers;
2656    SmallSetVector<SelectInst *, 8> &SelectUsers;
2657  
2658    // Utility IR builder, whose name prefix is setup for each visited use, and
2659    // the insertion point is set to point to the user.
2660    IRBuilderTy IRB;
2661  
2662    // Return the new alloca, addrspacecasted if required to avoid changing the
2663    // addrspace of a volatile access.
getPtrToNewAI(unsigned AddrSpace,bool IsVolatile)2664    Value *getPtrToNewAI(unsigned AddrSpace, bool IsVolatile) {
2665      if (!IsVolatile || AddrSpace == NewAI.getType()->getPointerAddressSpace())
2666        return &NewAI;
2667  
2668      Type *AccessTy = IRB.getPtrTy(AddrSpace);
2669      return IRB.CreateAddrSpaceCast(&NewAI, AccessTy);
2670    }
2671  
2672  public:
AllocaSliceRewriter(const DataLayout & DL,AllocaSlices & AS,SROA & Pass,AllocaInst & OldAI,AllocaInst & NewAI,uint64_t NewAllocaBeginOffset,uint64_t NewAllocaEndOffset,bool IsIntegerPromotable,VectorType * PromotableVecTy,SmallSetVector<PHINode *,8> & PHIUsers,SmallSetVector<SelectInst *,8> & SelectUsers)2673    AllocaSliceRewriter(const DataLayout &DL, AllocaSlices &AS, SROA &Pass,
2674                        AllocaInst &OldAI, AllocaInst &NewAI,
2675                        uint64_t NewAllocaBeginOffset,
2676                        uint64_t NewAllocaEndOffset, bool IsIntegerPromotable,
2677                        VectorType *PromotableVecTy,
2678                        SmallSetVector<PHINode *, 8> &PHIUsers,
2679                        SmallSetVector<SelectInst *, 8> &SelectUsers)
2680        : DL(DL), AS(AS), Pass(Pass), OldAI(OldAI), NewAI(NewAI),
2681          NewAllocaBeginOffset(NewAllocaBeginOffset),
2682          NewAllocaEndOffset(NewAllocaEndOffset),
2683          NewAllocaTy(NewAI.getAllocatedType()),
2684          IntTy(
2685              IsIntegerPromotable
2686                  ? Type::getIntNTy(NewAI.getContext(),
2687                                    DL.getTypeSizeInBits(NewAI.getAllocatedType())
2688                                        .getFixedValue())
2689                  : nullptr),
2690          VecTy(PromotableVecTy),
2691          ElementTy(VecTy ? VecTy->getElementType() : nullptr),
2692          ElementSize(VecTy ? DL.getTypeSizeInBits(ElementTy).getFixedValue() / 8
2693                            : 0),
2694          PHIUsers(PHIUsers), SelectUsers(SelectUsers),
2695          IRB(NewAI.getContext(), ConstantFolder()) {
2696      if (VecTy) {
2697        assert((DL.getTypeSizeInBits(ElementTy).getFixedValue() % 8) == 0 &&
2698               "Only multiple-of-8 sized vector elements are viable");
2699        ++NumVectorized;
2700      }
2701      assert((!IntTy && !VecTy) || (IntTy && !VecTy) || (!IntTy && VecTy));
2702    }
2703  
visit(AllocaSlices::const_iterator I)2704    bool visit(AllocaSlices::const_iterator I) {
2705      bool CanSROA = true;
2706      BeginOffset = I->beginOffset();
2707      EndOffset = I->endOffset();
2708      IsSplittable = I->isSplittable();
2709      IsSplit =
2710          BeginOffset < NewAllocaBeginOffset || EndOffset > NewAllocaEndOffset;
2711      LLVM_DEBUG(dbgs() << "  rewriting " << (IsSplit ? "split " : ""));
2712      LLVM_DEBUG(AS.printSlice(dbgs(), I, ""));
2713      LLVM_DEBUG(dbgs() << "\n");
2714  
2715      // Compute the intersecting offset range.
2716      assert(BeginOffset < NewAllocaEndOffset);
2717      assert(EndOffset > NewAllocaBeginOffset);
2718      NewBeginOffset = std::max(BeginOffset, NewAllocaBeginOffset);
2719      NewEndOffset = std::min(EndOffset, NewAllocaEndOffset);
2720  
2721      SliceSize = NewEndOffset - NewBeginOffset;
2722      LLVM_DEBUG(dbgs() << "   Begin:(" << BeginOffset << ", " << EndOffset
2723                        << ") NewBegin:(" << NewBeginOffset << ", "
2724                        << NewEndOffset << ") NewAllocaBegin:("
2725                        << NewAllocaBeginOffset << ", " << NewAllocaEndOffset
2726                        << ")\n");
2727      assert(IsSplit || NewBeginOffset == BeginOffset);
2728      OldUse = I->getUse();
2729      OldPtr = cast<Instruction>(OldUse->get());
2730  
2731      Instruction *OldUserI = cast<Instruction>(OldUse->getUser());
2732      IRB.SetInsertPoint(OldUserI);
2733      IRB.SetCurrentDebugLocation(OldUserI->getDebugLoc());
2734      IRB.getInserter().SetNamePrefix(Twine(NewAI.getName()) + "." +
2735                                      Twine(BeginOffset) + ".");
2736  
2737      CanSROA &= visit(cast<Instruction>(OldUse->getUser()));
2738      if (VecTy || IntTy)
2739        assert(CanSROA);
2740      return CanSROA;
2741    }
2742  
2743  private:
2744    // Make sure the other visit overloads are visible.
2745    using Base::visit;
2746  
2747    // Every instruction which can end up as a user must have a rewrite rule.
visitInstruction(Instruction & I)2748    bool visitInstruction(Instruction &I) {
2749      LLVM_DEBUG(dbgs() << "    !!!! Cannot rewrite: " << I << "\n");
2750      llvm_unreachable("No rewrite rule for this instruction!");
2751    }
2752  
getNewAllocaSlicePtr(IRBuilderTy & IRB,Type * PointerTy)2753    Value *getNewAllocaSlicePtr(IRBuilderTy &IRB, Type *PointerTy) {
2754      // Note that the offset computation can use BeginOffset or NewBeginOffset
2755      // interchangeably for unsplit slices.
2756      assert(IsSplit || BeginOffset == NewBeginOffset);
2757      uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset;
2758  
2759  #ifndef NDEBUG
2760      StringRef OldName = OldPtr->getName();
2761      // Skip through the last '.sroa.' component of the name.
2762      size_t LastSROAPrefix = OldName.rfind(".sroa.");
2763      if (LastSROAPrefix != StringRef::npos) {
2764        OldName = OldName.substr(LastSROAPrefix + strlen(".sroa."));
2765        // Look for an SROA slice index.
2766        size_t IndexEnd = OldName.find_first_not_of("0123456789");
2767        if (IndexEnd != StringRef::npos && OldName[IndexEnd] == '.') {
2768          // Strip the index and look for the offset.
2769          OldName = OldName.substr(IndexEnd + 1);
2770          size_t OffsetEnd = OldName.find_first_not_of("0123456789");
2771          if (OffsetEnd != StringRef::npos && OldName[OffsetEnd] == '.')
2772            // Strip the offset.
2773            OldName = OldName.substr(OffsetEnd + 1);
2774        }
2775      }
2776      // Strip any SROA suffixes as well.
2777      OldName = OldName.substr(0, OldName.find(".sroa_"));
2778  #endif
2779  
2780      return getAdjustedPtr(IRB, DL, &NewAI,
2781                            APInt(DL.getIndexTypeSizeInBits(PointerTy), Offset),
2782                            PointerTy,
2783  #ifndef NDEBUG
2784                            Twine(OldName) + "."
2785  #else
2786                            Twine()
2787  #endif
2788      );
2789    }
2790  
2791    /// Compute suitable alignment to access this slice of the *new*
2792    /// alloca.
2793    ///
2794    /// You can optionally pass a type to this routine and if that type's ABI
2795    /// alignment is itself suitable, this will return zero.
getSliceAlign()2796    Align getSliceAlign() {
2797      return commonAlignment(NewAI.getAlign(),
2798                             NewBeginOffset - NewAllocaBeginOffset);
2799    }
2800  
getIndex(uint64_t Offset)2801    unsigned getIndex(uint64_t Offset) {
2802      assert(VecTy && "Can only call getIndex when rewriting a vector");
2803      uint64_t RelOffset = Offset - NewAllocaBeginOffset;
2804      assert(RelOffset / ElementSize < UINT32_MAX && "Index out of bounds");
2805      uint32_t Index = RelOffset / ElementSize;
2806      assert(Index * ElementSize == RelOffset);
2807      return Index;
2808    }
2809  
deleteIfTriviallyDead(Value * V)2810    void deleteIfTriviallyDead(Value *V) {
2811      Instruction *I = cast<Instruction>(V);
2812      if (isInstructionTriviallyDead(I))
2813        Pass.DeadInsts.push_back(I);
2814    }
2815  
rewriteVectorizedLoadInst(LoadInst & LI)2816    Value *rewriteVectorizedLoadInst(LoadInst &LI) {
2817      unsigned BeginIndex = getIndex(NewBeginOffset);
2818      unsigned EndIndex = getIndex(NewEndOffset);
2819      assert(EndIndex > BeginIndex && "Empty vector!");
2820  
2821      LoadInst *Load = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI,
2822                                             NewAI.getAlign(), "load");
2823  
2824      Load->copyMetadata(LI, {LLVMContext::MD_mem_parallel_loop_access,
2825                              LLVMContext::MD_access_group});
2826      return extractVector(IRB, Load, BeginIndex, EndIndex, "vec");
2827    }
2828  
rewriteIntegerLoad(LoadInst & LI)2829    Value *rewriteIntegerLoad(LoadInst &LI) {
2830      assert(IntTy && "We cannot insert an integer to the alloca");
2831      assert(!LI.isVolatile());
2832      Value *V = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI,
2833                                       NewAI.getAlign(), "load");
2834      V = convertValue(DL, IRB, V, IntTy);
2835      assert(NewBeginOffset >= NewAllocaBeginOffset && "Out of bounds offset");
2836      uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset;
2837      if (Offset > 0 || NewEndOffset < NewAllocaEndOffset) {
2838        IntegerType *ExtractTy = Type::getIntNTy(LI.getContext(), SliceSize * 8);
2839        V = extractInteger(DL, IRB, V, ExtractTy, Offset, "extract");
2840      }
2841      // It is possible that the extracted type is not the load type. This
2842      // happens if there is a load past the end of the alloca, and as
2843      // a consequence the slice is narrower but still a candidate for integer
2844      // lowering. To handle this case, we just zero extend the extracted
2845      // integer.
2846      assert(cast<IntegerType>(LI.getType())->getBitWidth() >= SliceSize * 8 &&
2847             "Can only handle an extract for an overly wide load");
2848      if (cast<IntegerType>(LI.getType())->getBitWidth() > SliceSize * 8)
2849        V = IRB.CreateZExt(V, LI.getType());
2850      return V;
2851    }
2852  
visitLoadInst(LoadInst & LI)2853    bool visitLoadInst(LoadInst &LI) {
2854      LLVM_DEBUG(dbgs() << "    original: " << LI << "\n");
2855      Value *OldOp = LI.getOperand(0);
2856      assert(OldOp == OldPtr);
2857  
2858      AAMDNodes AATags = LI.getAAMetadata();
2859  
2860      unsigned AS = LI.getPointerAddressSpace();
2861  
2862      Type *TargetTy = IsSplit ? Type::getIntNTy(LI.getContext(), SliceSize * 8)
2863                               : LI.getType();
2864      const bool IsLoadPastEnd =
2865          DL.getTypeStoreSize(TargetTy).getFixedValue() > SliceSize;
2866      bool IsPtrAdjusted = false;
2867      Value *V;
2868      if (VecTy) {
2869        V = rewriteVectorizedLoadInst(LI);
2870      } else if (IntTy && LI.getType()->isIntegerTy()) {
2871        V = rewriteIntegerLoad(LI);
2872      } else if (NewBeginOffset == NewAllocaBeginOffset &&
2873                 NewEndOffset == NewAllocaEndOffset &&
2874                 (canConvertValue(DL, NewAllocaTy, TargetTy) ||
2875                  (IsLoadPastEnd && NewAllocaTy->isIntegerTy() &&
2876                   TargetTy->isIntegerTy() && !LI.isVolatile()))) {
2877        Value *NewPtr =
2878            getPtrToNewAI(LI.getPointerAddressSpace(), LI.isVolatile());
2879        LoadInst *NewLI = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), NewPtr,
2880                                                NewAI.getAlign(), LI.isVolatile(),
2881                                                LI.getName());
2882        if (LI.isVolatile())
2883          NewLI->setAtomic(LI.getOrdering(), LI.getSyncScopeID());
2884        if (NewLI->isAtomic())
2885          NewLI->setAlignment(LI.getAlign());
2886  
2887        // Copy any metadata that is valid for the new load. This may require
2888        // conversion to a different kind of metadata, e.g. !nonnull might change
2889        // to !range or vice versa.
2890        copyMetadataForLoad(*NewLI, LI);
2891  
2892        // Do this after copyMetadataForLoad() to preserve the TBAA shift.
2893        if (AATags)
2894          NewLI->setAAMetadata(AATags.adjustForAccess(
2895              NewBeginOffset - BeginOffset, NewLI->getType(), DL));
2896  
2897        // Try to preserve nonnull metadata
2898        V = NewLI;
2899  
2900        // If this is an integer load past the end of the slice (which means the
2901        // bytes outside the slice are undef or this load is dead) just forcibly
2902        // fix the integer size with correct handling of endianness.
2903        if (auto *AITy = dyn_cast<IntegerType>(NewAllocaTy))
2904          if (auto *TITy = dyn_cast<IntegerType>(TargetTy))
2905            if (AITy->getBitWidth() < TITy->getBitWidth()) {
2906              V = IRB.CreateZExt(V, TITy, "load.ext");
2907              if (DL.isBigEndian())
2908                V = IRB.CreateShl(V, TITy->getBitWidth() - AITy->getBitWidth(),
2909                                  "endian_shift");
2910            }
2911      } else {
2912        Type *LTy = IRB.getPtrTy(AS);
2913        LoadInst *NewLI =
2914            IRB.CreateAlignedLoad(TargetTy, getNewAllocaSlicePtr(IRB, LTy),
2915                                  getSliceAlign(), LI.isVolatile(), LI.getName());
2916  
2917        if (AATags)
2918          NewLI->setAAMetadata(AATags.adjustForAccess(
2919              NewBeginOffset - BeginOffset, NewLI->getType(), DL));
2920  
2921        if (LI.isVolatile())
2922          NewLI->setAtomic(LI.getOrdering(), LI.getSyncScopeID());
2923        NewLI->copyMetadata(LI, {LLVMContext::MD_mem_parallel_loop_access,
2924                                 LLVMContext::MD_access_group});
2925  
2926        V = NewLI;
2927        IsPtrAdjusted = true;
2928      }
2929      V = convertValue(DL, IRB, V, TargetTy);
2930  
2931      if (IsSplit) {
2932        assert(!LI.isVolatile());
2933        assert(LI.getType()->isIntegerTy() &&
2934               "Only integer type loads and stores are split");
2935        assert(SliceSize < DL.getTypeStoreSize(LI.getType()).getFixedValue() &&
2936               "Split load isn't smaller than original load");
2937        assert(DL.typeSizeEqualsStoreSize(LI.getType()) &&
2938               "Non-byte-multiple bit width");
2939        // Move the insertion point just past the load so that we can refer to it.
2940        BasicBlock::iterator LIIt = std::next(LI.getIterator());
2941        // Ensure the insertion point comes before any debug-info immediately
2942        // after the load, so that variable values referring to the load are
2943        // dominated by it.
2944        LIIt.setHeadBit(true);
2945        IRB.SetInsertPoint(LI.getParent(), LIIt);
2946        // Create a placeholder value with the same type as LI to use as the
2947        // basis for the new value. This allows us to replace the uses of LI with
2948        // the computed value, and then replace the placeholder with LI, leaving
2949        // LI only used for this computation.
2950        Value *Placeholder =
2951            new LoadInst(LI.getType(), PoisonValue::get(IRB.getPtrTy(AS)), "",
2952                         false, Align(1));
2953        V = insertInteger(DL, IRB, Placeholder, V, NewBeginOffset - BeginOffset,
2954                          "insert");
2955        LI.replaceAllUsesWith(V);
2956        Placeholder->replaceAllUsesWith(&LI);
2957        Placeholder->deleteValue();
2958      } else {
2959        LI.replaceAllUsesWith(V);
2960      }
2961  
2962      Pass.DeadInsts.push_back(&LI);
2963      deleteIfTriviallyDead(OldOp);
2964      LLVM_DEBUG(dbgs() << "          to: " << *V << "\n");
2965      return !LI.isVolatile() && !IsPtrAdjusted;
2966    }
2967  
rewriteVectorizedStoreInst(Value * V,StoreInst & SI,Value * OldOp,AAMDNodes AATags)2968    bool rewriteVectorizedStoreInst(Value *V, StoreInst &SI, Value *OldOp,
2969                                    AAMDNodes AATags) {
2970      // Capture V for the purpose of debug-info accounting once it's converted
2971      // to a vector store.
2972      Value *OrigV = V;
2973      if (V->getType() != VecTy) {
2974        unsigned BeginIndex = getIndex(NewBeginOffset);
2975        unsigned EndIndex = getIndex(NewEndOffset);
2976        assert(EndIndex > BeginIndex && "Empty vector!");
2977        unsigned NumElements = EndIndex - BeginIndex;
2978        assert(NumElements <= cast<FixedVectorType>(VecTy)->getNumElements() &&
2979               "Too many elements!");
2980        Type *SliceTy = (NumElements == 1)
2981                            ? ElementTy
2982                            : FixedVectorType::get(ElementTy, NumElements);
2983        if (V->getType() != SliceTy)
2984          V = convertValue(DL, IRB, V, SliceTy);
2985  
2986        // Mix in the existing elements.
2987        Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI,
2988                                           NewAI.getAlign(), "load");
2989        V = insertVector(IRB, Old, V, BeginIndex, "vec");
2990      }
2991      StoreInst *Store = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlign());
2992      Store->copyMetadata(SI, {LLVMContext::MD_mem_parallel_loop_access,
2993                               LLVMContext::MD_access_group});
2994      if (AATags)
2995        Store->setAAMetadata(AATags.adjustForAccess(NewBeginOffset - BeginOffset,
2996                                                    V->getType(), DL));
2997      Pass.DeadInsts.push_back(&SI);
2998  
2999      // NOTE: Careful to use OrigV rather than V.
3000      migrateDebugInfo(&OldAI, IsSplit, NewBeginOffset * 8, SliceSize * 8, &SI,
3001                       Store, Store->getPointerOperand(), OrigV, DL);
3002      LLVM_DEBUG(dbgs() << "          to: " << *Store << "\n");
3003      return true;
3004    }
3005  
rewriteIntegerStore(Value * V,StoreInst & SI,AAMDNodes AATags)3006    bool rewriteIntegerStore(Value *V, StoreInst &SI, AAMDNodes AATags) {
3007      assert(IntTy && "We cannot extract an integer from the alloca");
3008      assert(!SI.isVolatile());
3009      if (DL.getTypeSizeInBits(V->getType()).getFixedValue() !=
3010          IntTy->getBitWidth()) {
3011        Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI,
3012                                           NewAI.getAlign(), "oldload");
3013        Old = convertValue(DL, IRB, Old, IntTy);
3014        assert(BeginOffset >= NewAllocaBeginOffset && "Out of bounds offset");
3015        uint64_t Offset = BeginOffset - NewAllocaBeginOffset;
3016        V = insertInteger(DL, IRB, Old, SI.getValueOperand(), Offset, "insert");
3017      }
3018      V = convertValue(DL, IRB, V, NewAllocaTy);
3019      StoreInst *Store = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlign());
3020      Store->copyMetadata(SI, {LLVMContext::MD_mem_parallel_loop_access,
3021                               LLVMContext::MD_access_group});
3022      if (AATags)
3023        Store->setAAMetadata(AATags.adjustForAccess(NewBeginOffset - BeginOffset,
3024                                                    V->getType(), DL));
3025  
3026      migrateDebugInfo(&OldAI, IsSplit, NewBeginOffset * 8, SliceSize * 8, &SI,
3027                       Store, Store->getPointerOperand(),
3028                       Store->getValueOperand(), DL);
3029  
3030      Pass.DeadInsts.push_back(&SI);
3031      LLVM_DEBUG(dbgs() << "          to: " << *Store << "\n");
3032      return true;
3033    }
3034  
visitStoreInst(StoreInst & SI)3035    bool visitStoreInst(StoreInst &SI) {
3036      LLVM_DEBUG(dbgs() << "    original: " << SI << "\n");
3037      Value *OldOp = SI.getOperand(1);
3038      assert(OldOp == OldPtr);
3039  
3040      AAMDNodes AATags = SI.getAAMetadata();
3041      Value *V = SI.getValueOperand();
3042  
3043      // Strip all inbounds GEPs and pointer casts to try to dig out any root
3044      // alloca that should be re-examined after promoting this alloca.
3045      if (V->getType()->isPointerTy())
3046        if (AllocaInst *AI = dyn_cast<AllocaInst>(V->stripInBoundsOffsets()))
3047          Pass.PostPromotionWorklist.insert(AI);
3048  
3049      if (SliceSize < DL.getTypeStoreSize(V->getType()).getFixedValue()) {
3050        assert(!SI.isVolatile());
3051        assert(V->getType()->isIntegerTy() &&
3052               "Only integer type loads and stores are split");
3053        assert(DL.typeSizeEqualsStoreSize(V->getType()) &&
3054               "Non-byte-multiple bit width");
3055        IntegerType *NarrowTy = Type::getIntNTy(SI.getContext(), SliceSize * 8);
3056        V = extractInteger(DL, IRB, V, NarrowTy, NewBeginOffset - BeginOffset,
3057                           "extract");
3058      }
3059  
3060      if (VecTy)
3061        return rewriteVectorizedStoreInst(V, SI, OldOp, AATags);
3062      if (IntTy && V->getType()->isIntegerTy())
3063        return rewriteIntegerStore(V, SI, AATags);
3064  
3065      StoreInst *NewSI;
3066      if (NewBeginOffset == NewAllocaBeginOffset &&
3067          NewEndOffset == NewAllocaEndOffset &&
3068          canConvertValue(DL, V->getType(), NewAllocaTy)) {
3069        V = convertValue(DL, IRB, V, NewAllocaTy);
3070        Value *NewPtr =
3071            getPtrToNewAI(SI.getPointerAddressSpace(), SI.isVolatile());
3072  
3073        NewSI =
3074            IRB.CreateAlignedStore(V, NewPtr, NewAI.getAlign(), SI.isVolatile());
3075      } else {
3076        unsigned AS = SI.getPointerAddressSpace();
3077        Value *NewPtr = getNewAllocaSlicePtr(IRB, IRB.getPtrTy(AS));
3078        NewSI =
3079            IRB.CreateAlignedStore(V, NewPtr, getSliceAlign(), SI.isVolatile());
3080      }
3081      NewSI->copyMetadata(SI, {LLVMContext::MD_mem_parallel_loop_access,
3082                               LLVMContext::MD_access_group});
3083      if (AATags)
3084        NewSI->setAAMetadata(AATags.adjustForAccess(NewBeginOffset - BeginOffset,
3085                                                    V->getType(), DL));
3086      if (SI.isVolatile())
3087        NewSI->setAtomic(SI.getOrdering(), SI.getSyncScopeID());
3088      if (NewSI->isAtomic())
3089        NewSI->setAlignment(SI.getAlign());
3090  
3091      migrateDebugInfo(&OldAI, IsSplit, NewBeginOffset * 8, SliceSize * 8, &SI,
3092                       NewSI, NewSI->getPointerOperand(),
3093                       NewSI->getValueOperand(), DL);
3094  
3095      Pass.DeadInsts.push_back(&SI);
3096      deleteIfTriviallyDead(OldOp);
3097  
3098      LLVM_DEBUG(dbgs() << "          to: " << *NewSI << "\n");
3099      return NewSI->getPointerOperand() == &NewAI &&
3100             NewSI->getValueOperand()->getType() == NewAllocaTy &&
3101             !SI.isVolatile();
3102    }
3103  
3104    /// Compute an integer value from splatting an i8 across the given
3105    /// number of bytes.
3106    ///
3107    /// Note that this routine assumes an i8 is a byte. If that isn't true, don't
3108    /// call this routine.
3109    /// FIXME: Heed the advice above.
3110    ///
3111    /// \param V The i8 value to splat.
3112    /// \param Size The number of bytes in the output (assuming i8 is one byte)
getIntegerSplat(Value * V,unsigned Size)3113    Value *getIntegerSplat(Value *V, unsigned Size) {
3114      assert(Size > 0 && "Expected a positive number of bytes.");
3115      IntegerType *VTy = cast<IntegerType>(V->getType());
3116      assert(VTy->getBitWidth() == 8 && "Expected an i8 value for the byte");
3117      if (Size == 1)
3118        return V;
3119  
3120      Type *SplatIntTy = Type::getIntNTy(VTy->getContext(), Size * 8);
3121      V = IRB.CreateMul(
3122          IRB.CreateZExt(V, SplatIntTy, "zext"),
3123          IRB.CreateUDiv(Constant::getAllOnesValue(SplatIntTy),
3124                         IRB.CreateZExt(Constant::getAllOnesValue(V->getType()),
3125                                        SplatIntTy)),
3126          "isplat");
3127      return V;
3128    }
3129  
3130    /// Compute a vector splat for a given element value.
getVectorSplat(Value * V,unsigned NumElements)3131    Value *getVectorSplat(Value *V, unsigned NumElements) {
3132      V = IRB.CreateVectorSplat(NumElements, V, "vsplat");
3133      LLVM_DEBUG(dbgs() << "       splat: " << *V << "\n");
3134      return V;
3135    }
3136  
visitMemSetInst(MemSetInst & II)3137    bool visitMemSetInst(MemSetInst &II) {
3138      LLVM_DEBUG(dbgs() << "    original: " << II << "\n");
3139      assert(II.getRawDest() == OldPtr);
3140  
3141      AAMDNodes AATags = II.getAAMetadata();
3142  
3143      // If the memset has a variable size, it cannot be split, just adjust the
3144      // pointer to the new alloca.
3145      if (!isa<ConstantInt>(II.getLength())) {
3146        assert(!IsSplit);
3147        assert(NewBeginOffset == BeginOffset);
3148        II.setDest(getNewAllocaSlicePtr(IRB, OldPtr->getType()));
3149        II.setDestAlignment(getSliceAlign());
3150        // In theory we should call migrateDebugInfo here. However, we do not
3151        // emit dbg.assign intrinsics for mem intrinsics storing through non-
3152        // constant geps, or storing a variable number of bytes.
3153        assert(at::getAssignmentMarkers(&II).empty() &&
3154               at::getDVRAssignmentMarkers(&II).empty() &&
3155               "AT: Unexpected link to non-const GEP");
3156        deleteIfTriviallyDead(OldPtr);
3157        return false;
3158      }
3159  
3160      // Record this instruction for deletion.
3161      Pass.DeadInsts.push_back(&II);
3162  
3163      Type *AllocaTy = NewAI.getAllocatedType();
3164      Type *ScalarTy = AllocaTy->getScalarType();
3165  
3166      const bool CanContinue = [&]() {
3167        if (VecTy || IntTy)
3168          return true;
3169        if (BeginOffset > NewAllocaBeginOffset || EndOffset < NewAllocaEndOffset)
3170          return false;
3171        // Length must be in range for FixedVectorType.
3172        auto *C = cast<ConstantInt>(II.getLength());
3173        const uint64_t Len = C->getLimitedValue();
3174        if (Len > std::numeric_limits<unsigned>::max())
3175          return false;
3176        auto *Int8Ty = IntegerType::getInt8Ty(NewAI.getContext());
3177        auto *SrcTy = FixedVectorType::get(Int8Ty, Len);
3178        return canConvertValue(DL, SrcTy, AllocaTy) &&
3179               DL.isLegalInteger(DL.getTypeSizeInBits(ScalarTy).getFixedValue());
3180      }();
3181  
3182      // If this doesn't map cleanly onto the alloca type, and that type isn't
3183      // a single value type, just emit a memset.
3184      if (!CanContinue) {
3185        Type *SizeTy = II.getLength()->getType();
3186        unsigned Sz = NewEndOffset - NewBeginOffset;
3187        Constant *Size = ConstantInt::get(SizeTy, Sz);
3188        MemIntrinsic *New = cast<MemIntrinsic>(IRB.CreateMemSet(
3189            getNewAllocaSlicePtr(IRB, OldPtr->getType()), II.getValue(), Size,
3190            MaybeAlign(getSliceAlign()), II.isVolatile()));
3191        if (AATags)
3192          New->setAAMetadata(
3193              AATags.adjustForAccess(NewBeginOffset - BeginOffset, Sz));
3194  
3195        migrateDebugInfo(&OldAI, IsSplit, NewBeginOffset * 8, SliceSize * 8, &II,
3196                         New, New->getRawDest(), nullptr, DL);
3197  
3198        LLVM_DEBUG(dbgs() << "          to: " << *New << "\n");
3199        return false;
3200      }
3201  
3202      // If we can represent this as a simple value, we have to build the actual
3203      // value to store, which requires expanding the byte present in memset to
3204      // a sensible representation for the alloca type. This is essentially
3205      // splatting the byte to a sufficiently wide integer, splatting it across
3206      // any desired vector width, and bitcasting to the final type.
3207      Value *V;
3208  
3209      if (VecTy) {
3210        // If this is a memset of a vectorized alloca, insert it.
3211        assert(ElementTy == ScalarTy);
3212  
3213        unsigned BeginIndex = getIndex(NewBeginOffset);
3214        unsigned EndIndex = getIndex(NewEndOffset);
3215        assert(EndIndex > BeginIndex && "Empty vector!");
3216        unsigned NumElements = EndIndex - BeginIndex;
3217        assert(NumElements <= cast<FixedVectorType>(VecTy)->getNumElements() &&
3218               "Too many elements!");
3219  
3220        Value *Splat = getIntegerSplat(
3221            II.getValue(), DL.getTypeSizeInBits(ElementTy).getFixedValue() / 8);
3222        Splat = convertValue(DL, IRB, Splat, ElementTy);
3223        if (NumElements > 1)
3224          Splat = getVectorSplat(Splat, NumElements);
3225  
3226        Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI,
3227                                           NewAI.getAlign(), "oldload");
3228        V = insertVector(IRB, Old, Splat, BeginIndex, "vec");
3229      } else if (IntTy) {
3230        // If this is a memset on an alloca where we can widen stores, insert the
3231        // set integer.
3232        assert(!II.isVolatile());
3233  
3234        uint64_t Size = NewEndOffset - NewBeginOffset;
3235        V = getIntegerSplat(II.getValue(), Size);
3236  
3237        if (IntTy && (BeginOffset != NewAllocaBeginOffset ||
3238                      EndOffset != NewAllocaBeginOffset)) {
3239          Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI,
3240                                             NewAI.getAlign(), "oldload");
3241          Old = convertValue(DL, IRB, Old, IntTy);
3242          uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset;
3243          V = insertInteger(DL, IRB, Old, V, Offset, "insert");
3244        } else {
3245          assert(V->getType() == IntTy &&
3246                 "Wrong type for an alloca wide integer!");
3247        }
3248        V = convertValue(DL, IRB, V, AllocaTy);
3249      } else {
3250        // Established these invariants above.
3251        assert(NewBeginOffset == NewAllocaBeginOffset);
3252        assert(NewEndOffset == NewAllocaEndOffset);
3253  
3254        V = getIntegerSplat(II.getValue(),
3255                            DL.getTypeSizeInBits(ScalarTy).getFixedValue() / 8);
3256        if (VectorType *AllocaVecTy = dyn_cast<VectorType>(AllocaTy))
3257          V = getVectorSplat(
3258              V, cast<FixedVectorType>(AllocaVecTy)->getNumElements());
3259  
3260        V = convertValue(DL, IRB, V, AllocaTy);
3261      }
3262  
3263      Value *NewPtr = getPtrToNewAI(II.getDestAddressSpace(), II.isVolatile());
3264      StoreInst *New =
3265          IRB.CreateAlignedStore(V, NewPtr, NewAI.getAlign(), II.isVolatile());
3266      New->copyMetadata(II, {LLVMContext::MD_mem_parallel_loop_access,
3267                             LLVMContext::MD_access_group});
3268      if (AATags)
3269        New->setAAMetadata(AATags.adjustForAccess(NewBeginOffset - BeginOffset,
3270                                                  V->getType(), DL));
3271  
3272      migrateDebugInfo(&OldAI, IsSplit, NewBeginOffset * 8, SliceSize * 8, &II,
3273                       New, New->getPointerOperand(), V, DL);
3274  
3275      LLVM_DEBUG(dbgs() << "          to: " << *New << "\n");
3276      return !II.isVolatile();
3277    }
3278  
visitMemTransferInst(MemTransferInst & II)3279    bool visitMemTransferInst(MemTransferInst &II) {
3280      // Rewriting of memory transfer instructions can be a bit tricky. We break
3281      // them into two categories: split intrinsics and unsplit intrinsics.
3282  
3283      LLVM_DEBUG(dbgs() << "    original: " << II << "\n");
3284  
3285      AAMDNodes AATags = II.getAAMetadata();
3286  
3287      bool IsDest = &II.getRawDestUse() == OldUse;
3288      assert((IsDest && II.getRawDest() == OldPtr) ||
3289             (!IsDest && II.getRawSource() == OldPtr));
3290  
3291      Align SliceAlign = getSliceAlign();
3292      // For unsplit intrinsics, we simply modify the source and destination
3293      // pointers in place. This isn't just an optimization, it is a matter of
3294      // correctness. With unsplit intrinsics we may be dealing with transfers
3295      // within a single alloca before SROA ran, or with transfers that have
3296      // a variable length. We may also be dealing with memmove instead of
3297      // memcpy, and so simply updating the pointers is the necessary for us to
3298      // update both source and dest of a single call.
3299      if (!IsSplittable) {
3300        Value *AdjustedPtr = getNewAllocaSlicePtr(IRB, OldPtr->getType());
3301        if (IsDest) {
3302          // Update the address component of linked dbg.assigns.
3303          auto UpdateAssignAddress = [&](auto *DbgAssign) {
3304            if (llvm::is_contained(DbgAssign->location_ops(), II.getDest()) ||
3305                DbgAssign->getAddress() == II.getDest())
3306              DbgAssign->replaceVariableLocationOp(II.getDest(), AdjustedPtr);
3307          };
3308          for_each(at::getAssignmentMarkers(&II), UpdateAssignAddress);
3309          for_each(at::getDVRAssignmentMarkers(&II), UpdateAssignAddress);
3310          II.setDest(AdjustedPtr);
3311          II.setDestAlignment(SliceAlign);
3312        } else {
3313          II.setSource(AdjustedPtr);
3314          II.setSourceAlignment(SliceAlign);
3315        }
3316  
3317        LLVM_DEBUG(dbgs() << "          to: " << II << "\n");
3318        deleteIfTriviallyDead(OldPtr);
3319        return false;
3320      }
3321      // For split transfer intrinsics we have an incredibly useful assurance:
3322      // the source and destination do not reside within the same alloca, and at
3323      // least one of them does not escape. This means that we can replace
3324      // memmove with memcpy, and we don't need to worry about all manner of
3325      // downsides to splitting and transforming the operations.
3326  
3327      // If this doesn't map cleanly onto the alloca type, and that type isn't
3328      // a single value type, just emit a memcpy.
3329      bool EmitMemCpy =
3330          !VecTy && !IntTy &&
3331          (BeginOffset > NewAllocaBeginOffset || EndOffset < NewAllocaEndOffset ||
3332           SliceSize !=
3333               DL.getTypeStoreSize(NewAI.getAllocatedType()).getFixedValue() ||
3334           !DL.typeSizeEqualsStoreSize(NewAI.getAllocatedType()) ||
3335           !NewAI.getAllocatedType()->isSingleValueType());
3336  
3337      // If we're just going to emit a memcpy, the alloca hasn't changed, and the
3338      // size hasn't been shrunk based on analysis of the viable range, this is
3339      // a no-op.
3340      if (EmitMemCpy && &OldAI == &NewAI) {
3341        // Ensure the start lines up.
3342        assert(NewBeginOffset == BeginOffset);
3343  
3344        // Rewrite the size as needed.
3345        if (NewEndOffset != EndOffset)
3346          II.setLength(ConstantInt::get(II.getLength()->getType(),
3347                                        NewEndOffset - NewBeginOffset));
3348        return false;
3349      }
3350      // Record this instruction for deletion.
3351      Pass.DeadInsts.push_back(&II);
3352  
3353      // Strip all inbounds GEPs and pointer casts to try to dig out any root
3354      // alloca that should be re-examined after rewriting this instruction.
3355      Value *OtherPtr = IsDest ? II.getRawSource() : II.getRawDest();
3356      if (AllocaInst *AI =
3357              dyn_cast<AllocaInst>(OtherPtr->stripInBoundsOffsets())) {
3358        assert(AI != &OldAI && AI != &NewAI &&
3359               "Splittable transfers cannot reach the same alloca on both ends.");
3360        Pass.Worklist.insert(AI);
3361      }
3362  
3363      Type *OtherPtrTy = OtherPtr->getType();
3364      unsigned OtherAS = OtherPtrTy->getPointerAddressSpace();
3365  
3366      // Compute the relative offset for the other pointer within the transfer.
3367      unsigned OffsetWidth = DL.getIndexSizeInBits(OtherAS);
3368      APInt OtherOffset(OffsetWidth, NewBeginOffset - BeginOffset);
3369      Align OtherAlign =
3370          (IsDest ? II.getSourceAlign() : II.getDestAlign()).valueOrOne();
3371      OtherAlign =
3372          commonAlignment(OtherAlign, OtherOffset.zextOrTrunc(64).getZExtValue());
3373  
3374      if (EmitMemCpy) {
3375        // Compute the other pointer, folding as much as possible to produce
3376        // a single, simple GEP in most cases.
3377        OtherPtr = getAdjustedPtr(IRB, DL, OtherPtr, OtherOffset, OtherPtrTy,
3378                                  OtherPtr->getName() + ".");
3379  
3380        Value *OurPtr = getNewAllocaSlicePtr(IRB, OldPtr->getType());
3381        Type *SizeTy = II.getLength()->getType();
3382        Constant *Size = ConstantInt::get(SizeTy, NewEndOffset - NewBeginOffset);
3383  
3384        Value *DestPtr, *SrcPtr;
3385        MaybeAlign DestAlign, SrcAlign;
3386        // Note: IsDest is true iff we're copying into the new alloca slice
3387        if (IsDest) {
3388          DestPtr = OurPtr;
3389          DestAlign = SliceAlign;
3390          SrcPtr = OtherPtr;
3391          SrcAlign = OtherAlign;
3392        } else {
3393          DestPtr = OtherPtr;
3394          DestAlign = OtherAlign;
3395          SrcPtr = OurPtr;
3396          SrcAlign = SliceAlign;
3397        }
3398        CallInst *New = IRB.CreateMemCpy(DestPtr, DestAlign, SrcPtr, SrcAlign,
3399                                         Size, II.isVolatile());
3400        if (AATags)
3401          New->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset));
3402  
3403        APInt Offset(DL.getIndexTypeSizeInBits(DestPtr->getType()), 0);
3404        if (IsDest) {
3405          migrateDebugInfo(&OldAI, IsSplit, NewBeginOffset * 8, SliceSize * 8,
3406                           &II, New, DestPtr, nullptr, DL);
3407        } else if (AllocaInst *Base = dyn_cast<AllocaInst>(
3408                       DestPtr->stripAndAccumulateConstantOffsets(
3409                           DL, Offset, /*AllowNonInbounds*/ true))) {
3410          migrateDebugInfo(Base, IsSplit, Offset.getZExtValue() * 8,
3411                           SliceSize * 8, &II, New, DestPtr, nullptr, DL);
3412        }
3413        LLVM_DEBUG(dbgs() << "          to: " << *New << "\n");
3414        return false;
3415      }
3416  
3417      bool IsWholeAlloca = NewBeginOffset == NewAllocaBeginOffset &&
3418                           NewEndOffset == NewAllocaEndOffset;
3419      uint64_t Size = NewEndOffset - NewBeginOffset;
3420      unsigned BeginIndex = VecTy ? getIndex(NewBeginOffset) : 0;
3421      unsigned EndIndex = VecTy ? getIndex(NewEndOffset) : 0;
3422      unsigned NumElements = EndIndex - BeginIndex;
3423      IntegerType *SubIntTy =
3424          IntTy ? Type::getIntNTy(IntTy->getContext(), Size * 8) : nullptr;
3425  
3426      // Reset the other pointer type to match the register type we're going to
3427      // use, but using the address space of the original other pointer.
3428      Type *OtherTy;
3429      if (VecTy && !IsWholeAlloca) {
3430        if (NumElements == 1)
3431          OtherTy = VecTy->getElementType();
3432        else
3433          OtherTy = FixedVectorType::get(VecTy->getElementType(), NumElements);
3434      } else if (IntTy && !IsWholeAlloca) {
3435        OtherTy = SubIntTy;
3436      } else {
3437        OtherTy = NewAllocaTy;
3438      }
3439  
3440      Value *AdjPtr = getAdjustedPtr(IRB, DL, OtherPtr, OtherOffset, OtherPtrTy,
3441                                     OtherPtr->getName() + ".");
3442      MaybeAlign SrcAlign = OtherAlign;
3443      MaybeAlign DstAlign = SliceAlign;
3444      if (!IsDest)
3445        std::swap(SrcAlign, DstAlign);
3446  
3447      Value *SrcPtr;
3448      Value *DstPtr;
3449  
3450      if (IsDest) {
3451        DstPtr = getPtrToNewAI(II.getDestAddressSpace(), II.isVolatile());
3452        SrcPtr = AdjPtr;
3453      } else {
3454        DstPtr = AdjPtr;
3455        SrcPtr = getPtrToNewAI(II.getSourceAddressSpace(), II.isVolatile());
3456      }
3457  
3458      Value *Src;
3459      if (VecTy && !IsWholeAlloca && !IsDest) {
3460        Src = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI,
3461                                    NewAI.getAlign(), "load");
3462        Src = extractVector(IRB, Src, BeginIndex, EndIndex, "vec");
3463      } else if (IntTy && !IsWholeAlloca && !IsDest) {
3464        Src = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI,
3465                                    NewAI.getAlign(), "load");
3466        Src = convertValue(DL, IRB, Src, IntTy);
3467        uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset;
3468        Src = extractInteger(DL, IRB, Src, SubIntTy, Offset, "extract");
3469      } else {
3470        LoadInst *Load = IRB.CreateAlignedLoad(OtherTy, SrcPtr, SrcAlign,
3471                                               II.isVolatile(), "copyload");
3472        Load->copyMetadata(II, {LLVMContext::MD_mem_parallel_loop_access,
3473                                LLVMContext::MD_access_group});
3474        if (AATags)
3475          Load->setAAMetadata(AATags.adjustForAccess(NewBeginOffset - BeginOffset,
3476                                                     Load->getType(), DL));
3477        Src = Load;
3478      }
3479  
3480      if (VecTy && !IsWholeAlloca && IsDest) {
3481        Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI,
3482                                           NewAI.getAlign(), "oldload");
3483        Src = insertVector(IRB, Old, Src, BeginIndex, "vec");
3484      } else if (IntTy && !IsWholeAlloca && IsDest) {
3485        Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI,
3486                                           NewAI.getAlign(), "oldload");
3487        Old = convertValue(DL, IRB, Old, IntTy);
3488        uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset;
3489        Src = insertInteger(DL, IRB, Old, Src, Offset, "insert");
3490        Src = convertValue(DL, IRB, Src, NewAllocaTy);
3491      }
3492  
3493      StoreInst *Store = cast<StoreInst>(
3494          IRB.CreateAlignedStore(Src, DstPtr, DstAlign, II.isVolatile()));
3495      Store->copyMetadata(II, {LLVMContext::MD_mem_parallel_loop_access,
3496                               LLVMContext::MD_access_group});
3497      if (AATags)
3498        Store->setAAMetadata(AATags.adjustForAccess(NewBeginOffset - BeginOffset,
3499                                                    Src->getType(), DL));
3500  
3501      APInt Offset(DL.getIndexTypeSizeInBits(DstPtr->getType()), 0);
3502      if (IsDest) {
3503  
3504        migrateDebugInfo(&OldAI, IsSplit, NewBeginOffset * 8, SliceSize * 8, &II,
3505                         Store, DstPtr, Src, DL);
3506      } else if (AllocaInst *Base = dyn_cast<AllocaInst>(
3507                     DstPtr->stripAndAccumulateConstantOffsets(
3508                         DL, Offset, /*AllowNonInbounds*/ true))) {
3509        migrateDebugInfo(Base, IsSplit, Offset.getZExtValue() * 8, SliceSize * 8,
3510                         &II, Store, DstPtr, Src, DL);
3511      }
3512  
3513      LLVM_DEBUG(dbgs() << "          to: " << *Store << "\n");
3514      return !II.isVolatile();
3515    }
3516  
visitIntrinsicInst(IntrinsicInst & II)3517    bool visitIntrinsicInst(IntrinsicInst &II) {
3518      assert((II.isLifetimeStartOrEnd() || II.isLaunderOrStripInvariantGroup() ||
3519              II.isDroppable()) &&
3520             "Unexpected intrinsic!");
3521      LLVM_DEBUG(dbgs() << "    original: " << II << "\n");
3522  
3523      // Record this instruction for deletion.
3524      Pass.DeadInsts.push_back(&II);
3525  
3526      if (II.isDroppable()) {
3527        assert(II.getIntrinsicID() == Intrinsic::assume && "Expected assume");
3528        // TODO For now we forget assumed information, this can be improved.
3529        OldPtr->dropDroppableUsesIn(II);
3530        return true;
3531      }
3532  
3533      if (II.isLaunderOrStripInvariantGroup())
3534        return true;
3535  
3536      assert(II.getArgOperand(1) == OldPtr);
3537      // Lifetime intrinsics are only promotable if they cover the whole alloca.
3538      // Therefore, we drop lifetime intrinsics which don't cover the whole
3539      // alloca.
3540      // (In theory, intrinsics which partially cover an alloca could be
3541      // promoted, but PromoteMemToReg doesn't handle that case.)
3542      // FIXME: Check whether the alloca is promotable before dropping the
3543      // lifetime intrinsics?
3544      if (NewBeginOffset != NewAllocaBeginOffset ||
3545          NewEndOffset != NewAllocaEndOffset)
3546        return true;
3547  
3548      ConstantInt *Size =
3549          ConstantInt::get(cast<IntegerType>(II.getArgOperand(0)->getType()),
3550                           NewEndOffset - NewBeginOffset);
3551      // Lifetime intrinsics always expect an i8* so directly get such a pointer
3552      // for the new alloca slice.
3553      Type *PointerTy = IRB.getPtrTy(OldPtr->getType()->getPointerAddressSpace());
3554      Value *Ptr = getNewAllocaSlicePtr(IRB, PointerTy);
3555      Value *New;
3556      if (II.getIntrinsicID() == Intrinsic::lifetime_start)
3557        New = IRB.CreateLifetimeStart(Ptr, Size);
3558      else
3559        New = IRB.CreateLifetimeEnd(Ptr, Size);
3560  
3561      (void)New;
3562      LLVM_DEBUG(dbgs() << "          to: " << *New << "\n");
3563  
3564      return true;
3565    }
3566  
fixLoadStoreAlign(Instruction & Root)3567    void fixLoadStoreAlign(Instruction &Root) {
3568      // This algorithm implements the same visitor loop as
3569      // hasUnsafePHIOrSelectUse, and fixes the alignment of each load
3570      // or store found.
3571      SmallPtrSet<Instruction *, 4> Visited;
3572      SmallVector<Instruction *, 4> Uses;
3573      Visited.insert(&Root);
3574      Uses.push_back(&Root);
3575      do {
3576        Instruction *I = Uses.pop_back_val();
3577  
3578        if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
3579          LI->setAlignment(std::min(LI->getAlign(), getSliceAlign()));
3580          continue;
3581        }
3582        if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
3583          SI->setAlignment(std::min(SI->getAlign(), getSliceAlign()));
3584          continue;
3585        }
3586  
3587        assert(isa<BitCastInst>(I) || isa<AddrSpaceCastInst>(I) ||
3588               isa<PHINode>(I) || isa<SelectInst>(I) ||
3589               isa<GetElementPtrInst>(I));
3590        for (User *U : I->users())
3591          if (Visited.insert(cast<Instruction>(U)).second)
3592            Uses.push_back(cast<Instruction>(U));
3593      } while (!Uses.empty());
3594    }
3595  
visitPHINode(PHINode & PN)3596    bool visitPHINode(PHINode &PN) {
3597      LLVM_DEBUG(dbgs() << "    original: " << PN << "\n");
3598      assert(BeginOffset >= NewAllocaBeginOffset && "PHIs are unsplittable");
3599      assert(EndOffset <= NewAllocaEndOffset && "PHIs are unsplittable");
3600  
3601      // We would like to compute a new pointer in only one place, but have it be
3602      // as local as possible to the PHI. To do that, we re-use the location of
3603      // the old pointer, which necessarily must be in the right position to
3604      // dominate the PHI.
3605      IRBuilderBase::InsertPointGuard Guard(IRB);
3606      if (isa<PHINode>(OldPtr))
3607        IRB.SetInsertPoint(OldPtr->getParent(),
3608                           OldPtr->getParent()->getFirstInsertionPt());
3609      else
3610        IRB.SetInsertPoint(OldPtr);
3611      IRB.SetCurrentDebugLocation(OldPtr->getDebugLoc());
3612  
3613      Value *NewPtr = getNewAllocaSlicePtr(IRB, OldPtr->getType());
3614      // Replace the operands which were using the old pointer.
3615      std::replace(PN.op_begin(), PN.op_end(), cast<Value>(OldPtr), NewPtr);
3616  
3617      LLVM_DEBUG(dbgs() << "          to: " << PN << "\n");
3618      deleteIfTriviallyDead(OldPtr);
3619  
3620      // Fix the alignment of any loads or stores using this PHI node.
3621      fixLoadStoreAlign(PN);
3622  
3623      // PHIs can't be promoted on their own, but often can be speculated. We
3624      // check the speculation outside of the rewriter so that we see the
3625      // fully-rewritten alloca.
3626      PHIUsers.insert(&PN);
3627      return true;
3628    }
3629  
visitSelectInst(SelectInst & SI)3630    bool visitSelectInst(SelectInst &SI) {
3631      LLVM_DEBUG(dbgs() << "    original: " << SI << "\n");
3632      assert((SI.getTrueValue() == OldPtr || SI.getFalseValue() == OldPtr) &&
3633             "Pointer isn't an operand!");
3634      assert(BeginOffset >= NewAllocaBeginOffset && "Selects are unsplittable");
3635      assert(EndOffset <= NewAllocaEndOffset && "Selects are unsplittable");
3636  
3637      Value *NewPtr = getNewAllocaSlicePtr(IRB, OldPtr->getType());
3638      // Replace the operands which were using the old pointer.
3639      if (SI.getOperand(1) == OldPtr)
3640        SI.setOperand(1, NewPtr);
3641      if (SI.getOperand(2) == OldPtr)
3642        SI.setOperand(2, NewPtr);
3643  
3644      LLVM_DEBUG(dbgs() << "          to: " << SI << "\n");
3645      deleteIfTriviallyDead(OldPtr);
3646  
3647      // Fix the alignment of any loads or stores using this select.
3648      fixLoadStoreAlign(SI);
3649  
3650      // Selects can't be promoted on their own, but often can be speculated. We
3651      // check the speculation outside of the rewriter so that we see the
3652      // fully-rewritten alloca.
3653      SelectUsers.insert(&SI);
3654      return true;
3655    }
3656  };
3657  
3658  /// Visitor to rewrite aggregate loads and stores as scalar.
3659  ///
3660  /// This pass aggressively rewrites all aggregate loads and stores on
3661  /// a particular pointer (or any pointer derived from it which we can identify)
3662  /// with scalar loads and stores.
3663  class AggLoadStoreRewriter : public InstVisitor<AggLoadStoreRewriter, bool> {
3664    // Befriend the base class so it can delegate to private visit methods.
3665    friend class InstVisitor<AggLoadStoreRewriter, bool>;
3666  
3667    /// Queue of pointer uses to analyze and potentially rewrite.
3668    SmallVector<Use *, 8> Queue;
3669  
3670    /// Set to prevent us from cycling with phi nodes and loops.
3671    SmallPtrSet<User *, 8> Visited;
3672  
3673    /// The current pointer use being rewritten. This is used to dig up the used
3674    /// value (as opposed to the user).
3675    Use *U = nullptr;
3676  
3677    /// Used to calculate offsets, and hence alignment, of subobjects.
3678    const DataLayout &DL;
3679  
3680    IRBuilderTy &IRB;
3681  
3682  public:
AggLoadStoreRewriter(const DataLayout & DL,IRBuilderTy & IRB)3683    AggLoadStoreRewriter(const DataLayout &DL, IRBuilderTy &IRB)
3684        : DL(DL), IRB(IRB) {}
3685  
3686    /// Rewrite loads and stores through a pointer and all pointers derived from
3687    /// it.
rewrite(Instruction & I)3688    bool rewrite(Instruction &I) {
3689      LLVM_DEBUG(dbgs() << "  Rewriting FCA loads and stores...\n");
3690      enqueueUsers(I);
3691      bool Changed = false;
3692      while (!Queue.empty()) {
3693        U = Queue.pop_back_val();
3694        Changed |= visit(cast<Instruction>(U->getUser()));
3695      }
3696      return Changed;
3697    }
3698  
3699  private:
3700    /// Enqueue all the users of the given instruction for further processing.
3701    /// This uses a set to de-duplicate users.
enqueueUsers(Instruction & I)3702    void enqueueUsers(Instruction &I) {
3703      for (Use &U : I.uses())
3704        if (Visited.insert(U.getUser()).second)
3705          Queue.push_back(&U);
3706    }
3707  
3708    // Conservative default is to not rewrite anything.
visitInstruction(Instruction & I)3709    bool visitInstruction(Instruction &I) { return false; }
3710  
3711    /// Generic recursive split emission class.
3712    template <typename Derived> class OpSplitter {
3713    protected:
3714      /// The builder used to form new instructions.
3715      IRBuilderTy &IRB;
3716  
3717      /// The indices which to be used with insert- or extractvalue to select the
3718      /// appropriate value within the aggregate.
3719      SmallVector<unsigned, 4> Indices;
3720  
3721      /// The indices to a GEP instruction which will move Ptr to the correct slot
3722      /// within the aggregate.
3723      SmallVector<Value *, 4> GEPIndices;
3724  
3725      /// The base pointer of the original op, used as a base for GEPing the
3726      /// split operations.
3727      Value *Ptr;
3728  
3729      /// The base pointee type being GEPed into.
3730      Type *BaseTy;
3731  
3732      /// Known alignment of the base pointer.
3733      Align BaseAlign;
3734  
3735      /// To calculate offset of each component so we can correctly deduce
3736      /// alignments.
3737      const DataLayout &DL;
3738  
3739      /// Initialize the splitter with an insertion point, Ptr and start with a
3740      /// single zero GEP index.
OpSplitter(Instruction * InsertionPoint,Value * Ptr,Type * BaseTy,Align BaseAlign,const DataLayout & DL,IRBuilderTy & IRB)3741      OpSplitter(Instruction *InsertionPoint, Value *Ptr, Type *BaseTy,
3742                 Align BaseAlign, const DataLayout &DL, IRBuilderTy &IRB)
3743          : IRB(IRB), GEPIndices(1, IRB.getInt32(0)), Ptr(Ptr), BaseTy(BaseTy),
3744            BaseAlign(BaseAlign), DL(DL) {
3745        IRB.SetInsertPoint(InsertionPoint);
3746      }
3747  
3748    public:
3749      /// Generic recursive split emission routine.
3750      ///
3751      /// This method recursively splits an aggregate op (load or store) into
3752      /// scalar or vector ops. It splits recursively until it hits a single value
3753      /// and emits that single value operation via the template argument.
3754      ///
3755      /// The logic of this routine relies on GEPs and insertvalue and
3756      /// extractvalue all operating with the same fundamental index list, merely
3757      /// formatted differently (GEPs need actual values).
3758      ///
3759      /// \param Ty  The type being split recursively into smaller ops.
3760      /// \param Agg The aggregate value being built up or stored, depending on
3761      /// whether this is splitting a load or a store respectively.
emitSplitOps(Type * Ty,Value * & Agg,const Twine & Name)3762      void emitSplitOps(Type *Ty, Value *&Agg, const Twine &Name) {
3763        if (Ty->isSingleValueType()) {
3764          unsigned Offset = DL.getIndexedOffsetInType(BaseTy, GEPIndices);
3765          return static_cast<Derived *>(this)->emitFunc(
3766              Ty, Agg, commonAlignment(BaseAlign, Offset), Name);
3767        }
3768  
3769        if (ArrayType *ATy = dyn_cast<ArrayType>(Ty)) {
3770          unsigned OldSize = Indices.size();
3771          (void)OldSize;
3772          for (unsigned Idx = 0, Size = ATy->getNumElements(); Idx != Size;
3773               ++Idx) {
3774            assert(Indices.size() == OldSize && "Did not return to the old size");
3775            Indices.push_back(Idx);
3776            GEPIndices.push_back(IRB.getInt32(Idx));
3777            emitSplitOps(ATy->getElementType(), Agg, Name + "." + Twine(Idx));
3778            GEPIndices.pop_back();
3779            Indices.pop_back();
3780          }
3781          return;
3782        }
3783  
3784        if (StructType *STy = dyn_cast<StructType>(Ty)) {
3785          unsigned OldSize = Indices.size();
3786          (void)OldSize;
3787          for (unsigned Idx = 0, Size = STy->getNumElements(); Idx != Size;
3788               ++Idx) {
3789            assert(Indices.size() == OldSize && "Did not return to the old size");
3790            Indices.push_back(Idx);
3791            GEPIndices.push_back(IRB.getInt32(Idx));
3792            emitSplitOps(STy->getElementType(Idx), Agg, Name + "." + Twine(Idx));
3793            GEPIndices.pop_back();
3794            Indices.pop_back();
3795          }
3796          return;
3797        }
3798  
3799        llvm_unreachable("Only arrays and structs are aggregate loadable types");
3800      }
3801    };
3802  
3803    struct LoadOpSplitter : public OpSplitter<LoadOpSplitter> {
3804      AAMDNodes AATags;
3805  
LoadOpSplitter__anondf5662880d11::AggLoadStoreRewriter::LoadOpSplitter3806      LoadOpSplitter(Instruction *InsertionPoint, Value *Ptr, Type *BaseTy,
3807                     AAMDNodes AATags, Align BaseAlign, const DataLayout &DL,
3808                     IRBuilderTy &IRB)
3809          : OpSplitter<LoadOpSplitter>(InsertionPoint, Ptr, BaseTy, BaseAlign, DL,
3810                                       IRB),
3811            AATags(AATags) {}
3812  
3813      /// Emit a leaf load of a single value. This is called at the leaves of the
3814      /// recursive emission to actually load values.
emitFunc__anondf5662880d11::AggLoadStoreRewriter::LoadOpSplitter3815      void emitFunc(Type *Ty, Value *&Agg, Align Alignment, const Twine &Name) {
3816        assert(Ty->isSingleValueType());
3817        // Load the single value and insert it using the indices.
3818        Value *GEP =
3819            IRB.CreateInBoundsGEP(BaseTy, Ptr, GEPIndices, Name + ".gep");
3820        LoadInst *Load =
3821            IRB.CreateAlignedLoad(Ty, GEP, Alignment, Name + ".load");
3822  
3823        APInt Offset(
3824            DL.getIndexSizeInBits(Ptr->getType()->getPointerAddressSpace()), 0);
3825        if (AATags &&
3826            GEPOperator::accumulateConstantOffset(BaseTy, GEPIndices, DL, Offset))
3827          Load->setAAMetadata(
3828              AATags.adjustForAccess(Offset.getZExtValue(), Load->getType(), DL));
3829  
3830        Agg = IRB.CreateInsertValue(Agg, Load, Indices, Name + ".insert");
3831        LLVM_DEBUG(dbgs() << "          to: " << *Load << "\n");
3832      }
3833    };
3834  
visitLoadInst(LoadInst & LI)3835    bool visitLoadInst(LoadInst &LI) {
3836      assert(LI.getPointerOperand() == *U);
3837      if (!LI.isSimple() || LI.getType()->isSingleValueType())
3838        return false;
3839  
3840      // We have an aggregate being loaded, split it apart.
3841      LLVM_DEBUG(dbgs() << "    original: " << LI << "\n");
3842      LoadOpSplitter Splitter(&LI, *U, LI.getType(), LI.getAAMetadata(),
3843                              getAdjustedAlignment(&LI, 0), DL, IRB);
3844      Value *V = PoisonValue::get(LI.getType());
3845      Splitter.emitSplitOps(LI.getType(), V, LI.getName() + ".fca");
3846      Visited.erase(&LI);
3847      LI.replaceAllUsesWith(V);
3848      LI.eraseFromParent();
3849      return true;
3850    }
3851  
3852    struct StoreOpSplitter : public OpSplitter<StoreOpSplitter> {
StoreOpSplitter__anondf5662880d11::AggLoadStoreRewriter::StoreOpSplitter3853      StoreOpSplitter(Instruction *InsertionPoint, Value *Ptr, Type *BaseTy,
3854                      AAMDNodes AATags, StoreInst *AggStore, Align BaseAlign,
3855                      const DataLayout &DL, IRBuilderTy &IRB)
3856          : OpSplitter<StoreOpSplitter>(InsertionPoint, Ptr, BaseTy, BaseAlign,
3857                                        DL, IRB),
3858            AATags(AATags), AggStore(AggStore) {}
3859      AAMDNodes AATags;
3860      StoreInst *AggStore;
3861      /// Emit a leaf store of a single value. This is called at the leaves of the
3862      /// recursive emission to actually produce stores.
emitFunc__anondf5662880d11::AggLoadStoreRewriter::StoreOpSplitter3863      void emitFunc(Type *Ty, Value *&Agg, Align Alignment, const Twine &Name) {
3864        assert(Ty->isSingleValueType());
3865        // Extract the single value and store it using the indices.
3866        //
3867        // The gep and extractvalue values are factored out of the CreateStore
3868        // call to make the output independent of the argument evaluation order.
3869        Value *ExtractValue =
3870            IRB.CreateExtractValue(Agg, Indices, Name + ".extract");
3871        Value *InBoundsGEP =
3872            IRB.CreateInBoundsGEP(BaseTy, Ptr, GEPIndices, Name + ".gep");
3873        StoreInst *Store =
3874            IRB.CreateAlignedStore(ExtractValue, InBoundsGEP, Alignment);
3875  
3876        APInt Offset(
3877            DL.getIndexSizeInBits(Ptr->getType()->getPointerAddressSpace()), 0);
3878        GEPOperator::accumulateConstantOffset(BaseTy, GEPIndices, DL, Offset);
3879        if (AATags) {
3880          Store->setAAMetadata(AATags.adjustForAccess(
3881              Offset.getZExtValue(), ExtractValue->getType(), DL));
3882        }
3883  
3884        // migrateDebugInfo requires the base Alloca. Walk to it from this gep.
3885        // If we cannot (because there's an intervening non-const or unbounded
3886        // gep) then we wouldn't expect to see dbg.assign intrinsics linked to
3887        // this instruction.
3888        Value *Base = AggStore->getPointerOperand()->stripInBoundsOffsets();
3889        if (auto *OldAI = dyn_cast<AllocaInst>(Base)) {
3890          uint64_t SizeInBits =
3891              DL.getTypeSizeInBits(Store->getValueOperand()->getType());
3892          migrateDebugInfo(OldAI, /*IsSplit*/ true, Offset.getZExtValue() * 8,
3893                           SizeInBits, AggStore, Store,
3894                           Store->getPointerOperand(), Store->getValueOperand(),
3895                           DL);
3896        } else {
3897          assert(at::getAssignmentMarkers(Store).empty() &&
3898                 at::getDVRAssignmentMarkers(Store).empty() &&
3899                 "AT: unexpected debug.assign linked to store through "
3900                 "unbounded GEP");
3901        }
3902        LLVM_DEBUG(dbgs() << "          to: " << *Store << "\n");
3903      }
3904    };
3905  
visitStoreInst(StoreInst & SI)3906    bool visitStoreInst(StoreInst &SI) {
3907      if (!SI.isSimple() || SI.getPointerOperand() != *U)
3908        return false;
3909      Value *V = SI.getValueOperand();
3910      if (V->getType()->isSingleValueType())
3911        return false;
3912  
3913      // We have an aggregate being stored, split it apart.
3914      LLVM_DEBUG(dbgs() << "    original: " << SI << "\n");
3915      StoreOpSplitter Splitter(&SI, *U, V->getType(), SI.getAAMetadata(), &SI,
3916                               getAdjustedAlignment(&SI, 0), DL, IRB);
3917      Splitter.emitSplitOps(V->getType(), V, V->getName() + ".fca");
3918      Visited.erase(&SI);
3919      // The stores replacing SI each have markers describing fragments of the
3920      // assignment so delete the assignment markers linked to SI.
3921      at::deleteAssignmentMarkers(&SI);
3922      SI.eraseFromParent();
3923      return true;
3924    }
3925  
visitBitCastInst(BitCastInst & BC)3926    bool visitBitCastInst(BitCastInst &BC) {
3927      enqueueUsers(BC);
3928      return false;
3929    }
3930  
visitAddrSpaceCastInst(AddrSpaceCastInst & ASC)3931    bool visitAddrSpaceCastInst(AddrSpaceCastInst &ASC) {
3932      enqueueUsers(ASC);
3933      return false;
3934    }
3935  
3936    // Unfold gep (select cond, ptr1, ptr2), idx
3937    //   => select cond, gep(ptr1, idx), gep(ptr2, idx)
3938    // and  gep ptr, (select cond, idx1, idx2)
3939    //   => select cond, gep(ptr, idx1), gep(ptr, idx2)
unfoldGEPSelect(GetElementPtrInst & GEPI)3940    bool unfoldGEPSelect(GetElementPtrInst &GEPI) {
3941      // Check whether the GEP has exactly one select operand and all indices
3942      // will become constant after the transform.
3943      SelectInst *Sel = dyn_cast<SelectInst>(GEPI.getPointerOperand());
3944      for (Value *Op : GEPI.indices()) {
3945        if (auto *SI = dyn_cast<SelectInst>(Op)) {
3946          if (Sel)
3947            return false;
3948  
3949          Sel = SI;
3950          if (!isa<ConstantInt>(Sel->getTrueValue()) ||
3951              !isa<ConstantInt>(Sel->getFalseValue()))
3952            return false;
3953          continue;
3954        }
3955  
3956        if (!isa<ConstantInt>(Op))
3957          return false;
3958      }
3959  
3960      if (!Sel)
3961        return false;
3962  
3963      LLVM_DEBUG(dbgs() << "  Rewriting gep(select) -> select(gep):\n";
3964                 dbgs() << "    original: " << *Sel << "\n";
3965                 dbgs() << "              " << GEPI << "\n";);
3966  
3967      auto GetNewOps = [&](Value *SelOp) {
3968        SmallVector<Value *> NewOps;
3969        for (Value *Op : GEPI.operands())
3970          if (Op == Sel)
3971            NewOps.push_back(SelOp);
3972          else
3973            NewOps.push_back(Op);
3974        return NewOps;
3975      };
3976  
3977      Value *True = Sel->getTrueValue();
3978      Value *False = Sel->getFalseValue();
3979      SmallVector<Value *> TrueOps = GetNewOps(True);
3980      SmallVector<Value *> FalseOps = GetNewOps(False);
3981  
3982      IRB.SetInsertPoint(&GEPI);
3983      GEPNoWrapFlags NW = GEPI.getNoWrapFlags();
3984  
3985      Type *Ty = GEPI.getSourceElementType();
3986      Value *NTrue = IRB.CreateGEP(Ty, TrueOps[0], ArrayRef(TrueOps).drop_front(),
3987                                   True->getName() + ".sroa.gep", NW);
3988  
3989      Value *NFalse =
3990          IRB.CreateGEP(Ty, FalseOps[0], ArrayRef(FalseOps).drop_front(),
3991                        False->getName() + ".sroa.gep", NW);
3992  
3993      Value *NSel = IRB.CreateSelect(Sel->getCondition(), NTrue, NFalse,
3994                                     Sel->getName() + ".sroa.sel");
3995      Visited.erase(&GEPI);
3996      GEPI.replaceAllUsesWith(NSel);
3997      GEPI.eraseFromParent();
3998      Instruction *NSelI = cast<Instruction>(NSel);
3999      Visited.insert(NSelI);
4000      enqueueUsers(*NSelI);
4001  
4002      LLVM_DEBUG(dbgs() << "          to: " << *NTrue << "\n";
4003                 dbgs() << "              " << *NFalse << "\n";
4004                 dbgs() << "              " << *NSel << "\n";);
4005  
4006      return true;
4007    }
4008  
4009    // Unfold gep (phi ptr1, ptr2), idx
4010    //   => phi ((gep ptr1, idx), (gep ptr2, idx))
4011    // and  gep ptr, (phi idx1, idx2)
4012    //   => phi ((gep ptr, idx1), (gep ptr, idx2))
unfoldGEPPhi(GetElementPtrInst & GEPI)4013    bool unfoldGEPPhi(GetElementPtrInst &GEPI) {
4014      // To prevent infinitely expanding recursive phis, bail if the GEP pointer
4015      // operand (looking through the phi if it is the phi we want to unfold) is
4016      // an instruction besides a static alloca.
4017      PHINode *Phi = dyn_cast<PHINode>(GEPI.getPointerOperand());
4018      auto IsInvalidPointerOperand = [](Value *V) {
4019        if (!isa<Instruction>(V))
4020          return false;
4021        if (auto *AI = dyn_cast<AllocaInst>(V))
4022          return !AI->isStaticAlloca();
4023        return true;
4024      };
4025      if (Phi) {
4026        if (any_of(Phi->operands(), IsInvalidPointerOperand))
4027          return false;
4028      } else {
4029        if (IsInvalidPointerOperand(GEPI.getPointerOperand()))
4030          return false;
4031      }
4032      // Check whether the GEP has exactly one phi operand (including the pointer
4033      // operand) and all indices will become constant after the transform.
4034      for (Value *Op : GEPI.indices()) {
4035        if (auto *SI = dyn_cast<PHINode>(Op)) {
4036          if (Phi)
4037            return false;
4038  
4039          Phi = SI;
4040          if (!all_of(Phi->incoming_values(),
4041                      [](Value *V) { return isa<ConstantInt>(V); }))
4042            return false;
4043          continue;
4044        }
4045  
4046        if (!isa<ConstantInt>(Op))
4047          return false;
4048      }
4049  
4050      if (!Phi)
4051        return false;
4052  
4053      LLVM_DEBUG(dbgs() << "  Rewriting gep(phi) -> phi(gep):\n";
4054                 dbgs() << "    original: " << *Phi << "\n";
4055                 dbgs() << "              " << GEPI << "\n";);
4056  
4057      auto GetNewOps = [&](Value *PhiOp) {
4058        SmallVector<Value *> NewOps;
4059        for (Value *Op : GEPI.operands())
4060          if (Op == Phi)
4061            NewOps.push_back(PhiOp);
4062          else
4063            NewOps.push_back(Op);
4064        return NewOps;
4065      };
4066  
4067      IRB.SetInsertPoint(Phi);
4068      PHINode *NewPhi = IRB.CreatePHI(GEPI.getType(), Phi->getNumIncomingValues(),
4069                                      Phi->getName() + ".sroa.phi");
4070  
4071      Type *SourceTy = GEPI.getSourceElementType();
4072      // We only handle arguments, constants, and static allocas here, so we can
4073      // insert GEPs at the end of the entry block.
4074      IRB.SetInsertPoint(GEPI.getFunction()->getEntryBlock().getTerminator());
4075      for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I) {
4076        Value *Op = Phi->getIncomingValue(I);
4077        BasicBlock *BB = Phi->getIncomingBlock(I);
4078        Value *NewGEP;
4079        if (int NI = NewPhi->getBasicBlockIndex(BB); NI >= 0) {
4080          NewGEP = NewPhi->getIncomingValue(NI);
4081        } else {
4082          SmallVector<Value *> NewOps = GetNewOps(Op);
4083          NewGEP =
4084              IRB.CreateGEP(SourceTy, NewOps[0], ArrayRef(NewOps).drop_front(),
4085                            Phi->getName() + ".sroa.gep", GEPI.getNoWrapFlags());
4086        }
4087        NewPhi->addIncoming(NewGEP, BB);
4088      }
4089  
4090      Visited.erase(&GEPI);
4091      GEPI.replaceAllUsesWith(NewPhi);
4092      GEPI.eraseFromParent();
4093      Visited.insert(NewPhi);
4094      enqueueUsers(*NewPhi);
4095  
4096      LLVM_DEBUG(dbgs() << "          to: ";
4097                 for (Value *In
4098                      : NewPhi->incoming_values()) dbgs()
4099                 << "\n              " << *In;
4100                 dbgs() << "\n              " << *NewPhi << '\n');
4101  
4102      return true;
4103    }
4104  
visitGetElementPtrInst(GetElementPtrInst & GEPI)4105    bool visitGetElementPtrInst(GetElementPtrInst &GEPI) {
4106      if (unfoldGEPSelect(GEPI))
4107        return true;
4108  
4109      if (unfoldGEPPhi(GEPI))
4110        return true;
4111  
4112      enqueueUsers(GEPI);
4113      return false;
4114    }
4115  
visitPHINode(PHINode & PN)4116    bool visitPHINode(PHINode &PN) {
4117      enqueueUsers(PN);
4118      return false;
4119    }
4120  
visitSelectInst(SelectInst & SI)4121    bool visitSelectInst(SelectInst &SI) {
4122      enqueueUsers(SI);
4123      return false;
4124    }
4125  };
4126  
4127  } // end anonymous namespace
4128  
4129  /// Strip aggregate type wrapping.
4130  ///
4131  /// This removes no-op aggregate types wrapping an underlying type. It will
4132  /// strip as many layers of types as it can without changing either the type
4133  /// size or the allocated size.
stripAggregateTypeWrapping(const DataLayout & DL,Type * Ty)4134  static Type *stripAggregateTypeWrapping(const DataLayout &DL, Type *Ty) {
4135    if (Ty->isSingleValueType())
4136      return Ty;
4137  
4138    uint64_t AllocSize = DL.getTypeAllocSize(Ty).getFixedValue();
4139    uint64_t TypeSize = DL.getTypeSizeInBits(Ty).getFixedValue();
4140  
4141    Type *InnerTy;
4142    if (ArrayType *ArrTy = dyn_cast<ArrayType>(Ty)) {
4143      InnerTy = ArrTy->getElementType();
4144    } else if (StructType *STy = dyn_cast<StructType>(Ty)) {
4145      const StructLayout *SL = DL.getStructLayout(STy);
4146      unsigned Index = SL->getElementContainingOffset(0);
4147      InnerTy = STy->getElementType(Index);
4148    } else {
4149      return Ty;
4150    }
4151  
4152    if (AllocSize > DL.getTypeAllocSize(InnerTy).getFixedValue() ||
4153        TypeSize > DL.getTypeSizeInBits(InnerTy).getFixedValue())
4154      return Ty;
4155  
4156    return stripAggregateTypeWrapping(DL, InnerTy);
4157  }
4158  
4159  /// Try to find a partition of the aggregate type passed in for a given
4160  /// offset and size.
4161  ///
4162  /// This recurses through the aggregate type and tries to compute a subtype
4163  /// based on the offset and size. When the offset and size span a sub-section
4164  /// of an array, it will even compute a new array type for that sub-section,
4165  /// and the same for structs.
4166  ///
4167  /// Note that this routine is very strict and tries to find a partition of the
4168  /// type which produces the *exact* right offset and size. It is not forgiving
4169  /// when the size or offset cause either end of type-based partition to be off.
4170  /// Also, this is a best-effort routine. It is reasonable to give up and not
4171  /// return a type if necessary.
getTypePartition(const DataLayout & DL,Type * Ty,uint64_t Offset,uint64_t Size)4172  static Type *getTypePartition(const DataLayout &DL, Type *Ty, uint64_t Offset,
4173                                uint64_t Size) {
4174    if (Offset == 0 && DL.getTypeAllocSize(Ty).getFixedValue() == Size)
4175      return stripAggregateTypeWrapping(DL, Ty);
4176    if (Offset > DL.getTypeAllocSize(Ty).getFixedValue() ||
4177        (DL.getTypeAllocSize(Ty).getFixedValue() - Offset) < Size)
4178      return nullptr;
4179  
4180    if (isa<ArrayType>(Ty) || isa<VectorType>(Ty)) {
4181      Type *ElementTy;
4182      uint64_t TyNumElements;
4183      if (auto *AT = dyn_cast<ArrayType>(Ty)) {
4184        ElementTy = AT->getElementType();
4185        TyNumElements = AT->getNumElements();
4186      } else {
4187        // FIXME: This isn't right for vectors with non-byte-sized or
4188        // non-power-of-two sized elements.
4189        auto *VT = cast<FixedVectorType>(Ty);
4190        ElementTy = VT->getElementType();
4191        TyNumElements = VT->getNumElements();
4192      }
4193      uint64_t ElementSize = DL.getTypeAllocSize(ElementTy).getFixedValue();
4194      uint64_t NumSkippedElements = Offset / ElementSize;
4195      if (NumSkippedElements >= TyNumElements)
4196        return nullptr;
4197      Offset -= NumSkippedElements * ElementSize;
4198  
4199      // First check if we need to recurse.
4200      if (Offset > 0 || Size < ElementSize) {
4201        // Bail if the partition ends in a different array element.
4202        if ((Offset + Size) > ElementSize)
4203          return nullptr;
4204        // Recurse through the element type trying to peel off offset bytes.
4205        return getTypePartition(DL, ElementTy, Offset, Size);
4206      }
4207      assert(Offset == 0);
4208  
4209      if (Size == ElementSize)
4210        return stripAggregateTypeWrapping(DL, ElementTy);
4211      assert(Size > ElementSize);
4212      uint64_t NumElements = Size / ElementSize;
4213      if (NumElements * ElementSize != Size)
4214        return nullptr;
4215      return ArrayType::get(ElementTy, NumElements);
4216    }
4217  
4218    StructType *STy = dyn_cast<StructType>(Ty);
4219    if (!STy)
4220      return nullptr;
4221  
4222    const StructLayout *SL = DL.getStructLayout(STy);
4223  
4224    if (SL->getSizeInBits().isScalable())
4225      return nullptr;
4226  
4227    if (Offset >= SL->getSizeInBytes())
4228      return nullptr;
4229    uint64_t EndOffset = Offset + Size;
4230    if (EndOffset > SL->getSizeInBytes())
4231      return nullptr;
4232  
4233    unsigned Index = SL->getElementContainingOffset(Offset);
4234    Offset -= SL->getElementOffset(Index);
4235  
4236    Type *ElementTy = STy->getElementType(Index);
4237    uint64_t ElementSize = DL.getTypeAllocSize(ElementTy).getFixedValue();
4238    if (Offset >= ElementSize)
4239      return nullptr; // The offset points into alignment padding.
4240  
4241    // See if any partition must be contained by the element.
4242    if (Offset > 0 || Size < ElementSize) {
4243      if ((Offset + Size) > ElementSize)
4244        return nullptr;
4245      return getTypePartition(DL, ElementTy, Offset, Size);
4246    }
4247    assert(Offset == 0);
4248  
4249    if (Size == ElementSize)
4250      return stripAggregateTypeWrapping(DL, ElementTy);
4251  
4252    StructType::element_iterator EI = STy->element_begin() + Index,
4253                                 EE = STy->element_end();
4254    if (EndOffset < SL->getSizeInBytes()) {
4255      unsigned EndIndex = SL->getElementContainingOffset(EndOffset);
4256      if (Index == EndIndex)
4257        return nullptr; // Within a single element and its padding.
4258  
4259      // Don't try to form "natural" types if the elements don't line up with the
4260      // expected size.
4261      // FIXME: We could potentially recurse down through the last element in the
4262      // sub-struct to find a natural end point.
4263      if (SL->getElementOffset(EndIndex) != EndOffset)
4264        return nullptr;
4265  
4266      assert(Index < EndIndex);
4267      EE = STy->element_begin() + EndIndex;
4268    }
4269  
4270    // Try to build up a sub-structure.
4271    StructType *SubTy =
4272        StructType::get(STy->getContext(), ArrayRef(EI, EE), STy->isPacked());
4273    const StructLayout *SubSL = DL.getStructLayout(SubTy);
4274    if (Size != SubSL->getSizeInBytes())
4275      return nullptr; // The sub-struct doesn't have quite the size needed.
4276  
4277    return SubTy;
4278  }
4279  
4280  /// Pre-split loads and stores to simplify rewriting.
4281  ///
4282  /// We want to break up the splittable load+store pairs as much as
4283  /// possible. This is important to do as a preprocessing step, as once we
4284  /// start rewriting the accesses to partitions of the alloca we lose the
4285  /// necessary information to correctly split apart paired loads and stores
4286  /// which both point into this alloca. The case to consider is something like
4287  /// the following:
4288  ///
4289  ///   %a = alloca [12 x i8]
4290  ///   %gep1 = getelementptr i8, ptr %a, i32 0
4291  ///   %gep2 = getelementptr i8, ptr %a, i32 4
4292  ///   %gep3 = getelementptr i8, ptr %a, i32 8
4293  ///   store float 0.0, ptr %gep1
4294  ///   store float 1.0, ptr %gep2
4295  ///   %v = load i64, ptr %gep1
4296  ///   store i64 %v, ptr %gep2
4297  ///   %f1 = load float, ptr %gep2
4298  ///   %f2 = load float, ptr %gep3
4299  ///
4300  /// Here we want to form 3 partitions of the alloca, each 4 bytes large, and
4301  /// promote everything so we recover the 2 SSA values that should have been
4302  /// there all along.
4303  ///
4304  /// \returns true if any changes are made.
presplitLoadsAndStores(AllocaInst & AI,AllocaSlices & AS)4305  bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) {
4306    LLVM_DEBUG(dbgs() << "Pre-splitting loads and stores\n");
4307  
4308    // Track the loads and stores which are candidates for pre-splitting here, in
4309    // the order they first appear during the partition scan. These give stable
4310    // iteration order and a basis for tracking which loads and stores we
4311    // actually split.
4312    SmallVector<LoadInst *, 4> Loads;
4313    SmallVector<StoreInst *, 4> Stores;
4314  
4315    // We need to accumulate the splits required of each load or store where we
4316    // can find them via a direct lookup. This is important to cross-check loads
4317    // and stores against each other. We also track the slice so that we can kill
4318    // all the slices that end up split.
4319    struct SplitOffsets {
4320      Slice *S;
4321      std::vector<uint64_t> Splits;
4322    };
4323    SmallDenseMap<Instruction *, SplitOffsets, 8> SplitOffsetsMap;
4324  
4325    // Track loads out of this alloca which cannot, for any reason, be pre-split.
4326    // This is important as we also cannot pre-split stores of those loads!
4327    // FIXME: This is all pretty gross. It means that we can be more aggressive
4328    // in pre-splitting when the load feeding the store happens to come from
4329    // a separate alloca. Put another way, the effectiveness of SROA would be
4330    // decreased by a frontend which just concatenated all of its local allocas
4331    // into one big flat alloca. But defeating such patterns is exactly the job
4332    // SROA is tasked with! Sadly, to not have this discrepancy we would have
4333    // change store pre-splitting to actually force pre-splitting of the load
4334    // that feeds it *and all stores*. That makes pre-splitting much harder, but
4335    // maybe it would make it more principled?
4336    SmallPtrSet<LoadInst *, 8> UnsplittableLoads;
4337  
4338    LLVM_DEBUG(dbgs() << "  Searching for candidate loads and stores\n");
4339    for (auto &P : AS.partitions()) {
4340      for (Slice &S : P) {
4341        Instruction *I = cast<Instruction>(S.getUse()->getUser());
4342        if (!S.isSplittable() || S.endOffset() <= P.endOffset()) {
4343          // If this is a load we have to track that it can't participate in any
4344          // pre-splitting. If this is a store of a load we have to track that
4345          // that load also can't participate in any pre-splitting.
4346          if (auto *LI = dyn_cast<LoadInst>(I))
4347            UnsplittableLoads.insert(LI);
4348          else if (auto *SI = dyn_cast<StoreInst>(I))
4349            if (auto *LI = dyn_cast<LoadInst>(SI->getValueOperand()))
4350              UnsplittableLoads.insert(LI);
4351          continue;
4352        }
4353        assert(P.endOffset() > S.beginOffset() &&
4354               "Empty or backwards partition!");
4355  
4356        // Determine if this is a pre-splittable slice.
4357        if (auto *LI = dyn_cast<LoadInst>(I)) {
4358          assert(!LI->isVolatile() && "Cannot split volatile loads!");
4359  
4360          // The load must be used exclusively to store into other pointers for
4361          // us to be able to arbitrarily pre-split it. The stores must also be
4362          // simple to avoid changing semantics.
4363          auto IsLoadSimplyStored = [](LoadInst *LI) {
4364            for (User *LU : LI->users()) {
4365              auto *SI = dyn_cast<StoreInst>(LU);
4366              if (!SI || !SI->isSimple())
4367                return false;
4368            }
4369            return true;
4370          };
4371          if (!IsLoadSimplyStored(LI)) {
4372            UnsplittableLoads.insert(LI);
4373            continue;
4374          }
4375  
4376          Loads.push_back(LI);
4377        } else if (auto *SI = dyn_cast<StoreInst>(I)) {
4378          if (S.getUse() != &SI->getOperandUse(SI->getPointerOperandIndex()))
4379            // Skip stores *of* pointers. FIXME: This shouldn't even be possible!
4380            continue;
4381          auto *StoredLoad = dyn_cast<LoadInst>(SI->getValueOperand());
4382          if (!StoredLoad || !StoredLoad->isSimple())
4383            continue;
4384          assert(!SI->isVolatile() && "Cannot split volatile stores!");
4385  
4386          Stores.push_back(SI);
4387        } else {
4388          // Other uses cannot be pre-split.
4389          continue;
4390        }
4391  
4392        // Record the initial split.
4393        LLVM_DEBUG(dbgs() << "    Candidate: " << *I << "\n");
4394        auto &Offsets = SplitOffsetsMap[I];
4395        assert(Offsets.Splits.empty() &&
4396               "Should not have splits the first time we see an instruction!");
4397        Offsets.S = &S;
4398        Offsets.Splits.push_back(P.endOffset() - S.beginOffset());
4399      }
4400  
4401      // Now scan the already split slices, and add a split for any of them which
4402      // we're going to pre-split.
4403      for (Slice *S : P.splitSliceTails()) {
4404        auto SplitOffsetsMapI =
4405            SplitOffsetsMap.find(cast<Instruction>(S->getUse()->getUser()));
4406        if (SplitOffsetsMapI == SplitOffsetsMap.end())
4407          continue;
4408        auto &Offsets = SplitOffsetsMapI->second;
4409  
4410        assert(Offsets.S == S && "Found a mismatched slice!");
4411        assert(!Offsets.Splits.empty() &&
4412               "Cannot have an empty set of splits on the second partition!");
4413        assert(Offsets.Splits.back() ==
4414                   P.beginOffset() - Offsets.S->beginOffset() &&
4415               "Previous split does not end where this one begins!");
4416  
4417        // Record each split. The last partition's end isn't needed as the size
4418        // of the slice dictates that.
4419        if (S->endOffset() > P.endOffset())
4420          Offsets.Splits.push_back(P.endOffset() - Offsets.S->beginOffset());
4421      }
4422    }
4423  
4424    // We may have split loads where some of their stores are split stores. For
4425    // such loads and stores, we can only pre-split them if their splits exactly
4426    // match relative to their starting offset. We have to verify this prior to
4427    // any rewriting.
4428    llvm::erase_if(Stores, [&UnsplittableLoads, &SplitOffsetsMap](StoreInst *SI) {
4429      // Lookup the load we are storing in our map of split
4430      // offsets.
4431      auto *LI = cast<LoadInst>(SI->getValueOperand());
4432      // If it was completely unsplittable, then we're done,
4433      // and this store can't be pre-split.
4434      if (UnsplittableLoads.count(LI))
4435        return true;
4436  
4437      auto LoadOffsetsI = SplitOffsetsMap.find(LI);
4438      if (LoadOffsetsI == SplitOffsetsMap.end())
4439        return false; // Unrelated loads are definitely safe.
4440      auto &LoadOffsets = LoadOffsetsI->second;
4441  
4442      // Now lookup the store's offsets.
4443      auto &StoreOffsets = SplitOffsetsMap[SI];
4444  
4445      // If the relative offsets of each split in the load and
4446      // store match exactly, then we can split them and we
4447      // don't need to remove them here.
4448      if (LoadOffsets.Splits == StoreOffsets.Splits)
4449        return false;
4450  
4451      LLVM_DEBUG(dbgs() << "    Mismatched splits for load and store:\n"
4452                        << "      " << *LI << "\n"
4453                        << "      " << *SI << "\n");
4454  
4455      // We've found a store and load that we need to split
4456      // with mismatched relative splits. Just give up on them
4457      // and remove both instructions from our list of
4458      // candidates.
4459      UnsplittableLoads.insert(LI);
4460      return true;
4461    });
4462    // Now we have to go *back* through all the stores, because a later store may
4463    // have caused an earlier store's load to become unsplittable and if it is
4464    // unsplittable for the later store, then we can't rely on it being split in
4465    // the earlier store either.
4466    llvm::erase_if(Stores, [&UnsplittableLoads](StoreInst *SI) {
4467      auto *LI = cast<LoadInst>(SI->getValueOperand());
4468      return UnsplittableLoads.count(LI);
4469    });
4470    // Once we've established all the loads that can't be split for some reason,
4471    // filter any that made it into our list out.
4472    llvm::erase_if(Loads, [&UnsplittableLoads](LoadInst *LI) {
4473      return UnsplittableLoads.count(LI);
4474    });
4475  
4476    // If no loads or stores are left, there is no pre-splitting to be done for
4477    // this alloca.
4478    if (Loads.empty() && Stores.empty())
4479      return false;
4480  
4481    // From here on, we can't fail and will be building new accesses, so rig up
4482    // an IR builder.
4483    IRBuilderTy IRB(&AI);
4484  
4485    // Collect the new slices which we will merge into the alloca slices.
4486    SmallVector<Slice, 4> NewSlices;
4487  
4488    // Track any allocas we end up splitting loads and stores for so we iterate
4489    // on them.
4490    SmallPtrSet<AllocaInst *, 4> ResplitPromotableAllocas;
4491  
4492    // At this point, we have collected all of the loads and stores we can
4493    // pre-split, and the specific splits needed for them. We actually do the
4494    // splitting in a specific order in order to handle when one of the loads in
4495    // the value operand to one of the stores.
4496    //
4497    // First, we rewrite all of the split loads, and just accumulate each split
4498    // load in a parallel structure. We also build the slices for them and append
4499    // them to the alloca slices.
4500    SmallDenseMap<LoadInst *, std::vector<LoadInst *>, 1> SplitLoadsMap;
4501    std::vector<LoadInst *> SplitLoads;
4502    const DataLayout &DL = AI.getDataLayout();
4503    for (LoadInst *LI : Loads) {
4504      SplitLoads.clear();
4505  
4506      auto &Offsets = SplitOffsetsMap[LI];
4507      unsigned SliceSize = Offsets.S->endOffset() - Offsets.S->beginOffset();
4508      assert(LI->getType()->getIntegerBitWidth() % 8 == 0 &&
4509             "Load must have type size equal to store size");
4510      assert(LI->getType()->getIntegerBitWidth() / 8 >= SliceSize &&
4511             "Load must be >= slice size");
4512  
4513      uint64_t BaseOffset = Offsets.S->beginOffset();
4514      assert(BaseOffset + SliceSize > BaseOffset &&
4515             "Cannot represent alloca access size using 64-bit integers!");
4516  
4517      Instruction *BasePtr = cast<Instruction>(LI->getPointerOperand());
4518      IRB.SetInsertPoint(LI);
4519  
4520      LLVM_DEBUG(dbgs() << "  Splitting load: " << *LI << "\n");
4521  
4522      uint64_t PartOffset = 0, PartSize = Offsets.Splits.front();
4523      int Idx = 0, Size = Offsets.Splits.size();
4524      for (;;) {
4525        auto *PartTy = Type::getIntNTy(LI->getContext(), PartSize * 8);
4526        auto AS = LI->getPointerAddressSpace();
4527        auto *PartPtrTy = LI->getPointerOperandType();
4528        LoadInst *PLoad = IRB.CreateAlignedLoad(
4529            PartTy,
4530            getAdjustedPtr(IRB, DL, BasePtr,
4531                           APInt(DL.getIndexSizeInBits(AS), PartOffset),
4532                           PartPtrTy, BasePtr->getName() + "."),
4533            getAdjustedAlignment(LI, PartOffset),
4534            /*IsVolatile*/ false, LI->getName());
4535        PLoad->copyMetadata(*LI, {LLVMContext::MD_mem_parallel_loop_access,
4536                                  LLVMContext::MD_access_group});
4537  
4538        // Append this load onto the list of split loads so we can find it later
4539        // to rewrite the stores.
4540        SplitLoads.push_back(PLoad);
4541  
4542        // Now build a new slice for the alloca.
4543        NewSlices.push_back(
4544            Slice(BaseOffset + PartOffset, BaseOffset + PartOffset + PartSize,
4545                  &PLoad->getOperandUse(PLoad->getPointerOperandIndex()),
4546                  /*IsSplittable*/ false));
4547        LLVM_DEBUG(dbgs() << "    new slice [" << NewSlices.back().beginOffset()
4548                          << ", " << NewSlices.back().endOffset()
4549                          << "): " << *PLoad << "\n");
4550  
4551        // See if we've handled all the splits.
4552        if (Idx >= Size)
4553          break;
4554  
4555        // Setup the next partition.
4556        PartOffset = Offsets.Splits[Idx];
4557        ++Idx;
4558        PartSize = (Idx < Size ? Offsets.Splits[Idx] : SliceSize) - PartOffset;
4559      }
4560  
4561      // Now that we have the split loads, do the slow walk over all uses of the
4562      // load and rewrite them as split stores, or save the split loads to use
4563      // below if the store is going to be split there anyways.
4564      bool DeferredStores = false;
4565      for (User *LU : LI->users()) {
4566        StoreInst *SI = cast<StoreInst>(LU);
4567        if (!Stores.empty() && SplitOffsetsMap.count(SI)) {
4568          DeferredStores = true;
4569          LLVM_DEBUG(dbgs() << "    Deferred splitting of store: " << *SI
4570                            << "\n");
4571          continue;
4572        }
4573  
4574        Value *StoreBasePtr = SI->getPointerOperand();
4575        IRB.SetInsertPoint(SI);
4576        AAMDNodes AATags = SI->getAAMetadata();
4577  
4578        LLVM_DEBUG(dbgs() << "    Splitting store of load: " << *SI << "\n");
4579  
4580        for (int Idx = 0, Size = SplitLoads.size(); Idx < Size; ++Idx) {
4581          LoadInst *PLoad = SplitLoads[Idx];
4582          uint64_t PartOffset = Idx == 0 ? 0 : Offsets.Splits[Idx - 1];
4583          auto *PartPtrTy = SI->getPointerOperandType();
4584  
4585          auto AS = SI->getPointerAddressSpace();
4586          StoreInst *PStore = IRB.CreateAlignedStore(
4587              PLoad,
4588              getAdjustedPtr(IRB, DL, StoreBasePtr,
4589                             APInt(DL.getIndexSizeInBits(AS), PartOffset),
4590                             PartPtrTy, StoreBasePtr->getName() + "."),
4591              getAdjustedAlignment(SI, PartOffset),
4592              /*IsVolatile*/ false);
4593          PStore->copyMetadata(*SI, {LLVMContext::MD_mem_parallel_loop_access,
4594                                     LLVMContext::MD_access_group,
4595                                     LLVMContext::MD_DIAssignID});
4596  
4597          if (AATags)
4598            PStore->setAAMetadata(
4599                AATags.adjustForAccess(PartOffset, PLoad->getType(), DL));
4600          LLVM_DEBUG(dbgs() << "      +" << PartOffset << ":" << *PStore << "\n");
4601        }
4602  
4603        // We want to immediately iterate on any allocas impacted by splitting
4604        // this store, and we have to track any promotable alloca (indicated by
4605        // a direct store) as needing to be resplit because it is no longer
4606        // promotable.
4607        if (AllocaInst *OtherAI = dyn_cast<AllocaInst>(StoreBasePtr)) {
4608          ResplitPromotableAllocas.insert(OtherAI);
4609          Worklist.insert(OtherAI);
4610        } else if (AllocaInst *OtherAI = dyn_cast<AllocaInst>(
4611                       StoreBasePtr->stripInBoundsOffsets())) {
4612          Worklist.insert(OtherAI);
4613        }
4614  
4615        // Mark the original store as dead.
4616        DeadInsts.push_back(SI);
4617      }
4618  
4619      // Save the split loads if there are deferred stores among the users.
4620      if (DeferredStores)
4621        SplitLoadsMap.insert(std::make_pair(LI, std::move(SplitLoads)));
4622  
4623      // Mark the original load as dead and kill the original slice.
4624      DeadInsts.push_back(LI);
4625      Offsets.S->kill();
4626    }
4627  
4628    // Second, we rewrite all of the split stores. At this point, we know that
4629    // all loads from this alloca have been split already. For stores of such
4630    // loads, we can simply look up the pre-existing split loads. For stores of
4631    // other loads, we split those loads first and then write split stores of
4632    // them.
4633    for (StoreInst *SI : Stores) {
4634      auto *LI = cast<LoadInst>(SI->getValueOperand());
4635      IntegerType *Ty = cast<IntegerType>(LI->getType());
4636      assert(Ty->getBitWidth() % 8 == 0);
4637      uint64_t StoreSize = Ty->getBitWidth() / 8;
4638      assert(StoreSize > 0 && "Cannot have a zero-sized integer store!");
4639  
4640      auto &Offsets = SplitOffsetsMap[SI];
4641      assert(StoreSize == Offsets.S->endOffset() - Offsets.S->beginOffset() &&
4642             "Slice size should always match load size exactly!");
4643      uint64_t BaseOffset = Offsets.S->beginOffset();
4644      assert(BaseOffset + StoreSize > BaseOffset &&
4645             "Cannot represent alloca access size using 64-bit integers!");
4646  
4647      Value *LoadBasePtr = LI->getPointerOperand();
4648      Instruction *StoreBasePtr = cast<Instruction>(SI->getPointerOperand());
4649  
4650      LLVM_DEBUG(dbgs() << "  Splitting store: " << *SI << "\n");
4651  
4652      // Check whether we have an already split load.
4653      auto SplitLoadsMapI = SplitLoadsMap.find(LI);
4654      std::vector<LoadInst *> *SplitLoads = nullptr;
4655      if (SplitLoadsMapI != SplitLoadsMap.end()) {
4656        SplitLoads = &SplitLoadsMapI->second;
4657        assert(SplitLoads->size() == Offsets.Splits.size() + 1 &&
4658               "Too few split loads for the number of splits in the store!");
4659      } else {
4660        LLVM_DEBUG(dbgs() << "          of load: " << *LI << "\n");
4661      }
4662  
4663      uint64_t PartOffset = 0, PartSize = Offsets.Splits.front();
4664      int Idx = 0, Size = Offsets.Splits.size();
4665      for (;;) {
4666        auto *PartTy = Type::getIntNTy(Ty->getContext(), PartSize * 8);
4667        auto *LoadPartPtrTy = LI->getPointerOperandType();
4668        auto *StorePartPtrTy = SI->getPointerOperandType();
4669  
4670        // Either lookup a split load or create one.
4671        LoadInst *PLoad;
4672        if (SplitLoads) {
4673          PLoad = (*SplitLoads)[Idx];
4674        } else {
4675          IRB.SetInsertPoint(LI);
4676          auto AS = LI->getPointerAddressSpace();
4677          PLoad = IRB.CreateAlignedLoad(
4678              PartTy,
4679              getAdjustedPtr(IRB, DL, LoadBasePtr,
4680                             APInt(DL.getIndexSizeInBits(AS), PartOffset),
4681                             LoadPartPtrTy, LoadBasePtr->getName() + "."),
4682              getAdjustedAlignment(LI, PartOffset),
4683              /*IsVolatile*/ false, LI->getName());
4684          PLoad->copyMetadata(*LI, {LLVMContext::MD_mem_parallel_loop_access,
4685                                    LLVMContext::MD_access_group});
4686        }
4687  
4688        // And store this partition.
4689        IRB.SetInsertPoint(SI);
4690        auto AS = SI->getPointerAddressSpace();
4691        StoreInst *PStore = IRB.CreateAlignedStore(
4692            PLoad,
4693            getAdjustedPtr(IRB, DL, StoreBasePtr,
4694                           APInt(DL.getIndexSizeInBits(AS), PartOffset),
4695                           StorePartPtrTy, StoreBasePtr->getName() + "."),
4696            getAdjustedAlignment(SI, PartOffset),
4697            /*IsVolatile*/ false);
4698        PStore->copyMetadata(*SI, {LLVMContext::MD_mem_parallel_loop_access,
4699                                   LLVMContext::MD_access_group});
4700  
4701        // Now build a new slice for the alloca.
4702        NewSlices.push_back(
4703            Slice(BaseOffset + PartOffset, BaseOffset + PartOffset + PartSize,
4704                  &PStore->getOperandUse(PStore->getPointerOperandIndex()),
4705                  /*IsSplittable*/ false));
4706        LLVM_DEBUG(dbgs() << "    new slice [" << NewSlices.back().beginOffset()
4707                          << ", " << NewSlices.back().endOffset()
4708                          << "): " << *PStore << "\n");
4709        if (!SplitLoads) {
4710          LLVM_DEBUG(dbgs() << "      of split load: " << *PLoad << "\n");
4711        }
4712  
4713        // See if we've finished all the splits.
4714        if (Idx >= Size)
4715          break;
4716  
4717        // Setup the next partition.
4718        PartOffset = Offsets.Splits[Idx];
4719        ++Idx;
4720        PartSize = (Idx < Size ? Offsets.Splits[Idx] : StoreSize) - PartOffset;
4721      }
4722  
4723      // We want to immediately iterate on any allocas impacted by splitting
4724      // this load, which is only relevant if it isn't a load of this alloca and
4725      // thus we didn't already split the loads above. We also have to keep track
4726      // of any promotable allocas we split loads on as they can no longer be
4727      // promoted.
4728      if (!SplitLoads) {
4729        if (AllocaInst *OtherAI = dyn_cast<AllocaInst>(LoadBasePtr)) {
4730          assert(OtherAI != &AI && "We can't re-split our own alloca!");
4731          ResplitPromotableAllocas.insert(OtherAI);
4732          Worklist.insert(OtherAI);
4733        } else if (AllocaInst *OtherAI = dyn_cast<AllocaInst>(
4734                       LoadBasePtr->stripInBoundsOffsets())) {
4735          assert(OtherAI != &AI && "We can't re-split our own alloca!");
4736          Worklist.insert(OtherAI);
4737        }
4738      }
4739  
4740      // Mark the original store as dead now that we've split it up and kill its
4741      // slice. Note that we leave the original load in place unless this store
4742      // was its only use. It may in turn be split up if it is an alloca load
4743      // for some other alloca, but it may be a normal load. This may introduce
4744      // redundant loads, but where those can be merged the rest of the optimizer
4745      // should handle the merging, and this uncovers SSA splits which is more
4746      // important. In practice, the original loads will almost always be fully
4747      // split and removed eventually, and the splits will be merged by any
4748      // trivial CSE, including instcombine.
4749      if (LI->hasOneUse()) {
4750        assert(*LI->user_begin() == SI && "Single use isn't this store!");
4751        DeadInsts.push_back(LI);
4752      }
4753      DeadInsts.push_back(SI);
4754      Offsets.S->kill();
4755    }
4756  
4757    // Remove the killed slices that have ben pre-split.
4758    llvm::erase_if(AS, [](const Slice &S) { return S.isDead(); });
4759  
4760    // Insert our new slices. This will sort and merge them into the sorted
4761    // sequence.
4762    AS.insert(NewSlices);
4763  
4764    LLVM_DEBUG(dbgs() << "  Pre-split slices:\n");
4765  #ifndef NDEBUG
4766    for (auto I = AS.begin(), E = AS.end(); I != E; ++I)
4767      LLVM_DEBUG(AS.print(dbgs(), I, "    "));
4768  #endif
4769  
4770    // Finally, don't try to promote any allocas that new require re-splitting.
4771    // They have already been added to the worklist above.
4772    llvm::erase_if(PromotableAllocas, [&](AllocaInst *AI) {
4773      return ResplitPromotableAllocas.count(AI);
4774    });
4775  
4776    return true;
4777  }
4778  
4779  /// Rewrite an alloca partition's users.
4780  ///
4781  /// This routine drives both of the rewriting goals of the SROA pass. It tries
4782  /// to rewrite uses of an alloca partition to be conducive for SSA value
4783  /// promotion. If the partition needs a new, more refined alloca, this will
4784  /// build that new alloca, preserving as much type information as possible, and
4785  /// rewrite the uses of the old alloca to point at the new one and have the
4786  /// appropriate new offsets. It also evaluates how successful the rewrite was
4787  /// at enabling promotion and if it was successful queues the alloca to be
4788  /// promoted.
rewritePartition(AllocaInst & AI,AllocaSlices & AS,Partition & P)4789  AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS,
4790                                     Partition &P) {
4791    // Try to compute a friendly type for this partition of the alloca. This
4792    // won't always succeed, in which case we fall back to a legal integer type
4793    // or an i8 array of an appropriate size.
4794    Type *SliceTy = nullptr;
4795    VectorType *SliceVecTy = nullptr;
4796    const DataLayout &DL = AI.getDataLayout();
4797    std::pair<Type *, IntegerType *> CommonUseTy =
4798        findCommonType(P.begin(), P.end(), P.endOffset());
4799    // Do all uses operate on the same type?
4800    if (CommonUseTy.first)
4801      if (DL.getTypeAllocSize(CommonUseTy.first).getFixedValue() >= P.size()) {
4802        SliceTy = CommonUseTy.first;
4803        SliceVecTy = dyn_cast<VectorType>(SliceTy);
4804      }
4805    // If not, can we find an appropriate subtype in the original allocated type?
4806    if (!SliceTy)
4807      if (Type *TypePartitionTy = getTypePartition(DL, AI.getAllocatedType(),
4808                                                   P.beginOffset(), P.size()))
4809        SliceTy = TypePartitionTy;
4810  
4811    // If still not, can we use the largest bitwidth integer type used?
4812    if (!SliceTy && CommonUseTy.second)
4813      if (DL.getTypeAllocSize(CommonUseTy.second).getFixedValue() >= P.size()) {
4814        SliceTy = CommonUseTy.second;
4815        SliceVecTy = dyn_cast<VectorType>(SliceTy);
4816      }
4817    if ((!SliceTy || (SliceTy->isArrayTy() &&
4818                      SliceTy->getArrayElementType()->isIntegerTy())) &&
4819        DL.isLegalInteger(P.size() * 8)) {
4820      SliceTy = Type::getIntNTy(*C, P.size() * 8);
4821    }
4822  
4823    // If the common use types are not viable for promotion then attempt to find
4824    // another type that is viable.
4825    if (SliceVecTy && !checkVectorTypeForPromotion(P, SliceVecTy, DL))
4826      if (Type *TypePartitionTy = getTypePartition(DL, AI.getAllocatedType(),
4827                                                   P.beginOffset(), P.size())) {
4828        VectorType *TypePartitionVecTy = dyn_cast<VectorType>(TypePartitionTy);
4829        if (TypePartitionVecTy &&
4830            checkVectorTypeForPromotion(P, TypePartitionVecTy, DL))
4831          SliceTy = TypePartitionTy;
4832      }
4833  
4834    if (!SliceTy)
4835      SliceTy = ArrayType::get(Type::getInt8Ty(*C), P.size());
4836    assert(DL.getTypeAllocSize(SliceTy).getFixedValue() >= P.size());
4837  
4838    bool IsIntegerPromotable = isIntegerWideningViable(P, SliceTy, DL);
4839  
4840    VectorType *VecTy =
4841        IsIntegerPromotable ? nullptr : isVectorPromotionViable(P, DL);
4842    if (VecTy)
4843      SliceTy = VecTy;
4844  
4845    // Check for the case where we're going to rewrite to a new alloca of the
4846    // exact same type as the original, and with the same access offsets. In that
4847    // case, re-use the existing alloca, but still run through the rewriter to
4848    // perform phi and select speculation.
4849    // P.beginOffset() can be non-zero even with the same type in a case with
4850    // out-of-bounds access (e.g. @PR35657 function in SROA/basictest.ll).
4851    AllocaInst *NewAI;
4852    if (SliceTy == AI.getAllocatedType() && P.beginOffset() == 0) {
4853      NewAI = &AI;
4854      // FIXME: We should be able to bail at this point with "nothing changed".
4855      // FIXME: We might want to defer PHI speculation until after here.
4856      // FIXME: return nullptr;
4857    } else {
4858      // Make sure the alignment is compatible with P.beginOffset().
4859      const Align Alignment = commonAlignment(AI.getAlign(), P.beginOffset());
4860      // If we will get at least this much alignment from the type alone, leave
4861      // the alloca's alignment unconstrained.
4862      const bool IsUnconstrained = Alignment <= DL.getABITypeAlign(SliceTy);
4863      NewAI = new AllocaInst(
4864          SliceTy, AI.getAddressSpace(), nullptr,
4865          IsUnconstrained ? DL.getPrefTypeAlign(SliceTy) : Alignment,
4866          AI.getName() + ".sroa." + Twine(P.begin() - AS.begin()),
4867          AI.getIterator());
4868      // Copy the old AI debug location over to the new one.
4869      NewAI->setDebugLoc(AI.getDebugLoc());
4870      ++NumNewAllocas;
4871    }
4872  
4873    LLVM_DEBUG(dbgs() << "Rewriting alloca partition " << "[" << P.beginOffset()
4874                      << "," << P.endOffset() << ") to: " << *NewAI << "\n");
4875  
4876    // Track the high watermark on the worklist as it is only relevant for
4877    // promoted allocas. We will reset it to this point if the alloca is not in
4878    // fact scheduled for promotion.
4879    unsigned PPWOldSize = PostPromotionWorklist.size();
4880    unsigned NumUses = 0;
4881    SmallSetVector<PHINode *, 8> PHIUsers;
4882    SmallSetVector<SelectInst *, 8> SelectUsers;
4883  
4884    AllocaSliceRewriter Rewriter(DL, AS, *this, AI, *NewAI, P.beginOffset(),
4885                                 P.endOffset(), IsIntegerPromotable, VecTy,
4886                                 PHIUsers, SelectUsers);
4887    bool Promotable = true;
4888    for (Slice *S : P.splitSliceTails()) {
4889      Promotable &= Rewriter.visit(S);
4890      ++NumUses;
4891    }
4892    for (Slice &S : P) {
4893      Promotable &= Rewriter.visit(&S);
4894      ++NumUses;
4895    }
4896  
4897    NumAllocaPartitionUses += NumUses;
4898    MaxUsesPerAllocaPartition.updateMax(NumUses);
4899  
4900    // Now that we've processed all the slices in the new partition, check if any
4901    // PHIs or Selects would block promotion.
4902    for (PHINode *PHI : PHIUsers)
4903      if (!isSafePHIToSpeculate(*PHI)) {
4904        Promotable = false;
4905        PHIUsers.clear();
4906        SelectUsers.clear();
4907        break;
4908      }
4909  
4910    SmallVector<std::pair<SelectInst *, RewriteableMemOps>, 2>
4911        NewSelectsToRewrite;
4912    NewSelectsToRewrite.reserve(SelectUsers.size());
4913    for (SelectInst *Sel : SelectUsers) {
4914      std::optional<RewriteableMemOps> Ops =
4915          isSafeSelectToSpeculate(*Sel, PreserveCFG);
4916      if (!Ops) {
4917        Promotable = false;
4918        PHIUsers.clear();
4919        SelectUsers.clear();
4920        NewSelectsToRewrite.clear();
4921        break;
4922      }
4923      NewSelectsToRewrite.emplace_back(std::make_pair(Sel, *Ops));
4924    }
4925  
4926    if (Promotable) {
4927      for (Use *U : AS.getDeadUsesIfPromotable()) {
4928        auto *OldInst = dyn_cast<Instruction>(U->get());
4929        Value::dropDroppableUse(*U);
4930        if (OldInst)
4931          if (isInstructionTriviallyDead(OldInst))
4932            DeadInsts.push_back(OldInst);
4933      }
4934      if (PHIUsers.empty() && SelectUsers.empty()) {
4935        // Promote the alloca.
4936        PromotableAllocas.push_back(NewAI);
4937      } else {
4938        // If we have either PHIs or Selects to speculate, add them to those
4939        // worklists and re-queue the new alloca so that we promote in on the
4940        // next iteration.
4941        for (PHINode *PHIUser : PHIUsers)
4942          SpeculatablePHIs.insert(PHIUser);
4943        SelectsToRewrite.reserve(SelectsToRewrite.size() +
4944                                 NewSelectsToRewrite.size());
4945        for (auto &&KV : llvm::make_range(
4946                 std::make_move_iterator(NewSelectsToRewrite.begin()),
4947                 std::make_move_iterator(NewSelectsToRewrite.end())))
4948          SelectsToRewrite.insert(std::move(KV));
4949        Worklist.insert(NewAI);
4950      }
4951    } else {
4952      // Drop any post-promotion work items if promotion didn't happen.
4953      while (PostPromotionWorklist.size() > PPWOldSize)
4954        PostPromotionWorklist.pop_back();
4955  
4956      // We couldn't promote and we didn't create a new partition, nothing
4957      // happened.
4958      if (NewAI == &AI)
4959        return nullptr;
4960  
4961      // If we can't promote the alloca, iterate on it to check for new
4962      // refinements exposed by splitting the current alloca. Don't iterate on an
4963      // alloca which didn't actually change and didn't get promoted.
4964      Worklist.insert(NewAI);
4965    }
4966  
4967    return NewAI;
4968  }
4969  
4970  // There isn't a shared interface to get the "address" parts out of a
4971  // dbg.declare and dbg.assign, so provide some wrappers now for
4972  // both debug intrinsics and records.
getAddress(const DbgVariableIntrinsic * DVI)4973  const Value *getAddress(const DbgVariableIntrinsic *DVI) {
4974    if (const auto *DAI = dyn_cast<DbgAssignIntrinsic>(DVI))
4975      return DAI->getAddress();
4976    return cast<DbgDeclareInst>(DVI)->getAddress();
4977  }
4978  
getAddress(const DbgVariableRecord * DVR)4979  const Value *getAddress(const DbgVariableRecord *DVR) {
4980    assert(DVR->getType() == DbgVariableRecord::LocationType::Declare ||
4981           DVR->getType() == DbgVariableRecord::LocationType::Assign);
4982    return DVR->getAddress();
4983  }
4984  
isKillAddress(const DbgVariableIntrinsic * DVI)4985  bool isKillAddress(const DbgVariableIntrinsic *DVI) {
4986    if (const auto *DAI = dyn_cast<DbgAssignIntrinsic>(DVI))
4987      return DAI->isKillAddress();
4988    return cast<DbgDeclareInst>(DVI)->isKillLocation();
4989  }
4990  
isKillAddress(const DbgVariableRecord * DVR)4991  bool isKillAddress(const DbgVariableRecord *DVR) {
4992    assert(DVR->getType() == DbgVariableRecord::LocationType::Declare ||
4993           DVR->getType() == DbgVariableRecord::LocationType::Assign);
4994    if (DVR->getType() == DbgVariableRecord::LocationType::Assign)
4995      return DVR->isKillAddress();
4996    return DVR->isKillLocation();
4997  }
4998  
getAddressExpression(const DbgVariableIntrinsic * DVI)4999  const DIExpression *getAddressExpression(const DbgVariableIntrinsic *DVI) {
5000    if (const auto *DAI = dyn_cast<DbgAssignIntrinsic>(DVI))
5001      return DAI->getAddressExpression();
5002    return cast<DbgDeclareInst>(DVI)->getExpression();
5003  }
5004  
getAddressExpression(const DbgVariableRecord * DVR)5005  const DIExpression *getAddressExpression(const DbgVariableRecord *DVR) {
5006    assert(DVR->getType() == DbgVariableRecord::LocationType::Declare ||
5007           DVR->getType() == DbgVariableRecord::LocationType::Assign);
5008    if (DVR->getType() == DbgVariableRecord::LocationType::Assign)
5009      return DVR->getAddressExpression();
5010    return DVR->getExpression();
5011  }
5012  
5013  /// Create or replace an existing fragment in a DIExpression with \p Frag.
5014  /// If the expression already contains a DW_OP_LLVM_extract_bits_[sz]ext
5015  /// operation, add \p BitExtractOffset to the offset part.
5016  ///
5017  /// Returns the new expression, or nullptr if this fails (see details below).
5018  ///
5019  /// This function is similar to DIExpression::createFragmentExpression except
5020  /// for 3 important distinctions:
5021  ///   1. The new fragment isn't relative to an existing fragment.
5022  ///   2. It assumes the computed location is a memory location. This means we
5023  ///      don't need to perform checks that creating the fragment preserves the
5024  ///      expression semantics.
5025  ///   3. Existing extract_bits are modified independently of fragment changes
5026  ///      using \p BitExtractOffset. A change to the fragment offset or size
5027  ///      may affect a bit extract. But a bit extract offset can change
5028  ///      independently of the fragment dimensions.
5029  ///
5030  /// Returns the new expression, or nullptr if one couldn't be created.
5031  /// Ideally this is only used to signal that a bit-extract has become
5032  /// zero-sized (and thus the new debug record has no size and can be
5033  /// dropped), however, it fails for other reasons too - see the FIXME below.
5034  ///
5035  /// FIXME: To keep the change that introduces this function NFC it bails
5036  /// in some situations unecessarily, e.g. when fragment and bit extract
5037  /// sizes differ.
createOrReplaceFragment(const DIExpression * Expr,DIExpression::FragmentInfo Frag,int64_t BitExtractOffset)5038  static DIExpression *createOrReplaceFragment(const DIExpression *Expr,
5039                                               DIExpression::FragmentInfo Frag,
5040                                               int64_t BitExtractOffset) {
5041    SmallVector<uint64_t, 8> Ops;
5042    bool HasFragment = false;
5043    bool HasBitExtract = false;
5044  
5045    for (auto &Op : Expr->expr_ops()) {
5046      if (Op.getOp() == dwarf::DW_OP_LLVM_fragment) {
5047        HasFragment = true;
5048        continue;
5049      }
5050      if (Op.getOp() == dwarf::DW_OP_LLVM_extract_bits_zext ||
5051          Op.getOp() == dwarf::DW_OP_LLVM_extract_bits_sext) {
5052        HasBitExtract = true;
5053        int64_t ExtractOffsetInBits = Op.getArg(0);
5054        int64_t ExtractSizeInBits = Op.getArg(1);
5055  
5056        // DIExpression::createFragmentExpression doesn't know how to handle
5057        // a fragment that is smaller than the extract. Copy the behaviour
5058        // (bail) to avoid non-NFC changes.
5059        // FIXME: Don't do this.
5060        if (Frag.SizeInBits < uint64_t(ExtractSizeInBits))
5061          return nullptr;
5062  
5063        assert(BitExtractOffset <= 0);
5064        int64_t AdjustedOffset = ExtractOffsetInBits + BitExtractOffset;
5065  
5066        // DIExpression::createFragmentExpression doesn't know what to do
5067        // if the new extract starts "outside" the existing one. Copy the
5068        // behaviour (bail) to avoid non-NFC changes.
5069        // FIXME: Don't do this.
5070        if (AdjustedOffset < 0)
5071          return nullptr;
5072  
5073        Ops.push_back(Op.getOp());
5074        Ops.push_back(std::max<int64_t>(0, AdjustedOffset));
5075        Ops.push_back(ExtractSizeInBits);
5076        continue;
5077      }
5078      Op.appendToVector(Ops);
5079    }
5080  
5081    // Unsupported by createFragmentExpression, so don't support it here yet to
5082    // preserve NFC-ness.
5083    if (HasFragment && HasBitExtract)
5084      return nullptr;
5085  
5086    if (!HasBitExtract) {
5087      Ops.push_back(dwarf::DW_OP_LLVM_fragment);
5088      Ops.push_back(Frag.OffsetInBits);
5089      Ops.push_back(Frag.SizeInBits);
5090    }
5091    return DIExpression::get(Expr->getContext(), Ops);
5092  }
5093  
5094  /// Insert a new dbg.declare.
5095  /// \p Orig Original to copy debug loc and variable from.
5096  /// \p NewAddr Location's new base address.
5097  /// \p NewAddrExpr New expression to apply to address.
5098  /// \p BeforeInst Insert position.
5099  /// \p NewFragment New fragment (absolute, non-relative).
5100  /// \p BitExtractAdjustment Offset to apply to any extract_bits op.
5101  static void
insertNewDbgInst(DIBuilder & DIB,DbgDeclareInst * Orig,AllocaInst * NewAddr,DIExpression * NewAddrExpr,Instruction * BeforeInst,std::optional<DIExpression::FragmentInfo> NewFragment,int64_t BitExtractAdjustment)5102  insertNewDbgInst(DIBuilder &DIB, DbgDeclareInst *Orig, AllocaInst *NewAddr,
5103                   DIExpression *NewAddrExpr, Instruction *BeforeInst,
5104                   std::optional<DIExpression::FragmentInfo> NewFragment,
5105                   int64_t BitExtractAdjustment) {
5106    if (NewFragment)
5107      NewAddrExpr = createOrReplaceFragment(NewAddrExpr, *NewFragment,
5108                                            BitExtractAdjustment);
5109    if (!NewAddrExpr)
5110      return;
5111  
5112    DIB.insertDeclare(NewAddr, Orig->getVariable(), NewAddrExpr,
5113                      Orig->getDebugLoc(), BeforeInst);
5114  }
5115  
5116  /// Insert a new dbg.assign.
5117  /// \p Orig Original to copy debug loc, variable, value and value expression
5118  ///    from.
5119  /// \p NewAddr Location's new base address.
5120  /// \p NewAddrExpr New expression to apply to address.
5121  /// \p BeforeInst Insert position.
5122  /// \p NewFragment New fragment (absolute, non-relative).
5123  /// \p BitExtractAdjustment Offset to apply to any extract_bits op.
5124  static void
insertNewDbgInst(DIBuilder & DIB,DbgAssignIntrinsic * Orig,AllocaInst * NewAddr,DIExpression * NewAddrExpr,Instruction * BeforeInst,std::optional<DIExpression::FragmentInfo> NewFragment,int64_t BitExtractAdjustment)5125  insertNewDbgInst(DIBuilder &DIB, DbgAssignIntrinsic *Orig, AllocaInst *NewAddr,
5126                   DIExpression *NewAddrExpr, Instruction *BeforeInst,
5127                   std::optional<DIExpression::FragmentInfo> NewFragment,
5128                   int64_t BitExtractAdjustment) {
5129    // DIBuilder::insertDbgAssign will insert the #dbg_assign after NewAddr.
5130    (void)BeforeInst;
5131  
5132    // A dbg.assign puts fragment info in the value expression only. The address
5133    // expression has already been built: NewAddrExpr.
5134    DIExpression *NewFragmentExpr = Orig->getExpression();
5135    if (NewFragment)
5136      NewFragmentExpr = createOrReplaceFragment(NewFragmentExpr, *NewFragment,
5137                                                BitExtractAdjustment);
5138    if (!NewFragmentExpr)
5139      return;
5140  
5141    // Apply a DIAssignID to the store if it doesn't already have it.
5142    if (!NewAddr->hasMetadata(LLVMContext::MD_DIAssignID)) {
5143      NewAddr->setMetadata(LLVMContext::MD_DIAssignID,
5144                           DIAssignID::getDistinct(NewAddr->getContext()));
5145    }
5146  
5147    Instruction *NewAssign =
5148        DIB.insertDbgAssign(NewAddr, Orig->getValue(), Orig->getVariable(),
5149                            NewFragmentExpr, NewAddr, NewAddrExpr,
5150                            Orig->getDebugLoc())
5151            .get<Instruction *>();
5152    LLVM_DEBUG(dbgs() << "Created new assign intrinsic: " << *NewAssign << "\n");
5153    (void)NewAssign;
5154  }
5155  
5156  /// Insert a new DbgRecord.
5157  /// \p Orig Original to copy record type, debug loc and variable from, and
5158  ///    additionally value and value expression for dbg_assign records.
5159  /// \p NewAddr Location's new base address.
5160  /// \p NewAddrExpr New expression to apply to address.
5161  /// \p BeforeInst Insert position.
5162  /// \p NewFragment New fragment (absolute, non-relative).
5163  /// \p BitExtractAdjustment Offset to apply to any extract_bits op.
5164  static void
insertNewDbgInst(DIBuilder & DIB,DbgVariableRecord * Orig,AllocaInst * NewAddr,DIExpression * NewAddrExpr,Instruction * BeforeInst,std::optional<DIExpression::FragmentInfo> NewFragment,int64_t BitExtractAdjustment)5165  insertNewDbgInst(DIBuilder &DIB, DbgVariableRecord *Orig, AllocaInst *NewAddr,
5166                   DIExpression *NewAddrExpr, Instruction *BeforeInst,
5167                   std::optional<DIExpression::FragmentInfo> NewFragment,
5168                   int64_t BitExtractAdjustment) {
5169    (void)DIB;
5170  
5171    // A dbg_assign puts fragment info in the value expression only. The address
5172    // expression has already been built: NewAddrExpr. A dbg_declare puts the
5173    // new fragment info into NewAddrExpr (as it only has one expression).
5174    DIExpression *NewFragmentExpr =
5175        Orig->isDbgAssign() ? Orig->getExpression() : NewAddrExpr;
5176    if (NewFragment)
5177      NewFragmentExpr = createOrReplaceFragment(NewFragmentExpr, *NewFragment,
5178                                                BitExtractAdjustment);
5179    if (!NewFragmentExpr)
5180      return;
5181  
5182    if (Orig->isDbgDeclare()) {
5183      DbgVariableRecord *DVR = DbgVariableRecord::createDVRDeclare(
5184          NewAddr, Orig->getVariable(), NewFragmentExpr, Orig->getDebugLoc());
5185      BeforeInst->getParent()->insertDbgRecordBefore(DVR,
5186                                                     BeforeInst->getIterator());
5187      return;
5188    }
5189  
5190    // Apply a DIAssignID to the store if it doesn't already have it.
5191    if (!NewAddr->hasMetadata(LLVMContext::MD_DIAssignID)) {
5192      NewAddr->setMetadata(LLVMContext::MD_DIAssignID,
5193                           DIAssignID::getDistinct(NewAddr->getContext()));
5194    }
5195  
5196    DbgVariableRecord *NewAssign = DbgVariableRecord::createLinkedDVRAssign(
5197        NewAddr, Orig->getValue(), Orig->getVariable(), NewFragmentExpr, NewAddr,
5198        NewAddrExpr, Orig->getDebugLoc());
5199    LLVM_DEBUG(dbgs() << "Created new DVRAssign: " << *NewAssign << "\n");
5200    (void)NewAssign;
5201  }
5202  
5203  /// Walks the slices of an alloca and form partitions based on them,
5204  /// rewriting each of their uses.
splitAlloca(AllocaInst & AI,AllocaSlices & AS)5205  bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) {
5206    if (AS.begin() == AS.end())
5207      return false;
5208  
5209    unsigned NumPartitions = 0;
5210    bool Changed = false;
5211    const DataLayout &DL = AI.getModule()->getDataLayout();
5212  
5213    // First try to pre-split loads and stores.
5214    Changed |= presplitLoadsAndStores(AI, AS);
5215  
5216    // Now that we have identified any pre-splitting opportunities,
5217    // mark loads and stores unsplittable except for the following case.
5218    // We leave a slice splittable if all other slices are disjoint or fully
5219    // included in the slice, such as whole-alloca loads and stores.
5220    // If we fail to split these during pre-splitting, we want to force them
5221    // to be rewritten into a partition.
5222    bool IsSorted = true;
5223  
5224    uint64_t AllocaSize =
5225        DL.getTypeAllocSize(AI.getAllocatedType()).getFixedValue();
5226    const uint64_t MaxBitVectorSize = 1024;
5227    if (AllocaSize <= MaxBitVectorSize) {
5228      // If a byte boundary is included in any load or store, a slice starting or
5229      // ending at the boundary is not splittable.
5230      SmallBitVector SplittableOffset(AllocaSize + 1, true);
5231      for (Slice &S : AS)
5232        for (unsigned O = S.beginOffset() + 1;
5233             O < S.endOffset() && O < AllocaSize; O++)
5234          SplittableOffset.reset(O);
5235  
5236      for (Slice &S : AS) {
5237        if (!S.isSplittable())
5238          continue;
5239  
5240        if ((S.beginOffset() > AllocaSize || SplittableOffset[S.beginOffset()]) &&
5241            (S.endOffset() > AllocaSize || SplittableOffset[S.endOffset()]))
5242          continue;
5243  
5244        if (isa<LoadInst>(S.getUse()->getUser()) ||
5245            isa<StoreInst>(S.getUse()->getUser())) {
5246          S.makeUnsplittable();
5247          IsSorted = false;
5248        }
5249      }
5250    } else {
5251      // We only allow whole-alloca splittable loads and stores
5252      // for a large alloca to avoid creating too large BitVector.
5253      for (Slice &S : AS) {
5254        if (!S.isSplittable())
5255          continue;
5256  
5257        if (S.beginOffset() == 0 && S.endOffset() >= AllocaSize)
5258          continue;
5259  
5260        if (isa<LoadInst>(S.getUse()->getUser()) ||
5261            isa<StoreInst>(S.getUse()->getUser())) {
5262          S.makeUnsplittable();
5263          IsSorted = false;
5264        }
5265      }
5266    }
5267  
5268    if (!IsSorted)
5269      llvm::stable_sort(AS);
5270  
5271    /// Describes the allocas introduced by rewritePartition in order to migrate
5272    /// the debug info.
5273    struct Fragment {
5274      AllocaInst *Alloca;
5275      uint64_t Offset;
5276      uint64_t Size;
5277      Fragment(AllocaInst *AI, uint64_t O, uint64_t S)
5278          : Alloca(AI), Offset(O), Size(S) {}
5279    };
5280    SmallVector<Fragment, 4> Fragments;
5281  
5282    // Rewrite each partition.
5283    for (auto &P : AS.partitions()) {
5284      if (AllocaInst *NewAI = rewritePartition(AI, AS, P)) {
5285        Changed = true;
5286        if (NewAI != &AI) {
5287          uint64_t SizeOfByte = 8;
5288          uint64_t AllocaSize =
5289              DL.getTypeSizeInBits(NewAI->getAllocatedType()).getFixedValue();
5290          // Don't include any padding.
5291          uint64_t Size = std::min(AllocaSize, P.size() * SizeOfByte);
5292          Fragments.push_back(
5293              Fragment(NewAI, P.beginOffset() * SizeOfByte, Size));
5294        }
5295      }
5296      ++NumPartitions;
5297    }
5298  
5299    NumAllocaPartitions += NumPartitions;
5300    MaxPartitionsPerAlloca.updateMax(NumPartitions);
5301  
5302    // Migrate debug information from the old alloca to the new alloca(s)
5303    // and the individual partitions.
5304    auto MigrateOne = [&](auto *DbgVariable) {
5305      // Can't overlap with undef memory.
5306      if (isKillAddress(DbgVariable))
5307        return;
5308  
5309      const Value *DbgPtr = getAddress(DbgVariable);
5310      DIExpression::FragmentInfo VarFrag =
5311          DbgVariable->getFragmentOrEntireVariable();
5312      // Get the address expression constant offset if one exists and the ops
5313      // that come after it.
5314      int64_t CurrentExprOffsetInBytes = 0;
5315      SmallVector<uint64_t> PostOffsetOps;
5316      if (!getAddressExpression(DbgVariable)
5317               ->extractLeadingOffset(CurrentExprOffsetInBytes, PostOffsetOps))
5318        return; // Couldn't interpret this DIExpression - drop the var.
5319  
5320      // Offset defined by a DW_OP_LLVM_extract_bits_[sz]ext.
5321      int64_t ExtractOffsetInBits = 0;
5322      for (auto Op : getAddressExpression(DbgVariable)->expr_ops()) {
5323        if (Op.getOp() == dwarf::DW_OP_LLVM_extract_bits_zext ||
5324            Op.getOp() == dwarf::DW_OP_LLVM_extract_bits_sext) {
5325          ExtractOffsetInBits = Op.getArg(0);
5326          break;
5327        }
5328      }
5329  
5330      DIBuilder DIB(*AI.getModule(), /*AllowUnresolved*/ false);
5331      for (auto Fragment : Fragments) {
5332        int64_t OffsetFromLocationInBits;
5333        std::optional<DIExpression::FragmentInfo> NewDbgFragment;
5334        // Find the variable fragment that the new alloca slice covers.
5335        // Drop debug info for this variable fragment if we can't compute an
5336        // intersect between it and the alloca slice.
5337        if (!DIExpression::calculateFragmentIntersect(
5338                DL, &AI, Fragment.Offset, Fragment.Size, DbgPtr,
5339                CurrentExprOffsetInBytes * 8, ExtractOffsetInBits, VarFrag,
5340                NewDbgFragment, OffsetFromLocationInBits))
5341          continue; // Do not migrate this fragment to this slice.
5342  
5343        // Zero sized fragment indicates there's no intersect between the variable
5344        // fragment and the alloca slice. Skip this slice for this variable
5345        // fragment.
5346        if (NewDbgFragment && !NewDbgFragment->SizeInBits)
5347          continue; // Do not migrate this fragment to this slice.
5348  
5349        // No fragment indicates DbgVariable's variable or fragment exactly
5350        // overlaps the slice; copy its fragment (or nullopt if there isn't one).
5351        if (!NewDbgFragment)
5352          NewDbgFragment = DbgVariable->getFragment();
5353  
5354        // Reduce the new expression offset by the bit-extract offset since
5355        // we'll be keeping that.
5356        int64_t OffestFromNewAllocaInBits =
5357            OffsetFromLocationInBits - ExtractOffsetInBits;
5358        // We need to adjust an existing bit extract if the offset expression
5359        // can't eat the slack (i.e., if the new offset would be negative).
5360        int64_t BitExtractOffset =
5361            std::min<int64_t>(0, OffestFromNewAllocaInBits);
5362        // The magnitude of a negative value indicates the number of bits into
5363        // the existing variable fragment that the memory region begins. The new
5364        // variable fragment already excludes those bits - the new DbgPtr offset
5365        // only needs to be applied if it's positive.
5366        OffestFromNewAllocaInBits =
5367            std::max(int64_t(0), OffestFromNewAllocaInBits);
5368  
5369        // Rebuild the expression:
5370        //    {Offset(OffestFromNewAllocaInBits), PostOffsetOps, NewDbgFragment}
5371        // Add NewDbgFragment later, because dbg.assigns don't want it in the
5372        // address expression but the value expression instead.
5373        DIExpression *NewExpr = DIExpression::get(AI.getContext(), PostOffsetOps);
5374        if (OffestFromNewAllocaInBits > 0) {
5375          int64_t OffsetInBytes = (OffestFromNewAllocaInBits + 7) / 8;
5376          NewExpr = DIExpression::prepend(NewExpr, /*flags=*/0, OffsetInBytes);
5377        }
5378  
5379        // Remove any existing intrinsics on the new alloca describing
5380        // the variable fragment.
5381        auto RemoveOne = [DbgVariable](auto *OldDII) {
5382          auto SameVariableFragment = [](const auto *LHS, const auto *RHS) {
5383            return LHS->getVariable() == RHS->getVariable() &&
5384                   LHS->getDebugLoc()->getInlinedAt() ==
5385                       RHS->getDebugLoc()->getInlinedAt();
5386          };
5387          if (SameVariableFragment(OldDII, DbgVariable))
5388            OldDII->eraseFromParent();
5389        };
5390        for_each(findDbgDeclares(Fragment.Alloca), RemoveOne);
5391        for_each(findDVRDeclares(Fragment.Alloca), RemoveOne);
5392  
5393        insertNewDbgInst(DIB, DbgVariable, Fragment.Alloca, NewExpr, &AI,
5394                         NewDbgFragment, BitExtractOffset);
5395      }
5396    };
5397  
5398    // Migrate debug information from the old alloca to the new alloca(s)
5399    // and the individual partitions.
5400    for_each(findDbgDeclares(&AI), MigrateOne);
5401    for_each(findDVRDeclares(&AI), MigrateOne);
5402    for_each(at::getAssignmentMarkers(&AI), MigrateOne);
5403    for_each(at::getDVRAssignmentMarkers(&AI), MigrateOne);
5404  
5405    return Changed;
5406  }
5407  
5408  /// Clobber a use with poison, deleting the used value if it becomes dead.
clobberUse(Use & U)5409  void SROA::clobberUse(Use &U) {
5410    Value *OldV = U;
5411    // Replace the use with an poison value.
5412    U = PoisonValue::get(OldV->getType());
5413  
5414    // Check for this making an instruction dead. We have to garbage collect
5415    // all the dead instructions to ensure the uses of any alloca end up being
5416    // minimal.
5417    if (Instruction *OldI = dyn_cast<Instruction>(OldV))
5418      if (isInstructionTriviallyDead(OldI)) {
5419        DeadInsts.push_back(OldI);
5420      }
5421  }
5422  
5423  /// Analyze an alloca for SROA.
5424  ///
5425  /// This analyzes the alloca to ensure we can reason about it, builds
5426  /// the slices of the alloca, and then hands it off to be split and
5427  /// rewritten as needed.
5428  std::pair<bool /*Changed*/, bool /*CFGChanged*/>
runOnAlloca(AllocaInst & AI)5429  SROA::runOnAlloca(AllocaInst &AI) {
5430    bool Changed = false;
5431    bool CFGChanged = false;
5432  
5433    LLVM_DEBUG(dbgs() << "SROA alloca: " << AI << "\n");
5434    ++NumAllocasAnalyzed;
5435  
5436    // Special case dead allocas, as they're trivial.
5437    if (AI.use_empty()) {
5438      AI.eraseFromParent();
5439      Changed = true;
5440      return {Changed, CFGChanged};
5441    }
5442    const DataLayout &DL = AI.getDataLayout();
5443  
5444    // Skip alloca forms that this analysis can't handle.
5445    auto *AT = AI.getAllocatedType();
5446    TypeSize Size = DL.getTypeAllocSize(AT);
5447    if (AI.isArrayAllocation() || !AT->isSized() || Size.isScalable() ||
5448        Size.getFixedValue() == 0)
5449      return {Changed, CFGChanged};
5450  
5451    // First, split any FCA loads and stores touching this alloca to promote
5452    // better splitting and promotion opportunities.
5453    IRBuilderTy IRB(&AI);
5454    AggLoadStoreRewriter AggRewriter(DL, IRB);
5455    Changed |= AggRewriter.rewrite(AI);
5456  
5457    // Build the slices using a recursive instruction-visiting builder.
5458    AllocaSlices AS(DL, AI);
5459    LLVM_DEBUG(AS.print(dbgs()));
5460    if (AS.isEscaped())
5461      return {Changed, CFGChanged};
5462  
5463    // Delete all the dead users of this alloca before splitting and rewriting it.
5464    for (Instruction *DeadUser : AS.getDeadUsers()) {
5465      // Free up everything used by this instruction.
5466      for (Use &DeadOp : DeadUser->operands())
5467        clobberUse(DeadOp);
5468  
5469      // Now replace the uses of this instruction.
5470      DeadUser->replaceAllUsesWith(PoisonValue::get(DeadUser->getType()));
5471  
5472      // And mark it for deletion.
5473      DeadInsts.push_back(DeadUser);
5474      Changed = true;
5475    }
5476    for (Use *DeadOp : AS.getDeadOperands()) {
5477      clobberUse(*DeadOp);
5478      Changed = true;
5479    }
5480  
5481    // No slices to split. Leave the dead alloca for a later pass to clean up.
5482    if (AS.begin() == AS.end())
5483      return {Changed, CFGChanged};
5484  
5485    Changed |= splitAlloca(AI, AS);
5486  
5487    LLVM_DEBUG(dbgs() << "  Speculating PHIs\n");
5488    while (!SpeculatablePHIs.empty())
5489      speculatePHINodeLoads(IRB, *SpeculatablePHIs.pop_back_val());
5490  
5491    LLVM_DEBUG(dbgs() << "  Rewriting Selects\n");
5492    auto RemainingSelectsToRewrite = SelectsToRewrite.takeVector();
5493    while (!RemainingSelectsToRewrite.empty()) {
5494      const auto [K, V] = RemainingSelectsToRewrite.pop_back_val();
5495      CFGChanged |=
5496          rewriteSelectInstMemOps(*K, V, IRB, PreserveCFG ? nullptr : DTU);
5497    }
5498  
5499    return {Changed, CFGChanged};
5500  }
5501  
5502  /// Delete the dead instructions accumulated in this run.
5503  ///
5504  /// Recursively deletes the dead instructions we've accumulated. This is done
5505  /// at the very end to maximize locality of the recursive delete and to
5506  /// minimize the problems of invalidated instruction pointers as such pointers
5507  /// are used heavily in the intermediate stages of the algorithm.
5508  ///
5509  /// We also record the alloca instructions deleted here so that they aren't
5510  /// subsequently handed to mem2reg to promote.
deleteDeadInstructions(SmallPtrSetImpl<AllocaInst * > & DeletedAllocas)5511  bool SROA::deleteDeadInstructions(
5512      SmallPtrSetImpl<AllocaInst *> &DeletedAllocas) {
5513    bool Changed = false;
5514    while (!DeadInsts.empty()) {
5515      Instruction *I = dyn_cast_or_null<Instruction>(DeadInsts.pop_back_val());
5516      if (!I)
5517        continue;
5518      LLVM_DEBUG(dbgs() << "Deleting dead instruction: " << *I << "\n");
5519  
5520      // If the instruction is an alloca, find the possible dbg.declare connected
5521      // to it, and remove it too. We must do this before calling RAUW or we will
5522      // not be able to find it.
5523      if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) {
5524        DeletedAllocas.insert(AI);
5525        for (DbgDeclareInst *OldDII : findDbgDeclares(AI))
5526          OldDII->eraseFromParent();
5527        for (DbgVariableRecord *OldDII : findDVRDeclares(AI))
5528          OldDII->eraseFromParent();
5529      }
5530  
5531      at::deleteAssignmentMarkers(I);
5532      I->replaceAllUsesWith(UndefValue::get(I->getType()));
5533  
5534      for (Use &Operand : I->operands())
5535        if (Instruction *U = dyn_cast<Instruction>(Operand)) {
5536          // Zero out the operand and see if it becomes trivially dead.
5537          Operand = nullptr;
5538          if (isInstructionTriviallyDead(U))
5539            DeadInsts.push_back(U);
5540        }
5541  
5542      ++NumDeleted;
5543      I->eraseFromParent();
5544      Changed = true;
5545    }
5546    return Changed;
5547  }
5548  
5549  /// Promote the allocas, using the best available technique.
5550  ///
5551  /// This attempts to promote whatever allocas have been identified as viable in
5552  /// the PromotableAllocas list. If that list is empty, there is nothing to do.
5553  /// This function returns whether any promotion occurred.
promoteAllocas(Function & F)5554  bool SROA::promoteAllocas(Function &F) {
5555    if (PromotableAllocas.empty())
5556      return false;
5557  
5558    NumPromoted += PromotableAllocas.size();
5559  
5560    if (SROASkipMem2Reg) {
5561      LLVM_DEBUG(dbgs() << "Not promoting allocas with mem2reg!\n");
5562    } else {
5563      LLVM_DEBUG(dbgs() << "Promoting allocas with mem2reg...\n");
5564      PromoteMemToReg(PromotableAllocas, DTU->getDomTree(), AC);
5565    }
5566  
5567    PromotableAllocas.clear();
5568    return true;
5569  }
5570  
runSROA(Function & F)5571  std::pair<bool /*Changed*/, bool /*CFGChanged*/> SROA::runSROA(Function &F) {
5572    LLVM_DEBUG(dbgs() << "SROA function: " << F.getName() << "\n");
5573  
5574    const DataLayout &DL = F.getDataLayout();
5575    BasicBlock &EntryBB = F.getEntryBlock();
5576    for (BasicBlock::iterator I = EntryBB.begin(), E = std::prev(EntryBB.end());
5577         I != E; ++I) {
5578      if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) {
5579        if (DL.getTypeAllocSize(AI->getAllocatedType()).isScalable() &&
5580            isAllocaPromotable(AI))
5581          PromotableAllocas.push_back(AI);
5582        else
5583          Worklist.insert(AI);
5584      }
5585    }
5586  
5587    bool Changed = false;
5588    bool CFGChanged = false;
5589    // A set of deleted alloca instruction pointers which should be removed from
5590    // the list of promotable allocas.
5591    SmallPtrSet<AllocaInst *, 4> DeletedAllocas;
5592  
5593    do {
5594      while (!Worklist.empty()) {
5595        auto [IterationChanged, IterationCFGChanged] =
5596            runOnAlloca(*Worklist.pop_back_val());
5597        Changed |= IterationChanged;
5598        CFGChanged |= IterationCFGChanged;
5599  
5600        Changed |= deleteDeadInstructions(DeletedAllocas);
5601  
5602        // Remove the deleted allocas from various lists so that we don't try to
5603        // continue processing them.
5604        if (!DeletedAllocas.empty()) {
5605          auto IsInSet = [&](AllocaInst *AI) { return DeletedAllocas.count(AI); };
5606          Worklist.remove_if(IsInSet);
5607          PostPromotionWorklist.remove_if(IsInSet);
5608          llvm::erase_if(PromotableAllocas, IsInSet);
5609          DeletedAllocas.clear();
5610        }
5611      }
5612  
5613      Changed |= promoteAllocas(F);
5614  
5615      Worklist = PostPromotionWorklist;
5616      PostPromotionWorklist.clear();
5617    } while (!Worklist.empty());
5618  
5619    assert((!CFGChanged || Changed) && "Can not only modify the CFG.");
5620    assert((!CFGChanged || !PreserveCFG) &&
5621           "Should not have modified the CFG when told to preserve it.");
5622  
5623    if (Changed && isAssignmentTrackingEnabled(*F.getParent())) {
5624      for (auto &BB : F) {
5625        RemoveRedundantDbgInstrs(&BB);
5626      }
5627    }
5628  
5629    return {Changed, CFGChanged};
5630  }
5631  
run(Function & F,FunctionAnalysisManager & AM)5632  PreservedAnalyses SROAPass::run(Function &F, FunctionAnalysisManager &AM) {
5633    DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
5634    AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
5635    DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
5636    auto [Changed, CFGChanged] =
5637        SROA(&F.getContext(), &DTU, &AC, PreserveCFG).runSROA(F);
5638    if (!Changed)
5639      return PreservedAnalyses::all();
5640    PreservedAnalyses PA;
5641    if (!CFGChanged)
5642      PA.preserveSet<CFGAnalyses>();
5643    PA.preserve<DominatorTreeAnalysis>();
5644    return PA;
5645  }
5646  
printPipeline(raw_ostream & OS,function_ref<StringRef (StringRef)> MapClassName2PassName)5647  void SROAPass::printPipeline(
5648      raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
5649    static_cast<PassInfoMixin<SROAPass> *>(this)->printPipeline(
5650        OS, MapClassName2PassName);
5651    OS << (PreserveCFG == SROAOptions::PreserveCFG ? "<preserve-cfg>"
5652                                                   : "<modify-cfg>");
5653  }
5654  
SROAPass(SROAOptions PreserveCFG)5655  SROAPass::SROAPass(SROAOptions PreserveCFG) : PreserveCFG(PreserveCFG) {}
5656  
5657  namespace {
5658  
5659  /// A legacy pass for the legacy pass manager that wraps the \c SROA pass.
5660  class SROALegacyPass : public FunctionPass {
5661    SROAOptions PreserveCFG;
5662  
5663  public:
5664    static char ID;
5665  
SROALegacyPass(SROAOptions PreserveCFG=SROAOptions::PreserveCFG)5666    SROALegacyPass(SROAOptions PreserveCFG = SROAOptions::PreserveCFG)
5667        : FunctionPass(ID), PreserveCFG(PreserveCFG) {
5668      initializeSROALegacyPassPass(*PassRegistry::getPassRegistry());
5669    }
5670  
runOnFunction(Function & F)5671    bool runOnFunction(Function &F) override {
5672      if (skipFunction(F))
5673        return false;
5674  
5675      DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
5676      AssumptionCache &AC =
5677          getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
5678      DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
5679      auto [Changed, _] =
5680          SROA(&F.getContext(), &DTU, &AC, PreserveCFG).runSROA(F);
5681      return Changed;
5682    }
5683  
getAnalysisUsage(AnalysisUsage & AU) const5684    void getAnalysisUsage(AnalysisUsage &AU) const override {
5685      AU.addRequired<AssumptionCacheTracker>();
5686      AU.addRequired<DominatorTreeWrapperPass>();
5687      AU.addPreserved<GlobalsAAWrapperPass>();
5688      AU.addPreserved<DominatorTreeWrapperPass>();
5689    }
5690  
getPassName() const5691    StringRef getPassName() const override { return "SROA"; }
5692  };
5693  
5694  } // end anonymous namespace
5695  
5696  char SROALegacyPass::ID = 0;
5697  
createSROAPass(bool PreserveCFG)5698  FunctionPass *llvm::createSROAPass(bool PreserveCFG) {
5699    return new SROALegacyPass(PreserveCFG ? SROAOptions::PreserveCFG
5700                                          : SROAOptions::ModifyCFG);
5701  }
5702  
5703  INITIALIZE_PASS_BEGIN(SROALegacyPass, "sroa",
5704                        "Scalar Replacement Of Aggregates", false, false)
5705  INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
5706  INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
5707  INITIALIZE_PASS_END(SROALegacyPass, "sroa", "Scalar Replacement Of Aggregates",
5708                      false, false)
5709