xref: /freebsd/contrib/llvm-project/llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp (revision c66ec88fed842fbaad62c30d510644ceb7bd2d71)
1 //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// This pass custom lowers llvm.gather and llvm.scatter instructions to
10 /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to
11 /// produce a better final result as we go.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "ARM.h"
16 #include "ARMBaseInstrInfo.h"
17 #include "ARMSubtarget.h"
18 #include "llvm/Analysis/LoopInfo.h"
19 #include "llvm/Analysis/TargetTransformInfo.h"
20 #include "llvm/CodeGen/TargetLowering.h"
21 #include "llvm/CodeGen/TargetPassConfig.h"
22 #include "llvm/CodeGen/TargetSubtargetInfo.h"
23 #include "llvm/InitializePasses.h"
24 #include "llvm/IR/BasicBlock.h"
25 #include "llvm/IR/Constant.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/DerivedTypes.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/InstrTypes.h"
30 #include "llvm/IR/Instruction.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/IntrinsicInst.h"
33 #include "llvm/IR/Intrinsics.h"
34 #include "llvm/IR/IntrinsicsARM.h"
35 #include "llvm/IR/IRBuilder.h"
36 #include "llvm/IR/PatternMatch.h"
37 #include "llvm/IR/Type.h"
38 #include "llvm/IR/Value.h"
39 #include "llvm/Pass.h"
40 #include "llvm/Support/Casting.h"
41 #include "llvm/Transforms/Utils/Local.h"
42 #include <algorithm>
43 #include <cassert>
44 
45 using namespace llvm;
46 
47 #define DEBUG_TYPE "mve-gather-scatter-lowering"
48 
49 cl::opt<bool> EnableMaskedGatherScatters(
50     "enable-arm-maskedgatscat", cl::Hidden, cl::init(false),
51     cl::desc("Enable the generation of masked gathers and scatters"));
52 
53 namespace {
54 
55 class MVEGatherScatterLowering : public FunctionPass {
56 public:
57   static char ID; // Pass identification, replacement for typeid
58 
59   explicit MVEGatherScatterLowering() : FunctionPass(ID) {
60     initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry());
61   }
62 
63   bool runOnFunction(Function &F) override;
64 
65   StringRef getPassName() const override {
66     return "MVE gather/scatter lowering";
67   }
68 
69   void getAnalysisUsage(AnalysisUsage &AU) const override {
70     AU.setPreservesCFG();
71     AU.addRequired<TargetPassConfig>();
72     AU.addRequired<LoopInfoWrapperPass>();
73     FunctionPass::getAnalysisUsage(AU);
74   }
75 
76 private:
77   LoopInfo *LI = nullptr;
78 
79   // Check this is a valid gather with correct alignment
80   bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,
81                                Align Alignment);
82   // Check whether Ptr is hidden behind a bitcast and look through it
83   void lookThroughBitcast(Value *&Ptr);
84   // Check for a getelementptr and deduce base and offsets from it, on success
85   // returning the base directly and the offsets indirectly using the Offsets
86   // argument
87   Value *checkGEP(Value *&Offsets, Type *Ty, GetElementPtrInst *GEP,
88                   IRBuilder<> &Builder);
89   // Compute the scale of this gather/scatter instruction
90   int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize);
91   // If the value is a constant, or derived from constants via additions
92   // and multilications, return its numeric value
93   Optional<int64_t> getIfConst(const Value *V);
94   // If Inst is an add instruction, check whether one summand is a
95   // constant. If so, scale this constant and return it together with
96   // the other summand.
97   std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale);
98 
99   Value *lowerGather(IntrinsicInst *I);
100   // Create a gather from a base + vector of offsets
101   Value *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,
102                                      Instruction *&Root, IRBuilder<> &Builder);
103   // Create a gather from a vector of pointers
104   Value *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,
105                                    IRBuilder<> &Builder, int64_t Increment = 0);
106   // Create an incrementing gather from a vector of pointers
107   Value *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr,
108                                      IRBuilder<> &Builder,
109                                      int64_t Increment = 0);
110 
111   Value *lowerScatter(IntrinsicInst *I);
112   // Create a scatter to a base + vector of offsets
113   Value *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets,
114                                       IRBuilder<> &Builder);
115   // Create a scatter to a vector of pointers
116   Value *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr,
117                                     IRBuilder<> &Builder,
118                                     int64_t Increment = 0);
119   // Create an incrementing scatter from a vector of pointers
120   Value *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr,
121                                       IRBuilder<> &Builder,
122                                       int64_t Increment = 0);
123 
124   // QI gathers and scatters can increment their offsets on their own if
125   // the increment is a constant value (digit)
126   Value *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *BasePtr,
127                                       Value *Ptr, GetElementPtrInst *GEP,
128                                       IRBuilder<> &Builder);
129   // QI gathers/scatters can increment their offsets on their own if the
130   // increment is a constant value (digit) - this creates a writeback QI
131   // gather/scatter
132   Value *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr,
133                                         Value *Ptr, unsigned TypeScale,
134                                         IRBuilder<> &Builder);
135   // Check whether these offsets could be moved out of the loop they're in
136   bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI);
137   // Pushes the given add out of the loop
138   void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);
139   // Pushes the given mul out of the loop
140   void pushOutMul(PHINode *&Phi, Value *IncrementPerRound,
141                   Value *OffsSecondOperand, unsigned LoopIncrement,
142                   IRBuilder<> &Builder);
143 };
144 
145 } // end anonymous namespace
146 
147 char MVEGatherScatterLowering::ID = 0;
148 
149 INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE,
150                 "MVE gather/scattering lowering pass", false, false)
151 
152 Pass *llvm::createMVEGatherScatterLoweringPass() {
153   return new MVEGatherScatterLowering();
154 }
155 
156 bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,
157                                                        unsigned ElemSize,
158                                                        Align Alignment) {
159   if (((NumElements == 4 &&
160         (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) ||
161        (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) ||
162        (NumElements == 16 && ElemSize == 8)) &&
163       Alignment >= ElemSize / 8)
164     return true;
165   LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have "
166                     << "valid alignment or vector type \n");
167   return false;
168 }
169 
170 Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, Type *Ty,
171                                           GetElementPtrInst *GEP,
172                                           IRBuilder<> &Builder) {
173   if (!GEP) {
174     LLVM_DEBUG(
175         dbgs() << "masked gathers/scatters: no getelementpointer found\n");
176     return nullptr;
177   }
178   LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found."
179                     << " Looking at intrinsic for base + vector of offsets\n");
180   Value *GEPPtr = GEP->getPointerOperand();
181   if (GEPPtr->getType()->isVectorTy()) {
182     return nullptr;
183   }
184   if (GEP->getNumOperands() != 2) {
185     LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many"
186                       << " operands. Expanding.\n");
187     return nullptr;
188   }
189   Offsets = GEP->getOperand(1);
190   // Paranoid check whether the number of parallel lanes is the same
191   assert(cast<FixedVectorType>(Ty)->getNumElements() ==
192          cast<FixedVectorType>(Offsets->getType())->getNumElements());
193   // Only <N x i32> offsets can be integrated into an arm gather, any smaller
194   // type would have to be sign extended by the gep - and arm gathers can only
195   // zero extend. Additionally, the offsets do have to originate from a zext of
196   // a vector with element types smaller or equal the type of the gather we're
197   // looking at
198   if (Offsets->getType()->getScalarSizeInBits() != 32)
199     return nullptr;
200   if (ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets))
201     Offsets = ZextOffs->getOperand(0);
202   else if (!(cast<FixedVectorType>(Offsets->getType())->getNumElements() == 4 &&
203              Offsets->getType()->getScalarSizeInBits() == 32))
204     return nullptr;
205 
206   if (Ty != Offsets->getType()) {
207     if ((Ty->getScalarSizeInBits() <
208          Offsets->getType()->getScalarSizeInBits())) {
209       LLVM_DEBUG(dbgs() << "masked gathers/scatters: no correct offset type."
210                         << " Can't create intrinsic.\n");
211       return nullptr;
212     } else {
213       Offsets = Builder.CreateZExt(
214           Offsets, VectorType::getInteger(cast<VectorType>(Ty)));
215     }
216   }
217   // If none of the checks failed, return the gep's base pointer
218   LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n");
219   return GEPPtr;
220 }
221 
222 void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) {
223   // Look through bitcast instruction if #elements is the same
224   if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) {
225     auto *BCTy = cast<FixedVectorType>(BitCast->getType());
226     auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType());
227     if (BCTy->getNumElements() == BCSrcTy->getNumElements()) {
228       LLVM_DEBUG(
229           dbgs() << "masked gathers/scatters: looking through bitcast\n");
230       Ptr = BitCast->getOperand(0);
231     }
232   }
233 }
234 
235 int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize,
236                                            unsigned MemoryElemSize) {
237   // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2,
238   // or a 8bit, 16bit or 32bit load/store scaled by 1
239   if (GEPElemSize == 32 && MemoryElemSize == 32)
240     return 2;
241   else if (GEPElemSize == 16 && MemoryElemSize == 16)
242     return 1;
243   else if (GEPElemSize == 8)
244     return 0;
245   LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't "
246                     << "create intrinsic\n");
247   return -1;
248 }
249 
250 Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
251   const Constant *C = dyn_cast<Constant>(V);
252   if (C != nullptr)
253     return Optional<int64_t>{C->getUniqueInteger().getSExtValue()};
254   if (!isa<Instruction>(V))
255     return Optional<int64_t>{};
256 
257   const Instruction *I = cast<Instruction>(V);
258   if (I->getOpcode() == Instruction::Add ||
259               I->getOpcode() == Instruction::Mul) {
260     Optional<int64_t> Op0 = getIfConst(I->getOperand(0));
261     Optional<int64_t> Op1 = getIfConst(I->getOperand(1));
262     if (!Op0 || !Op1)
263       return Optional<int64_t>{};
264     if (I->getOpcode() == Instruction::Add)
265       return Optional<int64_t>{Op0.getValue() + Op1.getValue()};
266     if (I->getOpcode() == Instruction::Mul)
267       return Optional<int64_t>{Op0.getValue() * Op1.getValue()};
268   }
269   return Optional<int64_t>{};
270 }
271 
272 std::pair<Value *, int64_t>
273 MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) {
274   std::pair<Value *, int64_t> ReturnFalse =
275       std::pair<Value *, int64_t>(nullptr, 0);
276   // At this point, the instruction we're looking at must be an add or we
277   // bail out
278   Instruction *Add = dyn_cast<Instruction>(Inst);
279   if (Add == nullptr || Add->getOpcode() != Instruction::Add)
280     return ReturnFalse;
281 
282   Value *Summand;
283   Optional<int64_t> Const;
284   // Find out which operand the value that is increased is
285   if ((Const = getIfConst(Add->getOperand(0))))
286     Summand = Add->getOperand(1);
287   else if ((Const = getIfConst(Add->getOperand(1))))
288     Summand = Add->getOperand(0);
289   else
290     return ReturnFalse;
291 
292   // Check that the constant is small enough for an incrementing gather
293   int64_t Immediate = Const.getValue() << TypeScale;
294   if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0)
295     return ReturnFalse;
296 
297   return std::pair<Value *, int64_t>(Summand, Immediate);
298 }
299 
300 Value *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
301   using namespace PatternMatch;
302   LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n");
303 
304   // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
305   // Attempt to turn the masked gather in I into a MVE intrinsic
306   // Potentially optimising the addressing modes as we do so.
307   auto *Ty = cast<FixedVectorType>(I->getType());
308   Value *Ptr = I->getArgOperand(0);
309   Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue();
310   Value *Mask = I->getArgOperand(2);
311   Value *PassThru = I->getArgOperand(3);
312 
313   if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
314                                Alignment))
315     return nullptr;
316   lookThroughBitcast(Ptr);
317   assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
318 
319   IRBuilder<> Builder(I->getContext());
320   Builder.SetInsertPoint(I);
321   Builder.SetCurrentDebugLocation(I->getDebugLoc());
322 
323   Instruction *Root = I;
324   Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder);
325   if (!Load)
326     Load = tryCreateMaskedGatherBase(I, Ptr, Builder);
327   if (!Load)
328     return nullptr;
329 
330   if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) {
331     LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
332                       << "creating select\n");
333     Load = Builder.CreateSelect(Mask, Load, PassThru);
334   }
335 
336   Root->replaceAllUsesWith(Load);
337   Root->eraseFromParent();
338   if (Root != I)
339     // If this was an extending gather, we need to get rid of the sext/zext
340     // sext/zext as well as of the gather itself
341     I->eraseFromParent();
342 
343   LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n");
344   return Load;
345 }
346 
347 Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase(IntrinsicInst *I,
348                                                            Value *Ptr,
349                                                            IRBuilder<> &Builder,
350                                                            int64_t Increment) {
351   using namespace PatternMatch;
352   auto *Ty = cast<FixedVectorType>(I->getType());
353   LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
354   if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
355     // Can't build an intrinsic for this
356     return nullptr;
357   Value *Mask = I->getArgOperand(2);
358   if (match(Mask, m_One()))
359     return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
360                                    {Ty, Ptr->getType()},
361                                    {Ptr, Builder.getInt32(Increment)});
362   else
363     return Builder.CreateIntrinsic(
364         Intrinsic::arm_mve_vldr_gather_base_predicated,
365         {Ty, Ptr->getType(), Mask->getType()},
366         {Ptr, Builder.getInt32(Increment), Mask});
367 }
368 
369 Value *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB(
370     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
371   using namespace PatternMatch;
372   auto *Ty = cast<FixedVectorType>(I->getType());
373   LLVM_DEBUG(
374       dbgs()
375       << "masked gathers: loading from vector of pointers with writeback\n");
376   if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
377     // Can't build an intrinsic for this
378     return nullptr;
379   Value *Mask = I->getArgOperand(2);
380   if (match(Mask, m_One()))
381     return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb,
382                                    {Ty, Ptr->getType()},
383                                    {Ptr, Builder.getInt32(Increment)});
384   else
385     return Builder.CreateIntrinsic(
386         Intrinsic::arm_mve_vldr_gather_base_wb_predicated,
387         {Ty, Ptr->getType(), Mask->getType()},
388         {Ptr, Builder.getInt32(Increment), Mask});
389 }
390 
391 Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
392     IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) {
393   using namespace PatternMatch;
394 
395   Type *OriginalTy = I->getType();
396   Type *ResultTy = OriginalTy;
397 
398   unsigned Unsigned = 1;
399   // The size of the gather was already checked in isLegalTypeAndAlignment;
400   // if it was not a full vector width an appropriate extend should follow.
401   auto *Extend = Root;
402   if (OriginalTy->getPrimitiveSizeInBits() < 128) {
403     // Only transform gathers with exactly one use
404     if (!I->hasOneUse())
405       return nullptr;
406 
407     // The correct root to replace is not the CallInst itself, but the
408     // instruction which extends it
409     Extend = cast<Instruction>(*I->users().begin());
410     if (isa<SExtInst>(Extend)) {
411       Unsigned = 0;
412     } else if (!isa<ZExtInst>(Extend)) {
413       LLVM_DEBUG(dbgs() << "masked gathers: extend needed but not provided. "
414                         << "Expanding\n");
415       return nullptr;
416     }
417     LLVM_DEBUG(dbgs() << "masked gathers: found an extending gather\n");
418     ResultTy = Extend->getType();
419     // The final size of the gather must be a full vector width
420     if (ResultTy->getPrimitiveSizeInBits() != 128) {
421       LLVM_DEBUG(dbgs() << "masked gathers: extending from the wrong type. "
422                         << "Expanding\n");
423       return nullptr;
424     }
425   }
426 
427   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
428   Value *Offsets;
429   Value *BasePtr = checkGEP(Offsets, ResultTy, GEP, Builder);
430   if (!BasePtr)
431     return nullptr;
432   // Check whether the offset is a constant increment that could be merged into
433   // a QI gather
434   Value *Load = tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder);
435   if (Load)
436     return Load;
437 
438   int Scale = computeScale(
439       BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(),
440       OriginalTy->getScalarSizeInBits());
441   if (Scale == -1)
442     return nullptr;
443   Root = Extend;
444 
445   Value *Mask = I->getArgOperand(2);
446   if (!match(Mask, m_One()))
447     return Builder.CreateIntrinsic(
448         Intrinsic::arm_mve_vldr_gather_offset_predicated,
449         {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()},
450         {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()),
451          Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask});
452   else
453     return Builder.CreateIntrinsic(
454         Intrinsic::arm_mve_vldr_gather_offset,
455         {ResultTy, BasePtr->getType(), Offsets->getType()},
456         {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()),
457          Builder.getInt32(Scale), Builder.getInt32(Unsigned)});
458 }
459 
460 Value *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) {
461   using namespace PatternMatch;
462   LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n");
463 
464   // @llvm.masked.scatter.*(data, ptrs, alignment, mask)
465   // Attempt to turn the masked scatter in I into a MVE intrinsic
466   // Potentially optimising the addressing modes as we do so.
467   Value *Input = I->getArgOperand(0);
468   Value *Ptr = I->getArgOperand(1);
469   Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue();
470   auto *Ty = cast<FixedVectorType>(Input->getType());
471 
472   if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
473                                Alignment))
474     return nullptr;
475 
476   lookThroughBitcast(Ptr);
477   assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
478 
479   IRBuilder<> Builder(I->getContext());
480   Builder.SetInsertPoint(I);
481   Builder.SetCurrentDebugLocation(I->getDebugLoc());
482 
483   Value *Store = tryCreateMaskedScatterOffset(I, Ptr, Builder);
484   if (!Store)
485     Store = tryCreateMaskedScatterBase(I, Ptr, Builder);
486   if (!Store)
487     return nullptr;
488 
489   LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n");
490   I->eraseFromParent();
491   return Store;
492 }
493 
494 Value *MVEGatherScatterLowering::tryCreateMaskedScatterBase(
495     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
496   using namespace PatternMatch;
497   Value *Input = I->getArgOperand(0);
498   auto *Ty = cast<FixedVectorType>(Input->getType());
499   // Only QR variants allow truncating
500   if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) {
501     // Can't build an intrinsic for this
502     return nullptr;
503   }
504   Value *Mask = I->getArgOperand(3);
505   //  int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask)
506   LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n");
507   if (match(Mask, m_One()))
508     return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base,
509                                    {Ptr->getType(), Input->getType()},
510                                    {Ptr, Builder.getInt32(Increment), Input});
511   else
512     return Builder.CreateIntrinsic(
513         Intrinsic::arm_mve_vstr_scatter_base_predicated,
514         {Ptr->getType(), Input->getType(), Mask->getType()},
515         {Ptr, Builder.getInt32(Increment), Input, Mask});
516 }
517 
518 Value *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB(
519     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
520   using namespace PatternMatch;
521   Value *Input = I->getArgOperand(0);
522   auto *Ty = cast<FixedVectorType>(Input->getType());
523   LLVM_DEBUG(
524       dbgs()
525       << "masked scatters: storing to a vector of pointers with writeback\n");
526   if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
527     // Can't build an intrinsic for this
528     return nullptr;
529   Value *Mask = I->getArgOperand(3);
530   if (match(Mask, m_One()))
531     return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb,
532                                    {Ptr->getType(), Input->getType()},
533                                    {Ptr, Builder.getInt32(Increment), Input});
534   else
535     return Builder.CreateIntrinsic(
536         Intrinsic::arm_mve_vstr_scatter_base_wb_predicated,
537         {Ptr->getType(), Input->getType(), Mask->getType()},
538         {Ptr, Builder.getInt32(Increment), Input, Mask});
539 }
540 
541 Value *MVEGatherScatterLowering::tryCreateMaskedScatterOffset(
542     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
543   using namespace PatternMatch;
544   Value *Input = I->getArgOperand(0);
545   Value *Mask = I->getArgOperand(3);
546   Type *InputTy = Input->getType();
547   Type *MemoryTy = InputTy;
548   LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing"
549                     << " to base + vector of offsets\n");
550   // If the input has been truncated, try to integrate that trunc into the
551   // scatter instruction (we don't care about alignment here)
552   if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) {
553     Value *PreTrunc = Trunc->getOperand(0);
554     Type *PreTruncTy = PreTrunc->getType();
555     if (PreTruncTy->getPrimitiveSizeInBits() == 128) {
556       Input = PreTrunc;
557       InputTy = PreTruncTy;
558     }
559   }
560   if (InputTy->getPrimitiveSizeInBits() != 128) {
561     LLVM_DEBUG(
562         dbgs() << "masked scatters: cannot create scatters for non-standard"
563                << " input types. Expanding.\n");
564     return nullptr;
565   }
566 
567   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
568   Value *Offsets;
569   Value *BasePtr = checkGEP(Offsets, InputTy, GEP, Builder);
570   if (!BasePtr)
571     return nullptr;
572   // Check whether the offset is a constant increment that could be merged into
573   // a QI gather
574   Value *Store =
575       tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder);
576   if (Store)
577     return Store;
578   int Scale = computeScale(
579       BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(),
580       MemoryTy->getScalarSizeInBits());
581   if (Scale == -1)
582     return nullptr;
583 
584   if (!match(Mask, m_One()))
585     return Builder.CreateIntrinsic(
586         Intrinsic::arm_mve_vstr_scatter_offset_predicated,
587         {BasePtr->getType(), Offsets->getType(), Input->getType(),
588          Mask->getType()},
589         {BasePtr, Offsets, Input,
590          Builder.getInt32(MemoryTy->getScalarSizeInBits()),
591          Builder.getInt32(Scale), Mask});
592   else
593     return Builder.CreateIntrinsic(
594         Intrinsic::arm_mve_vstr_scatter_offset,
595         {BasePtr->getType(), Offsets->getType(), Input->getType()},
596         {BasePtr, Offsets, Input,
597          Builder.getInt32(MemoryTy->getScalarSizeInBits()),
598          Builder.getInt32(Scale)});
599 }
600 
601 Value *MVEGatherScatterLowering::tryCreateIncrementingGatScat(
602     IntrinsicInst *I, Value *BasePtr, Value *Offsets, GetElementPtrInst *GEP,
603     IRBuilder<> &Builder) {
604   FixedVectorType *Ty;
605   if (I->getIntrinsicID() == Intrinsic::masked_gather)
606     Ty = cast<FixedVectorType>(I->getType());
607   else
608     Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType());
609   // Incrementing gathers only exist for v4i32
610   if (Ty->getNumElements() != 4 ||
611       Ty->getScalarSizeInBits() != 32)
612     return nullptr;
613   Loop *L = LI->getLoopFor(I->getParent());
614   if (L == nullptr)
615     // Incrementing gathers are not beneficial outside of a loop
616     return nullptr;
617   LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
618                        "wb gather/scatter\n");
619 
620   // The gep was in charge of making sure the offsets are scaled correctly
621   // - calculate that factor so it can be applied by hand
622   DataLayout DT = I->getParent()->getParent()->getParent()->getDataLayout();
623   int TypeScale =
624       computeScale(DT.getTypeSizeInBits(GEP->getOperand(0)->getType()),
625                    DT.getTypeSizeInBits(GEP->getType()) /
626                        cast<FixedVectorType>(GEP->getType())->getNumElements());
627   if (TypeScale == -1)
628     return nullptr;
629 
630   if (GEP->hasOneUse()) {
631     // Only in this case do we want to build a wb gather, because the wb will
632     // change the phi which does affect other users of the gep (which will still
633     // be using the phi in the old way)
634     Value *Load =
635         tryCreateIncrementingWBGatScat(I, BasePtr, Offsets, TypeScale, Builder);
636     if (Load != nullptr)
637       return Load;
638   }
639   LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
640                        "non-wb gather/scatter\n");
641 
642   std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
643   if (Add.first == nullptr)
644     return nullptr;
645   Value *OffsetsIncoming = Add.first;
646   int64_t Immediate = Add.second;
647 
648   // Make sure the offsets are scaled correctly
649   Instruction *ScaledOffsets = BinaryOperator::Create(
650       Instruction::Shl, OffsetsIncoming,
651       Builder.CreateVectorSplat(Ty->getNumElements(), Builder.getInt32(TypeScale)),
652       "ScaledIndex", I);
653   // Add the base to the offsets
654   OffsetsIncoming = BinaryOperator::Create(
655       Instruction::Add, ScaledOffsets,
656       Builder.CreateVectorSplat(
657           Ty->getNumElements(),
658           Builder.CreatePtrToInt(
659               BasePtr,
660               cast<VectorType>(ScaledOffsets->getType())->getElementType())),
661       "StartIndex", I);
662 
663   if (I->getIntrinsicID() == Intrinsic::masked_gather)
664     return cast<IntrinsicInst>(
665         tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate));
666   else
667     return cast<IntrinsicInst>(
668         tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate));
669 }
670 
671 Value *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat(
672     IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale,
673     IRBuilder<> &Builder) {
674   // Check whether this gather's offset is incremented by a constant - if so,
675   // and the load is of the right type, we can merge this into a QI gather
676   Loop *L = LI->getLoopFor(I->getParent());
677   // Offsets that are worth merging into this instruction will be incremented
678   // by a constant, thus we're looking for an add of a phi and a constant
679   PHINode *Phi = dyn_cast<PHINode>(Offsets);
680   if (Phi == nullptr || Phi->getNumIncomingValues() != 2 ||
681       Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2)
682     // No phi means no IV to write back to; if there is a phi, we expect it
683     // to have exactly two incoming values; the only phis we are interested in
684     // will be loop IV's and have exactly two uses, one in their increment and
685     // one in the gather's gep
686     return nullptr;
687 
688   unsigned IncrementIndex =
689       Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1;
690   // Look through the phi to the phi increment
691   Offsets = Phi->getIncomingValue(IncrementIndex);
692 
693   std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
694   if (Add.first == nullptr)
695     return nullptr;
696   Value *OffsetsIncoming = Add.first;
697   int64_t Immediate = Add.second;
698   if (OffsetsIncoming != Phi)
699     // Then the increment we are looking at is not an increment of the
700     // induction variable, and we don't want to do a writeback
701     return nullptr;
702 
703   Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back());
704   unsigned NumElems =
705       cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements();
706 
707   // Make sure the offsets are scaled correctly
708   Instruction *ScaledOffsets = BinaryOperator::Create(
709       Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex),
710       Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)),
711       "ScaledIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
712   // Add the base to the offsets
713   OffsetsIncoming = BinaryOperator::Create(
714       Instruction::Add, ScaledOffsets,
715       Builder.CreateVectorSplat(
716           NumElems,
717           Builder.CreatePtrToInt(
718               BasePtr,
719               cast<VectorType>(ScaledOffsets->getType())->getElementType())),
720       "StartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
721   // The gather is pre-incrementing
722   OffsetsIncoming = BinaryOperator::Create(
723       Instruction::Sub, OffsetsIncoming,
724       Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)),
725       "PreIncrementStartIndex",
726       &Phi->getIncomingBlock(1 - IncrementIndex)->back());
727   Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming);
728 
729   Builder.SetInsertPoint(I);
730 
731   Value *EndResult;
732   Value *NewInduction;
733   if (I->getIntrinsicID() == Intrinsic::masked_gather) {
734     // Build the incrementing gather
735     Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate);
736     // One value to be handed to whoever uses the gather, one is the loop
737     // increment
738     EndResult = Builder.CreateExtractValue(Load, 0, "Gather");
739     NewInduction = Builder.CreateExtractValue(Load, 1, "GatherIncrement");
740   } else {
741     // Build the incrementing scatter
742     NewInduction = tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate);
743     EndResult = NewInduction;
744   }
745   Instruction *AddInst = cast<Instruction>(Offsets);
746   AddInst->replaceAllUsesWith(NewInduction);
747   AddInst->eraseFromParent();
748   Phi->setIncomingValue(IncrementIndex, NewInduction);
749 
750   return EndResult;
751 }
752 
753 void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,
754                                           Value *OffsSecondOperand,
755                                           unsigned StartIndex) {
756   LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n");
757   Instruction *InsertionPoint =
758         &cast<Instruction>(Phi->getIncomingBlock(StartIndex)->back());
759   // Initialize the phi with a vector that contains a sum of the constants
760   Instruction *NewIndex = BinaryOperator::Create(
761       Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand,
762       "PushedOutAdd", InsertionPoint);
763   unsigned IncrementIndex = StartIndex == 0 ? 1 : 0;
764 
765   // Order such that start index comes first (this reduces mov's)
766   Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex));
767   Phi->addIncoming(Phi->getIncomingValue(IncrementIndex),
768                    Phi->getIncomingBlock(IncrementIndex));
769   Phi->removeIncomingValue(IncrementIndex);
770   Phi->removeIncomingValue(StartIndex);
771 }
772 
773 void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi,
774                                           Value *IncrementPerRound,
775                                           Value *OffsSecondOperand,
776                                           unsigned LoopIncrement,
777                                           IRBuilder<> &Builder) {
778   LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");
779 
780   // Create a new scalar add outside of the loop and transform it to a splat
781   // by which loop variable can be incremented
782   Instruction *InsertionPoint = &cast<Instruction>(
783         Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back());
784 
785   // Create a new index
786   Value *StartIndex = BinaryOperator::Create(
787       Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),
788       OffsSecondOperand, "PushedOutMul", InsertionPoint);
789 
790   Instruction *Product =
791       BinaryOperator::Create(Instruction::Mul, IncrementPerRound,
792                              OffsSecondOperand, "Product", InsertionPoint);
793   // Increment NewIndex by Product instead of the multiplication
794   Instruction *NewIncrement = BinaryOperator::Create(
795       Instruction::Add, Phi, Product, "IncrementPushedOutMul",
796       cast<Instruction>(Phi->getIncomingBlock(LoopIncrement)->back())
797           .getPrevNode());
798 
799   Phi->addIncoming(StartIndex,
800                    Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1));
801   Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement));
802   Phi->removeIncomingValue((unsigned)0);
803   Phi->removeIncomingValue((unsigned)0);
804   return;
805 }
806 
807 // Check whether all usages of this instruction are as offsets of
808 // gathers/scatters or simple arithmetics only used by gathers/scatters
809 static bool hasAllGatScatUsers(Instruction *I) {
810   if (I->hasNUses(0)) {
811     return false;
812   }
813   bool Gatscat = true;
814   for (User *U : I->users()) {
815     if (!isa<Instruction>(U))
816       return false;
817     if (isa<GetElementPtrInst>(U) ||
818         isGatherScatter(dyn_cast<IntrinsicInst>(U))) {
819       return Gatscat;
820     } else {
821       unsigned OpCode = cast<Instruction>(U)->getOpcode();
822       if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) &&
823           hasAllGatScatUsers(cast<Instruction>(U))) {
824         continue;
825       }
826       return false;
827     }
828   }
829   return Gatscat;
830 }
831 
832 bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
833                                                LoopInfo *LI) {
834   LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize\n");
835   // Optimise the addresses of gathers/scatters by moving invariant
836   // calculations out of the loop
837   if (!isa<Instruction>(Offsets))
838     return false;
839   Instruction *Offs = cast<Instruction>(Offsets);
840   if (Offs->getOpcode() != Instruction::Add &&
841       Offs->getOpcode() != Instruction::Mul)
842     return false;
843   Loop *L = LI->getLoopFor(BB);
844   if (L == nullptr)
845     return false;
846   if (!Offs->hasOneUse()) {
847     if (!hasAllGatScatUsers(Offs))
848       return false;
849   }
850 
851   // Find out which, if any, operand of the instruction
852   // is a phi node
853   PHINode *Phi;
854   int OffsSecondOp;
855   if (isa<PHINode>(Offs->getOperand(0))) {
856     Phi = cast<PHINode>(Offs->getOperand(0));
857     OffsSecondOp = 1;
858   } else if (isa<PHINode>(Offs->getOperand(1))) {
859     Phi = cast<PHINode>(Offs->getOperand(1));
860     OffsSecondOp = 0;
861   } else {
862     bool Changed = true;
863     if (isa<Instruction>(Offs->getOperand(0)) &&
864         L->contains(cast<Instruction>(Offs->getOperand(0))))
865       Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI);
866     if (isa<Instruction>(Offs->getOperand(1)) &&
867         L->contains(cast<Instruction>(Offs->getOperand(1))))
868       Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI);
869     if (!Changed) {
870       return false;
871     } else {
872       if (isa<PHINode>(Offs->getOperand(0))) {
873         Phi = cast<PHINode>(Offs->getOperand(0));
874         OffsSecondOp = 1;
875       } else if (isa<PHINode>(Offs->getOperand(1))) {
876         Phi = cast<PHINode>(Offs->getOperand(1));
877         OffsSecondOp = 0;
878       } else {
879         return false;
880       }
881     }
882   }
883   // A phi node we want to perform this function on should be from the
884   // loop header, and shouldn't have more than 2 incoming values
885   if (Phi->getParent() != L->getHeader() ||
886       Phi->getNumIncomingValues() != 2)
887     return false;
888 
889   // The phi must be an induction variable
890   Instruction *Op;
891   int IncrementingBlock = -1;
892 
893   for (int i = 0; i < 2; i++)
894     if ((Op = dyn_cast<Instruction>(Phi->getIncomingValue(i))) != nullptr)
895       if (Op->getOpcode() == Instruction::Add &&
896           (Op->getOperand(0) == Phi || Op->getOperand(1) == Phi))
897         IncrementingBlock = i;
898   if (IncrementingBlock == -1)
899     return false;
900 
901   Instruction *IncInstruction =
902       cast<Instruction>(Phi->getIncomingValue(IncrementingBlock));
903 
904   // If the phi is not used by anything else, we can just adapt it when
905   // replacing the instruction; if it is, we'll have to duplicate it
906   PHINode *NewPhi;
907   Value *IncrementPerRound = IncInstruction->getOperand(
908       (IncInstruction->getOperand(0) == Phi) ? 1 : 0);
909 
910   // Get the value that is added to/multiplied with the phi
911   Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp);
912 
913   if (IncrementPerRound->getType() != OffsSecondOperand->getType())
914     // Something has gone wrong, abort
915     return false;
916 
917   // Only proceed if the increment per round is a constant or an instruction
918   // which does not originate from within the loop
919   if (!isa<Constant>(IncrementPerRound) &&
920       !(isa<Instruction>(IncrementPerRound) &&
921         !L->contains(cast<Instruction>(IncrementPerRound))))
922     return false;
923 
924   if (Phi->getNumUses() == 2) {
925     // No other users -> reuse existing phi (One user is the instruction
926     // we're looking at, the other is the phi increment)
927     if (IncInstruction->getNumUses() != 1) {
928       // If the incrementing instruction does have more users than
929       // our phi, we need to copy it
930       IncInstruction = BinaryOperator::Create(
931           Instruction::BinaryOps(IncInstruction->getOpcode()), Phi,
932           IncrementPerRound, "LoopIncrement", IncInstruction);
933       Phi->setIncomingValue(IncrementingBlock, IncInstruction);
934     }
935     NewPhi = Phi;
936   } else {
937     // There are other users -> create a new phi
938     NewPhi = PHINode::Create(Phi->getType(), 0, "NewPhi", Phi);
939     std::vector<Value *> Increases;
940     // Copy the incoming values of the old phi
941     NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1),
942                         Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1));
943     IncInstruction = BinaryOperator::Create(
944         Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi,
945         IncrementPerRound, "LoopIncrement", IncInstruction);
946     NewPhi->addIncoming(IncInstruction,
947                         Phi->getIncomingBlock(IncrementingBlock));
948     IncrementingBlock = 1;
949   }
950 
951   IRBuilder<> Builder(BB->getContext());
952   Builder.SetInsertPoint(Phi);
953   Builder.SetCurrentDebugLocation(Offs->getDebugLoc());
954 
955   switch (Offs->getOpcode()) {
956   case Instruction::Add:
957     pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);
958     break;
959   case Instruction::Mul:
960     pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock,
961                Builder);
962     break;
963   default:
964     return false;
965   }
966   LLVM_DEBUG(
967       dbgs() << "masked gathers/scatters: simplified loop variable add/mul\n");
968 
969   // The instruction has now been "absorbed" into the phi value
970   Offs->replaceAllUsesWith(NewPhi);
971   if (Offs->hasNUses(0))
972     Offs->eraseFromParent();
973   // Clean up the old increment in case it's unused because we built a new
974   // one
975   if (IncInstruction->hasNUses(0))
976     IncInstruction->eraseFromParent();
977 
978   return true;
979 }
980 
981 bool MVEGatherScatterLowering::runOnFunction(Function &F) {
982   if (!EnableMaskedGatherScatters)
983     return false;
984   auto &TPC = getAnalysis<TargetPassConfig>();
985   auto &TM = TPC.getTM<TargetMachine>();
986   auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
987   if (!ST->hasMVEIntegerOps())
988     return false;
989   LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
990   SmallVector<IntrinsicInst *, 4> Gathers;
991   SmallVector<IntrinsicInst *, 4> Scatters;
992 
993   bool Changed = false;
994 
995   for (BasicBlock &BB : F) {
996     for (Instruction &I : BB) {
997       IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
998       if (II && II->getIntrinsicID() == Intrinsic::masked_gather) {
999         Gathers.push_back(II);
1000         if (isa<GetElementPtrInst>(II->getArgOperand(0)))
1001           Changed |= optimiseOffsets(
1002               cast<Instruction>(II->getArgOperand(0))->getOperand(1),
1003               II->getParent(), LI);
1004       } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) {
1005         Scatters.push_back(II);
1006         if (isa<GetElementPtrInst>(II->getArgOperand(1)))
1007           Changed |= optimiseOffsets(
1008               cast<Instruction>(II->getArgOperand(1))->getOperand(1),
1009               II->getParent(), LI);
1010       }
1011     }
1012   }
1013 
1014   for (unsigned i = 0; i < Gathers.size(); i++) {
1015     IntrinsicInst *I = Gathers[i];
1016     Value *L = lowerGather(I);
1017     if (L == nullptr)
1018       continue;
1019 
1020     // Get rid of any now dead instructions
1021     SimplifyInstructionsInBlock(cast<Instruction>(L)->getParent());
1022     Changed = true;
1023   }
1024 
1025   for (unsigned i = 0; i < Scatters.size(); i++) {
1026     IntrinsicInst *I = Scatters[i];
1027     Value *S = lowerScatter(I);
1028     if (S == nullptr)
1029       continue;
1030 
1031     // Get rid of any now dead instructions
1032     SimplifyInstructionsInBlock(cast<Instruction>(S)->getParent());
1033     Changed = true;
1034   }
1035   return Changed;
1036 }
1037