xref: /freebsd/contrib/llvm-project/llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp (revision 3ceba58a7509418b47b8fca2d2b6bbf088714e26)
1 //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// This pass custom lowers llvm.gather and llvm.scatter instructions to
10 /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to
11 /// produce a better final result as we go.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "ARM.h"
16 #include "ARMBaseInstrInfo.h"
17 #include "ARMSubtarget.h"
18 #include "llvm/Analysis/LoopInfo.h"
19 #include "llvm/Analysis/TargetTransformInfo.h"
20 #include "llvm/Analysis/ValueTracking.h"
21 #include "llvm/CodeGen/TargetLowering.h"
22 #include "llvm/CodeGen/TargetPassConfig.h"
23 #include "llvm/CodeGen/TargetSubtargetInfo.h"
24 #include "llvm/InitializePasses.h"
25 #include "llvm/IR/BasicBlock.h"
26 #include "llvm/IR/Constant.h"
27 #include "llvm/IR/Constants.h"
28 #include "llvm/IR/DerivedTypes.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/InstrTypes.h"
31 #include "llvm/IR/Instruction.h"
32 #include "llvm/IR/Instructions.h"
33 #include "llvm/IR/IntrinsicInst.h"
34 #include "llvm/IR/Intrinsics.h"
35 #include "llvm/IR/IntrinsicsARM.h"
36 #include "llvm/IR/IRBuilder.h"
37 #include "llvm/IR/PatternMatch.h"
38 #include "llvm/IR/Type.h"
39 #include "llvm/IR/Value.h"
40 #include "llvm/Pass.h"
41 #include "llvm/Support/Casting.h"
42 #include "llvm/Transforms/Utils/Local.h"
43 #include <algorithm>
44 #include <cassert>
45 
46 using namespace llvm;
47 
48 #define DEBUG_TYPE "arm-mve-gather-scatter-lowering"
49 
50 cl::opt<bool> EnableMaskedGatherScatters(
51     "enable-arm-maskedgatscat", cl::Hidden, cl::init(true),
52     cl::desc("Enable the generation of masked gathers and scatters"));
53 
54 namespace {
55 
56 class MVEGatherScatterLowering : public FunctionPass {
57 public:
58   static char ID; // Pass identification, replacement for typeid
59 
60   explicit MVEGatherScatterLowering() : FunctionPass(ID) {
61     initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry());
62   }
63 
64   bool runOnFunction(Function &F) override;
65 
66   StringRef getPassName() const override {
67     return "MVE gather/scatter lowering";
68   }
69 
70   void getAnalysisUsage(AnalysisUsage &AU) const override {
71     AU.setPreservesCFG();
72     AU.addRequired<TargetPassConfig>();
73     AU.addRequired<LoopInfoWrapperPass>();
74     FunctionPass::getAnalysisUsage(AU);
75   }
76 
77 private:
78   LoopInfo *LI = nullptr;
79   const DataLayout *DL;
80 
81   // Check this is a valid gather with correct alignment
82   bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,
83                                Align Alignment);
84   // Check whether Ptr is hidden behind a bitcast and look through it
85   void lookThroughBitcast(Value *&Ptr);
86   // Decompose a ptr into Base and Offsets, potentially using a GEP to return a
87   // scalar base and vector offsets, or else fallback to using a base of 0 and
88   // offset of Ptr where possible.
89   Value *decomposePtr(Value *Ptr, Value *&Offsets, int &Scale,
90                       FixedVectorType *Ty, Type *MemoryTy,
91                       IRBuilder<> &Builder);
92   // Check for a getelementptr and deduce base and offsets from it, on success
93   // returning the base directly and the offsets indirectly using the Offsets
94   // argument
95   Value *decomposeGEP(Value *&Offsets, FixedVectorType *Ty,
96                       GetElementPtrInst *GEP, IRBuilder<> &Builder);
97   // Compute the scale of this gather/scatter instruction
98   int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize);
99   // If the value is a constant, or derived from constants via additions
100   // and multilications, return its numeric value
101   std::optional<int64_t> getIfConst(const Value *V);
102   // If Inst is an add instruction, check whether one summand is a
103   // constant. If so, scale this constant and return it together with
104   // the other summand.
105   std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale);
106 
107   Instruction *lowerGather(IntrinsicInst *I);
108   // Create a gather from a base + vector of offsets
109   Instruction *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,
110                                            Instruction *&Root,
111                                            IRBuilder<> &Builder);
112   // Create a gather from a vector of pointers
113   Instruction *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,
114                                          IRBuilder<> &Builder,
115                                          int64_t Increment = 0);
116   // Create an incrementing gather from a vector of pointers
117   Instruction *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr,
118                                            IRBuilder<> &Builder,
119                                            int64_t Increment = 0);
120 
121   Instruction *lowerScatter(IntrinsicInst *I);
122   // Create a scatter to a base + vector of offsets
123   Instruction *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets,
124                                             IRBuilder<> &Builder);
125   // Create a scatter to a vector of pointers
126   Instruction *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr,
127                                           IRBuilder<> &Builder,
128                                           int64_t Increment = 0);
129   // Create an incrementing scatter from a vector of pointers
130   Instruction *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr,
131                                             IRBuilder<> &Builder,
132                                             int64_t Increment = 0);
133 
134   // QI gathers and scatters can increment their offsets on their own if
135   // the increment is a constant value (digit)
136   Instruction *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *Ptr,
137                                             IRBuilder<> &Builder);
138   // QI gathers/scatters can increment their offsets on their own if the
139   // increment is a constant value (digit) - this creates a writeback QI
140   // gather/scatter
141   Instruction *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr,
142                                               Value *Ptr, unsigned TypeScale,
143                                               IRBuilder<> &Builder);
144 
145   // Optimise the base and offsets of the given address
146   bool optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI);
147   // Try to fold consecutive geps together into one
148   Value *foldGEP(GetElementPtrInst *GEP, Value *&Offsets, unsigned &Scale,
149                  IRBuilder<> &Builder);
150   // Check whether these offsets could be moved out of the loop they're in
151   bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI);
152   // Pushes the given add out of the loop
153   void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);
154   // Pushes the given mul or shl out of the loop
155   void pushOutMulShl(unsigned Opc, PHINode *&Phi, Value *IncrementPerRound,
156                      Value *OffsSecondOperand, unsigned LoopIncrement,
157                      IRBuilder<> &Builder);
158 };
159 
160 } // end anonymous namespace
161 
162 char MVEGatherScatterLowering::ID = 0;
163 
164 INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE,
165                 "MVE gather/scattering lowering pass", false, false)
166 
167 Pass *llvm::createMVEGatherScatterLoweringPass() {
168   return new MVEGatherScatterLowering();
169 }
170 
171 bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,
172                                                        unsigned ElemSize,
173                                                        Align Alignment) {
174   if (((NumElements == 4 &&
175         (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) ||
176        (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) ||
177        (NumElements == 16 && ElemSize == 8)) &&
178       Alignment >= ElemSize / 8)
179     return true;
180   LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have "
181                     << "valid alignment or vector type \n");
182   return false;
183 }
184 
185 static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount) {
186   // Offsets that are not of type <N x i32> are sign extended by the
187   // getelementptr instruction, and MVE gathers/scatters treat the offset as
188   // unsigned. Thus, if the element size is smaller than 32, we can only allow
189   // positive offsets - i.e., the offsets are not allowed to be variables we
190   // can't look into.
191   // Additionally, <N x i32> offsets have to either originate from a zext of a
192   // vector with element types smaller or equal the type of the gather we're
193   // looking at, or consist of constants that we can check are small enough
194   // to fit into the gather type.
195   // Thus we check that 0 < value < 2^TargetElemSize.
196   unsigned TargetElemSize = 128 / TargetElemCount;
197   unsigned OffsetElemSize = cast<FixedVectorType>(Offsets->getType())
198                                 ->getElementType()
199                                 ->getScalarSizeInBits();
200   if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) {
201     Constant *ConstOff = dyn_cast<Constant>(Offsets);
202     if (!ConstOff)
203       return false;
204     int64_t TargetElemMaxSize = (1ULL << TargetElemSize);
205     auto CheckValueSize = [TargetElemMaxSize](Value *OffsetElem) {
206       ConstantInt *OConst = dyn_cast<ConstantInt>(OffsetElem);
207       if (!OConst)
208         return false;
209       int SExtValue = OConst->getSExtValue();
210       if (SExtValue >= TargetElemMaxSize || SExtValue < 0)
211         return false;
212       return true;
213     };
214     if (isa<FixedVectorType>(ConstOff->getType())) {
215       for (unsigned i = 0; i < TargetElemCount; i++) {
216         if (!CheckValueSize(ConstOff->getAggregateElement(i)))
217           return false;
218       }
219     } else {
220       if (!CheckValueSize(ConstOff))
221         return false;
222     }
223   }
224   return true;
225 }
226 
227 Value *MVEGatherScatterLowering::decomposePtr(Value *Ptr, Value *&Offsets,
228                                               int &Scale, FixedVectorType *Ty,
229                                               Type *MemoryTy,
230                                               IRBuilder<> &Builder) {
231   if (auto *GEP = dyn_cast<GetElementPtrInst>(Ptr)) {
232     if (Value *V = decomposeGEP(Offsets, Ty, GEP, Builder)) {
233       Scale =
234           computeScale(GEP->getSourceElementType()->getPrimitiveSizeInBits(),
235                        MemoryTy->getScalarSizeInBits());
236       return Scale == -1 ? nullptr : V;
237     }
238   }
239 
240   // If we couldn't use the GEP (or it doesn't exist), attempt to use a
241   // BasePtr of 0 with Ptr as the Offsets, so long as there are only 4
242   // elements.
243   FixedVectorType *PtrTy = cast<FixedVectorType>(Ptr->getType());
244   if (PtrTy->getNumElements() != 4 || MemoryTy->getScalarSizeInBits() == 32)
245     return nullptr;
246   Value *Zero = ConstantInt::get(Builder.getInt32Ty(), 0);
247   Value *BasePtr = Builder.CreateIntToPtr(Zero, Builder.getPtrTy());
248   Offsets = Builder.CreatePtrToInt(
249       Ptr, FixedVectorType::get(Builder.getInt32Ty(), 4));
250   Scale = 0;
251   return BasePtr;
252 }
253 
254 Value *MVEGatherScatterLowering::decomposeGEP(Value *&Offsets,
255                                               FixedVectorType *Ty,
256                                               GetElementPtrInst *GEP,
257                                               IRBuilder<> &Builder) {
258   if (!GEP) {
259     LLVM_DEBUG(dbgs() << "masked gathers/scatters: no getelementpointer "
260                       << "found\n");
261     return nullptr;
262   }
263   LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found."
264                     << " Looking at intrinsic for base + vector of offsets\n");
265   Value *GEPPtr = GEP->getPointerOperand();
266   Offsets = GEP->getOperand(1);
267   if (GEPPtr->getType()->isVectorTy() ||
268       !isa<FixedVectorType>(Offsets->getType()))
269     return nullptr;
270 
271   if (GEP->getNumOperands() != 2) {
272     LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many"
273                       << " operands. Expanding.\n");
274     return nullptr;
275   }
276   Offsets = GEP->getOperand(1);
277   unsigned OffsetsElemCount =
278       cast<FixedVectorType>(Offsets->getType())->getNumElements();
279   // Paranoid check whether the number of parallel lanes is the same
280   assert(Ty->getNumElements() == OffsetsElemCount);
281 
282   ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets);
283   if (ZextOffs)
284     Offsets = ZextOffs->getOperand(0);
285   FixedVectorType *OffsetType = cast<FixedVectorType>(Offsets->getType());
286 
287   // If the offsets are already being zext-ed to <N x i32>, that relieves us of
288   // having to make sure that they won't overflow.
289   if (!ZextOffs || cast<FixedVectorType>(ZextOffs->getDestTy())
290                            ->getElementType()
291                            ->getScalarSizeInBits() != 32)
292     if (!checkOffsetSize(Offsets, OffsetsElemCount))
293       return nullptr;
294 
295   // The offset sizes have been checked; if any truncating or zext-ing is
296   // required to fix them, do that now
297   if (Ty != Offsets->getType()) {
298     if ((Ty->getElementType()->getScalarSizeInBits() <
299          OffsetType->getElementType()->getScalarSizeInBits())) {
300       Offsets = Builder.CreateTrunc(Offsets, Ty);
301     } else {
302       Offsets = Builder.CreateZExt(Offsets, VectorType::getInteger(Ty));
303     }
304   }
305   // If none of the checks failed, return the gep's base pointer
306   LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n");
307   return GEPPtr;
308 }
309 
310 void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) {
311   // Look through bitcast instruction if #elements is the same
312   if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) {
313     auto *BCTy = cast<FixedVectorType>(BitCast->getType());
314     auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType());
315     if (BCTy->getNumElements() == BCSrcTy->getNumElements()) {
316       LLVM_DEBUG(dbgs() << "masked gathers/scatters: looking through "
317                         << "bitcast\n");
318       Ptr = BitCast->getOperand(0);
319     }
320   }
321 }
322 
323 int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize,
324                                            unsigned MemoryElemSize) {
325   // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2,
326   // or a 8bit, 16bit or 32bit load/store scaled by 1
327   if (GEPElemSize == 32 && MemoryElemSize == 32)
328     return 2;
329   else if (GEPElemSize == 16 && MemoryElemSize == 16)
330     return 1;
331   else if (GEPElemSize == 8)
332     return 0;
333   LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't "
334                     << "create intrinsic\n");
335   return -1;
336 }
337 
338 std::optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
339   const Constant *C = dyn_cast<Constant>(V);
340   if (C && C->getSplatValue())
341     return std::optional<int64_t>{C->getUniqueInteger().getSExtValue()};
342   if (!isa<Instruction>(V))
343     return std::optional<int64_t>{};
344 
345   const Instruction *I = cast<Instruction>(V);
346   if (I->getOpcode() == Instruction::Add || I->getOpcode() == Instruction::Or ||
347       I->getOpcode() == Instruction::Mul ||
348       I->getOpcode() == Instruction::Shl) {
349     std::optional<int64_t> Op0 = getIfConst(I->getOperand(0));
350     std::optional<int64_t> Op1 = getIfConst(I->getOperand(1));
351     if (!Op0 || !Op1)
352       return std::optional<int64_t>{};
353     if (I->getOpcode() == Instruction::Add)
354       return std::optional<int64_t>{*Op0 + *Op1};
355     if (I->getOpcode() == Instruction::Mul)
356       return std::optional<int64_t>{*Op0 * *Op1};
357     if (I->getOpcode() == Instruction::Shl)
358       return std::optional<int64_t>{*Op0 << *Op1};
359     if (I->getOpcode() == Instruction::Or)
360       return std::optional<int64_t>{*Op0 | *Op1};
361   }
362   return std::optional<int64_t>{};
363 }
364 
365 // Return true if I is an Or instruction that is equivalent to an add, due to
366 // the operands having no common bits set.
367 static bool isAddLikeOr(Instruction *I, const DataLayout &DL) {
368   return I->getOpcode() == Instruction::Or &&
369          haveNoCommonBitsSet(I->getOperand(0), I->getOperand(1), DL);
370 }
371 
372 std::pair<Value *, int64_t>
373 MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) {
374   std::pair<Value *, int64_t> ReturnFalse =
375       std::pair<Value *, int64_t>(nullptr, 0);
376   // At this point, the instruction we're looking at must be an add or an
377   // add-like-or.
378   Instruction *Add = dyn_cast<Instruction>(Inst);
379   if (Add == nullptr ||
380       (Add->getOpcode() != Instruction::Add && !isAddLikeOr(Add, *DL)))
381     return ReturnFalse;
382 
383   Value *Summand;
384   std::optional<int64_t> Const;
385   // Find out which operand the value that is increased is
386   if ((Const = getIfConst(Add->getOperand(0))))
387     Summand = Add->getOperand(1);
388   else if ((Const = getIfConst(Add->getOperand(1))))
389     Summand = Add->getOperand(0);
390   else
391     return ReturnFalse;
392 
393   // Check that the constant is small enough for an incrementing gather
394   int64_t Immediate = *Const << TypeScale;
395   if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0)
396     return ReturnFalse;
397 
398   return std::pair<Value *, int64_t>(Summand, Immediate);
399 }
400 
401 Instruction *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
402   using namespace PatternMatch;
403   LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n"
404                     << *I << "\n");
405 
406   // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
407   // Attempt to turn the masked gather in I into a MVE intrinsic
408   // Potentially optimising the addressing modes as we do so.
409   auto *Ty = cast<FixedVectorType>(I->getType());
410   Value *Ptr = I->getArgOperand(0);
411   Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue();
412   Value *Mask = I->getArgOperand(2);
413   Value *PassThru = I->getArgOperand(3);
414 
415   if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
416                                Alignment))
417     return nullptr;
418   lookThroughBitcast(Ptr);
419   assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
420 
421   IRBuilder<> Builder(I->getContext());
422   Builder.SetInsertPoint(I);
423   Builder.SetCurrentDebugLocation(I->getDebugLoc());
424 
425   Instruction *Root = I;
426 
427   Instruction *Load = tryCreateIncrementingGatScat(I, Ptr, Builder);
428   if (!Load)
429     Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder);
430   if (!Load)
431     Load = tryCreateMaskedGatherBase(I, Ptr, Builder);
432   if (!Load)
433     return nullptr;
434 
435   if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) {
436     LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
437                       << "creating select\n");
438     Load = SelectInst::Create(Mask, Load, PassThru);
439     Builder.Insert(Load);
440   }
441 
442   Root->replaceAllUsesWith(Load);
443   Root->eraseFromParent();
444   if (Root != I)
445     // If this was an extending gather, we need to get rid of the sext/zext
446     // sext/zext as well as of the gather itself
447     I->eraseFromParent();
448 
449   LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n"
450                     << *Load << "\n");
451   return Load;
452 }
453 
454 Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBase(
455     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
456   using namespace PatternMatch;
457   auto *Ty = cast<FixedVectorType>(I->getType());
458   LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
459   if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
460     // Can't build an intrinsic for this
461     return nullptr;
462   Value *Mask = I->getArgOperand(2);
463   if (match(Mask, m_One()))
464     return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
465                                    {Ty, Ptr->getType()},
466                                    {Ptr, Builder.getInt32(Increment)});
467   else
468     return Builder.CreateIntrinsic(
469         Intrinsic::arm_mve_vldr_gather_base_predicated,
470         {Ty, Ptr->getType(), Mask->getType()},
471         {Ptr, Builder.getInt32(Increment), Mask});
472 }
473 
474 Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB(
475     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
476   using namespace PatternMatch;
477   auto *Ty = cast<FixedVectorType>(I->getType());
478   LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers with "
479                     << "writeback\n");
480   if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
481     // Can't build an intrinsic for this
482     return nullptr;
483   Value *Mask = I->getArgOperand(2);
484   if (match(Mask, m_One()))
485     return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb,
486                                    {Ty, Ptr->getType()},
487                                    {Ptr, Builder.getInt32(Increment)});
488   else
489     return Builder.CreateIntrinsic(
490         Intrinsic::arm_mve_vldr_gather_base_wb_predicated,
491         {Ty, Ptr->getType(), Mask->getType()},
492         {Ptr, Builder.getInt32(Increment), Mask});
493 }
494 
495 Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
496     IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) {
497   using namespace PatternMatch;
498 
499   Type *MemoryTy = I->getType();
500   Type *ResultTy = MemoryTy;
501 
502   unsigned Unsigned = 1;
503   // The size of the gather was already checked in isLegalTypeAndAlignment;
504   // if it was not a full vector width an appropriate extend should follow.
505   auto *Extend = Root;
506   bool TruncResult = false;
507   if (MemoryTy->getPrimitiveSizeInBits() < 128) {
508     if (I->hasOneUse()) {
509       // If the gather has a single extend of the correct type, use an extending
510       // gather and replace the ext. In which case the correct root to replace
511       // is not the CallInst itself, but the instruction which extends it.
512       Instruction* User = cast<Instruction>(*I->users().begin());
513       if (isa<SExtInst>(User) &&
514           User->getType()->getPrimitiveSizeInBits() == 128) {
515         LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: "
516                           << *User << "\n");
517         Extend = User;
518         ResultTy = User->getType();
519         Unsigned = 0;
520       } else if (isa<ZExtInst>(User) &&
521                  User->getType()->getPrimitiveSizeInBits() == 128) {
522         LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: "
523                           << *ResultTy << "\n");
524         Extend = User;
525         ResultTy = User->getType();
526       }
527     }
528 
529     // If an extend hasn't been found and the type is an integer, create an
530     // extending gather and truncate back to the original type.
531     if (ResultTy->getPrimitiveSizeInBits() < 128 &&
532         ResultTy->isIntOrIntVectorTy()) {
533       ResultTy = ResultTy->getWithNewBitWidth(
534           128 / cast<FixedVectorType>(ResultTy)->getNumElements());
535       TruncResult = true;
536       LLVM_DEBUG(dbgs() << "masked gathers: Small input type, truncing to: "
537                         << *ResultTy << "\n");
538     }
539 
540     // The final size of the gather must be a full vector width
541     if (ResultTy->getPrimitiveSizeInBits() != 128) {
542       LLVM_DEBUG(dbgs() << "masked gathers: Extend needed but not provided "
543                            "from the correct type. Expanding\n");
544       return nullptr;
545     }
546   }
547 
548   Value *Offsets;
549   int Scale;
550   Value *BasePtr = decomposePtr(
551       Ptr, Offsets, Scale, cast<FixedVectorType>(ResultTy), MemoryTy, Builder);
552   if (!BasePtr)
553     return nullptr;
554 
555   Root = Extend;
556   Value *Mask = I->getArgOperand(2);
557   Instruction *Load = nullptr;
558   if (!match(Mask, m_One()))
559     Load = Builder.CreateIntrinsic(
560         Intrinsic::arm_mve_vldr_gather_offset_predicated,
561         {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()},
562         {BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()),
563          Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask});
564   else
565     Load = Builder.CreateIntrinsic(
566         Intrinsic::arm_mve_vldr_gather_offset,
567         {ResultTy, BasePtr->getType(), Offsets->getType()},
568         {BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()),
569          Builder.getInt32(Scale), Builder.getInt32(Unsigned)});
570 
571   if (TruncResult) {
572     Load = TruncInst::Create(Instruction::Trunc, Load, MemoryTy);
573     Builder.Insert(Load);
574   }
575   return Load;
576 }
577 
578 Instruction *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) {
579   using namespace PatternMatch;
580   LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n"
581                     << *I << "\n");
582 
583   // @llvm.masked.scatter.*(data, ptrs, alignment, mask)
584   // Attempt to turn the masked scatter in I into a MVE intrinsic
585   // Potentially optimising the addressing modes as we do so.
586   Value *Input = I->getArgOperand(0);
587   Value *Ptr = I->getArgOperand(1);
588   Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue();
589   auto *Ty = cast<FixedVectorType>(Input->getType());
590 
591   if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
592                                Alignment))
593     return nullptr;
594 
595   lookThroughBitcast(Ptr);
596   assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
597 
598   IRBuilder<> Builder(I->getContext());
599   Builder.SetInsertPoint(I);
600   Builder.SetCurrentDebugLocation(I->getDebugLoc());
601 
602   Instruction *Store = tryCreateIncrementingGatScat(I, Ptr, Builder);
603   if (!Store)
604     Store = tryCreateMaskedScatterOffset(I, Ptr, Builder);
605   if (!Store)
606     Store = tryCreateMaskedScatterBase(I, Ptr, Builder);
607   if (!Store)
608     return nullptr;
609 
610   LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n"
611                     << *Store << "\n");
612   I->eraseFromParent();
613   return Store;
614 }
615 
616 Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBase(
617     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
618   using namespace PatternMatch;
619   Value *Input = I->getArgOperand(0);
620   auto *Ty = cast<FixedVectorType>(Input->getType());
621   // Only QR variants allow truncating
622   if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) {
623     // Can't build an intrinsic for this
624     return nullptr;
625   }
626   Value *Mask = I->getArgOperand(3);
627   //  int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask)
628   LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n");
629   if (match(Mask, m_One()))
630     return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base,
631                                    {Ptr->getType(), Input->getType()},
632                                    {Ptr, Builder.getInt32(Increment), Input});
633   else
634     return Builder.CreateIntrinsic(
635         Intrinsic::arm_mve_vstr_scatter_base_predicated,
636         {Ptr->getType(), Input->getType(), Mask->getType()},
637         {Ptr, Builder.getInt32(Increment), Input, Mask});
638 }
639 
640 Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB(
641     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
642   using namespace PatternMatch;
643   Value *Input = I->getArgOperand(0);
644   auto *Ty = cast<FixedVectorType>(Input->getType());
645   LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers "
646                     << "with writeback\n");
647   if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
648     // Can't build an intrinsic for this
649     return nullptr;
650   Value *Mask = I->getArgOperand(3);
651   if (match(Mask, m_One()))
652     return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb,
653                                    {Ptr->getType(), Input->getType()},
654                                    {Ptr, Builder.getInt32(Increment), Input});
655   else
656     return Builder.CreateIntrinsic(
657         Intrinsic::arm_mve_vstr_scatter_base_wb_predicated,
658         {Ptr->getType(), Input->getType(), Mask->getType()},
659         {Ptr, Builder.getInt32(Increment), Input, Mask});
660 }
661 
662 Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterOffset(
663     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
664   using namespace PatternMatch;
665   Value *Input = I->getArgOperand(0);
666   Value *Mask = I->getArgOperand(3);
667   Type *InputTy = Input->getType();
668   Type *MemoryTy = InputTy;
669 
670   LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing"
671                     << " to base + vector of offsets\n");
672   // If the input has been truncated, try to integrate that trunc into the
673   // scatter instruction (we don't care about alignment here)
674   if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) {
675     Value *PreTrunc = Trunc->getOperand(0);
676     Type *PreTruncTy = PreTrunc->getType();
677     if (PreTruncTy->getPrimitiveSizeInBits() == 128) {
678       Input = PreTrunc;
679       InputTy = PreTruncTy;
680     }
681   }
682   bool ExtendInput = false;
683   if (InputTy->getPrimitiveSizeInBits() < 128 &&
684       InputTy->isIntOrIntVectorTy()) {
685     // If we can't find a trunc to incorporate into the instruction, create an
686     // implicit one with a zext, so that we can still create a scatter. We know
687     // that the input type is 4x/8x/16x and of type i8/i16/i32, so any type
688     // smaller than 128 bits will divide evenly into a 128bit vector.
689     InputTy = InputTy->getWithNewBitWidth(
690         128 / cast<FixedVectorType>(InputTy)->getNumElements());
691     ExtendInput = true;
692     LLVM_DEBUG(dbgs() << "masked scatters: Small input type, will extend:\n"
693                       << *Input << "\n");
694   }
695   if (InputTy->getPrimitiveSizeInBits() != 128) {
696     LLVM_DEBUG(dbgs() << "masked scatters: cannot create scatters for "
697                          "non-standard input types. Expanding.\n");
698     return nullptr;
699   }
700 
701   Value *Offsets;
702   int Scale;
703   Value *BasePtr = decomposePtr(
704       Ptr, Offsets, Scale, cast<FixedVectorType>(InputTy), MemoryTy, Builder);
705   if (!BasePtr)
706     return nullptr;
707 
708   if (ExtendInput)
709     Input = Builder.CreateZExt(Input, InputTy);
710   if (!match(Mask, m_One()))
711     return Builder.CreateIntrinsic(
712         Intrinsic::arm_mve_vstr_scatter_offset_predicated,
713         {BasePtr->getType(), Offsets->getType(), Input->getType(),
714          Mask->getType()},
715         {BasePtr, Offsets, Input,
716          Builder.getInt32(MemoryTy->getScalarSizeInBits()),
717          Builder.getInt32(Scale), Mask});
718   else
719     return Builder.CreateIntrinsic(
720         Intrinsic::arm_mve_vstr_scatter_offset,
721         {BasePtr->getType(), Offsets->getType(), Input->getType()},
722         {BasePtr, Offsets, Input,
723          Builder.getInt32(MemoryTy->getScalarSizeInBits()),
724          Builder.getInt32(Scale)});
725 }
726 
727 Instruction *MVEGatherScatterLowering::tryCreateIncrementingGatScat(
728     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
729   FixedVectorType *Ty;
730   if (I->getIntrinsicID() == Intrinsic::masked_gather)
731     Ty = cast<FixedVectorType>(I->getType());
732   else
733     Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType());
734 
735   // Incrementing gathers only exist for v4i32
736   if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
737     return nullptr;
738   // Incrementing gathers are not beneficial outside of a loop
739   Loop *L = LI->getLoopFor(I->getParent());
740   if (L == nullptr)
741     return nullptr;
742 
743   // Decompose the GEP into Base and Offsets
744   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
745   Value *Offsets;
746   Value *BasePtr = decomposeGEP(Offsets, Ty, GEP, Builder);
747   if (!BasePtr)
748     return nullptr;
749 
750   LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
751                        "wb gather/scatter\n");
752 
753   // The gep was in charge of making sure the offsets are scaled correctly
754   // - calculate that factor so it can be applied by hand
755   int TypeScale =
756       computeScale(DL->getTypeSizeInBits(GEP->getOperand(0)->getType()),
757                    DL->getTypeSizeInBits(GEP->getType()) /
758                        cast<FixedVectorType>(GEP->getType())->getNumElements());
759   if (TypeScale == -1)
760     return nullptr;
761 
762   if (GEP->hasOneUse()) {
763     // Only in this case do we want to build a wb gather, because the wb will
764     // change the phi which does affect other users of the gep (which will still
765     // be using the phi in the old way)
766     if (auto *Load = tryCreateIncrementingWBGatScat(I, BasePtr, Offsets,
767                                                     TypeScale, Builder))
768       return Load;
769   }
770 
771   LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
772                        "non-wb gather/scatter\n");
773 
774   std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
775   if (Add.first == nullptr)
776     return nullptr;
777   Value *OffsetsIncoming = Add.first;
778   int64_t Immediate = Add.second;
779 
780   // Make sure the offsets are scaled correctly
781   Instruction *ScaledOffsets = BinaryOperator::Create(
782       Instruction::Shl, OffsetsIncoming,
783       Builder.CreateVectorSplat(Ty->getNumElements(),
784                                 Builder.getInt32(TypeScale)),
785       "ScaledIndex", I->getIterator());
786   // Add the base to the offsets
787   OffsetsIncoming = BinaryOperator::Create(
788       Instruction::Add, ScaledOffsets,
789       Builder.CreateVectorSplat(
790           Ty->getNumElements(),
791           Builder.CreatePtrToInt(
792               BasePtr,
793               cast<VectorType>(ScaledOffsets->getType())->getElementType())),
794       "StartIndex", I->getIterator());
795 
796   if (I->getIntrinsicID() == Intrinsic::masked_gather)
797     return tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate);
798   else
799     return tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate);
800 }
801 
802 Instruction *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat(
803     IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale,
804     IRBuilder<> &Builder) {
805   // Check whether this gather's offset is incremented by a constant - if so,
806   // and the load is of the right type, we can merge this into a QI gather
807   Loop *L = LI->getLoopFor(I->getParent());
808   // Offsets that are worth merging into this instruction will be incremented
809   // by a constant, thus we're looking for an add of a phi and a constant
810   PHINode *Phi = dyn_cast<PHINode>(Offsets);
811   if (Phi == nullptr || Phi->getNumIncomingValues() != 2 ||
812       Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2)
813     // No phi means no IV to write back to; if there is a phi, we expect it
814     // to have exactly two incoming values; the only phis we are interested in
815     // will be loop IV's and have exactly two uses, one in their increment and
816     // one in the gather's gep
817     return nullptr;
818 
819   unsigned IncrementIndex =
820       Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1;
821   // Look through the phi to the phi increment
822   Offsets = Phi->getIncomingValue(IncrementIndex);
823 
824   std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
825   if (Add.first == nullptr)
826     return nullptr;
827   Value *OffsetsIncoming = Add.first;
828   int64_t Immediate = Add.second;
829   if (OffsetsIncoming != Phi)
830     // Then the increment we are looking at is not an increment of the
831     // induction variable, and we don't want to do a writeback
832     return nullptr;
833 
834   Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back());
835   unsigned NumElems =
836       cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements();
837 
838   // Make sure the offsets are scaled correctly
839   Instruction *ScaledOffsets = BinaryOperator::Create(
840       Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex),
841       Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)),
842       "ScaledIndex",
843       Phi->getIncomingBlock(1 - IncrementIndex)->back().getIterator());
844   // Add the base to the offsets
845   OffsetsIncoming = BinaryOperator::Create(
846       Instruction::Add, ScaledOffsets,
847       Builder.CreateVectorSplat(
848           NumElems,
849           Builder.CreatePtrToInt(
850               BasePtr,
851               cast<VectorType>(ScaledOffsets->getType())->getElementType())),
852       "StartIndex",
853       Phi->getIncomingBlock(1 - IncrementIndex)->back().getIterator());
854   // The gather is pre-incrementing
855   OffsetsIncoming = BinaryOperator::Create(
856       Instruction::Sub, OffsetsIncoming,
857       Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)),
858       "PreIncrementStartIndex",
859       Phi->getIncomingBlock(1 - IncrementIndex)->back().getIterator());
860   Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming);
861 
862   Builder.SetInsertPoint(I);
863 
864   Instruction *EndResult;
865   Instruction *NewInduction;
866   if (I->getIntrinsicID() == Intrinsic::masked_gather) {
867     // Build the incrementing gather
868     Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate);
869     // One value to be handed to whoever uses the gather, one is the loop
870     // increment
871     EndResult = ExtractValueInst::Create(Load, 0, "Gather");
872     NewInduction = ExtractValueInst::Create(Load, 1, "GatherIncrement");
873     Builder.Insert(EndResult);
874     Builder.Insert(NewInduction);
875   } else {
876     // Build the incrementing scatter
877     EndResult = NewInduction =
878         tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate);
879   }
880   Instruction *AddInst = cast<Instruction>(Offsets);
881   AddInst->replaceAllUsesWith(NewInduction);
882   AddInst->eraseFromParent();
883   Phi->setIncomingValue(IncrementIndex, NewInduction);
884 
885   return EndResult;
886 }
887 
888 void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,
889                                           Value *OffsSecondOperand,
890                                           unsigned StartIndex) {
891   LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n");
892   BasicBlock::iterator InsertionPoint =
893       Phi->getIncomingBlock(StartIndex)->back().getIterator();
894   // Initialize the phi with a vector that contains a sum of the constants
895   Instruction *NewIndex = BinaryOperator::Create(
896       Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand,
897       "PushedOutAdd", InsertionPoint);
898   unsigned IncrementIndex = StartIndex == 0 ? 1 : 0;
899 
900   // Order such that start index comes first (this reduces mov's)
901   Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex));
902   Phi->addIncoming(Phi->getIncomingValue(IncrementIndex),
903                    Phi->getIncomingBlock(IncrementIndex));
904   Phi->removeIncomingValue(1);
905   Phi->removeIncomingValue((unsigned)0);
906 }
907 
908 void MVEGatherScatterLowering::pushOutMulShl(unsigned Opcode, PHINode *&Phi,
909                                              Value *IncrementPerRound,
910                                              Value *OffsSecondOperand,
911                                              unsigned LoopIncrement,
912                                              IRBuilder<> &Builder) {
913   LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");
914 
915   // Create a new scalar add outside of the loop and transform it to a splat
916   // by which loop variable can be incremented
917   BasicBlock::iterator InsertionPoint =
918       Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back().getIterator();
919 
920   // Create a new index
921   Value *StartIndex =
922       BinaryOperator::Create((Instruction::BinaryOps)Opcode,
923                              Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),
924                              OffsSecondOperand, "PushedOutMul", InsertionPoint);
925 
926   Instruction *Product =
927       BinaryOperator::Create((Instruction::BinaryOps)Opcode, IncrementPerRound,
928                              OffsSecondOperand, "Product", InsertionPoint);
929 
930   BasicBlock::iterator NewIncrInsertPt =
931       Phi->getIncomingBlock(LoopIncrement)->back().getIterator();
932   NewIncrInsertPt = std::prev(NewIncrInsertPt);
933 
934   // Increment NewIndex by Product instead of the multiplication
935   Instruction *NewIncrement = BinaryOperator::Create(
936       Instruction::Add, Phi, Product, "IncrementPushedOutMul", NewIncrInsertPt);
937 
938   Phi->addIncoming(StartIndex,
939                    Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1));
940   Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement));
941   Phi->removeIncomingValue((unsigned)0);
942   Phi->removeIncomingValue((unsigned)0);
943 }
944 
945 // Check whether all usages of this instruction are as offsets of
946 // gathers/scatters or simple arithmetics only used by gathers/scatters
947 static bool hasAllGatScatUsers(Instruction *I, const DataLayout &DL) {
948   if (I->hasNUses(0)) {
949     return false;
950   }
951   bool Gatscat = true;
952   for (User *U : I->users()) {
953     if (!isa<Instruction>(U))
954       return false;
955     if (isa<GetElementPtrInst>(U) ||
956         isGatherScatter(dyn_cast<IntrinsicInst>(U))) {
957       return Gatscat;
958     } else {
959       unsigned OpCode = cast<Instruction>(U)->getOpcode();
960       if ((OpCode == Instruction::Add || OpCode == Instruction::Mul ||
961            OpCode == Instruction::Shl ||
962            isAddLikeOr(cast<Instruction>(U), DL)) &&
963           hasAllGatScatUsers(cast<Instruction>(U), DL)) {
964         continue;
965       }
966       return false;
967     }
968   }
969   return Gatscat;
970 }
971 
972 bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
973                                                LoopInfo *LI) {
974   LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize: "
975                     << *Offsets << "\n");
976   // Optimise the addresses of gathers/scatters by moving invariant
977   // calculations out of the loop
978   if (!isa<Instruction>(Offsets))
979     return false;
980   Instruction *Offs = cast<Instruction>(Offsets);
981   if (Offs->getOpcode() != Instruction::Add && !isAddLikeOr(Offs, *DL) &&
982       Offs->getOpcode() != Instruction::Mul &&
983       Offs->getOpcode() != Instruction::Shl)
984     return false;
985   Loop *L = LI->getLoopFor(BB);
986   if (L == nullptr)
987     return false;
988   if (!Offs->hasOneUse()) {
989     if (!hasAllGatScatUsers(Offs, *DL))
990       return false;
991   }
992 
993   // Find out which, if any, operand of the instruction
994   // is a phi node
995   PHINode *Phi;
996   int OffsSecondOp;
997   if (isa<PHINode>(Offs->getOperand(0))) {
998     Phi = cast<PHINode>(Offs->getOperand(0));
999     OffsSecondOp = 1;
1000   } else if (isa<PHINode>(Offs->getOperand(1))) {
1001     Phi = cast<PHINode>(Offs->getOperand(1));
1002     OffsSecondOp = 0;
1003   } else {
1004     bool Changed = false;
1005     if (isa<Instruction>(Offs->getOperand(0)) &&
1006         L->contains(cast<Instruction>(Offs->getOperand(0))))
1007       Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI);
1008     if (isa<Instruction>(Offs->getOperand(1)) &&
1009         L->contains(cast<Instruction>(Offs->getOperand(1))))
1010       Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI);
1011     if (!Changed)
1012       return false;
1013     if (isa<PHINode>(Offs->getOperand(0))) {
1014       Phi = cast<PHINode>(Offs->getOperand(0));
1015       OffsSecondOp = 1;
1016     } else if (isa<PHINode>(Offs->getOperand(1))) {
1017       Phi = cast<PHINode>(Offs->getOperand(1));
1018       OffsSecondOp = 0;
1019     } else {
1020       return false;
1021     }
1022   }
1023   // A phi node we want to perform this function on should be from the
1024   // loop header.
1025   if (Phi->getParent() != L->getHeader())
1026     return false;
1027 
1028   // We're looking for a simple add recurrence.
1029   BinaryOperator *IncInstruction;
1030   Value *Start, *IncrementPerRound;
1031   if (!matchSimpleRecurrence(Phi, IncInstruction, Start, IncrementPerRound) ||
1032       IncInstruction->getOpcode() != Instruction::Add)
1033     return false;
1034 
1035   int IncrementingBlock = Phi->getIncomingValue(0) == IncInstruction ? 0 : 1;
1036 
1037   // Get the value that is added to/multiplied with the phi
1038   Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp);
1039 
1040   if (IncrementPerRound->getType() != OffsSecondOperand->getType() ||
1041       !L->isLoopInvariant(OffsSecondOperand))
1042     // Something has gone wrong, abort
1043     return false;
1044 
1045   // Only proceed if the increment per round is a constant or an instruction
1046   // which does not originate from within the loop
1047   if (!isa<Constant>(IncrementPerRound) &&
1048       !(isa<Instruction>(IncrementPerRound) &&
1049         !L->contains(cast<Instruction>(IncrementPerRound))))
1050     return false;
1051 
1052   // If the phi is not used by anything else, we can just adapt it when
1053   // replacing the instruction; if it is, we'll have to duplicate it
1054   PHINode *NewPhi;
1055   if (Phi->getNumUses() == 2) {
1056     // No other users -> reuse existing phi (One user is the instruction
1057     // we're looking at, the other is the phi increment)
1058     if (IncInstruction->getNumUses() != 1) {
1059       // If the incrementing instruction does have more users than
1060       // our phi, we need to copy it
1061       IncInstruction = BinaryOperator::Create(
1062           Instruction::BinaryOps(IncInstruction->getOpcode()), Phi,
1063           IncrementPerRound, "LoopIncrement", IncInstruction->getIterator());
1064       Phi->setIncomingValue(IncrementingBlock, IncInstruction);
1065     }
1066     NewPhi = Phi;
1067   } else {
1068     // There are other users -> create a new phi
1069     NewPhi = PHINode::Create(Phi->getType(), 2, "NewPhi", Phi->getIterator());
1070     // Copy the incoming values of the old phi
1071     NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1),
1072                         Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1));
1073     IncInstruction = BinaryOperator::Create(
1074         Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi,
1075         IncrementPerRound, "LoopIncrement", IncInstruction->getIterator());
1076     NewPhi->addIncoming(IncInstruction,
1077                         Phi->getIncomingBlock(IncrementingBlock));
1078     IncrementingBlock = 1;
1079   }
1080 
1081   IRBuilder<> Builder(BB->getContext());
1082   Builder.SetInsertPoint(Phi);
1083   Builder.SetCurrentDebugLocation(Offs->getDebugLoc());
1084 
1085   switch (Offs->getOpcode()) {
1086   case Instruction::Add:
1087   case Instruction::Or:
1088     pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);
1089     break;
1090   case Instruction::Mul:
1091   case Instruction::Shl:
1092     pushOutMulShl(Offs->getOpcode(), NewPhi, IncrementPerRound,
1093                   OffsSecondOperand, IncrementingBlock, Builder);
1094     break;
1095   default:
1096     return false;
1097   }
1098   LLVM_DEBUG(dbgs() << "masked gathers/scatters: simplified loop variable "
1099                     << "add/mul\n");
1100 
1101   // The instruction has now been "absorbed" into the phi value
1102   Offs->replaceAllUsesWith(NewPhi);
1103   if (Offs->hasNUses(0))
1104     Offs->eraseFromParent();
1105   // Clean up the old increment in case it's unused because we built a new
1106   // one
1107   if (IncInstruction->hasNUses(0))
1108     IncInstruction->eraseFromParent();
1109 
1110   return true;
1111 }
1112 
1113 static Value *CheckAndCreateOffsetAdd(Value *X, unsigned ScaleX, Value *Y,
1114                                       unsigned ScaleY, IRBuilder<> &Builder) {
1115   // Splat the non-vector value to a vector of the given type - if the value is
1116   // a constant (and its value isn't too big), we can even use this opportunity
1117   // to scale it to the size of the vector elements
1118   auto FixSummands = [&Builder](FixedVectorType *&VT, Value *&NonVectorVal) {
1119     ConstantInt *Const;
1120     if ((Const = dyn_cast<ConstantInt>(NonVectorVal)) &&
1121         VT->getElementType() != NonVectorVal->getType()) {
1122       unsigned TargetElemSize = VT->getElementType()->getPrimitiveSizeInBits();
1123       uint64_t N = Const->getZExtValue();
1124       if (N < (unsigned)(1 << (TargetElemSize - 1))) {
1125         NonVectorVal = Builder.CreateVectorSplat(
1126             VT->getNumElements(), Builder.getIntN(TargetElemSize, N));
1127         return;
1128       }
1129     }
1130     NonVectorVal =
1131         Builder.CreateVectorSplat(VT->getNumElements(), NonVectorVal);
1132   };
1133 
1134   FixedVectorType *XElType = dyn_cast<FixedVectorType>(X->getType());
1135   FixedVectorType *YElType = dyn_cast<FixedVectorType>(Y->getType());
1136   // If one of X, Y is not a vector, we have to splat it in order
1137   // to add the two of them.
1138   if (XElType && !YElType) {
1139     FixSummands(XElType, Y);
1140     YElType = cast<FixedVectorType>(Y->getType());
1141   } else if (YElType && !XElType) {
1142     FixSummands(YElType, X);
1143     XElType = cast<FixedVectorType>(X->getType());
1144   }
1145   assert(XElType && YElType && "Unknown vector types");
1146   // Check that the summands are of compatible types
1147   if (XElType != YElType) {
1148     LLVM_DEBUG(dbgs() << "masked gathers/scatters: incompatible gep offsets\n");
1149     return nullptr;
1150   }
1151 
1152   if (XElType->getElementType()->getScalarSizeInBits() != 32) {
1153     // Check that by adding the vectors we do not accidentally
1154     // create an overflow
1155     Constant *ConstX = dyn_cast<Constant>(X);
1156     Constant *ConstY = dyn_cast<Constant>(Y);
1157     if (!ConstX || !ConstY)
1158       return nullptr;
1159     unsigned TargetElemSize = 128 / XElType->getNumElements();
1160     for (unsigned i = 0; i < XElType->getNumElements(); i++) {
1161       ConstantInt *ConstXEl =
1162           dyn_cast<ConstantInt>(ConstX->getAggregateElement(i));
1163       ConstantInt *ConstYEl =
1164           dyn_cast<ConstantInt>(ConstY->getAggregateElement(i));
1165       if (!ConstXEl || !ConstYEl ||
1166           ConstXEl->getZExtValue() * ScaleX +
1167                   ConstYEl->getZExtValue() * ScaleY >=
1168               (unsigned)(1 << (TargetElemSize - 1)))
1169         return nullptr;
1170     }
1171   }
1172 
1173   Value *XScale = Builder.CreateVectorSplat(
1174       XElType->getNumElements(),
1175       Builder.getIntN(XElType->getScalarSizeInBits(), ScaleX));
1176   Value *YScale = Builder.CreateVectorSplat(
1177       YElType->getNumElements(),
1178       Builder.getIntN(YElType->getScalarSizeInBits(), ScaleY));
1179   Value *Add = Builder.CreateAdd(Builder.CreateMul(X, XScale),
1180                                  Builder.CreateMul(Y, YScale));
1181 
1182   if (checkOffsetSize(Add, XElType->getNumElements()))
1183     return Add;
1184   else
1185     return nullptr;
1186 }
1187 
1188 Value *MVEGatherScatterLowering::foldGEP(GetElementPtrInst *GEP,
1189                                          Value *&Offsets, unsigned &Scale,
1190                                          IRBuilder<> &Builder) {
1191   Value *GEPPtr = GEP->getPointerOperand();
1192   Offsets = GEP->getOperand(1);
1193   Scale = DL->getTypeAllocSize(GEP->getSourceElementType());
1194   // We only merge geps with constant offsets, because only for those
1195   // we can make sure that we do not cause an overflow
1196   if (GEP->getNumIndices() != 1 || !isa<Constant>(Offsets))
1197     return nullptr;
1198   if (GetElementPtrInst *BaseGEP = dyn_cast<GetElementPtrInst>(GEPPtr)) {
1199     // Merge the two geps into one
1200     Value *BaseBasePtr = foldGEP(BaseGEP, Offsets, Scale, Builder);
1201     if (!BaseBasePtr)
1202       return nullptr;
1203     Offsets = CheckAndCreateOffsetAdd(
1204         Offsets, Scale, GEP->getOperand(1),
1205         DL->getTypeAllocSize(GEP->getSourceElementType()), Builder);
1206     if (Offsets == nullptr)
1207       return nullptr;
1208     Scale = 1; // Scale is always an i8 at this point.
1209     return BaseBasePtr;
1210   }
1211   return GEPPtr;
1212 }
1213 
1214 bool MVEGatherScatterLowering::optimiseAddress(Value *Address, BasicBlock *BB,
1215                                                LoopInfo *LI) {
1216   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Address);
1217   if (!GEP)
1218     return false;
1219   bool Changed = false;
1220   if (GEP->hasOneUse() && isa<GetElementPtrInst>(GEP->getPointerOperand())) {
1221     IRBuilder<> Builder(GEP->getContext());
1222     Builder.SetInsertPoint(GEP);
1223     Builder.SetCurrentDebugLocation(GEP->getDebugLoc());
1224     Value *Offsets;
1225     unsigned Scale;
1226     Value *Base = foldGEP(GEP, Offsets, Scale, Builder);
1227     // We only want to merge the geps if there is a real chance that they can be
1228     // used by an MVE gather; thus the offset has to have the correct size
1229     // (always i32 if it is not of vector type) and the base has to be a
1230     // pointer.
1231     if (Offsets && Base && Base != GEP) {
1232       assert(Scale == 1 && "Expected to fold GEP to a scale of 1");
1233       Type *BaseTy = Builder.getPtrTy();
1234       if (auto *VecTy = dyn_cast<FixedVectorType>(Base->getType()))
1235         BaseTy = FixedVectorType::get(BaseTy, VecTy);
1236       GetElementPtrInst *NewAddress = GetElementPtrInst::Create(
1237           Builder.getInt8Ty(), Builder.CreateBitCast(Base, BaseTy), Offsets,
1238           "gep.merged", GEP->getIterator());
1239       LLVM_DEBUG(dbgs() << "Folded GEP: " << *GEP
1240                         << "\n      new :  " << *NewAddress << "\n");
1241       GEP->replaceAllUsesWith(
1242           Builder.CreateBitCast(NewAddress, GEP->getType()));
1243       GEP = NewAddress;
1244       Changed = true;
1245     }
1246   }
1247   Changed |= optimiseOffsets(GEP->getOperand(1), GEP->getParent(), LI);
1248   return Changed;
1249 }
1250 
1251 bool MVEGatherScatterLowering::runOnFunction(Function &F) {
1252   if (!EnableMaskedGatherScatters)
1253     return false;
1254   auto &TPC = getAnalysis<TargetPassConfig>();
1255   auto &TM = TPC.getTM<TargetMachine>();
1256   auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
1257   if (!ST->hasMVEIntegerOps())
1258     return false;
1259   LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1260   DL = &F.getDataLayout();
1261   SmallVector<IntrinsicInst *, 4> Gathers;
1262   SmallVector<IntrinsicInst *, 4> Scatters;
1263 
1264   bool Changed = false;
1265 
1266   for (BasicBlock &BB : F) {
1267     Changed |= SimplifyInstructionsInBlock(&BB);
1268 
1269     for (Instruction &I : BB) {
1270       IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
1271       if (II && II->getIntrinsicID() == Intrinsic::masked_gather &&
1272           isa<FixedVectorType>(II->getType())) {
1273         Gathers.push_back(II);
1274         Changed |= optimiseAddress(II->getArgOperand(0), II->getParent(), LI);
1275       } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter &&
1276                  isa<FixedVectorType>(II->getArgOperand(0)->getType())) {
1277         Scatters.push_back(II);
1278         Changed |= optimiseAddress(II->getArgOperand(1), II->getParent(), LI);
1279       }
1280     }
1281   }
1282   for (IntrinsicInst *I : Gathers) {
1283     Instruction *L = lowerGather(I);
1284     if (L == nullptr)
1285       continue;
1286 
1287     // Get rid of any now dead instructions
1288     SimplifyInstructionsInBlock(L->getParent());
1289     Changed = true;
1290   }
1291 
1292   for (IntrinsicInst *I : Scatters) {
1293     Instruction *S = lowerScatter(I);
1294     if (S == nullptr)
1295       continue;
1296 
1297     // Get rid of any now dead instructions
1298     SimplifyInstructionsInBlock(S->getParent());
1299     Changed = true;
1300   }
1301   return Changed;
1302 }
1303