xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Scalar/Scalarizer.cpp (revision f126890ac5386406dadf7c4cfa9566cbb56537c5)
1 //===- Scalarizer.cpp - Scalarize vector operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass converts vector operations into scalar operations (or, optionally,
10 // operations on smaller vector widths), in order to expose optimization
11 // opportunities on the individual scalar operations.
12 // It is mainly intended for targets that do not have vector units, but it
13 // may also be useful for revectorizing code to different vector widths.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/Transforms/Scalar/Scalarizer.h"
18 #include "llvm/ADT/PostOrderIterator.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/Twine.h"
21 #include "llvm/Analysis/VectorUtils.h"
22 #include "llvm/IR/Argument.h"
23 #include "llvm/IR/BasicBlock.h"
24 #include "llvm/IR/Constants.h"
25 #include "llvm/IR/DataLayout.h"
26 #include "llvm/IR/DerivedTypes.h"
27 #include "llvm/IR/Dominators.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/IRBuilder.h"
30 #include "llvm/IR/InstVisitor.h"
31 #include "llvm/IR/InstrTypes.h"
32 #include "llvm/IR/Instruction.h"
33 #include "llvm/IR/Instructions.h"
34 #include "llvm/IR/Intrinsics.h"
35 #include "llvm/IR/LLVMContext.h"
36 #include "llvm/IR/Module.h"
37 #include "llvm/IR/Type.h"
38 #include "llvm/IR/Value.h"
39 #include "llvm/InitializePasses.h"
40 #include "llvm/Pass.h"
41 #include "llvm/Support/Casting.h"
42 #include "llvm/Support/CommandLine.h"
43 #include "llvm/Transforms/Utils/Local.h"
44 #include <cassert>
45 #include <cstdint>
46 #include <iterator>
47 #include <map>
48 #include <utility>
49 
50 using namespace llvm;
51 
52 #define DEBUG_TYPE "scalarizer"
53 
54 static cl::opt<bool> ClScalarizeVariableInsertExtract(
55     "scalarize-variable-insert-extract", cl::init(true), cl::Hidden,
56     cl::desc("Allow the scalarizer pass to scalarize "
57              "insertelement/extractelement with variable index"));
58 
59 // This is disabled by default because having separate loads and stores
60 // makes it more likely that the -combiner-alias-analysis limits will be
61 // reached.
62 static cl::opt<bool> ClScalarizeLoadStore(
63     "scalarize-load-store", cl::init(false), cl::Hidden,
64     cl::desc("Allow the scalarizer pass to scalarize loads and store"));
65 
66 // Split vectors larger than this size into fragments, where each fragment is
67 // either a vector no larger than this size or a scalar.
68 //
69 // Instructions with operands or results of different sizes that would be split
70 // into a different number of fragments are currently left as-is.
71 static cl::opt<unsigned> ClScalarizeMinBits(
72     "scalarize-min-bits", cl::init(0), cl::Hidden,
73     cl::desc("Instruct the scalarizer pass to attempt to keep values of a "
74              "minimum number of bits"));
75 
76 namespace {
77 
78 BasicBlock::iterator skipPastPhiNodesAndDbg(BasicBlock::iterator Itr) {
79   BasicBlock *BB = Itr->getParent();
80   if (isa<PHINode>(Itr))
81     Itr = BB->getFirstInsertionPt();
82   if (Itr != BB->end())
83     Itr = skipDebugIntrinsics(Itr);
84   return Itr;
85 }
86 
87 // Used to store the scattered form of a vector.
88 using ValueVector = SmallVector<Value *, 8>;
89 
90 // Used to map a vector Value and associated type to its scattered form.
91 // The associated type is only non-null for pointer values that are "scattered"
92 // when used as pointer operands to load or store.
93 //
94 // We use std::map because we want iterators to persist across insertion and
95 // because the values are relatively large.
96 using ScatterMap = std::map<std::pair<Value *, Type *>, ValueVector>;
97 
98 // Lists Instructions that have been replaced with scalar implementations,
99 // along with a pointer to their scattered forms.
100 using GatherList = SmallVector<std::pair<Instruction *, ValueVector *>, 16>;
101 
102 struct VectorSplit {
103   // The type of the vector.
104   FixedVectorType *VecTy = nullptr;
105 
106   // The number of elements packed in a fragment (other than the remainder).
107   unsigned NumPacked = 0;
108 
109   // The number of fragments (scalars or smaller vectors) into which the vector
110   // shall be split.
111   unsigned NumFragments = 0;
112 
113   // The type of each complete fragment.
114   Type *SplitTy = nullptr;
115 
116   // The type of the remainder (last) fragment; null if all fragments are
117   // complete.
118   Type *RemainderTy = nullptr;
119 
120   Type *getFragmentType(unsigned I) const {
121     return RemainderTy && I == NumFragments - 1 ? RemainderTy : SplitTy;
122   }
123 };
124 
125 // Provides a very limited vector-like interface for lazily accessing one
126 // component of a scattered vector or vector pointer.
127 class Scatterer {
128 public:
129   Scatterer() = default;
130 
131   // Scatter V into Size components.  If new instructions are needed,
132   // insert them before BBI in BB.  If Cache is nonnull, use it to cache
133   // the results.
134   Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,
135             const VectorSplit &VS, ValueVector *cachePtr = nullptr);
136 
137   // Return component I, creating a new Value for it if necessary.
138   Value *operator[](unsigned I);
139 
140   // Return the number of components.
141   unsigned size() const { return VS.NumFragments; }
142 
143 private:
144   BasicBlock *BB;
145   BasicBlock::iterator BBI;
146   Value *V;
147   VectorSplit VS;
148   bool IsPointer;
149   ValueVector *CachePtr;
150   ValueVector Tmp;
151 };
152 
153 // FCmpSplitter(FCI)(Builder, X, Y, Name) uses Builder to create an FCmp
154 // called Name that compares X and Y in the same way as FCI.
155 struct FCmpSplitter {
156   FCmpSplitter(FCmpInst &fci) : FCI(fci) {}
157 
158   Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
159                     const Twine &Name) const {
160     return Builder.CreateFCmp(FCI.getPredicate(), Op0, Op1, Name);
161   }
162 
163   FCmpInst &FCI;
164 };
165 
166 // ICmpSplitter(ICI)(Builder, X, Y, Name) uses Builder to create an ICmp
167 // called Name that compares X and Y in the same way as ICI.
168 struct ICmpSplitter {
169   ICmpSplitter(ICmpInst &ici) : ICI(ici) {}
170 
171   Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
172                     const Twine &Name) const {
173     return Builder.CreateICmp(ICI.getPredicate(), Op0, Op1, Name);
174   }
175 
176   ICmpInst &ICI;
177 };
178 
179 // UnarySplitter(UO)(Builder, X, Name) uses Builder to create
180 // a unary operator like UO called Name with operand X.
181 struct UnarySplitter {
182   UnarySplitter(UnaryOperator &uo) : UO(uo) {}
183 
184   Value *operator()(IRBuilder<> &Builder, Value *Op, const Twine &Name) const {
185     return Builder.CreateUnOp(UO.getOpcode(), Op, Name);
186   }
187 
188   UnaryOperator &UO;
189 };
190 
191 // BinarySplitter(BO)(Builder, X, Y, Name) uses Builder to create
192 // a binary operator like BO called Name with operands X and Y.
193 struct BinarySplitter {
194   BinarySplitter(BinaryOperator &bo) : BO(bo) {}
195 
196   Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
197                     const Twine &Name) const {
198     return Builder.CreateBinOp(BO.getOpcode(), Op0, Op1, Name);
199   }
200 
201   BinaryOperator &BO;
202 };
203 
204 // Information about a load or store that we're scalarizing.
205 struct VectorLayout {
206   VectorLayout() = default;
207 
208   // Return the alignment of fragment Frag.
209   Align getFragmentAlign(unsigned Frag) {
210     return commonAlignment(VecAlign, Frag * SplitSize);
211   }
212 
213   // The split of the underlying vector type.
214   VectorSplit VS;
215 
216   // The alignment of the vector.
217   Align VecAlign;
218 
219   // The size of each (non-remainder) fragment in bytes.
220   uint64_t SplitSize = 0;
221 };
222 
223 /// Concatenate the given fragments to a single vector value of the type
224 /// described in @p VS.
225 static Value *concatenate(IRBuilder<> &Builder, ArrayRef<Value *> Fragments,
226                           const VectorSplit &VS, Twine Name) {
227   unsigned NumElements = VS.VecTy->getNumElements();
228   SmallVector<int> ExtendMask;
229   SmallVector<int> InsertMask;
230 
231   if (VS.NumPacked > 1) {
232     // Prepare the shufflevector masks once and re-use them for all
233     // fragments.
234     ExtendMask.resize(NumElements, -1);
235     for (unsigned I = 0; I < VS.NumPacked; ++I)
236       ExtendMask[I] = I;
237 
238     InsertMask.resize(NumElements);
239     for (unsigned I = 0; I < NumElements; ++I)
240       InsertMask[I] = I;
241   }
242 
243   Value *Res = PoisonValue::get(VS.VecTy);
244   for (unsigned I = 0; I < VS.NumFragments; ++I) {
245     Value *Fragment = Fragments[I];
246 
247     unsigned NumPacked = VS.NumPacked;
248     if (I == VS.NumFragments - 1 && VS.RemainderTy) {
249       if (auto *RemVecTy = dyn_cast<FixedVectorType>(VS.RemainderTy))
250         NumPacked = RemVecTy->getNumElements();
251       else
252         NumPacked = 1;
253     }
254 
255     if (NumPacked == 1) {
256       Res = Builder.CreateInsertElement(Res, Fragment, I * VS.NumPacked,
257                                         Name + ".upto" + Twine(I));
258     } else {
259       Fragment = Builder.CreateShuffleVector(Fragment, Fragment, ExtendMask);
260       if (I == 0) {
261         Res = Fragment;
262       } else {
263         for (unsigned J = 0; J < NumPacked; ++J)
264           InsertMask[I * VS.NumPacked + J] = NumElements + J;
265         Res = Builder.CreateShuffleVector(Res, Fragment, InsertMask,
266                                           Name + ".upto" + Twine(I));
267         for (unsigned J = 0; J < NumPacked; ++J)
268           InsertMask[I * VS.NumPacked + J] = I * VS.NumPacked + J;
269       }
270     }
271   }
272 
273   return Res;
274 }
275 
276 template <typename T>
277 T getWithDefaultOverride(const cl::opt<T> &ClOption,
278                          const std::optional<T> &DefaultOverride) {
279   return ClOption.getNumOccurrences() ? ClOption
280                                       : DefaultOverride.value_or(ClOption);
281 }
282 
283 class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
284 public:
285   ScalarizerVisitor(unsigned ParallelLoopAccessMDKind, DominatorTree *DT,
286                     ScalarizerPassOptions Options)
287       : ParallelLoopAccessMDKind(ParallelLoopAccessMDKind), DT(DT),
288         ScalarizeVariableInsertExtract(
289             getWithDefaultOverride(ClScalarizeVariableInsertExtract,
290                                    Options.ScalarizeVariableInsertExtract)),
291         ScalarizeLoadStore(getWithDefaultOverride(ClScalarizeLoadStore,
292                                                   Options.ScalarizeLoadStore)),
293         ScalarizeMinBits(getWithDefaultOverride(ClScalarizeMinBits,
294                                                 Options.ScalarizeMinBits)) {}
295 
296   bool visit(Function &F);
297 
298   // InstVisitor methods.  They return true if the instruction was scalarized,
299   // false if nothing changed.
300   bool visitInstruction(Instruction &I) { return false; }
301   bool visitSelectInst(SelectInst &SI);
302   bool visitICmpInst(ICmpInst &ICI);
303   bool visitFCmpInst(FCmpInst &FCI);
304   bool visitUnaryOperator(UnaryOperator &UO);
305   bool visitBinaryOperator(BinaryOperator &BO);
306   bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
307   bool visitCastInst(CastInst &CI);
308   bool visitBitCastInst(BitCastInst &BCI);
309   bool visitInsertElementInst(InsertElementInst &IEI);
310   bool visitExtractElementInst(ExtractElementInst &EEI);
311   bool visitShuffleVectorInst(ShuffleVectorInst &SVI);
312   bool visitPHINode(PHINode &PHI);
313   bool visitLoadInst(LoadInst &LI);
314   bool visitStoreInst(StoreInst &SI);
315   bool visitCallInst(CallInst &ICI);
316   bool visitFreezeInst(FreezeInst &FI);
317 
318 private:
319   Scatterer scatter(Instruction *Point, Value *V, const VectorSplit &VS);
320   void gather(Instruction *Op, const ValueVector &CV, const VectorSplit &VS);
321   void replaceUses(Instruction *Op, Value *CV);
322   bool canTransferMetadata(unsigned Kind);
323   void transferMetadataAndIRFlags(Instruction *Op, const ValueVector &CV);
324   std::optional<VectorSplit> getVectorSplit(Type *Ty);
325   std::optional<VectorLayout> getVectorLayout(Type *Ty, Align Alignment,
326                                               const DataLayout &DL);
327   bool finish();
328 
329   template<typename T> bool splitUnary(Instruction &, const T &);
330   template<typename T> bool splitBinary(Instruction &, const T &);
331 
332   bool splitCall(CallInst &CI);
333 
334   ScatterMap Scattered;
335   GatherList Gathered;
336   bool Scalarized;
337 
338   SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
339 
340   unsigned ParallelLoopAccessMDKind;
341 
342   DominatorTree *DT;
343 
344   const bool ScalarizeVariableInsertExtract;
345   const bool ScalarizeLoadStore;
346   const unsigned ScalarizeMinBits;
347 };
348 
349 class ScalarizerLegacyPass : public FunctionPass {
350 public:
351   static char ID;
352 
353   ScalarizerLegacyPass() : FunctionPass(ID) {
354     initializeScalarizerLegacyPassPass(*PassRegistry::getPassRegistry());
355   }
356 
357   bool runOnFunction(Function &F) override;
358 
359   void getAnalysisUsage(AnalysisUsage& AU) const override {
360     AU.addRequired<DominatorTreeWrapperPass>();
361     AU.addPreserved<DominatorTreeWrapperPass>();
362   }
363 };
364 
365 } // end anonymous namespace
366 
367 char ScalarizerLegacyPass::ID = 0;
368 INITIALIZE_PASS_BEGIN(ScalarizerLegacyPass, "scalarizer",
369                       "Scalarize vector operations", false, false)
370 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
371 INITIALIZE_PASS_END(ScalarizerLegacyPass, "scalarizer",
372                     "Scalarize vector operations", false, false)
373 
374 Scatterer::Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,
375                      const VectorSplit &VS, ValueVector *cachePtr)
376     : BB(bb), BBI(bbi), V(v), VS(VS), CachePtr(cachePtr) {
377   IsPointer = V->getType()->isPointerTy();
378   if (!CachePtr) {
379     Tmp.resize(VS.NumFragments, nullptr);
380   } else {
381     assert((CachePtr->empty() || VS.NumFragments == CachePtr->size() ||
382             IsPointer) &&
383            "Inconsistent vector sizes");
384     if (VS.NumFragments > CachePtr->size())
385       CachePtr->resize(VS.NumFragments, nullptr);
386   }
387 }
388 
389 // Return fragment Frag, creating a new Value for it if necessary.
390 Value *Scatterer::operator[](unsigned Frag) {
391   ValueVector &CV = CachePtr ? *CachePtr : Tmp;
392   // Try to reuse a previous value.
393   if (CV[Frag])
394     return CV[Frag];
395   IRBuilder<> Builder(BB, BBI);
396   if (IsPointer) {
397     if (Frag == 0)
398       CV[Frag] = V;
399     else
400       CV[Frag] = Builder.CreateConstGEP1_32(VS.SplitTy, V, Frag,
401                                             V->getName() + ".i" + Twine(Frag));
402     return CV[Frag];
403   }
404 
405   Type *FragmentTy = VS.getFragmentType(Frag);
406 
407   if (auto *VecTy = dyn_cast<FixedVectorType>(FragmentTy)) {
408     SmallVector<int> Mask;
409     for (unsigned J = 0; J < VecTy->getNumElements(); ++J)
410       Mask.push_back(Frag * VS.NumPacked + J);
411     CV[Frag] =
412         Builder.CreateShuffleVector(V, PoisonValue::get(V->getType()), Mask,
413                                     V->getName() + ".i" + Twine(Frag));
414   } else {
415     // Search through a chain of InsertElementInsts looking for element Frag.
416     // Record other elements in the cache.  The new V is still suitable
417     // for all uncached indices.
418     while (true) {
419       InsertElementInst *Insert = dyn_cast<InsertElementInst>(V);
420       if (!Insert)
421         break;
422       ConstantInt *Idx = dyn_cast<ConstantInt>(Insert->getOperand(2));
423       if (!Idx)
424         break;
425       unsigned J = Idx->getZExtValue();
426       V = Insert->getOperand(0);
427       if (Frag * VS.NumPacked == J) {
428         CV[Frag] = Insert->getOperand(1);
429         return CV[Frag];
430       }
431 
432       if (VS.NumPacked == 1 && !CV[J]) {
433         // Only cache the first entry we find for each index we're not actively
434         // searching for. This prevents us from going too far up the chain and
435         // caching incorrect entries.
436         CV[J] = Insert->getOperand(1);
437       }
438     }
439     CV[Frag] = Builder.CreateExtractElement(V, Frag * VS.NumPacked,
440                                             V->getName() + ".i" + Twine(Frag));
441   }
442 
443   return CV[Frag];
444 }
445 
446 bool ScalarizerLegacyPass::runOnFunction(Function &F) {
447   if (skipFunction(F))
448     return false;
449 
450   Module &M = *F.getParent();
451   unsigned ParallelLoopAccessMDKind =
452       M.getContext().getMDKindID("llvm.mem.parallel_loop_access");
453   DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
454   ScalarizerVisitor Impl(ParallelLoopAccessMDKind, DT, ScalarizerPassOptions());
455   return Impl.visit(F);
456 }
457 
458 FunctionPass *llvm::createScalarizerPass() {
459   return new ScalarizerLegacyPass();
460 }
461 
462 bool ScalarizerVisitor::visit(Function &F) {
463   assert(Gathered.empty() && Scattered.empty());
464 
465   Scalarized = false;
466 
467   // To ensure we replace gathered components correctly we need to do an ordered
468   // traversal of the basic blocks in the function.
469   ReversePostOrderTraversal<BasicBlock *> RPOT(&F.getEntryBlock());
470   for (BasicBlock *BB : RPOT) {
471     for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
472       Instruction *I = &*II;
473       bool Done = InstVisitor::visit(I);
474       ++II;
475       if (Done && I->getType()->isVoidTy())
476         I->eraseFromParent();
477     }
478   }
479   return finish();
480 }
481 
482 // Return a scattered form of V that can be accessed by Point.  V must be a
483 // vector or a pointer to a vector.
484 Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V,
485                                      const VectorSplit &VS) {
486   if (Argument *VArg = dyn_cast<Argument>(V)) {
487     // Put the scattered form of arguments in the entry block,
488     // so that it can be used everywhere.
489     Function *F = VArg->getParent();
490     BasicBlock *BB = &F->getEntryBlock();
491     return Scatterer(BB, BB->begin(), V, VS, &Scattered[{V, VS.SplitTy}]);
492   }
493   if (Instruction *VOp = dyn_cast<Instruction>(V)) {
494     // When scalarizing PHI nodes we might try to examine/rewrite InsertElement
495     // nodes in predecessors. If those predecessors are unreachable from entry,
496     // then the IR in those blocks could have unexpected properties resulting in
497     // infinite loops in Scatterer::operator[]. By simply treating values
498     // originating from instructions in unreachable blocks as undef we do not
499     // need to analyse them further.
500     if (!DT->isReachableFromEntry(VOp->getParent()))
501       return Scatterer(Point->getParent(), Point->getIterator(),
502                        PoisonValue::get(V->getType()), VS);
503     // Put the scattered form of an instruction directly after the
504     // instruction, skipping over PHI nodes and debug intrinsics.
505     BasicBlock *BB = VOp->getParent();
506     return Scatterer(
507         BB, skipPastPhiNodesAndDbg(std::next(BasicBlock::iterator(VOp))), V, VS,
508         &Scattered[{V, VS.SplitTy}]);
509   }
510   // In the fallback case, just put the scattered before Point and
511   // keep the result local to Point.
512   return Scatterer(Point->getParent(), Point->getIterator(), V, VS);
513 }
514 
515 // Replace Op with the gathered form of the components in CV.  Defer the
516 // deletion of Op and creation of the gathered form to the end of the pass,
517 // so that we can avoid creating the gathered form if all uses of Op are
518 // replaced with uses of CV.
519 void ScalarizerVisitor::gather(Instruction *Op, const ValueVector &CV,
520                                const VectorSplit &VS) {
521   transferMetadataAndIRFlags(Op, CV);
522 
523   // If we already have a scattered form of Op (created from ExtractElements
524   // of Op itself), replace them with the new form.
525   ValueVector &SV = Scattered[{Op, VS.SplitTy}];
526   if (!SV.empty()) {
527     for (unsigned I = 0, E = SV.size(); I != E; ++I) {
528       Value *V = SV[I];
529       if (V == nullptr || SV[I] == CV[I])
530         continue;
531 
532       Instruction *Old = cast<Instruction>(V);
533       if (isa<Instruction>(CV[I]))
534         CV[I]->takeName(Old);
535       Old->replaceAllUsesWith(CV[I]);
536       PotentiallyDeadInstrs.emplace_back(Old);
537     }
538   }
539   SV = CV;
540   Gathered.push_back(GatherList::value_type(Op, &SV));
541 }
542 
543 // Replace Op with CV and collect Op has a potentially dead instruction.
544 void ScalarizerVisitor::replaceUses(Instruction *Op, Value *CV) {
545   if (CV != Op) {
546     Op->replaceAllUsesWith(CV);
547     PotentiallyDeadInstrs.emplace_back(Op);
548     Scalarized = true;
549   }
550 }
551 
552 // Return true if it is safe to transfer the given metadata tag from
553 // vector to scalar instructions.
554 bool ScalarizerVisitor::canTransferMetadata(unsigned Tag) {
555   return (Tag == LLVMContext::MD_tbaa
556           || Tag == LLVMContext::MD_fpmath
557           || Tag == LLVMContext::MD_tbaa_struct
558           || Tag == LLVMContext::MD_invariant_load
559           || Tag == LLVMContext::MD_alias_scope
560           || Tag == LLVMContext::MD_noalias
561           || Tag == ParallelLoopAccessMDKind
562           || Tag == LLVMContext::MD_access_group);
563 }
564 
565 // Transfer metadata from Op to the instructions in CV if it is known
566 // to be safe to do so.
567 void ScalarizerVisitor::transferMetadataAndIRFlags(Instruction *Op,
568                                                    const ValueVector &CV) {
569   SmallVector<std::pair<unsigned, MDNode *>, 4> MDs;
570   Op->getAllMetadataOtherThanDebugLoc(MDs);
571   for (unsigned I = 0, E = CV.size(); I != E; ++I) {
572     if (Instruction *New = dyn_cast<Instruction>(CV[I])) {
573       for (const auto &MD : MDs)
574         if (canTransferMetadata(MD.first))
575           New->setMetadata(MD.first, MD.second);
576       New->copyIRFlags(Op);
577       if (Op->getDebugLoc() && !New->getDebugLoc())
578         New->setDebugLoc(Op->getDebugLoc());
579     }
580   }
581 }
582 
583 // Determine how Ty is split, if at all.
584 std::optional<VectorSplit> ScalarizerVisitor::getVectorSplit(Type *Ty) {
585   VectorSplit Split;
586   Split.VecTy = dyn_cast<FixedVectorType>(Ty);
587   if (!Split.VecTy)
588     return {};
589 
590   unsigned NumElems = Split.VecTy->getNumElements();
591   Type *ElemTy = Split.VecTy->getElementType();
592 
593   if (NumElems == 1 || ElemTy->isPointerTy() ||
594       2 * ElemTy->getScalarSizeInBits() > ScalarizeMinBits) {
595     Split.NumPacked = 1;
596     Split.NumFragments = NumElems;
597     Split.SplitTy = ElemTy;
598   } else {
599     Split.NumPacked = ScalarizeMinBits / ElemTy->getScalarSizeInBits();
600     if (Split.NumPacked >= NumElems)
601       return {};
602 
603     Split.NumFragments = divideCeil(NumElems, Split.NumPacked);
604     Split.SplitTy = FixedVectorType::get(ElemTy, Split.NumPacked);
605 
606     unsigned RemainderElems = NumElems % Split.NumPacked;
607     if (RemainderElems > 1)
608       Split.RemainderTy = FixedVectorType::get(ElemTy, RemainderElems);
609     else if (RemainderElems == 1)
610       Split.RemainderTy = ElemTy;
611   }
612 
613   return Split;
614 }
615 
616 // Try to fill in Layout from Ty, returning true on success.  Alignment is
617 // the alignment of the vector, or std::nullopt if the ABI default should be
618 // used.
619 std::optional<VectorLayout>
620 ScalarizerVisitor::getVectorLayout(Type *Ty, Align Alignment,
621                                    const DataLayout &DL) {
622   std::optional<VectorSplit> VS = getVectorSplit(Ty);
623   if (!VS)
624     return {};
625 
626   VectorLayout Layout;
627   Layout.VS = *VS;
628   // Check that we're dealing with full-byte fragments.
629   if (!DL.typeSizeEqualsStoreSize(VS->SplitTy) ||
630       (VS->RemainderTy && !DL.typeSizeEqualsStoreSize(VS->RemainderTy)))
631     return {};
632   Layout.VecAlign = Alignment;
633   Layout.SplitSize = DL.getTypeStoreSize(VS->SplitTy);
634   return Layout;
635 }
636 
637 // Scalarize one-operand instruction I, using Split(Builder, X, Name)
638 // to create an instruction like I with operand X and name Name.
639 template<typename Splitter>
640 bool ScalarizerVisitor::splitUnary(Instruction &I, const Splitter &Split) {
641   std::optional<VectorSplit> VS = getVectorSplit(I.getType());
642   if (!VS)
643     return false;
644 
645   std::optional<VectorSplit> OpVS;
646   if (I.getOperand(0)->getType() == I.getType()) {
647     OpVS = VS;
648   } else {
649     OpVS = getVectorSplit(I.getOperand(0)->getType());
650     if (!OpVS || VS->NumPacked != OpVS->NumPacked)
651       return false;
652   }
653 
654   IRBuilder<> Builder(&I);
655   Scatterer Op = scatter(&I, I.getOperand(0), *OpVS);
656   assert(Op.size() == VS->NumFragments && "Mismatched unary operation");
657   ValueVector Res;
658   Res.resize(VS->NumFragments);
659   for (unsigned Frag = 0; Frag < VS->NumFragments; ++Frag)
660     Res[Frag] = Split(Builder, Op[Frag], I.getName() + ".i" + Twine(Frag));
661   gather(&I, Res, *VS);
662   return true;
663 }
664 
665 // Scalarize two-operand instruction I, using Split(Builder, X, Y, Name)
666 // to create an instruction like I with operands X and Y and name Name.
667 template<typename Splitter>
668 bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
669   std::optional<VectorSplit> VS = getVectorSplit(I.getType());
670   if (!VS)
671     return false;
672 
673   std::optional<VectorSplit> OpVS;
674   if (I.getOperand(0)->getType() == I.getType()) {
675     OpVS = VS;
676   } else {
677     OpVS = getVectorSplit(I.getOperand(0)->getType());
678     if (!OpVS || VS->NumPacked != OpVS->NumPacked)
679       return false;
680   }
681 
682   IRBuilder<> Builder(&I);
683   Scatterer VOp0 = scatter(&I, I.getOperand(0), *OpVS);
684   Scatterer VOp1 = scatter(&I, I.getOperand(1), *OpVS);
685   assert(VOp0.size() == VS->NumFragments && "Mismatched binary operation");
686   assert(VOp1.size() == VS->NumFragments && "Mismatched binary operation");
687   ValueVector Res;
688   Res.resize(VS->NumFragments);
689   for (unsigned Frag = 0; Frag < VS->NumFragments; ++Frag) {
690     Value *Op0 = VOp0[Frag];
691     Value *Op1 = VOp1[Frag];
692     Res[Frag] = Split(Builder, Op0, Op1, I.getName() + ".i" + Twine(Frag));
693   }
694   gather(&I, Res, *VS);
695   return true;
696 }
697 
698 static bool isTriviallyScalariable(Intrinsic::ID ID) {
699   return isTriviallyVectorizable(ID);
700 }
701 
702 /// If a call to a vector typed intrinsic function, split into a scalar call per
703 /// element if possible for the intrinsic.
704 bool ScalarizerVisitor::splitCall(CallInst &CI) {
705   std::optional<VectorSplit> VS = getVectorSplit(CI.getType());
706   if (!VS)
707     return false;
708 
709   Function *F = CI.getCalledFunction();
710   if (!F)
711     return false;
712 
713   Intrinsic::ID ID = F->getIntrinsicID();
714   if (ID == Intrinsic::not_intrinsic || !isTriviallyScalariable(ID))
715     return false;
716 
717   // unsigned NumElems = VT->getNumElements();
718   unsigned NumArgs = CI.arg_size();
719 
720   ValueVector ScalarOperands(NumArgs);
721   SmallVector<Scatterer, 8> Scattered(NumArgs);
722   SmallVector<int> OverloadIdx(NumArgs, -1);
723 
724   SmallVector<llvm::Type *, 3> Tys;
725   // Add return type if intrinsic is overloaded on it.
726   if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
727     Tys.push_back(VS->SplitTy);
728 
729   // Assumes that any vector type has the same number of elements as the return
730   // vector type, which is true for all current intrinsics.
731   for (unsigned I = 0; I != NumArgs; ++I) {
732     Value *OpI = CI.getOperand(I);
733     if (auto *OpVecTy = dyn_cast<FixedVectorType>(OpI->getType())) {
734       assert(OpVecTy->getNumElements() == VS->VecTy->getNumElements());
735       std::optional<VectorSplit> OpVS = getVectorSplit(OpI->getType());
736       if (!OpVS || OpVS->NumPacked != VS->NumPacked) {
737         // The natural split of the operand doesn't match the result. This could
738         // happen if the vector elements are different and the ScalarizeMinBits
739         // option is used.
740         //
741         // We could in principle handle this case as well, at the cost of
742         // complicating the scattering machinery to support multiple scattering
743         // granularities for a single value.
744         return false;
745       }
746 
747       Scattered[I] = scatter(&CI, OpI, *OpVS);
748       if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) {
749         OverloadIdx[I] = Tys.size();
750         Tys.push_back(OpVS->SplitTy);
751       }
752     } else {
753       ScalarOperands[I] = OpI;
754       if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
755         Tys.push_back(OpI->getType());
756     }
757   }
758 
759   ValueVector Res(VS->NumFragments);
760   ValueVector ScalarCallOps(NumArgs);
761 
762   Function *NewIntrin = Intrinsic::getDeclaration(F->getParent(), ID, Tys);
763   IRBuilder<> Builder(&CI);
764 
765   // Perform actual scalarization, taking care to preserve any scalar operands.
766   for (unsigned I = 0; I < VS->NumFragments; ++I) {
767     bool IsRemainder = I == VS->NumFragments - 1 && VS->RemainderTy;
768     ScalarCallOps.clear();
769 
770     if (IsRemainder)
771       Tys[0] = VS->RemainderTy;
772 
773     for (unsigned J = 0; J != NumArgs; ++J) {
774       if (isVectorIntrinsicWithScalarOpAtArg(ID, J)) {
775         ScalarCallOps.push_back(ScalarOperands[J]);
776       } else {
777         ScalarCallOps.push_back(Scattered[J][I]);
778         if (IsRemainder && OverloadIdx[J] >= 0)
779           Tys[OverloadIdx[J]] = Scattered[J][I]->getType();
780       }
781     }
782 
783     if (IsRemainder)
784       NewIntrin = Intrinsic::getDeclaration(F->getParent(), ID, Tys);
785 
786     Res[I] = Builder.CreateCall(NewIntrin, ScalarCallOps,
787                                 CI.getName() + ".i" + Twine(I));
788   }
789 
790   gather(&CI, Res, *VS);
791   return true;
792 }
793 
794 bool ScalarizerVisitor::visitSelectInst(SelectInst &SI) {
795   std::optional<VectorSplit> VS = getVectorSplit(SI.getType());
796   if (!VS)
797     return false;
798 
799   std::optional<VectorSplit> CondVS;
800   if (isa<FixedVectorType>(SI.getCondition()->getType())) {
801     CondVS = getVectorSplit(SI.getCondition()->getType());
802     if (!CondVS || CondVS->NumPacked != VS->NumPacked) {
803       // This happens when ScalarizeMinBits is used.
804       return false;
805     }
806   }
807 
808   IRBuilder<> Builder(&SI);
809   Scatterer VOp1 = scatter(&SI, SI.getOperand(1), *VS);
810   Scatterer VOp2 = scatter(&SI, SI.getOperand(2), *VS);
811   assert(VOp1.size() == VS->NumFragments && "Mismatched select");
812   assert(VOp2.size() == VS->NumFragments && "Mismatched select");
813   ValueVector Res;
814   Res.resize(VS->NumFragments);
815 
816   if (CondVS) {
817     Scatterer VOp0 = scatter(&SI, SI.getOperand(0), *CondVS);
818     assert(VOp0.size() == CondVS->NumFragments && "Mismatched select");
819     for (unsigned I = 0; I < VS->NumFragments; ++I) {
820       Value *Op0 = VOp0[I];
821       Value *Op1 = VOp1[I];
822       Value *Op2 = VOp2[I];
823       Res[I] = Builder.CreateSelect(Op0, Op1, Op2,
824                                     SI.getName() + ".i" + Twine(I));
825     }
826   } else {
827     Value *Op0 = SI.getOperand(0);
828     for (unsigned I = 0; I < VS->NumFragments; ++I) {
829       Value *Op1 = VOp1[I];
830       Value *Op2 = VOp2[I];
831       Res[I] = Builder.CreateSelect(Op0, Op1, Op2,
832                                     SI.getName() + ".i" + Twine(I));
833     }
834   }
835   gather(&SI, Res, *VS);
836   return true;
837 }
838 
839 bool ScalarizerVisitor::visitICmpInst(ICmpInst &ICI) {
840   return splitBinary(ICI, ICmpSplitter(ICI));
841 }
842 
843 bool ScalarizerVisitor::visitFCmpInst(FCmpInst &FCI) {
844   return splitBinary(FCI, FCmpSplitter(FCI));
845 }
846 
847 bool ScalarizerVisitor::visitUnaryOperator(UnaryOperator &UO) {
848   return splitUnary(UO, UnarySplitter(UO));
849 }
850 
851 bool ScalarizerVisitor::visitBinaryOperator(BinaryOperator &BO) {
852   return splitBinary(BO, BinarySplitter(BO));
853 }
854 
855 bool ScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
856   std::optional<VectorSplit> VS = getVectorSplit(GEPI.getType());
857   if (!VS)
858     return false;
859 
860   IRBuilder<> Builder(&GEPI);
861   unsigned NumIndices = GEPI.getNumIndices();
862 
863   // The base pointer and indices might be scalar even if it's a vector GEP.
864   SmallVector<Value *, 8> ScalarOps{1 + NumIndices};
865   SmallVector<Scatterer, 8> ScatterOps{1 + NumIndices};
866 
867   for (unsigned I = 0; I < 1 + NumIndices; ++I) {
868     if (auto *VecTy =
869             dyn_cast<FixedVectorType>(GEPI.getOperand(I)->getType())) {
870       std::optional<VectorSplit> OpVS = getVectorSplit(VecTy);
871       if (!OpVS || OpVS->NumPacked != VS->NumPacked) {
872         // This can happen when ScalarizeMinBits is used.
873         return false;
874       }
875       ScatterOps[I] = scatter(&GEPI, GEPI.getOperand(I), *OpVS);
876     } else {
877       ScalarOps[I] = GEPI.getOperand(I);
878     }
879   }
880 
881   ValueVector Res;
882   Res.resize(VS->NumFragments);
883   for (unsigned I = 0; I < VS->NumFragments; ++I) {
884     SmallVector<Value *, 8> SplitOps;
885     SplitOps.resize(1 + NumIndices);
886     for (unsigned J = 0; J < 1 + NumIndices; ++J) {
887       if (ScalarOps[J])
888         SplitOps[J] = ScalarOps[J];
889       else
890         SplitOps[J] = ScatterOps[J][I];
891     }
892     Res[I] = Builder.CreateGEP(GEPI.getSourceElementType(), SplitOps[0],
893                                ArrayRef(SplitOps).drop_front(),
894                                GEPI.getName() + ".i" + Twine(I));
895     if (GEPI.isInBounds())
896       if (GetElementPtrInst *NewGEPI = dyn_cast<GetElementPtrInst>(Res[I]))
897         NewGEPI->setIsInBounds();
898   }
899   gather(&GEPI, Res, *VS);
900   return true;
901 }
902 
903 bool ScalarizerVisitor::visitCastInst(CastInst &CI) {
904   std::optional<VectorSplit> DestVS = getVectorSplit(CI.getDestTy());
905   if (!DestVS)
906     return false;
907 
908   std::optional<VectorSplit> SrcVS = getVectorSplit(CI.getSrcTy());
909   if (!SrcVS || SrcVS->NumPacked != DestVS->NumPacked)
910     return false;
911 
912   IRBuilder<> Builder(&CI);
913   Scatterer Op0 = scatter(&CI, CI.getOperand(0), *SrcVS);
914   assert(Op0.size() == SrcVS->NumFragments && "Mismatched cast");
915   ValueVector Res;
916   Res.resize(DestVS->NumFragments);
917   for (unsigned I = 0; I < DestVS->NumFragments; ++I)
918     Res[I] =
919         Builder.CreateCast(CI.getOpcode(), Op0[I], DestVS->getFragmentType(I),
920                            CI.getName() + ".i" + Twine(I));
921   gather(&CI, Res, *DestVS);
922   return true;
923 }
924 
925 bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) {
926   std::optional<VectorSplit> DstVS = getVectorSplit(BCI.getDestTy());
927   std::optional<VectorSplit> SrcVS = getVectorSplit(BCI.getSrcTy());
928   if (!DstVS || !SrcVS || DstVS->RemainderTy || SrcVS->RemainderTy)
929     return false;
930 
931   const bool isPointerTy = DstVS->VecTy->getElementType()->isPointerTy();
932 
933   // Vectors of pointers are always fully scalarized.
934   assert(!isPointerTy || (DstVS->NumPacked == 1 && SrcVS->NumPacked == 1));
935 
936   IRBuilder<> Builder(&BCI);
937   Scatterer Op0 = scatter(&BCI, BCI.getOperand(0), *SrcVS);
938   ValueVector Res;
939   Res.resize(DstVS->NumFragments);
940 
941   unsigned DstSplitBits = DstVS->SplitTy->getPrimitiveSizeInBits();
942   unsigned SrcSplitBits = SrcVS->SplitTy->getPrimitiveSizeInBits();
943 
944   if (isPointerTy || DstSplitBits == SrcSplitBits) {
945     assert(DstVS->NumFragments == SrcVS->NumFragments);
946     for (unsigned I = 0; I < DstVS->NumFragments; ++I) {
947       Res[I] = Builder.CreateBitCast(Op0[I], DstVS->getFragmentType(I),
948                                      BCI.getName() + ".i" + Twine(I));
949     }
950   } else if (SrcSplitBits % DstSplitBits == 0) {
951     // Convert each source fragment to the same-sized destination vector and
952     // then scatter the result to the destination.
953     VectorSplit MidVS;
954     MidVS.NumPacked = DstVS->NumPacked;
955     MidVS.NumFragments = SrcSplitBits / DstSplitBits;
956     MidVS.VecTy = FixedVectorType::get(DstVS->VecTy->getElementType(),
957                                        MidVS.NumPacked * MidVS.NumFragments);
958     MidVS.SplitTy = DstVS->SplitTy;
959 
960     unsigned ResI = 0;
961     for (unsigned I = 0; I < SrcVS->NumFragments; ++I) {
962       Value *V = Op0[I];
963 
964       // Look through any existing bitcasts before converting to <N x t2>.
965       // In the best case, the resulting conversion might be a no-op.
966       Instruction *VI;
967       while ((VI = dyn_cast<Instruction>(V)) &&
968              VI->getOpcode() == Instruction::BitCast)
969         V = VI->getOperand(0);
970 
971       V = Builder.CreateBitCast(V, MidVS.VecTy, V->getName() + ".cast");
972 
973       Scatterer Mid = scatter(&BCI, V, MidVS);
974       for (unsigned J = 0; J < MidVS.NumFragments; ++J)
975         Res[ResI++] = Mid[J];
976     }
977   } else if (DstSplitBits % SrcSplitBits == 0) {
978     // Gather enough source fragments to make up a destination fragment and
979     // then convert to the destination type.
980     VectorSplit MidVS;
981     MidVS.NumFragments = DstSplitBits / SrcSplitBits;
982     MidVS.NumPacked = SrcVS->NumPacked;
983     MidVS.VecTy = FixedVectorType::get(SrcVS->VecTy->getElementType(),
984                                        MidVS.NumPacked * MidVS.NumFragments);
985     MidVS.SplitTy = SrcVS->SplitTy;
986 
987     unsigned SrcI = 0;
988     SmallVector<Value *, 8> ConcatOps;
989     ConcatOps.resize(MidVS.NumFragments);
990     for (unsigned I = 0; I < DstVS->NumFragments; ++I) {
991       for (unsigned J = 0; J < MidVS.NumFragments; ++J)
992         ConcatOps[J] = Op0[SrcI++];
993       Value *V = concatenate(Builder, ConcatOps, MidVS,
994                              BCI.getName() + ".i" + Twine(I));
995       Res[I] = Builder.CreateBitCast(V, DstVS->getFragmentType(I),
996                                      BCI.getName() + ".i" + Twine(I));
997     }
998   } else {
999     return false;
1000   }
1001 
1002   gather(&BCI, Res, *DstVS);
1003   return true;
1004 }
1005 
1006 bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
1007   std::optional<VectorSplit> VS = getVectorSplit(IEI.getType());
1008   if (!VS)
1009     return false;
1010 
1011   IRBuilder<> Builder(&IEI);
1012   Scatterer Op0 = scatter(&IEI, IEI.getOperand(0), *VS);
1013   Value *NewElt = IEI.getOperand(1);
1014   Value *InsIdx = IEI.getOperand(2);
1015 
1016   ValueVector Res;
1017   Res.resize(VS->NumFragments);
1018 
1019   if (auto *CI = dyn_cast<ConstantInt>(InsIdx)) {
1020     unsigned Idx = CI->getZExtValue();
1021     unsigned Fragment = Idx / VS->NumPacked;
1022     for (unsigned I = 0; I < VS->NumFragments; ++I) {
1023       if (I == Fragment) {
1024         bool IsPacked = VS->NumPacked > 1;
1025         if (Fragment == VS->NumFragments - 1 && VS->RemainderTy &&
1026             !VS->RemainderTy->isVectorTy())
1027           IsPacked = false;
1028         if (IsPacked) {
1029           Res[I] =
1030               Builder.CreateInsertElement(Op0[I], NewElt, Idx % VS->NumPacked);
1031         } else {
1032           Res[I] = NewElt;
1033         }
1034       } else {
1035         Res[I] = Op0[I];
1036       }
1037     }
1038   } else {
1039     // Never split a variable insertelement that isn't fully scalarized.
1040     if (!ScalarizeVariableInsertExtract || VS->NumPacked > 1)
1041       return false;
1042 
1043     for (unsigned I = 0; I < VS->NumFragments; ++I) {
1044       Value *ShouldReplace =
1045           Builder.CreateICmpEQ(InsIdx, ConstantInt::get(InsIdx->getType(), I),
1046                                InsIdx->getName() + ".is." + Twine(I));
1047       Value *OldElt = Op0[I];
1048       Res[I] = Builder.CreateSelect(ShouldReplace, NewElt, OldElt,
1049                                     IEI.getName() + ".i" + Twine(I));
1050     }
1051   }
1052 
1053   gather(&IEI, Res, *VS);
1054   return true;
1055 }
1056 
1057 bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
1058   std::optional<VectorSplit> VS = getVectorSplit(EEI.getOperand(0)->getType());
1059   if (!VS)
1060     return false;
1061 
1062   IRBuilder<> Builder(&EEI);
1063   Scatterer Op0 = scatter(&EEI, EEI.getOperand(0), *VS);
1064   Value *ExtIdx = EEI.getOperand(1);
1065 
1066   if (auto *CI = dyn_cast<ConstantInt>(ExtIdx)) {
1067     unsigned Idx = CI->getZExtValue();
1068     unsigned Fragment = Idx / VS->NumPacked;
1069     Value *Res = Op0[Fragment];
1070     bool IsPacked = VS->NumPacked > 1;
1071     if (Fragment == VS->NumFragments - 1 && VS->RemainderTy &&
1072         !VS->RemainderTy->isVectorTy())
1073       IsPacked = false;
1074     if (IsPacked)
1075       Res = Builder.CreateExtractElement(Res, Idx % VS->NumPacked);
1076     replaceUses(&EEI, Res);
1077     return true;
1078   }
1079 
1080   // Never split a variable extractelement that isn't fully scalarized.
1081   if (!ScalarizeVariableInsertExtract || VS->NumPacked > 1)
1082     return false;
1083 
1084   Value *Res = PoisonValue::get(VS->VecTy->getElementType());
1085   for (unsigned I = 0; I < VS->NumFragments; ++I) {
1086     Value *ShouldExtract =
1087         Builder.CreateICmpEQ(ExtIdx, ConstantInt::get(ExtIdx->getType(), I),
1088                              ExtIdx->getName() + ".is." + Twine(I));
1089     Value *Elt = Op0[I];
1090     Res = Builder.CreateSelect(ShouldExtract, Elt, Res,
1091                                EEI.getName() + ".upto" + Twine(I));
1092   }
1093   replaceUses(&EEI, Res);
1094   return true;
1095 }
1096 
1097 bool ScalarizerVisitor::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
1098   std::optional<VectorSplit> VS = getVectorSplit(SVI.getType());
1099   std::optional<VectorSplit> VSOp =
1100       getVectorSplit(SVI.getOperand(0)->getType());
1101   if (!VS || !VSOp || VS->NumPacked > 1 || VSOp->NumPacked > 1)
1102     return false;
1103 
1104   Scatterer Op0 = scatter(&SVI, SVI.getOperand(0), *VSOp);
1105   Scatterer Op1 = scatter(&SVI, SVI.getOperand(1), *VSOp);
1106   ValueVector Res;
1107   Res.resize(VS->NumFragments);
1108 
1109   for (unsigned I = 0; I < VS->NumFragments; ++I) {
1110     int Selector = SVI.getMaskValue(I);
1111     if (Selector < 0)
1112       Res[I] = PoisonValue::get(VS->VecTy->getElementType());
1113     else if (unsigned(Selector) < Op0.size())
1114       Res[I] = Op0[Selector];
1115     else
1116       Res[I] = Op1[Selector - Op0.size()];
1117   }
1118   gather(&SVI, Res, *VS);
1119   return true;
1120 }
1121 
1122 bool ScalarizerVisitor::visitPHINode(PHINode &PHI) {
1123   std::optional<VectorSplit> VS = getVectorSplit(PHI.getType());
1124   if (!VS)
1125     return false;
1126 
1127   IRBuilder<> Builder(&PHI);
1128   ValueVector Res;
1129   Res.resize(VS->NumFragments);
1130 
1131   unsigned NumOps = PHI.getNumOperands();
1132   for (unsigned I = 0; I < VS->NumFragments; ++I) {
1133     Res[I] = Builder.CreatePHI(VS->getFragmentType(I), NumOps,
1134                                PHI.getName() + ".i" + Twine(I));
1135   }
1136 
1137   for (unsigned I = 0; I < NumOps; ++I) {
1138     Scatterer Op = scatter(&PHI, PHI.getIncomingValue(I), *VS);
1139     BasicBlock *IncomingBlock = PHI.getIncomingBlock(I);
1140     for (unsigned J = 0; J < VS->NumFragments; ++J)
1141       cast<PHINode>(Res[J])->addIncoming(Op[J], IncomingBlock);
1142   }
1143   gather(&PHI, Res, *VS);
1144   return true;
1145 }
1146 
1147 bool ScalarizerVisitor::visitLoadInst(LoadInst &LI) {
1148   if (!ScalarizeLoadStore)
1149     return false;
1150   if (!LI.isSimple())
1151     return false;
1152 
1153   std::optional<VectorLayout> Layout = getVectorLayout(
1154       LI.getType(), LI.getAlign(), LI.getModule()->getDataLayout());
1155   if (!Layout)
1156     return false;
1157 
1158   IRBuilder<> Builder(&LI);
1159   Scatterer Ptr = scatter(&LI, LI.getPointerOperand(), Layout->VS);
1160   ValueVector Res;
1161   Res.resize(Layout->VS.NumFragments);
1162 
1163   for (unsigned I = 0; I < Layout->VS.NumFragments; ++I) {
1164     Res[I] = Builder.CreateAlignedLoad(Layout->VS.getFragmentType(I), Ptr[I],
1165                                        Align(Layout->getFragmentAlign(I)),
1166                                        LI.getName() + ".i" + Twine(I));
1167   }
1168   gather(&LI, Res, Layout->VS);
1169   return true;
1170 }
1171 
1172 bool ScalarizerVisitor::visitStoreInst(StoreInst &SI) {
1173   if (!ScalarizeLoadStore)
1174     return false;
1175   if (!SI.isSimple())
1176     return false;
1177 
1178   Value *FullValue = SI.getValueOperand();
1179   std::optional<VectorLayout> Layout = getVectorLayout(
1180       FullValue->getType(), SI.getAlign(), SI.getModule()->getDataLayout());
1181   if (!Layout)
1182     return false;
1183 
1184   IRBuilder<> Builder(&SI);
1185   Scatterer VPtr = scatter(&SI, SI.getPointerOperand(), Layout->VS);
1186   Scatterer VVal = scatter(&SI, FullValue, Layout->VS);
1187 
1188   ValueVector Stores;
1189   Stores.resize(Layout->VS.NumFragments);
1190   for (unsigned I = 0; I < Layout->VS.NumFragments; ++I) {
1191     Value *Val = VVal[I];
1192     Value *Ptr = VPtr[I];
1193     Stores[I] =
1194         Builder.CreateAlignedStore(Val, Ptr, Layout->getFragmentAlign(I));
1195   }
1196   transferMetadataAndIRFlags(&SI, Stores);
1197   return true;
1198 }
1199 
1200 bool ScalarizerVisitor::visitCallInst(CallInst &CI) {
1201   return splitCall(CI);
1202 }
1203 
1204 bool ScalarizerVisitor::visitFreezeInst(FreezeInst &FI) {
1205   return splitUnary(FI, [](IRBuilder<> &Builder, Value *Op, const Twine &Name) {
1206     return Builder.CreateFreeze(Op, Name);
1207   });
1208 }
1209 
1210 // Delete the instructions that we scalarized.  If a full vector result
1211 // is still needed, recreate it using InsertElements.
1212 bool ScalarizerVisitor::finish() {
1213   // The presence of data in Gathered or Scattered indicates changes
1214   // made to the Function.
1215   if (Gathered.empty() && Scattered.empty() && !Scalarized)
1216     return false;
1217   for (const auto &GMI : Gathered) {
1218     Instruction *Op = GMI.first;
1219     ValueVector &CV = *GMI.second;
1220     if (!Op->use_empty()) {
1221       // The value is still needed, so recreate it using a series of
1222       // insertelements and/or shufflevectors.
1223       Value *Res;
1224       if (auto *Ty = dyn_cast<FixedVectorType>(Op->getType())) {
1225         BasicBlock *BB = Op->getParent();
1226         IRBuilder<> Builder(Op);
1227         if (isa<PHINode>(Op))
1228           Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
1229 
1230         VectorSplit VS = *getVectorSplit(Ty);
1231         assert(VS.NumFragments == CV.size());
1232 
1233         Res = concatenate(Builder, CV, VS, Op->getName());
1234 
1235         Res->takeName(Op);
1236       } else {
1237         assert(CV.size() == 1 && Op->getType() == CV[0]->getType());
1238         Res = CV[0];
1239         if (Op == Res)
1240           continue;
1241       }
1242       Op->replaceAllUsesWith(Res);
1243     }
1244     PotentiallyDeadInstrs.emplace_back(Op);
1245   }
1246   Gathered.clear();
1247   Scattered.clear();
1248   Scalarized = false;
1249 
1250   RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
1251 
1252   return true;
1253 }
1254 
1255 PreservedAnalyses ScalarizerPass::run(Function &F, FunctionAnalysisManager &AM) {
1256   Module &M = *F.getParent();
1257   unsigned ParallelLoopAccessMDKind =
1258       M.getContext().getMDKindID("llvm.mem.parallel_loop_access");
1259   DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F);
1260   ScalarizerVisitor Impl(ParallelLoopAccessMDKind, DT, Options);
1261   bool Changed = Impl.visit(F);
1262   PreservedAnalyses PA;
1263   PA.preserve<DominatorTreeAnalysis>();
1264   return Changed ? PA : PreservedAnalyses::all();
1265 }
1266