xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Scalar/Scalarizer.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- Scalarizer.cpp - Scalarize vector operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass converts vector operations into scalar operations (or, optionally,
10 // operations on smaller vector widths), in order to expose optimization
11 // opportunities on the individual scalar operations.
12 // It is mainly intended for targets that do not have vector units, but it
13 // may also be useful for revectorizing code to different vector widths.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/Transforms/Scalar/Scalarizer.h"
18 #include "llvm/ADT/PostOrderIterator.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/Twine.h"
21 #include "llvm/Analysis/VectorUtils.h"
22 #include "llvm/IR/Argument.h"
23 #include "llvm/IR/BasicBlock.h"
24 #include "llvm/IR/Constants.h"
25 #include "llvm/IR/DataLayout.h"
26 #include "llvm/IR/DerivedTypes.h"
27 #include "llvm/IR/Dominators.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/IRBuilder.h"
30 #include "llvm/IR/InstVisitor.h"
31 #include "llvm/IR/InstrTypes.h"
32 #include "llvm/IR/Instruction.h"
33 #include "llvm/IR/Instructions.h"
34 #include "llvm/IR/Intrinsics.h"
35 #include "llvm/IR/LLVMContext.h"
36 #include "llvm/IR/Module.h"
37 #include "llvm/IR/Type.h"
38 #include "llvm/IR/Value.h"
39 #include "llvm/Support/Casting.h"
40 #include "llvm/Support/CommandLine.h"
41 #include "llvm/Transforms/Utils/Local.h"
42 #include <cassert>
43 #include <cstdint>
44 #include <iterator>
45 #include <map>
46 #include <utility>
47 
48 using namespace llvm;
49 
50 #define DEBUG_TYPE "scalarizer"
51 
52 static cl::opt<bool> ClScalarizeVariableInsertExtract(
53     "scalarize-variable-insert-extract", cl::init(true), cl::Hidden,
54     cl::desc("Allow the scalarizer pass to scalarize "
55              "insertelement/extractelement with variable index"));
56 
57 // This is disabled by default because having separate loads and stores
58 // makes it more likely that the -combiner-alias-analysis limits will be
59 // reached.
60 static cl::opt<bool> ClScalarizeLoadStore(
61     "scalarize-load-store", cl::init(false), cl::Hidden,
62     cl::desc("Allow the scalarizer pass to scalarize loads and store"));
63 
64 // Split vectors larger than this size into fragments, where each fragment is
65 // either a vector no larger than this size or a scalar.
66 //
67 // Instructions with operands or results of different sizes that would be split
68 // into a different number of fragments are currently left as-is.
69 static cl::opt<unsigned> ClScalarizeMinBits(
70     "scalarize-min-bits", cl::init(0), cl::Hidden,
71     cl::desc("Instruct the scalarizer pass to attempt to keep values of a "
72              "minimum number of bits"));
73 
74 namespace {
75 
skipPastPhiNodesAndDbg(BasicBlock::iterator Itr)76 BasicBlock::iterator skipPastPhiNodesAndDbg(BasicBlock::iterator Itr) {
77   BasicBlock *BB = Itr->getParent();
78   if (isa<PHINode>(Itr))
79     Itr = BB->getFirstInsertionPt();
80   if (Itr != BB->end())
81     Itr = skipDebugIntrinsics(Itr);
82   return Itr;
83 }
84 
85 // Used to store the scattered form of a vector.
86 using ValueVector = SmallVector<Value *, 8>;
87 
88 // Used to map a vector Value and associated type to its scattered form.
89 // The associated type is only non-null for pointer values that are "scattered"
90 // when used as pointer operands to load or store.
91 //
92 // We use std::map because we want iterators to persist across insertion and
93 // because the values are relatively large.
94 using ScatterMap = std::map<std::pair<Value *, Type *>, ValueVector>;
95 
96 // Lists Instructions that have been replaced with scalar implementations,
97 // along with a pointer to their scattered forms.
98 using GatherList = SmallVector<std::pair<Instruction *, ValueVector *>, 16>;
99 
100 struct VectorSplit {
101   // The type of the vector.
102   FixedVectorType *VecTy = nullptr;
103 
104   // The number of elements packed in a fragment (other than the remainder).
105   unsigned NumPacked = 0;
106 
107   // The number of fragments (scalars or smaller vectors) into which the vector
108   // shall be split.
109   unsigned NumFragments = 0;
110 
111   // The type of each complete fragment.
112   Type *SplitTy = nullptr;
113 
114   // The type of the remainder (last) fragment; null if all fragments are
115   // complete.
116   Type *RemainderTy = nullptr;
117 
getFragmentType__anon7628c1430111::VectorSplit118   Type *getFragmentType(unsigned I) const {
119     return RemainderTy && I == NumFragments - 1 ? RemainderTy : SplitTy;
120   }
121 };
122 
123 // Provides a very limited vector-like interface for lazily accessing one
124 // component of a scattered vector or vector pointer.
125 class Scatterer {
126 public:
127   Scatterer() = default;
128 
129   // Scatter V into Size components.  If new instructions are needed,
130   // insert them before BBI in BB.  If Cache is nonnull, use it to cache
131   // the results.
132   Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,
133             const VectorSplit &VS, ValueVector *cachePtr = nullptr);
134 
135   // Return component I, creating a new Value for it if necessary.
136   Value *operator[](unsigned I);
137 
138   // Return the number of components.
size() const139   unsigned size() const { return VS.NumFragments; }
140 
141 private:
142   BasicBlock *BB;
143   BasicBlock::iterator BBI;
144   Value *V;
145   VectorSplit VS;
146   bool IsPointer;
147   ValueVector *CachePtr;
148   ValueVector Tmp;
149 };
150 
151 // FCmpSplitter(FCI)(Builder, X, Y, Name) uses Builder to create an FCmp
152 // called Name that compares X and Y in the same way as FCI.
153 struct FCmpSplitter {
FCmpSplitter__anon7628c1430111::FCmpSplitter154   FCmpSplitter(FCmpInst &fci) : FCI(fci) {}
155 
operator ()__anon7628c1430111::FCmpSplitter156   Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
157                     const Twine &Name) const {
158     return Builder.CreateFCmp(FCI.getPredicate(), Op0, Op1, Name);
159   }
160 
161   FCmpInst &FCI;
162 };
163 
164 // ICmpSplitter(ICI)(Builder, X, Y, Name) uses Builder to create an ICmp
165 // called Name that compares X and Y in the same way as ICI.
166 struct ICmpSplitter {
ICmpSplitter__anon7628c1430111::ICmpSplitter167   ICmpSplitter(ICmpInst &ici) : ICI(ici) {}
168 
operator ()__anon7628c1430111::ICmpSplitter169   Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
170                     const Twine &Name) const {
171     return Builder.CreateICmp(ICI.getPredicate(), Op0, Op1, Name);
172   }
173 
174   ICmpInst &ICI;
175 };
176 
177 // UnarySplitter(UO)(Builder, X, Name) uses Builder to create
178 // a unary operator like UO called Name with operand X.
179 struct UnarySplitter {
UnarySplitter__anon7628c1430111::UnarySplitter180   UnarySplitter(UnaryOperator &uo) : UO(uo) {}
181 
operator ()__anon7628c1430111::UnarySplitter182   Value *operator()(IRBuilder<> &Builder, Value *Op, const Twine &Name) const {
183     return Builder.CreateUnOp(UO.getOpcode(), Op, Name);
184   }
185 
186   UnaryOperator &UO;
187 };
188 
189 // BinarySplitter(BO)(Builder, X, Y, Name) uses Builder to create
190 // a binary operator like BO called Name with operands X and Y.
191 struct BinarySplitter {
BinarySplitter__anon7628c1430111::BinarySplitter192   BinarySplitter(BinaryOperator &bo) : BO(bo) {}
193 
operator ()__anon7628c1430111::BinarySplitter194   Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
195                     const Twine &Name) const {
196     return Builder.CreateBinOp(BO.getOpcode(), Op0, Op1, Name);
197   }
198 
199   BinaryOperator &BO;
200 };
201 
202 // Information about a load or store that we're scalarizing.
203 struct VectorLayout {
204   VectorLayout() = default;
205 
206   // Return the alignment of fragment Frag.
getFragmentAlign__anon7628c1430111::VectorLayout207   Align getFragmentAlign(unsigned Frag) {
208     return commonAlignment(VecAlign, Frag * SplitSize);
209   }
210 
211   // The split of the underlying vector type.
212   VectorSplit VS;
213 
214   // The alignment of the vector.
215   Align VecAlign;
216 
217   // The size of each (non-remainder) fragment in bytes.
218   uint64_t SplitSize = 0;
219 };
220 
221 /// Concatenate the given fragments to a single vector value of the type
222 /// described in @p VS.
concatenate(IRBuilder<> & Builder,ArrayRef<Value * > Fragments,const VectorSplit & VS,Twine Name)223 static Value *concatenate(IRBuilder<> &Builder, ArrayRef<Value *> Fragments,
224                           const VectorSplit &VS, Twine Name) {
225   unsigned NumElements = VS.VecTy->getNumElements();
226   SmallVector<int> ExtendMask;
227   SmallVector<int> InsertMask;
228 
229   if (VS.NumPacked > 1) {
230     // Prepare the shufflevector masks once and re-use them for all
231     // fragments.
232     ExtendMask.resize(NumElements, -1);
233     for (unsigned I = 0; I < VS.NumPacked; ++I)
234       ExtendMask[I] = I;
235 
236     InsertMask.resize(NumElements);
237     for (unsigned I = 0; I < NumElements; ++I)
238       InsertMask[I] = I;
239   }
240 
241   Value *Res = PoisonValue::get(VS.VecTy);
242   for (unsigned I = 0; I < VS.NumFragments; ++I) {
243     Value *Fragment = Fragments[I];
244 
245     unsigned NumPacked = VS.NumPacked;
246     if (I == VS.NumFragments - 1 && VS.RemainderTy) {
247       if (auto *RemVecTy = dyn_cast<FixedVectorType>(VS.RemainderTy))
248         NumPacked = RemVecTy->getNumElements();
249       else
250         NumPacked = 1;
251     }
252 
253     if (NumPacked == 1) {
254       Res = Builder.CreateInsertElement(Res, Fragment, I * VS.NumPacked,
255                                         Name + ".upto" + Twine(I));
256     } else {
257       Fragment = Builder.CreateShuffleVector(Fragment, Fragment, ExtendMask);
258       if (I == 0) {
259         Res = Fragment;
260       } else {
261         for (unsigned J = 0; J < NumPacked; ++J)
262           InsertMask[I * VS.NumPacked + J] = NumElements + J;
263         Res = Builder.CreateShuffleVector(Res, Fragment, InsertMask,
264                                           Name + ".upto" + Twine(I));
265         for (unsigned J = 0; J < NumPacked; ++J)
266           InsertMask[I * VS.NumPacked + J] = I * VS.NumPacked + J;
267       }
268     }
269   }
270 
271   return Res;
272 }
273 
274 template <typename T>
getWithDefaultOverride(const cl::opt<T> & ClOption,const std::optional<T> & DefaultOverride)275 T getWithDefaultOverride(const cl::opt<T> &ClOption,
276                          const std::optional<T> &DefaultOverride) {
277   return ClOption.getNumOccurrences() ? ClOption
278                                       : DefaultOverride.value_or(ClOption);
279 }
280 
281 class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
282 public:
ScalarizerVisitor(DominatorTree * DT,ScalarizerPassOptions Options)283   ScalarizerVisitor(DominatorTree *DT, ScalarizerPassOptions Options)
284       : DT(DT), ScalarizeVariableInsertExtract(getWithDefaultOverride(
285                     ClScalarizeVariableInsertExtract,
286                     Options.ScalarizeVariableInsertExtract)),
287         ScalarizeLoadStore(getWithDefaultOverride(ClScalarizeLoadStore,
288                                                   Options.ScalarizeLoadStore)),
289         ScalarizeMinBits(getWithDefaultOverride(ClScalarizeMinBits,
290                                                 Options.ScalarizeMinBits)) {}
291 
292   bool visit(Function &F);
293 
294   // InstVisitor methods.  They return true if the instruction was scalarized,
295   // false if nothing changed.
visitInstruction(Instruction & I)296   bool visitInstruction(Instruction &I) { return false; }
297   bool visitSelectInst(SelectInst &SI);
298   bool visitICmpInst(ICmpInst &ICI);
299   bool visitFCmpInst(FCmpInst &FCI);
300   bool visitUnaryOperator(UnaryOperator &UO);
301   bool visitBinaryOperator(BinaryOperator &BO);
302   bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
303   bool visitCastInst(CastInst &CI);
304   bool visitBitCastInst(BitCastInst &BCI);
305   bool visitInsertElementInst(InsertElementInst &IEI);
306   bool visitExtractElementInst(ExtractElementInst &EEI);
307   bool visitShuffleVectorInst(ShuffleVectorInst &SVI);
308   bool visitPHINode(PHINode &PHI);
309   bool visitLoadInst(LoadInst &LI);
310   bool visitStoreInst(StoreInst &SI);
311   bool visitCallInst(CallInst &ICI);
312   bool visitFreezeInst(FreezeInst &FI);
313 
314 private:
315   Scatterer scatter(Instruction *Point, Value *V, const VectorSplit &VS);
316   void gather(Instruction *Op, const ValueVector &CV, const VectorSplit &VS);
317   void replaceUses(Instruction *Op, Value *CV);
318   bool canTransferMetadata(unsigned Kind);
319   void transferMetadataAndIRFlags(Instruction *Op, const ValueVector &CV);
320   std::optional<VectorSplit> getVectorSplit(Type *Ty);
321   std::optional<VectorLayout> getVectorLayout(Type *Ty, Align Alignment,
322                                               const DataLayout &DL);
323   bool finish();
324 
325   template<typename T> bool splitUnary(Instruction &, const T &);
326   template<typename T> bool splitBinary(Instruction &, const T &);
327 
328   bool splitCall(CallInst &CI);
329 
330   ScatterMap Scattered;
331   GatherList Gathered;
332   bool Scalarized;
333 
334   SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
335 
336   DominatorTree *DT;
337 
338   const bool ScalarizeVariableInsertExtract;
339   const bool ScalarizeLoadStore;
340   const unsigned ScalarizeMinBits;
341 };
342 
343 } // end anonymous namespace
344 
Scatterer(BasicBlock * bb,BasicBlock::iterator bbi,Value * v,const VectorSplit & VS,ValueVector * cachePtr)345 Scatterer::Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,
346                      const VectorSplit &VS, ValueVector *cachePtr)
347     : BB(bb), BBI(bbi), V(v), VS(VS), CachePtr(cachePtr) {
348   IsPointer = V->getType()->isPointerTy();
349   if (!CachePtr) {
350     Tmp.resize(VS.NumFragments, nullptr);
351   } else {
352     assert((CachePtr->empty() || VS.NumFragments == CachePtr->size() ||
353             IsPointer) &&
354            "Inconsistent vector sizes");
355     if (VS.NumFragments > CachePtr->size())
356       CachePtr->resize(VS.NumFragments, nullptr);
357   }
358 }
359 
360 // Return fragment Frag, creating a new Value for it if necessary.
operator [](unsigned Frag)361 Value *Scatterer::operator[](unsigned Frag) {
362   ValueVector &CV = CachePtr ? *CachePtr : Tmp;
363   // Try to reuse a previous value.
364   if (CV[Frag])
365     return CV[Frag];
366   IRBuilder<> Builder(BB, BBI);
367   if (IsPointer) {
368     if (Frag == 0)
369       CV[Frag] = V;
370     else
371       CV[Frag] = Builder.CreateConstGEP1_32(VS.SplitTy, V, Frag,
372                                             V->getName() + ".i" + Twine(Frag));
373     return CV[Frag];
374   }
375 
376   Type *FragmentTy = VS.getFragmentType(Frag);
377 
378   if (auto *VecTy = dyn_cast<FixedVectorType>(FragmentTy)) {
379     SmallVector<int> Mask;
380     for (unsigned J = 0; J < VecTy->getNumElements(); ++J)
381       Mask.push_back(Frag * VS.NumPacked + J);
382     CV[Frag] =
383         Builder.CreateShuffleVector(V, PoisonValue::get(V->getType()), Mask,
384                                     V->getName() + ".i" + Twine(Frag));
385   } else {
386     // Search through a chain of InsertElementInsts looking for element Frag.
387     // Record other elements in the cache.  The new V is still suitable
388     // for all uncached indices.
389     while (true) {
390       InsertElementInst *Insert = dyn_cast<InsertElementInst>(V);
391       if (!Insert)
392         break;
393       ConstantInt *Idx = dyn_cast<ConstantInt>(Insert->getOperand(2));
394       if (!Idx)
395         break;
396       unsigned J = Idx->getZExtValue();
397       V = Insert->getOperand(0);
398       if (Frag * VS.NumPacked == J) {
399         CV[Frag] = Insert->getOperand(1);
400         return CV[Frag];
401       }
402 
403       if (VS.NumPacked == 1 && !CV[J]) {
404         // Only cache the first entry we find for each index we're not actively
405         // searching for. This prevents us from going too far up the chain and
406         // caching incorrect entries.
407         CV[J] = Insert->getOperand(1);
408       }
409     }
410     CV[Frag] = Builder.CreateExtractElement(V, Frag * VS.NumPacked,
411                                             V->getName() + ".i" + Twine(Frag));
412   }
413 
414   return CV[Frag];
415 }
416 
visit(Function & F)417 bool ScalarizerVisitor::visit(Function &F) {
418   assert(Gathered.empty() && Scattered.empty());
419 
420   Scalarized = false;
421 
422   // To ensure we replace gathered components correctly we need to do an ordered
423   // traversal of the basic blocks in the function.
424   ReversePostOrderTraversal<BasicBlock *> RPOT(&F.getEntryBlock());
425   for (BasicBlock *BB : RPOT) {
426     for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
427       Instruction *I = &*II;
428       bool Done = InstVisitor::visit(I);
429       ++II;
430       if (Done && I->getType()->isVoidTy())
431         I->eraseFromParent();
432     }
433   }
434   return finish();
435 }
436 
437 // Return a scattered form of V that can be accessed by Point.  V must be a
438 // vector or a pointer to a vector.
scatter(Instruction * Point,Value * V,const VectorSplit & VS)439 Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V,
440                                      const VectorSplit &VS) {
441   if (Argument *VArg = dyn_cast<Argument>(V)) {
442     // Put the scattered form of arguments in the entry block,
443     // so that it can be used everywhere.
444     Function *F = VArg->getParent();
445     BasicBlock *BB = &F->getEntryBlock();
446     return Scatterer(BB, BB->begin(), V, VS, &Scattered[{V, VS.SplitTy}]);
447   }
448   if (Instruction *VOp = dyn_cast<Instruction>(V)) {
449     // When scalarizing PHI nodes we might try to examine/rewrite InsertElement
450     // nodes in predecessors. If those predecessors are unreachable from entry,
451     // then the IR in those blocks could have unexpected properties resulting in
452     // infinite loops in Scatterer::operator[]. By simply treating values
453     // originating from instructions in unreachable blocks as undef we do not
454     // need to analyse them further.
455     if (!DT->isReachableFromEntry(VOp->getParent()))
456       return Scatterer(Point->getParent(), Point->getIterator(),
457                        PoisonValue::get(V->getType()), VS);
458     // Put the scattered form of an instruction directly after the
459     // instruction, skipping over PHI nodes and debug intrinsics.
460     BasicBlock *BB = VOp->getParent();
461     return Scatterer(
462         BB, skipPastPhiNodesAndDbg(std::next(BasicBlock::iterator(VOp))), V, VS,
463         &Scattered[{V, VS.SplitTy}]);
464   }
465   // In the fallback case, just put the scattered before Point and
466   // keep the result local to Point.
467   return Scatterer(Point->getParent(), Point->getIterator(), V, VS);
468 }
469 
470 // Replace Op with the gathered form of the components in CV.  Defer the
471 // deletion of Op and creation of the gathered form to the end of the pass,
472 // so that we can avoid creating the gathered form if all uses of Op are
473 // replaced with uses of CV.
gather(Instruction * Op,const ValueVector & CV,const VectorSplit & VS)474 void ScalarizerVisitor::gather(Instruction *Op, const ValueVector &CV,
475                                const VectorSplit &VS) {
476   transferMetadataAndIRFlags(Op, CV);
477 
478   // If we already have a scattered form of Op (created from ExtractElements
479   // of Op itself), replace them with the new form.
480   ValueVector &SV = Scattered[{Op, VS.SplitTy}];
481   if (!SV.empty()) {
482     for (unsigned I = 0, E = SV.size(); I != E; ++I) {
483       Value *V = SV[I];
484       if (V == nullptr || SV[I] == CV[I])
485         continue;
486 
487       Instruction *Old = cast<Instruction>(V);
488       if (isa<Instruction>(CV[I]))
489         CV[I]->takeName(Old);
490       Old->replaceAllUsesWith(CV[I]);
491       PotentiallyDeadInstrs.emplace_back(Old);
492     }
493   }
494   SV = CV;
495   Gathered.push_back(GatherList::value_type(Op, &SV));
496 }
497 
498 // Replace Op with CV and collect Op has a potentially dead instruction.
replaceUses(Instruction * Op,Value * CV)499 void ScalarizerVisitor::replaceUses(Instruction *Op, Value *CV) {
500   if (CV != Op) {
501     Op->replaceAllUsesWith(CV);
502     PotentiallyDeadInstrs.emplace_back(Op);
503     Scalarized = true;
504   }
505 }
506 
507 // Return true if it is safe to transfer the given metadata tag from
508 // vector to scalar instructions.
canTransferMetadata(unsigned Tag)509 bool ScalarizerVisitor::canTransferMetadata(unsigned Tag) {
510   return (Tag == LLVMContext::MD_tbaa
511           || Tag == LLVMContext::MD_fpmath
512           || Tag == LLVMContext::MD_tbaa_struct
513           || Tag == LLVMContext::MD_invariant_load
514           || Tag == LLVMContext::MD_alias_scope
515           || Tag == LLVMContext::MD_noalias
516           || Tag == LLVMContext::MD_mem_parallel_loop_access
517           || Tag == LLVMContext::MD_access_group);
518 }
519 
520 // Transfer metadata from Op to the instructions in CV if it is known
521 // to be safe to do so.
transferMetadataAndIRFlags(Instruction * Op,const ValueVector & CV)522 void ScalarizerVisitor::transferMetadataAndIRFlags(Instruction *Op,
523                                                    const ValueVector &CV) {
524   SmallVector<std::pair<unsigned, MDNode *>, 4> MDs;
525   Op->getAllMetadataOtherThanDebugLoc(MDs);
526   for (Value *V : CV) {
527     if (Instruction *New = dyn_cast<Instruction>(V)) {
528       for (const auto &MD : MDs)
529         if (canTransferMetadata(MD.first))
530           New->setMetadata(MD.first, MD.second);
531       New->copyIRFlags(Op);
532       if (Op->getDebugLoc() && !New->getDebugLoc())
533         New->setDebugLoc(Op->getDebugLoc());
534     }
535   }
536 }
537 
538 // Determine how Ty is split, if at all.
getVectorSplit(Type * Ty)539 std::optional<VectorSplit> ScalarizerVisitor::getVectorSplit(Type *Ty) {
540   VectorSplit Split;
541   Split.VecTy = dyn_cast<FixedVectorType>(Ty);
542   if (!Split.VecTy)
543     return {};
544 
545   unsigned NumElems = Split.VecTy->getNumElements();
546   Type *ElemTy = Split.VecTy->getElementType();
547 
548   if (NumElems == 1 || ElemTy->isPointerTy() ||
549       2 * ElemTy->getScalarSizeInBits() > ScalarizeMinBits) {
550     Split.NumPacked = 1;
551     Split.NumFragments = NumElems;
552     Split.SplitTy = ElemTy;
553   } else {
554     Split.NumPacked = ScalarizeMinBits / ElemTy->getScalarSizeInBits();
555     if (Split.NumPacked >= NumElems)
556       return {};
557 
558     Split.NumFragments = divideCeil(NumElems, Split.NumPacked);
559     Split.SplitTy = FixedVectorType::get(ElemTy, Split.NumPacked);
560 
561     unsigned RemainderElems = NumElems % Split.NumPacked;
562     if (RemainderElems > 1)
563       Split.RemainderTy = FixedVectorType::get(ElemTy, RemainderElems);
564     else if (RemainderElems == 1)
565       Split.RemainderTy = ElemTy;
566   }
567 
568   return Split;
569 }
570 
571 // Try to fill in Layout from Ty, returning true on success.  Alignment is
572 // the alignment of the vector, or std::nullopt if the ABI default should be
573 // used.
574 std::optional<VectorLayout>
getVectorLayout(Type * Ty,Align Alignment,const DataLayout & DL)575 ScalarizerVisitor::getVectorLayout(Type *Ty, Align Alignment,
576                                    const DataLayout &DL) {
577   std::optional<VectorSplit> VS = getVectorSplit(Ty);
578   if (!VS)
579     return {};
580 
581   VectorLayout Layout;
582   Layout.VS = *VS;
583   // Check that we're dealing with full-byte fragments.
584   if (!DL.typeSizeEqualsStoreSize(VS->SplitTy) ||
585       (VS->RemainderTy && !DL.typeSizeEqualsStoreSize(VS->RemainderTy)))
586     return {};
587   Layout.VecAlign = Alignment;
588   Layout.SplitSize = DL.getTypeStoreSize(VS->SplitTy);
589   return Layout;
590 }
591 
592 // Scalarize one-operand instruction I, using Split(Builder, X, Name)
593 // to create an instruction like I with operand X and name Name.
594 template<typename Splitter>
splitUnary(Instruction & I,const Splitter & Split)595 bool ScalarizerVisitor::splitUnary(Instruction &I, const Splitter &Split) {
596   std::optional<VectorSplit> VS = getVectorSplit(I.getType());
597   if (!VS)
598     return false;
599 
600   std::optional<VectorSplit> OpVS;
601   if (I.getOperand(0)->getType() == I.getType()) {
602     OpVS = VS;
603   } else {
604     OpVS = getVectorSplit(I.getOperand(0)->getType());
605     if (!OpVS || VS->NumPacked != OpVS->NumPacked)
606       return false;
607   }
608 
609   IRBuilder<> Builder(&I);
610   Scatterer Op = scatter(&I, I.getOperand(0), *OpVS);
611   assert(Op.size() == VS->NumFragments && "Mismatched unary operation");
612   ValueVector Res;
613   Res.resize(VS->NumFragments);
614   for (unsigned Frag = 0; Frag < VS->NumFragments; ++Frag)
615     Res[Frag] = Split(Builder, Op[Frag], I.getName() + ".i" + Twine(Frag));
616   gather(&I, Res, *VS);
617   return true;
618 }
619 
620 // Scalarize two-operand instruction I, using Split(Builder, X, Y, Name)
621 // to create an instruction like I with operands X and Y and name Name.
622 template<typename Splitter>
splitBinary(Instruction & I,const Splitter & Split)623 bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
624   std::optional<VectorSplit> VS = getVectorSplit(I.getType());
625   if (!VS)
626     return false;
627 
628   std::optional<VectorSplit> OpVS;
629   if (I.getOperand(0)->getType() == I.getType()) {
630     OpVS = VS;
631   } else {
632     OpVS = getVectorSplit(I.getOperand(0)->getType());
633     if (!OpVS || VS->NumPacked != OpVS->NumPacked)
634       return false;
635   }
636 
637   IRBuilder<> Builder(&I);
638   Scatterer VOp0 = scatter(&I, I.getOperand(0), *OpVS);
639   Scatterer VOp1 = scatter(&I, I.getOperand(1), *OpVS);
640   assert(VOp0.size() == VS->NumFragments && "Mismatched binary operation");
641   assert(VOp1.size() == VS->NumFragments && "Mismatched binary operation");
642   ValueVector Res;
643   Res.resize(VS->NumFragments);
644   for (unsigned Frag = 0; Frag < VS->NumFragments; ++Frag) {
645     Value *Op0 = VOp0[Frag];
646     Value *Op1 = VOp1[Frag];
647     Res[Frag] = Split(Builder, Op0, Op1, I.getName() + ".i" + Twine(Frag));
648   }
649   gather(&I, Res, *VS);
650   return true;
651 }
652 
isTriviallyScalariable(Intrinsic::ID ID)653 static bool isTriviallyScalariable(Intrinsic::ID ID) {
654   return isTriviallyVectorizable(ID);
655 }
656 
657 /// If a call to a vector typed intrinsic function, split into a scalar call per
658 /// element if possible for the intrinsic.
splitCall(CallInst & CI)659 bool ScalarizerVisitor::splitCall(CallInst &CI) {
660   std::optional<VectorSplit> VS = getVectorSplit(CI.getType());
661   if (!VS)
662     return false;
663 
664   Function *F = CI.getCalledFunction();
665   if (!F)
666     return false;
667 
668   Intrinsic::ID ID = F->getIntrinsicID();
669   if (ID == Intrinsic::not_intrinsic || !isTriviallyScalariable(ID))
670     return false;
671 
672   // unsigned NumElems = VT->getNumElements();
673   unsigned NumArgs = CI.arg_size();
674 
675   ValueVector ScalarOperands(NumArgs);
676   SmallVector<Scatterer, 8> Scattered(NumArgs);
677   SmallVector<int> OverloadIdx(NumArgs, -1);
678 
679   SmallVector<llvm::Type *, 3> Tys;
680   // Add return type if intrinsic is overloaded on it.
681   if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
682     Tys.push_back(VS->SplitTy);
683 
684   // Assumes that any vector type has the same number of elements as the return
685   // vector type, which is true for all current intrinsics.
686   for (unsigned I = 0; I != NumArgs; ++I) {
687     Value *OpI = CI.getOperand(I);
688     if ([[maybe_unused]] auto *OpVecTy =
689             dyn_cast<FixedVectorType>(OpI->getType())) {
690       assert(OpVecTy->getNumElements() == VS->VecTy->getNumElements());
691       std::optional<VectorSplit> OpVS = getVectorSplit(OpI->getType());
692       if (!OpVS || OpVS->NumPacked != VS->NumPacked) {
693         // The natural split of the operand doesn't match the result. This could
694         // happen if the vector elements are different and the ScalarizeMinBits
695         // option is used.
696         //
697         // We could in principle handle this case as well, at the cost of
698         // complicating the scattering machinery to support multiple scattering
699         // granularities for a single value.
700         return false;
701       }
702 
703       Scattered[I] = scatter(&CI, OpI, *OpVS);
704       if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) {
705         OverloadIdx[I] = Tys.size();
706         Tys.push_back(OpVS->SplitTy);
707       }
708     } else {
709       ScalarOperands[I] = OpI;
710       if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
711         Tys.push_back(OpI->getType());
712     }
713   }
714 
715   ValueVector Res(VS->NumFragments);
716   ValueVector ScalarCallOps(NumArgs);
717 
718   Function *NewIntrin = Intrinsic::getDeclaration(F->getParent(), ID, Tys);
719   IRBuilder<> Builder(&CI);
720 
721   // Perform actual scalarization, taking care to preserve any scalar operands.
722   for (unsigned I = 0; I < VS->NumFragments; ++I) {
723     bool IsRemainder = I == VS->NumFragments - 1 && VS->RemainderTy;
724     ScalarCallOps.clear();
725 
726     if (IsRemainder)
727       Tys[0] = VS->RemainderTy;
728 
729     for (unsigned J = 0; J != NumArgs; ++J) {
730       if (isVectorIntrinsicWithScalarOpAtArg(ID, J)) {
731         ScalarCallOps.push_back(ScalarOperands[J]);
732       } else {
733         ScalarCallOps.push_back(Scattered[J][I]);
734         if (IsRemainder && OverloadIdx[J] >= 0)
735           Tys[OverloadIdx[J]] = Scattered[J][I]->getType();
736       }
737     }
738 
739     if (IsRemainder)
740       NewIntrin = Intrinsic::getDeclaration(F->getParent(), ID, Tys);
741 
742     Res[I] = Builder.CreateCall(NewIntrin, ScalarCallOps,
743                                 CI.getName() + ".i" + Twine(I));
744   }
745 
746   gather(&CI, Res, *VS);
747   return true;
748 }
749 
visitSelectInst(SelectInst & SI)750 bool ScalarizerVisitor::visitSelectInst(SelectInst &SI) {
751   std::optional<VectorSplit> VS = getVectorSplit(SI.getType());
752   if (!VS)
753     return false;
754 
755   std::optional<VectorSplit> CondVS;
756   if (isa<FixedVectorType>(SI.getCondition()->getType())) {
757     CondVS = getVectorSplit(SI.getCondition()->getType());
758     if (!CondVS || CondVS->NumPacked != VS->NumPacked) {
759       // This happens when ScalarizeMinBits is used.
760       return false;
761     }
762   }
763 
764   IRBuilder<> Builder(&SI);
765   Scatterer VOp1 = scatter(&SI, SI.getOperand(1), *VS);
766   Scatterer VOp2 = scatter(&SI, SI.getOperand(2), *VS);
767   assert(VOp1.size() == VS->NumFragments && "Mismatched select");
768   assert(VOp2.size() == VS->NumFragments && "Mismatched select");
769   ValueVector Res;
770   Res.resize(VS->NumFragments);
771 
772   if (CondVS) {
773     Scatterer VOp0 = scatter(&SI, SI.getOperand(0), *CondVS);
774     assert(VOp0.size() == CondVS->NumFragments && "Mismatched select");
775     for (unsigned I = 0; I < VS->NumFragments; ++I) {
776       Value *Op0 = VOp0[I];
777       Value *Op1 = VOp1[I];
778       Value *Op2 = VOp2[I];
779       Res[I] = Builder.CreateSelect(Op0, Op1, Op2,
780                                     SI.getName() + ".i" + Twine(I));
781     }
782   } else {
783     Value *Op0 = SI.getOperand(0);
784     for (unsigned I = 0; I < VS->NumFragments; ++I) {
785       Value *Op1 = VOp1[I];
786       Value *Op2 = VOp2[I];
787       Res[I] = Builder.CreateSelect(Op0, Op1, Op2,
788                                     SI.getName() + ".i" + Twine(I));
789     }
790   }
791   gather(&SI, Res, *VS);
792   return true;
793 }
794 
visitICmpInst(ICmpInst & ICI)795 bool ScalarizerVisitor::visitICmpInst(ICmpInst &ICI) {
796   return splitBinary(ICI, ICmpSplitter(ICI));
797 }
798 
visitFCmpInst(FCmpInst & FCI)799 bool ScalarizerVisitor::visitFCmpInst(FCmpInst &FCI) {
800   return splitBinary(FCI, FCmpSplitter(FCI));
801 }
802 
visitUnaryOperator(UnaryOperator & UO)803 bool ScalarizerVisitor::visitUnaryOperator(UnaryOperator &UO) {
804   return splitUnary(UO, UnarySplitter(UO));
805 }
806 
visitBinaryOperator(BinaryOperator & BO)807 bool ScalarizerVisitor::visitBinaryOperator(BinaryOperator &BO) {
808   return splitBinary(BO, BinarySplitter(BO));
809 }
810 
visitGetElementPtrInst(GetElementPtrInst & GEPI)811 bool ScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
812   std::optional<VectorSplit> VS = getVectorSplit(GEPI.getType());
813   if (!VS)
814     return false;
815 
816   IRBuilder<> Builder(&GEPI);
817   unsigned NumIndices = GEPI.getNumIndices();
818 
819   // The base pointer and indices might be scalar even if it's a vector GEP.
820   SmallVector<Value *, 8> ScalarOps{1 + NumIndices};
821   SmallVector<Scatterer, 8> ScatterOps{1 + NumIndices};
822 
823   for (unsigned I = 0; I < 1 + NumIndices; ++I) {
824     if (auto *VecTy =
825             dyn_cast<FixedVectorType>(GEPI.getOperand(I)->getType())) {
826       std::optional<VectorSplit> OpVS = getVectorSplit(VecTy);
827       if (!OpVS || OpVS->NumPacked != VS->NumPacked) {
828         // This can happen when ScalarizeMinBits is used.
829         return false;
830       }
831       ScatterOps[I] = scatter(&GEPI, GEPI.getOperand(I), *OpVS);
832     } else {
833       ScalarOps[I] = GEPI.getOperand(I);
834     }
835   }
836 
837   ValueVector Res;
838   Res.resize(VS->NumFragments);
839   for (unsigned I = 0; I < VS->NumFragments; ++I) {
840     SmallVector<Value *, 8> SplitOps;
841     SplitOps.resize(1 + NumIndices);
842     for (unsigned J = 0; J < 1 + NumIndices; ++J) {
843       if (ScalarOps[J])
844         SplitOps[J] = ScalarOps[J];
845       else
846         SplitOps[J] = ScatterOps[J][I];
847     }
848     Res[I] = Builder.CreateGEP(GEPI.getSourceElementType(), SplitOps[0],
849                                ArrayRef(SplitOps).drop_front(),
850                                GEPI.getName() + ".i" + Twine(I));
851     if (GEPI.isInBounds())
852       if (GetElementPtrInst *NewGEPI = dyn_cast<GetElementPtrInst>(Res[I]))
853         NewGEPI->setIsInBounds();
854   }
855   gather(&GEPI, Res, *VS);
856   return true;
857 }
858 
visitCastInst(CastInst & CI)859 bool ScalarizerVisitor::visitCastInst(CastInst &CI) {
860   std::optional<VectorSplit> DestVS = getVectorSplit(CI.getDestTy());
861   if (!DestVS)
862     return false;
863 
864   std::optional<VectorSplit> SrcVS = getVectorSplit(CI.getSrcTy());
865   if (!SrcVS || SrcVS->NumPacked != DestVS->NumPacked)
866     return false;
867 
868   IRBuilder<> Builder(&CI);
869   Scatterer Op0 = scatter(&CI, CI.getOperand(0), *SrcVS);
870   assert(Op0.size() == SrcVS->NumFragments && "Mismatched cast");
871   ValueVector Res;
872   Res.resize(DestVS->NumFragments);
873   for (unsigned I = 0; I < DestVS->NumFragments; ++I)
874     Res[I] =
875         Builder.CreateCast(CI.getOpcode(), Op0[I], DestVS->getFragmentType(I),
876                            CI.getName() + ".i" + Twine(I));
877   gather(&CI, Res, *DestVS);
878   return true;
879 }
880 
visitBitCastInst(BitCastInst & BCI)881 bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) {
882   std::optional<VectorSplit> DstVS = getVectorSplit(BCI.getDestTy());
883   std::optional<VectorSplit> SrcVS = getVectorSplit(BCI.getSrcTy());
884   if (!DstVS || !SrcVS || DstVS->RemainderTy || SrcVS->RemainderTy)
885     return false;
886 
887   const bool isPointerTy = DstVS->VecTy->getElementType()->isPointerTy();
888 
889   // Vectors of pointers are always fully scalarized.
890   assert(!isPointerTy || (DstVS->NumPacked == 1 && SrcVS->NumPacked == 1));
891 
892   IRBuilder<> Builder(&BCI);
893   Scatterer Op0 = scatter(&BCI, BCI.getOperand(0), *SrcVS);
894   ValueVector Res;
895   Res.resize(DstVS->NumFragments);
896 
897   unsigned DstSplitBits = DstVS->SplitTy->getPrimitiveSizeInBits();
898   unsigned SrcSplitBits = SrcVS->SplitTy->getPrimitiveSizeInBits();
899 
900   if (isPointerTy || DstSplitBits == SrcSplitBits) {
901     assert(DstVS->NumFragments == SrcVS->NumFragments);
902     for (unsigned I = 0; I < DstVS->NumFragments; ++I) {
903       Res[I] = Builder.CreateBitCast(Op0[I], DstVS->getFragmentType(I),
904                                      BCI.getName() + ".i" + Twine(I));
905     }
906   } else if (SrcSplitBits % DstSplitBits == 0) {
907     // Convert each source fragment to the same-sized destination vector and
908     // then scatter the result to the destination.
909     VectorSplit MidVS;
910     MidVS.NumPacked = DstVS->NumPacked;
911     MidVS.NumFragments = SrcSplitBits / DstSplitBits;
912     MidVS.VecTy = FixedVectorType::get(DstVS->VecTy->getElementType(),
913                                        MidVS.NumPacked * MidVS.NumFragments);
914     MidVS.SplitTy = DstVS->SplitTy;
915 
916     unsigned ResI = 0;
917     for (unsigned I = 0; I < SrcVS->NumFragments; ++I) {
918       Value *V = Op0[I];
919 
920       // Look through any existing bitcasts before converting to <N x t2>.
921       // In the best case, the resulting conversion might be a no-op.
922       Instruction *VI;
923       while ((VI = dyn_cast<Instruction>(V)) &&
924              VI->getOpcode() == Instruction::BitCast)
925         V = VI->getOperand(0);
926 
927       V = Builder.CreateBitCast(V, MidVS.VecTy, V->getName() + ".cast");
928 
929       Scatterer Mid = scatter(&BCI, V, MidVS);
930       for (unsigned J = 0; J < MidVS.NumFragments; ++J)
931         Res[ResI++] = Mid[J];
932     }
933   } else if (DstSplitBits % SrcSplitBits == 0) {
934     // Gather enough source fragments to make up a destination fragment and
935     // then convert to the destination type.
936     VectorSplit MidVS;
937     MidVS.NumFragments = DstSplitBits / SrcSplitBits;
938     MidVS.NumPacked = SrcVS->NumPacked;
939     MidVS.VecTy = FixedVectorType::get(SrcVS->VecTy->getElementType(),
940                                        MidVS.NumPacked * MidVS.NumFragments);
941     MidVS.SplitTy = SrcVS->SplitTy;
942 
943     unsigned SrcI = 0;
944     SmallVector<Value *, 8> ConcatOps;
945     ConcatOps.resize(MidVS.NumFragments);
946     for (unsigned I = 0; I < DstVS->NumFragments; ++I) {
947       for (unsigned J = 0; J < MidVS.NumFragments; ++J)
948         ConcatOps[J] = Op0[SrcI++];
949       Value *V = concatenate(Builder, ConcatOps, MidVS,
950                              BCI.getName() + ".i" + Twine(I));
951       Res[I] = Builder.CreateBitCast(V, DstVS->getFragmentType(I),
952                                      BCI.getName() + ".i" + Twine(I));
953     }
954   } else {
955     return false;
956   }
957 
958   gather(&BCI, Res, *DstVS);
959   return true;
960 }
961 
visitInsertElementInst(InsertElementInst & IEI)962 bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
963   std::optional<VectorSplit> VS = getVectorSplit(IEI.getType());
964   if (!VS)
965     return false;
966 
967   IRBuilder<> Builder(&IEI);
968   Scatterer Op0 = scatter(&IEI, IEI.getOperand(0), *VS);
969   Value *NewElt = IEI.getOperand(1);
970   Value *InsIdx = IEI.getOperand(2);
971 
972   ValueVector Res;
973   Res.resize(VS->NumFragments);
974 
975   if (auto *CI = dyn_cast<ConstantInt>(InsIdx)) {
976     unsigned Idx = CI->getZExtValue();
977     unsigned Fragment = Idx / VS->NumPacked;
978     for (unsigned I = 0; I < VS->NumFragments; ++I) {
979       if (I == Fragment) {
980         bool IsPacked = VS->NumPacked > 1;
981         if (Fragment == VS->NumFragments - 1 && VS->RemainderTy &&
982             !VS->RemainderTy->isVectorTy())
983           IsPacked = false;
984         if (IsPacked) {
985           Res[I] =
986               Builder.CreateInsertElement(Op0[I], NewElt, Idx % VS->NumPacked);
987         } else {
988           Res[I] = NewElt;
989         }
990       } else {
991         Res[I] = Op0[I];
992       }
993     }
994   } else {
995     // Never split a variable insertelement that isn't fully scalarized.
996     if (!ScalarizeVariableInsertExtract || VS->NumPacked > 1)
997       return false;
998 
999     for (unsigned I = 0; I < VS->NumFragments; ++I) {
1000       Value *ShouldReplace =
1001           Builder.CreateICmpEQ(InsIdx, ConstantInt::get(InsIdx->getType(), I),
1002                                InsIdx->getName() + ".is." + Twine(I));
1003       Value *OldElt = Op0[I];
1004       Res[I] = Builder.CreateSelect(ShouldReplace, NewElt, OldElt,
1005                                     IEI.getName() + ".i" + Twine(I));
1006     }
1007   }
1008 
1009   gather(&IEI, Res, *VS);
1010   return true;
1011 }
1012 
visitExtractElementInst(ExtractElementInst & EEI)1013 bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
1014   std::optional<VectorSplit> VS = getVectorSplit(EEI.getOperand(0)->getType());
1015   if (!VS)
1016     return false;
1017 
1018   IRBuilder<> Builder(&EEI);
1019   Scatterer Op0 = scatter(&EEI, EEI.getOperand(0), *VS);
1020   Value *ExtIdx = EEI.getOperand(1);
1021 
1022   if (auto *CI = dyn_cast<ConstantInt>(ExtIdx)) {
1023     unsigned Idx = CI->getZExtValue();
1024     unsigned Fragment = Idx / VS->NumPacked;
1025     Value *Res = Op0[Fragment];
1026     bool IsPacked = VS->NumPacked > 1;
1027     if (Fragment == VS->NumFragments - 1 && VS->RemainderTy &&
1028         !VS->RemainderTy->isVectorTy())
1029       IsPacked = false;
1030     if (IsPacked)
1031       Res = Builder.CreateExtractElement(Res, Idx % VS->NumPacked);
1032     replaceUses(&EEI, Res);
1033     return true;
1034   }
1035 
1036   // Never split a variable extractelement that isn't fully scalarized.
1037   if (!ScalarizeVariableInsertExtract || VS->NumPacked > 1)
1038     return false;
1039 
1040   Value *Res = PoisonValue::get(VS->VecTy->getElementType());
1041   for (unsigned I = 0; I < VS->NumFragments; ++I) {
1042     Value *ShouldExtract =
1043         Builder.CreateICmpEQ(ExtIdx, ConstantInt::get(ExtIdx->getType(), I),
1044                              ExtIdx->getName() + ".is." + Twine(I));
1045     Value *Elt = Op0[I];
1046     Res = Builder.CreateSelect(ShouldExtract, Elt, Res,
1047                                EEI.getName() + ".upto" + Twine(I));
1048   }
1049   replaceUses(&EEI, Res);
1050   return true;
1051 }
1052 
visitShuffleVectorInst(ShuffleVectorInst & SVI)1053 bool ScalarizerVisitor::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
1054   std::optional<VectorSplit> VS = getVectorSplit(SVI.getType());
1055   std::optional<VectorSplit> VSOp =
1056       getVectorSplit(SVI.getOperand(0)->getType());
1057   if (!VS || !VSOp || VS->NumPacked > 1 || VSOp->NumPacked > 1)
1058     return false;
1059 
1060   Scatterer Op0 = scatter(&SVI, SVI.getOperand(0), *VSOp);
1061   Scatterer Op1 = scatter(&SVI, SVI.getOperand(1), *VSOp);
1062   ValueVector Res;
1063   Res.resize(VS->NumFragments);
1064 
1065   for (unsigned I = 0; I < VS->NumFragments; ++I) {
1066     int Selector = SVI.getMaskValue(I);
1067     if (Selector < 0)
1068       Res[I] = PoisonValue::get(VS->VecTy->getElementType());
1069     else if (unsigned(Selector) < Op0.size())
1070       Res[I] = Op0[Selector];
1071     else
1072       Res[I] = Op1[Selector - Op0.size()];
1073   }
1074   gather(&SVI, Res, *VS);
1075   return true;
1076 }
1077 
visitPHINode(PHINode & PHI)1078 bool ScalarizerVisitor::visitPHINode(PHINode &PHI) {
1079   std::optional<VectorSplit> VS = getVectorSplit(PHI.getType());
1080   if (!VS)
1081     return false;
1082 
1083   IRBuilder<> Builder(&PHI);
1084   ValueVector Res;
1085   Res.resize(VS->NumFragments);
1086 
1087   unsigned NumOps = PHI.getNumOperands();
1088   for (unsigned I = 0; I < VS->NumFragments; ++I) {
1089     Res[I] = Builder.CreatePHI(VS->getFragmentType(I), NumOps,
1090                                PHI.getName() + ".i" + Twine(I));
1091   }
1092 
1093   for (unsigned I = 0; I < NumOps; ++I) {
1094     Scatterer Op = scatter(&PHI, PHI.getIncomingValue(I), *VS);
1095     BasicBlock *IncomingBlock = PHI.getIncomingBlock(I);
1096     for (unsigned J = 0; J < VS->NumFragments; ++J)
1097       cast<PHINode>(Res[J])->addIncoming(Op[J], IncomingBlock);
1098   }
1099   gather(&PHI, Res, *VS);
1100   return true;
1101 }
1102 
visitLoadInst(LoadInst & LI)1103 bool ScalarizerVisitor::visitLoadInst(LoadInst &LI) {
1104   if (!ScalarizeLoadStore)
1105     return false;
1106   if (!LI.isSimple())
1107     return false;
1108 
1109   std::optional<VectorLayout> Layout = getVectorLayout(
1110       LI.getType(), LI.getAlign(), LI.getDataLayout());
1111   if (!Layout)
1112     return false;
1113 
1114   IRBuilder<> Builder(&LI);
1115   Scatterer Ptr = scatter(&LI, LI.getPointerOperand(), Layout->VS);
1116   ValueVector Res;
1117   Res.resize(Layout->VS.NumFragments);
1118 
1119   for (unsigned I = 0; I < Layout->VS.NumFragments; ++I) {
1120     Res[I] = Builder.CreateAlignedLoad(Layout->VS.getFragmentType(I), Ptr[I],
1121                                        Align(Layout->getFragmentAlign(I)),
1122                                        LI.getName() + ".i" + Twine(I));
1123   }
1124   gather(&LI, Res, Layout->VS);
1125   return true;
1126 }
1127 
visitStoreInst(StoreInst & SI)1128 bool ScalarizerVisitor::visitStoreInst(StoreInst &SI) {
1129   if (!ScalarizeLoadStore)
1130     return false;
1131   if (!SI.isSimple())
1132     return false;
1133 
1134   Value *FullValue = SI.getValueOperand();
1135   std::optional<VectorLayout> Layout = getVectorLayout(
1136       FullValue->getType(), SI.getAlign(), SI.getDataLayout());
1137   if (!Layout)
1138     return false;
1139 
1140   IRBuilder<> Builder(&SI);
1141   Scatterer VPtr = scatter(&SI, SI.getPointerOperand(), Layout->VS);
1142   Scatterer VVal = scatter(&SI, FullValue, Layout->VS);
1143 
1144   ValueVector Stores;
1145   Stores.resize(Layout->VS.NumFragments);
1146   for (unsigned I = 0; I < Layout->VS.NumFragments; ++I) {
1147     Value *Val = VVal[I];
1148     Value *Ptr = VPtr[I];
1149     Stores[I] =
1150         Builder.CreateAlignedStore(Val, Ptr, Layout->getFragmentAlign(I));
1151   }
1152   transferMetadataAndIRFlags(&SI, Stores);
1153   return true;
1154 }
1155 
visitCallInst(CallInst & CI)1156 bool ScalarizerVisitor::visitCallInst(CallInst &CI) {
1157   return splitCall(CI);
1158 }
1159 
visitFreezeInst(FreezeInst & FI)1160 bool ScalarizerVisitor::visitFreezeInst(FreezeInst &FI) {
1161   return splitUnary(FI, [](IRBuilder<> &Builder, Value *Op, const Twine &Name) {
1162     return Builder.CreateFreeze(Op, Name);
1163   });
1164 }
1165 
1166 // Delete the instructions that we scalarized.  If a full vector result
1167 // is still needed, recreate it using InsertElements.
finish()1168 bool ScalarizerVisitor::finish() {
1169   // The presence of data in Gathered or Scattered indicates changes
1170   // made to the Function.
1171   if (Gathered.empty() && Scattered.empty() && !Scalarized)
1172     return false;
1173   for (const auto &GMI : Gathered) {
1174     Instruction *Op = GMI.first;
1175     ValueVector &CV = *GMI.second;
1176     if (!Op->use_empty()) {
1177       // The value is still needed, so recreate it using a series of
1178       // insertelements and/or shufflevectors.
1179       Value *Res;
1180       if (auto *Ty = dyn_cast<FixedVectorType>(Op->getType())) {
1181         BasicBlock *BB = Op->getParent();
1182         IRBuilder<> Builder(Op);
1183         if (isa<PHINode>(Op))
1184           Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
1185 
1186         VectorSplit VS = *getVectorSplit(Ty);
1187         assert(VS.NumFragments == CV.size());
1188 
1189         Res = concatenate(Builder, CV, VS, Op->getName());
1190 
1191         Res->takeName(Op);
1192       } else {
1193         assert(CV.size() == 1 && Op->getType() == CV[0]->getType());
1194         Res = CV[0];
1195         if (Op == Res)
1196           continue;
1197       }
1198       Op->replaceAllUsesWith(Res);
1199     }
1200     PotentiallyDeadInstrs.emplace_back(Op);
1201   }
1202   Gathered.clear();
1203   Scattered.clear();
1204   Scalarized = false;
1205 
1206   RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
1207 
1208   return true;
1209 }
1210 
run(Function & F,FunctionAnalysisManager & AM)1211 PreservedAnalyses ScalarizerPass::run(Function &F, FunctionAnalysisManager &AM) {
1212   DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F);
1213   ScalarizerVisitor Impl(DT, Options);
1214   bool Changed = Impl.visit(F);
1215   PreservedAnalyses PA;
1216   PA.preserve<DominatorTreeAnalysis>();
1217   return Changed ? PA : PreservedAnalyses::all();
1218 }
1219