xref: /freebsd/contrib/llvm-project/llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp (revision 5ffd83dbcc34f10e07f6d3e968ae6365869615f4)
1480093f4SDimitry Andric //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===//
2480093f4SDimitry Andric //
3480093f4SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4480093f4SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5480093f4SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6480093f4SDimitry Andric //
7480093f4SDimitry Andric //===----------------------------------------------------------------------===//
8480093f4SDimitry Andric //
9480093f4SDimitry Andric /// This pass custom lowers llvm.gather and llvm.scatter instructions to
10480093f4SDimitry Andric /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to
11480093f4SDimitry Andric /// produce a better final result as we go.
12480093f4SDimitry Andric //
13480093f4SDimitry Andric //===----------------------------------------------------------------------===//
14480093f4SDimitry Andric 
15480093f4SDimitry Andric #include "ARM.h"
16480093f4SDimitry Andric #include "ARMBaseInstrInfo.h"
17480093f4SDimitry Andric #include "ARMSubtarget.h"
18*5ffd83dbSDimitry Andric #include "llvm/Analysis/LoopInfo.h"
19480093f4SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h"
20480093f4SDimitry Andric #include "llvm/CodeGen/TargetLowering.h"
21480093f4SDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h"
22480093f4SDimitry Andric #include "llvm/CodeGen/TargetSubtargetInfo.h"
23480093f4SDimitry Andric #include "llvm/InitializePasses.h"
24480093f4SDimitry Andric #include "llvm/IR/BasicBlock.h"
25480093f4SDimitry Andric #include "llvm/IR/Constant.h"
26480093f4SDimitry Andric #include "llvm/IR/Constants.h"
27480093f4SDimitry Andric #include "llvm/IR/DerivedTypes.h"
28480093f4SDimitry Andric #include "llvm/IR/Function.h"
29480093f4SDimitry Andric #include "llvm/IR/InstrTypes.h"
30480093f4SDimitry Andric #include "llvm/IR/Instruction.h"
31480093f4SDimitry Andric #include "llvm/IR/Instructions.h"
32480093f4SDimitry Andric #include "llvm/IR/IntrinsicInst.h"
33480093f4SDimitry Andric #include "llvm/IR/Intrinsics.h"
34480093f4SDimitry Andric #include "llvm/IR/IntrinsicsARM.h"
35480093f4SDimitry Andric #include "llvm/IR/IRBuilder.h"
36480093f4SDimitry Andric #include "llvm/IR/PatternMatch.h"
37480093f4SDimitry Andric #include "llvm/IR/Type.h"
38480093f4SDimitry Andric #include "llvm/IR/Value.h"
39480093f4SDimitry Andric #include "llvm/Pass.h"
40480093f4SDimitry Andric #include "llvm/Support/Casting.h"
41*5ffd83dbSDimitry Andric #include "llvm/Transforms/Utils/Local.h"
42480093f4SDimitry Andric #include <algorithm>
43480093f4SDimitry Andric #include <cassert>
44480093f4SDimitry Andric 
45480093f4SDimitry Andric using namespace llvm;
46480093f4SDimitry Andric 
47480093f4SDimitry Andric #define DEBUG_TYPE "mve-gather-scatter-lowering"
48480093f4SDimitry Andric 
49480093f4SDimitry Andric cl::opt<bool> EnableMaskedGatherScatters(
50480093f4SDimitry Andric     "enable-arm-maskedgatscat", cl::Hidden, cl::init(false),
51480093f4SDimitry Andric     cl::desc("Enable the generation of masked gathers and scatters"));
52480093f4SDimitry Andric 
53480093f4SDimitry Andric namespace {
54480093f4SDimitry Andric 
55480093f4SDimitry Andric class MVEGatherScatterLowering : public FunctionPass {
56480093f4SDimitry Andric public:
57480093f4SDimitry Andric   static char ID; // Pass identification, replacement for typeid
58480093f4SDimitry Andric 
59480093f4SDimitry Andric   explicit MVEGatherScatterLowering() : FunctionPass(ID) {
60480093f4SDimitry Andric     initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry());
61480093f4SDimitry Andric   }
62480093f4SDimitry Andric 
63480093f4SDimitry Andric   bool runOnFunction(Function &F) override;
64480093f4SDimitry Andric 
65480093f4SDimitry Andric   StringRef getPassName() const override {
66480093f4SDimitry Andric     return "MVE gather/scatter lowering";
67480093f4SDimitry Andric   }
68480093f4SDimitry Andric 
69480093f4SDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
70480093f4SDimitry Andric     AU.setPreservesCFG();
71480093f4SDimitry Andric     AU.addRequired<TargetPassConfig>();
72*5ffd83dbSDimitry Andric     AU.addRequired<LoopInfoWrapperPass>();
73480093f4SDimitry Andric     FunctionPass::getAnalysisUsage(AU);
74480093f4SDimitry Andric   }
75480093f4SDimitry Andric 
76480093f4SDimitry Andric private:
77*5ffd83dbSDimitry Andric   LoopInfo *LI = nullptr;
78*5ffd83dbSDimitry Andric 
79480093f4SDimitry Andric   // Check this is a valid gather with correct alignment
80480093f4SDimitry Andric   bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,
81*5ffd83dbSDimitry Andric                                Align Alignment);
82480093f4SDimitry Andric   // Check whether Ptr is hidden behind a bitcast and look through it
83480093f4SDimitry Andric   void lookThroughBitcast(Value *&Ptr);
84480093f4SDimitry Andric   // Check for a getelementptr and deduce base and offsets from it, on success
85480093f4SDimitry Andric   // returning the base directly and the offsets indirectly using the Offsets
86480093f4SDimitry Andric   // argument
87*5ffd83dbSDimitry Andric   Value *checkGEP(Value *&Offsets, Type *Ty, GetElementPtrInst *GEP,
88*5ffd83dbSDimitry Andric                   IRBuilder<> &Builder);
89*5ffd83dbSDimitry Andric   // Compute the scale of this gather/scatter instruction
90*5ffd83dbSDimitry Andric   int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize);
91*5ffd83dbSDimitry Andric   // If the value is a constant, or derived from constants via additions
92*5ffd83dbSDimitry Andric   // and multilications, return its numeric value
93*5ffd83dbSDimitry Andric   Optional<int64_t> getIfConst(const Value *V);
94*5ffd83dbSDimitry Andric   // If Inst is an add instruction, check whether one summand is a
95*5ffd83dbSDimitry Andric   // constant. If so, scale this constant and return it together with
96*5ffd83dbSDimitry Andric   // the other summand.
97*5ffd83dbSDimitry Andric   std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale);
98480093f4SDimitry Andric 
99*5ffd83dbSDimitry Andric   Value *lowerGather(IntrinsicInst *I);
100480093f4SDimitry Andric   // Create a gather from a base + vector of offsets
101480093f4SDimitry Andric   Value *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,
102*5ffd83dbSDimitry Andric                                      Instruction *&Root, IRBuilder<> &Builder);
103480093f4SDimitry Andric   // Create a gather from a vector of pointers
104480093f4SDimitry Andric   Value *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,
105*5ffd83dbSDimitry Andric                                    IRBuilder<> &Builder, int64_t Increment = 0);
106*5ffd83dbSDimitry Andric   // Create an incrementing gather from a vector of pointers
107*5ffd83dbSDimitry Andric   Value *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr,
108*5ffd83dbSDimitry Andric                                      IRBuilder<> &Builder,
109*5ffd83dbSDimitry Andric                                      int64_t Increment = 0);
110*5ffd83dbSDimitry Andric 
111*5ffd83dbSDimitry Andric   Value *lowerScatter(IntrinsicInst *I);
112*5ffd83dbSDimitry Andric   // Create a scatter to a base + vector of offsets
113*5ffd83dbSDimitry Andric   Value *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets,
114*5ffd83dbSDimitry Andric                                       IRBuilder<> &Builder);
115*5ffd83dbSDimitry Andric   // Create a scatter to a vector of pointers
116*5ffd83dbSDimitry Andric   Value *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr,
117*5ffd83dbSDimitry Andric                                     IRBuilder<> &Builder,
118*5ffd83dbSDimitry Andric                                     int64_t Increment = 0);
119*5ffd83dbSDimitry Andric   // Create an incrementing scatter from a vector of pointers
120*5ffd83dbSDimitry Andric   Value *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr,
121*5ffd83dbSDimitry Andric                                       IRBuilder<> &Builder,
122*5ffd83dbSDimitry Andric                                       int64_t Increment = 0);
123*5ffd83dbSDimitry Andric 
124*5ffd83dbSDimitry Andric   // QI gathers and scatters can increment their offsets on their own if
125*5ffd83dbSDimitry Andric   // the increment is a constant value (digit)
126*5ffd83dbSDimitry Andric   Value *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *BasePtr,
127*5ffd83dbSDimitry Andric                                       Value *Ptr, GetElementPtrInst *GEP,
128*5ffd83dbSDimitry Andric                                       IRBuilder<> &Builder);
129*5ffd83dbSDimitry Andric   // QI gathers/scatters can increment their offsets on their own if the
130*5ffd83dbSDimitry Andric   // increment is a constant value (digit) - this creates a writeback QI
131*5ffd83dbSDimitry Andric   // gather/scatter
132*5ffd83dbSDimitry Andric   Value *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr,
133*5ffd83dbSDimitry Andric                                         Value *Ptr, unsigned TypeScale,
134*5ffd83dbSDimitry Andric                                         IRBuilder<> &Builder);
135*5ffd83dbSDimitry Andric   // Check whether these offsets could be moved out of the loop they're in
136*5ffd83dbSDimitry Andric   bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI);
137*5ffd83dbSDimitry Andric   // Pushes the given add out of the loop
138*5ffd83dbSDimitry Andric   void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);
139*5ffd83dbSDimitry Andric   // Pushes the given mul out of the loop
140*5ffd83dbSDimitry Andric   void pushOutMul(PHINode *&Phi, Value *IncrementPerRound,
141*5ffd83dbSDimitry Andric                   Value *OffsSecondOperand, unsigned LoopIncrement,
142*5ffd83dbSDimitry Andric                   IRBuilder<> &Builder);
143480093f4SDimitry Andric };
144480093f4SDimitry Andric 
145480093f4SDimitry Andric } // end anonymous namespace
146480093f4SDimitry Andric 
147480093f4SDimitry Andric char MVEGatherScatterLowering::ID = 0;
148480093f4SDimitry Andric 
149480093f4SDimitry Andric INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE,
150480093f4SDimitry Andric                 "MVE gather/scattering lowering pass", false, false)
151480093f4SDimitry Andric 
152480093f4SDimitry Andric Pass *llvm::createMVEGatherScatterLoweringPass() {
153480093f4SDimitry Andric   return new MVEGatherScatterLowering();
154480093f4SDimitry Andric }
155480093f4SDimitry Andric 
156480093f4SDimitry Andric bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,
157480093f4SDimitry Andric                                                        unsigned ElemSize,
158*5ffd83dbSDimitry Andric                                                        Align Alignment) {
159*5ffd83dbSDimitry Andric   if (((NumElements == 4 &&
160*5ffd83dbSDimitry Andric         (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) ||
161*5ffd83dbSDimitry Andric        (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) ||
162480093f4SDimitry Andric        (NumElements == 16 && ElemSize == 8)) &&
163*5ffd83dbSDimitry Andric       Alignment >= ElemSize / 8)
164480093f4SDimitry Andric     return true;
165*5ffd83dbSDimitry Andric   LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have "
166*5ffd83dbSDimitry Andric                     << "valid alignment or vector type \n");
167480093f4SDimitry Andric   return false;
168480093f4SDimitry Andric }
169480093f4SDimitry Andric 
170*5ffd83dbSDimitry Andric Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, Type *Ty,
171*5ffd83dbSDimitry Andric                                           GetElementPtrInst *GEP,
172*5ffd83dbSDimitry Andric                                           IRBuilder<> &Builder) {
173480093f4SDimitry Andric   if (!GEP) {
174*5ffd83dbSDimitry Andric     LLVM_DEBUG(
175*5ffd83dbSDimitry Andric         dbgs() << "masked gathers/scatters: no getelementpointer found\n");
176480093f4SDimitry Andric     return nullptr;
177480093f4SDimitry Andric   }
178*5ffd83dbSDimitry Andric   LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found."
179*5ffd83dbSDimitry Andric                     << " Looking at intrinsic for base + vector of offsets\n");
180480093f4SDimitry Andric   Value *GEPPtr = GEP->getPointerOperand();
181480093f4SDimitry Andric   if (GEPPtr->getType()->isVectorTy()) {
182480093f4SDimitry Andric     return nullptr;
183480093f4SDimitry Andric   }
184480093f4SDimitry Andric   if (GEP->getNumOperands() != 2) {
185*5ffd83dbSDimitry Andric     LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many"
186480093f4SDimitry Andric                       << " operands. Expanding.\n");
187480093f4SDimitry Andric     return nullptr;
188480093f4SDimitry Andric   }
189480093f4SDimitry Andric   Offsets = GEP->getOperand(1);
190*5ffd83dbSDimitry Andric   // Paranoid check whether the number of parallel lanes is the same
191*5ffd83dbSDimitry Andric   assert(cast<FixedVectorType>(Ty)->getNumElements() ==
192*5ffd83dbSDimitry Andric          cast<FixedVectorType>(Offsets->getType())->getNumElements());
193*5ffd83dbSDimitry Andric   // Only <N x i32> offsets can be integrated into an arm gather, any smaller
194*5ffd83dbSDimitry Andric   // type would have to be sign extended by the gep - and arm gathers can only
195*5ffd83dbSDimitry Andric   // zero extend. Additionally, the offsets do have to originate from a zext of
196*5ffd83dbSDimitry Andric   // a vector with element types smaller or equal the type of the gather we're
197*5ffd83dbSDimitry Andric   // looking at
198*5ffd83dbSDimitry Andric   if (Offsets->getType()->getScalarSizeInBits() != 32)
199*5ffd83dbSDimitry Andric     return nullptr;
200480093f4SDimitry Andric   if (ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets))
201480093f4SDimitry Andric     Offsets = ZextOffs->getOperand(0);
202*5ffd83dbSDimitry Andric   else if (!(cast<FixedVectorType>(Offsets->getType())->getNumElements() == 4 &&
203*5ffd83dbSDimitry Andric              Offsets->getType()->getScalarSizeInBits() == 32))
204480093f4SDimitry Andric     return nullptr;
205*5ffd83dbSDimitry Andric 
206*5ffd83dbSDimitry Andric   if (Ty != Offsets->getType()) {
207*5ffd83dbSDimitry Andric     if ((Ty->getScalarSizeInBits() <
208*5ffd83dbSDimitry Andric          Offsets->getType()->getScalarSizeInBits())) {
209*5ffd83dbSDimitry Andric       LLVM_DEBUG(dbgs() << "masked gathers/scatters: no correct offset type."
210*5ffd83dbSDimitry Andric                         << " Can't create intrinsic.\n");
211*5ffd83dbSDimitry Andric       return nullptr;
212*5ffd83dbSDimitry Andric     } else {
213*5ffd83dbSDimitry Andric       Offsets = Builder.CreateZExt(
214*5ffd83dbSDimitry Andric           Offsets, VectorType::getInteger(cast<VectorType>(Ty)));
215480093f4SDimitry Andric     }
216480093f4SDimitry Andric   }
217480093f4SDimitry Andric   // If none of the checks failed, return the gep's base pointer
218*5ffd83dbSDimitry Andric   LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n");
219480093f4SDimitry Andric   return GEPPtr;
220480093f4SDimitry Andric }
221480093f4SDimitry Andric 
222480093f4SDimitry Andric void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) {
223480093f4SDimitry Andric   // Look through bitcast instruction if #elements is the same
224480093f4SDimitry Andric   if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) {
225*5ffd83dbSDimitry Andric     auto *BCTy = cast<FixedVectorType>(BitCast->getType());
226*5ffd83dbSDimitry Andric     auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType());
227*5ffd83dbSDimitry Andric     if (BCTy->getNumElements() == BCSrcTy->getNumElements()) {
228*5ffd83dbSDimitry Andric       LLVM_DEBUG(
229*5ffd83dbSDimitry Andric           dbgs() << "masked gathers/scatters: looking through bitcast\n");
230480093f4SDimitry Andric       Ptr = BitCast->getOperand(0);
231480093f4SDimitry Andric     }
232480093f4SDimitry Andric   }
233480093f4SDimitry Andric }
234480093f4SDimitry Andric 
235*5ffd83dbSDimitry Andric int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize,
236*5ffd83dbSDimitry Andric                                            unsigned MemoryElemSize) {
237*5ffd83dbSDimitry Andric   // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2,
238*5ffd83dbSDimitry Andric   // or a 8bit, 16bit or 32bit load/store scaled by 1
239*5ffd83dbSDimitry Andric   if (GEPElemSize == 32 && MemoryElemSize == 32)
240*5ffd83dbSDimitry Andric     return 2;
241*5ffd83dbSDimitry Andric   else if (GEPElemSize == 16 && MemoryElemSize == 16)
242*5ffd83dbSDimitry Andric     return 1;
243*5ffd83dbSDimitry Andric   else if (GEPElemSize == 8)
244*5ffd83dbSDimitry Andric     return 0;
245*5ffd83dbSDimitry Andric   LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't "
246*5ffd83dbSDimitry Andric                     << "create intrinsic\n");
247*5ffd83dbSDimitry Andric   return -1;
248*5ffd83dbSDimitry Andric }
249*5ffd83dbSDimitry Andric 
250*5ffd83dbSDimitry Andric Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
251*5ffd83dbSDimitry Andric   const Constant *C = dyn_cast<Constant>(V);
252*5ffd83dbSDimitry Andric   if (C != nullptr)
253*5ffd83dbSDimitry Andric     return Optional<int64_t>{C->getUniqueInteger().getSExtValue()};
254*5ffd83dbSDimitry Andric   if (!isa<Instruction>(V))
255*5ffd83dbSDimitry Andric     return Optional<int64_t>{};
256*5ffd83dbSDimitry Andric 
257*5ffd83dbSDimitry Andric   const Instruction *I = cast<Instruction>(V);
258*5ffd83dbSDimitry Andric   if (I->getOpcode() == Instruction::Add ||
259*5ffd83dbSDimitry Andric               I->getOpcode() == Instruction::Mul) {
260*5ffd83dbSDimitry Andric     Optional<int64_t> Op0 = getIfConst(I->getOperand(0));
261*5ffd83dbSDimitry Andric     Optional<int64_t> Op1 = getIfConst(I->getOperand(1));
262*5ffd83dbSDimitry Andric     if (!Op0 || !Op1)
263*5ffd83dbSDimitry Andric       return Optional<int64_t>{};
264*5ffd83dbSDimitry Andric     if (I->getOpcode() == Instruction::Add)
265*5ffd83dbSDimitry Andric       return Optional<int64_t>{Op0.getValue() + Op1.getValue()};
266*5ffd83dbSDimitry Andric     if (I->getOpcode() == Instruction::Mul)
267*5ffd83dbSDimitry Andric       return Optional<int64_t>{Op0.getValue() * Op1.getValue()};
268*5ffd83dbSDimitry Andric   }
269*5ffd83dbSDimitry Andric   return Optional<int64_t>{};
270*5ffd83dbSDimitry Andric }
271*5ffd83dbSDimitry Andric 
272*5ffd83dbSDimitry Andric std::pair<Value *, int64_t>
273*5ffd83dbSDimitry Andric MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) {
274*5ffd83dbSDimitry Andric   std::pair<Value *, int64_t> ReturnFalse =
275*5ffd83dbSDimitry Andric       std::pair<Value *, int64_t>(nullptr, 0);
276*5ffd83dbSDimitry Andric   // At this point, the instruction we're looking at must be an add or we
277*5ffd83dbSDimitry Andric   // bail out
278*5ffd83dbSDimitry Andric   Instruction *Add = dyn_cast<Instruction>(Inst);
279*5ffd83dbSDimitry Andric   if (Add == nullptr || Add->getOpcode() != Instruction::Add)
280*5ffd83dbSDimitry Andric     return ReturnFalse;
281*5ffd83dbSDimitry Andric 
282*5ffd83dbSDimitry Andric   Value *Summand;
283*5ffd83dbSDimitry Andric   Optional<int64_t> Const;
284*5ffd83dbSDimitry Andric   // Find out which operand the value that is increased is
285*5ffd83dbSDimitry Andric   if ((Const = getIfConst(Add->getOperand(0))))
286*5ffd83dbSDimitry Andric     Summand = Add->getOperand(1);
287*5ffd83dbSDimitry Andric   else if ((Const = getIfConst(Add->getOperand(1))))
288*5ffd83dbSDimitry Andric     Summand = Add->getOperand(0);
289*5ffd83dbSDimitry Andric   else
290*5ffd83dbSDimitry Andric     return ReturnFalse;
291*5ffd83dbSDimitry Andric 
292*5ffd83dbSDimitry Andric   // Check that the constant is small enough for an incrementing gather
293*5ffd83dbSDimitry Andric   int64_t Immediate = Const.getValue() << TypeScale;
294*5ffd83dbSDimitry Andric   if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0)
295*5ffd83dbSDimitry Andric     return ReturnFalse;
296*5ffd83dbSDimitry Andric 
297*5ffd83dbSDimitry Andric   return std::pair<Value *, int64_t>(Summand, Immediate);
298*5ffd83dbSDimitry Andric }
299*5ffd83dbSDimitry Andric 
300*5ffd83dbSDimitry Andric Value *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
301480093f4SDimitry Andric   using namespace PatternMatch;
302480093f4SDimitry Andric   LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n");
303480093f4SDimitry Andric 
304480093f4SDimitry Andric   // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
305480093f4SDimitry Andric   // Attempt to turn the masked gather in I into a MVE intrinsic
306480093f4SDimitry Andric   // Potentially optimising the addressing modes as we do so.
307*5ffd83dbSDimitry Andric   auto *Ty = cast<FixedVectorType>(I->getType());
308480093f4SDimitry Andric   Value *Ptr = I->getArgOperand(0);
309*5ffd83dbSDimitry Andric   Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue();
310480093f4SDimitry Andric   Value *Mask = I->getArgOperand(2);
311480093f4SDimitry Andric   Value *PassThru = I->getArgOperand(3);
312480093f4SDimitry Andric 
313*5ffd83dbSDimitry Andric   if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
314*5ffd83dbSDimitry Andric                                Alignment))
315*5ffd83dbSDimitry Andric     return nullptr;
316480093f4SDimitry Andric   lookThroughBitcast(Ptr);
317480093f4SDimitry Andric   assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
318480093f4SDimitry Andric 
319480093f4SDimitry Andric   IRBuilder<> Builder(I->getContext());
320480093f4SDimitry Andric   Builder.SetInsertPoint(I);
321480093f4SDimitry Andric   Builder.SetCurrentDebugLocation(I->getDebugLoc());
322*5ffd83dbSDimitry Andric 
323*5ffd83dbSDimitry Andric   Instruction *Root = I;
324*5ffd83dbSDimitry Andric   Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder);
325480093f4SDimitry Andric   if (!Load)
326480093f4SDimitry Andric     Load = tryCreateMaskedGatherBase(I, Ptr, Builder);
327480093f4SDimitry Andric   if (!Load)
328*5ffd83dbSDimitry Andric     return nullptr;
329480093f4SDimitry Andric 
330480093f4SDimitry Andric   if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) {
331480093f4SDimitry Andric     LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
332480093f4SDimitry Andric                       << "creating select\n");
333480093f4SDimitry Andric     Load = Builder.CreateSelect(Mask, Load, PassThru);
334480093f4SDimitry Andric   }
335480093f4SDimitry Andric 
336*5ffd83dbSDimitry Andric   Root->replaceAllUsesWith(Load);
337*5ffd83dbSDimitry Andric   Root->eraseFromParent();
338*5ffd83dbSDimitry Andric   if (Root != I)
339*5ffd83dbSDimitry Andric     // If this was an extending gather, we need to get rid of the sext/zext
340*5ffd83dbSDimitry Andric     // sext/zext as well as of the gather itself
341480093f4SDimitry Andric     I->eraseFromParent();
342*5ffd83dbSDimitry Andric 
343*5ffd83dbSDimitry Andric   LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n");
344*5ffd83dbSDimitry Andric   return Load;
345480093f4SDimitry Andric }
346480093f4SDimitry Andric 
347*5ffd83dbSDimitry Andric Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase(IntrinsicInst *I,
348*5ffd83dbSDimitry Andric                                                            Value *Ptr,
349*5ffd83dbSDimitry Andric                                                            IRBuilder<> &Builder,
350*5ffd83dbSDimitry Andric                                                            int64_t Increment) {
351480093f4SDimitry Andric   using namespace PatternMatch;
352*5ffd83dbSDimitry Andric   auto *Ty = cast<FixedVectorType>(I->getType());
353480093f4SDimitry Andric   LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
354*5ffd83dbSDimitry Andric   if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
355480093f4SDimitry Andric     // Can't build an intrinsic for this
356480093f4SDimitry Andric     return nullptr;
357480093f4SDimitry Andric   Value *Mask = I->getArgOperand(2);
358480093f4SDimitry Andric   if (match(Mask, m_One()))
359480093f4SDimitry Andric     return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
360480093f4SDimitry Andric                                    {Ty, Ptr->getType()},
361*5ffd83dbSDimitry Andric                                    {Ptr, Builder.getInt32(Increment)});
362480093f4SDimitry Andric   else
363480093f4SDimitry Andric     return Builder.CreateIntrinsic(
364480093f4SDimitry Andric         Intrinsic::arm_mve_vldr_gather_base_predicated,
365480093f4SDimitry Andric         {Ty, Ptr->getType(), Mask->getType()},
366*5ffd83dbSDimitry Andric         {Ptr, Builder.getInt32(Increment), Mask});
367*5ffd83dbSDimitry Andric }
368*5ffd83dbSDimitry Andric 
369*5ffd83dbSDimitry Andric Value *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB(
370*5ffd83dbSDimitry Andric     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
371*5ffd83dbSDimitry Andric   using namespace PatternMatch;
372*5ffd83dbSDimitry Andric   auto *Ty = cast<FixedVectorType>(I->getType());
373*5ffd83dbSDimitry Andric   LLVM_DEBUG(
374*5ffd83dbSDimitry Andric       dbgs()
375*5ffd83dbSDimitry Andric       << "masked gathers: loading from vector of pointers with writeback\n");
376*5ffd83dbSDimitry Andric   if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
377*5ffd83dbSDimitry Andric     // Can't build an intrinsic for this
378*5ffd83dbSDimitry Andric     return nullptr;
379*5ffd83dbSDimitry Andric   Value *Mask = I->getArgOperand(2);
380*5ffd83dbSDimitry Andric   if (match(Mask, m_One()))
381*5ffd83dbSDimitry Andric     return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb,
382*5ffd83dbSDimitry Andric                                    {Ty, Ptr->getType()},
383*5ffd83dbSDimitry Andric                                    {Ptr, Builder.getInt32(Increment)});
384*5ffd83dbSDimitry Andric   else
385*5ffd83dbSDimitry Andric     return Builder.CreateIntrinsic(
386*5ffd83dbSDimitry Andric         Intrinsic::arm_mve_vldr_gather_base_wb_predicated,
387*5ffd83dbSDimitry Andric         {Ty, Ptr->getType(), Mask->getType()},
388*5ffd83dbSDimitry Andric         {Ptr, Builder.getInt32(Increment), Mask});
389480093f4SDimitry Andric }
390480093f4SDimitry Andric 
391480093f4SDimitry Andric Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
392*5ffd83dbSDimitry Andric     IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) {
393480093f4SDimitry Andric   using namespace PatternMatch;
394*5ffd83dbSDimitry Andric 
395*5ffd83dbSDimitry Andric   Type *OriginalTy = I->getType();
396*5ffd83dbSDimitry Andric   Type *ResultTy = OriginalTy;
397*5ffd83dbSDimitry Andric 
398*5ffd83dbSDimitry Andric   unsigned Unsigned = 1;
399*5ffd83dbSDimitry Andric   // The size of the gather was already checked in isLegalTypeAndAlignment;
400*5ffd83dbSDimitry Andric   // if it was not a full vector width an appropriate extend should follow.
401*5ffd83dbSDimitry Andric   auto *Extend = Root;
402*5ffd83dbSDimitry Andric   if (OriginalTy->getPrimitiveSizeInBits() < 128) {
403*5ffd83dbSDimitry Andric     // Only transform gathers with exactly one use
404*5ffd83dbSDimitry Andric     if (!I->hasOneUse())
405480093f4SDimitry Andric       return nullptr;
406480093f4SDimitry Andric 
407*5ffd83dbSDimitry Andric     // The correct root to replace is not the CallInst itself, but the
408*5ffd83dbSDimitry Andric     // instruction which extends it
409*5ffd83dbSDimitry Andric     Extend = cast<Instruction>(*I->users().begin());
410*5ffd83dbSDimitry Andric     if (isa<SExtInst>(Extend)) {
411*5ffd83dbSDimitry Andric       Unsigned = 0;
412*5ffd83dbSDimitry Andric     } else if (!isa<ZExtInst>(Extend)) {
413*5ffd83dbSDimitry Andric       LLVM_DEBUG(dbgs() << "masked gathers: extend needed but not provided. "
414*5ffd83dbSDimitry Andric                         << "Expanding\n");
415480093f4SDimitry Andric       return nullptr;
416480093f4SDimitry Andric     }
417*5ffd83dbSDimitry Andric     LLVM_DEBUG(dbgs() << "masked gathers: found an extending gather\n");
418*5ffd83dbSDimitry Andric     ResultTy = Extend->getType();
419*5ffd83dbSDimitry Andric     // The final size of the gather must be a full vector width
420*5ffd83dbSDimitry Andric     if (ResultTy->getPrimitiveSizeInBits() != 128) {
421*5ffd83dbSDimitry Andric       LLVM_DEBUG(dbgs() << "masked gathers: extending from the wrong type. "
422*5ffd83dbSDimitry Andric                         << "Expanding\n");
423*5ffd83dbSDimitry Andric       return nullptr;
424*5ffd83dbSDimitry Andric     }
425*5ffd83dbSDimitry Andric   }
426*5ffd83dbSDimitry Andric 
427*5ffd83dbSDimitry Andric   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
428*5ffd83dbSDimitry Andric   Value *Offsets;
429*5ffd83dbSDimitry Andric   Value *BasePtr = checkGEP(Offsets, ResultTy, GEP, Builder);
430*5ffd83dbSDimitry Andric   if (!BasePtr)
431*5ffd83dbSDimitry Andric     return nullptr;
432*5ffd83dbSDimitry Andric   // Check whether the offset is a constant increment that could be merged into
433*5ffd83dbSDimitry Andric   // a QI gather
434*5ffd83dbSDimitry Andric   Value *Load = tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder);
435*5ffd83dbSDimitry Andric   if (Load)
436*5ffd83dbSDimitry Andric     return Load;
437*5ffd83dbSDimitry Andric 
438*5ffd83dbSDimitry Andric   int Scale = computeScale(
439*5ffd83dbSDimitry Andric       BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(),
440*5ffd83dbSDimitry Andric       OriginalTy->getScalarSizeInBits());
441*5ffd83dbSDimitry Andric   if (Scale == -1)
442*5ffd83dbSDimitry Andric     return nullptr;
443*5ffd83dbSDimitry Andric   Root = Extend;
444480093f4SDimitry Andric 
445480093f4SDimitry Andric   Value *Mask = I->getArgOperand(2);
446480093f4SDimitry Andric   if (!match(Mask, m_One()))
447480093f4SDimitry Andric     return Builder.CreateIntrinsic(
448480093f4SDimitry Andric         Intrinsic::arm_mve_vldr_gather_offset_predicated,
449*5ffd83dbSDimitry Andric         {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()},
450*5ffd83dbSDimitry Andric         {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()),
451*5ffd83dbSDimitry Andric          Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask});
452480093f4SDimitry Andric   else
453480093f4SDimitry Andric     return Builder.CreateIntrinsic(
454480093f4SDimitry Andric         Intrinsic::arm_mve_vldr_gather_offset,
455*5ffd83dbSDimitry Andric         {ResultTy, BasePtr->getType(), Offsets->getType()},
456*5ffd83dbSDimitry Andric         {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()),
457*5ffd83dbSDimitry Andric          Builder.getInt32(Scale), Builder.getInt32(Unsigned)});
458*5ffd83dbSDimitry Andric }
459*5ffd83dbSDimitry Andric 
460*5ffd83dbSDimitry Andric Value *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) {
461*5ffd83dbSDimitry Andric   using namespace PatternMatch;
462*5ffd83dbSDimitry Andric   LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n");
463*5ffd83dbSDimitry Andric 
464*5ffd83dbSDimitry Andric   // @llvm.masked.scatter.*(data, ptrs, alignment, mask)
465*5ffd83dbSDimitry Andric   // Attempt to turn the masked scatter in I into a MVE intrinsic
466*5ffd83dbSDimitry Andric   // Potentially optimising the addressing modes as we do so.
467*5ffd83dbSDimitry Andric   Value *Input = I->getArgOperand(0);
468*5ffd83dbSDimitry Andric   Value *Ptr = I->getArgOperand(1);
469*5ffd83dbSDimitry Andric   Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue();
470*5ffd83dbSDimitry Andric   auto *Ty = cast<FixedVectorType>(Input->getType());
471*5ffd83dbSDimitry Andric 
472*5ffd83dbSDimitry Andric   if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
473*5ffd83dbSDimitry Andric                                Alignment))
474*5ffd83dbSDimitry Andric     return nullptr;
475*5ffd83dbSDimitry Andric 
476*5ffd83dbSDimitry Andric   lookThroughBitcast(Ptr);
477*5ffd83dbSDimitry Andric   assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
478*5ffd83dbSDimitry Andric 
479*5ffd83dbSDimitry Andric   IRBuilder<> Builder(I->getContext());
480*5ffd83dbSDimitry Andric   Builder.SetInsertPoint(I);
481*5ffd83dbSDimitry Andric   Builder.SetCurrentDebugLocation(I->getDebugLoc());
482*5ffd83dbSDimitry Andric 
483*5ffd83dbSDimitry Andric   Value *Store = tryCreateMaskedScatterOffset(I, Ptr, Builder);
484*5ffd83dbSDimitry Andric   if (!Store)
485*5ffd83dbSDimitry Andric     Store = tryCreateMaskedScatterBase(I, Ptr, Builder);
486*5ffd83dbSDimitry Andric   if (!Store)
487*5ffd83dbSDimitry Andric     return nullptr;
488*5ffd83dbSDimitry Andric 
489*5ffd83dbSDimitry Andric   LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n");
490*5ffd83dbSDimitry Andric   I->eraseFromParent();
491*5ffd83dbSDimitry Andric   return Store;
492*5ffd83dbSDimitry Andric }
493*5ffd83dbSDimitry Andric 
494*5ffd83dbSDimitry Andric Value *MVEGatherScatterLowering::tryCreateMaskedScatterBase(
495*5ffd83dbSDimitry Andric     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
496*5ffd83dbSDimitry Andric   using namespace PatternMatch;
497*5ffd83dbSDimitry Andric   Value *Input = I->getArgOperand(0);
498*5ffd83dbSDimitry Andric   auto *Ty = cast<FixedVectorType>(Input->getType());
499*5ffd83dbSDimitry Andric   // Only QR variants allow truncating
500*5ffd83dbSDimitry Andric   if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) {
501*5ffd83dbSDimitry Andric     // Can't build an intrinsic for this
502*5ffd83dbSDimitry Andric     return nullptr;
503*5ffd83dbSDimitry Andric   }
504*5ffd83dbSDimitry Andric   Value *Mask = I->getArgOperand(3);
505*5ffd83dbSDimitry Andric   //  int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask)
506*5ffd83dbSDimitry Andric   LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n");
507*5ffd83dbSDimitry Andric   if (match(Mask, m_One()))
508*5ffd83dbSDimitry Andric     return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base,
509*5ffd83dbSDimitry Andric                                    {Ptr->getType(), Input->getType()},
510*5ffd83dbSDimitry Andric                                    {Ptr, Builder.getInt32(Increment), Input});
511*5ffd83dbSDimitry Andric   else
512*5ffd83dbSDimitry Andric     return Builder.CreateIntrinsic(
513*5ffd83dbSDimitry Andric         Intrinsic::arm_mve_vstr_scatter_base_predicated,
514*5ffd83dbSDimitry Andric         {Ptr->getType(), Input->getType(), Mask->getType()},
515*5ffd83dbSDimitry Andric         {Ptr, Builder.getInt32(Increment), Input, Mask});
516*5ffd83dbSDimitry Andric }
517*5ffd83dbSDimitry Andric 
518*5ffd83dbSDimitry Andric Value *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB(
519*5ffd83dbSDimitry Andric     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
520*5ffd83dbSDimitry Andric   using namespace PatternMatch;
521*5ffd83dbSDimitry Andric   Value *Input = I->getArgOperand(0);
522*5ffd83dbSDimitry Andric   auto *Ty = cast<FixedVectorType>(Input->getType());
523*5ffd83dbSDimitry Andric   LLVM_DEBUG(
524*5ffd83dbSDimitry Andric       dbgs()
525*5ffd83dbSDimitry Andric       << "masked scatters: storing to a vector of pointers with writeback\n");
526*5ffd83dbSDimitry Andric   if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
527*5ffd83dbSDimitry Andric     // Can't build an intrinsic for this
528*5ffd83dbSDimitry Andric     return nullptr;
529*5ffd83dbSDimitry Andric   Value *Mask = I->getArgOperand(3);
530*5ffd83dbSDimitry Andric   if (match(Mask, m_One()))
531*5ffd83dbSDimitry Andric     return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb,
532*5ffd83dbSDimitry Andric                                    {Ptr->getType(), Input->getType()},
533*5ffd83dbSDimitry Andric                                    {Ptr, Builder.getInt32(Increment), Input});
534*5ffd83dbSDimitry Andric   else
535*5ffd83dbSDimitry Andric     return Builder.CreateIntrinsic(
536*5ffd83dbSDimitry Andric         Intrinsic::arm_mve_vstr_scatter_base_wb_predicated,
537*5ffd83dbSDimitry Andric         {Ptr->getType(), Input->getType(), Mask->getType()},
538*5ffd83dbSDimitry Andric         {Ptr, Builder.getInt32(Increment), Input, Mask});
539*5ffd83dbSDimitry Andric }
540*5ffd83dbSDimitry Andric 
541*5ffd83dbSDimitry Andric Value *MVEGatherScatterLowering::tryCreateMaskedScatterOffset(
542*5ffd83dbSDimitry Andric     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
543*5ffd83dbSDimitry Andric   using namespace PatternMatch;
544*5ffd83dbSDimitry Andric   Value *Input = I->getArgOperand(0);
545*5ffd83dbSDimitry Andric   Value *Mask = I->getArgOperand(3);
546*5ffd83dbSDimitry Andric   Type *InputTy = Input->getType();
547*5ffd83dbSDimitry Andric   Type *MemoryTy = InputTy;
548*5ffd83dbSDimitry Andric   LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing"
549*5ffd83dbSDimitry Andric                     << " to base + vector of offsets\n");
550*5ffd83dbSDimitry Andric   // If the input has been truncated, try to integrate that trunc into the
551*5ffd83dbSDimitry Andric   // scatter instruction (we don't care about alignment here)
552*5ffd83dbSDimitry Andric   if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) {
553*5ffd83dbSDimitry Andric     Value *PreTrunc = Trunc->getOperand(0);
554*5ffd83dbSDimitry Andric     Type *PreTruncTy = PreTrunc->getType();
555*5ffd83dbSDimitry Andric     if (PreTruncTy->getPrimitiveSizeInBits() == 128) {
556*5ffd83dbSDimitry Andric       Input = PreTrunc;
557*5ffd83dbSDimitry Andric       InputTy = PreTruncTy;
558*5ffd83dbSDimitry Andric     }
559*5ffd83dbSDimitry Andric   }
560*5ffd83dbSDimitry Andric   if (InputTy->getPrimitiveSizeInBits() != 128) {
561*5ffd83dbSDimitry Andric     LLVM_DEBUG(
562*5ffd83dbSDimitry Andric         dbgs() << "masked scatters: cannot create scatters for non-standard"
563*5ffd83dbSDimitry Andric                << " input types. Expanding.\n");
564*5ffd83dbSDimitry Andric     return nullptr;
565*5ffd83dbSDimitry Andric   }
566*5ffd83dbSDimitry Andric 
567*5ffd83dbSDimitry Andric   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
568*5ffd83dbSDimitry Andric   Value *Offsets;
569*5ffd83dbSDimitry Andric   Value *BasePtr = checkGEP(Offsets, InputTy, GEP, Builder);
570*5ffd83dbSDimitry Andric   if (!BasePtr)
571*5ffd83dbSDimitry Andric     return nullptr;
572*5ffd83dbSDimitry Andric   // Check whether the offset is a constant increment that could be merged into
573*5ffd83dbSDimitry Andric   // a QI gather
574*5ffd83dbSDimitry Andric   Value *Store =
575*5ffd83dbSDimitry Andric       tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder);
576*5ffd83dbSDimitry Andric   if (Store)
577*5ffd83dbSDimitry Andric     return Store;
578*5ffd83dbSDimitry Andric   int Scale = computeScale(
579*5ffd83dbSDimitry Andric       BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(),
580*5ffd83dbSDimitry Andric       MemoryTy->getScalarSizeInBits());
581*5ffd83dbSDimitry Andric   if (Scale == -1)
582*5ffd83dbSDimitry Andric     return nullptr;
583*5ffd83dbSDimitry Andric 
584*5ffd83dbSDimitry Andric   if (!match(Mask, m_One()))
585*5ffd83dbSDimitry Andric     return Builder.CreateIntrinsic(
586*5ffd83dbSDimitry Andric         Intrinsic::arm_mve_vstr_scatter_offset_predicated,
587*5ffd83dbSDimitry Andric         {BasePtr->getType(), Offsets->getType(), Input->getType(),
588*5ffd83dbSDimitry Andric          Mask->getType()},
589*5ffd83dbSDimitry Andric         {BasePtr, Offsets, Input,
590*5ffd83dbSDimitry Andric          Builder.getInt32(MemoryTy->getScalarSizeInBits()),
591*5ffd83dbSDimitry Andric          Builder.getInt32(Scale), Mask});
592*5ffd83dbSDimitry Andric   else
593*5ffd83dbSDimitry Andric     return Builder.CreateIntrinsic(
594*5ffd83dbSDimitry Andric         Intrinsic::arm_mve_vstr_scatter_offset,
595*5ffd83dbSDimitry Andric         {BasePtr->getType(), Offsets->getType(), Input->getType()},
596*5ffd83dbSDimitry Andric         {BasePtr, Offsets, Input,
597*5ffd83dbSDimitry Andric          Builder.getInt32(MemoryTy->getScalarSizeInBits()),
598*5ffd83dbSDimitry Andric          Builder.getInt32(Scale)});
599*5ffd83dbSDimitry Andric }
600*5ffd83dbSDimitry Andric 
601*5ffd83dbSDimitry Andric Value *MVEGatherScatterLowering::tryCreateIncrementingGatScat(
602*5ffd83dbSDimitry Andric     IntrinsicInst *I, Value *BasePtr, Value *Offsets, GetElementPtrInst *GEP,
603*5ffd83dbSDimitry Andric     IRBuilder<> &Builder) {
604*5ffd83dbSDimitry Andric   FixedVectorType *Ty;
605*5ffd83dbSDimitry Andric   if (I->getIntrinsicID() == Intrinsic::masked_gather)
606*5ffd83dbSDimitry Andric     Ty = cast<FixedVectorType>(I->getType());
607*5ffd83dbSDimitry Andric   else
608*5ffd83dbSDimitry Andric     Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType());
609*5ffd83dbSDimitry Andric   // Incrementing gathers only exist for v4i32
610*5ffd83dbSDimitry Andric   if (Ty->getNumElements() != 4 ||
611*5ffd83dbSDimitry Andric       Ty->getScalarSizeInBits() != 32)
612*5ffd83dbSDimitry Andric     return nullptr;
613*5ffd83dbSDimitry Andric   Loop *L = LI->getLoopFor(I->getParent());
614*5ffd83dbSDimitry Andric   if (L == nullptr)
615*5ffd83dbSDimitry Andric     // Incrementing gathers are not beneficial outside of a loop
616*5ffd83dbSDimitry Andric     return nullptr;
617*5ffd83dbSDimitry Andric   LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
618*5ffd83dbSDimitry Andric                        "wb gather/scatter\n");
619*5ffd83dbSDimitry Andric 
620*5ffd83dbSDimitry Andric   // The gep was in charge of making sure the offsets are scaled correctly
621*5ffd83dbSDimitry Andric   // - calculate that factor so it can be applied by hand
622*5ffd83dbSDimitry Andric   DataLayout DT = I->getParent()->getParent()->getParent()->getDataLayout();
623*5ffd83dbSDimitry Andric   int TypeScale =
624*5ffd83dbSDimitry Andric       computeScale(DT.getTypeSizeInBits(GEP->getOperand(0)->getType()),
625*5ffd83dbSDimitry Andric                    DT.getTypeSizeInBits(GEP->getType()) /
626*5ffd83dbSDimitry Andric                        cast<FixedVectorType>(GEP->getType())->getNumElements());
627*5ffd83dbSDimitry Andric   if (TypeScale == -1)
628*5ffd83dbSDimitry Andric     return nullptr;
629*5ffd83dbSDimitry Andric 
630*5ffd83dbSDimitry Andric   if (GEP->hasOneUse()) {
631*5ffd83dbSDimitry Andric     // Only in this case do we want to build a wb gather, because the wb will
632*5ffd83dbSDimitry Andric     // change the phi which does affect other users of the gep (which will still
633*5ffd83dbSDimitry Andric     // be using the phi in the old way)
634*5ffd83dbSDimitry Andric     Value *Load =
635*5ffd83dbSDimitry Andric         tryCreateIncrementingWBGatScat(I, BasePtr, Offsets, TypeScale, Builder);
636*5ffd83dbSDimitry Andric     if (Load != nullptr)
637*5ffd83dbSDimitry Andric       return Load;
638*5ffd83dbSDimitry Andric   }
639*5ffd83dbSDimitry Andric   LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
640*5ffd83dbSDimitry Andric                        "non-wb gather/scatter\n");
641*5ffd83dbSDimitry Andric 
642*5ffd83dbSDimitry Andric   std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
643*5ffd83dbSDimitry Andric   if (Add.first == nullptr)
644*5ffd83dbSDimitry Andric     return nullptr;
645*5ffd83dbSDimitry Andric   Value *OffsetsIncoming = Add.first;
646*5ffd83dbSDimitry Andric   int64_t Immediate = Add.second;
647*5ffd83dbSDimitry Andric 
648*5ffd83dbSDimitry Andric   // Make sure the offsets are scaled correctly
649*5ffd83dbSDimitry Andric   Instruction *ScaledOffsets = BinaryOperator::Create(
650*5ffd83dbSDimitry Andric       Instruction::Shl, OffsetsIncoming,
651*5ffd83dbSDimitry Andric       Builder.CreateVectorSplat(Ty->getNumElements(), Builder.getInt32(TypeScale)),
652*5ffd83dbSDimitry Andric       "ScaledIndex", I);
653*5ffd83dbSDimitry Andric   // Add the base to the offsets
654*5ffd83dbSDimitry Andric   OffsetsIncoming = BinaryOperator::Create(
655*5ffd83dbSDimitry Andric       Instruction::Add, ScaledOffsets,
656*5ffd83dbSDimitry Andric       Builder.CreateVectorSplat(
657*5ffd83dbSDimitry Andric           Ty->getNumElements(),
658*5ffd83dbSDimitry Andric           Builder.CreatePtrToInt(
659*5ffd83dbSDimitry Andric               BasePtr,
660*5ffd83dbSDimitry Andric               cast<VectorType>(ScaledOffsets->getType())->getElementType())),
661*5ffd83dbSDimitry Andric       "StartIndex", I);
662*5ffd83dbSDimitry Andric 
663*5ffd83dbSDimitry Andric   if (I->getIntrinsicID() == Intrinsic::masked_gather)
664*5ffd83dbSDimitry Andric     return cast<IntrinsicInst>(
665*5ffd83dbSDimitry Andric         tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate));
666*5ffd83dbSDimitry Andric   else
667*5ffd83dbSDimitry Andric     return cast<IntrinsicInst>(
668*5ffd83dbSDimitry Andric         tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate));
669*5ffd83dbSDimitry Andric }
670*5ffd83dbSDimitry Andric 
671*5ffd83dbSDimitry Andric Value *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat(
672*5ffd83dbSDimitry Andric     IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale,
673*5ffd83dbSDimitry Andric     IRBuilder<> &Builder) {
674*5ffd83dbSDimitry Andric   // Check whether this gather's offset is incremented by a constant - if so,
675*5ffd83dbSDimitry Andric   // and the load is of the right type, we can merge this into a QI gather
676*5ffd83dbSDimitry Andric   Loop *L = LI->getLoopFor(I->getParent());
677*5ffd83dbSDimitry Andric   // Offsets that are worth merging into this instruction will be incremented
678*5ffd83dbSDimitry Andric   // by a constant, thus we're looking for an add of a phi and a constant
679*5ffd83dbSDimitry Andric   PHINode *Phi = dyn_cast<PHINode>(Offsets);
680*5ffd83dbSDimitry Andric   if (Phi == nullptr || Phi->getNumIncomingValues() != 2 ||
681*5ffd83dbSDimitry Andric       Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2)
682*5ffd83dbSDimitry Andric     // No phi means no IV to write back to; if there is a phi, we expect it
683*5ffd83dbSDimitry Andric     // to have exactly two incoming values; the only phis we are interested in
684*5ffd83dbSDimitry Andric     // will be loop IV's and have exactly two uses, one in their increment and
685*5ffd83dbSDimitry Andric     // one in the gather's gep
686*5ffd83dbSDimitry Andric     return nullptr;
687*5ffd83dbSDimitry Andric 
688*5ffd83dbSDimitry Andric   unsigned IncrementIndex =
689*5ffd83dbSDimitry Andric       Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1;
690*5ffd83dbSDimitry Andric   // Look through the phi to the phi increment
691*5ffd83dbSDimitry Andric   Offsets = Phi->getIncomingValue(IncrementIndex);
692*5ffd83dbSDimitry Andric 
693*5ffd83dbSDimitry Andric   std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
694*5ffd83dbSDimitry Andric   if (Add.first == nullptr)
695*5ffd83dbSDimitry Andric     return nullptr;
696*5ffd83dbSDimitry Andric   Value *OffsetsIncoming = Add.first;
697*5ffd83dbSDimitry Andric   int64_t Immediate = Add.second;
698*5ffd83dbSDimitry Andric   if (OffsetsIncoming != Phi)
699*5ffd83dbSDimitry Andric     // Then the increment we are looking at is not an increment of the
700*5ffd83dbSDimitry Andric     // induction variable, and we don't want to do a writeback
701*5ffd83dbSDimitry Andric     return nullptr;
702*5ffd83dbSDimitry Andric 
703*5ffd83dbSDimitry Andric   Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back());
704*5ffd83dbSDimitry Andric   unsigned NumElems =
705*5ffd83dbSDimitry Andric       cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements();
706*5ffd83dbSDimitry Andric 
707*5ffd83dbSDimitry Andric   // Make sure the offsets are scaled correctly
708*5ffd83dbSDimitry Andric   Instruction *ScaledOffsets = BinaryOperator::Create(
709*5ffd83dbSDimitry Andric       Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex),
710*5ffd83dbSDimitry Andric       Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)),
711*5ffd83dbSDimitry Andric       "ScaledIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
712*5ffd83dbSDimitry Andric   // Add the base to the offsets
713*5ffd83dbSDimitry Andric   OffsetsIncoming = BinaryOperator::Create(
714*5ffd83dbSDimitry Andric       Instruction::Add, ScaledOffsets,
715*5ffd83dbSDimitry Andric       Builder.CreateVectorSplat(
716*5ffd83dbSDimitry Andric           NumElems,
717*5ffd83dbSDimitry Andric           Builder.CreatePtrToInt(
718*5ffd83dbSDimitry Andric               BasePtr,
719*5ffd83dbSDimitry Andric               cast<VectorType>(ScaledOffsets->getType())->getElementType())),
720*5ffd83dbSDimitry Andric       "StartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
721*5ffd83dbSDimitry Andric   // The gather is pre-incrementing
722*5ffd83dbSDimitry Andric   OffsetsIncoming = BinaryOperator::Create(
723*5ffd83dbSDimitry Andric       Instruction::Sub, OffsetsIncoming,
724*5ffd83dbSDimitry Andric       Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)),
725*5ffd83dbSDimitry Andric       "PreIncrementStartIndex",
726*5ffd83dbSDimitry Andric       &Phi->getIncomingBlock(1 - IncrementIndex)->back());
727*5ffd83dbSDimitry Andric   Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming);
728*5ffd83dbSDimitry Andric 
729*5ffd83dbSDimitry Andric   Builder.SetInsertPoint(I);
730*5ffd83dbSDimitry Andric 
731*5ffd83dbSDimitry Andric   Value *EndResult;
732*5ffd83dbSDimitry Andric   Value *NewInduction;
733*5ffd83dbSDimitry Andric   if (I->getIntrinsicID() == Intrinsic::masked_gather) {
734*5ffd83dbSDimitry Andric     // Build the incrementing gather
735*5ffd83dbSDimitry Andric     Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate);
736*5ffd83dbSDimitry Andric     // One value to be handed to whoever uses the gather, one is the loop
737*5ffd83dbSDimitry Andric     // increment
738*5ffd83dbSDimitry Andric     EndResult = Builder.CreateExtractValue(Load, 0, "Gather");
739*5ffd83dbSDimitry Andric     NewInduction = Builder.CreateExtractValue(Load, 1, "GatherIncrement");
740*5ffd83dbSDimitry Andric   } else {
741*5ffd83dbSDimitry Andric     // Build the incrementing scatter
742*5ffd83dbSDimitry Andric     NewInduction = tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate);
743*5ffd83dbSDimitry Andric     EndResult = NewInduction;
744*5ffd83dbSDimitry Andric   }
745*5ffd83dbSDimitry Andric   Instruction *AddInst = cast<Instruction>(Offsets);
746*5ffd83dbSDimitry Andric   AddInst->replaceAllUsesWith(NewInduction);
747*5ffd83dbSDimitry Andric   AddInst->eraseFromParent();
748*5ffd83dbSDimitry Andric   Phi->setIncomingValue(IncrementIndex, NewInduction);
749*5ffd83dbSDimitry Andric 
750*5ffd83dbSDimitry Andric   return EndResult;
751*5ffd83dbSDimitry Andric }
752*5ffd83dbSDimitry Andric 
753*5ffd83dbSDimitry Andric void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,
754*5ffd83dbSDimitry Andric                                           Value *OffsSecondOperand,
755*5ffd83dbSDimitry Andric                                           unsigned StartIndex) {
756*5ffd83dbSDimitry Andric   LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n");
757*5ffd83dbSDimitry Andric   Instruction *InsertionPoint =
758*5ffd83dbSDimitry Andric         &cast<Instruction>(Phi->getIncomingBlock(StartIndex)->back());
759*5ffd83dbSDimitry Andric   // Initialize the phi with a vector that contains a sum of the constants
760*5ffd83dbSDimitry Andric   Instruction *NewIndex = BinaryOperator::Create(
761*5ffd83dbSDimitry Andric       Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand,
762*5ffd83dbSDimitry Andric       "PushedOutAdd", InsertionPoint);
763*5ffd83dbSDimitry Andric   unsigned IncrementIndex = StartIndex == 0 ? 1 : 0;
764*5ffd83dbSDimitry Andric 
765*5ffd83dbSDimitry Andric   // Order such that start index comes first (this reduces mov's)
766*5ffd83dbSDimitry Andric   Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex));
767*5ffd83dbSDimitry Andric   Phi->addIncoming(Phi->getIncomingValue(IncrementIndex),
768*5ffd83dbSDimitry Andric                    Phi->getIncomingBlock(IncrementIndex));
769*5ffd83dbSDimitry Andric   Phi->removeIncomingValue(IncrementIndex);
770*5ffd83dbSDimitry Andric   Phi->removeIncomingValue(StartIndex);
771*5ffd83dbSDimitry Andric }
772*5ffd83dbSDimitry Andric 
773*5ffd83dbSDimitry Andric void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi,
774*5ffd83dbSDimitry Andric                                           Value *IncrementPerRound,
775*5ffd83dbSDimitry Andric                                           Value *OffsSecondOperand,
776*5ffd83dbSDimitry Andric                                           unsigned LoopIncrement,
777*5ffd83dbSDimitry Andric                                           IRBuilder<> &Builder) {
778*5ffd83dbSDimitry Andric   LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");
779*5ffd83dbSDimitry Andric 
780*5ffd83dbSDimitry Andric   // Create a new scalar add outside of the loop and transform it to a splat
781*5ffd83dbSDimitry Andric   // by which loop variable can be incremented
782*5ffd83dbSDimitry Andric   Instruction *InsertionPoint = &cast<Instruction>(
783*5ffd83dbSDimitry Andric         Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back());
784*5ffd83dbSDimitry Andric 
785*5ffd83dbSDimitry Andric   // Create a new index
786*5ffd83dbSDimitry Andric   Value *StartIndex = BinaryOperator::Create(
787*5ffd83dbSDimitry Andric       Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),
788*5ffd83dbSDimitry Andric       OffsSecondOperand, "PushedOutMul", InsertionPoint);
789*5ffd83dbSDimitry Andric 
790*5ffd83dbSDimitry Andric   Instruction *Product =
791*5ffd83dbSDimitry Andric       BinaryOperator::Create(Instruction::Mul, IncrementPerRound,
792*5ffd83dbSDimitry Andric                              OffsSecondOperand, "Product", InsertionPoint);
793*5ffd83dbSDimitry Andric   // Increment NewIndex by Product instead of the multiplication
794*5ffd83dbSDimitry Andric   Instruction *NewIncrement = BinaryOperator::Create(
795*5ffd83dbSDimitry Andric       Instruction::Add, Phi, Product, "IncrementPushedOutMul",
796*5ffd83dbSDimitry Andric       cast<Instruction>(Phi->getIncomingBlock(LoopIncrement)->back())
797*5ffd83dbSDimitry Andric           .getPrevNode());
798*5ffd83dbSDimitry Andric 
799*5ffd83dbSDimitry Andric   Phi->addIncoming(StartIndex,
800*5ffd83dbSDimitry Andric                    Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1));
801*5ffd83dbSDimitry Andric   Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement));
802*5ffd83dbSDimitry Andric   Phi->removeIncomingValue((unsigned)0);
803*5ffd83dbSDimitry Andric   Phi->removeIncomingValue((unsigned)0);
804*5ffd83dbSDimitry Andric   return;
805*5ffd83dbSDimitry Andric }
806*5ffd83dbSDimitry Andric 
807*5ffd83dbSDimitry Andric // Check whether all usages of this instruction are as offsets of
808*5ffd83dbSDimitry Andric // gathers/scatters or simple arithmetics only used by gathers/scatters
809*5ffd83dbSDimitry Andric static bool hasAllGatScatUsers(Instruction *I) {
810*5ffd83dbSDimitry Andric   if (I->hasNUses(0)) {
811*5ffd83dbSDimitry Andric     return false;
812*5ffd83dbSDimitry Andric   }
813*5ffd83dbSDimitry Andric   bool Gatscat = true;
814*5ffd83dbSDimitry Andric   for (User *U : I->users()) {
815*5ffd83dbSDimitry Andric     if (!isa<Instruction>(U))
816*5ffd83dbSDimitry Andric       return false;
817*5ffd83dbSDimitry Andric     if (isa<GetElementPtrInst>(U) ||
818*5ffd83dbSDimitry Andric         isGatherScatter(dyn_cast<IntrinsicInst>(U))) {
819*5ffd83dbSDimitry Andric       return Gatscat;
820*5ffd83dbSDimitry Andric     } else {
821*5ffd83dbSDimitry Andric       unsigned OpCode = cast<Instruction>(U)->getOpcode();
822*5ffd83dbSDimitry Andric       if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) &&
823*5ffd83dbSDimitry Andric           hasAllGatScatUsers(cast<Instruction>(U))) {
824*5ffd83dbSDimitry Andric         continue;
825*5ffd83dbSDimitry Andric       }
826*5ffd83dbSDimitry Andric       return false;
827*5ffd83dbSDimitry Andric     }
828*5ffd83dbSDimitry Andric   }
829*5ffd83dbSDimitry Andric   return Gatscat;
830*5ffd83dbSDimitry Andric }
831*5ffd83dbSDimitry Andric 
832*5ffd83dbSDimitry Andric bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
833*5ffd83dbSDimitry Andric                                                LoopInfo *LI) {
834*5ffd83dbSDimitry Andric   LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize\n");
835*5ffd83dbSDimitry Andric   // Optimise the addresses of gathers/scatters by moving invariant
836*5ffd83dbSDimitry Andric   // calculations out of the loop
837*5ffd83dbSDimitry Andric   if (!isa<Instruction>(Offsets))
838*5ffd83dbSDimitry Andric     return false;
839*5ffd83dbSDimitry Andric   Instruction *Offs = cast<Instruction>(Offsets);
840*5ffd83dbSDimitry Andric   if (Offs->getOpcode() != Instruction::Add &&
841*5ffd83dbSDimitry Andric       Offs->getOpcode() != Instruction::Mul)
842*5ffd83dbSDimitry Andric     return false;
843*5ffd83dbSDimitry Andric   Loop *L = LI->getLoopFor(BB);
844*5ffd83dbSDimitry Andric   if (L == nullptr)
845*5ffd83dbSDimitry Andric     return false;
846*5ffd83dbSDimitry Andric   if (!Offs->hasOneUse()) {
847*5ffd83dbSDimitry Andric     if (!hasAllGatScatUsers(Offs))
848*5ffd83dbSDimitry Andric       return false;
849*5ffd83dbSDimitry Andric   }
850*5ffd83dbSDimitry Andric 
851*5ffd83dbSDimitry Andric   // Find out which, if any, operand of the instruction
852*5ffd83dbSDimitry Andric   // is a phi node
853*5ffd83dbSDimitry Andric   PHINode *Phi;
854*5ffd83dbSDimitry Andric   int OffsSecondOp;
855*5ffd83dbSDimitry Andric   if (isa<PHINode>(Offs->getOperand(0))) {
856*5ffd83dbSDimitry Andric     Phi = cast<PHINode>(Offs->getOperand(0));
857*5ffd83dbSDimitry Andric     OffsSecondOp = 1;
858*5ffd83dbSDimitry Andric   } else if (isa<PHINode>(Offs->getOperand(1))) {
859*5ffd83dbSDimitry Andric     Phi = cast<PHINode>(Offs->getOperand(1));
860*5ffd83dbSDimitry Andric     OffsSecondOp = 0;
861*5ffd83dbSDimitry Andric   } else {
862*5ffd83dbSDimitry Andric     bool Changed = true;
863*5ffd83dbSDimitry Andric     if (isa<Instruction>(Offs->getOperand(0)) &&
864*5ffd83dbSDimitry Andric         L->contains(cast<Instruction>(Offs->getOperand(0))))
865*5ffd83dbSDimitry Andric       Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI);
866*5ffd83dbSDimitry Andric     if (isa<Instruction>(Offs->getOperand(1)) &&
867*5ffd83dbSDimitry Andric         L->contains(cast<Instruction>(Offs->getOperand(1))))
868*5ffd83dbSDimitry Andric       Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI);
869*5ffd83dbSDimitry Andric     if (!Changed) {
870*5ffd83dbSDimitry Andric       return false;
871*5ffd83dbSDimitry Andric     } else {
872*5ffd83dbSDimitry Andric       if (isa<PHINode>(Offs->getOperand(0))) {
873*5ffd83dbSDimitry Andric         Phi = cast<PHINode>(Offs->getOperand(0));
874*5ffd83dbSDimitry Andric         OffsSecondOp = 1;
875*5ffd83dbSDimitry Andric       } else if (isa<PHINode>(Offs->getOperand(1))) {
876*5ffd83dbSDimitry Andric         Phi = cast<PHINode>(Offs->getOperand(1));
877*5ffd83dbSDimitry Andric         OffsSecondOp = 0;
878*5ffd83dbSDimitry Andric       } else {
879*5ffd83dbSDimitry Andric         return false;
880*5ffd83dbSDimitry Andric       }
881*5ffd83dbSDimitry Andric     }
882*5ffd83dbSDimitry Andric   }
883*5ffd83dbSDimitry Andric   // A phi node we want to perform this function on should be from the
884*5ffd83dbSDimitry Andric   // loop header, and shouldn't have more than 2 incoming values
885*5ffd83dbSDimitry Andric   if (Phi->getParent() != L->getHeader() ||
886*5ffd83dbSDimitry Andric       Phi->getNumIncomingValues() != 2)
887*5ffd83dbSDimitry Andric     return false;
888*5ffd83dbSDimitry Andric 
889*5ffd83dbSDimitry Andric   // The phi must be an induction variable
890*5ffd83dbSDimitry Andric   Instruction *Op;
891*5ffd83dbSDimitry Andric   int IncrementingBlock = -1;
892*5ffd83dbSDimitry Andric 
893*5ffd83dbSDimitry Andric   for (int i = 0; i < 2; i++)
894*5ffd83dbSDimitry Andric     if ((Op = dyn_cast<Instruction>(Phi->getIncomingValue(i))) != nullptr)
895*5ffd83dbSDimitry Andric       if (Op->getOpcode() == Instruction::Add &&
896*5ffd83dbSDimitry Andric           (Op->getOperand(0) == Phi || Op->getOperand(1) == Phi))
897*5ffd83dbSDimitry Andric         IncrementingBlock = i;
898*5ffd83dbSDimitry Andric   if (IncrementingBlock == -1)
899*5ffd83dbSDimitry Andric     return false;
900*5ffd83dbSDimitry Andric 
901*5ffd83dbSDimitry Andric   Instruction *IncInstruction =
902*5ffd83dbSDimitry Andric       cast<Instruction>(Phi->getIncomingValue(IncrementingBlock));
903*5ffd83dbSDimitry Andric 
904*5ffd83dbSDimitry Andric   // If the phi is not used by anything else, we can just adapt it when
905*5ffd83dbSDimitry Andric   // replacing the instruction; if it is, we'll have to duplicate it
906*5ffd83dbSDimitry Andric   PHINode *NewPhi;
907*5ffd83dbSDimitry Andric   Value *IncrementPerRound = IncInstruction->getOperand(
908*5ffd83dbSDimitry Andric       (IncInstruction->getOperand(0) == Phi) ? 1 : 0);
909*5ffd83dbSDimitry Andric 
910*5ffd83dbSDimitry Andric   // Get the value that is added to/multiplied with the phi
911*5ffd83dbSDimitry Andric   Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp);
912*5ffd83dbSDimitry Andric 
913*5ffd83dbSDimitry Andric   if (IncrementPerRound->getType() != OffsSecondOperand->getType())
914*5ffd83dbSDimitry Andric     // Something has gone wrong, abort
915*5ffd83dbSDimitry Andric     return false;
916*5ffd83dbSDimitry Andric 
917*5ffd83dbSDimitry Andric   // Only proceed if the increment per round is a constant or an instruction
918*5ffd83dbSDimitry Andric   // which does not originate from within the loop
919*5ffd83dbSDimitry Andric   if (!isa<Constant>(IncrementPerRound) &&
920*5ffd83dbSDimitry Andric       !(isa<Instruction>(IncrementPerRound) &&
921*5ffd83dbSDimitry Andric         !L->contains(cast<Instruction>(IncrementPerRound))))
922*5ffd83dbSDimitry Andric     return false;
923*5ffd83dbSDimitry Andric 
924*5ffd83dbSDimitry Andric   if (Phi->getNumUses() == 2) {
925*5ffd83dbSDimitry Andric     // No other users -> reuse existing phi (One user is the instruction
926*5ffd83dbSDimitry Andric     // we're looking at, the other is the phi increment)
927*5ffd83dbSDimitry Andric     if (IncInstruction->getNumUses() != 1) {
928*5ffd83dbSDimitry Andric       // If the incrementing instruction does have more users than
929*5ffd83dbSDimitry Andric       // our phi, we need to copy it
930*5ffd83dbSDimitry Andric       IncInstruction = BinaryOperator::Create(
931*5ffd83dbSDimitry Andric           Instruction::BinaryOps(IncInstruction->getOpcode()), Phi,
932*5ffd83dbSDimitry Andric           IncrementPerRound, "LoopIncrement", IncInstruction);
933*5ffd83dbSDimitry Andric       Phi->setIncomingValue(IncrementingBlock, IncInstruction);
934*5ffd83dbSDimitry Andric     }
935*5ffd83dbSDimitry Andric     NewPhi = Phi;
936*5ffd83dbSDimitry Andric   } else {
937*5ffd83dbSDimitry Andric     // There are other users -> create a new phi
938*5ffd83dbSDimitry Andric     NewPhi = PHINode::Create(Phi->getType(), 0, "NewPhi", Phi);
939*5ffd83dbSDimitry Andric     std::vector<Value *> Increases;
940*5ffd83dbSDimitry Andric     // Copy the incoming values of the old phi
941*5ffd83dbSDimitry Andric     NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1),
942*5ffd83dbSDimitry Andric                         Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1));
943*5ffd83dbSDimitry Andric     IncInstruction = BinaryOperator::Create(
944*5ffd83dbSDimitry Andric         Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi,
945*5ffd83dbSDimitry Andric         IncrementPerRound, "LoopIncrement", IncInstruction);
946*5ffd83dbSDimitry Andric     NewPhi->addIncoming(IncInstruction,
947*5ffd83dbSDimitry Andric                         Phi->getIncomingBlock(IncrementingBlock));
948*5ffd83dbSDimitry Andric     IncrementingBlock = 1;
949*5ffd83dbSDimitry Andric   }
950*5ffd83dbSDimitry Andric 
951*5ffd83dbSDimitry Andric   IRBuilder<> Builder(BB->getContext());
952*5ffd83dbSDimitry Andric   Builder.SetInsertPoint(Phi);
953*5ffd83dbSDimitry Andric   Builder.SetCurrentDebugLocation(Offs->getDebugLoc());
954*5ffd83dbSDimitry Andric 
955*5ffd83dbSDimitry Andric   switch (Offs->getOpcode()) {
956*5ffd83dbSDimitry Andric   case Instruction::Add:
957*5ffd83dbSDimitry Andric     pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);
958*5ffd83dbSDimitry Andric     break;
959*5ffd83dbSDimitry Andric   case Instruction::Mul:
960*5ffd83dbSDimitry Andric     pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock,
961*5ffd83dbSDimitry Andric                Builder);
962*5ffd83dbSDimitry Andric     break;
963*5ffd83dbSDimitry Andric   default:
964*5ffd83dbSDimitry Andric     return false;
965*5ffd83dbSDimitry Andric   }
966*5ffd83dbSDimitry Andric   LLVM_DEBUG(
967*5ffd83dbSDimitry Andric       dbgs() << "masked gathers/scatters: simplified loop variable add/mul\n");
968*5ffd83dbSDimitry Andric 
969*5ffd83dbSDimitry Andric   // The instruction has now been "absorbed" into the phi value
970*5ffd83dbSDimitry Andric   Offs->replaceAllUsesWith(NewPhi);
971*5ffd83dbSDimitry Andric   if (Offs->hasNUses(0))
972*5ffd83dbSDimitry Andric     Offs->eraseFromParent();
973*5ffd83dbSDimitry Andric   // Clean up the old increment in case it's unused because we built a new
974*5ffd83dbSDimitry Andric   // one
975*5ffd83dbSDimitry Andric   if (IncInstruction->hasNUses(0))
976*5ffd83dbSDimitry Andric     IncInstruction->eraseFromParent();
977*5ffd83dbSDimitry Andric 
978*5ffd83dbSDimitry Andric   return true;
979480093f4SDimitry Andric }
980480093f4SDimitry Andric 
981480093f4SDimitry Andric bool MVEGatherScatterLowering::runOnFunction(Function &F) {
982480093f4SDimitry Andric   if (!EnableMaskedGatherScatters)
983480093f4SDimitry Andric     return false;
984480093f4SDimitry Andric   auto &TPC = getAnalysis<TargetPassConfig>();
985480093f4SDimitry Andric   auto &TM = TPC.getTM<TargetMachine>();
986480093f4SDimitry Andric   auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
987480093f4SDimitry Andric   if (!ST->hasMVEIntegerOps())
988480093f4SDimitry Andric     return false;
989*5ffd83dbSDimitry Andric   LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
990480093f4SDimitry Andric   SmallVector<IntrinsicInst *, 4> Gathers;
991*5ffd83dbSDimitry Andric   SmallVector<IntrinsicInst *, 4> Scatters;
992*5ffd83dbSDimitry Andric 
993*5ffd83dbSDimitry Andric   bool Changed = false;
994*5ffd83dbSDimitry Andric 
995480093f4SDimitry Andric   for (BasicBlock &BB : F) {
996480093f4SDimitry Andric     for (Instruction &I : BB) {
997480093f4SDimitry Andric       IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
998*5ffd83dbSDimitry Andric       if (II && II->getIntrinsicID() == Intrinsic::masked_gather) {
999480093f4SDimitry Andric         Gathers.push_back(II);
1000*5ffd83dbSDimitry Andric         if (isa<GetElementPtrInst>(II->getArgOperand(0)))
1001*5ffd83dbSDimitry Andric           Changed |= optimiseOffsets(
1002*5ffd83dbSDimitry Andric               cast<Instruction>(II->getArgOperand(0))->getOperand(1),
1003*5ffd83dbSDimitry Andric               II->getParent(), LI);
1004*5ffd83dbSDimitry Andric       } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) {
1005*5ffd83dbSDimitry Andric         Scatters.push_back(II);
1006*5ffd83dbSDimitry Andric         if (isa<GetElementPtrInst>(II->getArgOperand(1)))
1007*5ffd83dbSDimitry Andric           Changed |= optimiseOffsets(
1008*5ffd83dbSDimitry Andric               cast<Instruction>(II->getArgOperand(1))->getOperand(1),
1009*5ffd83dbSDimitry Andric               II->getParent(), LI);
1010*5ffd83dbSDimitry Andric       }
1011480093f4SDimitry Andric     }
1012480093f4SDimitry Andric   }
1013480093f4SDimitry Andric 
1014*5ffd83dbSDimitry Andric   for (unsigned i = 0; i < Gathers.size(); i++) {
1015*5ffd83dbSDimitry Andric     IntrinsicInst *I = Gathers[i];
1016*5ffd83dbSDimitry Andric     Value *L = lowerGather(I);
1017*5ffd83dbSDimitry Andric     if (L == nullptr)
1018*5ffd83dbSDimitry Andric       continue;
1019480093f4SDimitry Andric 
1020*5ffd83dbSDimitry Andric     // Get rid of any now dead instructions
1021*5ffd83dbSDimitry Andric     SimplifyInstructionsInBlock(cast<Instruction>(L)->getParent());
1022*5ffd83dbSDimitry Andric     Changed = true;
1023*5ffd83dbSDimitry Andric   }
1024480093f4SDimitry Andric 
1025*5ffd83dbSDimitry Andric   for (unsigned i = 0; i < Scatters.size(); i++) {
1026*5ffd83dbSDimitry Andric     IntrinsicInst *I = Scatters[i];
1027*5ffd83dbSDimitry Andric     Value *S = lowerScatter(I);
1028*5ffd83dbSDimitry Andric     if (S == nullptr)
1029*5ffd83dbSDimitry Andric       continue;
1030*5ffd83dbSDimitry Andric 
1031*5ffd83dbSDimitry Andric     // Get rid of any now dead instructions
1032*5ffd83dbSDimitry Andric     SimplifyInstructionsInBlock(cast<Instruction>(S)->getParent());
1033*5ffd83dbSDimitry Andric     Changed = true;
1034*5ffd83dbSDimitry Andric   }
1035*5ffd83dbSDimitry Andric   return Changed;
1036480093f4SDimitry Andric }
1037