xref: /freebsd/contrib/llvm-project/llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp (revision 13ec1e3155c7e9bf037b12af186351b7fa9b9450)
1 //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// This pass custom lowers llvm.gather and llvm.scatter instructions to
10 /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to
11 /// produce a better final result as we go.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "ARM.h"
16 #include "ARMBaseInstrInfo.h"
17 #include "ARMSubtarget.h"
18 #include "llvm/Analysis/LoopInfo.h"
19 #include "llvm/Analysis/TargetTransformInfo.h"
20 #include "llvm/Analysis/ValueTracking.h"
21 #include "llvm/CodeGen/TargetLowering.h"
22 #include "llvm/CodeGen/TargetPassConfig.h"
23 #include "llvm/CodeGen/TargetSubtargetInfo.h"
24 #include "llvm/InitializePasses.h"
25 #include "llvm/IR/BasicBlock.h"
26 #include "llvm/IR/Constant.h"
27 #include "llvm/IR/Constants.h"
28 #include "llvm/IR/DerivedTypes.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/InstrTypes.h"
31 #include "llvm/IR/Instruction.h"
32 #include "llvm/IR/Instructions.h"
33 #include "llvm/IR/IntrinsicInst.h"
34 #include "llvm/IR/Intrinsics.h"
35 #include "llvm/IR/IntrinsicsARM.h"
36 #include "llvm/IR/IRBuilder.h"
37 #include "llvm/IR/PatternMatch.h"
38 #include "llvm/IR/Type.h"
39 #include "llvm/IR/Value.h"
40 #include "llvm/Pass.h"
41 #include "llvm/Support/Casting.h"
42 #include "llvm/Transforms/Utils/Local.h"
43 #include <algorithm>
44 #include <cassert>
45 
46 using namespace llvm;
47 
48 #define DEBUG_TYPE "arm-mve-gather-scatter-lowering"
49 
50 cl::opt<bool> EnableMaskedGatherScatters(
51     "enable-arm-maskedgatscat", cl::Hidden, cl::init(true),
52     cl::desc("Enable the generation of masked gathers and scatters"));
53 
54 namespace {
55 
56 class MVEGatherScatterLowering : public FunctionPass {
57 public:
58   static char ID; // Pass identification, replacement for typeid
59 
60   explicit MVEGatherScatterLowering() : FunctionPass(ID) {
61     initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry());
62   }
63 
64   bool runOnFunction(Function &F) override;
65 
66   StringRef getPassName() const override {
67     return "MVE gather/scatter lowering";
68   }
69 
70   void getAnalysisUsage(AnalysisUsage &AU) const override {
71     AU.setPreservesCFG();
72     AU.addRequired<TargetPassConfig>();
73     AU.addRequired<LoopInfoWrapperPass>();
74     FunctionPass::getAnalysisUsage(AU);
75   }
76 
77 private:
78   LoopInfo *LI = nullptr;
79 
80   // Check this is a valid gather with correct alignment
81   bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,
82                                Align Alignment);
83   // Check whether Ptr is hidden behind a bitcast and look through it
84   void lookThroughBitcast(Value *&Ptr);
85   // Decompose a ptr into Base and Offsets, potentially using a GEP to return a
86   // scalar base and vector offsets, or else fallback to using a base of 0 and
87   // offset of Ptr where possible.
88   Value *decomposePtr(Value *Ptr, Value *&Offsets, int &Scale,
89                       FixedVectorType *Ty, Type *MemoryTy,
90                       IRBuilder<> &Builder);
91   // Check for a getelementptr and deduce base and offsets from it, on success
92   // returning the base directly and the offsets indirectly using the Offsets
93   // argument
94   Value *decomposeGEP(Value *&Offsets, FixedVectorType *Ty,
95                       GetElementPtrInst *GEP, IRBuilder<> &Builder);
96   // Compute the scale of this gather/scatter instruction
97   int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize);
98   // If the value is a constant, or derived from constants via additions
99   // and multilications, return its numeric value
100   Optional<int64_t> getIfConst(const Value *V);
101   // If Inst is an add instruction, check whether one summand is a
102   // constant. If so, scale this constant and return it together with
103   // the other summand.
104   std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale);
105 
106   Instruction *lowerGather(IntrinsicInst *I);
107   // Create a gather from a base + vector of offsets
108   Instruction *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,
109                                            Instruction *&Root,
110                                            IRBuilder<> &Builder);
111   // Create a gather from a vector of pointers
112   Instruction *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,
113                                          IRBuilder<> &Builder,
114                                          int64_t Increment = 0);
115   // Create an incrementing gather from a vector of pointers
116   Instruction *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr,
117                                            IRBuilder<> &Builder,
118                                            int64_t Increment = 0);
119 
120   Instruction *lowerScatter(IntrinsicInst *I);
121   // Create a scatter to a base + vector of offsets
122   Instruction *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets,
123                                             IRBuilder<> &Builder);
124   // Create a scatter to a vector of pointers
125   Instruction *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr,
126                                           IRBuilder<> &Builder,
127                                           int64_t Increment = 0);
128   // Create an incrementing scatter from a vector of pointers
129   Instruction *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr,
130                                             IRBuilder<> &Builder,
131                                             int64_t Increment = 0);
132 
133   // QI gathers and scatters can increment their offsets on their own if
134   // the increment is a constant value (digit)
135   Instruction *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *Ptr,
136                                             IRBuilder<> &Builder);
137   // QI gathers/scatters can increment their offsets on their own if the
138   // increment is a constant value (digit) - this creates a writeback QI
139   // gather/scatter
140   Instruction *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr,
141                                               Value *Ptr, unsigned TypeScale,
142                                               IRBuilder<> &Builder);
143 
144   // Optimise the base and offsets of the given address
145   bool optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI);
146   // Try to fold consecutive geps together into one
147   Value *foldGEP(GetElementPtrInst *GEP, Value *&Offsets, IRBuilder<> &Builder);
148   // Check whether these offsets could be moved out of the loop they're in
149   bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI);
150   // Pushes the given add out of the loop
151   void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);
152   // Pushes the given mul out of the loop
153   void pushOutMul(PHINode *&Phi, Value *IncrementPerRound,
154                   Value *OffsSecondOperand, unsigned LoopIncrement,
155                   IRBuilder<> &Builder);
156 };
157 
158 } // end anonymous namespace
159 
160 char MVEGatherScatterLowering::ID = 0;
161 
162 INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE,
163                 "MVE gather/scattering lowering pass", false, false)
164 
165 Pass *llvm::createMVEGatherScatterLoweringPass() {
166   return new MVEGatherScatterLowering();
167 }
168 
169 bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,
170                                                        unsigned ElemSize,
171                                                        Align Alignment) {
172   if (((NumElements == 4 &&
173         (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) ||
174        (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) ||
175        (NumElements == 16 && ElemSize == 8)) &&
176       Alignment >= ElemSize / 8)
177     return true;
178   LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have "
179                     << "valid alignment or vector type \n");
180   return false;
181 }
182 
183 static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount) {
184   // Offsets that are not of type <N x i32> are sign extended by the
185   // getelementptr instruction, and MVE gathers/scatters treat the offset as
186   // unsigned. Thus, if the element size is smaller than 32, we can only allow
187   // positive offsets - i.e., the offsets are not allowed to be variables we
188   // can't look into.
189   // Additionally, <N x i32> offsets have to either originate from a zext of a
190   // vector with element types smaller or equal the type of the gather we're
191   // looking at, or consist of constants that we can check are small enough
192   // to fit into the gather type.
193   // Thus we check that 0 < value < 2^TargetElemSize.
194   unsigned TargetElemSize = 128 / TargetElemCount;
195   unsigned OffsetElemSize = cast<FixedVectorType>(Offsets->getType())
196                                 ->getElementType()
197                                 ->getScalarSizeInBits();
198   if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) {
199     Constant *ConstOff = dyn_cast<Constant>(Offsets);
200     if (!ConstOff)
201       return false;
202     int64_t TargetElemMaxSize = (1ULL << TargetElemSize);
203     auto CheckValueSize = [TargetElemMaxSize](Value *OffsetElem) {
204       ConstantInt *OConst = dyn_cast<ConstantInt>(OffsetElem);
205       if (!OConst)
206         return false;
207       int SExtValue = OConst->getSExtValue();
208       if (SExtValue >= TargetElemMaxSize || SExtValue < 0)
209         return false;
210       return true;
211     };
212     if (isa<FixedVectorType>(ConstOff->getType())) {
213       for (unsigned i = 0; i < TargetElemCount; i++) {
214         if (!CheckValueSize(ConstOff->getAggregateElement(i)))
215           return false;
216       }
217     } else {
218       if (!CheckValueSize(ConstOff))
219         return false;
220     }
221   }
222   return true;
223 }
224 
225 Value *MVEGatherScatterLowering::decomposePtr(Value *Ptr, Value *&Offsets,
226                                               int &Scale, FixedVectorType *Ty,
227                                               Type *MemoryTy,
228                                               IRBuilder<> &Builder) {
229   if (auto *GEP = dyn_cast<GetElementPtrInst>(Ptr)) {
230     if (Value *V = decomposeGEP(Offsets, Ty, GEP, Builder)) {
231       Scale =
232           computeScale(GEP->getSourceElementType()->getPrimitiveSizeInBits(),
233                        MemoryTy->getScalarSizeInBits());
234       return Scale == -1 ? nullptr : V;
235     }
236   }
237 
238   // If we couldn't use the GEP (or it doesn't exist), attempt to use a
239   // BasePtr of 0 with Ptr as the Offsets, so long as there are only 4
240   // elements.
241   FixedVectorType *PtrTy = cast<FixedVectorType>(Ptr->getType());
242   if (PtrTy->getNumElements() != 4 || MemoryTy->getScalarSizeInBits() == 32)
243     return nullptr;
244   Value *Zero = ConstantInt::get(Builder.getInt32Ty(), 0);
245   Value *BasePtr = Builder.CreateIntToPtr(Zero, Builder.getInt8PtrTy());
246   Offsets = Builder.CreatePtrToInt(
247       Ptr, FixedVectorType::get(Builder.getInt32Ty(), 4));
248   Scale = 0;
249   return BasePtr;
250 }
251 
252 Value *MVEGatherScatterLowering::decomposeGEP(Value *&Offsets,
253                                               FixedVectorType *Ty,
254                                               GetElementPtrInst *GEP,
255                                               IRBuilder<> &Builder) {
256   if (!GEP) {
257     LLVM_DEBUG(dbgs() << "masked gathers/scatters: no getelementpointer "
258                       << "found\n");
259     return nullptr;
260   }
261   LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found."
262                     << " Looking at intrinsic for base + vector of offsets\n");
263   Value *GEPPtr = GEP->getPointerOperand();
264   Offsets = GEP->getOperand(1);
265   if (GEPPtr->getType()->isVectorTy() ||
266       !isa<FixedVectorType>(Offsets->getType()))
267     return nullptr;
268 
269   if (GEP->getNumOperands() != 2) {
270     LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many"
271                       << " operands. Expanding.\n");
272     return nullptr;
273   }
274   Offsets = GEP->getOperand(1);
275   unsigned OffsetsElemCount =
276       cast<FixedVectorType>(Offsets->getType())->getNumElements();
277   // Paranoid check whether the number of parallel lanes is the same
278   assert(Ty->getNumElements() == OffsetsElemCount);
279 
280   ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets);
281   if (ZextOffs)
282     Offsets = ZextOffs->getOperand(0);
283   FixedVectorType *OffsetType = cast<FixedVectorType>(Offsets->getType());
284 
285   // If the offsets are already being zext-ed to <N x i32>, that relieves us of
286   // having to make sure that they won't overflow.
287   if (!ZextOffs || cast<FixedVectorType>(ZextOffs->getDestTy())
288                            ->getElementType()
289                            ->getScalarSizeInBits() != 32)
290     if (!checkOffsetSize(Offsets, OffsetsElemCount))
291       return nullptr;
292 
293   // The offset sizes have been checked; if any truncating or zext-ing is
294   // required to fix them, do that now
295   if (Ty != Offsets->getType()) {
296     if ((Ty->getElementType()->getScalarSizeInBits() <
297          OffsetType->getElementType()->getScalarSizeInBits())) {
298       Offsets = Builder.CreateTrunc(Offsets, Ty);
299     } else {
300       Offsets = Builder.CreateZExt(Offsets, VectorType::getInteger(Ty));
301     }
302   }
303   // If none of the checks failed, return the gep's base pointer
304   LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n");
305   return GEPPtr;
306 }
307 
308 void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) {
309   // Look through bitcast instruction if #elements is the same
310   if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) {
311     auto *BCTy = cast<FixedVectorType>(BitCast->getType());
312     auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType());
313     if (BCTy->getNumElements() == BCSrcTy->getNumElements()) {
314       LLVM_DEBUG(dbgs() << "masked gathers/scatters: looking through "
315                         << "bitcast\n");
316       Ptr = BitCast->getOperand(0);
317     }
318   }
319 }
320 
321 int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize,
322                                            unsigned MemoryElemSize) {
323   // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2,
324   // or a 8bit, 16bit or 32bit load/store scaled by 1
325   if (GEPElemSize == 32 && MemoryElemSize == 32)
326     return 2;
327   else if (GEPElemSize == 16 && MemoryElemSize == 16)
328     return 1;
329   else if (GEPElemSize == 8)
330     return 0;
331   LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't "
332                     << "create intrinsic\n");
333   return -1;
334 }
335 
336 Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
337   const Constant *C = dyn_cast<Constant>(V);
338   if (C != nullptr)
339     return Optional<int64_t>{C->getUniqueInteger().getSExtValue()};
340   if (!isa<Instruction>(V))
341     return Optional<int64_t>{};
342 
343   const Instruction *I = cast<Instruction>(V);
344   if (I->getOpcode() == Instruction::Add ||
345               I->getOpcode() == Instruction::Mul) {
346     Optional<int64_t> Op0 = getIfConst(I->getOperand(0));
347     Optional<int64_t> Op1 = getIfConst(I->getOperand(1));
348     if (!Op0 || !Op1)
349       return Optional<int64_t>{};
350     if (I->getOpcode() == Instruction::Add)
351       return Optional<int64_t>{Op0.getValue() + Op1.getValue()};
352     if (I->getOpcode() == Instruction::Mul)
353       return Optional<int64_t>{Op0.getValue() * Op1.getValue()};
354   }
355   return Optional<int64_t>{};
356 }
357 
358 std::pair<Value *, int64_t>
359 MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) {
360   std::pair<Value *, int64_t> ReturnFalse =
361       std::pair<Value *, int64_t>(nullptr, 0);
362   // At this point, the instruction we're looking at must be an add or we
363   // bail out
364   Instruction *Add = dyn_cast<Instruction>(Inst);
365   if (Add == nullptr || Add->getOpcode() != Instruction::Add)
366     return ReturnFalse;
367 
368   Value *Summand;
369   Optional<int64_t> Const;
370   // Find out which operand the value that is increased is
371   if ((Const = getIfConst(Add->getOperand(0))))
372     Summand = Add->getOperand(1);
373   else if ((Const = getIfConst(Add->getOperand(1))))
374     Summand = Add->getOperand(0);
375   else
376     return ReturnFalse;
377 
378   // Check that the constant is small enough for an incrementing gather
379   int64_t Immediate = Const.getValue() << TypeScale;
380   if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0)
381     return ReturnFalse;
382 
383   return std::pair<Value *, int64_t>(Summand, Immediate);
384 }
385 
386 Instruction *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
387   using namespace PatternMatch;
388   LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n"
389                     << *I << "\n");
390 
391   // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
392   // Attempt to turn the masked gather in I into a MVE intrinsic
393   // Potentially optimising the addressing modes as we do so.
394   auto *Ty = cast<FixedVectorType>(I->getType());
395   Value *Ptr = I->getArgOperand(0);
396   Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue();
397   Value *Mask = I->getArgOperand(2);
398   Value *PassThru = I->getArgOperand(3);
399 
400   if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
401                                Alignment))
402     return nullptr;
403   lookThroughBitcast(Ptr);
404   assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
405 
406   IRBuilder<> Builder(I->getContext());
407   Builder.SetInsertPoint(I);
408   Builder.SetCurrentDebugLocation(I->getDebugLoc());
409 
410   Instruction *Root = I;
411 
412   Instruction *Load = tryCreateIncrementingGatScat(I, Ptr, Builder);
413   if (!Load)
414     Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder);
415   if (!Load)
416     Load = tryCreateMaskedGatherBase(I, Ptr, Builder);
417   if (!Load)
418     return nullptr;
419 
420   if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) {
421     LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
422                       << "creating select\n");
423     Load = SelectInst::Create(Mask, Load, PassThru);
424     Builder.Insert(Load);
425   }
426 
427   Root->replaceAllUsesWith(Load);
428   Root->eraseFromParent();
429   if (Root != I)
430     // If this was an extending gather, we need to get rid of the sext/zext
431     // sext/zext as well as of the gather itself
432     I->eraseFromParent();
433 
434   LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n"
435                     << *Load << "\n");
436   return Load;
437 }
438 
439 Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBase(
440     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
441   using namespace PatternMatch;
442   auto *Ty = cast<FixedVectorType>(I->getType());
443   LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
444   if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
445     // Can't build an intrinsic for this
446     return nullptr;
447   Value *Mask = I->getArgOperand(2);
448   if (match(Mask, m_One()))
449     return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
450                                    {Ty, Ptr->getType()},
451                                    {Ptr, Builder.getInt32(Increment)});
452   else
453     return Builder.CreateIntrinsic(
454         Intrinsic::arm_mve_vldr_gather_base_predicated,
455         {Ty, Ptr->getType(), Mask->getType()},
456         {Ptr, Builder.getInt32(Increment), Mask});
457 }
458 
459 Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB(
460     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
461   using namespace PatternMatch;
462   auto *Ty = cast<FixedVectorType>(I->getType());
463   LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers with "
464                     << "writeback\n");
465   if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
466     // Can't build an intrinsic for this
467     return nullptr;
468   Value *Mask = I->getArgOperand(2);
469   if (match(Mask, m_One()))
470     return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb,
471                                    {Ty, Ptr->getType()},
472                                    {Ptr, Builder.getInt32(Increment)});
473   else
474     return Builder.CreateIntrinsic(
475         Intrinsic::arm_mve_vldr_gather_base_wb_predicated,
476         {Ty, Ptr->getType(), Mask->getType()},
477         {Ptr, Builder.getInt32(Increment), Mask});
478 }
479 
480 Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
481     IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) {
482   using namespace PatternMatch;
483 
484   Type *MemoryTy = I->getType();
485   Type *ResultTy = MemoryTy;
486 
487   unsigned Unsigned = 1;
488   // The size of the gather was already checked in isLegalTypeAndAlignment;
489   // if it was not a full vector width an appropriate extend should follow.
490   auto *Extend = Root;
491   bool TruncResult = false;
492   if (MemoryTy->getPrimitiveSizeInBits() < 128) {
493     if (I->hasOneUse()) {
494       // If the gather has a single extend of the correct type, use an extending
495       // gather and replace the ext. In which case the correct root to replace
496       // is not the CallInst itself, but the instruction which extends it.
497       Instruction* User = cast<Instruction>(*I->users().begin());
498       if (isa<SExtInst>(User) &&
499           User->getType()->getPrimitiveSizeInBits() == 128) {
500         LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: "
501                           << *User << "\n");
502         Extend = User;
503         ResultTy = User->getType();
504         Unsigned = 0;
505       } else if (isa<ZExtInst>(User) &&
506                  User->getType()->getPrimitiveSizeInBits() == 128) {
507         LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: "
508                           << *ResultTy << "\n");
509         Extend = User;
510         ResultTy = User->getType();
511       }
512     }
513 
514     // If an extend hasn't been found and the type is an integer, create an
515     // extending gather and truncate back to the original type.
516     if (ResultTy->getPrimitiveSizeInBits() < 128 &&
517         ResultTy->isIntOrIntVectorTy()) {
518       ResultTy = ResultTy->getWithNewBitWidth(
519           128 / cast<FixedVectorType>(ResultTy)->getNumElements());
520       TruncResult = true;
521       LLVM_DEBUG(dbgs() << "masked gathers: Small input type, truncing to: "
522                         << *ResultTy << "\n");
523     }
524 
525     // The final size of the gather must be a full vector width
526     if (ResultTy->getPrimitiveSizeInBits() != 128) {
527       LLVM_DEBUG(dbgs() << "masked gathers: Extend needed but not provided "
528                            "from the correct type. Expanding\n");
529       return nullptr;
530     }
531   }
532 
533   Value *Offsets;
534   int Scale;
535   Value *BasePtr = decomposePtr(
536       Ptr, Offsets, Scale, cast<FixedVectorType>(ResultTy), MemoryTy, Builder);
537   if (!BasePtr)
538     return nullptr;
539 
540   Root = Extend;
541   Value *Mask = I->getArgOperand(2);
542   Instruction *Load = nullptr;
543   if (!match(Mask, m_One()))
544     Load = Builder.CreateIntrinsic(
545         Intrinsic::arm_mve_vldr_gather_offset_predicated,
546         {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()},
547         {BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()),
548          Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask});
549   else
550     Load = Builder.CreateIntrinsic(
551         Intrinsic::arm_mve_vldr_gather_offset,
552         {ResultTy, BasePtr->getType(), Offsets->getType()},
553         {BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()),
554          Builder.getInt32(Scale), Builder.getInt32(Unsigned)});
555 
556   if (TruncResult) {
557     Load = TruncInst::Create(Instruction::Trunc, Load, MemoryTy);
558     Builder.Insert(Load);
559   }
560   return Load;
561 }
562 
563 Instruction *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) {
564   using namespace PatternMatch;
565   LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n"
566                     << *I << "\n");
567 
568   // @llvm.masked.scatter.*(data, ptrs, alignment, mask)
569   // Attempt to turn the masked scatter in I into a MVE intrinsic
570   // Potentially optimising the addressing modes as we do so.
571   Value *Input = I->getArgOperand(0);
572   Value *Ptr = I->getArgOperand(1);
573   Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue();
574   auto *Ty = cast<FixedVectorType>(Input->getType());
575 
576   if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
577                                Alignment))
578     return nullptr;
579 
580   lookThroughBitcast(Ptr);
581   assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
582 
583   IRBuilder<> Builder(I->getContext());
584   Builder.SetInsertPoint(I);
585   Builder.SetCurrentDebugLocation(I->getDebugLoc());
586 
587   Instruction *Store = tryCreateIncrementingGatScat(I, Ptr, Builder);
588   if (!Store)
589     Store = tryCreateMaskedScatterOffset(I, Ptr, Builder);
590   if (!Store)
591     Store = tryCreateMaskedScatterBase(I, Ptr, Builder);
592   if (!Store)
593     return nullptr;
594 
595   LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n"
596                     << *Store << "\n");
597   I->eraseFromParent();
598   return Store;
599 }
600 
601 Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBase(
602     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
603   using namespace PatternMatch;
604   Value *Input = I->getArgOperand(0);
605   auto *Ty = cast<FixedVectorType>(Input->getType());
606   // Only QR variants allow truncating
607   if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) {
608     // Can't build an intrinsic for this
609     return nullptr;
610   }
611   Value *Mask = I->getArgOperand(3);
612   //  int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask)
613   LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n");
614   if (match(Mask, m_One()))
615     return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base,
616                                    {Ptr->getType(), Input->getType()},
617                                    {Ptr, Builder.getInt32(Increment), Input});
618   else
619     return Builder.CreateIntrinsic(
620         Intrinsic::arm_mve_vstr_scatter_base_predicated,
621         {Ptr->getType(), Input->getType(), Mask->getType()},
622         {Ptr, Builder.getInt32(Increment), Input, Mask});
623 }
624 
625 Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB(
626     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
627   using namespace PatternMatch;
628   Value *Input = I->getArgOperand(0);
629   auto *Ty = cast<FixedVectorType>(Input->getType());
630   LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers "
631                     << "with writeback\n");
632   if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
633     // Can't build an intrinsic for this
634     return nullptr;
635   Value *Mask = I->getArgOperand(3);
636   if (match(Mask, m_One()))
637     return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb,
638                                    {Ptr->getType(), Input->getType()},
639                                    {Ptr, Builder.getInt32(Increment), Input});
640   else
641     return Builder.CreateIntrinsic(
642         Intrinsic::arm_mve_vstr_scatter_base_wb_predicated,
643         {Ptr->getType(), Input->getType(), Mask->getType()},
644         {Ptr, Builder.getInt32(Increment), Input, Mask});
645 }
646 
647 Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterOffset(
648     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
649   using namespace PatternMatch;
650   Value *Input = I->getArgOperand(0);
651   Value *Mask = I->getArgOperand(3);
652   Type *InputTy = Input->getType();
653   Type *MemoryTy = InputTy;
654 
655   LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing"
656                     << " to base + vector of offsets\n");
657   // If the input has been truncated, try to integrate that trunc into the
658   // scatter instruction (we don't care about alignment here)
659   if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) {
660     Value *PreTrunc = Trunc->getOperand(0);
661     Type *PreTruncTy = PreTrunc->getType();
662     if (PreTruncTy->getPrimitiveSizeInBits() == 128) {
663       Input = PreTrunc;
664       InputTy = PreTruncTy;
665     }
666   }
667   bool ExtendInput = false;
668   if (InputTy->getPrimitiveSizeInBits() < 128 &&
669       InputTy->isIntOrIntVectorTy()) {
670     // If we can't find a trunc to incorporate into the instruction, create an
671     // implicit one with a zext, so that we can still create a scatter. We know
672     // that the input type is 4x/8x/16x and of type i8/i16/i32, so any type
673     // smaller than 128 bits will divide evenly into a 128bit vector.
674     InputTy = InputTy->getWithNewBitWidth(
675         128 / cast<FixedVectorType>(InputTy)->getNumElements());
676     ExtendInput = true;
677     LLVM_DEBUG(dbgs() << "masked scatters: Small input type, will extend:\n"
678                       << *Input << "\n");
679   }
680   if (InputTy->getPrimitiveSizeInBits() != 128) {
681     LLVM_DEBUG(dbgs() << "masked scatters: cannot create scatters for "
682                          "non-standard input types. Expanding.\n");
683     return nullptr;
684   }
685 
686   Value *Offsets;
687   int Scale;
688   Value *BasePtr = decomposePtr(
689       Ptr, Offsets, Scale, cast<FixedVectorType>(InputTy), MemoryTy, Builder);
690   if (!BasePtr)
691     return nullptr;
692 
693   if (ExtendInput)
694     Input = Builder.CreateZExt(Input, InputTy);
695   if (!match(Mask, m_One()))
696     return Builder.CreateIntrinsic(
697         Intrinsic::arm_mve_vstr_scatter_offset_predicated,
698         {BasePtr->getType(), Offsets->getType(), Input->getType(),
699          Mask->getType()},
700         {BasePtr, Offsets, Input,
701          Builder.getInt32(MemoryTy->getScalarSizeInBits()),
702          Builder.getInt32(Scale), Mask});
703   else
704     return Builder.CreateIntrinsic(
705         Intrinsic::arm_mve_vstr_scatter_offset,
706         {BasePtr->getType(), Offsets->getType(), Input->getType()},
707         {BasePtr, Offsets, Input,
708          Builder.getInt32(MemoryTy->getScalarSizeInBits()),
709          Builder.getInt32(Scale)});
710 }
711 
712 Instruction *MVEGatherScatterLowering::tryCreateIncrementingGatScat(
713     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
714   FixedVectorType *Ty;
715   if (I->getIntrinsicID() == Intrinsic::masked_gather)
716     Ty = cast<FixedVectorType>(I->getType());
717   else
718     Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType());
719 
720   // Incrementing gathers only exist for v4i32
721   if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
722     return nullptr;
723   // Incrementing gathers are not beneficial outside of a loop
724   Loop *L = LI->getLoopFor(I->getParent());
725   if (L == nullptr)
726     return nullptr;
727 
728   // Decompose the GEP into Base and Offsets
729   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
730   Value *Offsets;
731   Value *BasePtr = decomposeGEP(Offsets, Ty, GEP, Builder);
732   if (!BasePtr)
733     return nullptr;
734 
735   LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
736                        "wb gather/scatter\n");
737 
738   // The gep was in charge of making sure the offsets are scaled correctly
739   // - calculate that factor so it can be applied by hand
740   DataLayout DT = I->getParent()->getParent()->getParent()->getDataLayout();
741   int TypeScale =
742       computeScale(DT.getTypeSizeInBits(GEP->getOperand(0)->getType()),
743                    DT.getTypeSizeInBits(GEP->getType()) /
744                        cast<FixedVectorType>(GEP->getType())->getNumElements());
745   if (TypeScale == -1)
746     return nullptr;
747 
748   if (GEP->hasOneUse()) {
749     // Only in this case do we want to build a wb gather, because the wb will
750     // change the phi which does affect other users of the gep (which will still
751     // be using the phi in the old way)
752     if (auto *Load = tryCreateIncrementingWBGatScat(I, BasePtr, Offsets,
753                                                     TypeScale, Builder))
754       return Load;
755   }
756 
757   LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
758                        "non-wb gather/scatter\n");
759 
760   std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
761   if (Add.first == nullptr)
762     return nullptr;
763   Value *OffsetsIncoming = Add.first;
764   int64_t Immediate = Add.second;
765 
766   // Make sure the offsets are scaled correctly
767   Instruction *ScaledOffsets = BinaryOperator::Create(
768       Instruction::Shl, OffsetsIncoming,
769       Builder.CreateVectorSplat(Ty->getNumElements(), Builder.getInt32(TypeScale)),
770       "ScaledIndex", I);
771   // Add the base to the offsets
772   OffsetsIncoming = BinaryOperator::Create(
773       Instruction::Add, ScaledOffsets,
774       Builder.CreateVectorSplat(
775           Ty->getNumElements(),
776           Builder.CreatePtrToInt(
777               BasePtr,
778               cast<VectorType>(ScaledOffsets->getType())->getElementType())),
779       "StartIndex", I);
780 
781   if (I->getIntrinsicID() == Intrinsic::masked_gather)
782     return tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate);
783   else
784     return tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate);
785 }
786 
787 Instruction *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat(
788     IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale,
789     IRBuilder<> &Builder) {
790   // Check whether this gather's offset is incremented by a constant - if so,
791   // and the load is of the right type, we can merge this into a QI gather
792   Loop *L = LI->getLoopFor(I->getParent());
793   // Offsets that are worth merging into this instruction will be incremented
794   // by a constant, thus we're looking for an add of a phi and a constant
795   PHINode *Phi = dyn_cast<PHINode>(Offsets);
796   if (Phi == nullptr || Phi->getNumIncomingValues() != 2 ||
797       Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2)
798     // No phi means no IV to write back to; if there is a phi, we expect it
799     // to have exactly two incoming values; the only phis we are interested in
800     // will be loop IV's and have exactly two uses, one in their increment and
801     // one in the gather's gep
802     return nullptr;
803 
804   unsigned IncrementIndex =
805       Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1;
806   // Look through the phi to the phi increment
807   Offsets = Phi->getIncomingValue(IncrementIndex);
808 
809   std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
810   if (Add.first == nullptr)
811     return nullptr;
812   Value *OffsetsIncoming = Add.first;
813   int64_t Immediate = Add.second;
814   if (OffsetsIncoming != Phi)
815     // Then the increment we are looking at is not an increment of the
816     // induction variable, and we don't want to do a writeback
817     return nullptr;
818 
819   Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back());
820   unsigned NumElems =
821       cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements();
822 
823   // Make sure the offsets are scaled correctly
824   Instruction *ScaledOffsets = BinaryOperator::Create(
825       Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex),
826       Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)),
827       "ScaledIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
828   // Add the base to the offsets
829   OffsetsIncoming = BinaryOperator::Create(
830       Instruction::Add, ScaledOffsets,
831       Builder.CreateVectorSplat(
832           NumElems,
833           Builder.CreatePtrToInt(
834               BasePtr,
835               cast<VectorType>(ScaledOffsets->getType())->getElementType())),
836       "StartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
837   // The gather is pre-incrementing
838   OffsetsIncoming = BinaryOperator::Create(
839       Instruction::Sub, OffsetsIncoming,
840       Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)),
841       "PreIncrementStartIndex",
842       &Phi->getIncomingBlock(1 - IncrementIndex)->back());
843   Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming);
844 
845   Builder.SetInsertPoint(I);
846 
847   Instruction *EndResult;
848   Instruction *NewInduction;
849   if (I->getIntrinsicID() == Intrinsic::masked_gather) {
850     // Build the incrementing gather
851     Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate);
852     // One value to be handed to whoever uses the gather, one is the loop
853     // increment
854     EndResult = ExtractValueInst::Create(Load, 0, "Gather");
855     NewInduction = ExtractValueInst::Create(Load, 1, "GatherIncrement");
856     Builder.Insert(EndResult);
857     Builder.Insert(NewInduction);
858   } else {
859     // Build the incrementing scatter
860     EndResult = NewInduction =
861         tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate);
862   }
863   Instruction *AddInst = cast<Instruction>(Offsets);
864   AddInst->replaceAllUsesWith(NewInduction);
865   AddInst->eraseFromParent();
866   Phi->setIncomingValue(IncrementIndex, NewInduction);
867 
868   return EndResult;
869 }
870 
871 void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,
872                                           Value *OffsSecondOperand,
873                                           unsigned StartIndex) {
874   LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n");
875   Instruction *InsertionPoint =
876         &cast<Instruction>(Phi->getIncomingBlock(StartIndex)->back());
877   // Initialize the phi with a vector that contains a sum of the constants
878   Instruction *NewIndex = BinaryOperator::Create(
879       Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand,
880       "PushedOutAdd", InsertionPoint);
881   unsigned IncrementIndex = StartIndex == 0 ? 1 : 0;
882 
883   // Order such that start index comes first (this reduces mov's)
884   Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex));
885   Phi->addIncoming(Phi->getIncomingValue(IncrementIndex),
886                    Phi->getIncomingBlock(IncrementIndex));
887   Phi->removeIncomingValue(IncrementIndex);
888   Phi->removeIncomingValue(StartIndex);
889 }
890 
891 void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi,
892                                           Value *IncrementPerRound,
893                                           Value *OffsSecondOperand,
894                                           unsigned LoopIncrement,
895                                           IRBuilder<> &Builder) {
896   LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");
897 
898   // Create a new scalar add outside of the loop and transform it to a splat
899   // by which loop variable can be incremented
900   Instruction *InsertionPoint = &cast<Instruction>(
901         Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back());
902 
903   // Create a new index
904   Value *StartIndex = BinaryOperator::Create(
905       Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),
906       OffsSecondOperand, "PushedOutMul", InsertionPoint);
907 
908   Instruction *Product =
909       BinaryOperator::Create(Instruction::Mul, IncrementPerRound,
910                              OffsSecondOperand, "Product", InsertionPoint);
911   // Increment NewIndex by Product instead of the multiplication
912   Instruction *NewIncrement = BinaryOperator::Create(
913       Instruction::Add, Phi, Product, "IncrementPushedOutMul",
914       cast<Instruction>(Phi->getIncomingBlock(LoopIncrement)->back())
915           .getPrevNode());
916 
917   Phi->addIncoming(StartIndex,
918                    Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1));
919   Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement));
920   Phi->removeIncomingValue((unsigned)0);
921   Phi->removeIncomingValue((unsigned)0);
922 }
923 
924 // Check whether all usages of this instruction are as offsets of
925 // gathers/scatters or simple arithmetics only used by gathers/scatters
926 static bool hasAllGatScatUsers(Instruction *I) {
927   if (I->hasNUses(0)) {
928     return false;
929   }
930   bool Gatscat = true;
931   for (User *U : I->users()) {
932     if (!isa<Instruction>(U))
933       return false;
934     if (isa<GetElementPtrInst>(U) ||
935         isGatherScatter(dyn_cast<IntrinsicInst>(U))) {
936       return Gatscat;
937     } else {
938       unsigned OpCode = cast<Instruction>(U)->getOpcode();
939       if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) &&
940           hasAllGatScatUsers(cast<Instruction>(U))) {
941         continue;
942       }
943       return false;
944     }
945   }
946   return Gatscat;
947 }
948 
949 bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
950                                                LoopInfo *LI) {
951   LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize\n"
952                     << *Offsets << "\n");
953   // Optimise the addresses of gathers/scatters by moving invariant
954   // calculations out of the loop
955   if (!isa<Instruction>(Offsets))
956     return false;
957   Instruction *Offs = cast<Instruction>(Offsets);
958   if (Offs->getOpcode() != Instruction::Add &&
959       Offs->getOpcode() != Instruction::Mul)
960     return false;
961   Loop *L = LI->getLoopFor(BB);
962   if (L == nullptr)
963     return false;
964   if (!Offs->hasOneUse()) {
965     if (!hasAllGatScatUsers(Offs))
966       return false;
967   }
968 
969   // Find out which, if any, operand of the instruction
970   // is a phi node
971   PHINode *Phi;
972   int OffsSecondOp;
973   if (isa<PHINode>(Offs->getOperand(0))) {
974     Phi = cast<PHINode>(Offs->getOperand(0));
975     OffsSecondOp = 1;
976   } else if (isa<PHINode>(Offs->getOperand(1))) {
977     Phi = cast<PHINode>(Offs->getOperand(1));
978     OffsSecondOp = 0;
979   } else {
980     bool Changed = false;
981     if (isa<Instruction>(Offs->getOperand(0)) &&
982         L->contains(cast<Instruction>(Offs->getOperand(0))))
983       Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI);
984     if (isa<Instruction>(Offs->getOperand(1)) &&
985         L->contains(cast<Instruction>(Offs->getOperand(1))))
986       Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI);
987     if (!Changed)
988       return false;
989     if (isa<PHINode>(Offs->getOperand(0))) {
990       Phi = cast<PHINode>(Offs->getOperand(0));
991       OffsSecondOp = 1;
992     } else if (isa<PHINode>(Offs->getOperand(1))) {
993       Phi = cast<PHINode>(Offs->getOperand(1));
994       OffsSecondOp = 0;
995     } else {
996       return false;
997     }
998   }
999   // A phi node we want to perform this function on should be from the
1000   // loop header.
1001   if (Phi->getParent() != L->getHeader())
1002     return false;
1003 
1004   // We're looking for a simple add recurrence.
1005   BinaryOperator *IncInstruction;
1006   Value *Start, *IncrementPerRound;
1007   if (!matchSimpleRecurrence(Phi, IncInstruction, Start, IncrementPerRound) ||
1008       IncInstruction->getOpcode() != Instruction::Add)
1009     return false;
1010 
1011   int IncrementingBlock = Phi->getIncomingValue(0) == IncInstruction ? 0 : 1;
1012 
1013   // Get the value that is added to/multiplied with the phi
1014   Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp);
1015 
1016   if (IncrementPerRound->getType() != OffsSecondOperand->getType() ||
1017       !L->isLoopInvariant(OffsSecondOperand))
1018     // Something has gone wrong, abort
1019     return false;
1020 
1021   // Only proceed if the increment per round is a constant or an instruction
1022   // which does not originate from within the loop
1023   if (!isa<Constant>(IncrementPerRound) &&
1024       !(isa<Instruction>(IncrementPerRound) &&
1025         !L->contains(cast<Instruction>(IncrementPerRound))))
1026     return false;
1027 
1028   // If the phi is not used by anything else, we can just adapt it when
1029   // replacing the instruction; if it is, we'll have to duplicate it
1030   PHINode *NewPhi;
1031   if (Phi->getNumUses() == 2) {
1032     // No other users -> reuse existing phi (One user is the instruction
1033     // we're looking at, the other is the phi increment)
1034     if (IncInstruction->getNumUses() != 1) {
1035       // If the incrementing instruction does have more users than
1036       // our phi, we need to copy it
1037       IncInstruction = BinaryOperator::Create(
1038           Instruction::BinaryOps(IncInstruction->getOpcode()), Phi,
1039           IncrementPerRound, "LoopIncrement", IncInstruction);
1040       Phi->setIncomingValue(IncrementingBlock, IncInstruction);
1041     }
1042     NewPhi = Phi;
1043   } else {
1044     // There are other users -> create a new phi
1045     NewPhi = PHINode::Create(Phi->getType(), 2, "NewPhi", Phi);
1046     // Copy the incoming values of the old phi
1047     NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1),
1048                         Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1));
1049     IncInstruction = BinaryOperator::Create(
1050         Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi,
1051         IncrementPerRound, "LoopIncrement", IncInstruction);
1052     NewPhi->addIncoming(IncInstruction,
1053                         Phi->getIncomingBlock(IncrementingBlock));
1054     IncrementingBlock = 1;
1055   }
1056 
1057   IRBuilder<> Builder(BB->getContext());
1058   Builder.SetInsertPoint(Phi);
1059   Builder.SetCurrentDebugLocation(Offs->getDebugLoc());
1060 
1061   switch (Offs->getOpcode()) {
1062   case Instruction::Add:
1063     pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);
1064     break;
1065   case Instruction::Mul:
1066     pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock,
1067                Builder);
1068     break;
1069   default:
1070     return false;
1071   }
1072   LLVM_DEBUG(dbgs() << "masked gathers/scatters: simplified loop variable "
1073                     << "add/mul\n");
1074 
1075   // The instruction has now been "absorbed" into the phi value
1076   Offs->replaceAllUsesWith(NewPhi);
1077   if (Offs->hasNUses(0))
1078     Offs->eraseFromParent();
1079   // Clean up the old increment in case it's unused because we built a new
1080   // one
1081   if (IncInstruction->hasNUses(0))
1082     IncInstruction->eraseFromParent();
1083 
1084   return true;
1085 }
1086 
1087 static Value *CheckAndCreateOffsetAdd(Value *X, Value *Y, Value *GEP,
1088                                       IRBuilder<> &Builder) {
1089   // Splat the non-vector value to a vector of the given type - if the value is
1090   // a constant (and its value isn't too big), we can even use this opportunity
1091   // to scale it to the size of the vector elements
1092   auto FixSummands = [&Builder](FixedVectorType *&VT, Value *&NonVectorVal) {
1093     ConstantInt *Const;
1094     if ((Const = dyn_cast<ConstantInt>(NonVectorVal)) &&
1095         VT->getElementType() != NonVectorVal->getType()) {
1096       unsigned TargetElemSize = VT->getElementType()->getPrimitiveSizeInBits();
1097       uint64_t N = Const->getZExtValue();
1098       if (N < (unsigned)(1 << (TargetElemSize - 1))) {
1099         NonVectorVal = Builder.CreateVectorSplat(
1100             VT->getNumElements(), Builder.getIntN(TargetElemSize, N));
1101         return;
1102       }
1103     }
1104     NonVectorVal =
1105         Builder.CreateVectorSplat(VT->getNumElements(), NonVectorVal);
1106   };
1107 
1108   FixedVectorType *XElType = dyn_cast<FixedVectorType>(X->getType());
1109   FixedVectorType *YElType = dyn_cast<FixedVectorType>(Y->getType());
1110   // If one of X, Y is not a vector, we have to splat it in order
1111   // to add the two of them.
1112   if (XElType && !YElType) {
1113     FixSummands(XElType, Y);
1114     YElType = cast<FixedVectorType>(Y->getType());
1115   } else if (YElType && !XElType) {
1116     FixSummands(YElType, X);
1117     XElType = cast<FixedVectorType>(X->getType());
1118   }
1119   assert(XElType && YElType && "Unknown vector types");
1120   // Check that the summands are of compatible types
1121   if (XElType != YElType) {
1122     LLVM_DEBUG(dbgs() << "masked gathers/scatters: incompatible gep offsets\n");
1123     return nullptr;
1124   }
1125 
1126   if (XElType->getElementType()->getScalarSizeInBits() != 32) {
1127     // Check that by adding the vectors we do not accidentally
1128     // create an overflow
1129     Constant *ConstX = dyn_cast<Constant>(X);
1130     Constant *ConstY = dyn_cast<Constant>(Y);
1131     if (!ConstX || !ConstY)
1132       return nullptr;
1133     unsigned TargetElemSize = 128 / XElType->getNumElements();
1134     for (unsigned i = 0; i < XElType->getNumElements(); i++) {
1135       ConstantInt *ConstXEl =
1136           dyn_cast<ConstantInt>(ConstX->getAggregateElement(i));
1137       ConstantInt *ConstYEl =
1138           dyn_cast<ConstantInt>(ConstY->getAggregateElement(i));
1139       if (!ConstXEl || !ConstYEl ||
1140           ConstXEl->getZExtValue() + ConstYEl->getZExtValue() >=
1141               (unsigned)(1 << (TargetElemSize - 1)))
1142         return nullptr;
1143     }
1144   }
1145 
1146   Value *Add = Builder.CreateAdd(X, Y);
1147 
1148   FixedVectorType *GEPType = cast<FixedVectorType>(GEP->getType());
1149   if (checkOffsetSize(Add, GEPType->getNumElements()))
1150     return Add;
1151   else
1152     return nullptr;
1153 }
1154 
1155 Value *MVEGatherScatterLowering::foldGEP(GetElementPtrInst *GEP,
1156                                          Value *&Offsets,
1157                                          IRBuilder<> &Builder) {
1158   Value *GEPPtr = GEP->getPointerOperand();
1159   Offsets = GEP->getOperand(1);
1160   // We only merge geps with constant offsets, because only for those
1161   // we can make sure that we do not cause an overflow
1162   if (!isa<Constant>(Offsets))
1163     return nullptr;
1164   GetElementPtrInst *BaseGEP;
1165   if ((BaseGEP = dyn_cast<GetElementPtrInst>(GEPPtr))) {
1166     // Merge the two geps into one
1167     Value *BaseBasePtr = foldGEP(BaseGEP, Offsets, Builder);
1168     if (!BaseBasePtr)
1169       return nullptr;
1170     Offsets =
1171         CheckAndCreateOffsetAdd(Offsets, GEP->getOperand(1), GEP, Builder);
1172     if (Offsets == nullptr)
1173       return nullptr;
1174     return BaseBasePtr;
1175   }
1176   return GEPPtr;
1177 }
1178 
1179 bool MVEGatherScatterLowering::optimiseAddress(Value *Address, BasicBlock *BB,
1180                                                LoopInfo *LI) {
1181   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Address);
1182   if (!GEP)
1183     return false;
1184   bool Changed = false;
1185   if (GEP->hasOneUse() &&
1186       dyn_cast<GetElementPtrInst>(GEP->getPointerOperand())) {
1187     IRBuilder<> Builder(GEP->getContext());
1188     Builder.SetInsertPoint(GEP);
1189     Builder.SetCurrentDebugLocation(GEP->getDebugLoc());
1190     Value *Offsets;
1191     Value *Base = foldGEP(GEP, Offsets, Builder);
1192     // We only want to merge the geps if there is a real chance that they can be
1193     // used by an MVE gather; thus the offset has to have the correct size
1194     // (always i32 if it is not of vector type) and the base has to be a
1195     // pointer.
1196     if (Offsets && Base && Base != GEP) {
1197       GetElementPtrInst *NewAddress = GetElementPtrInst::Create(
1198           GEP->getSourceElementType(), Base, Offsets, "gep.merged", GEP);
1199       GEP->replaceAllUsesWith(NewAddress);
1200       GEP = NewAddress;
1201       Changed = true;
1202     }
1203   }
1204   Changed |= optimiseOffsets(GEP->getOperand(1), GEP->getParent(), LI);
1205   return Changed;
1206 }
1207 
1208 bool MVEGatherScatterLowering::runOnFunction(Function &F) {
1209   if (!EnableMaskedGatherScatters)
1210     return false;
1211   auto &TPC = getAnalysis<TargetPassConfig>();
1212   auto &TM = TPC.getTM<TargetMachine>();
1213   auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
1214   if (!ST->hasMVEIntegerOps())
1215     return false;
1216   LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1217   SmallVector<IntrinsicInst *, 4> Gathers;
1218   SmallVector<IntrinsicInst *, 4> Scatters;
1219 
1220   bool Changed = false;
1221 
1222   for (BasicBlock &BB : F) {
1223     Changed |= SimplifyInstructionsInBlock(&BB);
1224 
1225     for (Instruction &I : BB) {
1226       IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
1227       if (II && II->getIntrinsicID() == Intrinsic::masked_gather &&
1228           isa<FixedVectorType>(II->getType())) {
1229         Gathers.push_back(II);
1230         Changed |= optimiseAddress(II->getArgOperand(0), II->getParent(), LI);
1231       } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter &&
1232                  isa<FixedVectorType>(II->getArgOperand(0)->getType())) {
1233         Scatters.push_back(II);
1234         Changed |= optimiseAddress(II->getArgOperand(1), II->getParent(), LI);
1235       }
1236     }
1237   }
1238   for (unsigned i = 0; i < Gathers.size(); i++) {
1239     IntrinsicInst *I = Gathers[i];
1240     Instruction *L = lowerGather(I);
1241     if (L == nullptr)
1242       continue;
1243 
1244     // Get rid of any now dead instructions
1245     SimplifyInstructionsInBlock(L->getParent());
1246     Changed = true;
1247   }
1248 
1249   for (unsigned i = 0; i < Scatters.size(); i++) {
1250     IntrinsicInst *I = Scatters[i];
1251     Instruction *S = lowerScatter(I);
1252     if (S == nullptr)
1253       continue;
1254 
1255     // Get rid of any now dead instructions
1256     SimplifyInstructionsInBlock(S->getParent());
1257     Changed = true;
1258   }
1259   return Changed;
1260 }
1261