1480093f4SDimitry Andric //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===//
2480093f4SDimitry Andric //
3480093f4SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4480093f4SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5480093f4SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6480093f4SDimitry Andric //
7480093f4SDimitry Andric //===----------------------------------------------------------------------===//
8480093f4SDimitry Andric //
9480093f4SDimitry Andric /// This pass custom lowers llvm.gather and llvm.scatter instructions to
10480093f4SDimitry Andric /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to
11480093f4SDimitry Andric /// produce a better final result as we go.
12480093f4SDimitry Andric //
13480093f4SDimitry Andric //===----------------------------------------------------------------------===//
14480093f4SDimitry Andric
15480093f4SDimitry Andric #include "ARM.h"
16480093f4SDimitry Andric #include "ARMBaseInstrInfo.h"
17480093f4SDimitry Andric #include "ARMSubtarget.h"
185ffd83dbSDimitry Andric #include "llvm/Analysis/LoopInfo.h"
19480093f4SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h"
20fe6060f1SDimitry Andric #include "llvm/Analysis/ValueTracking.h"
21480093f4SDimitry Andric #include "llvm/CodeGen/TargetLowering.h"
22480093f4SDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h"
23480093f4SDimitry Andric #include "llvm/CodeGen/TargetSubtargetInfo.h"
24480093f4SDimitry Andric #include "llvm/InitializePasses.h"
25480093f4SDimitry Andric #include "llvm/IR/BasicBlock.h"
26480093f4SDimitry Andric #include "llvm/IR/Constant.h"
27480093f4SDimitry Andric #include "llvm/IR/Constants.h"
28480093f4SDimitry Andric #include "llvm/IR/DerivedTypes.h"
29480093f4SDimitry Andric #include "llvm/IR/Function.h"
30480093f4SDimitry Andric #include "llvm/IR/InstrTypes.h"
31480093f4SDimitry Andric #include "llvm/IR/Instruction.h"
32480093f4SDimitry Andric #include "llvm/IR/Instructions.h"
33480093f4SDimitry Andric #include "llvm/IR/IntrinsicInst.h"
34480093f4SDimitry Andric #include "llvm/IR/Intrinsics.h"
35480093f4SDimitry Andric #include "llvm/IR/IntrinsicsARM.h"
36480093f4SDimitry Andric #include "llvm/IR/IRBuilder.h"
37480093f4SDimitry Andric #include "llvm/IR/PatternMatch.h"
38480093f4SDimitry Andric #include "llvm/IR/Type.h"
39480093f4SDimitry Andric #include "llvm/IR/Value.h"
40480093f4SDimitry Andric #include "llvm/Pass.h"
41480093f4SDimitry Andric #include "llvm/Support/Casting.h"
425ffd83dbSDimitry Andric #include "llvm/Transforms/Utils/Local.h"
43480093f4SDimitry Andric #include <algorithm>
44480093f4SDimitry Andric #include <cassert>
45480093f4SDimitry Andric
46480093f4SDimitry Andric using namespace llvm;
47480093f4SDimitry Andric
48e8d8bef9SDimitry Andric #define DEBUG_TYPE "arm-mve-gather-scatter-lowering"
49480093f4SDimitry Andric
50480093f4SDimitry Andric cl::opt<bool> EnableMaskedGatherScatters(
51e8d8bef9SDimitry Andric "enable-arm-maskedgatscat", cl::Hidden, cl::init(true),
52480093f4SDimitry Andric cl::desc("Enable the generation of masked gathers and scatters"));
53480093f4SDimitry Andric
54480093f4SDimitry Andric namespace {
55480093f4SDimitry Andric
56480093f4SDimitry Andric class MVEGatherScatterLowering : public FunctionPass {
57480093f4SDimitry Andric public:
58480093f4SDimitry Andric static char ID; // Pass identification, replacement for typeid
59480093f4SDimitry Andric
MVEGatherScatterLowering()60480093f4SDimitry Andric explicit MVEGatherScatterLowering() : FunctionPass(ID) {
61480093f4SDimitry Andric initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry());
62480093f4SDimitry Andric }
63480093f4SDimitry Andric
64480093f4SDimitry Andric bool runOnFunction(Function &F) override;
65480093f4SDimitry Andric
getPassName() const66480093f4SDimitry Andric StringRef getPassName() const override {
67480093f4SDimitry Andric return "MVE gather/scatter lowering";
68480093f4SDimitry Andric }
69480093f4SDimitry Andric
getAnalysisUsage(AnalysisUsage & AU) const70480093f4SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override {
71480093f4SDimitry Andric AU.setPreservesCFG();
72480093f4SDimitry Andric AU.addRequired<TargetPassConfig>();
735ffd83dbSDimitry Andric AU.addRequired<LoopInfoWrapperPass>();
74480093f4SDimitry Andric FunctionPass::getAnalysisUsage(AU);
75480093f4SDimitry Andric }
76480093f4SDimitry Andric
77480093f4SDimitry Andric private:
785ffd83dbSDimitry Andric LoopInfo *LI = nullptr;
79349cc55cSDimitry Andric const DataLayout *DL;
805ffd83dbSDimitry Andric
81480093f4SDimitry Andric // Check this is a valid gather with correct alignment
82480093f4SDimitry Andric bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,
835ffd83dbSDimitry Andric Align Alignment);
84480093f4SDimitry Andric // Check whether Ptr is hidden behind a bitcast and look through it
85480093f4SDimitry Andric void lookThroughBitcast(Value *&Ptr);
86fe6060f1SDimitry Andric // Decompose a ptr into Base and Offsets, potentially using a GEP to return a
87fe6060f1SDimitry Andric // scalar base and vector offsets, or else fallback to using a base of 0 and
88fe6060f1SDimitry Andric // offset of Ptr where possible.
89fe6060f1SDimitry Andric Value *decomposePtr(Value *Ptr, Value *&Offsets, int &Scale,
90fe6060f1SDimitry Andric FixedVectorType *Ty, Type *MemoryTy,
91fe6060f1SDimitry Andric IRBuilder<> &Builder);
92480093f4SDimitry Andric // Check for a getelementptr and deduce base and offsets from it, on success
93480093f4SDimitry Andric // returning the base directly and the offsets indirectly using the Offsets
94480093f4SDimitry Andric // argument
95fe6060f1SDimitry Andric Value *decomposeGEP(Value *&Offsets, FixedVectorType *Ty,
96fe6060f1SDimitry Andric GetElementPtrInst *GEP, IRBuilder<> &Builder);
975ffd83dbSDimitry Andric // Compute the scale of this gather/scatter instruction
985ffd83dbSDimitry Andric int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize);
995ffd83dbSDimitry Andric // If the value is a constant, or derived from constants via additions
1005ffd83dbSDimitry Andric // and multilications, return its numeric value
101bdd1243dSDimitry Andric std::optional<int64_t> getIfConst(const Value *V);
1025ffd83dbSDimitry Andric // If Inst is an add instruction, check whether one summand is a
1035ffd83dbSDimitry Andric // constant. If so, scale this constant and return it together with
1045ffd83dbSDimitry Andric // the other summand.
1055ffd83dbSDimitry Andric std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale);
106480093f4SDimitry Andric
107fe6060f1SDimitry Andric Instruction *lowerGather(IntrinsicInst *I);
108480093f4SDimitry Andric // Create a gather from a base + vector of offsets
109fe6060f1SDimitry Andric Instruction *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,
110fe6060f1SDimitry Andric Instruction *&Root,
111fe6060f1SDimitry Andric IRBuilder<> &Builder);
112480093f4SDimitry Andric // Create a gather from a vector of pointers
113fe6060f1SDimitry Andric Instruction *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,
114fe6060f1SDimitry Andric IRBuilder<> &Builder,
115fe6060f1SDimitry Andric int64_t Increment = 0);
1165ffd83dbSDimitry Andric // Create an incrementing gather from a vector of pointers
117fe6060f1SDimitry Andric Instruction *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr,
1185ffd83dbSDimitry Andric IRBuilder<> &Builder,
1195ffd83dbSDimitry Andric int64_t Increment = 0);
1205ffd83dbSDimitry Andric
121fe6060f1SDimitry Andric Instruction *lowerScatter(IntrinsicInst *I);
1225ffd83dbSDimitry Andric // Create a scatter to a base + vector of offsets
123fe6060f1SDimitry Andric Instruction *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets,
1245ffd83dbSDimitry Andric IRBuilder<> &Builder);
1255ffd83dbSDimitry Andric // Create a scatter to a vector of pointers
126fe6060f1SDimitry Andric Instruction *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr,
1275ffd83dbSDimitry Andric IRBuilder<> &Builder,
1285ffd83dbSDimitry Andric int64_t Increment = 0);
1295ffd83dbSDimitry Andric // Create an incrementing scatter from a vector of pointers
130fe6060f1SDimitry Andric Instruction *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr,
1315ffd83dbSDimitry Andric IRBuilder<> &Builder,
1325ffd83dbSDimitry Andric int64_t Increment = 0);
1335ffd83dbSDimitry Andric
1345ffd83dbSDimitry Andric // QI gathers and scatters can increment their offsets on their own if
1355ffd83dbSDimitry Andric // the increment is a constant value (digit)
136fe6060f1SDimitry Andric Instruction *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *Ptr,
1375ffd83dbSDimitry Andric IRBuilder<> &Builder);
1385ffd83dbSDimitry Andric // QI gathers/scatters can increment their offsets on their own if the
1395ffd83dbSDimitry Andric // increment is a constant value (digit) - this creates a writeback QI
1405ffd83dbSDimitry Andric // gather/scatter
141fe6060f1SDimitry Andric Instruction *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr,
1425ffd83dbSDimitry Andric Value *Ptr, unsigned TypeScale,
1435ffd83dbSDimitry Andric IRBuilder<> &Builder);
144e8d8bef9SDimitry Andric
145e8d8bef9SDimitry Andric // Optimise the base and offsets of the given address
146e8d8bef9SDimitry Andric bool optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI);
147e8d8bef9SDimitry Andric // Try to fold consecutive geps together into one
14881ad6265SDimitry Andric Value *foldGEP(GetElementPtrInst *GEP, Value *&Offsets, unsigned &Scale,
14981ad6265SDimitry Andric IRBuilder<> &Builder);
1505ffd83dbSDimitry Andric // Check whether these offsets could be moved out of the loop they're in
1515ffd83dbSDimitry Andric bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI);
1525ffd83dbSDimitry Andric // Pushes the given add out of the loop
1535ffd83dbSDimitry Andric void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);
154349cc55cSDimitry Andric // Pushes the given mul or shl out of the loop
155349cc55cSDimitry Andric void pushOutMulShl(unsigned Opc, PHINode *&Phi, Value *IncrementPerRound,
1565ffd83dbSDimitry Andric Value *OffsSecondOperand, unsigned LoopIncrement,
1575ffd83dbSDimitry Andric IRBuilder<> &Builder);
158480093f4SDimitry Andric };
159480093f4SDimitry Andric
160480093f4SDimitry Andric } // end anonymous namespace
161480093f4SDimitry Andric
162480093f4SDimitry Andric char MVEGatherScatterLowering::ID = 0;
163480093f4SDimitry Andric
164480093f4SDimitry Andric INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE,
165480093f4SDimitry Andric "MVE gather/scattering lowering pass", false, false)
166480093f4SDimitry Andric
createMVEGatherScatterLoweringPass()167480093f4SDimitry Andric Pass *llvm::createMVEGatherScatterLoweringPass() {
168480093f4SDimitry Andric return new MVEGatherScatterLowering();
169480093f4SDimitry Andric }
170480093f4SDimitry Andric
isLegalTypeAndAlignment(unsigned NumElements,unsigned ElemSize,Align Alignment)171480093f4SDimitry Andric bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,
172480093f4SDimitry Andric unsigned ElemSize,
1735ffd83dbSDimitry Andric Align Alignment) {
1745ffd83dbSDimitry Andric if (((NumElements == 4 &&
1755ffd83dbSDimitry Andric (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) ||
1765ffd83dbSDimitry Andric (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) ||
177480093f4SDimitry Andric (NumElements == 16 && ElemSize == 8)) &&
1785ffd83dbSDimitry Andric Alignment >= ElemSize / 8)
179480093f4SDimitry Andric return true;
1805ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have "
1815ffd83dbSDimitry Andric << "valid alignment or vector type \n");
182480093f4SDimitry Andric return false;
183480093f4SDimitry Andric }
184480093f4SDimitry Andric
checkOffsetSize(Value * Offsets,unsigned TargetElemCount)185e8d8bef9SDimitry Andric static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount) {
186e8d8bef9SDimitry Andric // Offsets that are not of type <N x i32> are sign extended by the
187e8d8bef9SDimitry Andric // getelementptr instruction, and MVE gathers/scatters treat the offset as
188e8d8bef9SDimitry Andric // unsigned. Thus, if the element size is smaller than 32, we can only allow
189e8d8bef9SDimitry Andric // positive offsets - i.e., the offsets are not allowed to be variables we
190e8d8bef9SDimitry Andric // can't look into.
191e8d8bef9SDimitry Andric // Additionally, <N x i32> offsets have to either originate from a zext of a
192e8d8bef9SDimitry Andric // vector with element types smaller or equal the type of the gather we're
193e8d8bef9SDimitry Andric // looking at, or consist of constants that we can check are small enough
194e8d8bef9SDimitry Andric // to fit into the gather type.
195e8d8bef9SDimitry Andric // Thus we check that 0 < value < 2^TargetElemSize.
196e8d8bef9SDimitry Andric unsigned TargetElemSize = 128 / TargetElemCount;
197e8d8bef9SDimitry Andric unsigned OffsetElemSize = cast<FixedVectorType>(Offsets->getType())
198e8d8bef9SDimitry Andric ->getElementType()
199e8d8bef9SDimitry Andric ->getScalarSizeInBits();
200e8d8bef9SDimitry Andric if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) {
201e8d8bef9SDimitry Andric Constant *ConstOff = dyn_cast<Constant>(Offsets);
202e8d8bef9SDimitry Andric if (!ConstOff)
203e8d8bef9SDimitry Andric return false;
204e8d8bef9SDimitry Andric int64_t TargetElemMaxSize = (1ULL << TargetElemSize);
205e8d8bef9SDimitry Andric auto CheckValueSize = [TargetElemMaxSize](Value *OffsetElem) {
206e8d8bef9SDimitry Andric ConstantInt *OConst = dyn_cast<ConstantInt>(OffsetElem);
207e8d8bef9SDimitry Andric if (!OConst)
208e8d8bef9SDimitry Andric return false;
209e8d8bef9SDimitry Andric int SExtValue = OConst->getSExtValue();
210e8d8bef9SDimitry Andric if (SExtValue >= TargetElemMaxSize || SExtValue < 0)
211e8d8bef9SDimitry Andric return false;
212e8d8bef9SDimitry Andric return true;
213e8d8bef9SDimitry Andric };
214e8d8bef9SDimitry Andric if (isa<FixedVectorType>(ConstOff->getType())) {
215e8d8bef9SDimitry Andric for (unsigned i = 0; i < TargetElemCount; i++) {
216e8d8bef9SDimitry Andric if (!CheckValueSize(ConstOff->getAggregateElement(i)))
217e8d8bef9SDimitry Andric return false;
218e8d8bef9SDimitry Andric }
219e8d8bef9SDimitry Andric } else {
220e8d8bef9SDimitry Andric if (!CheckValueSize(ConstOff))
221e8d8bef9SDimitry Andric return false;
222e8d8bef9SDimitry Andric }
223e8d8bef9SDimitry Andric }
224e8d8bef9SDimitry Andric return true;
225e8d8bef9SDimitry Andric }
226e8d8bef9SDimitry Andric
decomposePtr(Value * Ptr,Value * & Offsets,int & Scale,FixedVectorType * Ty,Type * MemoryTy,IRBuilder<> & Builder)227fe6060f1SDimitry Andric Value *MVEGatherScatterLowering::decomposePtr(Value *Ptr, Value *&Offsets,
228fe6060f1SDimitry Andric int &Scale, FixedVectorType *Ty,
229fe6060f1SDimitry Andric Type *MemoryTy,
230fe6060f1SDimitry Andric IRBuilder<> &Builder) {
231fe6060f1SDimitry Andric if (auto *GEP = dyn_cast<GetElementPtrInst>(Ptr)) {
232fe6060f1SDimitry Andric if (Value *V = decomposeGEP(Offsets, Ty, GEP, Builder)) {
233fe6060f1SDimitry Andric Scale =
234fe6060f1SDimitry Andric computeScale(GEP->getSourceElementType()->getPrimitiveSizeInBits(),
235fe6060f1SDimitry Andric MemoryTy->getScalarSizeInBits());
236fe6060f1SDimitry Andric return Scale == -1 ? nullptr : V;
237fe6060f1SDimitry Andric }
238fe6060f1SDimitry Andric }
239fe6060f1SDimitry Andric
240fe6060f1SDimitry Andric // If we couldn't use the GEP (or it doesn't exist), attempt to use a
241fe6060f1SDimitry Andric // BasePtr of 0 with Ptr as the Offsets, so long as there are only 4
242fe6060f1SDimitry Andric // elements.
243fe6060f1SDimitry Andric FixedVectorType *PtrTy = cast<FixedVectorType>(Ptr->getType());
244fe6060f1SDimitry Andric if (PtrTy->getNumElements() != 4 || MemoryTy->getScalarSizeInBits() == 32)
245fe6060f1SDimitry Andric return nullptr;
246fe6060f1SDimitry Andric Value *Zero = ConstantInt::get(Builder.getInt32Ty(), 0);
2475f757f3fSDimitry Andric Value *BasePtr = Builder.CreateIntToPtr(Zero, Builder.getPtrTy());
248fe6060f1SDimitry Andric Offsets = Builder.CreatePtrToInt(
249fe6060f1SDimitry Andric Ptr, FixedVectorType::get(Builder.getInt32Ty(), 4));
250fe6060f1SDimitry Andric Scale = 0;
251fe6060f1SDimitry Andric return BasePtr;
252fe6060f1SDimitry Andric }
253fe6060f1SDimitry Andric
decomposeGEP(Value * & Offsets,FixedVectorType * Ty,GetElementPtrInst * GEP,IRBuilder<> & Builder)254fe6060f1SDimitry Andric Value *MVEGatherScatterLowering::decomposeGEP(Value *&Offsets,
255fe6060f1SDimitry Andric FixedVectorType *Ty,
2565ffd83dbSDimitry Andric GetElementPtrInst *GEP,
2575ffd83dbSDimitry Andric IRBuilder<> &Builder) {
258480093f4SDimitry Andric if (!GEP) {
259fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: no getelementpointer "
260fe6060f1SDimitry Andric << "found\n");
261480093f4SDimitry Andric return nullptr;
262480093f4SDimitry Andric }
2635ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found."
2645ffd83dbSDimitry Andric << " Looking at intrinsic for base + vector of offsets\n");
265480093f4SDimitry Andric Value *GEPPtr = GEP->getPointerOperand();
266e8d8bef9SDimitry Andric Offsets = GEP->getOperand(1);
267e8d8bef9SDimitry Andric if (GEPPtr->getType()->isVectorTy() ||
268e8d8bef9SDimitry Andric !isa<FixedVectorType>(Offsets->getType()))
269480093f4SDimitry Andric return nullptr;
270e8d8bef9SDimitry Andric
271480093f4SDimitry Andric if (GEP->getNumOperands() != 2) {
2725ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many"
273480093f4SDimitry Andric << " operands. Expanding.\n");
274480093f4SDimitry Andric return nullptr;
275480093f4SDimitry Andric }
276480093f4SDimitry Andric Offsets = GEP->getOperand(1);
277e8d8bef9SDimitry Andric unsigned OffsetsElemCount =
278e8d8bef9SDimitry Andric cast<FixedVectorType>(Offsets->getType())->getNumElements();
2795ffd83dbSDimitry Andric // Paranoid check whether the number of parallel lanes is the same
280e8d8bef9SDimitry Andric assert(Ty->getNumElements() == OffsetsElemCount);
281e8d8bef9SDimitry Andric
282e8d8bef9SDimitry Andric ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets);
283e8d8bef9SDimitry Andric if (ZextOffs)
284480093f4SDimitry Andric Offsets = ZextOffs->getOperand(0);
285e8d8bef9SDimitry Andric FixedVectorType *OffsetType = cast<FixedVectorType>(Offsets->getType());
286e8d8bef9SDimitry Andric
287e8d8bef9SDimitry Andric // If the offsets are already being zext-ed to <N x i32>, that relieves us of
288e8d8bef9SDimitry Andric // having to make sure that they won't overflow.
289e8d8bef9SDimitry Andric if (!ZextOffs || cast<FixedVectorType>(ZextOffs->getDestTy())
290e8d8bef9SDimitry Andric ->getElementType()
291e8d8bef9SDimitry Andric ->getScalarSizeInBits() != 32)
292e8d8bef9SDimitry Andric if (!checkOffsetSize(Offsets, OffsetsElemCount))
293480093f4SDimitry Andric return nullptr;
2945ffd83dbSDimitry Andric
295e8d8bef9SDimitry Andric // The offset sizes have been checked; if any truncating or zext-ing is
296e8d8bef9SDimitry Andric // required to fix them, do that now
2975ffd83dbSDimitry Andric if (Ty != Offsets->getType()) {
298e8d8bef9SDimitry Andric if ((Ty->getElementType()->getScalarSizeInBits() <
299e8d8bef9SDimitry Andric OffsetType->getElementType()->getScalarSizeInBits())) {
300e8d8bef9SDimitry Andric Offsets = Builder.CreateTrunc(Offsets, Ty);
3015ffd83dbSDimitry Andric } else {
302e8d8bef9SDimitry Andric Offsets = Builder.CreateZExt(Offsets, VectorType::getInteger(Ty));
303480093f4SDimitry Andric }
304480093f4SDimitry Andric }
305480093f4SDimitry Andric // If none of the checks failed, return the gep's base pointer
3065ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n");
307480093f4SDimitry Andric return GEPPtr;
308480093f4SDimitry Andric }
309480093f4SDimitry Andric
lookThroughBitcast(Value * & Ptr)310480093f4SDimitry Andric void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) {
311480093f4SDimitry Andric // Look through bitcast instruction if #elements is the same
312480093f4SDimitry Andric if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) {
3135ffd83dbSDimitry Andric auto *BCTy = cast<FixedVectorType>(BitCast->getType());
3145ffd83dbSDimitry Andric auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType());
3155ffd83dbSDimitry Andric if (BCTy->getNumElements() == BCSrcTy->getNumElements()) {
316fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: looking through "
317fe6060f1SDimitry Andric << "bitcast\n");
318480093f4SDimitry Andric Ptr = BitCast->getOperand(0);
319480093f4SDimitry Andric }
320480093f4SDimitry Andric }
321480093f4SDimitry Andric }
322480093f4SDimitry Andric
computeScale(unsigned GEPElemSize,unsigned MemoryElemSize)3235ffd83dbSDimitry Andric int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize,
3245ffd83dbSDimitry Andric unsigned MemoryElemSize) {
3255ffd83dbSDimitry Andric // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2,
3265ffd83dbSDimitry Andric // or a 8bit, 16bit or 32bit load/store scaled by 1
3275ffd83dbSDimitry Andric if (GEPElemSize == 32 && MemoryElemSize == 32)
3285ffd83dbSDimitry Andric return 2;
3295ffd83dbSDimitry Andric else if (GEPElemSize == 16 && MemoryElemSize == 16)
3305ffd83dbSDimitry Andric return 1;
3315ffd83dbSDimitry Andric else if (GEPElemSize == 8)
3325ffd83dbSDimitry Andric return 0;
3335ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't "
3345ffd83dbSDimitry Andric << "create intrinsic\n");
3355ffd83dbSDimitry Andric return -1;
3365ffd83dbSDimitry Andric }
3375ffd83dbSDimitry Andric
getIfConst(const Value * V)338bdd1243dSDimitry Andric std::optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
3395ffd83dbSDimitry Andric const Constant *C = dyn_cast<Constant>(V);
340349cc55cSDimitry Andric if (C && C->getSplatValue())
341bdd1243dSDimitry Andric return std::optional<int64_t>{C->getUniqueInteger().getSExtValue()};
3425ffd83dbSDimitry Andric if (!isa<Instruction>(V))
343bdd1243dSDimitry Andric return std::optional<int64_t>{};
3445ffd83dbSDimitry Andric
3455ffd83dbSDimitry Andric const Instruction *I = cast<Instruction>(V);
346349cc55cSDimitry Andric if (I->getOpcode() == Instruction::Add || I->getOpcode() == Instruction::Or ||
347349cc55cSDimitry Andric I->getOpcode() == Instruction::Mul ||
348349cc55cSDimitry Andric I->getOpcode() == Instruction::Shl) {
349bdd1243dSDimitry Andric std::optional<int64_t> Op0 = getIfConst(I->getOperand(0));
350bdd1243dSDimitry Andric std::optional<int64_t> Op1 = getIfConst(I->getOperand(1));
3515ffd83dbSDimitry Andric if (!Op0 || !Op1)
352bdd1243dSDimitry Andric return std::optional<int64_t>{};
3535ffd83dbSDimitry Andric if (I->getOpcode() == Instruction::Add)
354bdd1243dSDimitry Andric return std::optional<int64_t>{*Op0 + *Op1};
3555ffd83dbSDimitry Andric if (I->getOpcode() == Instruction::Mul)
356bdd1243dSDimitry Andric return std::optional<int64_t>{*Op0 * *Op1};
357349cc55cSDimitry Andric if (I->getOpcode() == Instruction::Shl)
358bdd1243dSDimitry Andric return std::optional<int64_t>{*Op0 << *Op1};
359349cc55cSDimitry Andric if (I->getOpcode() == Instruction::Or)
360bdd1243dSDimitry Andric return std::optional<int64_t>{*Op0 | *Op1};
3615ffd83dbSDimitry Andric }
362bdd1243dSDimitry Andric return std::optional<int64_t>{};
3635ffd83dbSDimitry Andric }
3645ffd83dbSDimitry Andric
365349cc55cSDimitry Andric // Return true if I is an Or instruction that is equivalent to an add, due to
366349cc55cSDimitry Andric // the operands having no common bits set.
isAddLikeOr(Instruction * I,const DataLayout & DL)367349cc55cSDimitry Andric static bool isAddLikeOr(Instruction *I, const DataLayout &DL) {
368349cc55cSDimitry Andric return I->getOpcode() == Instruction::Or &&
369349cc55cSDimitry Andric haveNoCommonBitsSet(I->getOperand(0), I->getOperand(1), DL);
370349cc55cSDimitry Andric }
371349cc55cSDimitry Andric
3725ffd83dbSDimitry Andric std::pair<Value *, int64_t>
getVarAndConst(Value * Inst,int TypeScale)3735ffd83dbSDimitry Andric MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) {
3745ffd83dbSDimitry Andric std::pair<Value *, int64_t> ReturnFalse =
3755ffd83dbSDimitry Andric std::pair<Value *, int64_t>(nullptr, 0);
376349cc55cSDimitry Andric // At this point, the instruction we're looking at must be an add or an
377349cc55cSDimitry Andric // add-like-or.
3785ffd83dbSDimitry Andric Instruction *Add = dyn_cast<Instruction>(Inst);
379349cc55cSDimitry Andric if (Add == nullptr ||
380349cc55cSDimitry Andric (Add->getOpcode() != Instruction::Add && !isAddLikeOr(Add, *DL)))
3815ffd83dbSDimitry Andric return ReturnFalse;
3825ffd83dbSDimitry Andric
3835ffd83dbSDimitry Andric Value *Summand;
384bdd1243dSDimitry Andric std::optional<int64_t> Const;
3855ffd83dbSDimitry Andric // Find out which operand the value that is increased is
3865ffd83dbSDimitry Andric if ((Const = getIfConst(Add->getOperand(0))))
3875ffd83dbSDimitry Andric Summand = Add->getOperand(1);
3885ffd83dbSDimitry Andric else if ((Const = getIfConst(Add->getOperand(1))))
3895ffd83dbSDimitry Andric Summand = Add->getOperand(0);
3905ffd83dbSDimitry Andric else
3915ffd83dbSDimitry Andric return ReturnFalse;
3925ffd83dbSDimitry Andric
3935ffd83dbSDimitry Andric // Check that the constant is small enough for an incrementing gather
39481ad6265SDimitry Andric int64_t Immediate = *Const << TypeScale;
3955ffd83dbSDimitry Andric if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0)
3965ffd83dbSDimitry Andric return ReturnFalse;
3975ffd83dbSDimitry Andric
3985ffd83dbSDimitry Andric return std::pair<Value *, int64_t>(Summand, Immediate);
3995ffd83dbSDimitry Andric }
4005ffd83dbSDimitry Andric
lowerGather(IntrinsicInst * I)401fe6060f1SDimitry Andric Instruction *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
402480093f4SDimitry Andric using namespace PatternMatch;
403fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n"
404fe6060f1SDimitry Andric << *I << "\n");
405480093f4SDimitry Andric
406480093f4SDimitry Andric // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
407480093f4SDimitry Andric // Attempt to turn the masked gather in I into a MVE intrinsic
408480093f4SDimitry Andric // Potentially optimising the addressing modes as we do so.
4095ffd83dbSDimitry Andric auto *Ty = cast<FixedVectorType>(I->getType());
410480093f4SDimitry Andric Value *Ptr = I->getArgOperand(0);
4115ffd83dbSDimitry Andric Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue();
412480093f4SDimitry Andric Value *Mask = I->getArgOperand(2);
413480093f4SDimitry Andric Value *PassThru = I->getArgOperand(3);
414480093f4SDimitry Andric
4155ffd83dbSDimitry Andric if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
4165ffd83dbSDimitry Andric Alignment))
4175ffd83dbSDimitry Andric return nullptr;
418480093f4SDimitry Andric lookThroughBitcast(Ptr);
419480093f4SDimitry Andric assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
420480093f4SDimitry Andric
421480093f4SDimitry Andric IRBuilder<> Builder(I->getContext());
422480093f4SDimitry Andric Builder.SetInsertPoint(I);
423480093f4SDimitry Andric Builder.SetCurrentDebugLocation(I->getDebugLoc());
4245ffd83dbSDimitry Andric
4255ffd83dbSDimitry Andric Instruction *Root = I;
426fe6060f1SDimitry Andric
427fe6060f1SDimitry Andric Instruction *Load = tryCreateIncrementingGatScat(I, Ptr, Builder);
428fe6060f1SDimitry Andric if (!Load)
429fe6060f1SDimitry Andric Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder);
430480093f4SDimitry Andric if (!Load)
431480093f4SDimitry Andric Load = tryCreateMaskedGatherBase(I, Ptr, Builder);
432480093f4SDimitry Andric if (!Load)
4335ffd83dbSDimitry Andric return nullptr;
434480093f4SDimitry Andric
435480093f4SDimitry Andric if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) {
436480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
437480093f4SDimitry Andric << "creating select\n");
438fe6060f1SDimitry Andric Load = SelectInst::Create(Mask, Load, PassThru);
439fe6060f1SDimitry Andric Builder.Insert(Load);
440480093f4SDimitry Andric }
441480093f4SDimitry Andric
4425ffd83dbSDimitry Andric Root->replaceAllUsesWith(Load);
4435ffd83dbSDimitry Andric Root->eraseFromParent();
4445ffd83dbSDimitry Andric if (Root != I)
4455ffd83dbSDimitry Andric // If this was an extending gather, we need to get rid of the sext/zext
4465ffd83dbSDimitry Andric // sext/zext as well as of the gather itself
447480093f4SDimitry Andric I->eraseFromParent();
4485ffd83dbSDimitry Andric
449fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n"
450fe6060f1SDimitry Andric << *Load << "\n");
4515ffd83dbSDimitry Andric return Load;
452480093f4SDimitry Andric }
453480093f4SDimitry Andric
tryCreateMaskedGatherBase(IntrinsicInst * I,Value * Ptr,IRBuilder<> & Builder,int64_t Increment)454fe6060f1SDimitry Andric Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBase(
455fe6060f1SDimitry Andric IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
456480093f4SDimitry Andric using namespace PatternMatch;
4575ffd83dbSDimitry Andric auto *Ty = cast<FixedVectorType>(I->getType());
458480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
4595ffd83dbSDimitry Andric if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
460480093f4SDimitry Andric // Can't build an intrinsic for this
461480093f4SDimitry Andric return nullptr;
462480093f4SDimitry Andric Value *Mask = I->getArgOperand(2);
463480093f4SDimitry Andric if (match(Mask, m_One()))
464480093f4SDimitry Andric return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
465480093f4SDimitry Andric {Ty, Ptr->getType()},
4665ffd83dbSDimitry Andric {Ptr, Builder.getInt32(Increment)});
467480093f4SDimitry Andric else
468480093f4SDimitry Andric return Builder.CreateIntrinsic(
469480093f4SDimitry Andric Intrinsic::arm_mve_vldr_gather_base_predicated,
470480093f4SDimitry Andric {Ty, Ptr->getType(), Mask->getType()},
4715ffd83dbSDimitry Andric {Ptr, Builder.getInt32(Increment), Mask});
4725ffd83dbSDimitry Andric }
4735ffd83dbSDimitry Andric
tryCreateMaskedGatherBaseWB(IntrinsicInst * I,Value * Ptr,IRBuilder<> & Builder,int64_t Increment)474fe6060f1SDimitry Andric Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB(
4755ffd83dbSDimitry Andric IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
4765ffd83dbSDimitry Andric using namespace PatternMatch;
4775ffd83dbSDimitry Andric auto *Ty = cast<FixedVectorType>(I->getType());
478fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers with "
479fe6060f1SDimitry Andric << "writeback\n");
4805ffd83dbSDimitry Andric if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
4815ffd83dbSDimitry Andric // Can't build an intrinsic for this
4825ffd83dbSDimitry Andric return nullptr;
4835ffd83dbSDimitry Andric Value *Mask = I->getArgOperand(2);
4845ffd83dbSDimitry Andric if (match(Mask, m_One()))
4855ffd83dbSDimitry Andric return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb,
4865ffd83dbSDimitry Andric {Ty, Ptr->getType()},
4875ffd83dbSDimitry Andric {Ptr, Builder.getInt32(Increment)});
4885ffd83dbSDimitry Andric else
4895ffd83dbSDimitry Andric return Builder.CreateIntrinsic(
4905ffd83dbSDimitry Andric Intrinsic::arm_mve_vldr_gather_base_wb_predicated,
4915ffd83dbSDimitry Andric {Ty, Ptr->getType(), Mask->getType()},
4925ffd83dbSDimitry Andric {Ptr, Builder.getInt32(Increment), Mask});
493480093f4SDimitry Andric }
494480093f4SDimitry Andric
tryCreateMaskedGatherOffset(IntrinsicInst * I,Value * Ptr,Instruction * & Root,IRBuilder<> & Builder)495fe6060f1SDimitry Andric Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
4965ffd83dbSDimitry Andric IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) {
497480093f4SDimitry Andric using namespace PatternMatch;
4985ffd83dbSDimitry Andric
499fe6060f1SDimitry Andric Type *MemoryTy = I->getType();
500fe6060f1SDimitry Andric Type *ResultTy = MemoryTy;
5015ffd83dbSDimitry Andric
5025ffd83dbSDimitry Andric unsigned Unsigned = 1;
5035ffd83dbSDimitry Andric // The size of the gather was already checked in isLegalTypeAndAlignment;
5045ffd83dbSDimitry Andric // if it was not a full vector width an appropriate extend should follow.
5055ffd83dbSDimitry Andric auto *Extend = Root;
506fe6060f1SDimitry Andric bool TruncResult = false;
507fe6060f1SDimitry Andric if (MemoryTy->getPrimitiveSizeInBits() < 128) {
508fe6060f1SDimitry Andric if (I->hasOneUse()) {
509fe6060f1SDimitry Andric // If the gather has a single extend of the correct type, use an extending
510fe6060f1SDimitry Andric // gather and replace the ext. In which case the correct root to replace
511fe6060f1SDimitry Andric // is not the CallInst itself, but the instruction which extends it.
512fe6060f1SDimitry Andric Instruction* User = cast<Instruction>(*I->users().begin());
513fe6060f1SDimitry Andric if (isa<SExtInst>(User) &&
514fe6060f1SDimitry Andric User->getType()->getPrimitiveSizeInBits() == 128) {
515fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: "
516fe6060f1SDimitry Andric << *User << "\n");
517fe6060f1SDimitry Andric Extend = User;
518fe6060f1SDimitry Andric ResultTy = User->getType();
5195ffd83dbSDimitry Andric Unsigned = 0;
520fe6060f1SDimitry Andric } else if (isa<ZExtInst>(User) &&
521fe6060f1SDimitry Andric User->getType()->getPrimitiveSizeInBits() == 128) {
522fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: "
523fe6060f1SDimitry Andric << *ResultTy << "\n");
524fe6060f1SDimitry Andric Extend = User;
525fe6060f1SDimitry Andric ResultTy = User->getType();
526480093f4SDimitry Andric }
527fe6060f1SDimitry Andric }
528fe6060f1SDimitry Andric
529fe6060f1SDimitry Andric // If an extend hasn't been found and the type is an integer, create an
530fe6060f1SDimitry Andric // extending gather and truncate back to the original type.
531fe6060f1SDimitry Andric if (ResultTy->getPrimitiveSizeInBits() < 128 &&
532fe6060f1SDimitry Andric ResultTy->isIntOrIntVectorTy()) {
533fe6060f1SDimitry Andric ResultTy = ResultTy->getWithNewBitWidth(
534fe6060f1SDimitry Andric 128 / cast<FixedVectorType>(ResultTy)->getNumElements());
535fe6060f1SDimitry Andric TruncResult = true;
536fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: Small input type, truncing to: "
537fe6060f1SDimitry Andric << *ResultTy << "\n");
538fe6060f1SDimitry Andric }
539fe6060f1SDimitry Andric
5405ffd83dbSDimitry Andric // The final size of the gather must be a full vector width
5415ffd83dbSDimitry Andric if (ResultTy->getPrimitiveSizeInBits() != 128) {
542fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: Extend needed but not provided "
543fe6060f1SDimitry Andric "from the correct type. Expanding\n");
5445ffd83dbSDimitry Andric return nullptr;
5455ffd83dbSDimitry Andric }
5465ffd83dbSDimitry Andric }
5475ffd83dbSDimitry Andric
5485ffd83dbSDimitry Andric Value *Offsets;
549fe6060f1SDimitry Andric int Scale;
550fe6060f1SDimitry Andric Value *BasePtr = decomposePtr(
551fe6060f1SDimitry Andric Ptr, Offsets, Scale, cast<FixedVectorType>(ResultTy), MemoryTy, Builder);
5525ffd83dbSDimitry Andric if (!BasePtr)
5535ffd83dbSDimitry Andric return nullptr;
5545ffd83dbSDimitry Andric
5555ffd83dbSDimitry Andric Root = Extend;
556480093f4SDimitry Andric Value *Mask = I->getArgOperand(2);
557fe6060f1SDimitry Andric Instruction *Load = nullptr;
558480093f4SDimitry Andric if (!match(Mask, m_One()))
559fe6060f1SDimitry Andric Load = Builder.CreateIntrinsic(
560480093f4SDimitry Andric Intrinsic::arm_mve_vldr_gather_offset_predicated,
5615ffd83dbSDimitry Andric {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()},
562fe6060f1SDimitry Andric {BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()),
5635ffd83dbSDimitry Andric Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask});
564480093f4SDimitry Andric else
565fe6060f1SDimitry Andric Load = Builder.CreateIntrinsic(
566480093f4SDimitry Andric Intrinsic::arm_mve_vldr_gather_offset,
5675ffd83dbSDimitry Andric {ResultTy, BasePtr->getType(), Offsets->getType()},
568fe6060f1SDimitry Andric {BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()),
5695ffd83dbSDimitry Andric Builder.getInt32(Scale), Builder.getInt32(Unsigned)});
570fe6060f1SDimitry Andric
571fe6060f1SDimitry Andric if (TruncResult) {
572fe6060f1SDimitry Andric Load = TruncInst::Create(Instruction::Trunc, Load, MemoryTy);
573fe6060f1SDimitry Andric Builder.Insert(Load);
574fe6060f1SDimitry Andric }
575fe6060f1SDimitry Andric return Load;
5765ffd83dbSDimitry Andric }
5775ffd83dbSDimitry Andric
lowerScatter(IntrinsicInst * I)578fe6060f1SDimitry Andric Instruction *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) {
5795ffd83dbSDimitry Andric using namespace PatternMatch;
580fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n"
581fe6060f1SDimitry Andric << *I << "\n");
5825ffd83dbSDimitry Andric
5835ffd83dbSDimitry Andric // @llvm.masked.scatter.*(data, ptrs, alignment, mask)
5845ffd83dbSDimitry Andric // Attempt to turn the masked scatter in I into a MVE intrinsic
5855ffd83dbSDimitry Andric // Potentially optimising the addressing modes as we do so.
5865ffd83dbSDimitry Andric Value *Input = I->getArgOperand(0);
5875ffd83dbSDimitry Andric Value *Ptr = I->getArgOperand(1);
5885ffd83dbSDimitry Andric Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue();
5895ffd83dbSDimitry Andric auto *Ty = cast<FixedVectorType>(Input->getType());
5905ffd83dbSDimitry Andric
5915ffd83dbSDimitry Andric if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
5925ffd83dbSDimitry Andric Alignment))
5935ffd83dbSDimitry Andric return nullptr;
5945ffd83dbSDimitry Andric
5955ffd83dbSDimitry Andric lookThroughBitcast(Ptr);
5965ffd83dbSDimitry Andric assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
5975ffd83dbSDimitry Andric
5985ffd83dbSDimitry Andric IRBuilder<> Builder(I->getContext());
5995ffd83dbSDimitry Andric Builder.SetInsertPoint(I);
6005ffd83dbSDimitry Andric Builder.SetCurrentDebugLocation(I->getDebugLoc());
6015ffd83dbSDimitry Andric
602fe6060f1SDimitry Andric Instruction *Store = tryCreateIncrementingGatScat(I, Ptr, Builder);
603fe6060f1SDimitry Andric if (!Store)
604fe6060f1SDimitry Andric Store = tryCreateMaskedScatterOffset(I, Ptr, Builder);
6055ffd83dbSDimitry Andric if (!Store)
6065ffd83dbSDimitry Andric Store = tryCreateMaskedScatterBase(I, Ptr, Builder);
6075ffd83dbSDimitry Andric if (!Store)
6085ffd83dbSDimitry Andric return nullptr;
6095ffd83dbSDimitry Andric
610fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n"
611fe6060f1SDimitry Andric << *Store << "\n");
6125ffd83dbSDimitry Andric I->eraseFromParent();
6135ffd83dbSDimitry Andric return Store;
6145ffd83dbSDimitry Andric }
6155ffd83dbSDimitry Andric
tryCreateMaskedScatterBase(IntrinsicInst * I,Value * Ptr,IRBuilder<> & Builder,int64_t Increment)616fe6060f1SDimitry Andric Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBase(
6175ffd83dbSDimitry Andric IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
6185ffd83dbSDimitry Andric using namespace PatternMatch;
6195ffd83dbSDimitry Andric Value *Input = I->getArgOperand(0);
6205ffd83dbSDimitry Andric auto *Ty = cast<FixedVectorType>(Input->getType());
6215ffd83dbSDimitry Andric // Only QR variants allow truncating
6225ffd83dbSDimitry Andric if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) {
6235ffd83dbSDimitry Andric // Can't build an intrinsic for this
6245ffd83dbSDimitry Andric return nullptr;
6255ffd83dbSDimitry Andric }
6265ffd83dbSDimitry Andric Value *Mask = I->getArgOperand(3);
6275ffd83dbSDimitry Andric // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask)
6285ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n");
6295ffd83dbSDimitry Andric if (match(Mask, m_One()))
6305ffd83dbSDimitry Andric return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base,
6315ffd83dbSDimitry Andric {Ptr->getType(), Input->getType()},
6325ffd83dbSDimitry Andric {Ptr, Builder.getInt32(Increment), Input});
6335ffd83dbSDimitry Andric else
6345ffd83dbSDimitry Andric return Builder.CreateIntrinsic(
6355ffd83dbSDimitry Andric Intrinsic::arm_mve_vstr_scatter_base_predicated,
6365ffd83dbSDimitry Andric {Ptr->getType(), Input->getType(), Mask->getType()},
6375ffd83dbSDimitry Andric {Ptr, Builder.getInt32(Increment), Input, Mask});
6385ffd83dbSDimitry Andric }
6395ffd83dbSDimitry Andric
tryCreateMaskedScatterBaseWB(IntrinsicInst * I,Value * Ptr,IRBuilder<> & Builder,int64_t Increment)640fe6060f1SDimitry Andric Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB(
6415ffd83dbSDimitry Andric IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
6425ffd83dbSDimitry Andric using namespace PatternMatch;
6435ffd83dbSDimitry Andric Value *Input = I->getArgOperand(0);
6445ffd83dbSDimitry Andric auto *Ty = cast<FixedVectorType>(Input->getType());
645fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers "
646fe6060f1SDimitry Andric << "with writeback\n");
6475ffd83dbSDimitry Andric if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
6485ffd83dbSDimitry Andric // Can't build an intrinsic for this
6495ffd83dbSDimitry Andric return nullptr;
6505ffd83dbSDimitry Andric Value *Mask = I->getArgOperand(3);
6515ffd83dbSDimitry Andric if (match(Mask, m_One()))
6525ffd83dbSDimitry Andric return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb,
6535ffd83dbSDimitry Andric {Ptr->getType(), Input->getType()},
6545ffd83dbSDimitry Andric {Ptr, Builder.getInt32(Increment), Input});
6555ffd83dbSDimitry Andric else
6565ffd83dbSDimitry Andric return Builder.CreateIntrinsic(
6575ffd83dbSDimitry Andric Intrinsic::arm_mve_vstr_scatter_base_wb_predicated,
6585ffd83dbSDimitry Andric {Ptr->getType(), Input->getType(), Mask->getType()},
6595ffd83dbSDimitry Andric {Ptr, Builder.getInt32(Increment), Input, Mask});
6605ffd83dbSDimitry Andric }
6615ffd83dbSDimitry Andric
tryCreateMaskedScatterOffset(IntrinsicInst * I,Value * Ptr,IRBuilder<> & Builder)662fe6060f1SDimitry Andric Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterOffset(
6635ffd83dbSDimitry Andric IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
6645ffd83dbSDimitry Andric using namespace PatternMatch;
6655ffd83dbSDimitry Andric Value *Input = I->getArgOperand(0);
6665ffd83dbSDimitry Andric Value *Mask = I->getArgOperand(3);
6675ffd83dbSDimitry Andric Type *InputTy = Input->getType();
6685ffd83dbSDimitry Andric Type *MemoryTy = InputTy;
669fe6060f1SDimitry Andric
6705ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing"
6715ffd83dbSDimitry Andric << " to base + vector of offsets\n");
6725ffd83dbSDimitry Andric // If the input has been truncated, try to integrate that trunc into the
6735ffd83dbSDimitry Andric // scatter instruction (we don't care about alignment here)
6745ffd83dbSDimitry Andric if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) {
6755ffd83dbSDimitry Andric Value *PreTrunc = Trunc->getOperand(0);
6765ffd83dbSDimitry Andric Type *PreTruncTy = PreTrunc->getType();
6775ffd83dbSDimitry Andric if (PreTruncTy->getPrimitiveSizeInBits() == 128) {
6785ffd83dbSDimitry Andric Input = PreTrunc;
6795ffd83dbSDimitry Andric InputTy = PreTruncTy;
6805ffd83dbSDimitry Andric }
6815ffd83dbSDimitry Andric }
682fe6060f1SDimitry Andric bool ExtendInput = false;
683fe6060f1SDimitry Andric if (InputTy->getPrimitiveSizeInBits() < 128 &&
684fe6060f1SDimitry Andric InputTy->isIntOrIntVectorTy()) {
685fe6060f1SDimitry Andric // If we can't find a trunc to incorporate into the instruction, create an
686fe6060f1SDimitry Andric // implicit one with a zext, so that we can still create a scatter. We know
687fe6060f1SDimitry Andric // that the input type is 4x/8x/16x and of type i8/i16/i32, so any type
688fe6060f1SDimitry Andric // smaller than 128 bits will divide evenly into a 128bit vector.
689fe6060f1SDimitry Andric InputTy = InputTy->getWithNewBitWidth(
690fe6060f1SDimitry Andric 128 / cast<FixedVectorType>(InputTy)->getNumElements());
691fe6060f1SDimitry Andric ExtendInput = true;
692fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked scatters: Small input type, will extend:\n"
693fe6060f1SDimitry Andric << *Input << "\n");
694fe6060f1SDimitry Andric }
6955ffd83dbSDimitry Andric if (InputTy->getPrimitiveSizeInBits() != 128) {
696fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked scatters: cannot create scatters for "
697fe6060f1SDimitry Andric "non-standard input types. Expanding.\n");
6985ffd83dbSDimitry Andric return nullptr;
6995ffd83dbSDimitry Andric }
7005ffd83dbSDimitry Andric
7015ffd83dbSDimitry Andric Value *Offsets;
702fe6060f1SDimitry Andric int Scale;
703fe6060f1SDimitry Andric Value *BasePtr = decomposePtr(
704fe6060f1SDimitry Andric Ptr, Offsets, Scale, cast<FixedVectorType>(InputTy), MemoryTy, Builder);
7055ffd83dbSDimitry Andric if (!BasePtr)
7065ffd83dbSDimitry Andric return nullptr;
7075ffd83dbSDimitry Andric
708fe6060f1SDimitry Andric if (ExtendInput)
709fe6060f1SDimitry Andric Input = Builder.CreateZExt(Input, InputTy);
7105ffd83dbSDimitry Andric if (!match(Mask, m_One()))
7115ffd83dbSDimitry Andric return Builder.CreateIntrinsic(
7125ffd83dbSDimitry Andric Intrinsic::arm_mve_vstr_scatter_offset_predicated,
7135ffd83dbSDimitry Andric {BasePtr->getType(), Offsets->getType(), Input->getType(),
7145ffd83dbSDimitry Andric Mask->getType()},
7155ffd83dbSDimitry Andric {BasePtr, Offsets, Input,
7165ffd83dbSDimitry Andric Builder.getInt32(MemoryTy->getScalarSizeInBits()),
7175ffd83dbSDimitry Andric Builder.getInt32(Scale), Mask});
7185ffd83dbSDimitry Andric else
7195ffd83dbSDimitry Andric return Builder.CreateIntrinsic(
7205ffd83dbSDimitry Andric Intrinsic::arm_mve_vstr_scatter_offset,
7215ffd83dbSDimitry Andric {BasePtr->getType(), Offsets->getType(), Input->getType()},
7225ffd83dbSDimitry Andric {BasePtr, Offsets, Input,
7235ffd83dbSDimitry Andric Builder.getInt32(MemoryTy->getScalarSizeInBits()),
7245ffd83dbSDimitry Andric Builder.getInt32(Scale)});
7255ffd83dbSDimitry Andric }
7265ffd83dbSDimitry Andric
tryCreateIncrementingGatScat(IntrinsicInst * I,Value * Ptr,IRBuilder<> & Builder)727fe6060f1SDimitry Andric Instruction *MVEGatherScatterLowering::tryCreateIncrementingGatScat(
728fe6060f1SDimitry Andric IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
7295ffd83dbSDimitry Andric FixedVectorType *Ty;
7305ffd83dbSDimitry Andric if (I->getIntrinsicID() == Intrinsic::masked_gather)
7315ffd83dbSDimitry Andric Ty = cast<FixedVectorType>(I->getType());
7325ffd83dbSDimitry Andric else
7335ffd83dbSDimitry Andric Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType());
734fe6060f1SDimitry Andric
7355ffd83dbSDimitry Andric // Incrementing gathers only exist for v4i32
736fe6060f1SDimitry Andric if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
7375ffd83dbSDimitry Andric return nullptr;
738fe6060f1SDimitry Andric // Incrementing gathers are not beneficial outside of a loop
7395ffd83dbSDimitry Andric Loop *L = LI->getLoopFor(I->getParent());
7405ffd83dbSDimitry Andric if (L == nullptr)
7415ffd83dbSDimitry Andric return nullptr;
742fe6060f1SDimitry Andric
743fe6060f1SDimitry Andric // Decompose the GEP into Base and Offsets
744fe6060f1SDimitry Andric GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
745fe6060f1SDimitry Andric Value *Offsets;
746fe6060f1SDimitry Andric Value *BasePtr = decomposeGEP(Offsets, Ty, GEP, Builder);
747fe6060f1SDimitry Andric if (!BasePtr)
748fe6060f1SDimitry Andric return nullptr;
749fe6060f1SDimitry Andric
7505ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
7515ffd83dbSDimitry Andric "wb gather/scatter\n");
7525ffd83dbSDimitry Andric
7535ffd83dbSDimitry Andric // The gep was in charge of making sure the offsets are scaled correctly
7545ffd83dbSDimitry Andric // - calculate that factor so it can be applied by hand
7555ffd83dbSDimitry Andric int TypeScale =
756349cc55cSDimitry Andric computeScale(DL->getTypeSizeInBits(GEP->getOperand(0)->getType()),
757349cc55cSDimitry Andric DL->getTypeSizeInBits(GEP->getType()) /
7585ffd83dbSDimitry Andric cast<FixedVectorType>(GEP->getType())->getNumElements());
7595ffd83dbSDimitry Andric if (TypeScale == -1)
7605ffd83dbSDimitry Andric return nullptr;
7615ffd83dbSDimitry Andric
7625ffd83dbSDimitry Andric if (GEP->hasOneUse()) {
7635ffd83dbSDimitry Andric // Only in this case do we want to build a wb gather, because the wb will
7645ffd83dbSDimitry Andric // change the phi which does affect other users of the gep (which will still
7655ffd83dbSDimitry Andric // be using the phi in the old way)
766fe6060f1SDimitry Andric if (auto *Load = tryCreateIncrementingWBGatScat(I, BasePtr, Offsets,
767fe6060f1SDimitry Andric TypeScale, Builder))
7685ffd83dbSDimitry Andric return Load;
7695ffd83dbSDimitry Andric }
770fe6060f1SDimitry Andric
7715ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
7725ffd83dbSDimitry Andric "non-wb gather/scatter\n");
7735ffd83dbSDimitry Andric
7745ffd83dbSDimitry Andric std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
7755ffd83dbSDimitry Andric if (Add.first == nullptr)
7765ffd83dbSDimitry Andric return nullptr;
7775ffd83dbSDimitry Andric Value *OffsetsIncoming = Add.first;
7785ffd83dbSDimitry Andric int64_t Immediate = Add.second;
7795ffd83dbSDimitry Andric
7805ffd83dbSDimitry Andric // Make sure the offsets are scaled correctly
7815ffd83dbSDimitry Andric Instruction *ScaledOffsets = BinaryOperator::Create(
7825ffd83dbSDimitry Andric Instruction::Shl, OffsetsIncoming,
783*0fca6ea1SDimitry Andric Builder.CreateVectorSplat(Ty->getNumElements(),
784*0fca6ea1SDimitry Andric Builder.getInt32(TypeScale)),
785*0fca6ea1SDimitry Andric "ScaledIndex", I->getIterator());
7865ffd83dbSDimitry Andric // Add the base to the offsets
7875ffd83dbSDimitry Andric OffsetsIncoming = BinaryOperator::Create(
7885ffd83dbSDimitry Andric Instruction::Add, ScaledOffsets,
7895ffd83dbSDimitry Andric Builder.CreateVectorSplat(
7905ffd83dbSDimitry Andric Ty->getNumElements(),
7915ffd83dbSDimitry Andric Builder.CreatePtrToInt(
7925ffd83dbSDimitry Andric BasePtr,
7935ffd83dbSDimitry Andric cast<VectorType>(ScaledOffsets->getType())->getElementType())),
794*0fca6ea1SDimitry Andric "StartIndex", I->getIterator());
7955ffd83dbSDimitry Andric
7965ffd83dbSDimitry Andric if (I->getIntrinsicID() == Intrinsic::masked_gather)
797fe6060f1SDimitry Andric return tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate);
7985ffd83dbSDimitry Andric else
799fe6060f1SDimitry Andric return tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate);
8005ffd83dbSDimitry Andric }
8015ffd83dbSDimitry Andric
tryCreateIncrementingWBGatScat(IntrinsicInst * I,Value * BasePtr,Value * Offsets,unsigned TypeScale,IRBuilder<> & Builder)802fe6060f1SDimitry Andric Instruction *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat(
8035ffd83dbSDimitry Andric IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale,
8045ffd83dbSDimitry Andric IRBuilder<> &Builder) {
8055ffd83dbSDimitry Andric // Check whether this gather's offset is incremented by a constant - if so,
8065ffd83dbSDimitry Andric // and the load is of the right type, we can merge this into a QI gather
8075ffd83dbSDimitry Andric Loop *L = LI->getLoopFor(I->getParent());
8085ffd83dbSDimitry Andric // Offsets that are worth merging into this instruction will be incremented
8095ffd83dbSDimitry Andric // by a constant, thus we're looking for an add of a phi and a constant
8105ffd83dbSDimitry Andric PHINode *Phi = dyn_cast<PHINode>(Offsets);
8115ffd83dbSDimitry Andric if (Phi == nullptr || Phi->getNumIncomingValues() != 2 ||
8125ffd83dbSDimitry Andric Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2)
8135ffd83dbSDimitry Andric // No phi means no IV to write back to; if there is a phi, we expect it
8145ffd83dbSDimitry Andric // to have exactly two incoming values; the only phis we are interested in
8155ffd83dbSDimitry Andric // will be loop IV's and have exactly two uses, one in their increment and
8165ffd83dbSDimitry Andric // one in the gather's gep
8175ffd83dbSDimitry Andric return nullptr;
8185ffd83dbSDimitry Andric
8195ffd83dbSDimitry Andric unsigned IncrementIndex =
8205ffd83dbSDimitry Andric Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1;
8215ffd83dbSDimitry Andric // Look through the phi to the phi increment
8225ffd83dbSDimitry Andric Offsets = Phi->getIncomingValue(IncrementIndex);
8235ffd83dbSDimitry Andric
8245ffd83dbSDimitry Andric std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
8255ffd83dbSDimitry Andric if (Add.first == nullptr)
8265ffd83dbSDimitry Andric return nullptr;
8275ffd83dbSDimitry Andric Value *OffsetsIncoming = Add.first;
8285ffd83dbSDimitry Andric int64_t Immediate = Add.second;
8295ffd83dbSDimitry Andric if (OffsetsIncoming != Phi)
8305ffd83dbSDimitry Andric // Then the increment we are looking at is not an increment of the
8315ffd83dbSDimitry Andric // induction variable, and we don't want to do a writeback
8325ffd83dbSDimitry Andric return nullptr;
8335ffd83dbSDimitry Andric
8345ffd83dbSDimitry Andric Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back());
8355ffd83dbSDimitry Andric unsigned NumElems =
8365ffd83dbSDimitry Andric cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements();
8375ffd83dbSDimitry Andric
8385ffd83dbSDimitry Andric // Make sure the offsets are scaled correctly
8395ffd83dbSDimitry Andric Instruction *ScaledOffsets = BinaryOperator::Create(
8405ffd83dbSDimitry Andric Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex),
8415ffd83dbSDimitry Andric Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)),
842*0fca6ea1SDimitry Andric "ScaledIndex",
843*0fca6ea1SDimitry Andric Phi->getIncomingBlock(1 - IncrementIndex)->back().getIterator());
8445ffd83dbSDimitry Andric // Add the base to the offsets
8455ffd83dbSDimitry Andric OffsetsIncoming = BinaryOperator::Create(
8465ffd83dbSDimitry Andric Instruction::Add, ScaledOffsets,
8475ffd83dbSDimitry Andric Builder.CreateVectorSplat(
8485ffd83dbSDimitry Andric NumElems,
8495ffd83dbSDimitry Andric Builder.CreatePtrToInt(
8505ffd83dbSDimitry Andric BasePtr,
8515ffd83dbSDimitry Andric cast<VectorType>(ScaledOffsets->getType())->getElementType())),
852*0fca6ea1SDimitry Andric "StartIndex",
853*0fca6ea1SDimitry Andric Phi->getIncomingBlock(1 - IncrementIndex)->back().getIterator());
8545ffd83dbSDimitry Andric // The gather is pre-incrementing
8555ffd83dbSDimitry Andric OffsetsIncoming = BinaryOperator::Create(
8565ffd83dbSDimitry Andric Instruction::Sub, OffsetsIncoming,
8575ffd83dbSDimitry Andric Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)),
8585ffd83dbSDimitry Andric "PreIncrementStartIndex",
859*0fca6ea1SDimitry Andric Phi->getIncomingBlock(1 - IncrementIndex)->back().getIterator());
8605ffd83dbSDimitry Andric Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming);
8615ffd83dbSDimitry Andric
8625ffd83dbSDimitry Andric Builder.SetInsertPoint(I);
8635ffd83dbSDimitry Andric
864fe6060f1SDimitry Andric Instruction *EndResult;
865fe6060f1SDimitry Andric Instruction *NewInduction;
8665ffd83dbSDimitry Andric if (I->getIntrinsicID() == Intrinsic::masked_gather) {
8675ffd83dbSDimitry Andric // Build the incrementing gather
8685ffd83dbSDimitry Andric Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate);
8695ffd83dbSDimitry Andric // One value to be handed to whoever uses the gather, one is the loop
8705ffd83dbSDimitry Andric // increment
871fe6060f1SDimitry Andric EndResult = ExtractValueInst::Create(Load, 0, "Gather");
872fe6060f1SDimitry Andric NewInduction = ExtractValueInst::Create(Load, 1, "GatherIncrement");
873fe6060f1SDimitry Andric Builder.Insert(EndResult);
874fe6060f1SDimitry Andric Builder.Insert(NewInduction);
8755ffd83dbSDimitry Andric } else {
8765ffd83dbSDimitry Andric // Build the incrementing scatter
877fe6060f1SDimitry Andric EndResult = NewInduction =
878fe6060f1SDimitry Andric tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate);
8795ffd83dbSDimitry Andric }
8805ffd83dbSDimitry Andric Instruction *AddInst = cast<Instruction>(Offsets);
8815ffd83dbSDimitry Andric AddInst->replaceAllUsesWith(NewInduction);
8825ffd83dbSDimitry Andric AddInst->eraseFromParent();
8835ffd83dbSDimitry Andric Phi->setIncomingValue(IncrementIndex, NewInduction);
8845ffd83dbSDimitry Andric
8855ffd83dbSDimitry Andric return EndResult;
8865ffd83dbSDimitry Andric }
8875ffd83dbSDimitry Andric
pushOutAdd(PHINode * & Phi,Value * OffsSecondOperand,unsigned StartIndex)8885ffd83dbSDimitry Andric void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,
8895ffd83dbSDimitry Andric Value *OffsSecondOperand,
8905ffd83dbSDimitry Andric unsigned StartIndex) {
8915ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n");
892*0fca6ea1SDimitry Andric BasicBlock::iterator InsertionPoint =
893*0fca6ea1SDimitry Andric Phi->getIncomingBlock(StartIndex)->back().getIterator();
8945ffd83dbSDimitry Andric // Initialize the phi with a vector that contains a sum of the constants
8955ffd83dbSDimitry Andric Instruction *NewIndex = BinaryOperator::Create(
8965ffd83dbSDimitry Andric Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand,
8975ffd83dbSDimitry Andric "PushedOutAdd", InsertionPoint);
8985ffd83dbSDimitry Andric unsigned IncrementIndex = StartIndex == 0 ? 1 : 0;
8995ffd83dbSDimitry Andric
9005ffd83dbSDimitry Andric // Order such that start index comes first (this reduces mov's)
9015ffd83dbSDimitry Andric Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex));
9025ffd83dbSDimitry Andric Phi->addIncoming(Phi->getIncomingValue(IncrementIndex),
9035ffd83dbSDimitry Andric Phi->getIncomingBlock(IncrementIndex));
9047a6dacacSDimitry Andric Phi->removeIncomingValue(1);
9057a6dacacSDimitry Andric Phi->removeIncomingValue((unsigned)0);
9065ffd83dbSDimitry Andric }
9075ffd83dbSDimitry Andric
pushOutMulShl(unsigned Opcode,PHINode * & Phi,Value * IncrementPerRound,Value * OffsSecondOperand,unsigned LoopIncrement,IRBuilder<> & Builder)908349cc55cSDimitry Andric void MVEGatherScatterLowering::pushOutMulShl(unsigned Opcode, PHINode *&Phi,
9095ffd83dbSDimitry Andric Value *IncrementPerRound,
9105ffd83dbSDimitry Andric Value *OffsSecondOperand,
9115ffd83dbSDimitry Andric unsigned LoopIncrement,
9125ffd83dbSDimitry Andric IRBuilder<> &Builder) {
9135ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");
9145ffd83dbSDimitry Andric
9155ffd83dbSDimitry Andric // Create a new scalar add outside of the loop and transform it to a splat
9165ffd83dbSDimitry Andric // by which loop variable can be incremented
917*0fca6ea1SDimitry Andric BasicBlock::iterator InsertionPoint =
918*0fca6ea1SDimitry Andric Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back().getIterator();
9195ffd83dbSDimitry Andric
9205ffd83dbSDimitry Andric // Create a new index
921349cc55cSDimitry Andric Value *StartIndex =
922349cc55cSDimitry Andric BinaryOperator::Create((Instruction::BinaryOps)Opcode,
923349cc55cSDimitry Andric Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),
9245ffd83dbSDimitry Andric OffsSecondOperand, "PushedOutMul", InsertionPoint);
9255ffd83dbSDimitry Andric
9265ffd83dbSDimitry Andric Instruction *Product =
927349cc55cSDimitry Andric BinaryOperator::Create((Instruction::BinaryOps)Opcode, IncrementPerRound,
9285ffd83dbSDimitry Andric OffsSecondOperand, "Product", InsertionPoint);
929*0fca6ea1SDimitry Andric
930*0fca6ea1SDimitry Andric BasicBlock::iterator NewIncrInsertPt =
931*0fca6ea1SDimitry Andric Phi->getIncomingBlock(LoopIncrement)->back().getIterator();
932*0fca6ea1SDimitry Andric NewIncrInsertPt = std::prev(NewIncrInsertPt);
933*0fca6ea1SDimitry Andric
9345ffd83dbSDimitry Andric // Increment NewIndex by Product instead of the multiplication
9355ffd83dbSDimitry Andric Instruction *NewIncrement = BinaryOperator::Create(
936*0fca6ea1SDimitry Andric Instruction::Add, Phi, Product, "IncrementPushedOutMul", NewIncrInsertPt);
9375ffd83dbSDimitry Andric
9385ffd83dbSDimitry Andric Phi->addIncoming(StartIndex,
9395ffd83dbSDimitry Andric Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1));
9405ffd83dbSDimitry Andric Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement));
9415ffd83dbSDimitry Andric Phi->removeIncomingValue((unsigned)0);
9425ffd83dbSDimitry Andric Phi->removeIncomingValue((unsigned)0);
9435ffd83dbSDimitry Andric }
9445ffd83dbSDimitry Andric
9455ffd83dbSDimitry Andric // Check whether all usages of this instruction are as offsets of
9465ffd83dbSDimitry Andric // gathers/scatters or simple arithmetics only used by gathers/scatters
hasAllGatScatUsers(Instruction * I,const DataLayout & DL)947349cc55cSDimitry Andric static bool hasAllGatScatUsers(Instruction *I, const DataLayout &DL) {
9485ffd83dbSDimitry Andric if (I->hasNUses(0)) {
9495ffd83dbSDimitry Andric return false;
9505ffd83dbSDimitry Andric }
9515ffd83dbSDimitry Andric bool Gatscat = true;
9525ffd83dbSDimitry Andric for (User *U : I->users()) {
9535ffd83dbSDimitry Andric if (!isa<Instruction>(U))
9545ffd83dbSDimitry Andric return false;
9555ffd83dbSDimitry Andric if (isa<GetElementPtrInst>(U) ||
9565ffd83dbSDimitry Andric isGatherScatter(dyn_cast<IntrinsicInst>(U))) {
9575ffd83dbSDimitry Andric return Gatscat;
9585ffd83dbSDimitry Andric } else {
9595ffd83dbSDimitry Andric unsigned OpCode = cast<Instruction>(U)->getOpcode();
960349cc55cSDimitry Andric if ((OpCode == Instruction::Add || OpCode == Instruction::Mul ||
961349cc55cSDimitry Andric OpCode == Instruction::Shl ||
962349cc55cSDimitry Andric isAddLikeOr(cast<Instruction>(U), DL)) &&
963349cc55cSDimitry Andric hasAllGatScatUsers(cast<Instruction>(U), DL)) {
9645ffd83dbSDimitry Andric continue;
9655ffd83dbSDimitry Andric }
9665ffd83dbSDimitry Andric return false;
9675ffd83dbSDimitry Andric }
9685ffd83dbSDimitry Andric }
9695ffd83dbSDimitry Andric return Gatscat;
9705ffd83dbSDimitry Andric }
9715ffd83dbSDimitry Andric
optimiseOffsets(Value * Offsets,BasicBlock * BB,LoopInfo * LI)9725ffd83dbSDimitry Andric bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
9735ffd83dbSDimitry Andric LoopInfo *LI) {
97481ad6265SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize: "
975fe6060f1SDimitry Andric << *Offsets << "\n");
9765ffd83dbSDimitry Andric // Optimise the addresses of gathers/scatters by moving invariant
9775ffd83dbSDimitry Andric // calculations out of the loop
9785ffd83dbSDimitry Andric if (!isa<Instruction>(Offsets))
9795ffd83dbSDimitry Andric return false;
9805ffd83dbSDimitry Andric Instruction *Offs = cast<Instruction>(Offsets);
981349cc55cSDimitry Andric if (Offs->getOpcode() != Instruction::Add && !isAddLikeOr(Offs, *DL) &&
982349cc55cSDimitry Andric Offs->getOpcode() != Instruction::Mul &&
983349cc55cSDimitry Andric Offs->getOpcode() != Instruction::Shl)
9845ffd83dbSDimitry Andric return false;
9855ffd83dbSDimitry Andric Loop *L = LI->getLoopFor(BB);
9865ffd83dbSDimitry Andric if (L == nullptr)
9875ffd83dbSDimitry Andric return false;
9885ffd83dbSDimitry Andric if (!Offs->hasOneUse()) {
989349cc55cSDimitry Andric if (!hasAllGatScatUsers(Offs, *DL))
9905ffd83dbSDimitry Andric return false;
9915ffd83dbSDimitry Andric }
9925ffd83dbSDimitry Andric
9935ffd83dbSDimitry Andric // Find out which, if any, operand of the instruction
9945ffd83dbSDimitry Andric // is a phi node
9955ffd83dbSDimitry Andric PHINode *Phi;
9965ffd83dbSDimitry Andric int OffsSecondOp;
9975ffd83dbSDimitry Andric if (isa<PHINode>(Offs->getOperand(0))) {
9985ffd83dbSDimitry Andric Phi = cast<PHINode>(Offs->getOperand(0));
9995ffd83dbSDimitry Andric OffsSecondOp = 1;
10005ffd83dbSDimitry Andric } else if (isa<PHINode>(Offs->getOperand(1))) {
10015ffd83dbSDimitry Andric Phi = cast<PHINode>(Offs->getOperand(1));
10025ffd83dbSDimitry Andric OffsSecondOp = 0;
10035ffd83dbSDimitry Andric } else {
1004fe6060f1SDimitry Andric bool Changed = false;
10055ffd83dbSDimitry Andric if (isa<Instruction>(Offs->getOperand(0)) &&
10065ffd83dbSDimitry Andric L->contains(cast<Instruction>(Offs->getOperand(0))))
10075ffd83dbSDimitry Andric Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI);
10085ffd83dbSDimitry Andric if (isa<Instruction>(Offs->getOperand(1)) &&
10095ffd83dbSDimitry Andric L->contains(cast<Instruction>(Offs->getOperand(1))))
10105ffd83dbSDimitry Andric Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI);
1011fe6060f1SDimitry Andric if (!Changed)
10125ffd83dbSDimitry Andric return false;
10135ffd83dbSDimitry Andric if (isa<PHINode>(Offs->getOperand(0))) {
10145ffd83dbSDimitry Andric Phi = cast<PHINode>(Offs->getOperand(0));
10155ffd83dbSDimitry Andric OffsSecondOp = 1;
10165ffd83dbSDimitry Andric } else if (isa<PHINode>(Offs->getOperand(1))) {
10175ffd83dbSDimitry Andric Phi = cast<PHINode>(Offs->getOperand(1));
10185ffd83dbSDimitry Andric OffsSecondOp = 0;
10195ffd83dbSDimitry Andric } else {
10205ffd83dbSDimitry Andric return false;
10215ffd83dbSDimitry Andric }
10225ffd83dbSDimitry Andric }
10235ffd83dbSDimitry Andric // A phi node we want to perform this function on should be from the
1024fe6060f1SDimitry Andric // loop header.
1025fe6060f1SDimitry Andric if (Phi->getParent() != L->getHeader())
10265ffd83dbSDimitry Andric return false;
10275ffd83dbSDimitry Andric
1028fe6060f1SDimitry Andric // We're looking for a simple add recurrence.
1029fe6060f1SDimitry Andric BinaryOperator *IncInstruction;
1030fe6060f1SDimitry Andric Value *Start, *IncrementPerRound;
1031fe6060f1SDimitry Andric if (!matchSimpleRecurrence(Phi, IncInstruction, Start, IncrementPerRound) ||
1032fe6060f1SDimitry Andric IncInstruction->getOpcode() != Instruction::Add)
10335ffd83dbSDimitry Andric return false;
10345ffd83dbSDimitry Andric
1035fe6060f1SDimitry Andric int IncrementingBlock = Phi->getIncomingValue(0) == IncInstruction ? 0 : 1;
10365ffd83dbSDimitry Andric
10375ffd83dbSDimitry Andric // Get the value that is added to/multiplied with the phi
10385ffd83dbSDimitry Andric Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp);
10395ffd83dbSDimitry Andric
10404652422eSDimitry Andric if (IncrementPerRound->getType() != OffsSecondOperand->getType() ||
10414652422eSDimitry Andric !L->isLoopInvariant(OffsSecondOperand))
10425ffd83dbSDimitry Andric // Something has gone wrong, abort
10435ffd83dbSDimitry Andric return false;
10445ffd83dbSDimitry Andric
10455ffd83dbSDimitry Andric // Only proceed if the increment per round is a constant or an instruction
10465ffd83dbSDimitry Andric // which does not originate from within the loop
10475ffd83dbSDimitry Andric if (!isa<Constant>(IncrementPerRound) &&
10485ffd83dbSDimitry Andric !(isa<Instruction>(IncrementPerRound) &&
10495ffd83dbSDimitry Andric !L->contains(cast<Instruction>(IncrementPerRound))))
10505ffd83dbSDimitry Andric return false;
10515ffd83dbSDimitry Andric
1052fe6060f1SDimitry Andric // If the phi is not used by anything else, we can just adapt it when
1053fe6060f1SDimitry Andric // replacing the instruction; if it is, we'll have to duplicate it
1054fe6060f1SDimitry Andric PHINode *NewPhi;
10555ffd83dbSDimitry Andric if (Phi->getNumUses() == 2) {
10565ffd83dbSDimitry Andric // No other users -> reuse existing phi (One user is the instruction
10575ffd83dbSDimitry Andric // we're looking at, the other is the phi increment)
10585ffd83dbSDimitry Andric if (IncInstruction->getNumUses() != 1) {
10595ffd83dbSDimitry Andric // If the incrementing instruction does have more users than
10605ffd83dbSDimitry Andric // our phi, we need to copy it
10615ffd83dbSDimitry Andric IncInstruction = BinaryOperator::Create(
10625ffd83dbSDimitry Andric Instruction::BinaryOps(IncInstruction->getOpcode()), Phi,
1063*0fca6ea1SDimitry Andric IncrementPerRound, "LoopIncrement", IncInstruction->getIterator());
10645ffd83dbSDimitry Andric Phi->setIncomingValue(IncrementingBlock, IncInstruction);
10655ffd83dbSDimitry Andric }
10665ffd83dbSDimitry Andric NewPhi = Phi;
10675ffd83dbSDimitry Andric } else {
10685ffd83dbSDimitry Andric // There are other users -> create a new phi
1069*0fca6ea1SDimitry Andric NewPhi = PHINode::Create(Phi->getType(), 2, "NewPhi", Phi->getIterator());
10705ffd83dbSDimitry Andric // Copy the incoming values of the old phi
10715ffd83dbSDimitry Andric NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1),
10725ffd83dbSDimitry Andric Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1));
10735ffd83dbSDimitry Andric IncInstruction = BinaryOperator::Create(
10745ffd83dbSDimitry Andric Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi,
1075*0fca6ea1SDimitry Andric IncrementPerRound, "LoopIncrement", IncInstruction->getIterator());
10765ffd83dbSDimitry Andric NewPhi->addIncoming(IncInstruction,
10775ffd83dbSDimitry Andric Phi->getIncomingBlock(IncrementingBlock));
10785ffd83dbSDimitry Andric IncrementingBlock = 1;
10795ffd83dbSDimitry Andric }
10805ffd83dbSDimitry Andric
10815ffd83dbSDimitry Andric IRBuilder<> Builder(BB->getContext());
10825ffd83dbSDimitry Andric Builder.SetInsertPoint(Phi);
10835ffd83dbSDimitry Andric Builder.SetCurrentDebugLocation(Offs->getDebugLoc());
10845ffd83dbSDimitry Andric
10855ffd83dbSDimitry Andric switch (Offs->getOpcode()) {
10865ffd83dbSDimitry Andric case Instruction::Add:
1087349cc55cSDimitry Andric case Instruction::Or:
10885ffd83dbSDimitry Andric pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);
10895ffd83dbSDimitry Andric break;
10905ffd83dbSDimitry Andric case Instruction::Mul:
1091349cc55cSDimitry Andric case Instruction::Shl:
1092349cc55cSDimitry Andric pushOutMulShl(Offs->getOpcode(), NewPhi, IncrementPerRound,
1093349cc55cSDimitry Andric OffsSecondOperand, IncrementingBlock, Builder);
10945ffd83dbSDimitry Andric break;
10955ffd83dbSDimitry Andric default:
10965ffd83dbSDimitry Andric return false;
10975ffd83dbSDimitry Andric }
1098fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: simplified loop variable "
1099fe6060f1SDimitry Andric << "add/mul\n");
11005ffd83dbSDimitry Andric
11015ffd83dbSDimitry Andric // The instruction has now been "absorbed" into the phi value
11025ffd83dbSDimitry Andric Offs->replaceAllUsesWith(NewPhi);
11035ffd83dbSDimitry Andric if (Offs->hasNUses(0))
11045ffd83dbSDimitry Andric Offs->eraseFromParent();
11055ffd83dbSDimitry Andric // Clean up the old increment in case it's unused because we built a new
11065ffd83dbSDimitry Andric // one
11075ffd83dbSDimitry Andric if (IncInstruction->hasNUses(0))
11085ffd83dbSDimitry Andric IncInstruction->eraseFromParent();
11095ffd83dbSDimitry Andric
11105ffd83dbSDimitry Andric return true;
1111480093f4SDimitry Andric }
1112480093f4SDimitry Andric
CheckAndCreateOffsetAdd(Value * X,unsigned ScaleX,Value * Y,unsigned ScaleY,IRBuilder<> & Builder)111381ad6265SDimitry Andric static Value *CheckAndCreateOffsetAdd(Value *X, unsigned ScaleX, Value *Y,
111481ad6265SDimitry Andric unsigned ScaleY, IRBuilder<> &Builder) {
1115e8d8bef9SDimitry Andric // Splat the non-vector value to a vector of the given type - if the value is
1116e8d8bef9SDimitry Andric // a constant (and its value isn't too big), we can even use this opportunity
1117e8d8bef9SDimitry Andric // to scale it to the size of the vector elements
1118e8d8bef9SDimitry Andric auto FixSummands = [&Builder](FixedVectorType *&VT, Value *&NonVectorVal) {
1119e8d8bef9SDimitry Andric ConstantInt *Const;
1120e8d8bef9SDimitry Andric if ((Const = dyn_cast<ConstantInt>(NonVectorVal)) &&
1121e8d8bef9SDimitry Andric VT->getElementType() != NonVectorVal->getType()) {
1122e8d8bef9SDimitry Andric unsigned TargetElemSize = VT->getElementType()->getPrimitiveSizeInBits();
1123e8d8bef9SDimitry Andric uint64_t N = Const->getZExtValue();
1124e8d8bef9SDimitry Andric if (N < (unsigned)(1 << (TargetElemSize - 1))) {
1125e8d8bef9SDimitry Andric NonVectorVal = Builder.CreateVectorSplat(
1126e8d8bef9SDimitry Andric VT->getNumElements(), Builder.getIntN(TargetElemSize, N));
1127e8d8bef9SDimitry Andric return;
1128e8d8bef9SDimitry Andric }
1129e8d8bef9SDimitry Andric }
1130e8d8bef9SDimitry Andric NonVectorVal =
1131e8d8bef9SDimitry Andric Builder.CreateVectorSplat(VT->getNumElements(), NonVectorVal);
1132e8d8bef9SDimitry Andric };
1133e8d8bef9SDimitry Andric
1134e8d8bef9SDimitry Andric FixedVectorType *XElType = dyn_cast<FixedVectorType>(X->getType());
1135e8d8bef9SDimitry Andric FixedVectorType *YElType = dyn_cast<FixedVectorType>(Y->getType());
1136e8d8bef9SDimitry Andric // If one of X, Y is not a vector, we have to splat it in order
1137e8d8bef9SDimitry Andric // to add the two of them.
1138e8d8bef9SDimitry Andric if (XElType && !YElType) {
1139e8d8bef9SDimitry Andric FixSummands(XElType, Y);
1140e8d8bef9SDimitry Andric YElType = cast<FixedVectorType>(Y->getType());
1141e8d8bef9SDimitry Andric } else if (YElType && !XElType) {
1142e8d8bef9SDimitry Andric FixSummands(YElType, X);
1143e8d8bef9SDimitry Andric XElType = cast<FixedVectorType>(X->getType());
1144e8d8bef9SDimitry Andric }
1145e8d8bef9SDimitry Andric assert(XElType && YElType && "Unknown vector types");
1146e8d8bef9SDimitry Andric // Check that the summands are of compatible types
1147e8d8bef9SDimitry Andric if (XElType != YElType) {
1148e8d8bef9SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: incompatible gep offsets\n");
1149e8d8bef9SDimitry Andric return nullptr;
1150e8d8bef9SDimitry Andric }
1151e8d8bef9SDimitry Andric
1152e8d8bef9SDimitry Andric if (XElType->getElementType()->getScalarSizeInBits() != 32) {
1153e8d8bef9SDimitry Andric // Check that by adding the vectors we do not accidentally
1154e8d8bef9SDimitry Andric // create an overflow
1155e8d8bef9SDimitry Andric Constant *ConstX = dyn_cast<Constant>(X);
1156e8d8bef9SDimitry Andric Constant *ConstY = dyn_cast<Constant>(Y);
1157e8d8bef9SDimitry Andric if (!ConstX || !ConstY)
1158e8d8bef9SDimitry Andric return nullptr;
1159e8d8bef9SDimitry Andric unsigned TargetElemSize = 128 / XElType->getNumElements();
1160e8d8bef9SDimitry Andric for (unsigned i = 0; i < XElType->getNumElements(); i++) {
1161e8d8bef9SDimitry Andric ConstantInt *ConstXEl =
1162e8d8bef9SDimitry Andric dyn_cast<ConstantInt>(ConstX->getAggregateElement(i));
1163e8d8bef9SDimitry Andric ConstantInt *ConstYEl =
1164e8d8bef9SDimitry Andric dyn_cast<ConstantInt>(ConstY->getAggregateElement(i));
1165e8d8bef9SDimitry Andric if (!ConstXEl || !ConstYEl ||
116681ad6265SDimitry Andric ConstXEl->getZExtValue() * ScaleX +
116781ad6265SDimitry Andric ConstYEl->getZExtValue() * ScaleY >=
1168e8d8bef9SDimitry Andric (unsigned)(1 << (TargetElemSize - 1)))
1169e8d8bef9SDimitry Andric return nullptr;
1170e8d8bef9SDimitry Andric }
1171e8d8bef9SDimitry Andric }
1172e8d8bef9SDimitry Andric
117381ad6265SDimitry Andric Value *XScale = Builder.CreateVectorSplat(
117481ad6265SDimitry Andric XElType->getNumElements(),
117581ad6265SDimitry Andric Builder.getIntN(XElType->getScalarSizeInBits(), ScaleX));
117681ad6265SDimitry Andric Value *YScale = Builder.CreateVectorSplat(
117781ad6265SDimitry Andric YElType->getNumElements(),
117881ad6265SDimitry Andric Builder.getIntN(YElType->getScalarSizeInBits(), ScaleY));
117981ad6265SDimitry Andric Value *Add = Builder.CreateAdd(Builder.CreateMul(X, XScale),
118081ad6265SDimitry Andric Builder.CreateMul(Y, YScale));
1181e8d8bef9SDimitry Andric
118281ad6265SDimitry Andric if (checkOffsetSize(Add, XElType->getNumElements()))
1183e8d8bef9SDimitry Andric return Add;
1184e8d8bef9SDimitry Andric else
1185e8d8bef9SDimitry Andric return nullptr;
1186e8d8bef9SDimitry Andric }
1187e8d8bef9SDimitry Andric
foldGEP(GetElementPtrInst * GEP,Value * & Offsets,unsigned & Scale,IRBuilder<> & Builder)1188e8d8bef9SDimitry Andric Value *MVEGatherScatterLowering::foldGEP(GetElementPtrInst *GEP,
118981ad6265SDimitry Andric Value *&Offsets, unsigned &Scale,
1190e8d8bef9SDimitry Andric IRBuilder<> &Builder) {
1191e8d8bef9SDimitry Andric Value *GEPPtr = GEP->getPointerOperand();
1192e8d8bef9SDimitry Andric Offsets = GEP->getOperand(1);
119381ad6265SDimitry Andric Scale = DL->getTypeAllocSize(GEP->getSourceElementType());
1194e8d8bef9SDimitry Andric // We only merge geps with constant offsets, because only for those
1195e8d8bef9SDimitry Andric // we can make sure that we do not cause an overflow
119681ad6265SDimitry Andric if (GEP->getNumIndices() != 1 || !isa<Constant>(Offsets))
1197e8d8bef9SDimitry Andric return nullptr;
119881ad6265SDimitry Andric if (GetElementPtrInst *BaseGEP = dyn_cast<GetElementPtrInst>(GEPPtr)) {
1199e8d8bef9SDimitry Andric // Merge the two geps into one
120081ad6265SDimitry Andric Value *BaseBasePtr = foldGEP(BaseGEP, Offsets, Scale, Builder);
1201e8d8bef9SDimitry Andric if (!BaseBasePtr)
1202e8d8bef9SDimitry Andric return nullptr;
120381ad6265SDimitry Andric Offsets = CheckAndCreateOffsetAdd(
120481ad6265SDimitry Andric Offsets, Scale, GEP->getOperand(1),
120581ad6265SDimitry Andric DL->getTypeAllocSize(GEP->getSourceElementType()), Builder);
1206e8d8bef9SDimitry Andric if (Offsets == nullptr)
1207e8d8bef9SDimitry Andric return nullptr;
120881ad6265SDimitry Andric Scale = 1; // Scale is always an i8 at this point.
1209e8d8bef9SDimitry Andric return BaseBasePtr;
1210e8d8bef9SDimitry Andric }
1211e8d8bef9SDimitry Andric return GEPPtr;
1212e8d8bef9SDimitry Andric }
1213e8d8bef9SDimitry Andric
optimiseAddress(Value * Address,BasicBlock * BB,LoopInfo * LI)1214e8d8bef9SDimitry Andric bool MVEGatherScatterLowering::optimiseAddress(Value *Address, BasicBlock *BB,
1215e8d8bef9SDimitry Andric LoopInfo *LI) {
1216e8d8bef9SDimitry Andric GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Address);
1217e8d8bef9SDimitry Andric if (!GEP)
1218e8d8bef9SDimitry Andric return false;
1219e8d8bef9SDimitry Andric bool Changed = false;
1220349cc55cSDimitry Andric if (GEP->hasOneUse() && isa<GetElementPtrInst>(GEP->getPointerOperand())) {
1221e8d8bef9SDimitry Andric IRBuilder<> Builder(GEP->getContext());
1222e8d8bef9SDimitry Andric Builder.SetInsertPoint(GEP);
1223e8d8bef9SDimitry Andric Builder.SetCurrentDebugLocation(GEP->getDebugLoc());
1224e8d8bef9SDimitry Andric Value *Offsets;
122581ad6265SDimitry Andric unsigned Scale;
122681ad6265SDimitry Andric Value *Base = foldGEP(GEP, Offsets, Scale, Builder);
1227e8d8bef9SDimitry Andric // We only want to merge the geps if there is a real chance that they can be
1228e8d8bef9SDimitry Andric // used by an MVE gather; thus the offset has to have the correct size
1229e8d8bef9SDimitry Andric // (always i32 if it is not of vector type) and the base has to be a
1230e8d8bef9SDimitry Andric // pointer.
1231e8d8bef9SDimitry Andric if (Offsets && Base && Base != GEP) {
123281ad6265SDimitry Andric assert(Scale == 1 && "Expected to fold GEP to a scale of 1");
12335f757f3fSDimitry Andric Type *BaseTy = Builder.getPtrTy();
123481ad6265SDimitry Andric if (auto *VecTy = dyn_cast<FixedVectorType>(Base->getType()))
123581ad6265SDimitry Andric BaseTy = FixedVectorType::get(BaseTy, VecTy);
1236e8d8bef9SDimitry Andric GetElementPtrInst *NewAddress = GetElementPtrInst::Create(
123781ad6265SDimitry Andric Builder.getInt8Ty(), Builder.CreateBitCast(Base, BaseTy), Offsets,
1238*0fca6ea1SDimitry Andric "gep.merged", GEP->getIterator());
123981ad6265SDimitry Andric LLVM_DEBUG(dbgs() << "Folded GEP: " << *GEP
124081ad6265SDimitry Andric << "\n new : " << *NewAddress << "\n");
124181ad6265SDimitry Andric GEP->replaceAllUsesWith(
124281ad6265SDimitry Andric Builder.CreateBitCast(NewAddress, GEP->getType()));
1243e8d8bef9SDimitry Andric GEP = NewAddress;
1244e8d8bef9SDimitry Andric Changed = true;
1245e8d8bef9SDimitry Andric }
1246e8d8bef9SDimitry Andric }
1247e8d8bef9SDimitry Andric Changed |= optimiseOffsets(GEP->getOperand(1), GEP->getParent(), LI);
1248e8d8bef9SDimitry Andric return Changed;
1249e8d8bef9SDimitry Andric }
1250e8d8bef9SDimitry Andric
runOnFunction(Function & F)1251480093f4SDimitry Andric bool MVEGatherScatterLowering::runOnFunction(Function &F) {
1252480093f4SDimitry Andric if (!EnableMaskedGatherScatters)
1253480093f4SDimitry Andric return false;
1254480093f4SDimitry Andric auto &TPC = getAnalysis<TargetPassConfig>();
1255480093f4SDimitry Andric auto &TM = TPC.getTM<TargetMachine>();
1256480093f4SDimitry Andric auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
1257480093f4SDimitry Andric if (!ST->hasMVEIntegerOps())
1258480093f4SDimitry Andric return false;
12595ffd83dbSDimitry Andric LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1260*0fca6ea1SDimitry Andric DL = &F.getDataLayout();
1261480093f4SDimitry Andric SmallVector<IntrinsicInst *, 4> Gathers;
12625ffd83dbSDimitry Andric SmallVector<IntrinsicInst *, 4> Scatters;
12635ffd83dbSDimitry Andric
12645ffd83dbSDimitry Andric bool Changed = false;
12655ffd83dbSDimitry Andric
1266480093f4SDimitry Andric for (BasicBlock &BB : F) {
12674652422eSDimitry Andric Changed |= SimplifyInstructionsInBlock(&BB);
12684652422eSDimitry Andric
1269480093f4SDimitry Andric for (Instruction &I : BB) {
1270480093f4SDimitry Andric IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
1271e8d8bef9SDimitry Andric if (II && II->getIntrinsicID() == Intrinsic::masked_gather &&
1272e8d8bef9SDimitry Andric isa<FixedVectorType>(II->getType())) {
1273480093f4SDimitry Andric Gathers.push_back(II);
1274e8d8bef9SDimitry Andric Changed |= optimiseAddress(II->getArgOperand(0), II->getParent(), LI);
1275e8d8bef9SDimitry Andric } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter &&
1276e8d8bef9SDimitry Andric isa<FixedVectorType>(II->getArgOperand(0)->getType())) {
12775ffd83dbSDimitry Andric Scatters.push_back(II);
1278e8d8bef9SDimitry Andric Changed |= optimiseAddress(II->getArgOperand(1), II->getParent(), LI);
12795ffd83dbSDimitry Andric }
1280480093f4SDimitry Andric }
1281480093f4SDimitry Andric }
1282*0fca6ea1SDimitry Andric for (IntrinsicInst *I : Gathers) {
1283fe6060f1SDimitry Andric Instruction *L = lowerGather(I);
12845ffd83dbSDimitry Andric if (L == nullptr)
12855ffd83dbSDimitry Andric continue;
1286480093f4SDimitry Andric
12875ffd83dbSDimitry Andric // Get rid of any now dead instructions
1288fe6060f1SDimitry Andric SimplifyInstructionsInBlock(L->getParent());
12895ffd83dbSDimitry Andric Changed = true;
12905ffd83dbSDimitry Andric }
1291480093f4SDimitry Andric
1292*0fca6ea1SDimitry Andric for (IntrinsicInst *I : Scatters) {
1293fe6060f1SDimitry Andric Instruction *S = lowerScatter(I);
12945ffd83dbSDimitry Andric if (S == nullptr)
12955ffd83dbSDimitry Andric continue;
12965ffd83dbSDimitry Andric
12975ffd83dbSDimitry Andric // Get rid of any now dead instructions
1298fe6060f1SDimitry Andric SimplifyInstructionsInBlock(S->getParent());
12995ffd83dbSDimitry Andric Changed = true;
13005ffd83dbSDimitry Andric }
13015ffd83dbSDimitry Andric return Changed;
1302480093f4SDimitry Andric }
1303