xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Scalar/Scalarizer.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- Scalarizer.cpp - Scalarize vector operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass converts vector operations into scalar operations (or, optionally,
10 // operations on smaller vector widths), in order to expose optimization
11 // opportunities on the individual scalar operations.
12 // It is mainly intended for targets that do not have vector units, but it
13 // may also be useful for revectorizing code to different vector widths.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/Transforms/Scalar/Scalarizer.h"
18 #include "llvm/ADT/PostOrderIterator.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/Twine.h"
21 #include "llvm/Analysis/TargetTransformInfo.h"
22 #include "llvm/Analysis/VectorUtils.h"
23 #include "llvm/IR/Argument.h"
24 #include "llvm/IR/BasicBlock.h"
25 #include "llvm/IR/Constants.h"
26 #include "llvm/IR/DataLayout.h"
27 #include "llvm/IR/DerivedTypes.h"
28 #include "llvm/IR/Dominators.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/InstVisitor.h"
32 #include "llvm/IR/InstrTypes.h"
33 #include "llvm/IR/Instruction.h"
34 #include "llvm/IR/Instructions.h"
35 #include "llvm/IR/Intrinsics.h"
36 #include "llvm/IR/LLVMContext.h"
37 #include "llvm/IR/Module.h"
38 #include "llvm/IR/Type.h"
39 #include "llvm/IR/Value.h"
40 #include "llvm/InitializePasses.h"
41 #include "llvm/Support/Casting.h"
42 #include "llvm/Transforms/Utils/Local.h"
43 #include <cassert>
44 #include <cstdint>
45 #include <iterator>
46 #include <map>
47 #include <utility>
48 
49 using namespace llvm;
50 
51 #define DEBUG_TYPE "scalarizer"
52 
53 namespace {
54 
skipPastPhiNodesAndDbg(BasicBlock::iterator Itr)55 BasicBlock::iterator skipPastPhiNodesAndDbg(BasicBlock::iterator Itr) {
56   BasicBlock *BB = Itr->getParent();
57   if (isa<PHINode>(Itr))
58     Itr = BB->getFirstInsertionPt();
59   if (Itr != BB->end())
60     Itr = skipDebugIntrinsics(Itr);
61   return Itr;
62 }
63 
64 // Used to store the scattered form of a vector.
65 using ValueVector = SmallVector<Value *, 8>;
66 
67 // Used to map a vector Value and associated type to its scattered form.
68 // The associated type is only non-null for pointer values that are "scattered"
69 // when used as pointer operands to load or store.
70 //
71 // We use std::map because we want iterators to persist across insertion and
72 // because the values are relatively large.
73 using ScatterMap = std::map<std::pair<Value *, Type *>, ValueVector>;
74 
75 // Lists Instructions that have been replaced with scalar implementations,
76 // along with a pointer to their scattered forms.
77 using GatherList = SmallVector<std::pair<Instruction *, ValueVector *>, 16>;
78 
79 struct VectorSplit {
80   // The type of the vector.
81   FixedVectorType *VecTy = nullptr;
82 
83   // The number of elements packed in a fragment (other than the remainder).
84   unsigned NumPacked = 0;
85 
86   // The number of fragments (scalars or smaller vectors) into which the vector
87   // shall be split.
88   unsigned NumFragments = 0;
89 
90   // The type of each complete fragment.
91   Type *SplitTy = nullptr;
92 
93   // The type of the remainder (last) fragment; null if all fragments are
94   // complete.
95   Type *RemainderTy = nullptr;
96 
getFragmentType__anon7628c1430111::VectorSplit97   Type *getFragmentType(unsigned I) const {
98     return RemainderTy && I == NumFragments - 1 ? RemainderTy : SplitTy;
99   }
100 };
101 
102 // Provides a very limited vector-like interface for lazily accessing one
103 // component of a scattered vector or vector pointer.
104 class Scatterer {
105 public:
106   Scatterer() = default;
107 
108   // Scatter V into Size components.  If new instructions are needed,
109   // insert them before BBI in BB.  If Cache is nonnull, use it to cache
110   // the results.
111   Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,
112             const VectorSplit &VS, ValueVector *cachePtr = nullptr);
113 
114   // Return component I, creating a new Value for it if necessary.
115   Value *operator[](unsigned I);
116 
117   // Return the number of components.
size() const118   unsigned size() const { return VS.NumFragments; }
119 
120 private:
121   BasicBlock *BB;
122   BasicBlock::iterator BBI;
123   Value *V;
124   VectorSplit VS;
125   bool IsPointer;
126   ValueVector *CachePtr;
127   ValueVector Tmp;
128 };
129 
130 // FCmpSplitter(FCI)(Builder, X, Y, Name) uses Builder to create an FCmp
131 // called Name that compares X and Y in the same way as FCI.
132 struct FCmpSplitter {
FCmpSplitter__anon7628c1430111::FCmpSplitter133   FCmpSplitter(FCmpInst &fci) : FCI(fci) {}
134 
operator ()__anon7628c1430111::FCmpSplitter135   Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
136                     const Twine &Name) const {
137     return Builder.CreateFCmp(FCI.getPredicate(), Op0, Op1, Name);
138   }
139 
140   FCmpInst &FCI;
141 };
142 
143 // ICmpSplitter(ICI)(Builder, X, Y, Name) uses Builder to create an ICmp
144 // called Name that compares X and Y in the same way as ICI.
145 struct ICmpSplitter {
ICmpSplitter__anon7628c1430111::ICmpSplitter146   ICmpSplitter(ICmpInst &ici) : ICI(ici) {}
147 
operator ()__anon7628c1430111::ICmpSplitter148   Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
149                     const Twine &Name) const {
150     return Builder.CreateICmp(ICI.getPredicate(), Op0, Op1, Name);
151   }
152 
153   ICmpInst &ICI;
154 };
155 
156 // UnarySplitter(UO)(Builder, X, Name) uses Builder to create
157 // a unary operator like UO called Name with operand X.
158 struct UnarySplitter {
UnarySplitter__anon7628c1430111::UnarySplitter159   UnarySplitter(UnaryOperator &uo) : UO(uo) {}
160 
operator ()__anon7628c1430111::UnarySplitter161   Value *operator()(IRBuilder<> &Builder, Value *Op, const Twine &Name) const {
162     return Builder.CreateUnOp(UO.getOpcode(), Op, Name);
163   }
164 
165   UnaryOperator &UO;
166 };
167 
168 // BinarySplitter(BO)(Builder, X, Y, Name) uses Builder to create
169 // a binary operator like BO called Name with operands X and Y.
170 struct BinarySplitter {
BinarySplitter__anon7628c1430111::BinarySplitter171   BinarySplitter(BinaryOperator &bo) : BO(bo) {}
172 
operator ()__anon7628c1430111::BinarySplitter173   Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
174                     const Twine &Name) const {
175     return Builder.CreateBinOp(BO.getOpcode(), Op0, Op1, Name);
176   }
177 
178   BinaryOperator &BO;
179 };
180 
181 // Information about a load or store that we're scalarizing.
182 struct VectorLayout {
183   VectorLayout() = default;
184 
185   // Return the alignment of fragment Frag.
getFragmentAlign__anon7628c1430111::VectorLayout186   Align getFragmentAlign(unsigned Frag) {
187     return commonAlignment(VecAlign, Frag * SplitSize);
188   }
189 
190   // The split of the underlying vector type.
191   VectorSplit VS;
192 
193   // The alignment of the vector.
194   Align VecAlign;
195 
196   // The size of each (non-remainder) fragment in bytes.
197   uint64_t SplitSize = 0;
198 };
199 
isStructOfMatchingFixedVectors(Type * Ty)200 static bool isStructOfMatchingFixedVectors(Type *Ty) {
201   if (!isa<StructType>(Ty))
202     return false;
203   unsigned StructSize = Ty->getNumContainedTypes();
204   if (StructSize < 1)
205     return false;
206   FixedVectorType *VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType(0));
207   if (!VecTy)
208     return false;
209   unsigned VecSize = VecTy->getNumElements();
210   for (unsigned I = 1; I < StructSize; I++) {
211     VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType(I));
212     if (!VecTy || VecSize != VecTy->getNumElements())
213       return false;
214   }
215   return true;
216 }
217 
218 /// Concatenate the given fragments to a single vector value of the type
219 /// described in @p VS.
concatenate(IRBuilder<> & Builder,ArrayRef<Value * > Fragments,const VectorSplit & VS,Twine Name)220 static Value *concatenate(IRBuilder<> &Builder, ArrayRef<Value *> Fragments,
221                           const VectorSplit &VS, Twine Name) {
222   unsigned NumElements = VS.VecTy->getNumElements();
223   SmallVector<int> ExtendMask;
224   SmallVector<int> InsertMask;
225 
226   if (VS.NumPacked > 1) {
227     // Prepare the shufflevector masks once and re-use them for all
228     // fragments.
229     ExtendMask.resize(NumElements, -1);
230     for (unsigned I = 0; I < VS.NumPacked; ++I)
231       ExtendMask[I] = I;
232 
233     InsertMask.resize(NumElements);
234     for (unsigned I = 0; I < NumElements; ++I)
235       InsertMask[I] = I;
236   }
237 
238   Value *Res = PoisonValue::get(VS.VecTy);
239   for (unsigned I = 0; I < VS.NumFragments; ++I) {
240     Value *Fragment = Fragments[I];
241 
242     unsigned NumPacked = VS.NumPacked;
243     if (I == VS.NumFragments - 1 && VS.RemainderTy) {
244       if (auto *RemVecTy = dyn_cast<FixedVectorType>(VS.RemainderTy))
245         NumPacked = RemVecTy->getNumElements();
246       else
247         NumPacked = 1;
248     }
249 
250     if (NumPacked == 1) {
251       Res = Builder.CreateInsertElement(Res, Fragment, I * VS.NumPacked,
252                                         Name + ".upto" + Twine(I));
253     } else {
254       Fragment = Builder.CreateShuffleVector(Fragment, Fragment, ExtendMask);
255       if (I == 0) {
256         Res = Fragment;
257       } else {
258         for (unsigned J = 0; J < NumPacked; ++J)
259           InsertMask[I * VS.NumPacked + J] = NumElements + J;
260         Res = Builder.CreateShuffleVector(Res, Fragment, InsertMask,
261                                           Name + ".upto" + Twine(I));
262         for (unsigned J = 0; J < NumPacked; ++J)
263           InsertMask[I * VS.NumPacked + J] = I * VS.NumPacked + J;
264       }
265     }
266   }
267 
268   return Res;
269 }
270 
271 class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
272 public:
ScalarizerVisitor(DominatorTree * DT,const TargetTransformInfo * TTI,ScalarizerPassOptions Options)273   ScalarizerVisitor(DominatorTree *DT, const TargetTransformInfo *TTI,
274                     ScalarizerPassOptions Options)
275       : DT(DT), TTI(TTI),
276         ScalarizeVariableInsertExtract(Options.ScalarizeVariableInsertExtract),
277         ScalarizeLoadStore(Options.ScalarizeLoadStore),
278         ScalarizeMinBits(Options.ScalarizeMinBits) {}
279 
280   bool visit(Function &F);
281 
282   // InstVisitor methods.  They return true if the instruction was scalarized,
283   // false if nothing changed.
visitInstruction(Instruction & I)284   bool visitInstruction(Instruction &I) { return false; }
285   bool visitSelectInst(SelectInst &SI);
286   bool visitICmpInst(ICmpInst &ICI);
287   bool visitFCmpInst(FCmpInst &FCI);
288   bool visitUnaryOperator(UnaryOperator &UO);
289   bool visitBinaryOperator(BinaryOperator &BO);
290   bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
291   bool visitCastInst(CastInst &CI);
292   bool visitBitCastInst(BitCastInst &BCI);
293   bool visitInsertElementInst(InsertElementInst &IEI);
294   bool visitExtractElementInst(ExtractElementInst &EEI);
295   bool visitExtractValueInst(ExtractValueInst &EVI);
296   bool visitShuffleVectorInst(ShuffleVectorInst &SVI);
297   bool visitPHINode(PHINode &PHI);
298   bool visitLoadInst(LoadInst &LI);
299   bool visitStoreInst(StoreInst &SI);
300   bool visitCallInst(CallInst &ICI);
301   bool visitFreezeInst(FreezeInst &FI);
302 
303 private:
304   Scatterer scatter(Instruction *Point, Value *V, const VectorSplit &VS);
305   void gather(Instruction *Op, const ValueVector &CV, const VectorSplit &VS);
306   void replaceUses(Instruction *Op, Value *CV);
307   bool canTransferMetadata(unsigned Kind);
308   void transferMetadataAndIRFlags(Instruction *Op, const ValueVector &CV);
309   std::optional<VectorSplit> getVectorSplit(Type *Ty);
310   std::optional<VectorLayout> getVectorLayout(Type *Ty, Align Alignment,
311                                               const DataLayout &DL);
312   bool finish();
313 
314   template<typename T> bool splitUnary(Instruction &, const T &);
315   template<typename T> bool splitBinary(Instruction &, const T &);
316 
317   bool splitCall(CallInst &CI);
318 
319   ScatterMap Scattered;
320   GatherList Gathered;
321   bool Scalarized;
322 
323   SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
324 
325   DominatorTree *DT;
326   const TargetTransformInfo *TTI;
327 
328   const bool ScalarizeVariableInsertExtract;
329   const bool ScalarizeLoadStore;
330   const unsigned ScalarizeMinBits;
331 };
332 
333 class ScalarizerLegacyPass : public FunctionPass {
334 public:
335   static char ID;
336   ScalarizerPassOptions Options;
ScalarizerLegacyPass()337   ScalarizerLegacyPass() : FunctionPass(ID), Options() {}
338   ScalarizerLegacyPass(const ScalarizerPassOptions &Options);
339   bool runOnFunction(Function &F) override;
340   void getAnalysisUsage(AnalysisUsage &AU) const override;
341 };
342 
343 } // end anonymous namespace
344 
ScalarizerLegacyPass(const ScalarizerPassOptions & Options)345 ScalarizerLegacyPass::ScalarizerLegacyPass(const ScalarizerPassOptions &Options)
346     : FunctionPass(ID), Options(Options) {}
347 
getAnalysisUsage(AnalysisUsage & AU) const348 void ScalarizerLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const {
349   AU.addRequired<DominatorTreeWrapperPass>();
350   AU.addRequired<TargetTransformInfoWrapperPass>();
351   AU.addPreserved<DominatorTreeWrapperPass>();
352 }
353 
354 char ScalarizerLegacyPass::ID = 0;
355 INITIALIZE_PASS_BEGIN(ScalarizerLegacyPass, "scalarizer",
356                       "Scalarize vector operations", false, false)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)357 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
358 INITIALIZE_PASS_END(ScalarizerLegacyPass, "scalarizer",
359                     "Scalarize vector operations", false, false)
360 
361 Scatterer::Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,
362                      const VectorSplit &VS, ValueVector *cachePtr)
363     : BB(bb), BBI(bbi), V(v), VS(VS), CachePtr(cachePtr) {
364   IsPointer = V->getType()->isPointerTy();
365   if (!CachePtr) {
366     Tmp.resize(VS.NumFragments, nullptr);
367   } else {
368     assert((CachePtr->empty() || VS.NumFragments == CachePtr->size() ||
369             IsPointer) &&
370            "Inconsistent vector sizes");
371     if (VS.NumFragments > CachePtr->size())
372       CachePtr->resize(VS.NumFragments, nullptr);
373   }
374 }
375 
376 // Return fragment Frag, creating a new Value for it if necessary.
operator [](unsigned Frag)377 Value *Scatterer::operator[](unsigned Frag) {
378   ValueVector &CV = CachePtr ? *CachePtr : Tmp;
379   // Try to reuse a previous value.
380   if (CV[Frag])
381     return CV[Frag];
382   IRBuilder<> Builder(BB, BBI);
383   if (IsPointer) {
384     if (Frag == 0)
385       CV[Frag] = V;
386     else
387       CV[Frag] = Builder.CreateConstGEP1_32(VS.SplitTy, V, Frag,
388                                             V->getName() + ".i" + Twine(Frag));
389     return CV[Frag];
390   }
391 
392   Type *FragmentTy = VS.getFragmentType(Frag);
393 
394   if (auto *VecTy = dyn_cast<FixedVectorType>(FragmentTy)) {
395     SmallVector<int> Mask;
396     for (unsigned J = 0; J < VecTy->getNumElements(); ++J)
397       Mask.push_back(Frag * VS.NumPacked + J);
398     CV[Frag] =
399         Builder.CreateShuffleVector(V, PoisonValue::get(V->getType()), Mask,
400                                     V->getName() + ".i" + Twine(Frag));
401   } else {
402     // Search through a chain of InsertElementInsts looking for element Frag.
403     // Record other elements in the cache.  The new V is still suitable
404     // for all uncached indices.
405     while (true) {
406       InsertElementInst *Insert = dyn_cast<InsertElementInst>(V);
407       if (!Insert)
408         break;
409       ConstantInt *Idx = dyn_cast<ConstantInt>(Insert->getOperand(2));
410       if (!Idx)
411         break;
412       unsigned J = Idx->getZExtValue();
413       V = Insert->getOperand(0);
414       if (Frag * VS.NumPacked == J) {
415         CV[Frag] = Insert->getOperand(1);
416         return CV[Frag];
417       }
418 
419       if (VS.NumPacked == 1 && !CV[J]) {
420         // Only cache the first entry we find for each index we're not actively
421         // searching for. This prevents us from going too far up the chain and
422         // caching incorrect entries.
423         CV[J] = Insert->getOperand(1);
424       }
425     }
426     CV[Frag] = Builder.CreateExtractElement(V, Frag * VS.NumPacked,
427                                             V->getName() + ".i" + Twine(Frag));
428   }
429 
430   return CV[Frag];
431 }
432 
runOnFunction(Function & F)433 bool ScalarizerLegacyPass::runOnFunction(Function &F) {
434   if (skipFunction(F))
435     return false;
436 
437   DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
438   const TargetTransformInfo *TTI =
439       &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
440   ScalarizerVisitor Impl(DT, TTI, Options);
441   return Impl.visit(F);
442 }
443 
createScalarizerPass(const ScalarizerPassOptions & Options)444 FunctionPass *llvm::createScalarizerPass(const ScalarizerPassOptions &Options) {
445   return new ScalarizerLegacyPass(Options);
446 }
447 
visit(Function & F)448 bool ScalarizerVisitor::visit(Function &F) {
449   assert(Gathered.empty() && Scattered.empty());
450 
451   Scalarized = false;
452 
453   // To ensure we replace gathered components correctly we need to do an ordered
454   // traversal of the basic blocks in the function.
455   ReversePostOrderTraversal<BasicBlock *> RPOT(&F.getEntryBlock());
456   for (BasicBlock *BB : RPOT) {
457     for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
458       Instruction *I = &*II;
459       bool Done = InstVisitor::visit(I);
460       ++II;
461       if (Done && I->getType()->isVoidTy())
462         I->eraseFromParent();
463     }
464   }
465   return finish();
466 }
467 
468 // Return a scattered form of V that can be accessed by Point.  V must be a
469 // vector or a pointer to a vector.
scatter(Instruction * Point,Value * V,const VectorSplit & VS)470 Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V,
471                                      const VectorSplit &VS) {
472   if (Argument *VArg = dyn_cast<Argument>(V)) {
473     // Put the scattered form of arguments in the entry block,
474     // so that it can be used everywhere.
475     Function *F = VArg->getParent();
476     BasicBlock *BB = &F->getEntryBlock();
477     return Scatterer(BB, BB->begin(), V, VS, &Scattered[{V, VS.SplitTy}]);
478   }
479   if (Instruction *VOp = dyn_cast<Instruction>(V)) {
480     // When scalarizing PHI nodes we might try to examine/rewrite InsertElement
481     // nodes in predecessors. If those predecessors are unreachable from entry,
482     // then the IR in those blocks could have unexpected properties resulting in
483     // infinite loops in Scatterer::operator[]. By simply treating values
484     // originating from instructions in unreachable blocks as undef we do not
485     // need to analyse them further.
486     if (!DT->isReachableFromEntry(VOp->getParent()))
487       return Scatterer(Point->getParent(), Point->getIterator(),
488                        PoisonValue::get(V->getType()), VS);
489     // Put the scattered form of an instruction directly after the
490     // instruction, skipping over PHI nodes and debug intrinsics.
491     BasicBlock *BB = VOp->getParent();
492     return Scatterer(
493         BB, skipPastPhiNodesAndDbg(std::next(BasicBlock::iterator(VOp))), V, VS,
494         &Scattered[{V, VS.SplitTy}]);
495   }
496   // In the fallback case, just put the scattered before Point and
497   // keep the result local to Point.
498   return Scatterer(Point->getParent(), Point->getIterator(), V, VS);
499 }
500 
501 // Replace Op with the gathered form of the components in CV.  Defer the
502 // deletion of Op and creation of the gathered form to the end of the pass,
503 // so that we can avoid creating the gathered form if all uses of Op are
504 // replaced with uses of CV.
gather(Instruction * Op,const ValueVector & CV,const VectorSplit & VS)505 void ScalarizerVisitor::gather(Instruction *Op, const ValueVector &CV,
506                                const VectorSplit &VS) {
507   transferMetadataAndIRFlags(Op, CV);
508 
509   // If we already have a scattered form of Op (created from ExtractElements
510   // of Op itself), replace them with the new form.
511   ValueVector &SV = Scattered[{Op, VS.SplitTy}];
512   if (!SV.empty()) {
513     for (unsigned I = 0, E = SV.size(); I != E; ++I) {
514       Value *V = SV[I];
515       if (V == nullptr || SV[I] == CV[I])
516         continue;
517 
518       Instruction *Old = cast<Instruction>(V);
519       if (isa<Instruction>(CV[I]))
520         CV[I]->takeName(Old);
521       Old->replaceAllUsesWith(CV[I]);
522       PotentiallyDeadInstrs.emplace_back(Old);
523     }
524   }
525   SV = CV;
526   Gathered.push_back(GatherList::value_type(Op, &SV));
527 }
528 
529 // Replace Op with CV and collect Op has a potentially dead instruction.
replaceUses(Instruction * Op,Value * CV)530 void ScalarizerVisitor::replaceUses(Instruction *Op, Value *CV) {
531   if (CV != Op) {
532     Op->replaceAllUsesWith(CV);
533     PotentiallyDeadInstrs.emplace_back(Op);
534     Scalarized = true;
535   }
536 }
537 
538 // Return true if it is safe to transfer the given metadata tag from
539 // vector to scalar instructions.
canTransferMetadata(unsigned Tag)540 bool ScalarizerVisitor::canTransferMetadata(unsigned Tag) {
541   return (Tag == LLVMContext::MD_tbaa
542           || Tag == LLVMContext::MD_fpmath
543           || Tag == LLVMContext::MD_tbaa_struct
544           || Tag == LLVMContext::MD_invariant_load
545           || Tag == LLVMContext::MD_alias_scope
546           || Tag == LLVMContext::MD_noalias
547           || Tag == LLVMContext::MD_mem_parallel_loop_access
548           || Tag == LLVMContext::MD_access_group);
549 }
550 
551 // Transfer metadata from Op to the instructions in CV if it is known
552 // to be safe to do so.
transferMetadataAndIRFlags(Instruction * Op,const ValueVector & CV)553 void ScalarizerVisitor::transferMetadataAndIRFlags(Instruction *Op,
554                                                    const ValueVector &CV) {
555   SmallVector<std::pair<unsigned, MDNode *>, 4> MDs;
556   Op->getAllMetadataOtherThanDebugLoc(MDs);
557   for (Value *V : CV) {
558     if (Instruction *New = dyn_cast<Instruction>(V)) {
559       for (const auto &MD : MDs)
560         if (canTransferMetadata(MD.first))
561           New->setMetadata(MD.first, MD.second);
562       New->copyIRFlags(Op);
563       if (Op->getDebugLoc() && !New->getDebugLoc())
564         New->setDebugLoc(Op->getDebugLoc());
565     }
566   }
567 }
568 
569 // Determine how Ty is split, if at all.
getVectorSplit(Type * Ty)570 std::optional<VectorSplit> ScalarizerVisitor::getVectorSplit(Type *Ty) {
571   VectorSplit Split;
572   Split.VecTy = dyn_cast<FixedVectorType>(Ty);
573   if (!Split.VecTy)
574     return {};
575 
576   unsigned NumElems = Split.VecTy->getNumElements();
577   Type *ElemTy = Split.VecTy->getElementType();
578 
579   if (NumElems == 1 || ElemTy->isPointerTy() ||
580       2 * ElemTy->getScalarSizeInBits() > ScalarizeMinBits) {
581     Split.NumPacked = 1;
582     Split.NumFragments = NumElems;
583     Split.SplitTy = ElemTy;
584   } else {
585     Split.NumPacked = ScalarizeMinBits / ElemTy->getScalarSizeInBits();
586     if (Split.NumPacked >= NumElems)
587       return {};
588 
589     Split.NumFragments = divideCeil(NumElems, Split.NumPacked);
590     Split.SplitTy = FixedVectorType::get(ElemTy, Split.NumPacked);
591 
592     unsigned RemainderElems = NumElems % Split.NumPacked;
593     if (RemainderElems > 1)
594       Split.RemainderTy = FixedVectorType::get(ElemTy, RemainderElems);
595     else if (RemainderElems == 1)
596       Split.RemainderTy = ElemTy;
597   }
598 
599   return Split;
600 }
601 
602 // Try to fill in Layout from Ty, returning true on success.  Alignment is
603 // the alignment of the vector, or std::nullopt if the ABI default should be
604 // used.
605 std::optional<VectorLayout>
getVectorLayout(Type * Ty,Align Alignment,const DataLayout & DL)606 ScalarizerVisitor::getVectorLayout(Type *Ty, Align Alignment,
607                                    const DataLayout &DL) {
608   std::optional<VectorSplit> VS = getVectorSplit(Ty);
609   if (!VS)
610     return {};
611 
612   VectorLayout Layout;
613   Layout.VS = *VS;
614   // Check that we're dealing with full-byte fragments.
615   if (!DL.typeSizeEqualsStoreSize(VS->SplitTy) ||
616       (VS->RemainderTy && !DL.typeSizeEqualsStoreSize(VS->RemainderTy)))
617     return {};
618   Layout.VecAlign = Alignment;
619   Layout.SplitSize = DL.getTypeStoreSize(VS->SplitTy);
620   return Layout;
621 }
622 
623 // Scalarize one-operand instruction I, using Split(Builder, X, Name)
624 // to create an instruction like I with operand X and name Name.
625 template<typename Splitter>
splitUnary(Instruction & I,const Splitter & Split)626 bool ScalarizerVisitor::splitUnary(Instruction &I, const Splitter &Split) {
627   std::optional<VectorSplit> VS = getVectorSplit(I.getType());
628   if (!VS)
629     return false;
630 
631   std::optional<VectorSplit> OpVS;
632   if (I.getOperand(0)->getType() == I.getType()) {
633     OpVS = VS;
634   } else {
635     OpVS = getVectorSplit(I.getOperand(0)->getType());
636     if (!OpVS || VS->NumPacked != OpVS->NumPacked)
637       return false;
638   }
639 
640   IRBuilder<> Builder(&I);
641   Scatterer Op = scatter(&I, I.getOperand(0), *OpVS);
642   assert(Op.size() == VS->NumFragments && "Mismatched unary operation");
643   ValueVector Res;
644   Res.resize(VS->NumFragments);
645   for (unsigned Frag = 0; Frag < VS->NumFragments; ++Frag)
646     Res[Frag] = Split(Builder, Op[Frag], I.getName() + ".i" + Twine(Frag));
647   gather(&I, Res, *VS);
648   return true;
649 }
650 
651 // Scalarize two-operand instruction I, using Split(Builder, X, Y, Name)
652 // to create an instruction like I with operands X and Y and name Name.
653 template<typename Splitter>
splitBinary(Instruction & I,const Splitter & Split)654 bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
655   std::optional<VectorSplit> VS = getVectorSplit(I.getType());
656   if (!VS)
657     return false;
658 
659   std::optional<VectorSplit> OpVS;
660   if (I.getOperand(0)->getType() == I.getType()) {
661     OpVS = VS;
662   } else {
663     OpVS = getVectorSplit(I.getOperand(0)->getType());
664     if (!OpVS || VS->NumPacked != OpVS->NumPacked)
665       return false;
666   }
667 
668   IRBuilder<> Builder(&I);
669   Scatterer VOp0 = scatter(&I, I.getOperand(0), *OpVS);
670   Scatterer VOp1 = scatter(&I, I.getOperand(1), *OpVS);
671   assert(VOp0.size() == VS->NumFragments && "Mismatched binary operation");
672   assert(VOp1.size() == VS->NumFragments && "Mismatched binary operation");
673   ValueVector Res;
674   Res.resize(VS->NumFragments);
675   for (unsigned Frag = 0; Frag < VS->NumFragments; ++Frag) {
676     Value *Op0 = VOp0[Frag];
677     Value *Op1 = VOp1[Frag];
678     Res[Frag] = Split(Builder, Op0, Op1, I.getName() + ".i" + Twine(Frag));
679   }
680   gather(&I, Res, *VS);
681   return true;
682 }
683 
684 /// If a call to a vector typed intrinsic function, split into a scalar call per
685 /// element if possible for the intrinsic.
splitCall(CallInst & CI)686 bool ScalarizerVisitor::splitCall(CallInst &CI) {
687   Type *CallType = CI.getType();
688   bool AreAllVectorsOfMatchingSize = isStructOfMatchingFixedVectors(CallType);
689   std::optional<VectorSplit> VS;
690   if (AreAllVectorsOfMatchingSize)
691     VS = getVectorSplit(CallType->getContainedType(0));
692   else
693     VS = getVectorSplit(CallType);
694   if (!VS)
695     return false;
696 
697   Function *F = CI.getCalledFunction();
698   if (!F)
699     return false;
700 
701   Intrinsic::ID ID = F->getIntrinsicID();
702 
703   if (ID == Intrinsic::not_intrinsic || !isTriviallyScalarizable(ID, TTI))
704     return false;
705 
706   // unsigned NumElems = VT->getNumElements();
707   unsigned NumArgs = CI.arg_size();
708 
709   ValueVector ScalarOperands(NumArgs);
710   SmallVector<Scatterer, 8> Scattered(NumArgs);
711   SmallVector<int> OverloadIdx(NumArgs, -1);
712 
713   SmallVector<llvm::Type *, 3> Tys;
714   // Add return type if intrinsic is overloaded on it.
715   if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1, TTI))
716     Tys.push_back(VS->SplitTy);
717 
718   if (AreAllVectorsOfMatchingSize) {
719     for (unsigned I = 1; I < CallType->getNumContainedTypes(); I++) {
720       std::optional<VectorSplit> CurrVS =
721           getVectorSplit(cast<FixedVectorType>(CallType->getContainedType(I)));
722       // It is possible for VectorSplit.NumPacked >= NumElems. If that happens a
723       // VectorSplit is not returned and we will bailout of handling this call.
724       // The secondary bailout case is if NumPacked does not match. This can
725       // happen if ScalarizeMinBits is not set to the default. This means with
726       // certain ScalarizeMinBits intrinsics like frexp will only scalarize when
727       // the struct elements have the same bitness.
728       if (!CurrVS || CurrVS->NumPacked != VS->NumPacked)
729         return false;
730       if (isVectorIntrinsicWithStructReturnOverloadAtField(ID, I, TTI))
731         Tys.push_back(CurrVS->SplitTy);
732     }
733   }
734   // Assumes that any vector type has the same number of elements as the return
735   // vector type, which is true for all current intrinsics.
736   for (unsigned I = 0; I != NumArgs; ++I) {
737     Value *OpI = CI.getOperand(I);
738     if ([[maybe_unused]] auto *OpVecTy =
739             dyn_cast<FixedVectorType>(OpI->getType())) {
740       assert(OpVecTy->getNumElements() == VS->VecTy->getNumElements());
741       std::optional<VectorSplit> OpVS = getVectorSplit(OpI->getType());
742       if (!OpVS || OpVS->NumPacked != VS->NumPacked) {
743         // The natural split of the operand doesn't match the result. This could
744         // happen if the vector elements are different and the ScalarizeMinBits
745         // option is used.
746         //
747         // We could in principle handle this case as well, at the cost of
748         // complicating the scattering machinery to support multiple scattering
749         // granularities for a single value.
750         return false;
751       }
752 
753       Scattered[I] = scatter(&CI, OpI, *OpVS);
754       if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI)) {
755         OverloadIdx[I] = Tys.size();
756         Tys.push_back(OpVS->SplitTy);
757       }
758     } else {
759       ScalarOperands[I] = OpI;
760       if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI))
761         Tys.push_back(OpI->getType());
762     }
763   }
764 
765   ValueVector Res(VS->NumFragments);
766   ValueVector ScalarCallOps(NumArgs);
767 
768   Function *NewIntrin =
769       Intrinsic::getOrInsertDeclaration(F->getParent(), ID, Tys);
770   IRBuilder<> Builder(&CI);
771 
772   // Perform actual scalarization, taking care to preserve any scalar operands.
773   for (unsigned I = 0; I < VS->NumFragments; ++I) {
774     bool IsRemainder = I == VS->NumFragments - 1 && VS->RemainderTy;
775     ScalarCallOps.clear();
776 
777     if (IsRemainder)
778       Tys[0] = VS->RemainderTy;
779 
780     for (unsigned J = 0; J != NumArgs; ++J) {
781       if (isVectorIntrinsicWithScalarOpAtArg(ID, J, TTI)) {
782         ScalarCallOps.push_back(ScalarOperands[J]);
783       } else {
784         ScalarCallOps.push_back(Scattered[J][I]);
785         if (IsRemainder && OverloadIdx[J] >= 0)
786           Tys[OverloadIdx[J]] = Scattered[J][I]->getType();
787       }
788     }
789 
790     if (IsRemainder)
791       NewIntrin = Intrinsic::getOrInsertDeclaration(F->getParent(), ID, Tys);
792 
793     Res[I] = Builder.CreateCall(NewIntrin, ScalarCallOps,
794                                 CI.getName() + ".i" + Twine(I));
795   }
796 
797   gather(&CI, Res, *VS);
798   return true;
799 }
800 
visitSelectInst(SelectInst & SI)801 bool ScalarizerVisitor::visitSelectInst(SelectInst &SI) {
802   std::optional<VectorSplit> VS = getVectorSplit(SI.getType());
803   if (!VS)
804     return false;
805 
806   std::optional<VectorSplit> CondVS;
807   if (isa<FixedVectorType>(SI.getCondition()->getType())) {
808     CondVS = getVectorSplit(SI.getCondition()->getType());
809     if (!CondVS || CondVS->NumPacked != VS->NumPacked) {
810       // This happens when ScalarizeMinBits is used.
811       return false;
812     }
813   }
814 
815   IRBuilder<> Builder(&SI);
816   Scatterer VOp1 = scatter(&SI, SI.getOperand(1), *VS);
817   Scatterer VOp2 = scatter(&SI, SI.getOperand(2), *VS);
818   assert(VOp1.size() == VS->NumFragments && "Mismatched select");
819   assert(VOp2.size() == VS->NumFragments && "Mismatched select");
820   ValueVector Res;
821   Res.resize(VS->NumFragments);
822 
823   if (CondVS) {
824     Scatterer VOp0 = scatter(&SI, SI.getOperand(0), *CondVS);
825     assert(VOp0.size() == CondVS->NumFragments && "Mismatched select");
826     for (unsigned I = 0; I < VS->NumFragments; ++I) {
827       Value *Op0 = VOp0[I];
828       Value *Op1 = VOp1[I];
829       Value *Op2 = VOp2[I];
830       Res[I] = Builder.CreateSelect(Op0, Op1, Op2,
831                                     SI.getName() + ".i" + Twine(I));
832     }
833   } else {
834     Value *Op0 = SI.getOperand(0);
835     for (unsigned I = 0; I < VS->NumFragments; ++I) {
836       Value *Op1 = VOp1[I];
837       Value *Op2 = VOp2[I];
838       Res[I] = Builder.CreateSelect(Op0, Op1, Op2,
839                                     SI.getName() + ".i" + Twine(I));
840     }
841   }
842   gather(&SI, Res, *VS);
843   return true;
844 }
845 
visitICmpInst(ICmpInst & ICI)846 bool ScalarizerVisitor::visitICmpInst(ICmpInst &ICI) {
847   return splitBinary(ICI, ICmpSplitter(ICI));
848 }
849 
visitFCmpInst(FCmpInst & FCI)850 bool ScalarizerVisitor::visitFCmpInst(FCmpInst &FCI) {
851   return splitBinary(FCI, FCmpSplitter(FCI));
852 }
853 
visitUnaryOperator(UnaryOperator & UO)854 bool ScalarizerVisitor::visitUnaryOperator(UnaryOperator &UO) {
855   return splitUnary(UO, UnarySplitter(UO));
856 }
857 
visitBinaryOperator(BinaryOperator & BO)858 bool ScalarizerVisitor::visitBinaryOperator(BinaryOperator &BO) {
859   return splitBinary(BO, BinarySplitter(BO));
860 }
861 
visitGetElementPtrInst(GetElementPtrInst & GEPI)862 bool ScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
863   std::optional<VectorSplit> VS = getVectorSplit(GEPI.getType());
864   if (!VS)
865     return false;
866 
867   IRBuilder<> Builder(&GEPI);
868   unsigned NumIndices = GEPI.getNumIndices();
869 
870   // The base pointer and indices might be scalar even if it's a vector GEP.
871   SmallVector<Value *, 8> ScalarOps{1 + NumIndices};
872   SmallVector<Scatterer, 8> ScatterOps{1 + NumIndices};
873 
874   for (unsigned I = 0; I < 1 + NumIndices; ++I) {
875     if (auto *VecTy =
876             dyn_cast<FixedVectorType>(GEPI.getOperand(I)->getType())) {
877       std::optional<VectorSplit> OpVS = getVectorSplit(VecTy);
878       if (!OpVS || OpVS->NumPacked != VS->NumPacked) {
879         // This can happen when ScalarizeMinBits is used.
880         return false;
881       }
882       ScatterOps[I] = scatter(&GEPI, GEPI.getOperand(I), *OpVS);
883     } else {
884       ScalarOps[I] = GEPI.getOperand(I);
885     }
886   }
887 
888   ValueVector Res;
889   Res.resize(VS->NumFragments);
890   for (unsigned I = 0; I < VS->NumFragments; ++I) {
891     SmallVector<Value *, 8> SplitOps;
892     SplitOps.resize(1 + NumIndices);
893     for (unsigned J = 0; J < 1 + NumIndices; ++J) {
894       if (ScalarOps[J])
895         SplitOps[J] = ScalarOps[J];
896       else
897         SplitOps[J] = ScatterOps[J][I];
898     }
899     Res[I] = Builder.CreateGEP(GEPI.getSourceElementType(), SplitOps[0],
900                                ArrayRef(SplitOps).drop_front(),
901                                GEPI.getName() + ".i" + Twine(I));
902     if (GEPI.isInBounds())
903       if (GetElementPtrInst *NewGEPI = dyn_cast<GetElementPtrInst>(Res[I]))
904         NewGEPI->setIsInBounds();
905   }
906   gather(&GEPI, Res, *VS);
907   return true;
908 }
909 
visitCastInst(CastInst & CI)910 bool ScalarizerVisitor::visitCastInst(CastInst &CI) {
911   std::optional<VectorSplit> DestVS = getVectorSplit(CI.getDestTy());
912   if (!DestVS)
913     return false;
914 
915   std::optional<VectorSplit> SrcVS = getVectorSplit(CI.getSrcTy());
916   if (!SrcVS || SrcVS->NumPacked != DestVS->NumPacked)
917     return false;
918 
919   IRBuilder<> Builder(&CI);
920   Scatterer Op0 = scatter(&CI, CI.getOperand(0), *SrcVS);
921   assert(Op0.size() == SrcVS->NumFragments && "Mismatched cast");
922   ValueVector Res;
923   Res.resize(DestVS->NumFragments);
924   for (unsigned I = 0; I < DestVS->NumFragments; ++I)
925     Res[I] =
926         Builder.CreateCast(CI.getOpcode(), Op0[I], DestVS->getFragmentType(I),
927                            CI.getName() + ".i" + Twine(I));
928   gather(&CI, Res, *DestVS);
929   return true;
930 }
931 
visitBitCastInst(BitCastInst & BCI)932 bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) {
933   std::optional<VectorSplit> DstVS = getVectorSplit(BCI.getDestTy());
934   std::optional<VectorSplit> SrcVS = getVectorSplit(BCI.getSrcTy());
935   if (!DstVS || !SrcVS || DstVS->RemainderTy || SrcVS->RemainderTy)
936     return false;
937 
938   const bool isPointerTy = DstVS->VecTy->getElementType()->isPointerTy();
939 
940   // Vectors of pointers are always fully scalarized.
941   assert(!isPointerTy || (DstVS->NumPacked == 1 && SrcVS->NumPacked == 1));
942 
943   IRBuilder<> Builder(&BCI);
944   Scatterer Op0 = scatter(&BCI, BCI.getOperand(0), *SrcVS);
945   ValueVector Res;
946   Res.resize(DstVS->NumFragments);
947 
948   unsigned DstSplitBits = DstVS->SplitTy->getPrimitiveSizeInBits();
949   unsigned SrcSplitBits = SrcVS->SplitTy->getPrimitiveSizeInBits();
950 
951   if (isPointerTy || DstSplitBits == SrcSplitBits) {
952     assert(DstVS->NumFragments == SrcVS->NumFragments);
953     for (unsigned I = 0; I < DstVS->NumFragments; ++I) {
954       Res[I] = Builder.CreateBitCast(Op0[I], DstVS->getFragmentType(I),
955                                      BCI.getName() + ".i" + Twine(I));
956     }
957   } else if (SrcSplitBits % DstSplitBits == 0) {
958     // Convert each source fragment to the same-sized destination vector and
959     // then scatter the result to the destination.
960     VectorSplit MidVS;
961     MidVS.NumPacked = DstVS->NumPacked;
962     MidVS.NumFragments = SrcSplitBits / DstSplitBits;
963     MidVS.VecTy = FixedVectorType::get(DstVS->VecTy->getElementType(),
964                                        MidVS.NumPacked * MidVS.NumFragments);
965     MidVS.SplitTy = DstVS->SplitTy;
966 
967     unsigned ResI = 0;
968     for (unsigned I = 0; I < SrcVS->NumFragments; ++I) {
969       Value *V = Op0[I];
970 
971       // Look through any existing bitcasts before converting to <N x t2>.
972       // In the best case, the resulting conversion might be a no-op.
973       Instruction *VI;
974       while ((VI = dyn_cast<Instruction>(V)) &&
975              VI->getOpcode() == Instruction::BitCast)
976         V = VI->getOperand(0);
977 
978       V = Builder.CreateBitCast(V, MidVS.VecTy, V->getName() + ".cast");
979 
980       Scatterer Mid = scatter(&BCI, V, MidVS);
981       for (unsigned J = 0; J < MidVS.NumFragments; ++J)
982         Res[ResI++] = Mid[J];
983     }
984   } else if (DstSplitBits % SrcSplitBits == 0) {
985     // Gather enough source fragments to make up a destination fragment and
986     // then convert to the destination type.
987     VectorSplit MidVS;
988     MidVS.NumFragments = DstSplitBits / SrcSplitBits;
989     MidVS.NumPacked = SrcVS->NumPacked;
990     MidVS.VecTy = FixedVectorType::get(SrcVS->VecTy->getElementType(),
991                                        MidVS.NumPacked * MidVS.NumFragments);
992     MidVS.SplitTy = SrcVS->SplitTy;
993 
994     unsigned SrcI = 0;
995     SmallVector<Value *, 8> ConcatOps;
996     ConcatOps.resize(MidVS.NumFragments);
997     for (unsigned I = 0; I < DstVS->NumFragments; ++I) {
998       for (unsigned J = 0; J < MidVS.NumFragments; ++J)
999         ConcatOps[J] = Op0[SrcI++];
1000       Value *V = concatenate(Builder, ConcatOps, MidVS,
1001                              BCI.getName() + ".i" + Twine(I));
1002       Res[I] = Builder.CreateBitCast(V, DstVS->getFragmentType(I),
1003                                      BCI.getName() + ".i" + Twine(I));
1004     }
1005   } else {
1006     return false;
1007   }
1008 
1009   gather(&BCI, Res, *DstVS);
1010   return true;
1011 }
1012 
visitInsertElementInst(InsertElementInst & IEI)1013 bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
1014   std::optional<VectorSplit> VS = getVectorSplit(IEI.getType());
1015   if (!VS)
1016     return false;
1017 
1018   IRBuilder<> Builder(&IEI);
1019   Scatterer Op0 = scatter(&IEI, IEI.getOperand(0), *VS);
1020   Value *NewElt = IEI.getOperand(1);
1021   Value *InsIdx = IEI.getOperand(2);
1022 
1023   ValueVector Res;
1024   Res.resize(VS->NumFragments);
1025 
1026   if (auto *CI = dyn_cast<ConstantInt>(InsIdx)) {
1027     unsigned Idx = CI->getZExtValue();
1028     unsigned Fragment = Idx / VS->NumPacked;
1029     for (unsigned I = 0; I < VS->NumFragments; ++I) {
1030       if (I == Fragment) {
1031         bool IsPacked = VS->NumPacked > 1;
1032         if (Fragment == VS->NumFragments - 1 && VS->RemainderTy &&
1033             !VS->RemainderTy->isVectorTy())
1034           IsPacked = false;
1035         if (IsPacked) {
1036           Res[I] =
1037               Builder.CreateInsertElement(Op0[I], NewElt, Idx % VS->NumPacked);
1038         } else {
1039           Res[I] = NewElt;
1040         }
1041       } else {
1042         Res[I] = Op0[I];
1043       }
1044     }
1045   } else {
1046     // Never split a variable insertelement that isn't fully scalarized.
1047     if (!ScalarizeVariableInsertExtract || VS->NumPacked > 1)
1048       return false;
1049 
1050     for (unsigned I = 0; I < VS->NumFragments; ++I) {
1051       Value *ShouldReplace =
1052           Builder.CreateICmpEQ(InsIdx, ConstantInt::get(InsIdx->getType(), I),
1053                                InsIdx->getName() + ".is." + Twine(I));
1054       Value *OldElt = Op0[I];
1055       Res[I] = Builder.CreateSelect(ShouldReplace, NewElt, OldElt,
1056                                     IEI.getName() + ".i" + Twine(I));
1057     }
1058   }
1059 
1060   gather(&IEI, Res, *VS);
1061   return true;
1062 }
1063 
visitExtractValueInst(ExtractValueInst & EVI)1064 bool ScalarizerVisitor::visitExtractValueInst(ExtractValueInst &EVI) {
1065   Value *Op = EVI.getOperand(0);
1066   Type *OpTy = Op->getType();
1067   ValueVector Res;
1068   if (!isStructOfMatchingFixedVectors(OpTy))
1069     return false;
1070   if (CallInst *CI = dyn_cast<CallInst>(Op)) {
1071     Function *F = CI->getCalledFunction();
1072     if (!F)
1073       return false;
1074     Intrinsic::ID ID = F->getIntrinsicID();
1075     if (ID == Intrinsic::not_intrinsic || !isTriviallyScalarizable(ID, TTI))
1076       return false;
1077     // Note: Fall through means Operand is a`CallInst` and it is defined in
1078     // `isTriviallyScalarizable`.
1079   } else
1080     return false;
1081   Type *VecType = cast<FixedVectorType>(OpTy->getContainedType(0));
1082   std::optional<VectorSplit> VS = getVectorSplit(VecType);
1083   if (!VS)
1084     return false;
1085   for (unsigned I = 1; I < OpTy->getNumContainedTypes(); I++) {
1086     std::optional<VectorSplit> CurrVS =
1087         getVectorSplit(cast<FixedVectorType>(OpTy->getContainedType(I)));
1088     // It is possible for VectorSplit.NumPacked >= NumElems. If that happens a
1089     // VectorSplit is not returned and we will bailout of handling this call.
1090     // The secondary bailout case is if NumPacked does not match. This can
1091     // happen if ScalarizeMinBits is not set to the default. This means with
1092     // certain ScalarizeMinBits intrinsics like frexp will only scalarize when
1093     // the struct elements have the same bitness.
1094     if (!CurrVS || CurrVS->NumPacked != VS->NumPacked)
1095       return false;
1096   }
1097   IRBuilder<> Builder(&EVI);
1098   Scatterer Op0 = scatter(&EVI, Op, *VS);
1099   assert(!EVI.getIndices().empty() && "Make sure an index exists");
1100   // Note for our use case we only care about the top level index.
1101   unsigned Index = EVI.getIndices()[0];
1102   for (unsigned OpIdx = 0; OpIdx < Op0.size(); ++OpIdx) {
1103     Value *ResElem = Builder.CreateExtractValue(
1104         Op0[OpIdx], Index, EVI.getName() + ".elem" + Twine(Index));
1105     Res.push_back(ResElem);
1106   }
1107 
1108   gather(&EVI, Res, *VS);
1109   return true;
1110 }
1111 
visitExtractElementInst(ExtractElementInst & EEI)1112 bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
1113   std::optional<VectorSplit> VS = getVectorSplit(EEI.getOperand(0)->getType());
1114   if (!VS)
1115     return false;
1116 
1117   IRBuilder<> Builder(&EEI);
1118   Scatterer Op0 = scatter(&EEI, EEI.getOperand(0), *VS);
1119   Value *ExtIdx = EEI.getOperand(1);
1120 
1121   if (auto *CI = dyn_cast<ConstantInt>(ExtIdx)) {
1122     unsigned Idx = CI->getZExtValue();
1123     unsigned Fragment = Idx / VS->NumPacked;
1124     Value *Res = Op0[Fragment];
1125     bool IsPacked = VS->NumPacked > 1;
1126     if (Fragment == VS->NumFragments - 1 && VS->RemainderTy &&
1127         !VS->RemainderTy->isVectorTy())
1128       IsPacked = false;
1129     if (IsPacked)
1130       Res = Builder.CreateExtractElement(Res, Idx % VS->NumPacked);
1131     replaceUses(&EEI, Res);
1132     return true;
1133   }
1134 
1135   // Never split a variable extractelement that isn't fully scalarized.
1136   if (!ScalarizeVariableInsertExtract || VS->NumPacked > 1)
1137     return false;
1138 
1139   Value *Res = PoisonValue::get(VS->VecTy->getElementType());
1140   for (unsigned I = 0; I < VS->NumFragments; ++I) {
1141     Value *ShouldExtract =
1142         Builder.CreateICmpEQ(ExtIdx, ConstantInt::get(ExtIdx->getType(), I),
1143                              ExtIdx->getName() + ".is." + Twine(I));
1144     Value *Elt = Op0[I];
1145     Res = Builder.CreateSelect(ShouldExtract, Elt, Res,
1146                                EEI.getName() + ".upto" + Twine(I));
1147   }
1148   replaceUses(&EEI, Res);
1149   return true;
1150 }
1151 
visitShuffleVectorInst(ShuffleVectorInst & SVI)1152 bool ScalarizerVisitor::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
1153   std::optional<VectorSplit> VS = getVectorSplit(SVI.getType());
1154   std::optional<VectorSplit> VSOp =
1155       getVectorSplit(SVI.getOperand(0)->getType());
1156   if (!VS || !VSOp || VS->NumPacked > 1 || VSOp->NumPacked > 1)
1157     return false;
1158 
1159   Scatterer Op0 = scatter(&SVI, SVI.getOperand(0), *VSOp);
1160   Scatterer Op1 = scatter(&SVI, SVI.getOperand(1), *VSOp);
1161   ValueVector Res;
1162   Res.resize(VS->NumFragments);
1163 
1164   for (unsigned I = 0; I < VS->NumFragments; ++I) {
1165     int Selector = SVI.getMaskValue(I);
1166     if (Selector < 0)
1167       Res[I] = PoisonValue::get(VS->VecTy->getElementType());
1168     else if (unsigned(Selector) < Op0.size())
1169       Res[I] = Op0[Selector];
1170     else
1171       Res[I] = Op1[Selector - Op0.size()];
1172   }
1173   gather(&SVI, Res, *VS);
1174   return true;
1175 }
1176 
visitPHINode(PHINode & PHI)1177 bool ScalarizerVisitor::visitPHINode(PHINode &PHI) {
1178   std::optional<VectorSplit> VS = getVectorSplit(PHI.getType());
1179   if (!VS)
1180     return false;
1181 
1182   IRBuilder<> Builder(&PHI);
1183   ValueVector Res;
1184   Res.resize(VS->NumFragments);
1185 
1186   unsigned NumOps = PHI.getNumOperands();
1187   for (unsigned I = 0; I < VS->NumFragments; ++I) {
1188     Res[I] = Builder.CreatePHI(VS->getFragmentType(I), NumOps,
1189                                PHI.getName() + ".i" + Twine(I));
1190   }
1191 
1192   for (unsigned I = 0; I < NumOps; ++I) {
1193     Scatterer Op = scatter(&PHI, PHI.getIncomingValue(I), *VS);
1194     BasicBlock *IncomingBlock = PHI.getIncomingBlock(I);
1195     for (unsigned J = 0; J < VS->NumFragments; ++J)
1196       cast<PHINode>(Res[J])->addIncoming(Op[J], IncomingBlock);
1197   }
1198   gather(&PHI, Res, *VS);
1199   return true;
1200 }
1201 
visitLoadInst(LoadInst & LI)1202 bool ScalarizerVisitor::visitLoadInst(LoadInst &LI) {
1203   if (!ScalarizeLoadStore)
1204     return false;
1205   if (!LI.isSimple())
1206     return false;
1207 
1208   std::optional<VectorLayout> Layout = getVectorLayout(
1209       LI.getType(), LI.getAlign(), LI.getDataLayout());
1210   if (!Layout)
1211     return false;
1212 
1213   IRBuilder<> Builder(&LI);
1214   Scatterer Ptr = scatter(&LI, LI.getPointerOperand(), Layout->VS);
1215   ValueVector Res;
1216   Res.resize(Layout->VS.NumFragments);
1217 
1218   for (unsigned I = 0; I < Layout->VS.NumFragments; ++I) {
1219     Res[I] = Builder.CreateAlignedLoad(Layout->VS.getFragmentType(I), Ptr[I],
1220                                        Align(Layout->getFragmentAlign(I)),
1221                                        LI.getName() + ".i" + Twine(I));
1222   }
1223   gather(&LI, Res, Layout->VS);
1224   return true;
1225 }
1226 
visitStoreInst(StoreInst & SI)1227 bool ScalarizerVisitor::visitStoreInst(StoreInst &SI) {
1228   if (!ScalarizeLoadStore)
1229     return false;
1230   if (!SI.isSimple())
1231     return false;
1232 
1233   Value *FullValue = SI.getValueOperand();
1234   std::optional<VectorLayout> Layout = getVectorLayout(
1235       FullValue->getType(), SI.getAlign(), SI.getDataLayout());
1236   if (!Layout)
1237     return false;
1238 
1239   IRBuilder<> Builder(&SI);
1240   Scatterer VPtr = scatter(&SI, SI.getPointerOperand(), Layout->VS);
1241   Scatterer VVal = scatter(&SI, FullValue, Layout->VS);
1242 
1243   ValueVector Stores;
1244   Stores.resize(Layout->VS.NumFragments);
1245   for (unsigned I = 0; I < Layout->VS.NumFragments; ++I) {
1246     Value *Val = VVal[I];
1247     Value *Ptr = VPtr[I];
1248     Stores[I] =
1249         Builder.CreateAlignedStore(Val, Ptr, Layout->getFragmentAlign(I));
1250   }
1251   transferMetadataAndIRFlags(&SI, Stores);
1252   return true;
1253 }
1254 
visitCallInst(CallInst & CI)1255 bool ScalarizerVisitor::visitCallInst(CallInst &CI) {
1256   return splitCall(CI);
1257 }
1258 
visitFreezeInst(FreezeInst & FI)1259 bool ScalarizerVisitor::visitFreezeInst(FreezeInst &FI) {
1260   return splitUnary(FI, [](IRBuilder<> &Builder, Value *Op, const Twine &Name) {
1261     return Builder.CreateFreeze(Op, Name);
1262   });
1263 }
1264 
1265 // Delete the instructions that we scalarized.  If a full vector result
1266 // is still needed, recreate it using InsertElements.
finish()1267 bool ScalarizerVisitor::finish() {
1268   // The presence of data in Gathered or Scattered indicates changes
1269   // made to the Function.
1270   if (Gathered.empty() && Scattered.empty() && !Scalarized)
1271     return false;
1272   for (const auto &GMI : Gathered) {
1273     Instruction *Op = GMI.first;
1274     ValueVector &CV = *GMI.second;
1275     if (!Op->use_empty()) {
1276       // The value is still needed, so recreate it using a series of
1277       // insertelements and/or shufflevectors.
1278       Value *Res;
1279       if (auto *Ty = dyn_cast<FixedVectorType>(Op->getType())) {
1280         BasicBlock *BB = Op->getParent();
1281         IRBuilder<> Builder(Op);
1282         if (isa<PHINode>(Op))
1283           Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
1284 
1285         VectorSplit VS = *getVectorSplit(Ty);
1286         assert(VS.NumFragments == CV.size());
1287 
1288         Res = concatenate(Builder, CV, VS, Op->getName());
1289 
1290         Res->takeName(Op);
1291       } else if (auto *Ty = dyn_cast<StructType>(Op->getType())) {
1292         BasicBlock *BB = Op->getParent();
1293         IRBuilder<> Builder(Op);
1294         if (isa<PHINode>(Op))
1295           Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
1296 
1297         // Iterate over each element in the struct
1298         unsigned NumOfStructElements = Ty->getNumElements();
1299         SmallVector<ValueVector, 4> ElemCV(NumOfStructElements);
1300         for (unsigned I = 0; I < NumOfStructElements; ++I) {
1301           for (auto *CVelem : CV) {
1302             Value *Elem = Builder.CreateExtractValue(
1303                 CVelem, I, Op->getName() + ".elem" + Twine(I));
1304             ElemCV[I].push_back(Elem);
1305           }
1306         }
1307         Res = PoisonValue::get(Ty);
1308         for (unsigned I = 0; I < NumOfStructElements; ++I) {
1309           Type *ElemTy = Ty->getElementType(I);
1310           assert(isa<FixedVectorType>(ElemTy) &&
1311                  "Only Structs of all FixedVectorType supported");
1312           VectorSplit VS = *getVectorSplit(ElemTy);
1313           assert(VS.NumFragments == CV.size());
1314 
1315           Value *ConcatenatedVector =
1316               concatenate(Builder, ElemCV[I], VS, Op->getName());
1317           Res = Builder.CreateInsertValue(Res, ConcatenatedVector, I,
1318                                           Op->getName() + ".insert");
1319         }
1320       } else {
1321         assert(CV.size() == 1 && Op->getType() == CV[0]->getType());
1322         Res = CV[0];
1323         if (Op == Res)
1324           continue;
1325       }
1326       Op->replaceAllUsesWith(Res);
1327     }
1328     PotentiallyDeadInstrs.emplace_back(Op);
1329   }
1330   Gathered.clear();
1331   Scattered.clear();
1332   Scalarized = false;
1333 
1334   RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
1335 
1336   return true;
1337 }
1338 
run(Function & F,FunctionAnalysisManager & AM)1339 PreservedAnalyses ScalarizerPass::run(Function &F, FunctionAnalysisManager &AM) {
1340   DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F);
1341   const TargetTransformInfo *TTI = &AM.getResult<TargetIRAnalysis>(F);
1342   ScalarizerVisitor Impl(DT, TTI, Options);
1343   bool Changed = Impl.visit(F);
1344   PreservedAnalyses PA;
1345   PA.preserve<DominatorTreeAnalysis>();
1346   return Changed ? PA : PreservedAnalyses::all();
1347 }
1348