1480093f4SDimitry Andric //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===// 2480093f4SDimitry Andric // 3480093f4SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4480093f4SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5480093f4SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6480093f4SDimitry Andric // 7480093f4SDimitry Andric //===----------------------------------------------------------------------===// 8480093f4SDimitry Andric // 9480093f4SDimitry Andric /// This pass custom lowers llvm.gather and llvm.scatter instructions to 10480093f4SDimitry Andric /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to 11480093f4SDimitry Andric /// produce a better final result as we go. 12480093f4SDimitry Andric // 13480093f4SDimitry Andric //===----------------------------------------------------------------------===// 14480093f4SDimitry Andric 15480093f4SDimitry Andric #include "ARM.h" 16480093f4SDimitry Andric #include "ARMBaseInstrInfo.h" 17480093f4SDimitry Andric #include "ARMSubtarget.h" 185ffd83dbSDimitry Andric #include "llvm/Analysis/LoopInfo.h" 19480093f4SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h" 20fe6060f1SDimitry Andric #include "llvm/Analysis/ValueTracking.h" 21480093f4SDimitry Andric #include "llvm/CodeGen/TargetLowering.h" 22480093f4SDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h" 23480093f4SDimitry Andric #include "llvm/CodeGen/TargetSubtargetInfo.h" 24480093f4SDimitry Andric #include "llvm/InitializePasses.h" 25480093f4SDimitry Andric #include "llvm/IR/BasicBlock.h" 26480093f4SDimitry Andric #include "llvm/IR/Constant.h" 27480093f4SDimitry Andric #include "llvm/IR/Constants.h" 28480093f4SDimitry Andric #include "llvm/IR/DerivedTypes.h" 29480093f4SDimitry Andric #include "llvm/IR/Function.h" 30480093f4SDimitry Andric #include "llvm/IR/InstrTypes.h" 31480093f4SDimitry Andric #include "llvm/IR/Instruction.h" 32480093f4SDimitry Andric #include "llvm/IR/Instructions.h" 33480093f4SDimitry Andric #include "llvm/IR/IntrinsicInst.h" 34480093f4SDimitry Andric #include "llvm/IR/Intrinsics.h" 35480093f4SDimitry Andric #include "llvm/IR/IntrinsicsARM.h" 36480093f4SDimitry Andric #include "llvm/IR/IRBuilder.h" 37480093f4SDimitry Andric #include "llvm/IR/PatternMatch.h" 38480093f4SDimitry Andric #include "llvm/IR/Type.h" 39480093f4SDimitry Andric #include "llvm/IR/Value.h" 40480093f4SDimitry Andric #include "llvm/Pass.h" 41480093f4SDimitry Andric #include "llvm/Support/Casting.h" 425ffd83dbSDimitry Andric #include "llvm/Transforms/Utils/Local.h" 43480093f4SDimitry Andric #include <algorithm> 44480093f4SDimitry Andric #include <cassert> 45480093f4SDimitry Andric 46480093f4SDimitry Andric using namespace llvm; 47480093f4SDimitry Andric 48e8d8bef9SDimitry Andric #define DEBUG_TYPE "arm-mve-gather-scatter-lowering" 49480093f4SDimitry Andric 50480093f4SDimitry Andric cl::opt<bool> EnableMaskedGatherScatters( 51e8d8bef9SDimitry Andric "enable-arm-maskedgatscat", cl::Hidden, cl::init(true), 52480093f4SDimitry Andric cl::desc("Enable the generation of masked gathers and scatters")); 53480093f4SDimitry Andric 54480093f4SDimitry Andric namespace { 55480093f4SDimitry Andric 56480093f4SDimitry Andric class MVEGatherScatterLowering : public FunctionPass { 57480093f4SDimitry Andric public: 58480093f4SDimitry Andric static char ID; // Pass identification, replacement for typeid 59480093f4SDimitry Andric 60480093f4SDimitry Andric explicit MVEGatherScatterLowering() : FunctionPass(ID) { 61480093f4SDimitry Andric initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry()); 62480093f4SDimitry Andric } 63480093f4SDimitry Andric 64480093f4SDimitry Andric bool runOnFunction(Function &F) override; 65480093f4SDimitry Andric 66480093f4SDimitry Andric StringRef getPassName() const override { 67480093f4SDimitry Andric return "MVE gather/scatter lowering"; 68480093f4SDimitry Andric } 69480093f4SDimitry Andric 70480093f4SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 71480093f4SDimitry Andric AU.setPreservesCFG(); 72480093f4SDimitry Andric AU.addRequired<TargetPassConfig>(); 735ffd83dbSDimitry Andric AU.addRequired<LoopInfoWrapperPass>(); 74480093f4SDimitry Andric FunctionPass::getAnalysisUsage(AU); 75480093f4SDimitry Andric } 76480093f4SDimitry Andric 77480093f4SDimitry Andric private: 785ffd83dbSDimitry Andric LoopInfo *LI = nullptr; 79349cc55cSDimitry Andric const DataLayout *DL; 805ffd83dbSDimitry Andric 81480093f4SDimitry Andric // Check this is a valid gather with correct alignment 82480093f4SDimitry Andric bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize, 835ffd83dbSDimitry Andric Align Alignment); 84480093f4SDimitry Andric // Check whether Ptr is hidden behind a bitcast and look through it 85480093f4SDimitry Andric void lookThroughBitcast(Value *&Ptr); 86fe6060f1SDimitry Andric // Decompose a ptr into Base and Offsets, potentially using a GEP to return a 87fe6060f1SDimitry Andric // scalar base and vector offsets, or else fallback to using a base of 0 and 88fe6060f1SDimitry Andric // offset of Ptr where possible. 89fe6060f1SDimitry Andric Value *decomposePtr(Value *Ptr, Value *&Offsets, int &Scale, 90fe6060f1SDimitry Andric FixedVectorType *Ty, Type *MemoryTy, 91fe6060f1SDimitry Andric IRBuilder<> &Builder); 92480093f4SDimitry Andric // Check for a getelementptr and deduce base and offsets from it, on success 93480093f4SDimitry Andric // returning the base directly and the offsets indirectly using the Offsets 94480093f4SDimitry Andric // argument 95fe6060f1SDimitry Andric Value *decomposeGEP(Value *&Offsets, FixedVectorType *Ty, 96fe6060f1SDimitry Andric GetElementPtrInst *GEP, IRBuilder<> &Builder); 975ffd83dbSDimitry Andric // Compute the scale of this gather/scatter instruction 985ffd83dbSDimitry Andric int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize); 995ffd83dbSDimitry Andric // If the value is a constant, or derived from constants via additions 1005ffd83dbSDimitry Andric // and multilications, return its numeric value 1015ffd83dbSDimitry Andric Optional<int64_t> getIfConst(const Value *V); 1025ffd83dbSDimitry Andric // If Inst is an add instruction, check whether one summand is a 1035ffd83dbSDimitry Andric // constant. If so, scale this constant and return it together with 1045ffd83dbSDimitry Andric // the other summand. 1055ffd83dbSDimitry Andric std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale); 106480093f4SDimitry Andric 107fe6060f1SDimitry Andric Instruction *lowerGather(IntrinsicInst *I); 108480093f4SDimitry Andric // Create a gather from a base + vector of offsets 109fe6060f1SDimitry Andric Instruction *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr, 110fe6060f1SDimitry Andric Instruction *&Root, 111fe6060f1SDimitry Andric IRBuilder<> &Builder); 112480093f4SDimitry Andric // Create a gather from a vector of pointers 113fe6060f1SDimitry Andric Instruction *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr, 114fe6060f1SDimitry Andric IRBuilder<> &Builder, 115fe6060f1SDimitry Andric int64_t Increment = 0); 1165ffd83dbSDimitry Andric // Create an incrementing gather from a vector of pointers 117fe6060f1SDimitry Andric Instruction *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr, 1185ffd83dbSDimitry Andric IRBuilder<> &Builder, 1195ffd83dbSDimitry Andric int64_t Increment = 0); 1205ffd83dbSDimitry Andric 121fe6060f1SDimitry Andric Instruction *lowerScatter(IntrinsicInst *I); 1225ffd83dbSDimitry Andric // Create a scatter to a base + vector of offsets 123fe6060f1SDimitry Andric Instruction *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets, 1245ffd83dbSDimitry Andric IRBuilder<> &Builder); 1255ffd83dbSDimitry Andric // Create a scatter to a vector of pointers 126fe6060f1SDimitry Andric Instruction *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr, 1275ffd83dbSDimitry Andric IRBuilder<> &Builder, 1285ffd83dbSDimitry Andric int64_t Increment = 0); 1295ffd83dbSDimitry Andric // Create an incrementing scatter from a vector of pointers 130fe6060f1SDimitry Andric Instruction *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr, 1315ffd83dbSDimitry Andric IRBuilder<> &Builder, 1325ffd83dbSDimitry Andric int64_t Increment = 0); 1335ffd83dbSDimitry Andric 1345ffd83dbSDimitry Andric // QI gathers and scatters can increment their offsets on their own if 1355ffd83dbSDimitry Andric // the increment is a constant value (digit) 136fe6060f1SDimitry Andric Instruction *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *Ptr, 1375ffd83dbSDimitry Andric IRBuilder<> &Builder); 1385ffd83dbSDimitry Andric // QI gathers/scatters can increment their offsets on their own if the 1395ffd83dbSDimitry Andric // increment is a constant value (digit) - this creates a writeback QI 1405ffd83dbSDimitry Andric // gather/scatter 141fe6060f1SDimitry Andric Instruction *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr, 1425ffd83dbSDimitry Andric Value *Ptr, unsigned TypeScale, 1435ffd83dbSDimitry Andric IRBuilder<> &Builder); 144e8d8bef9SDimitry Andric 145e8d8bef9SDimitry Andric // Optimise the base and offsets of the given address 146e8d8bef9SDimitry Andric bool optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI); 147e8d8bef9SDimitry Andric // Try to fold consecutive geps together into one 148*81ad6265SDimitry Andric Value *foldGEP(GetElementPtrInst *GEP, Value *&Offsets, unsigned &Scale, 149*81ad6265SDimitry Andric IRBuilder<> &Builder); 1505ffd83dbSDimitry Andric // Check whether these offsets could be moved out of the loop they're in 1515ffd83dbSDimitry Andric bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI); 1525ffd83dbSDimitry Andric // Pushes the given add out of the loop 1535ffd83dbSDimitry Andric void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex); 154349cc55cSDimitry Andric // Pushes the given mul or shl out of the loop 155349cc55cSDimitry Andric void pushOutMulShl(unsigned Opc, PHINode *&Phi, Value *IncrementPerRound, 1565ffd83dbSDimitry Andric Value *OffsSecondOperand, unsigned LoopIncrement, 1575ffd83dbSDimitry Andric IRBuilder<> &Builder); 158480093f4SDimitry Andric }; 159480093f4SDimitry Andric 160480093f4SDimitry Andric } // end anonymous namespace 161480093f4SDimitry Andric 162480093f4SDimitry Andric char MVEGatherScatterLowering::ID = 0; 163480093f4SDimitry Andric 164480093f4SDimitry Andric INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE, 165480093f4SDimitry Andric "MVE gather/scattering lowering pass", false, false) 166480093f4SDimitry Andric 167480093f4SDimitry Andric Pass *llvm::createMVEGatherScatterLoweringPass() { 168480093f4SDimitry Andric return new MVEGatherScatterLowering(); 169480093f4SDimitry Andric } 170480093f4SDimitry Andric 171480093f4SDimitry Andric bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements, 172480093f4SDimitry Andric unsigned ElemSize, 1735ffd83dbSDimitry Andric Align Alignment) { 1745ffd83dbSDimitry Andric if (((NumElements == 4 && 1755ffd83dbSDimitry Andric (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) || 1765ffd83dbSDimitry Andric (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) || 177480093f4SDimitry Andric (NumElements == 16 && ElemSize == 8)) && 1785ffd83dbSDimitry Andric Alignment >= ElemSize / 8) 179480093f4SDimitry Andric return true; 1805ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have " 1815ffd83dbSDimitry Andric << "valid alignment or vector type \n"); 182480093f4SDimitry Andric return false; 183480093f4SDimitry Andric } 184480093f4SDimitry Andric 185e8d8bef9SDimitry Andric static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount) { 186e8d8bef9SDimitry Andric // Offsets that are not of type <N x i32> are sign extended by the 187e8d8bef9SDimitry Andric // getelementptr instruction, and MVE gathers/scatters treat the offset as 188e8d8bef9SDimitry Andric // unsigned. Thus, if the element size is smaller than 32, we can only allow 189e8d8bef9SDimitry Andric // positive offsets - i.e., the offsets are not allowed to be variables we 190e8d8bef9SDimitry Andric // can't look into. 191e8d8bef9SDimitry Andric // Additionally, <N x i32> offsets have to either originate from a zext of a 192e8d8bef9SDimitry Andric // vector with element types smaller or equal the type of the gather we're 193e8d8bef9SDimitry Andric // looking at, or consist of constants that we can check are small enough 194e8d8bef9SDimitry Andric // to fit into the gather type. 195e8d8bef9SDimitry Andric // Thus we check that 0 < value < 2^TargetElemSize. 196e8d8bef9SDimitry Andric unsigned TargetElemSize = 128 / TargetElemCount; 197e8d8bef9SDimitry Andric unsigned OffsetElemSize = cast<FixedVectorType>(Offsets->getType()) 198e8d8bef9SDimitry Andric ->getElementType() 199e8d8bef9SDimitry Andric ->getScalarSizeInBits(); 200e8d8bef9SDimitry Andric if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) { 201e8d8bef9SDimitry Andric Constant *ConstOff = dyn_cast<Constant>(Offsets); 202e8d8bef9SDimitry Andric if (!ConstOff) 203e8d8bef9SDimitry Andric return false; 204e8d8bef9SDimitry Andric int64_t TargetElemMaxSize = (1ULL << TargetElemSize); 205e8d8bef9SDimitry Andric auto CheckValueSize = [TargetElemMaxSize](Value *OffsetElem) { 206e8d8bef9SDimitry Andric ConstantInt *OConst = dyn_cast<ConstantInt>(OffsetElem); 207e8d8bef9SDimitry Andric if (!OConst) 208e8d8bef9SDimitry Andric return false; 209e8d8bef9SDimitry Andric int SExtValue = OConst->getSExtValue(); 210e8d8bef9SDimitry Andric if (SExtValue >= TargetElemMaxSize || SExtValue < 0) 211e8d8bef9SDimitry Andric return false; 212e8d8bef9SDimitry Andric return true; 213e8d8bef9SDimitry Andric }; 214e8d8bef9SDimitry Andric if (isa<FixedVectorType>(ConstOff->getType())) { 215e8d8bef9SDimitry Andric for (unsigned i = 0; i < TargetElemCount; i++) { 216e8d8bef9SDimitry Andric if (!CheckValueSize(ConstOff->getAggregateElement(i))) 217e8d8bef9SDimitry Andric return false; 218e8d8bef9SDimitry Andric } 219e8d8bef9SDimitry Andric } else { 220e8d8bef9SDimitry Andric if (!CheckValueSize(ConstOff)) 221e8d8bef9SDimitry Andric return false; 222e8d8bef9SDimitry Andric } 223e8d8bef9SDimitry Andric } 224e8d8bef9SDimitry Andric return true; 225e8d8bef9SDimitry Andric } 226e8d8bef9SDimitry Andric 227fe6060f1SDimitry Andric Value *MVEGatherScatterLowering::decomposePtr(Value *Ptr, Value *&Offsets, 228fe6060f1SDimitry Andric int &Scale, FixedVectorType *Ty, 229fe6060f1SDimitry Andric Type *MemoryTy, 230fe6060f1SDimitry Andric IRBuilder<> &Builder) { 231fe6060f1SDimitry Andric if (auto *GEP = dyn_cast<GetElementPtrInst>(Ptr)) { 232fe6060f1SDimitry Andric if (Value *V = decomposeGEP(Offsets, Ty, GEP, Builder)) { 233fe6060f1SDimitry Andric Scale = 234fe6060f1SDimitry Andric computeScale(GEP->getSourceElementType()->getPrimitiveSizeInBits(), 235fe6060f1SDimitry Andric MemoryTy->getScalarSizeInBits()); 236fe6060f1SDimitry Andric return Scale == -1 ? nullptr : V; 237fe6060f1SDimitry Andric } 238fe6060f1SDimitry Andric } 239fe6060f1SDimitry Andric 240fe6060f1SDimitry Andric // If we couldn't use the GEP (or it doesn't exist), attempt to use a 241fe6060f1SDimitry Andric // BasePtr of 0 with Ptr as the Offsets, so long as there are only 4 242fe6060f1SDimitry Andric // elements. 243fe6060f1SDimitry Andric FixedVectorType *PtrTy = cast<FixedVectorType>(Ptr->getType()); 244fe6060f1SDimitry Andric if (PtrTy->getNumElements() != 4 || MemoryTy->getScalarSizeInBits() == 32) 245fe6060f1SDimitry Andric return nullptr; 246fe6060f1SDimitry Andric Value *Zero = ConstantInt::get(Builder.getInt32Ty(), 0); 247fe6060f1SDimitry Andric Value *BasePtr = Builder.CreateIntToPtr(Zero, Builder.getInt8PtrTy()); 248fe6060f1SDimitry Andric Offsets = Builder.CreatePtrToInt( 249fe6060f1SDimitry Andric Ptr, FixedVectorType::get(Builder.getInt32Ty(), 4)); 250fe6060f1SDimitry Andric Scale = 0; 251fe6060f1SDimitry Andric return BasePtr; 252fe6060f1SDimitry Andric } 253fe6060f1SDimitry Andric 254fe6060f1SDimitry Andric Value *MVEGatherScatterLowering::decomposeGEP(Value *&Offsets, 255fe6060f1SDimitry Andric FixedVectorType *Ty, 2565ffd83dbSDimitry Andric GetElementPtrInst *GEP, 2575ffd83dbSDimitry Andric IRBuilder<> &Builder) { 258480093f4SDimitry Andric if (!GEP) { 259fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: no getelementpointer " 260fe6060f1SDimitry Andric << "found\n"); 261480093f4SDimitry Andric return nullptr; 262480093f4SDimitry Andric } 2635ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found." 2645ffd83dbSDimitry Andric << " Looking at intrinsic for base + vector of offsets\n"); 265480093f4SDimitry Andric Value *GEPPtr = GEP->getPointerOperand(); 266e8d8bef9SDimitry Andric Offsets = GEP->getOperand(1); 267e8d8bef9SDimitry Andric if (GEPPtr->getType()->isVectorTy() || 268e8d8bef9SDimitry Andric !isa<FixedVectorType>(Offsets->getType())) 269480093f4SDimitry Andric return nullptr; 270e8d8bef9SDimitry Andric 271480093f4SDimitry Andric if (GEP->getNumOperands() != 2) { 2725ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many" 273480093f4SDimitry Andric << " operands. Expanding.\n"); 274480093f4SDimitry Andric return nullptr; 275480093f4SDimitry Andric } 276480093f4SDimitry Andric Offsets = GEP->getOperand(1); 277e8d8bef9SDimitry Andric unsigned OffsetsElemCount = 278e8d8bef9SDimitry Andric cast<FixedVectorType>(Offsets->getType())->getNumElements(); 2795ffd83dbSDimitry Andric // Paranoid check whether the number of parallel lanes is the same 280e8d8bef9SDimitry Andric assert(Ty->getNumElements() == OffsetsElemCount); 281e8d8bef9SDimitry Andric 282e8d8bef9SDimitry Andric ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets); 283e8d8bef9SDimitry Andric if (ZextOffs) 284480093f4SDimitry Andric Offsets = ZextOffs->getOperand(0); 285e8d8bef9SDimitry Andric FixedVectorType *OffsetType = cast<FixedVectorType>(Offsets->getType()); 286e8d8bef9SDimitry Andric 287e8d8bef9SDimitry Andric // If the offsets are already being zext-ed to <N x i32>, that relieves us of 288e8d8bef9SDimitry Andric // having to make sure that they won't overflow. 289e8d8bef9SDimitry Andric if (!ZextOffs || cast<FixedVectorType>(ZextOffs->getDestTy()) 290e8d8bef9SDimitry Andric ->getElementType() 291e8d8bef9SDimitry Andric ->getScalarSizeInBits() != 32) 292e8d8bef9SDimitry Andric if (!checkOffsetSize(Offsets, OffsetsElemCount)) 293480093f4SDimitry Andric return nullptr; 2945ffd83dbSDimitry Andric 295e8d8bef9SDimitry Andric // The offset sizes have been checked; if any truncating or zext-ing is 296e8d8bef9SDimitry Andric // required to fix them, do that now 2975ffd83dbSDimitry Andric if (Ty != Offsets->getType()) { 298e8d8bef9SDimitry Andric if ((Ty->getElementType()->getScalarSizeInBits() < 299e8d8bef9SDimitry Andric OffsetType->getElementType()->getScalarSizeInBits())) { 300e8d8bef9SDimitry Andric Offsets = Builder.CreateTrunc(Offsets, Ty); 3015ffd83dbSDimitry Andric } else { 302e8d8bef9SDimitry Andric Offsets = Builder.CreateZExt(Offsets, VectorType::getInteger(Ty)); 303480093f4SDimitry Andric } 304480093f4SDimitry Andric } 305480093f4SDimitry Andric // If none of the checks failed, return the gep's base pointer 3065ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n"); 307480093f4SDimitry Andric return GEPPtr; 308480093f4SDimitry Andric } 309480093f4SDimitry Andric 310480093f4SDimitry Andric void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) { 311480093f4SDimitry Andric // Look through bitcast instruction if #elements is the same 312480093f4SDimitry Andric if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) { 3135ffd83dbSDimitry Andric auto *BCTy = cast<FixedVectorType>(BitCast->getType()); 3145ffd83dbSDimitry Andric auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType()); 3155ffd83dbSDimitry Andric if (BCTy->getNumElements() == BCSrcTy->getNumElements()) { 316fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: looking through " 317fe6060f1SDimitry Andric << "bitcast\n"); 318480093f4SDimitry Andric Ptr = BitCast->getOperand(0); 319480093f4SDimitry Andric } 320480093f4SDimitry Andric } 321480093f4SDimitry Andric } 322480093f4SDimitry Andric 3235ffd83dbSDimitry Andric int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize, 3245ffd83dbSDimitry Andric unsigned MemoryElemSize) { 3255ffd83dbSDimitry Andric // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2, 3265ffd83dbSDimitry Andric // or a 8bit, 16bit or 32bit load/store scaled by 1 3275ffd83dbSDimitry Andric if (GEPElemSize == 32 && MemoryElemSize == 32) 3285ffd83dbSDimitry Andric return 2; 3295ffd83dbSDimitry Andric else if (GEPElemSize == 16 && MemoryElemSize == 16) 3305ffd83dbSDimitry Andric return 1; 3315ffd83dbSDimitry Andric else if (GEPElemSize == 8) 3325ffd83dbSDimitry Andric return 0; 3335ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't " 3345ffd83dbSDimitry Andric << "create intrinsic\n"); 3355ffd83dbSDimitry Andric return -1; 3365ffd83dbSDimitry Andric } 3375ffd83dbSDimitry Andric 3385ffd83dbSDimitry Andric Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) { 3395ffd83dbSDimitry Andric const Constant *C = dyn_cast<Constant>(V); 340349cc55cSDimitry Andric if (C && C->getSplatValue()) 3415ffd83dbSDimitry Andric return Optional<int64_t>{C->getUniqueInteger().getSExtValue()}; 3425ffd83dbSDimitry Andric if (!isa<Instruction>(V)) 3435ffd83dbSDimitry Andric return Optional<int64_t>{}; 3445ffd83dbSDimitry Andric 3455ffd83dbSDimitry Andric const Instruction *I = cast<Instruction>(V); 346349cc55cSDimitry Andric if (I->getOpcode() == Instruction::Add || I->getOpcode() == Instruction::Or || 347349cc55cSDimitry Andric I->getOpcode() == Instruction::Mul || 348349cc55cSDimitry Andric I->getOpcode() == Instruction::Shl) { 3495ffd83dbSDimitry Andric Optional<int64_t> Op0 = getIfConst(I->getOperand(0)); 3505ffd83dbSDimitry Andric Optional<int64_t> Op1 = getIfConst(I->getOperand(1)); 3515ffd83dbSDimitry Andric if (!Op0 || !Op1) 3525ffd83dbSDimitry Andric return Optional<int64_t>{}; 3535ffd83dbSDimitry Andric if (I->getOpcode() == Instruction::Add) 3545ffd83dbSDimitry Andric return Optional<int64_t>{Op0.getValue() + Op1.getValue()}; 3555ffd83dbSDimitry Andric if (I->getOpcode() == Instruction::Mul) 3565ffd83dbSDimitry Andric return Optional<int64_t>{Op0.getValue() * Op1.getValue()}; 357349cc55cSDimitry Andric if (I->getOpcode() == Instruction::Shl) 358349cc55cSDimitry Andric return Optional<int64_t>{Op0.getValue() << Op1.getValue()}; 359349cc55cSDimitry Andric if (I->getOpcode() == Instruction::Or) 360349cc55cSDimitry Andric return Optional<int64_t>{Op0.getValue() | Op1.getValue()}; 3615ffd83dbSDimitry Andric } 3625ffd83dbSDimitry Andric return Optional<int64_t>{}; 3635ffd83dbSDimitry Andric } 3645ffd83dbSDimitry Andric 365349cc55cSDimitry Andric // Return true if I is an Or instruction that is equivalent to an add, due to 366349cc55cSDimitry Andric // the operands having no common bits set. 367349cc55cSDimitry Andric static bool isAddLikeOr(Instruction *I, const DataLayout &DL) { 368349cc55cSDimitry Andric return I->getOpcode() == Instruction::Or && 369349cc55cSDimitry Andric haveNoCommonBitsSet(I->getOperand(0), I->getOperand(1), DL); 370349cc55cSDimitry Andric } 371349cc55cSDimitry Andric 3725ffd83dbSDimitry Andric std::pair<Value *, int64_t> 3735ffd83dbSDimitry Andric MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) { 3745ffd83dbSDimitry Andric std::pair<Value *, int64_t> ReturnFalse = 3755ffd83dbSDimitry Andric std::pair<Value *, int64_t>(nullptr, 0); 376349cc55cSDimitry Andric // At this point, the instruction we're looking at must be an add or an 377349cc55cSDimitry Andric // add-like-or. 3785ffd83dbSDimitry Andric Instruction *Add = dyn_cast<Instruction>(Inst); 379349cc55cSDimitry Andric if (Add == nullptr || 380349cc55cSDimitry Andric (Add->getOpcode() != Instruction::Add && !isAddLikeOr(Add, *DL))) 3815ffd83dbSDimitry Andric return ReturnFalse; 3825ffd83dbSDimitry Andric 3835ffd83dbSDimitry Andric Value *Summand; 3845ffd83dbSDimitry Andric Optional<int64_t> Const; 3855ffd83dbSDimitry Andric // Find out which operand the value that is increased is 3865ffd83dbSDimitry Andric if ((Const = getIfConst(Add->getOperand(0)))) 3875ffd83dbSDimitry Andric Summand = Add->getOperand(1); 3885ffd83dbSDimitry Andric else if ((Const = getIfConst(Add->getOperand(1)))) 3895ffd83dbSDimitry Andric Summand = Add->getOperand(0); 3905ffd83dbSDimitry Andric else 3915ffd83dbSDimitry Andric return ReturnFalse; 3925ffd83dbSDimitry Andric 3935ffd83dbSDimitry Andric // Check that the constant is small enough for an incrementing gather 394*81ad6265SDimitry Andric int64_t Immediate = *Const << TypeScale; 3955ffd83dbSDimitry Andric if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0) 3965ffd83dbSDimitry Andric return ReturnFalse; 3975ffd83dbSDimitry Andric 3985ffd83dbSDimitry Andric return std::pair<Value *, int64_t>(Summand, Immediate); 3995ffd83dbSDimitry Andric } 4005ffd83dbSDimitry Andric 401fe6060f1SDimitry Andric Instruction *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) { 402480093f4SDimitry Andric using namespace PatternMatch; 403fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n" 404fe6060f1SDimitry Andric << *I << "\n"); 405480093f4SDimitry Andric 406480093f4SDimitry Andric // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0) 407480093f4SDimitry Andric // Attempt to turn the masked gather in I into a MVE intrinsic 408480093f4SDimitry Andric // Potentially optimising the addressing modes as we do so. 4095ffd83dbSDimitry Andric auto *Ty = cast<FixedVectorType>(I->getType()); 410480093f4SDimitry Andric Value *Ptr = I->getArgOperand(0); 4115ffd83dbSDimitry Andric Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue(); 412480093f4SDimitry Andric Value *Mask = I->getArgOperand(2); 413480093f4SDimitry Andric Value *PassThru = I->getArgOperand(3); 414480093f4SDimitry Andric 4155ffd83dbSDimitry Andric if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), 4165ffd83dbSDimitry Andric Alignment)) 4175ffd83dbSDimitry Andric return nullptr; 418480093f4SDimitry Andric lookThroughBitcast(Ptr); 419480093f4SDimitry Andric assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); 420480093f4SDimitry Andric 421480093f4SDimitry Andric IRBuilder<> Builder(I->getContext()); 422480093f4SDimitry Andric Builder.SetInsertPoint(I); 423480093f4SDimitry Andric Builder.SetCurrentDebugLocation(I->getDebugLoc()); 4245ffd83dbSDimitry Andric 4255ffd83dbSDimitry Andric Instruction *Root = I; 426fe6060f1SDimitry Andric 427fe6060f1SDimitry Andric Instruction *Load = tryCreateIncrementingGatScat(I, Ptr, Builder); 428fe6060f1SDimitry Andric if (!Load) 429fe6060f1SDimitry Andric Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder); 430480093f4SDimitry Andric if (!Load) 431480093f4SDimitry Andric Load = tryCreateMaskedGatherBase(I, Ptr, Builder); 432480093f4SDimitry Andric if (!Load) 4335ffd83dbSDimitry Andric return nullptr; 434480093f4SDimitry Andric 435480093f4SDimitry Andric if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) { 436480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - " 437480093f4SDimitry Andric << "creating select\n"); 438fe6060f1SDimitry Andric Load = SelectInst::Create(Mask, Load, PassThru); 439fe6060f1SDimitry Andric Builder.Insert(Load); 440480093f4SDimitry Andric } 441480093f4SDimitry Andric 4425ffd83dbSDimitry Andric Root->replaceAllUsesWith(Load); 4435ffd83dbSDimitry Andric Root->eraseFromParent(); 4445ffd83dbSDimitry Andric if (Root != I) 4455ffd83dbSDimitry Andric // If this was an extending gather, we need to get rid of the sext/zext 4465ffd83dbSDimitry Andric // sext/zext as well as of the gather itself 447480093f4SDimitry Andric I->eraseFromParent(); 4485ffd83dbSDimitry Andric 449fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n" 450fe6060f1SDimitry Andric << *Load << "\n"); 4515ffd83dbSDimitry Andric return Load; 452480093f4SDimitry Andric } 453480093f4SDimitry Andric 454fe6060f1SDimitry Andric Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBase( 455fe6060f1SDimitry Andric IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 456480093f4SDimitry Andric using namespace PatternMatch; 4575ffd83dbSDimitry Andric auto *Ty = cast<FixedVectorType>(I->getType()); 458480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n"); 4595ffd83dbSDimitry Andric if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 460480093f4SDimitry Andric // Can't build an intrinsic for this 461480093f4SDimitry Andric return nullptr; 462480093f4SDimitry Andric Value *Mask = I->getArgOperand(2); 463480093f4SDimitry Andric if (match(Mask, m_One())) 464480093f4SDimitry Andric return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base, 465480093f4SDimitry Andric {Ty, Ptr->getType()}, 4665ffd83dbSDimitry Andric {Ptr, Builder.getInt32(Increment)}); 467480093f4SDimitry Andric else 468480093f4SDimitry Andric return Builder.CreateIntrinsic( 469480093f4SDimitry Andric Intrinsic::arm_mve_vldr_gather_base_predicated, 470480093f4SDimitry Andric {Ty, Ptr->getType(), Mask->getType()}, 4715ffd83dbSDimitry Andric {Ptr, Builder.getInt32(Increment), Mask}); 4725ffd83dbSDimitry Andric } 4735ffd83dbSDimitry Andric 474fe6060f1SDimitry Andric Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB( 4755ffd83dbSDimitry Andric IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 4765ffd83dbSDimitry Andric using namespace PatternMatch; 4775ffd83dbSDimitry Andric auto *Ty = cast<FixedVectorType>(I->getType()); 478fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers with " 479fe6060f1SDimitry Andric << "writeback\n"); 4805ffd83dbSDimitry Andric if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 4815ffd83dbSDimitry Andric // Can't build an intrinsic for this 4825ffd83dbSDimitry Andric return nullptr; 4835ffd83dbSDimitry Andric Value *Mask = I->getArgOperand(2); 4845ffd83dbSDimitry Andric if (match(Mask, m_One())) 4855ffd83dbSDimitry Andric return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb, 4865ffd83dbSDimitry Andric {Ty, Ptr->getType()}, 4875ffd83dbSDimitry Andric {Ptr, Builder.getInt32(Increment)}); 4885ffd83dbSDimitry Andric else 4895ffd83dbSDimitry Andric return Builder.CreateIntrinsic( 4905ffd83dbSDimitry Andric Intrinsic::arm_mve_vldr_gather_base_wb_predicated, 4915ffd83dbSDimitry Andric {Ty, Ptr->getType(), Mask->getType()}, 4925ffd83dbSDimitry Andric {Ptr, Builder.getInt32(Increment), Mask}); 493480093f4SDimitry Andric } 494480093f4SDimitry Andric 495fe6060f1SDimitry Andric Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherOffset( 4965ffd83dbSDimitry Andric IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) { 497480093f4SDimitry Andric using namespace PatternMatch; 4985ffd83dbSDimitry Andric 499fe6060f1SDimitry Andric Type *MemoryTy = I->getType(); 500fe6060f1SDimitry Andric Type *ResultTy = MemoryTy; 5015ffd83dbSDimitry Andric 5025ffd83dbSDimitry Andric unsigned Unsigned = 1; 5035ffd83dbSDimitry Andric // The size of the gather was already checked in isLegalTypeAndAlignment; 5045ffd83dbSDimitry Andric // if it was not a full vector width an appropriate extend should follow. 5055ffd83dbSDimitry Andric auto *Extend = Root; 506fe6060f1SDimitry Andric bool TruncResult = false; 507fe6060f1SDimitry Andric if (MemoryTy->getPrimitiveSizeInBits() < 128) { 508fe6060f1SDimitry Andric if (I->hasOneUse()) { 509fe6060f1SDimitry Andric // If the gather has a single extend of the correct type, use an extending 510fe6060f1SDimitry Andric // gather and replace the ext. In which case the correct root to replace 511fe6060f1SDimitry Andric // is not the CallInst itself, but the instruction which extends it. 512fe6060f1SDimitry Andric Instruction* User = cast<Instruction>(*I->users().begin()); 513fe6060f1SDimitry Andric if (isa<SExtInst>(User) && 514fe6060f1SDimitry Andric User->getType()->getPrimitiveSizeInBits() == 128) { 515fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: " 516fe6060f1SDimitry Andric << *User << "\n"); 517fe6060f1SDimitry Andric Extend = User; 518fe6060f1SDimitry Andric ResultTy = User->getType(); 5195ffd83dbSDimitry Andric Unsigned = 0; 520fe6060f1SDimitry Andric } else if (isa<ZExtInst>(User) && 521fe6060f1SDimitry Andric User->getType()->getPrimitiveSizeInBits() == 128) { 522fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: " 523fe6060f1SDimitry Andric << *ResultTy << "\n"); 524fe6060f1SDimitry Andric Extend = User; 525fe6060f1SDimitry Andric ResultTy = User->getType(); 526480093f4SDimitry Andric } 527fe6060f1SDimitry Andric } 528fe6060f1SDimitry Andric 529fe6060f1SDimitry Andric // If an extend hasn't been found and the type is an integer, create an 530fe6060f1SDimitry Andric // extending gather and truncate back to the original type. 531fe6060f1SDimitry Andric if (ResultTy->getPrimitiveSizeInBits() < 128 && 532fe6060f1SDimitry Andric ResultTy->isIntOrIntVectorTy()) { 533fe6060f1SDimitry Andric ResultTy = ResultTy->getWithNewBitWidth( 534fe6060f1SDimitry Andric 128 / cast<FixedVectorType>(ResultTy)->getNumElements()); 535fe6060f1SDimitry Andric TruncResult = true; 536fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: Small input type, truncing to: " 537fe6060f1SDimitry Andric << *ResultTy << "\n"); 538fe6060f1SDimitry Andric } 539fe6060f1SDimitry Andric 5405ffd83dbSDimitry Andric // The final size of the gather must be a full vector width 5415ffd83dbSDimitry Andric if (ResultTy->getPrimitiveSizeInBits() != 128) { 542fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: Extend needed but not provided " 543fe6060f1SDimitry Andric "from the correct type. Expanding\n"); 5445ffd83dbSDimitry Andric return nullptr; 5455ffd83dbSDimitry Andric } 5465ffd83dbSDimitry Andric } 5475ffd83dbSDimitry Andric 5485ffd83dbSDimitry Andric Value *Offsets; 549fe6060f1SDimitry Andric int Scale; 550fe6060f1SDimitry Andric Value *BasePtr = decomposePtr( 551fe6060f1SDimitry Andric Ptr, Offsets, Scale, cast<FixedVectorType>(ResultTy), MemoryTy, Builder); 5525ffd83dbSDimitry Andric if (!BasePtr) 5535ffd83dbSDimitry Andric return nullptr; 5545ffd83dbSDimitry Andric 5555ffd83dbSDimitry Andric Root = Extend; 556480093f4SDimitry Andric Value *Mask = I->getArgOperand(2); 557fe6060f1SDimitry Andric Instruction *Load = nullptr; 558480093f4SDimitry Andric if (!match(Mask, m_One())) 559fe6060f1SDimitry Andric Load = Builder.CreateIntrinsic( 560480093f4SDimitry Andric Intrinsic::arm_mve_vldr_gather_offset_predicated, 5615ffd83dbSDimitry Andric {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()}, 562fe6060f1SDimitry Andric {BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()), 5635ffd83dbSDimitry Andric Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask}); 564480093f4SDimitry Andric else 565fe6060f1SDimitry Andric Load = Builder.CreateIntrinsic( 566480093f4SDimitry Andric Intrinsic::arm_mve_vldr_gather_offset, 5675ffd83dbSDimitry Andric {ResultTy, BasePtr->getType(), Offsets->getType()}, 568fe6060f1SDimitry Andric {BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()), 5695ffd83dbSDimitry Andric Builder.getInt32(Scale), Builder.getInt32(Unsigned)}); 570fe6060f1SDimitry Andric 571fe6060f1SDimitry Andric if (TruncResult) { 572fe6060f1SDimitry Andric Load = TruncInst::Create(Instruction::Trunc, Load, MemoryTy); 573fe6060f1SDimitry Andric Builder.Insert(Load); 574fe6060f1SDimitry Andric } 575fe6060f1SDimitry Andric return Load; 5765ffd83dbSDimitry Andric } 5775ffd83dbSDimitry Andric 578fe6060f1SDimitry Andric Instruction *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) { 5795ffd83dbSDimitry Andric using namespace PatternMatch; 580fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n" 581fe6060f1SDimitry Andric << *I << "\n"); 5825ffd83dbSDimitry Andric 5835ffd83dbSDimitry Andric // @llvm.masked.scatter.*(data, ptrs, alignment, mask) 5845ffd83dbSDimitry Andric // Attempt to turn the masked scatter in I into a MVE intrinsic 5855ffd83dbSDimitry Andric // Potentially optimising the addressing modes as we do so. 5865ffd83dbSDimitry Andric Value *Input = I->getArgOperand(0); 5875ffd83dbSDimitry Andric Value *Ptr = I->getArgOperand(1); 5885ffd83dbSDimitry Andric Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue(); 5895ffd83dbSDimitry Andric auto *Ty = cast<FixedVectorType>(Input->getType()); 5905ffd83dbSDimitry Andric 5915ffd83dbSDimitry Andric if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), 5925ffd83dbSDimitry Andric Alignment)) 5935ffd83dbSDimitry Andric return nullptr; 5945ffd83dbSDimitry Andric 5955ffd83dbSDimitry Andric lookThroughBitcast(Ptr); 5965ffd83dbSDimitry Andric assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); 5975ffd83dbSDimitry Andric 5985ffd83dbSDimitry Andric IRBuilder<> Builder(I->getContext()); 5995ffd83dbSDimitry Andric Builder.SetInsertPoint(I); 6005ffd83dbSDimitry Andric Builder.SetCurrentDebugLocation(I->getDebugLoc()); 6015ffd83dbSDimitry Andric 602fe6060f1SDimitry Andric Instruction *Store = tryCreateIncrementingGatScat(I, Ptr, Builder); 603fe6060f1SDimitry Andric if (!Store) 604fe6060f1SDimitry Andric Store = tryCreateMaskedScatterOffset(I, Ptr, Builder); 6055ffd83dbSDimitry Andric if (!Store) 6065ffd83dbSDimitry Andric Store = tryCreateMaskedScatterBase(I, Ptr, Builder); 6075ffd83dbSDimitry Andric if (!Store) 6085ffd83dbSDimitry Andric return nullptr; 6095ffd83dbSDimitry Andric 610fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n" 611fe6060f1SDimitry Andric << *Store << "\n"); 6125ffd83dbSDimitry Andric I->eraseFromParent(); 6135ffd83dbSDimitry Andric return Store; 6145ffd83dbSDimitry Andric } 6155ffd83dbSDimitry Andric 616fe6060f1SDimitry Andric Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBase( 6175ffd83dbSDimitry Andric IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 6185ffd83dbSDimitry Andric using namespace PatternMatch; 6195ffd83dbSDimitry Andric Value *Input = I->getArgOperand(0); 6205ffd83dbSDimitry Andric auto *Ty = cast<FixedVectorType>(Input->getType()); 6215ffd83dbSDimitry Andric // Only QR variants allow truncating 6225ffd83dbSDimitry Andric if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) { 6235ffd83dbSDimitry Andric // Can't build an intrinsic for this 6245ffd83dbSDimitry Andric return nullptr; 6255ffd83dbSDimitry Andric } 6265ffd83dbSDimitry Andric Value *Mask = I->getArgOperand(3); 6275ffd83dbSDimitry Andric // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask) 6285ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n"); 6295ffd83dbSDimitry Andric if (match(Mask, m_One())) 6305ffd83dbSDimitry Andric return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base, 6315ffd83dbSDimitry Andric {Ptr->getType(), Input->getType()}, 6325ffd83dbSDimitry Andric {Ptr, Builder.getInt32(Increment), Input}); 6335ffd83dbSDimitry Andric else 6345ffd83dbSDimitry Andric return Builder.CreateIntrinsic( 6355ffd83dbSDimitry Andric Intrinsic::arm_mve_vstr_scatter_base_predicated, 6365ffd83dbSDimitry Andric {Ptr->getType(), Input->getType(), Mask->getType()}, 6375ffd83dbSDimitry Andric {Ptr, Builder.getInt32(Increment), Input, Mask}); 6385ffd83dbSDimitry Andric } 6395ffd83dbSDimitry Andric 640fe6060f1SDimitry Andric Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB( 6415ffd83dbSDimitry Andric IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 6425ffd83dbSDimitry Andric using namespace PatternMatch; 6435ffd83dbSDimitry Andric Value *Input = I->getArgOperand(0); 6445ffd83dbSDimitry Andric auto *Ty = cast<FixedVectorType>(Input->getType()); 645fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers " 646fe6060f1SDimitry Andric << "with writeback\n"); 6475ffd83dbSDimitry Andric if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 6485ffd83dbSDimitry Andric // Can't build an intrinsic for this 6495ffd83dbSDimitry Andric return nullptr; 6505ffd83dbSDimitry Andric Value *Mask = I->getArgOperand(3); 6515ffd83dbSDimitry Andric if (match(Mask, m_One())) 6525ffd83dbSDimitry Andric return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb, 6535ffd83dbSDimitry Andric {Ptr->getType(), Input->getType()}, 6545ffd83dbSDimitry Andric {Ptr, Builder.getInt32(Increment), Input}); 6555ffd83dbSDimitry Andric else 6565ffd83dbSDimitry Andric return Builder.CreateIntrinsic( 6575ffd83dbSDimitry Andric Intrinsic::arm_mve_vstr_scatter_base_wb_predicated, 6585ffd83dbSDimitry Andric {Ptr->getType(), Input->getType(), Mask->getType()}, 6595ffd83dbSDimitry Andric {Ptr, Builder.getInt32(Increment), Input, Mask}); 6605ffd83dbSDimitry Andric } 6615ffd83dbSDimitry Andric 662fe6060f1SDimitry Andric Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterOffset( 6635ffd83dbSDimitry Andric IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) { 6645ffd83dbSDimitry Andric using namespace PatternMatch; 6655ffd83dbSDimitry Andric Value *Input = I->getArgOperand(0); 6665ffd83dbSDimitry Andric Value *Mask = I->getArgOperand(3); 6675ffd83dbSDimitry Andric Type *InputTy = Input->getType(); 6685ffd83dbSDimitry Andric Type *MemoryTy = InputTy; 669fe6060f1SDimitry Andric 6705ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing" 6715ffd83dbSDimitry Andric << " to base + vector of offsets\n"); 6725ffd83dbSDimitry Andric // If the input has been truncated, try to integrate that trunc into the 6735ffd83dbSDimitry Andric // scatter instruction (we don't care about alignment here) 6745ffd83dbSDimitry Andric if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) { 6755ffd83dbSDimitry Andric Value *PreTrunc = Trunc->getOperand(0); 6765ffd83dbSDimitry Andric Type *PreTruncTy = PreTrunc->getType(); 6775ffd83dbSDimitry Andric if (PreTruncTy->getPrimitiveSizeInBits() == 128) { 6785ffd83dbSDimitry Andric Input = PreTrunc; 6795ffd83dbSDimitry Andric InputTy = PreTruncTy; 6805ffd83dbSDimitry Andric } 6815ffd83dbSDimitry Andric } 682fe6060f1SDimitry Andric bool ExtendInput = false; 683fe6060f1SDimitry Andric if (InputTy->getPrimitiveSizeInBits() < 128 && 684fe6060f1SDimitry Andric InputTy->isIntOrIntVectorTy()) { 685fe6060f1SDimitry Andric // If we can't find a trunc to incorporate into the instruction, create an 686fe6060f1SDimitry Andric // implicit one with a zext, so that we can still create a scatter. We know 687fe6060f1SDimitry Andric // that the input type is 4x/8x/16x and of type i8/i16/i32, so any type 688fe6060f1SDimitry Andric // smaller than 128 bits will divide evenly into a 128bit vector. 689fe6060f1SDimitry Andric InputTy = InputTy->getWithNewBitWidth( 690fe6060f1SDimitry Andric 128 / cast<FixedVectorType>(InputTy)->getNumElements()); 691fe6060f1SDimitry Andric ExtendInput = true; 692fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked scatters: Small input type, will extend:\n" 693fe6060f1SDimitry Andric << *Input << "\n"); 694fe6060f1SDimitry Andric } 6955ffd83dbSDimitry Andric if (InputTy->getPrimitiveSizeInBits() != 128) { 696fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked scatters: cannot create scatters for " 697fe6060f1SDimitry Andric "non-standard input types. Expanding.\n"); 6985ffd83dbSDimitry Andric return nullptr; 6995ffd83dbSDimitry Andric } 7005ffd83dbSDimitry Andric 7015ffd83dbSDimitry Andric Value *Offsets; 702fe6060f1SDimitry Andric int Scale; 703fe6060f1SDimitry Andric Value *BasePtr = decomposePtr( 704fe6060f1SDimitry Andric Ptr, Offsets, Scale, cast<FixedVectorType>(InputTy), MemoryTy, Builder); 7055ffd83dbSDimitry Andric if (!BasePtr) 7065ffd83dbSDimitry Andric return nullptr; 7075ffd83dbSDimitry Andric 708fe6060f1SDimitry Andric if (ExtendInput) 709fe6060f1SDimitry Andric Input = Builder.CreateZExt(Input, InputTy); 7105ffd83dbSDimitry Andric if (!match(Mask, m_One())) 7115ffd83dbSDimitry Andric return Builder.CreateIntrinsic( 7125ffd83dbSDimitry Andric Intrinsic::arm_mve_vstr_scatter_offset_predicated, 7135ffd83dbSDimitry Andric {BasePtr->getType(), Offsets->getType(), Input->getType(), 7145ffd83dbSDimitry Andric Mask->getType()}, 7155ffd83dbSDimitry Andric {BasePtr, Offsets, Input, 7165ffd83dbSDimitry Andric Builder.getInt32(MemoryTy->getScalarSizeInBits()), 7175ffd83dbSDimitry Andric Builder.getInt32(Scale), Mask}); 7185ffd83dbSDimitry Andric else 7195ffd83dbSDimitry Andric return Builder.CreateIntrinsic( 7205ffd83dbSDimitry Andric Intrinsic::arm_mve_vstr_scatter_offset, 7215ffd83dbSDimitry Andric {BasePtr->getType(), Offsets->getType(), Input->getType()}, 7225ffd83dbSDimitry Andric {BasePtr, Offsets, Input, 7235ffd83dbSDimitry Andric Builder.getInt32(MemoryTy->getScalarSizeInBits()), 7245ffd83dbSDimitry Andric Builder.getInt32(Scale)}); 7255ffd83dbSDimitry Andric } 7265ffd83dbSDimitry Andric 727fe6060f1SDimitry Andric Instruction *MVEGatherScatterLowering::tryCreateIncrementingGatScat( 728fe6060f1SDimitry Andric IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) { 7295ffd83dbSDimitry Andric FixedVectorType *Ty; 7305ffd83dbSDimitry Andric if (I->getIntrinsicID() == Intrinsic::masked_gather) 7315ffd83dbSDimitry Andric Ty = cast<FixedVectorType>(I->getType()); 7325ffd83dbSDimitry Andric else 7335ffd83dbSDimitry Andric Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType()); 734fe6060f1SDimitry Andric 7355ffd83dbSDimitry Andric // Incrementing gathers only exist for v4i32 736fe6060f1SDimitry Andric if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 7375ffd83dbSDimitry Andric return nullptr; 738fe6060f1SDimitry Andric // Incrementing gathers are not beneficial outside of a loop 7395ffd83dbSDimitry Andric Loop *L = LI->getLoopFor(I->getParent()); 7405ffd83dbSDimitry Andric if (L == nullptr) 7415ffd83dbSDimitry Andric return nullptr; 742fe6060f1SDimitry Andric 743fe6060f1SDimitry Andric // Decompose the GEP into Base and Offsets 744fe6060f1SDimitry Andric GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); 745fe6060f1SDimitry Andric Value *Offsets; 746fe6060f1SDimitry Andric Value *BasePtr = decomposeGEP(Offsets, Ty, GEP, Builder); 747fe6060f1SDimitry Andric if (!BasePtr) 748fe6060f1SDimitry Andric return nullptr; 749fe6060f1SDimitry Andric 7505ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing " 7515ffd83dbSDimitry Andric "wb gather/scatter\n"); 7525ffd83dbSDimitry Andric 7535ffd83dbSDimitry Andric // The gep was in charge of making sure the offsets are scaled correctly 7545ffd83dbSDimitry Andric // - calculate that factor so it can be applied by hand 7555ffd83dbSDimitry Andric int TypeScale = 756349cc55cSDimitry Andric computeScale(DL->getTypeSizeInBits(GEP->getOperand(0)->getType()), 757349cc55cSDimitry Andric DL->getTypeSizeInBits(GEP->getType()) / 7585ffd83dbSDimitry Andric cast<FixedVectorType>(GEP->getType())->getNumElements()); 7595ffd83dbSDimitry Andric if (TypeScale == -1) 7605ffd83dbSDimitry Andric return nullptr; 7615ffd83dbSDimitry Andric 7625ffd83dbSDimitry Andric if (GEP->hasOneUse()) { 7635ffd83dbSDimitry Andric // Only in this case do we want to build a wb gather, because the wb will 7645ffd83dbSDimitry Andric // change the phi which does affect other users of the gep (which will still 7655ffd83dbSDimitry Andric // be using the phi in the old way) 766fe6060f1SDimitry Andric if (auto *Load = tryCreateIncrementingWBGatScat(I, BasePtr, Offsets, 767fe6060f1SDimitry Andric TypeScale, Builder)) 7685ffd83dbSDimitry Andric return Load; 7695ffd83dbSDimitry Andric } 770fe6060f1SDimitry Andric 7715ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing " 7725ffd83dbSDimitry Andric "non-wb gather/scatter\n"); 7735ffd83dbSDimitry Andric 7745ffd83dbSDimitry Andric std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale); 7755ffd83dbSDimitry Andric if (Add.first == nullptr) 7765ffd83dbSDimitry Andric return nullptr; 7775ffd83dbSDimitry Andric Value *OffsetsIncoming = Add.first; 7785ffd83dbSDimitry Andric int64_t Immediate = Add.second; 7795ffd83dbSDimitry Andric 7805ffd83dbSDimitry Andric // Make sure the offsets are scaled correctly 7815ffd83dbSDimitry Andric Instruction *ScaledOffsets = BinaryOperator::Create( 7825ffd83dbSDimitry Andric Instruction::Shl, OffsetsIncoming, 7835ffd83dbSDimitry Andric Builder.CreateVectorSplat(Ty->getNumElements(), Builder.getInt32(TypeScale)), 7845ffd83dbSDimitry Andric "ScaledIndex", I); 7855ffd83dbSDimitry Andric // Add the base to the offsets 7865ffd83dbSDimitry Andric OffsetsIncoming = BinaryOperator::Create( 7875ffd83dbSDimitry Andric Instruction::Add, ScaledOffsets, 7885ffd83dbSDimitry Andric Builder.CreateVectorSplat( 7895ffd83dbSDimitry Andric Ty->getNumElements(), 7905ffd83dbSDimitry Andric Builder.CreatePtrToInt( 7915ffd83dbSDimitry Andric BasePtr, 7925ffd83dbSDimitry Andric cast<VectorType>(ScaledOffsets->getType())->getElementType())), 7935ffd83dbSDimitry Andric "StartIndex", I); 7945ffd83dbSDimitry Andric 7955ffd83dbSDimitry Andric if (I->getIntrinsicID() == Intrinsic::masked_gather) 796fe6060f1SDimitry Andric return tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate); 7975ffd83dbSDimitry Andric else 798fe6060f1SDimitry Andric return tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate); 7995ffd83dbSDimitry Andric } 8005ffd83dbSDimitry Andric 801fe6060f1SDimitry Andric Instruction *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat( 8025ffd83dbSDimitry Andric IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale, 8035ffd83dbSDimitry Andric IRBuilder<> &Builder) { 8045ffd83dbSDimitry Andric // Check whether this gather's offset is incremented by a constant - if so, 8055ffd83dbSDimitry Andric // and the load is of the right type, we can merge this into a QI gather 8065ffd83dbSDimitry Andric Loop *L = LI->getLoopFor(I->getParent()); 8075ffd83dbSDimitry Andric // Offsets that are worth merging into this instruction will be incremented 8085ffd83dbSDimitry Andric // by a constant, thus we're looking for an add of a phi and a constant 8095ffd83dbSDimitry Andric PHINode *Phi = dyn_cast<PHINode>(Offsets); 8105ffd83dbSDimitry Andric if (Phi == nullptr || Phi->getNumIncomingValues() != 2 || 8115ffd83dbSDimitry Andric Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2) 8125ffd83dbSDimitry Andric // No phi means no IV to write back to; if there is a phi, we expect it 8135ffd83dbSDimitry Andric // to have exactly two incoming values; the only phis we are interested in 8145ffd83dbSDimitry Andric // will be loop IV's and have exactly two uses, one in their increment and 8155ffd83dbSDimitry Andric // one in the gather's gep 8165ffd83dbSDimitry Andric return nullptr; 8175ffd83dbSDimitry Andric 8185ffd83dbSDimitry Andric unsigned IncrementIndex = 8195ffd83dbSDimitry Andric Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1; 8205ffd83dbSDimitry Andric // Look through the phi to the phi increment 8215ffd83dbSDimitry Andric Offsets = Phi->getIncomingValue(IncrementIndex); 8225ffd83dbSDimitry Andric 8235ffd83dbSDimitry Andric std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale); 8245ffd83dbSDimitry Andric if (Add.first == nullptr) 8255ffd83dbSDimitry Andric return nullptr; 8265ffd83dbSDimitry Andric Value *OffsetsIncoming = Add.first; 8275ffd83dbSDimitry Andric int64_t Immediate = Add.second; 8285ffd83dbSDimitry Andric if (OffsetsIncoming != Phi) 8295ffd83dbSDimitry Andric // Then the increment we are looking at is not an increment of the 8305ffd83dbSDimitry Andric // induction variable, and we don't want to do a writeback 8315ffd83dbSDimitry Andric return nullptr; 8325ffd83dbSDimitry Andric 8335ffd83dbSDimitry Andric Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back()); 8345ffd83dbSDimitry Andric unsigned NumElems = 8355ffd83dbSDimitry Andric cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements(); 8365ffd83dbSDimitry Andric 8375ffd83dbSDimitry Andric // Make sure the offsets are scaled correctly 8385ffd83dbSDimitry Andric Instruction *ScaledOffsets = BinaryOperator::Create( 8395ffd83dbSDimitry Andric Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex), 8405ffd83dbSDimitry Andric Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)), 8415ffd83dbSDimitry Andric "ScaledIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back()); 8425ffd83dbSDimitry Andric // Add the base to the offsets 8435ffd83dbSDimitry Andric OffsetsIncoming = BinaryOperator::Create( 8445ffd83dbSDimitry Andric Instruction::Add, ScaledOffsets, 8455ffd83dbSDimitry Andric Builder.CreateVectorSplat( 8465ffd83dbSDimitry Andric NumElems, 8475ffd83dbSDimitry Andric Builder.CreatePtrToInt( 8485ffd83dbSDimitry Andric BasePtr, 8495ffd83dbSDimitry Andric cast<VectorType>(ScaledOffsets->getType())->getElementType())), 8505ffd83dbSDimitry Andric "StartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back()); 8515ffd83dbSDimitry Andric // The gather is pre-incrementing 8525ffd83dbSDimitry Andric OffsetsIncoming = BinaryOperator::Create( 8535ffd83dbSDimitry Andric Instruction::Sub, OffsetsIncoming, 8545ffd83dbSDimitry Andric Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)), 8555ffd83dbSDimitry Andric "PreIncrementStartIndex", 8565ffd83dbSDimitry Andric &Phi->getIncomingBlock(1 - IncrementIndex)->back()); 8575ffd83dbSDimitry Andric Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming); 8585ffd83dbSDimitry Andric 8595ffd83dbSDimitry Andric Builder.SetInsertPoint(I); 8605ffd83dbSDimitry Andric 861fe6060f1SDimitry Andric Instruction *EndResult; 862fe6060f1SDimitry Andric Instruction *NewInduction; 8635ffd83dbSDimitry Andric if (I->getIntrinsicID() == Intrinsic::masked_gather) { 8645ffd83dbSDimitry Andric // Build the incrementing gather 8655ffd83dbSDimitry Andric Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate); 8665ffd83dbSDimitry Andric // One value to be handed to whoever uses the gather, one is the loop 8675ffd83dbSDimitry Andric // increment 868fe6060f1SDimitry Andric EndResult = ExtractValueInst::Create(Load, 0, "Gather"); 869fe6060f1SDimitry Andric NewInduction = ExtractValueInst::Create(Load, 1, "GatherIncrement"); 870fe6060f1SDimitry Andric Builder.Insert(EndResult); 871fe6060f1SDimitry Andric Builder.Insert(NewInduction); 8725ffd83dbSDimitry Andric } else { 8735ffd83dbSDimitry Andric // Build the incrementing scatter 874fe6060f1SDimitry Andric EndResult = NewInduction = 875fe6060f1SDimitry Andric tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate); 8765ffd83dbSDimitry Andric } 8775ffd83dbSDimitry Andric Instruction *AddInst = cast<Instruction>(Offsets); 8785ffd83dbSDimitry Andric AddInst->replaceAllUsesWith(NewInduction); 8795ffd83dbSDimitry Andric AddInst->eraseFromParent(); 8805ffd83dbSDimitry Andric Phi->setIncomingValue(IncrementIndex, NewInduction); 8815ffd83dbSDimitry Andric 8825ffd83dbSDimitry Andric return EndResult; 8835ffd83dbSDimitry Andric } 8845ffd83dbSDimitry Andric 8855ffd83dbSDimitry Andric void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi, 8865ffd83dbSDimitry Andric Value *OffsSecondOperand, 8875ffd83dbSDimitry Andric unsigned StartIndex) { 8885ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n"); 8895ffd83dbSDimitry Andric Instruction *InsertionPoint = 8905ffd83dbSDimitry Andric &cast<Instruction>(Phi->getIncomingBlock(StartIndex)->back()); 8915ffd83dbSDimitry Andric // Initialize the phi with a vector that contains a sum of the constants 8925ffd83dbSDimitry Andric Instruction *NewIndex = BinaryOperator::Create( 8935ffd83dbSDimitry Andric Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand, 8945ffd83dbSDimitry Andric "PushedOutAdd", InsertionPoint); 8955ffd83dbSDimitry Andric unsigned IncrementIndex = StartIndex == 0 ? 1 : 0; 8965ffd83dbSDimitry Andric 8975ffd83dbSDimitry Andric // Order such that start index comes first (this reduces mov's) 8985ffd83dbSDimitry Andric Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex)); 8995ffd83dbSDimitry Andric Phi->addIncoming(Phi->getIncomingValue(IncrementIndex), 9005ffd83dbSDimitry Andric Phi->getIncomingBlock(IncrementIndex)); 9015ffd83dbSDimitry Andric Phi->removeIncomingValue(IncrementIndex); 9025ffd83dbSDimitry Andric Phi->removeIncomingValue(StartIndex); 9035ffd83dbSDimitry Andric } 9045ffd83dbSDimitry Andric 905349cc55cSDimitry Andric void MVEGatherScatterLowering::pushOutMulShl(unsigned Opcode, PHINode *&Phi, 9065ffd83dbSDimitry Andric Value *IncrementPerRound, 9075ffd83dbSDimitry Andric Value *OffsSecondOperand, 9085ffd83dbSDimitry Andric unsigned LoopIncrement, 9095ffd83dbSDimitry Andric IRBuilder<> &Builder) { 9105ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n"); 9115ffd83dbSDimitry Andric 9125ffd83dbSDimitry Andric // Create a new scalar add outside of the loop and transform it to a splat 9135ffd83dbSDimitry Andric // by which loop variable can be incremented 9145ffd83dbSDimitry Andric Instruction *InsertionPoint = &cast<Instruction>( 9155ffd83dbSDimitry Andric Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back()); 9165ffd83dbSDimitry Andric 9175ffd83dbSDimitry Andric // Create a new index 918349cc55cSDimitry Andric Value *StartIndex = 919349cc55cSDimitry Andric BinaryOperator::Create((Instruction::BinaryOps)Opcode, 920349cc55cSDimitry Andric Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1), 9215ffd83dbSDimitry Andric OffsSecondOperand, "PushedOutMul", InsertionPoint); 9225ffd83dbSDimitry Andric 9235ffd83dbSDimitry Andric Instruction *Product = 924349cc55cSDimitry Andric BinaryOperator::Create((Instruction::BinaryOps)Opcode, IncrementPerRound, 9255ffd83dbSDimitry Andric OffsSecondOperand, "Product", InsertionPoint); 9265ffd83dbSDimitry Andric // Increment NewIndex by Product instead of the multiplication 9275ffd83dbSDimitry Andric Instruction *NewIncrement = BinaryOperator::Create( 9285ffd83dbSDimitry Andric Instruction::Add, Phi, Product, "IncrementPushedOutMul", 9295ffd83dbSDimitry Andric cast<Instruction>(Phi->getIncomingBlock(LoopIncrement)->back()) 9305ffd83dbSDimitry Andric .getPrevNode()); 9315ffd83dbSDimitry Andric 9325ffd83dbSDimitry Andric Phi->addIncoming(StartIndex, 9335ffd83dbSDimitry Andric Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)); 9345ffd83dbSDimitry Andric Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement)); 9355ffd83dbSDimitry Andric Phi->removeIncomingValue((unsigned)0); 9365ffd83dbSDimitry Andric Phi->removeIncomingValue((unsigned)0); 9375ffd83dbSDimitry Andric } 9385ffd83dbSDimitry Andric 9395ffd83dbSDimitry Andric // Check whether all usages of this instruction are as offsets of 9405ffd83dbSDimitry Andric // gathers/scatters or simple arithmetics only used by gathers/scatters 941349cc55cSDimitry Andric static bool hasAllGatScatUsers(Instruction *I, const DataLayout &DL) { 9425ffd83dbSDimitry Andric if (I->hasNUses(0)) { 9435ffd83dbSDimitry Andric return false; 9445ffd83dbSDimitry Andric } 9455ffd83dbSDimitry Andric bool Gatscat = true; 9465ffd83dbSDimitry Andric for (User *U : I->users()) { 9475ffd83dbSDimitry Andric if (!isa<Instruction>(U)) 9485ffd83dbSDimitry Andric return false; 9495ffd83dbSDimitry Andric if (isa<GetElementPtrInst>(U) || 9505ffd83dbSDimitry Andric isGatherScatter(dyn_cast<IntrinsicInst>(U))) { 9515ffd83dbSDimitry Andric return Gatscat; 9525ffd83dbSDimitry Andric } else { 9535ffd83dbSDimitry Andric unsigned OpCode = cast<Instruction>(U)->getOpcode(); 954349cc55cSDimitry Andric if ((OpCode == Instruction::Add || OpCode == Instruction::Mul || 955349cc55cSDimitry Andric OpCode == Instruction::Shl || 956349cc55cSDimitry Andric isAddLikeOr(cast<Instruction>(U), DL)) && 957349cc55cSDimitry Andric hasAllGatScatUsers(cast<Instruction>(U), DL)) { 9585ffd83dbSDimitry Andric continue; 9595ffd83dbSDimitry Andric } 9605ffd83dbSDimitry Andric return false; 9615ffd83dbSDimitry Andric } 9625ffd83dbSDimitry Andric } 9635ffd83dbSDimitry Andric return Gatscat; 9645ffd83dbSDimitry Andric } 9655ffd83dbSDimitry Andric 9665ffd83dbSDimitry Andric bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB, 9675ffd83dbSDimitry Andric LoopInfo *LI) { 968*81ad6265SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize: " 969fe6060f1SDimitry Andric << *Offsets << "\n"); 9705ffd83dbSDimitry Andric // Optimise the addresses of gathers/scatters by moving invariant 9715ffd83dbSDimitry Andric // calculations out of the loop 9725ffd83dbSDimitry Andric if (!isa<Instruction>(Offsets)) 9735ffd83dbSDimitry Andric return false; 9745ffd83dbSDimitry Andric Instruction *Offs = cast<Instruction>(Offsets); 975349cc55cSDimitry Andric if (Offs->getOpcode() != Instruction::Add && !isAddLikeOr(Offs, *DL) && 976349cc55cSDimitry Andric Offs->getOpcode() != Instruction::Mul && 977349cc55cSDimitry Andric Offs->getOpcode() != Instruction::Shl) 9785ffd83dbSDimitry Andric return false; 9795ffd83dbSDimitry Andric Loop *L = LI->getLoopFor(BB); 9805ffd83dbSDimitry Andric if (L == nullptr) 9815ffd83dbSDimitry Andric return false; 9825ffd83dbSDimitry Andric if (!Offs->hasOneUse()) { 983349cc55cSDimitry Andric if (!hasAllGatScatUsers(Offs, *DL)) 9845ffd83dbSDimitry Andric return false; 9855ffd83dbSDimitry Andric } 9865ffd83dbSDimitry Andric 9875ffd83dbSDimitry Andric // Find out which, if any, operand of the instruction 9885ffd83dbSDimitry Andric // is a phi node 9895ffd83dbSDimitry Andric PHINode *Phi; 9905ffd83dbSDimitry Andric int OffsSecondOp; 9915ffd83dbSDimitry Andric if (isa<PHINode>(Offs->getOperand(0))) { 9925ffd83dbSDimitry Andric Phi = cast<PHINode>(Offs->getOperand(0)); 9935ffd83dbSDimitry Andric OffsSecondOp = 1; 9945ffd83dbSDimitry Andric } else if (isa<PHINode>(Offs->getOperand(1))) { 9955ffd83dbSDimitry Andric Phi = cast<PHINode>(Offs->getOperand(1)); 9965ffd83dbSDimitry Andric OffsSecondOp = 0; 9975ffd83dbSDimitry Andric } else { 998fe6060f1SDimitry Andric bool Changed = false; 9995ffd83dbSDimitry Andric if (isa<Instruction>(Offs->getOperand(0)) && 10005ffd83dbSDimitry Andric L->contains(cast<Instruction>(Offs->getOperand(0)))) 10015ffd83dbSDimitry Andric Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI); 10025ffd83dbSDimitry Andric if (isa<Instruction>(Offs->getOperand(1)) && 10035ffd83dbSDimitry Andric L->contains(cast<Instruction>(Offs->getOperand(1)))) 10045ffd83dbSDimitry Andric Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI); 1005fe6060f1SDimitry Andric if (!Changed) 10065ffd83dbSDimitry Andric return false; 10075ffd83dbSDimitry Andric if (isa<PHINode>(Offs->getOperand(0))) { 10085ffd83dbSDimitry Andric Phi = cast<PHINode>(Offs->getOperand(0)); 10095ffd83dbSDimitry Andric OffsSecondOp = 1; 10105ffd83dbSDimitry Andric } else if (isa<PHINode>(Offs->getOperand(1))) { 10115ffd83dbSDimitry Andric Phi = cast<PHINode>(Offs->getOperand(1)); 10125ffd83dbSDimitry Andric OffsSecondOp = 0; 10135ffd83dbSDimitry Andric } else { 10145ffd83dbSDimitry Andric return false; 10155ffd83dbSDimitry Andric } 10165ffd83dbSDimitry Andric } 10175ffd83dbSDimitry Andric // A phi node we want to perform this function on should be from the 1018fe6060f1SDimitry Andric // loop header. 1019fe6060f1SDimitry Andric if (Phi->getParent() != L->getHeader()) 10205ffd83dbSDimitry Andric return false; 10215ffd83dbSDimitry Andric 1022fe6060f1SDimitry Andric // We're looking for a simple add recurrence. 1023fe6060f1SDimitry Andric BinaryOperator *IncInstruction; 1024fe6060f1SDimitry Andric Value *Start, *IncrementPerRound; 1025fe6060f1SDimitry Andric if (!matchSimpleRecurrence(Phi, IncInstruction, Start, IncrementPerRound) || 1026fe6060f1SDimitry Andric IncInstruction->getOpcode() != Instruction::Add) 10275ffd83dbSDimitry Andric return false; 10285ffd83dbSDimitry Andric 1029fe6060f1SDimitry Andric int IncrementingBlock = Phi->getIncomingValue(0) == IncInstruction ? 0 : 1; 10305ffd83dbSDimitry Andric 10315ffd83dbSDimitry Andric // Get the value that is added to/multiplied with the phi 10325ffd83dbSDimitry Andric Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp); 10335ffd83dbSDimitry Andric 10344652422eSDimitry Andric if (IncrementPerRound->getType() != OffsSecondOperand->getType() || 10354652422eSDimitry Andric !L->isLoopInvariant(OffsSecondOperand)) 10365ffd83dbSDimitry Andric // Something has gone wrong, abort 10375ffd83dbSDimitry Andric return false; 10385ffd83dbSDimitry Andric 10395ffd83dbSDimitry Andric // Only proceed if the increment per round is a constant or an instruction 10405ffd83dbSDimitry Andric // which does not originate from within the loop 10415ffd83dbSDimitry Andric if (!isa<Constant>(IncrementPerRound) && 10425ffd83dbSDimitry Andric !(isa<Instruction>(IncrementPerRound) && 10435ffd83dbSDimitry Andric !L->contains(cast<Instruction>(IncrementPerRound)))) 10445ffd83dbSDimitry Andric return false; 10455ffd83dbSDimitry Andric 1046fe6060f1SDimitry Andric // If the phi is not used by anything else, we can just adapt it when 1047fe6060f1SDimitry Andric // replacing the instruction; if it is, we'll have to duplicate it 1048fe6060f1SDimitry Andric PHINode *NewPhi; 10495ffd83dbSDimitry Andric if (Phi->getNumUses() == 2) { 10505ffd83dbSDimitry Andric // No other users -> reuse existing phi (One user is the instruction 10515ffd83dbSDimitry Andric // we're looking at, the other is the phi increment) 10525ffd83dbSDimitry Andric if (IncInstruction->getNumUses() != 1) { 10535ffd83dbSDimitry Andric // If the incrementing instruction does have more users than 10545ffd83dbSDimitry Andric // our phi, we need to copy it 10555ffd83dbSDimitry Andric IncInstruction = BinaryOperator::Create( 10565ffd83dbSDimitry Andric Instruction::BinaryOps(IncInstruction->getOpcode()), Phi, 10575ffd83dbSDimitry Andric IncrementPerRound, "LoopIncrement", IncInstruction); 10585ffd83dbSDimitry Andric Phi->setIncomingValue(IncrementingBlock, IncInstruction); 10595ffd83dbSDimitry Andric } 10605ffd83dbSDimitry Andric NewPhi = Phi; 10615ffd83dbSDimitry Andric } else { 10625ffd83dbSDimitry Andric // There are other users -> create a new phi 1063fe6060f1SDimitry Andric NewPhi = PHINode::Create(Phi->getType(), 2, "NewPhi", Phi); 10645ffd83dbSDimitry Andric // Copy the incoming values of the old phi 10655ffd83dbSDimitry Andric NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1), 10665ffd83dbSDimitry Andric Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1)); 10675ffd83dbSDimitry Andric IncInstruction = BinaryOperator::Create( 10685ffd83dbSDimitry Andric Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi, 10695ffd83dbSDimitry Andric IncrementPerRound, "LoopIncrement", IncInstruction); 10705ffd83dbSDimitry Andric NewPhi->addIncoming(IncInstruction, 10715ffd83dbSDimitry Andric Phi->getIncomingBlock(IncrementingBlock)); 10725ffd83dbSDimitry Andric IncrementingBlock = 1; 10735ffd83dbSDimitry Andric } 10745ffd83dbSDimitry Andric 10755ffd83dbSDimitry Andric IRBuilder<> Builder(BB->getContext()); 10765ffd83dbSDimitry Andric Builder.SetInsertPoint(Phi); 10775ffd83dbSDimitry Andric Builder.SetCurrentDebugLocation(Offs->getDebugLoc()); 10785ffd83dbSDimitry Andric 10795ffd83dbSDimitry Andric switch (Offs->getOpcode()) { 10805ffd83dbSDimitry Andric case Instruction::Add: 1081349cc55cSDimitry Andric case Instruction::Or: 10825ffd83dbSDimitry Andric pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1); 10835ffd83dbSDimitry Andric break; 10845ffd83dbSDimitry Andric case Instruction::Mul: 1085349cc55cSDimitry Andric case Instruction::Shl: 1086349cc55cSDimitry Andric pushOutMulShl(Offs->getOpcode(), NewPhi, IncrementPerRound, 1087349cc55cSDimitry Andric OffsSecondOperand, IncrementingBlock, Builder); 10885ffd83dbSDimitry Andric break; 10895ffd83dbSDimitry Andric default: 10905ffd83dbSDimitry Andric return false; 10915ffd83dbSDimitry Andric } 1092fe6060f1SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: simplified loop variable " 1093fe6060f1SDimitry Andric << "add/mul\n"); 10945ffd83dbSDimitry Andric 10955ffd83dbSDimitry Andric // The instruction has now been "absorbed" into the phi value 10965ffd83dbSDimitry Andric Offs->replaceAllUsesWith(NewPhi); 10975ffd83dbSDimitry Andric if (Offs->hasNUses(0)) 10985ffd83dbSDimitry Andric Offs->eraseFromParent(); 10995ffd83dbSDimitry Andric // Clean up the old increment in case it's unused because we built a new 11005ffd83dbSDimitry Andric // one 11015ffd83dbSDimitry Andric if (IncInstruction->hasNUses(0)) 11025ffd83dbSDimitry Andric IncInstruction->eraseFromParent(); 11035ffd83dbSDimitry Andric 11045ffd83dbSDimitry Andric return true; 1105480093f4SDimitry Andric } 1106480093f4SDimitry Andric 1107*81ad6265SDimitry Andric static Value *CheckAndCreateOffsetAdd(Value *X, unsigned ScaleX, Value *Y, 1108*81ad6265SDimitry Andric unsigned ScaleY, IRBuilder<> &Builder) { 1109e8d8bef9SDimitry Andric // Splat the non-vector value to a vector of the given type - if the value is 1110e8d8bef9SDimitry Andric // a constant (and its value isn't too big), we can even use this opportunity 1111e8d8bef9SDimitry Andric // to scale it to the size of the vector elements 1112e8d8bef9SDimitry Andric auto FixSummands = [&Builder](FixedVectorType *&VT, Value *&NonVectorVal) { 1113e8d8bef9SDimitry Andric ConstantInt *Const; 1114e8d8bef9SDimitry Andric if ((Const = dyn_cast<ConstantInt>(NonVectorVal)) && 1115e8d8bef9SDimitry Andric VT->getElementType() != NonVectorVal->getType()) { 1116e8d8bef9SDimitry Andric unsigned TargetElemSize = VT->getElementType()->getPrimitiveSizeInBits(); 1117e8d8bef9SDimitry Andric uint64_t N = Const->getZExtValue(); 1118e8d8bef9SDimitry Andric if (N < (unsigned)(1 << (TargetElemSize - 1))) { 1119e8d8bef9SDimitry Andric NonVectorVal = Builder.CreateVectorSplat( 1120e8d8bef9SDimitry Andric VT->getNumElements(), Builder.getIntN(TargetElemSize, N)); 1121e8d8bef9SDimitry Andric return; 1122e8d8bef9SDimitry Andric } 1123e8d8bef9SDimitry Andric } 1124e8d8bef9SDimitry Andric NonVectorVal = 1125e8d8bef9SDimitry Andric Builder.CreateVectorSplat(VT->getNumElements(), NonVectorVal); 1126e8d8bef9SDimitry Andric }; 1127e8d8bef9SDimitry Andric 1128e8d8bef9SDimitry Andric FixedVectorType *XElType = dyn_cast<FixedVectorType>(X->getType()); 1129e8d8bef9SDimitry Andric FixedVectorType *YElType = dyn_cast<FixedVectorType>(Y->getType()); 1130e8d8bef9SDimitry Andric // If one of X, Y is not a vector, we have to splat it in order 1131e8d8bef9SDimitry Andric // to add the two of them. 1132e8d8bef9SDimitry Andric if (XElType && !YElType) { 1133e8d8bef9SDimitry Andric FixSummands(XElType, Y); 1134e8d8bef9SDimitry Andric YElType = cast<FixedVectorType>(Y->getType()); 1135e8d8bef9SDimitry Andric } else if (YElType && !XElType) { 1136e8d8bef9SDimitry Andric FixSummands(YElType, X); 1137e8d8bef9SDimitry Andric XElType = cast<FixedVectorType>(X->getType()); 1138e8d8bef9SDimitry Andric } 1139e8d8bef9SDimitry Andric assert(XElType && YElType && "Unknown vector types"); 1140e8d8bef9SDimitry Andric // Check that the summands are of compatible types 1141e8d8bef9SDimitry Andric if (XElType != YElType) { 1142e8d8bef9SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: incompatible gep offsets\n"); 1143e8d8bef9SDimitry Andric return nullptr; 1144e8d8bef9SDimitry Andric } 1145e8d8bef9SDimitry Andric 1146e8d8bef9SDimitry Andric if (XElType->getElementType()->getScalarSizeInBits() != 32) { 1147e8d8bef9SDimitry Andric // Check that by adding the vectors we do not accidentally 1148e8d8bef9SDimitry Andric // create an overflow 1149e8d8bef9SDimitry Andric Constant *ConstX = dyn_cast<Constant>(X); 1150e8d8bef9SDimitry Andric Constant *ConstY = dyn_cast<Constant>(Y); 1151e8d8bef9SDimitry Andric if (!ConstX || !ConstY) 1152e8d8bef9SDimitry Andric return nullptr; 1153e8d8bef9SDimitry Andric unsigned TargetElemSize = 128 / XElType->getNumElements(); 1154e8d8bef9SDimitry Andric for (unsigned i = 0; i < XElType->getNumElements(); i++) { 1155e8d8bef9SDimitry Andric ConstantInt *ConstXEl = 1156e8d8bef9SDimitry Andric dyn_cast<ConstantInt>(ConstX->getAggregateElement(i)); 1157e8d8bef9SDimitry Andric ConstantInt *ConstYEl = 1158e8d8bef9SDimitry Andric dyn_cast<ConstantInt>(ConstY->getAggregateElement(i)); 1159e8d8bef9SDimitry Andric if (!ConstXEl || !ConstYEl || 1160*81ad6265SDimitry Andric ConstXEl->getZExtValue() * ScaleX + 1161*81ad6265SDimitry Andric ConstYEl->getZExtValue() * ScaleY >= 1162e8d8bef9SDimitry Andric (unsigned)(1 << (TargetElemSize - 1))) 1163e8d8bef9SDimitry Andric return nullptr; 1164e8d8bef9SDimitry Andric } 1165e8d8bef9SDimitry Andric } 1166e8d8bef9SDimitry Andric 1167*81ad6265SDimitry Andric Value *XScale = Builder.CreateVectorSplat( 1168*81ad6265SDimitry Andric XElType->getNumElements(), 1169*81ad6265SDimitry Andric Builder.getIntN(XElType->getScalarSizeInBits(), ScaleX)); 1170*81ad6265SDimitry Andric Value *YScale = Builder.CreateVectorSplat( 1171*81ad6265SDimitry Andric YElType->getNumElements(), 1172*81ad6265SDimitry Andric Builder.getIntN(YElType->getScalarSizeInBits(), ScaleY)); 1173*81ad6265SDimitry Andric Value *Add = Builder.CreateAdd(Builder.CreateMul(X, XScale), 1174*81ad6265SDimitry Andric Builder.CreateMul(Y, YScale)); 1175e8d8bef9SDimitry Andric 1176*81ad6265SDimitry Andric if (checkOffsetSize(Add, XElType->getNumElements())) 1177e8d8bef9SDimitry Andric return Add; 1178e8d8bef9SDimitry Andric else 1179e8d8bef9SDimitry Andric return nullptr; 1180e8d8bef9SDimitry Andric } 1181e8d8bef9SDimitry Andric 1182e8d8bef9SDimitry Andric Value *MVEGatherScatterLowering::foldGEP(GetElementPtrInst *GEP, 1183*81ad6265SDimitry Andric Value *&Offsets, unsigned &Scale, 1184e8d8bef9SDimitry Andric IRBuilder<> &Builder) { 1185e8d8bef9SDimitry Andric Value *GEPPtr = GEP->getPointerOperand(); 1186e8d8bef9SDimitry Andric Offsets = GEP->getOperand(1); 1187*81ad6265SDimitry Andric Scale = DL->getTypeAllocSize(GEP->getSourceElementType()); 1188e8d8bef9SDimitry Andric // We only merge geps with constant offsets, because only for those 1189e8d8bef9SDimitry Andric // we can make sure that we do not cause an overflow 1190*81ad6265SDimitry Andric if (GEP->getNumIndices() != 1 || !isa<Constant>(Offsets)) 1191e8d8bef9SDimitry Andric return nullptr; 1192*81ad6265SDimitry Andric if (GetElementPtrInst *BaseGEP = dyn_cast<GetElementPtrInst>(GEPPtr)) { 1193e8d8bef9SDimitry Andric // Merge the two geps into one 1194*81ad6265SDimitry Andric Value *BaseBasePtr = foldGEP(BaseGEP, Offsets, Scale, Builder); 1195e8d8bef9SDimitry Andric if (!BaseBasePtr) 1196e8d8bef9SDimitry Andric return nullptr; 1197*81ad6265SDimitry Andric Offsets = CheckAndCreateOffsetAdd( 1198*81ad6265SDimitry Andric Offsets, Scale, GEP->getOperand(1), 1199*81ad6265SDimitry Andric DL->getTypeAllocSize(GEP->getSourceElementType()), Builder); 1200e8d8bef9SDimitry Andric if (Offsets == nullptr) 1201e8d8bef9SDimitry Andric return nullptr; 1202*81ad6265SDimitry Andric Scale = 1; // Scale is always an i8 at this point. 1203e8d8bef9SDimitry Andric return BaseBasePtr; 1204e8d8bef9SDimitry Andric } 1205e8d8bef9SDimitry Andric return GEPPtr; 1206e8d8bef9SDimitry Andric } 1207e8d8bef9SDimitry Andric 1208e8d8bef9SDimitry Andric bool MVEGatherScatterLowering::optimiseAddress(Value *Address, BasicBlock *BB, 1209e8d8bef9SDimitry Andric LoopInfo *LI) { 1210e8d8bef9SDimitry Andric GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Address); 1211e8d8bef9SDimitry Andric if (!GEP) 1212e8d8bef9SDimitry Andric return false; 1213e8d8bef9SDimitry Andric bool Changed = false; 1214349cc55cSDimitry Andric if (GEP->hasOneUse() && isa<GetElementPtrInst>(GEP->getPointerOperand())) { 1215e8d8bef9SDimitry Andric IRBuilder<> Builder(GEP->getContext()); 1216e8d8bef9SDimitry Andric Builder.SetInsertPoint(GEP); 1217e8d8bef9SDimitry Andric Builder.SetCurrentDebugLocation(GEP->getDebugLoc()); 1218e8d8bef9SDimitry Andric Value *Offsets; 1219*81ad6265SDimitry Andric unsigned Scale; 1220*81ad6265SDimitry Andric Value *Base = foldGEP(GEP, Offsets, Scale, Builder); 1221e8d8bef9SDimitry Andric // We only want to merge the geps if there is a real chance that they can be 1222e8d8bef9SDimitry Andric // used by an MVE gather; thus the offset has to have the correct size 1223e8d8bef9SDimitry Andric // (always i32 if it is not of vector type) and the base has to be a 1224e8d8bef9SDimitry Andric // pointer. 1225e8d8bef9SDimitry Andric if (Offsets && Base && Base != GEP) { 1226*81ad6265SDimitry Andric assert(Scale == 1 && "Expected to fold GEP to a scale of 1"); 1227*81ad6265SDimitry Andric Type *BaseTy = Builder.getInt8PtrTy(); 1228*81ad6265SDimitry Andric if (auto *VecTy = dyn_cast<FixedVectorType>(Base->getType())) 1229*81ad6265SDimitry Andric BaseTy = FixedVectorType::get(BaseTy, VecTy); 1230e8d8bef9SDimitry Andric GetElementPtrInst *NewAddress = GetElementPtrInst::Create( 1231*81ad6265SDimitry Andric Builder.getInt8Ty(), Builder.CreateBitCast(Base, BaseTy), Offsets, 1232*81ad6265SDimitry Andric "gep.merged", GEP); 1233*81ad6265SDimitry Andric LLVM_DEBUG(dbgs() << "Folded GEP: " << *GEP 1234*81ad6265SDimitry Andric << "\n new : " << *NewAddress << "\n"); 1235*81ad6265SDimitry Andric GEP->replaceAllUsesWith( 1236*81ad6265SDimitry Andric Builder.CreateBitCast(NewAddress, GEP->getType())); 1237e8d8bef9SDimitry Andric GEP = NewAddress; 1238e8d8bef9SDimitry Andric Changed = true; 1239e8d8bef9SDimitry Andric } 1240e8d8bef9SDimitry Andric } 1241e8d8bef9SDimitry Andric Changed |= optimiseOffsets(GEP->getOperand(1), GEP->getParent(), LI); 1242e8d8bef9SDimitry Andric return Changed; 1243e8d8bef9SDimitry Andric } 1244e8d8bef9SDimitry Andric 1245480093f4SDimitry Andric bool MVEGatherScatterLowering::runOnFunction(Function &F) { 1246480093f4SDimitry Andric if (!EnableMaskedGatherScatters) 1247480093f4SDimitry Andric return false; 1248480093f4SDimitry Andric auto &TPC = getAnalysis<TargetPassConfig>(); 1249480093f4SDimitry Andric auto &TM = TPC.getTM<TargetMachine>(); 1250480093f4SDimitry Andric auto *ST = &TM.getSubtarget<ARMSubtarget>(F); 1251480093f4SDimitry Andric if (!ST->hasMVEIntegerOps()) 1252480093f4SDimitry Andric return false; 12535ffd83dbSDimitry Andric LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 1254349cc55cSDimitry Andric DL = &F.getParent()->getDataLayout(); 1255480093f4SDimitry Andric SmallVector<IntrinsicInst *, 4> Gathers; 12565ffd83dbSDimitry Andric SmallVector<IntrinsicInst *, 4> Scatters; 12575ffd83dbSDimitry Andric 12585ffd83dbSDimitry Andric bool Changed = false; 12595ffd83dbSDimitry Andric 1260480093f4SDimitry Andric for (BasicBlock &BB : F) { 12614652422eSDimitry Andric Changed |= SimplifyInstructionsInBlock(&BB); 12624652422eSDimitry Andric 1263480093f4SDimitry Andric for (Instruction &I : BB) { 1264480093f4SDimitry Andric IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I); 1265e8d8bef9SDimitry Andric if (II && II->getIntrinsicID() == Intrinsic::masked_gather && 1266e8d8bef9SDimitry Andric isa<FixedVectorType>(II->getType())) { 1267480093f4SDimitry Andric Gathers.push_back(II); 1268e8d8bef9SDimitry Andric Changed |= optimiseAddress(II->getArgOperand(0), II->getParent(), LI); 1269e8d8bef9SDimitry Andric } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter && 1270e8d8bef9SDimitry Andric isa<FixedVectorType>(II->getArgOperand(0)->getType())) { 12715ffd83dbSDimitry Andric Scatters.push_back(II); 1272e8d8bef9SDimitry Andric Changed |= optimiseAddress(II->getArgOperand(1), II->getParent(), LI); 12735ffd83dbSDimitry Andric } 1274480093f4SDimitry Andric } 1275480093f4SDimitry Andric } 12765ffd83dbSDimitry Andric for (unsigned i = 0; i < Gathers.size(); i++) { 12775ffd83dbSDimitry Andric IntrinsicInst *I = Gathers[i]; 1278fe6060f1SDimitry Andric Instruction *L = lowerGather(I); 12795ffd83dbSDimitry Andric if (L == nullptr) 12805ffd83dbSDimitry Andric continue; 1281480093f4SDimitry Andric 12825ffd83dbSDimitry Andric // Get rid of any now dead instructions 1283fe6060f1SDimitry Andric SimplifyInstructionsInBlock(L->getParent()); 12845ffd83dbSDimitry Andric Changed = true; 12855ffd83dbSDimitry Andric } 1286480093f4SDimitry Andric 12875ffd83dbSDimitry Andric for (unsigned i = 0; i < Scatters.size(); i++) { 12885ffd83dbSDimitry Andric IntrinsicInst *I = Scatters[i]; 1289fe6060f1SDimitry Andric Instruction *S = lowerScatter(I); 12905ffd83dbSDimitry Andric if (S == nullptr) 12915ffd83dbSDimitry Andric continue; 12925ffd83dbSDimitry Andric 12935ffd83dbSDimitry Andric // Get rid of any now dead instructions 1294fe6060f1SDimitry Andric SimplifyInstructionsInBlock(S->getParent()); 12955ffd83dbSDimitry Andric Changed = true; 12965ffd83dbSDimitry Andric } 12975ffd83dbSDimitry Andric return Changed; 1298480093f4SDimitry Andric } 1299