1480093f4SDimitry Andric //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===// 2480093f4SDimitry Andric // 3480093f4SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4480093f4SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5480093f4SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6480093f4SDimitry Andric // 7480093f4SDimitry Andric //===----------------------------------------------------------------------===// 8480093f4SDimitry Andric // 9480093f4SDimitry Andric /// This pass custom lowers llvm.gather and llvm.scatter instructions to 10480093f4SDimitry Andric /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to 11480093f4SDimitry Andric /// produce a better final result as we go. 12480093f4SDimitry Andric // 13480093f4SDimitry Andric //===----------------------------------------------------------------------===// 14480093f4SDimitry Andric 15480093f4SDimitry Andric #include "ARM.h" 16480093f4SDimitry Andric #include "ARMBaseInstrInfo.h" 17480093f4SDimitry Andric #include "ARMSubtarget.h" 18*5ffd83dbSDimitry Andric #include "llvm/Analysis/LoopInfo.h" 19480093f4SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h" 20480093f4SDimitry Andric #include "llvm/CodeGen/TargetLowering.h" 21480093f4SDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h" 22480093f4SDimitry Andric #include "llvm/CodeGen/TargetSubtargetInfo.h" 23480093f4SDimitry Andric #include "llvm/InitializePasses.h" 24480093f4SDimitry Andric #include "llvm/IR/BasicBlock.h" 25480093f4SDimitry Andric #include "llvm/IR/Constant.h" 26480093f4SDimitry Andric #include "llvm/IR/Constants.h" 27480093f4SDimitry Andric #include "llvm/IR/DerivedTypes.h" 28480093f4SDimitry Andric #include "llvm/IR/Function.h" 29480093f4SDimitry Andric #include "llvm/IR/InstrTypes.h" 30480093f4SDimitry Andric #include "llvm/IR/Instruction.h" 31480093f4SDimitry Andric #include "llvm/IR/Instructions.h" 32480093f4SDimitry Andric #include "llvm/IR/IntrinsicInst.h" 33480093f4SDimitry Andric #include "llvm/IR/Intrinsics.h" 34480093f4SDimitry Andric #include "llvm/IR/IntrinsicsARM.h" 35480093f4SDimitry Andric #include "llvm/IR/IRBuilder.h" 36480093f4SDimitry Andric #include "llvm/IR/PatternMatch.h" 37480093f4SDimitry Andric #include "llvm/IR/Type.h" 38480093f4SDimitry Andric #include "llvm/IR/Value.h" 39480093f4SDimitry Andric #include "llvm/Pass.h" 40480093f4SDimitry Andric #include "llvm/Support/Casting.h" 41*5ffd83dbSDimitry Andric #include "llvm/Transforms/Utils/Local.h" 42480093f4SDimitry Andric #include <algorithm> 43480093f4SDimitry Andric #include <cassert> 44480093f4SDimitry Andric 45480093f4SDimitry Andric using namespace llvm; 46480093f4SDimitry Andric 47480093f4SDimitry Andric #define DEBUG_TYPE "mve-gather-scatter-lowering" 48480093f4SDimitry Andric 49480093f4SDimitry Andric cl::opt<bool> EnableMaskedGatherScatters( 50480093f4SDimitry Andric "enable-arm-maskedgatscat", cl::Hidden, cl::init(false), 51480093f4SDimitry Andric cl::desc("Enable the generation of masked gathers and scatters")); 52480093f4SDimitry Andric 53480093f4SDimitry Andric namespace { 54480093f4SDimitry Andric 55480093f4SDimitry Andric class MVEGatherScatterLowering : public FunctionPass { 56480093f4SDimitry Andric public: 57480093f4SDimitry Andric static char ID; // Pass identification, replacement for typeid 58480093f4SDimitry Andric 59480093f4SDimitry Andric explicit MVEGatherScatterLowering() : FunctionPass(ID) { 60480093f4SDimitry Andric initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry()); 61480093f4SDimitry Andric } 62480093f4SDimitry Andric 63480093f4SDimitry Andric bool runOnFunction(Function &F) override; 64480093f4SDimitry Andric 65480093f4SDimitry Andric StringRef getPassName() const override { 66480093f4SDimitry Andric return "MVE gather/scatter lowering"; 67480093f4SDimitry Andric } 68480093f4SDimitry Andric 69480093f4SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 70480093f4SDimitry Andric AU.setPreservesCFG(); 71480093f4SDimitry Andric AU.addRequired<TargetPassConfig>(); 72*5ffd83dbSDimitry Andric AU.addRequired<LoopInfoWrapperPass>(); 73480093f4SDimitry Andric FunctionPass::getAnalysisUsage(AU); 74480093f4SDimitry Andric } 75480093f4SDimitry Andric 76480093f4SDimitry Andric private: 77*5ffd83dbSDimitry Andric LoopInfo *LI = nullptr; 78*5ffd83dbSDimitry Andric 79480093f4SDimitry Andric // Check this is a valid gather with correct alignment 80480093f4SDimitry Andric bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize, 81*5ffd83dbSDimitry Andric Align Alignment); 82480093f4SDimitry Andric // Check whether Ptr is hidden behind a bitcast and look through it 83480093f4SDimitry Andric void lookThroughBitcast(Value *&Ptr); 84480093f4SDimitry Andric // Check for a getelementptr and deduce base and offsets from it, on success 85480093f4SDimitry Andric // returning the base directly and the offsets indirectly using the Offsets 86480093f4SDimitry Andric // argument 87*5ffd83dbSDimitry Andric Value *checkGEP(Value *&Offsets, Type *Ty, GetElementPtrInst *GEP, 88*5ffd83dbSDimitry Andric IRBuilder<> &Builder); 89*5ffd83dbSDimitry Andric // Compute the scale of this gather/scatter instruction 90*5ffd83dbSDimitry Andric int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize); 91*5ffd83dbSDimitry Andric // If the value is a constant, or derived from constants via additions 92*5ffd83dbSDimitry Andric // and multilications, return its numeric value 93*5ffd83dbSDimitry Andric Optional<int64_t> getIfConst(const Value *V); 94*5ffd83dbSDimitry Andric // If Inst is an add instruction, check whether one summand is a 95*5ffd83dbSDimitry Andric // constant. If so, scale this constant and return it together with 96*5ffd83dbSDimitry Andric // the other summand. 97*5ffd83dbSDimitry Andric std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale); 98480093f4SDimitry Andric 99*5ffd83dbSDimitry Andric Value *lowerGather(IntrinsicInst *I); 100480093f4SDimitry Andric // Create a gather from a base + vector of offsets 101480093f4SDimitry Andric Value *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr, 102*5ffd83dbSDimitry Andric Instruction *&Root, IRBuilder<> &Builder); 103480093f4SDimitry Andric // Create a gather from a vector of pointers 104480093f4SDimitry Andric Value *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr, 105*5ffd83dbSDimitry Andric IRBuilder<> &Builder, int64_t Increment = 0); 106*5ffd83dbSDimitry Andric // Create an incrementing gather from a vector of pointers 107*5ffd83dbSDimitry Andric Value *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr, 108*5ffd83dbSDimitry Andric IRBuilder<> &Builder, 109*5ffd83dbSDimitry Andric int64_t Increment = 0); 110*5ffd83dbSDimitry Andric 111*5ffd83dbSDimitry Andric Value *lowerScatter(IntrinsicInst *I); 112*5ffd83dbSDimitry Andric // Create a scatter to a base + vector of offsets 113*5ffd83dbSDimitry Andric Value *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets, 114*5ffd83dbSDimitry Andric IRBuilder<> &Builder); 115*5ffd83dbSDimitry Andric // Create a scatter to a vector of pointers 116*5ffd83dbSDimitry Andric Value *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr, 117*5ffd83dbSDimitry Andric IRBuilder<> &Builder, 118*5ffd83dbSDimitry Andric int64_t Increment = 0); 119*5ffd83dbSDimitry Andric // Create an incrementing scatter from a vector of pointers 120*5ffd83dbSDimitry Andric Value *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr, 121*5ffd83dbSDimitry Andric IRBuilder<> &Builder, 122*5ffd83dbSDimitry Andric int64_t Increment = 0); 123*5ffd83dbSDimitry Andric 124*5ffd83dbSDimitry Andric // QI gathers and scatters can increment their offsets on their own if 125*5ffd83dbSDimitry Andric // the increment is a constant value (digit) 126*5ffd83dbSDimitry Andric Value *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *BasePtr, 127*5ffd83dbSDimitry Andric Value *Ptr, GetElementPtrInst *GEP, 128*5ffd83dbSDimitry Andric IRBuilder<> &Builder); 129*5ffd83dbSDimitry Andric // QI gathers/scatters can increment their offsets on their own if the 130*5ffd83dbSDimitry Andric // increment is a constant value (digit) - this creates a writeback QI 131*5ffd83dbSDimitry Andric // gather/scatter 132*5ffd83dbSDimitry Andric Value *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr, 133*5ffd83dbSDimitry Andric Value *Ptr, unsigned TypeScale, 134*5ffd83dbSDimitry Andric IRBuilder<> &Builder); 135*5ffd83dbSDimitry Andric // Check whether these offsets could be moved out of the loop they're in 136*5ffd83dbSDimitry Andric bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI); 137*5ffd83dbSDimitry Andric // Pushes the given add out of the loop 138*5ffd83dbSDimitry Andric void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex); 139*5ffd83dbSDimitry Andric // Pushes the given mul out of the loop 140*5ffd83dbSDimitry Andric void pushOutMul(PHINode *&Phi, Value *IncrementPerRound, 141*5ffd83dbSDimitry Andric Value *OffsSecondOperand, unsigned LoopIncrement, 142*5ffd83dbSDimitry Andric IRBuilder<> &Builder); 143480093f4SDimitry Andric }; 144480093f4SDimitry Andric 145480093f4SDimitry Andric } // end anonymous namespace 146480093f4SDimitry Andric 147480093f4SDimitry Andric char MVEGatherScatterLowering::ID = 0; 148480093f4SDimitry Andric 149480093f4SDimitry Andric INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE, 150480093f4SDimitry Andric "MVE gather/scattering lowering pass", false, false) 151480093f4SDimitry Andric 152480093f4SDimitry Andric Pass *llvm::createMVEGatherScatterLoweringPass() { 153480093f4SDimitry Andric return new MVEGatherScatterLowering(); 154480093f4SDimitry Andric } 155480093f4SDimitry Andric 156480093f4SDimitry Andric bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements, 157480093f4SDimitry Andric unsigned ElemSize, 158*5ffd83dbSDimitry Andric Align Alignment) { 159*5ffd83dbSDimitry Andric if (((NumElements == 4 && 160*5ffd83dbSDimitry Andric (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) || 161*5ffd83dbSDimitry Andric (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) || 162480093f4SDimitry Andric (NumElements == 16 && ElemSize == 8)) && 163*5ffd83dbSDimitry Andric Alignment >= ElemSize / 8) 164480093f4SDimitry Andric return true; 165*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have " 166*5ffd83dbSDimitry Andric << "valid alignment or vector type \n"); 167480093f4SDimitry Andric return false; 168480093f4SDimitry Andric } 169480093f4SDimitry Andric 170*5ffd83dbSDimitry Andric Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, Type *Ty, 171*5ffd83dbSDimitry Andric GetElementPtrInst *GEP, 172*5ffd83dbSDimitry Andric IRBuilder<> &Builder) { 173480093f4SDimitry Andric if (!GEP) { 174*5ffd83dbSDimitry Andric LLVM_DEBUG( 175*5ffd83dbSDimitry Andric dbgs() << "masked gathers/scatters: no getelementpointer found\n"); 176480093f4SDimitry Andric return nullptr; 177480093f4SDimitry Andric } 178*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found." 179*5ffd83dbSDimitry Andric << " Looking at intrinsic for base + vector of offsets\n"); 180480093f4SDimitry Andric Value *GEPPtr = GEP->getPointerOperand(); 181480093f4SDimitry Andric if (GEPPtr->getType()->isVectorTy()) { 182480093f4SDimitry Andric return nullptr; 183480093f4SDimitry Andric } 184480093f4SDimitry Andric if (GEP->getNumOperands() != 2) { 185*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many" 186480093f4SDimitry Andric << " operands. Expanding.\n"); 187480093f4SDimitry Andric return nullptr; 188480093f4SDimitry Andric } 189480093f4SDimitry Andric Offsets = GEP->getOperand(1); 190*5ffd83dbSDimitry Andric // Paranoid check whether the number of parallel lanes is the same 191*5ffd83dbSDimitry Andric assert(cast<FixedVectorType>(Ty)->getNumElements() == 192*5ffd83dbSDimitry Andric cast<FixedVectorType>(Offsets->getType())->getNumElements()); 193*5ffd83dbSDimitry Andric // Only <N x i32> offsets can be integrated into an arm gather, any smaller 194*5ffd83dbSDimitry Andric // type would have to be sign extended by the gep - and arm gathers can only 195*5ffd83dbSDimitry Andric // zero extend. Additionally, the offsets do have to originate from a zext of 196*5ffd83dbSDimitry Andric // a vector with element types smaller or equal the type of the gather we're 197*5ffd83dbSDimitry Andric // looking at 198*5ffd83dbSDimitry Andric if (Offsets->getType()->getScalarSizeInBits() != 32) 199*5ffd83dbSDimitry Andric return nullptr; 200480093f4SDimitry Andric if (ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets)) 201480093f4SDimitry Andric Offsets = ZextOffs->getOperand(0); 202*5ffd83dbSDimitry Andric else if (!(cast<FixedVectorType>(Offsets->getType())->getNumElements() == 4 && 203*5ffd83dbSDimitry Andric Offsets->getType()->getScalarSizeInBits() == 32)) 204480093f4SDimitry Andric return nullptr; 205*5ffd83dbSDimitry Andric 206*5ffd83dbSDimitry Andric if (Ty != Offsets->getType()) { 207*5ffd83dbSDimitry Andric if ((Ty->getScalarSizeInBits() < 208*5ffd83dbSDimitry Andric Offsets->getType()->getScalarSizeInBits())) { 209*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: no correct offset type." 210*5ffd83dbSDimitry Andric << " Can't create intrinsic.\n"); 211*5ffd83dbSDimitry Andric return nullptr; 212*5ffd83dbSDimitry Andric } else { 213*5ffd83dbSDimitry Andric Offsets = Builder.CreateZExt( 214*5ffd83dbSDimitry Andric Offsets, VectorType::getInteger(cast<VectorType>(Ty))); 215480093f4SDimitry Andric } 216480093f4SDimitry Andric } 217480093f4SDimitry Andric // If none of the checks failed, return the gep's base pointer 218*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n"); 219480093f4SDimitry Andric return GEPPtr; 220480093f4SDimitry Andric } 221480093f4SDimitry Andric 222480093f4SDimitry Andric void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) { 223480093f4SDimitry Andric // Look through bitcast instruction if #elements is the same 224480093f4SDimitry Andric if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) { 225*5ffd83dbSDimitry Andric auto *BCTy = cast<FixedVectorType>(BitCast->getType()); 226*5ffd83dbSDimitry Andric auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType()); 227*5ffd83dbSDimitry Andric if (BCTy->getNumElements() == BCSrcTy->getNumElements()) { 228*5ffd83dbSDimitry Andric LLVM_DEBUG( 229*5ffd83dbSDimitry Andric dbgs() << "masked gathers/scatters: looking through bitcast\n"); 230480093f4SDimitry Andric Ptr = BitCast->getOperand(0); 231480093f4SDimitry Andric } 232480093f4SDimitry Andric } 233480093f4SDimitry Andric } 234480093f4SDimitry Andric 235*5ffd83dbSDimitry Andric int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize, 236*5ffd83dbSDimitry Andric unsigned MemoryElemSize) { 237*5ffd83dbSDimitry Andric // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2, 238*5ffd83dbSDimitry Andric // or a 8bit, 16bit or 32bit load/store scaled by 1 239*5ffd83dbSDimitry Andric if (GEPElemSize == 32 && MemoryElemSize == 32) 240*5ffd83dbSDimitry Andric return 2; 241*5ffd83dbSDimitry Andric else if (GEPElemSize == 16 && MemoryElemSize == 16) 242*5ffd83dbSDimitry Andric return 1; 243*5ffd83dbSDimitry Andric else if (GEPElemSize == 8) 244*5ffd83dbSDimitry Andric return 0; 245*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't " 246*5ffd83dbSDimitry Andric << "create intrinsic\n"); 247*5ffd83dbSDimitry Andric return -1; 248*5ffd83dbSDimitry Andric } 249*5ffd83dbSDimitry Andric 250*5ffd83dbSDimitry Andric Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) { 251*5ffd83dbSDimitry Andric const Constant *C = dyn_cast<Constant>(V); 252*5ffd83dbSDimitry Andric if (C != nullptr) 253*5ffd83dbSDimitry Andric return Optional<int64_t>{C->getUniqueInteger().getSExtValue()}; 254*5ffd83dbSDimitry Andric if (!isa<Instruction>(V)) 255*5ffd83dbSDimitry Andric return Optional<int64_t>{}; 256*5ffd83dbSDimitry Andric 257*5ffd83dbSDimitry Andric const Instruction *I = cast<Instruction>(V); 258*5ffd83dbSDimitry Andric if (I->getOpcode() == Instruction::Add || 259*5ffd83dbSDimitry Andric I->getOpcode() == Instruction::Mul) { 260*5ffd83dbSDimitry Andric Optional<int64_t> Op0 = getIfConst(I->getOperand(0)); 261*5ffd83dbSDimitry Andric Optional<int64_t> Op1 = getIfConst(I->getOperand(1)); 262*5ffd83dbSDimitry Andric if (!Op0 || !Op1) 263*5ffd83dbSDimitry Andric return Optional<int64_t>{}; 264*5ffd83dbSDimitry Andric if (I->getOpcode() == Instruction::Add) 265*5ffd83dbSDimitry Andric return Optional<int64_t>{Op0.getValue() + Op1.getValue()}; 266*5ffd83dbSDimitry Andric if (I->getOpcode() == Instruction::Mul) 267*5ffd83dbSDimitry Andric return Optional<int64_t>{Op0.getValue() * Op1.getValue()}; 268*5ffd83dbSDimitry Andric } 269*5ffd83dbSDimitry Andric return Optional<int64_t>{}; 270*5ffd83dbSDimitry Andric } 271*5ffd83dbSDimitry Andric 272*5ffd83dbSDimitry Andric std::pair<Value *, int64_t> 273*5ffd83dbSDimitry Andric MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) { 274*5ffd83dbSDimitry Andric std::pair<Value *, int64_t> ReturnFalse = 275*5ffd83dbSDimitry Andric std::pair<Value *, int64_t>(nullptr, 0); 276*5ffd83dbSDimitry Andric // At this point, the instruction we're looking at must be an add or we 277*5ffd83dbSDimitry Andric // bail out 278*5ffd83dbSDimitry Andric Instruction *Add = dyn_cast<Instruction>(Inst); 279*5ffd83dbSDimitry Andric if (Add == nullptr || Add->getOpcode() != Instruction::Add) 280*5ffd83dbSDimitry Andric return ReturnFalse; 281*5ffd83dbSDimitry Andric 282*5ffd83dbSDimitry Andric Value *Summand; 283*5ffd83dbSDimitry Andric Optional<int64_t> Const; 284*5ffd83dbSDimitry Andric // Find out which operand the value that is increased is 285*5ffd83dbSDimitry Andric if ((Const = getIfConst(Add->getOperand(0)))) 286*5ffd83dbSDimitry Andric Summand = Add->getOperand(1); 287*5ffd83dbSDimitry Andric else if ((Const = getIfConst(Add->getOperand(1)))) 288*5ffd83dbSDimitry Andric Summand = Add->getOperand(0); 289*5ffd83dbSDimitry Andric else 290*5ffd83dbSDimitry Andric return ReturnFalse; 291*5ffd83dbSDimitry Andric 292*5ffd83dbSDimitry Andric // Check that the constant is small enough for an incrementing gather 293*5ffd83dbSDimitry Andric int64_t Immediate = Const.getValue() << TypeScale; 294*5ffd83dbSDimitry Andric if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0) 295*5ffd83dbSDimitry Andric return ReturnFalse; 296*5ffd83dbSDimitry Andric 297*5ffd83dbSDimitry Andric return std::pair<Value *, int64_t>(Summand, Immediate); 298*5ffd83dbSDimitry Andric } 299*5ffd83dbSDimitry Andric 300*5ffd83dbSDimitry Andric Value *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) { 301480093f4SDimitry Andric using namespace PatternMatch; 302480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n"); 303480093f4SDimitry Andric 304480093f4SDimitry Andric // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0) 305480093f4SDimitry Andric // Attempt to turn the masked gather in I into a MVE intrinsic 306480093f4SDimitry Andric // Potentially optimising the addressing modes as we do so. 307*5ffd83dbSDimitry Andric auto *Ty = cast<FixedVectorType>(I->getType()); 308480093f4SDimitry Andric Value *Ptr = I->getArgOperand(0); 309*5ffd83dbSDimitry Andric Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue(); 310480093f4SDimitry Andric Value *Mask = I->getArgOperand(2); 311480093f4SDimitry Andric Value *PassThru = I->getArgOperand(3); 312480093f4SDimitry Andric 313*5ffd83dbSDimitry Andric if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), 314*5ffd83dbSDimitry Andric Alignment)) 315*5ffd83dbSDimitry Andric return nullptr; 316480093f4SDimitry Andric lookThroughBitcast(Ptr); 317480093f4SDimitry Andric assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); 318480093f4SDimitry Andric 319480093f4SDimitry Andric IRBuilder<> Builder(I->getContext()); 320480093f4SDimitry Andric Builder.SetInsertPoint(I); 321480093f4SDimitry Andric Builder.SetCurrentDebugLocation(I->getDebugLoc()); 322*5ffd83dbSDimitry Andric 323*5ffd83dbSDimitry Andric Instruction *Root = I; 324*5ffd83dbSDimitry Andric Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder); 325480093f4SDimitry Andric if (!Load) 326480093f4SDimitry Andric Load = tryCreateMaskedGatherBase(I, Ptr, Builder); 327480093f4SDimitry Andric if (!Load) 328*5ffd83dbSDimitry Andric return nullptr; 329480093f4SDimitry Andric 330480093f4SDimitry Andric if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) { 331480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - " 332480093f4SDimitry Andric << "creating select\n"); 333480093f4SDimitry Andric Load = Builder.CreateSelect(Mask, Load, PassThru); 334480093f4SDimitry Andric } 335480093f4SDimitry Andric 336*5ffd83dbSDimitry Andric Root->replaceAllUsesWith(Load); 337*5ffd83dbSDimitry Andric Root->eraseFromParent(); 338*5ffd83dbSDimitry Andric if (Root != I) 339*5ffd83dbSDimitry Andric // If this was an extending gather, we need to get rid of the sext/zext 340*5ffd83dbSDimitry Andric // sext/zext as well as of the gather itself 341480093f4SDimitry Andric I->eraseFromParent(); 342*5ffd83dbSDimitry Andric 343*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n"); 344*5ffd83dbSDimitry Andric return Load; 345480093f4SDimitry Andric } 346480093f4SDimitry Andric 347*5ffd83dbSDimitry Andric Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase(IntrinsicInst *I, 348*5ffd83dbSDimitry Andric Value *Ptr, 349*5ffd83dbSDimitry Andric IRBuilder<> &Builder, 350*5ffd83dbSDimitry Andric int64_t Increment) { 351480093f4SDimitry Andric using namespace PatternMatch; 352*5ffd83dbSDimitry Andric auto *Ty = cast<FixedVectorType>(I->getType()); 353480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n"); 354*5ffd83dbSDimitry Andric if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 355480093f4SDimitry Andric // Can't build an intrinsic for this 356480093f4SDimitry Andric return nullptr; 357480093f4SDimitry Andric Value *Mask = I->getArgOperand(2); 358480093f4SDimitry Andric if (match(Mask, m_One())) 359480093f4SDimitry Andric return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base, 360480093f4SDimitry Andric {Ty, Ptr->getType()}, 361*5ffd83dbSDimitry Andric {Ptr, Builder.getInt32(Increment)}); 362480093f4SDimitry Andric else 363480093f4SDimitry Andric return Builder.CreateIntrinsic( 364480093f4SDimitry Andric Intrinsic::arm_mve_vldr_gather_base_predicated, 365480093f4SDimitry Andric {Ty, Ptr->getType(), Mask->getType()}, 366*5ffd83dbSDimitry Andric {Ptr, Builder.getInt32(Increment), Mask}); 367*5ffd83dbSDimitry Andric } 368*5ffd83dbSDimitry Andric 369*5ffd83dbSDimitry Andric Value *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB( 370*5ffd83dbSDimitry Andric IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 371*5ffd83dbSDimitry Andric using namespace PatternMatch; 372*5ffd83dbSDimitry Andric auto *Ty = cast<FixedVectorType>(I->getType()); 373*5ffd83dbSDimitry Andric LLVM_DEBUG( 374*5ffd83dbSDimitry Andric dbgs() 375*5ffd83dbSDimitry Andric << "masked gathers: loading from vector of pointers with writeback\n"); 376*5ffd83dbSDimitry Andric if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 377*5ffd83dbSDimitry Andric // Can't build an intrinsic for this 378*5ffd83dbSDimitry Andric return nullptr; 379*5ffd83dbSDimitry Andric Value *Mask = I->getArgOperand(2); 380*5ffd83dbSDimitry Andric if (match(Mask, m_One())) 381*5ffd83dbSDimitry Andric return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb, 382*5ffd83dbSDimitry Andric {Ty, Ptr->getType()}, 383*5ffd83dbSDimitry Andric {Ptr, Builder.getInt32(Increment)}); 384*5ffd83dbSDimitry Andric else 385*5ffd83dbSDimitry Andric return Builder.CreateIntrinsic( 386*5ffd83dbSDimitry Andric Intrinsic::arm_mve_vldr_gather_base_wb_predicated, 387*5ffd83dbSDimitry Andric {Ty, Ptr->getType(), Mask->getType()}, 388*5ffd83dbSDimitry Andric {Ptr, Builder.getInt32(Increment), Mask}); 389480093f4SDimitry Andric } 390480093f4SDimitry Andric 391480093f4SDimitry Andric Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset( 392*5ffd83dbSDimitry Andric IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) { 393480093f4SDimitry Andric using namespace PatternMatch; 394*5ffd83dbSDimitry Andric 395*5ffd83dbSDimitry Andric Type *OriginalTy = I->getType(); 396*5ffd83dbSDimitry Andric Type *ResultTy = OriginalTy; 397*5ffd83dbSDimitry Andric 398*5ffd83dbSDimitry Andric unsigned Unsigned = 1; 399*5ffd83dbSDimitry Andric // The size of the gather was already checked in isLegalTypeAndAlignment; 400*5ffd83dbSDimitry Andric // if it was not a full vector width an appropriate extend should follow. 401*5ffd83dbSDimitry Andric auto *Extend = Root; 402*5ffd83dbSDimitry Andric if (OriginalTy->getPrimitiveSizeInBits() < 128) { 403*5ffd83dbSDimitry Andric // Only transform gathers with exactly one use 404*5ffd83dbSDimitry Andric if (!I->hasOneUse()) 405480093f4SDimitry Andric return nullptr; 406480093f4SDimitry Andric 407*5ffd83dbSDimitry Andric // The correct root to replace is not the CallInst itself, but the 408*5ffd83dbSDimitry Andric // instruction which extends it 409*5ffd83dbSDimitry Andric Extend = cast<Instruction>(*I->users().begin()); 410*5ffd83dbSDimitry Andric if (isa<SExtInst>(Extend)) { 411*5ffd83dbSDimitry Andric Unsigned = 0; 412*5ffd83dbSDimitry Andric } else if (!isa<ZExtInst>(Extend)) { 413*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: extend needed but not provided. " 414*5ffd83dbSDimitry Andric << "Expanding\n"); 415480093f4SDimitry Andric return nullptr; 416480093f4SDimitry Andric } 417*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: found an extending gather\n"); 418*5ffd83dbSDimitry Andric ResultTy = Extend->getType(); 419*5ffd83dbSDimitry Andric // The final size of the gather must be a full vector width 420*5ffd83dbSDimitry Andric if (ResultTy->getPrimitiveSizeInBits() != 128) { 421*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: extending from the wrong type. " 422*5ffd83dbSDimitry Andric << "Expanding\n"); 423*5ffd83dbSDimitry Andric return nullptr; 424*5ffd83dbSDimitry Andric } 425*5ffd83dbSDimitry Andric } 426*5ffd83dbSDimitry Andric 427*5ffd83dbSDimitry Andric GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); 428*5ffd83dbSDimitry Andric Value *Offsets; 429*5ffd83dbSDimitry Andric Value *BasePtr = checkGEP(Offsets, ResultTy, GEP, Builder); 430*5ffd83dbSDimitry Andric if (!BasePtr) 431*5ffd83dbSDimitry Andric return nullptr; 432*5ffd83dbSDimitry Andric // Check whether the offset is a constant increment that could be merged into 433*5ffd83dbSDimitry Andric // a QI gather 434*5ffd83dbSDimitry Andric Value *Load = tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder); 435*5ffd83dbSDimitry Andric if (Load) 436*5ffd83dbSDimitry Andric return Load; 437*5ffd83dbSDimitry Andric 438*5ffd83dbSDimitry Andric int Scale = computeScale( 439*5ffd83dbSDimitry Andric BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(), 440*5ffd83dbSDimitry Andric OriginalTy->getScalarSizeInBits()); 441*5ffd83dbSDimitry Andric if (Scale == -1) 442*5ffd83dbSDimitry Andric return nullptr; 443*5ffd83dbSDimitry Andric Root = Extend; 444480093f4SDimitry Andric 445480093f4SDimitry Andric Value *Mask = I->getArgOperand(2); 446480093f4SDimitry Andric if (!match(Mask, m_One())) 447480093f4SDimitry Andric return Builder.CreateIntrinsic( 448480093f4SDimitry Andric Intrinsic::arm_mve_vldr_gather_offset_predicated, 449*5ffd83dbSDimitry Andric {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()}, 450*5ffd83dbSDimitry Andric {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()), 451*5ffd83dbSDimitry Andric Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask}); 452480093f4SDimitry Andric else 453480093f4SDimitry Andric return Builder.CreateIntrinsic( 454480093f4SDimitry Andric Intrinsic::arm_mve_vldr_gather_offset, 455*5ffd83dbSDimitry Andric {ResultTy, BasePtr->getType(), Offsets->getType()}, 456*5ffd83dbSDimitry Andric {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()), 457*5ffd83dbSDimitry Andric Builder.getInt32(Scale), Builder.getInt32(Unsigned)}); 458*5ffd83dbSDimitry Andric } 459*5ffd83dbSDimitry Andric 460*5ffd83dbSDimitry Andric Value *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) { 461*5ffd83dbSDimitry Andric using namespace PatternMatch; 462*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n"); 463*5ffd83dbSDimitry Andric 464*5ffd83dbSDimitry Andric // @llvm.masked.scatter.*(data, ptrs, alignment, mask) 465*5ffd83dbSDimitry Andric // Attempt to turn the masked scatter in I into a MVE intrinsic 466*5ffd83dbSDimitry Andric // Potentially optimising the addressing modes as we do so. 467*5ffd83dbSDimitry Andric Value *Input = I->getArgOperand(0); 468*5ffd83dbSDimitry Andric Value *Ptr = I->getArgOperand(1); 469*5ffd83dbSDimitry Andric Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue(); 470*5ffd83dbSDimitry Andric auto *Ty = cast<FixedVectorType>(Input->getType()); 471*5ffd83dbSDimitry Andric 472*5ffd83dbSDimitry Andric if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), 473*5ffd83dbSDimitry Andric Alignment)) 474*5ffd83dbSDimitry Andric return nullptr; 475*5ffd83dbSDimitry Andric 476*5ffd83dbSDimitry Andric lookThroughBitcast(Ptr); 477*5ffd83dbSDimitry Andric assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); 478*5ffd83dbSDimitry Andric 479*5ffd83dbSDimitry Andric IRBuilder<> Builder(I->getContext()); 480*5ffd83dbSDimitry Andric Builder.SetInsertPoint(I); 481*5ffd83dbSDimitry Andric Builder.SetCurrentDebugLocation(I->getDebugLoc()); 482*5ffd83dbSDimitry Andric 483*5ffd83dbSDimitry Andric Value *Store = tryCreateMaskedScatterOffset(I, Ptr, Builder); 484*5ffd83dbSDimitry Andric if (!Store) 485*5ffd83dbSDimitry Andric Store = tryCreateMaskedScatterBase(I, Ptr, Builder); 486*5ffd83dbSDimitry Andric if (!Store) 487*5ffd83dbSDimitry Andric return nullptr; 488*5ffd83dbSDimitry Andric 489*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n"); 490*5ffd83dbSDimitry Andric I->eraseFromParent(); 491*5ffd83dbSDimitry Andric return Store; 492*5ffd83dbSDimitry Andric } 493*5ffd83dbSDimitry Andric 494*5ffd83dbSDimitry Andric Value *MVEGatherScatterLowering::tryCreateMaskedScatterBase( 495*5ffd83dbSDimitry Andric IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 496*5ffd83dbSDimitry Andric using namespace PatternMatch; 497*5ffd83dbSDimitry Andric Value *Input = I->getArgOperand(0); 498*5ffd83dbSDimitry Andric auto *Ty = cast<FixedVectorType>(Input->getType()); 499*5ffd83dbSDimitry Andric // Only QR variants allow truncating 500*5ffd83dbSDimitry Andric if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) { 501*5ffd83dbSDimitry Andric // Can't build an intrinsic for this 502*5ffd83dbSDimitry Andric return nullptr; 503*5ffd83dbSDimitry Andric } 504*5ffd83dbSDimitry Andric Value *Mask = I->getArgOperand(3); 505*5ffd83dbSDimitry Andric // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask) 506*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n"); 507*5ffd83dbSDimitry Andric if (match(Mask, m_One())) 508*5ffd83dbSDimitry Andric return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base, 509*5ffd83dbSDimitry Andric {Ptr->getType(), Input->getType()}, 510*5ffd83dbSDimitry Andric {Ptr, Builder.getInt32(Increment), Input}); 511*5ffd83dbSDimitry Andric else 512*5ffd83dbSDimitry Andric return Builder.CreateIntrinsic( 513*5ffd83dbSDimitry Andric Intrinsic::arm_mve_vstr_scatter_base_predicated, 514*5ffd83dbSDimitry Andric {Ptr->getType(), Input->getType(), Mask->getType()}, 515*5ffd83dbSDimitry Andric {Ptr, Builder.getInt32(Increment), Input, Mask}); 516*5ffd83dbSDimitry Andric } 517*5ffd83dbSDimitry Andric 518*5ffd83dbSDimitry Andric Value *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB( 519*5ffd83dbSDimitry Andric IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 520*5ffd83dbSDimitry Andric using namespace PatternMatch; 521*5ffd83dbSDimitry Andric Value *Input = I->getArgOperand(0); 522*5ffd83dbSDimitry Andric auto *Ty = cast<FixedVectorType>(Input->getType()); 523*5ffd83dbSDimitry Andric LLVM_DEBUG( 524*5ffd83dbSDimitry Andric dbgs() 525*5ffd83dbSDimitry Andric << "masked scatters: storing to a vector of pointers with writeback\n"); 526*5ffd83dbSDimitry Andric if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 527*5ffd83dbSDimitry Andric // Can't build an intrinsic for this 528*5ffd83dbSDimitry Andric return nullptr; 529*5ffd83dbSDimitry Andric Value *Mask = I->getArgOperand(3); 530*5ffd83dbSDimitry Andric if (match(Mask, m_One())) 531*5ffd83dbSDimitry Andric return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb, 532*5ffd83dbSDimitry Andric {Ptr->getType(), Input->getType()}, 533*5ffd83dbSDimitry Andric {Ptr, Builder.getInt32(Increment), Input}); 534*5ffd83dbSDimitry Andric else 535*5ffd83dbSDimitry Andric return Builder.CreateIntrinsic( 536*5ffd83dbSDimitry Andric Intrinsic::arm_mve_vstr_scatter_base_wb_predicated, 537*5ffd83dbSDimitry Andric {Ptr->getType(), Input->getType(), Mask->getType()}, 538*5ffd83dbSDimitry Andric {Ptr, Builder.getInt32(Increment), Input, Mask}); 539*5ffd83dbSDimitry Andric } 540*5ffd83dbSDimitry Andric 541*5ffd83dbSDimitry Andric Value *MVEGatherScatterLowering::tryCreateMaskedScatterOffset( 542*5ffd83dbSDimitry Andric IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) { 543*5ffd83dbSDimitry Andric using namespace PatternMatch; 544*5ffd83dbSDimitry Andric Value *Input = I->getArgOperand(0); 545*5ffd83dbSDimitry Andric Value *Mask = I->getArgOperand(3); 546*5ffd83dbSDimitry Andric Type *InputTy = Input->getType(); 547*5ffd83dbSDimitry Andric Type *MemoryTy = InputTy; 548*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing" 549*5ffd83dbSDimitry Andric << " to base + vector of offsets\n"); 550*5ffd83dbSDimitry Andric // If the input has been truncated, try to integrate that trunc into the 551*5ffd83dbSDimitry Andric // scatter instruction (we don't care about alignment here) 552*5ffd83dbSDimitry Andric if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) { 553*5ffd83dbSDimitry Andric Value *PreTrunc = Trunc->getOperand(0); 554*5ffd83dbSDimitry Andric Type *PreTruncTy = PreTrunc->getType(); 555*5ffd83dbSDimitry Andric if (PreTruncTy->getPrimitiveSizeInBits() == 128) { 556*5ffd83dbSDimitry Andric Input = PreTrunc; 557*5ffd83dbSDimitry Andric InputTy = PreTruncTy; 558*5ffd83dbSDimitry Andric } 559*5ffd83dbSDimitry Andric } 560*5ffd83dbSDimitry Andric if (InputTy->getPrimitiveSizeInBits() != 128) { 561*5ffd83dbSDimitry Andric LLVM_DEBUG( 562*5ffd83dbSDimitry Andric dbgs() << "masked scatters: cannot create scatters for non-standard" 563*5ffd83dbSDimitry Andric << " input types. Expanding.\n"); 564*5ffd83dbSDimitry Andric return nullptr; 565*5ffd83dbSDimitry Andric } 566*5ffd83dbSDimitry Andric 567*5ffd83dbSDimitry Andric GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); 568*5ffd83dbSDimitry Andric Value *Offsets; 569*5ffd83dbSDimitry Andric Value *BasePtr = checkGEP(Offsets, InputTy, GEP, Builder); 570*5ffd83dbSDimitry Andric if (!BasePtr) 571*5ffd83dbSDimitry Andric return nullptr; 572*5ffd83dbSDimitry Andric // Check whether the offset is a constant increment that could be merged into 573*5ffd83dbSDimitry Andric // a QI gather 574*5ffd83dbSDimitry Andric Value *Store = 575*5ffd83dbSDimitry Andric tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder); 576*5ffd83dbSDimitry Andric if (Store) 577*5ffd83dbSDimitry Andric return Store; 578*5ffd83dbSDimitry Andric int Scale = computeScale( 579*5ffd83dbSDimitry Andric BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(), 580*5ffd83dbSDimitry Andric MemoryTy->getScalarSizeInBits()); 581*5ffd83dbSDimitry Andric if (Scale == -1) 582*5ffd83dbSDimitry Andric return nullptr; 583*5ffd83dbSDimitry Andric 584*5ffd83dbSDimitry Andric if (!match(Mask, m_One())) 585*5ffd83dbSDimitry Andric return Builder.CreateIntrinsic( 586*5ffd83dbSDimitry Andric Intrinsic::arm_mve_vstr_scatter_offset_predicated, 587*5ffd83dbSDimitry Andric {BasePtr->getType(), Offsets->getType(), Input->getType(), 588*5ffd83dbSDimitry Andric Mask->getType()}, 589*5ffd83dbSDimitry Andric {BasePtr, Offsets, Input, 590*5ffd83dbSDimitry Andric Builder.getInt32(MemoryTy->getScalarSizeInBits()), 591*5ffd83dbSDimitry Andric Builder.getInt32(Scale), Mask}); 592*5ffd83dbSDimitry Andric else 593*5ffd83dbSDimitry Andric return Builder.CreateIntrinsic( 594*5ffd83dbSDimitry Andric Intrinsic::arm_mve_vstr_scatter_offset, 595*5ffd83dbSDimitry Andric {BasePtr->getType(), Offsets->getType(), Input->getType()}, 596*5ffd83dbSDimitry Andric {BasePtr, Offsets, Input, 597*5ffd83dbSDimitry Andric Builder.getInt32(MemoryTy->getScalarSizeInBits()), 598*5ffd83dbSDimitry Andric Builder.getInt32(Scale)}); 599*5ffd83dbSDimitry Andric } 600*5ffd83dbSDimitry Andric 601*5ffd83dbSDimitry Andric Value *MVEGatherScatterLowering::tryCreateIncrementingGatScat( 602*5ffd83dbSDimitry Andric IntrinsicInst *I, Value *BasePtr, Value *Offsets, GetElementPtrInst *GEP, 603*5ffd83dbSDimitry Andric IRBuilder<> &Builder) { 604*5ffd83dbSDimitry Andric FixedVectorType *Ty; 605*5ffd83dbSDimitry Andric if (I->getIntrinsicID() == Intrinsic::masked_gather) 606*5ffd83dbSDimitry Andric Ty = cast<FixedVectorType>(I->getType()); 607*5ffd83dbSDimitry Andric else 608*5ffd83dbSDimitry Andric Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType()); 609*5ffd83dbSDimitry Andric // Incrementing gathers only exist for v4i32 610*5ffd83dbSDimitry Andric if (Ty->getNumElements() != 4 || 611*5ffd83dbSDimitry Andric Ty->getScalarSizeInBits() != 32) 612*5ffd83dbSDimitry Andric return nullptr; 613*5ffd83dbSDimitry Andric Loop *L = LI->getLoopFor(I->getParent()); 614*5ffd83dbSDimitry Andric if (L == nullptr) 615*5ffd83dbSDimitry Andric // Incrementing gathers are not beneficial outside of a loop 616*5ffd83dbSDimitry Andric return nullptr; 617*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing " 618*5ffd83dbSDimitry Andric "wb gather/scatter\n"); 619*5ffd83dbSDimitry Andric 620*5ffd83dbSDimitry Andric // The gep was in charge of making sure the offsets are scaled correctly 621*5ffd83dbSDimitry Andric // - calculate that factor so it can be applied by hand 622*5ffd83dbSDimitry Andric DataLayout DT = I->getParent()->getParent()->getParent()->getDataLayout(); 623*5ffd83dbSDimitry Andric int TypeScale = 624*5ffd83dbSDimitry Andric computeScale(DT.getTypeSizeInBits(GEP->getOperand(0)->getType()), 625*5ffd83dbSDimitry Andric DT.getTypeSizeInBits(GEP->getType()) / 626*5ffd83dbSDimitry Andric cast<FixedVectorType>(GEP->getType())->getNumElements()); 627*5ffd83dbSDimitry Andric if (TypeScale == -1) 628*5ffd83dbSDimitry Andric return nullptr; 629*5ffd83dbSDimitry Andric 630*5ffd83dbSDimitry Andric if (GEP->hasOneUse()) { 631*5ffd83dbSDimitry Andric // Only in this case do we want to build a wb gather, because the wb will 632*5ffd83dbSDimitry Andric // change the phi which does affect other users of the gep (which will still 633*5ffd83dbSDimitry Andric // be using the phi in the old way) 634*5ffd83dbSDimitry Andric Value *Load = 635*5ffd83dbSDimitry Andric tryCreateIncrementingWBGatScat(I, BasePtr, Offsets, TypeScale, Builder); 636*5ffd83dbSDimitry Andric if (Load != nullptr) 637*5ffd83dbSDimitry Andric return Load; 638*5ffd83dbSDimitry Andric } 639*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing " 640*5ffd83dbSDimitry Andric "non-wb gather/scatter\n"); 641*5ffd83dbSDimitry Andric 642*5ffd83dbSDimitry Andric std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale); 643*5ffd83dbSDimitry Andric if (Add.first == nullptr) 644*5ffd83dbSDimitry Andric return nullptr; 645*5ffd83dbSDimitry Andric Value *OffsetsIncoming = Add.first; 646*5ffd83dbSDimitry Andric int64_t Immediate = Add.second; 647*5ffd83dbSDimitry Andric 648*5ffd83dbSDimitry Andric // Make sure the offsets are scaled correctly 649*5ffd83dbSDimitry Andric Instruction *ScaledOffsets = BinaryOperator::Create( 650*5ffd83dbSDimitry Andric Instruction::Shl, OffsetsIncoming, 651*5ffd83dbSDimitry Andric Builder.CreateVectorSplat(Ty->getNumElements(), Builder.getInt32(TypeScale)), 652*5ffd83dbSDimitry Andric "ScaledIndex", I); 653*5ffd83dbSDimitry Andric // Add the base to the offsets 654*5ffd83dbSDimitry Andric OffsetsIncoming = BinaryOperator::Create( 655*5ffd83dbSDimitry Andric Instruction::Add, ScaledOffsets, 656*5ffd83dbSDimitry Andric Builder.CreateVectorSplat( 657*5ffd83dbSDimitry Andric Ty->getNumElements(), 658*5ffd83dbSDimitry Andric Builder.CreatePtrToInt( 659*5ffd83dbSDimitry Andric BasePtr, 660*5ffd83dbSDimitry Andric cast<VectorType>(ScaledOffsets->getType())->getElementType())), 661*5ffd83dbSDimitry Andric "StartIndex", I); 662*5ffd83dbSDimitry Andric 663*5ffd83dbSDimitry Andric if (I->getIntrinsicID() == Intrinsic::masked_gather) 664*5ffd83dbSDimitry Andric return cast<IntrinsicInst>( 665*5ffd83dbSDimitry Andric tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate)); 666*5ffd83dbSDimitry Andric else 667*5ffd83dbSDimitry Andric return cast<IntrinsicInst>( 668*5ffd83dbSDimitry Andric tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate)); 669*5ffd83dbSDimitry Andric } 670*5ffd83dbSDimitry Andric 671*5ffd83dbSDimitry Andric Value *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat( 672*5ffd83dbSDimitry Andric IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale, 673*5ffd83dbSDimitry Andric IRBuilder<> &Builder) { 674*5ffd83dbSDimitry Andric // Check whether this gather's offset is incremented by a constant - if so, 675*5ffd83dbSDimitry Andric // and the load is of the right type, we can merge this into a QI gather 676*5ffd83dbSDimitry Andric Loop *L = LI->getLoopFor(I->getParent()); 677*5ffd83dbSDimitry Andric // Offsets that are worth merging into this instruction will be incremented 678*5ffd83dbSDimitry Andric // by a constant, thus we're looking for an add of a phi and a constant 679*5ffd83dbSDimitry Andric PHINode *Phi = dyn_cast<PHINode>(Offsets); 680*5ffd83dbSDimitry Andric if (Phi == nullptr || Phi->getNumIncomingValues() != 2 || 681*5ffd83dbSDimitry Andric Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2) 682*5ffd83dbSDimitry Andric // No phi means no IV to write back to; if there is a phi, we expect it 683*5ffd83dbSDimitry Andric // to have exactly two incoming values; the only phis we are interested in 684*5ffd83dbSDimitry Andric // will be loop IV's and have exactly two uses, one in their increment and 685*5ffd83dbSDimitry Andric // one in the gather's gep 686*5ffd83dbSDimitry Andric return nullptr; 687*5ffd83dbSDimitry Andric 688*5ffd83dbSDimitry Andric unsigned IncrementIndex = 689*5ffd83dbSDimitry Andric Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1; 690*5ffd83dbSDimitry Andric // Look through the phi to the phi increment 691*5ffd83dbSDimitry Andric Offsets = Phi->getIncomingValue(IncrementIndex); 692*5ffd83dbSDimitry Andric 693*5ffd83dbSDimitry Andric std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale); 694*5ffd83dbSDimitry Andric if (Add.first == nullptr) 695*5ffd83dbSDimitry Andric return nullptr; 696*5ffd83dbSDimitry Andric Value *OffsetsIncoming = Add.first; 697*5ffd83dbSDimitry Andric int64_t Immediate = Add.second; 698*5ffd83dbSDimitry Andric if (OffsetsIncoming != Phi) 699*5ffd83dbSDimitry Andric // Then the increment we are looking at is not an increment of the 700*5ffd83dbSDimitry Andric // induction variable, and we don't want to do a writeback 701*5ffd83dbSDimitry Andric return nullptr; 702*5ffd83dbSDimitry Andric 703*5ffd83dbSDimitry Andric Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back()); 704*5ffd83dbSDimitry Andric unsigned NumElems = 705*5ffd83dbSDimitry Andric cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements(); 706*5ffd83dbSDimitry Andric 707*5ffd83dbSDimitry Andric // Make sure the offsets are scaled correctly 708*5ffd83dbSDimitry Andric Instruction *ScaledOffsets = BinaryOperator::Create( 709*5ffd83dbSDimitry Andric Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex), 710*5ffd83dbSDimitry Andric Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)), 711*5ffd83dbSDimitry Andric "ScaledIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back()); 712*5ffd83dbSDimitry Andric // Add the base to the offsets 713*5ffd83dbSDimitry Andric OffsetsIncoming = BinaryOperator::Create( 714*5ffd83dbSDimitry Andric Instruction::Add, ScaledOffsets, 715*5ffd83dbSDimitry Andric Builder.CreateVectorSplat( 716*5ffd83dbSDimitry Andric NumElems, 717*5ffd83dbSDimitry Andric Builder.CreatePtrToInt( 718*5ffd83dbSDimitry Andric BasePtr, 719*5ffd83dbSDimitry Andric cast<VectorType>(ScaledOffsets->getType())->getElementType())), 720*5ffd83dbSDimitry Andric "StartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back()); 721*5ffd83dbSDimitry Andric // The gather is pre-incrementing 722*5ffd83dbSDimitry Andric OffsetsIncoming = BinaryOperator::Create( 723*5ffd83dbSDimitry Andric Instruction::Sub, OffsetsIncoming, 724*5ffd83dbSDimitry Andric Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)), 725*5ffd83dbSDimitry Andric "PreIncrementStartIndex", 726*5ffd83dbSDimitry Andric &Phi->getIncomingBlock(1 - IncrementIndex)->back()); 727*5ffd83dbSDimitry Andric Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming); 728*5ffd83dbSDimitry Andric 729*5ffd83dbSDimitry Andric Builder.SetInsertPoint(I); 730*5ffd83dbSDimitry Andric 731*5ffd83dbSDimitry Andric Value *EndResult; 732*5ffd83dbSDimitry Andric Value *NewInduction; 733*5ffd83dbSDimitry Andric if (I->getIntrinsicID() == Intrinsic::masked_gather) { 734*5ffd83dbSDimitry Andric // Build the incrementing gather 735*5ffd83dbSDimitry Andric Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate); 736*5ffd83dbSDimitry Andric // One value to be handed to whoever uses the gather, one is the loop 737*5ffd83dbSDimitry Andric // increment 738*5ffd83dbSDimitry Andric EndResult = Builder.CreateExtractValue(Load, 0, "Gather"); 739*5ffd83dbSDimitry Andric NewInduction = Builder.CreateExtractValue(Load, 1, "GatherIncrement"); 740*5ffd83dbSDimitry Andric } else { 741*5ffd83dbSDimitry Andric // Build the incrementing scatter 742*5ffd83dbSDimitry Andric NewInduction = tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate); 743*5ffd83dbSDimitry Andric EndResult = NewInduction; 744*5ffd83dbSDimitry Andric } 745*5ffd83dbSDimitry Andric Instruction *AddInst = cast<Instruction>(Offsets); 746*5ffd83dbSDimitry Andric AddInst->replaceAllUsesWith(NewInduction); 747*5ffd83dbSDimitry Andric AddInst->eraseFromParent(); 748*5ffd83dbSDimitry Andric Phi->setIncomingValue(IncrementIndex, NewInduction); 749*5ffd83dbSDimitry Andric 750*5ffd83dbSDimitry Andric return EndResult; 751*5ffd83dbSDimitry Andric } 752*5ffd83dbSDimitry Andric 753*5ffd83dbSDimitry Andric void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi, 754*5ffd83dbSDimitry Andric Value *OffsSecondOperand, 755*5ffd83dbSDimitry Andric unsigned StartIndex) { 756*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n"); 757*5ffd83dbSDimitry Andric Instruction *InsertionPoint = 758*5ffd83dbSDimitry Andric &cast<Instruction>(Phi->getIncomingBlock(StartIndex)->back()); 759*5ffd83dbSDimitry Andric // Initialize the phi with a vector that contains a sum of the constants 760*5ffd83dbSDimitry Andric Instruction *NewIndex = BinaryOperator::Create( 761*5ffd83dbSDimitry Andric Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand, 762*5ffd83dbSDimitry Andric "PushedOutAdd", InsertionPoint); 763*5ffd83dbSDimitry Andric unsigned IncrementIndex = StartIndex == 0 ? 1 : 0; 764*5ffd83dbSDimitry Andric 765*5ffd83dbSDimitry Andric // Order such that start index comes first (this reduces mov's) 766*5ffd83dbSDimitry Andric Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex)); 767*5ffd83dbSDimitry Andric Phi->addIncoming(Phi->getIncomingValue(IncrementIndex), 768*5ffd83dbSDimitry Andric Phi->getIncomingBlock(IncrementIndex)); 769*5ffd83dbSDimitry Andric Phi->removeIncomingValue(IncrementIndex); 770*5ffd83dbSDimitry Andric Phi->removeIncomingValue(StartIndex); 771*5ffd83dbSDimitry Andric } 772*5ffd83dbSDimitry Andric 773*5ffd83dbSDimitry Andric void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi, 774*5ffd83dbSDimitry Andric Value *IncrementPerRound, 775*5ffd83dbSDimitry Andric Value *OffsSecondOperand, 776*5ffd83dbSDimitry Andric unsigned LoopIncrement, 777*5ffd83dbSDimitry Andric IRBuilder<> &Builder) { 778*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n"); 779*5ffd83dbSDimitry Andric 780*5ffd83dbSDimitry Andric // Create a new scalar add outside of the loop and transform it to a splat 781*5ffd83dbSDimitry Andric // by which loop variable can be incremented 782*5ffd83dbSDimitry Andric Instruction *InsertionPoint = &cast<Instruction>( 783*5ffd83dbSDimitry Andric Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back()); 784*5ffd83dbSDimitry Andric 785*5ffd83dbSDimitry Andric // Create a new index 786*5ffd83dbSDimitry Andric Value *StartIndex = BinaryOperator::Create( 787*5ffd83dbSDimitry Andric Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1), 788*5ffd83dbSDimitry Andric OffsSecondOperand, "PushedOutMul", InsertionPoint); 789*5ffd83dbSDimitry Andric 790*5ffd83dbSDimitry Andric Instruction *Product = 791*5ffd83dbSDimitry Andric BinaryOperator::Create(Instruction::Mul, IncrementPerRound, 792*5ffd83dbSDimitry Andric OffsSecondOperand, "Product", InsertionPoint); 793*5ffd83dbSDimitry Andric // Increment NewIndex by Product instead of the multiplication 794*5ffd83dbSDimitry Andric Instruction *NewIncrement = BinaryOperator::Create( 795*5ffd83dbSDimitry Andric Instruction::Add, Phi, Product, "IncrementPushedOutMul", 796*5ffd83dbSDimitry Andric cast<Instruction>(Phi->getIncomingBlock(LoopIncrement)->back()) 797*5ffd83dbSDimitry Andric .getPrevNode()); 798*5ffd83dbSDimitry Andric 799*5ffd83dbSDimitry Andric Phi->addIncoming(StartIndex, 800*5ffd83dbSDimitry Andric Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)); 801*5ffd83dbSDimitry Andric Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement)); 802*5ffd83dbSDimitry Andric Phi->removeIncomingValue((unsigned)0); 803*5ffd83dbSDimitry Andric Phi->removeIncomingValue((unsigned)0); 804*5ffd83dbSDimitry Andric return; 805*5ffd83dbSDimitry Andric } 806*5ffd83dbSDimitry Andric 807*5ffd83dbSDimitry Andric // Check whether all usages of this instruction are as offsets of 808*5ffd83dbSDimitry Andric // gathers/scatters or simple arithmetics only used by gathers/scatters 809*5ffd83dbSDimitry Andric static bool hasAllGatScatUsers(Instruction *I) { 810*5ffd83dbSDimitry Andric if (I->hasNUses(0)) { 811*5ffd83dbSDimitry Andric return false; 812*5ffd83dbSDimitry Andric } 813*5ffd83dbSDimitry Andric bool Gatscat = true; 814*5ffd83dbSDimitry Andric for (User *U : I->users()) { 815*5ffd83dbSDimitry Andric if (!isa<Instruction>(U)) 816*5ffd83dbSDimitry Andric return false; 817*5ffd83dbSDimitry Andric if (isa<GetElementPtrInst>(U) || 818*5ffd83dbSDimitry Andric isGatherScatter(dyn_cast<IntrinsicInst>(U))) { 819*5ffd83dbSDimitry Andric return Gatscat; 820*5ffd83dbSDimitry Andric } else { 821*5ffd83dbSDimitry Andric unsigned OpCode = cast<Instruction>(U)->getOpcode(); 822*5ffd83dbSDimitry Andric if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) && 823*5ffd83dbSDimitry Andric hasAllGatScatUsers(cast<Instruction>(U))) { 824*5ffd83dbSDimitry Andric continue; 825*5ffd83dbSDimitry Andric } 826*5ffd83dbSDimitry Andric return false; 827*5ffd83dbSDimitry Andric } 828*5ffd83dbSDimitry Andric } 829*5ffd83dbSDimitry Andric return Gatscat; 830*5ffd83dbSDimitry Andric } 831*5ffd83dbSDimitry Andric 832*5ffd83dbSDimitry Andric bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB, 833*5ffd83dbSDimitry Andric LoopInfo *LI) { 834*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize\n"); 835*5ffd83dbSDimitry Andric // Optimise the addresses of gathers/scatters by moving invariant 836*5ffd83dbSDimitry Andric // calculations out of the loop 837*5ffd83dbSDimitry Andric if (!isa<Instruction>(Offsets)) 838*5ffd83dbSDimitry Andric return false; 839*5ffd83dbSDimitry Andric Instruction *Offs = cast<Instruction>(Offsets); 840*5ffd83dbSDimitry Andric if (Offs->getOpcode() != Instruction::Add && 841*5ffd83dbSDimitry Andric Offs->getOpcode() != Instruction::Mul) 842*5ffd83dbSDimitry Andric return false; 843*5ffd83dbSDimitry Andric Loop *L = LI->getLoopFor(BB); 844*5ffd83dbSDimitry Andric if (L == nullptr) 845*5ffd83dbSDimitry Andric return false; 846*5ffd83dbSDimitry Andric if (!Offs->hasOneUse()) { 847*5ffd83dbSDimitry Andric if (!hasAllGatScatUsers(Offs)) 848*5ffd83dbSDimitry Andric return false; 849*5ffd83dbSDimitry Andric } 850*5ffd83dbSDimitry Andric 851*5ffd83dbSDimitry Andric // Find out which, if any, operand of the instruction 852*5ffd83dbSDimitry Andric // is a phi node 853*5ffd83dbSDimitry Andric PHINode *Phi; 854*5ffd83dbSDimitry Andric int OffsSecondOp; 855*5ffd83dbSDimitry Andric if (isa<PHINode>(Offs->getOperand(0))) { 856*5ffd83dbSDimitry Andric Phi = cast<PHINode>(Offs->getOperand(0)); 857*5ffd83dbSDimitry Andric OffsSecondOp = 1; 858*5ffd83dbSDimitry Andric } else if (isa<PHINode>(Offs->getOperand(1))) { 859*5ffd83dbSDimitry Andric Phi = cast<PHINode>(Offs->getOperand(1)); 860*5ffd83dbSDimitry Andric OffsSecondOp = 0; 861*5ffd83dbSDimitry Andric } else { 862*5ffd83dbSDimitry Andric bool Changed = true; 863*5ffd83dbSDimitry Andric if (isa<Instruction>(Offs->getOperand(0)) && 864*5ffd83dbSDimitry Andric L->contains(cast<Instruction>(Offs->getOperand(0)))) 865*5ffd83dbSDimitry Andric Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI); 866*5ffd83dbSDimitry Andric if (isa<Instruction>(Offs->getOperand(1)) && 867*5ffd83dbSDimitry Andric L->contains(cast<Instruction>(Offs->getOperand(1)))) 868*5ffd83dbSDimitry Andric Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI); 869*5ffd83dbSDimitry Andric if (!Changed) { 870*5ffd83dbSDimitry Andric return false; 871*5ffd83dbSDimitry Andric } else { 872*5ffd83dbSDimitry Andric if (isa<PHINode>(Offs->getOperand(0))) { 873*5ffd83dbSDimitry Andric Phi = cast<PHINode>(Offs->getOperand(0)); 874*5ffd83dbSDimitry Andric OffsSecondOp = 1; 875*5ffd83dbSDimitry Andric } else if (isa<PHINode>(Offs->getOperand(1))) { 876*5ffd83dbSDimitry Andric Phi = cast<PHINode>(Offs->getOperand(1)); 877*5ffd83dbSDimitry Andric OffsSecondOp = 0; 878*5ffd83dbSDimitry Andric } else { 879*5ffd83dbSDimitry Andric return false; 880*5ffd83dbSDimitry Andric } 881*5ffd83dbSDimitry Andric } 882*5ffd83dbSDimitry Andric } 883*5ffd83dbSDimitry Andric // A phi node we want to perform this function on should be from the 884*5ffd83dbSDimitry Andric // loop header, and shouldn't have more than 2 incoming values 885*5ffd83dbSDimitry Andric if (Phi->getParent() != L->getHeader() || 886*5ffd83dbSDimitry Andric Phi->getNumIncomingValues() != 2) 887*5ffd83dbSDimitry Andric return false; 888*5ffd83dbSDimitry Andric 889*5ffd83dbSDimitry Andric // The phi must be an induction variable 890*5ffd83dbSDimitry Andric Instruction *Op; 891*5ffd83dbSDimitry Andric int IncrementingBlock = -1; 892*5ffd83dbSDimitry Andric 893*5ffd83dbSDimitry Andric for (int i = 0; i < 2; i++) 894*5ffd83dbSDimitry Andric if ((Op = dyn_cast<Instruction>(Phi->getIncomingValue(i))) != nullptr) 895*5ffd83dbSDimitry Andric if (Op->getOpcode() == Instruction::Add && 896*5ffd83dbSDimitry Andric (Op->getOperand(0) == Phi || Op->getOperand(1) == Phi)) 897*5ffd83dbSDimitry Andric IncrementingBlock = i; 898*5ffd83dbSDimitry Andric if (IncrementingBlock == -1) 899*5ffd83dbSDimitry Andric return false; 900*5ffd83dbSDimitry Andric 901*5ffd83dbSDimitry Andric Instruction *IncInstruction = 902*5ffd83dbSDimitry Andric cast<Instruction>(Phi->getIncomingValue(IncrementingBlock)); 903*5ffd83dbSDimitry Andric 904*5ffd83dbSDimitry Andric // If the phi is not used by anything else, we can just adapt it when 905*5ffd83dbSDimitry Andric // replacing the instruction; if it is, we'll have to duplicate it 906*5ffd83dbSDimitry Andric PHINode *NewPhi; 907*5ffd83dbSDimitry Andric Value *IncrementPerRound = IncInstruction->getOperand( 908*5ffd83dbSDimitry Andric (IncInstruction->getOperand(0) == Phi) ? 1 : 0); 909*5ffd83dbSDimitry Andric 910*5ffd83dbSDimitry Andric // Get the value that is added to/multiplied with the phi 911*5ffd83dbSDimitry Andric Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp); 912*5ffd83dbSDimitry Andric 913*5ffd83dbSDimitry Andric if (IncrementPerRound->getType() != OffsSecondOperand->getType()) 914*5ffd83dbSDimitry Andric // Something has gone wrong, abort 915*5ffd83dbSDimitry Andric return false; 916*5ffd83dbSDimitry Andric 917*5ffd83dbSDimitry Andric // Only proceed if the increment per round is a constant or an instruction 918*5ffd83dbSDimitry Andric // which does not originate from within the loop 919*5ffd83dbSDimitry Andric if (!isa<Constant>(IncrementPerRound) && 920*5ffd83dbSDimitry Andric !(isa<Instruction>(IncrementPerRound) && 921*5ffd83dbSDimitry Andric !L->contains(cast<Instruction>(IncrementPerRound)))) 922*5ffd83dbSDimitry Andric return false; 923*5ffd83dbSDimitry Andric 924*5ffd83dbSDimitry Andric if (Phi->getNumUses() == 2) { 925*5ffd83dbSDimitry Andric // No other users -> reuse existing phi (One user is the instruction 926*5ffd83dbSDimitry Andric // we're looking at, the other is the phi increment) 927*5ffd83dbSDimitry Andric if (IncInstruction->getNumUses() != 1) { 928*5ffd83dbSDimitry Andric // If the incrementing instruction does have more users than 929*5ffd83dbSDimitry Andric // our phi, we need to copy it 930*5ffd83dbSDimitry Andric IncInstruction = BinaryOperator::Create( 931*5ffd83dbSDimitry Andric Instruction::BinaryOps(IncInstruction->getOpcode()), Phi, 932*5ffd83dbSDimitry Andric IncrementPerRound, "LoopIncrement", IncInstruction); 933*5ffd83dbSDimitry Andric Phi->setIncomingValue(IncrementingBlock, IncInstruction); 934*5ffd83dbSDimitry Andric } 935*5ffd83dbSDimitry Andric NewPhi = Phi; 936*5ffd83dbSDimitry Andric } else { 937*5ffd83dbSDimitry Andric // There are other users -> create a new phi 938*5ffd83dbSDimitry Andric NewPhi = PHINode::Create(Phi->getType(), 0, "NewPhi", Phi); 939*5ffd83dbSDimitry Andric std::vector<Value *> Increases; 940*5ffd83dbSDimitry Andric // Copy the incoming values of the old phi 941*5ffd83dbSDimitry Andric NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1), 942*5ffd83dbSDimitry Andric Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1)); 943*5ffd83dbSDimitry Andric IncInstruction = BinaryOperator::Create( 944*5ffd83dbSDimitry Andric Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi, 945*5ffd83dbSDimitry Andric IncrementPerRound, "LoopIncrement", IncInstruction); 946*5ffd83dbSDimitry Andric NewPhi->addIncoming(IncInstruction, 947*5ffd83dbSDimitry Andric Phi->getIncomingBlock(IncrementingBlock)); 948*5ffd83dbSDimitry Andric IncrementingBlock = 1; 949*5ffd83dbSDimitry Andric } 950*5ffd83dbSDimitry Andric 951*5ffd83dbSDimitry Andric IRBuilder<> Builder(BB->getContext()); 952*5ffd83dbSDimitry Andric Builder.SetInsertPoint(Phi); 953*5ffd83dbSDimitry Andric Builder.SetCurrentDebugLocation(Offs->getDebugLoc()); 954*5ffd83dbSDimitry Andric 955*5ffd83dbSDimitry Andric switch (Offs->getOpcode()) { 956*5ffd83dbSDimitry Andric case Instruction::Add: 957*5ffd83dbSDimitry Andric pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1); 958*5ffd83dbSDimitry Andric break; 959*5ffd83dbSDimitry Andric case Instruction::Mul: 960*5ffd83dbSDimitry Andric pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock, 961*5ffd83dbSDimitry Andric Builder); 962*5ffd83dbSDimitry Andric break; 963*5ffd83dbSDimitry Andric default: 964*5ffd83dbSDimitry Andric return false; 965*5ffd83dbSDimitry Andric } 966*5ffd83dbSDimitry Andric LLVM_DEBUG( 967*5ffd83dbSDimitry Andric dbgs() << "masked gathers/scatters: simplified loop variable add/mul\n"); 968*5ffd83dbSDimitry Andric 969*5ffd83dbSDimitry Andric // The instruction has now been "absorbed" into the phi value 970*5ffd83dbSDimitry Andric Offs->replaceAllUsesWith(NewPhi); 971*5ffd83dbSDimitry Andric if (Offs->hasNUses(0)) 972*5ffd83dbSDimitry Andric Offs->eraseFromParent(); 973*5ffd83dbSDimitry Andric // Clean up the old increment in case it's unused because we built a new 974*5ffd83dbSDimitry Andric // one 975*5ffd83dbSDimitry Andric if (IncInstruction->hasNUses(0)) 976*5ffd83dbSDimitry Andric IncInstruction->eraseFromParent(); 977*5ffd83dbSDimitry Andric 978*5ffd83dbSDimitry Andric return true; 979480093f4SDimitry Andric } 980480093f4SDimitry Andric 981480093f4SDimitry Andric bool MVEGatherScatterLowering::runOnFunction(Function &F) { 982480093f4SDimitry Andric if (!EnableMaskedGatherScatters) 983480093f4SDimitry Andric return false; 984480093f4SDimitry Andric auto &TPC = getAnalysis<TargetPassConfig>(); 985480093f4SDimitry Andric auto &TM = TPC.getTM<TargetMachine>(); 986480093f4SDimitry Andric auto *ST = &TM.getSubtarget<ARMSubtarget>(F); 987480093f4SDimitry Andric if (!ST->hasMVEIntegerOps()) 988480093f4SDimitry Andric return false; 989*5ffd83dbSDimitry Andric LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 990480093f4SDimitry Andric SmallVector<IntrinsicInst *, 4> Gathers; 991*5ffd83dbSDimitry Andric SmallVector<IntrinsicInst *, 4> Scatters; 992*5ffd83dbSDimitry Andric 993*5ffd83dbSDimitry Andric bool Changed = false; 994*5ffd83dbSDimitry Andric 995480093f4SDimitry Andric for (BasicBlock &BB : F) { 996480093f4SDimitry Andric for (Instruction &I : BB) { 997480093f4SDimitry Andric IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I); 998*5ffd83dbSDimitry Andric if (II && II->getIntrinsicID() == Intrinsic::masked_gather) { 999480093f4SDimitry Andric Gathers.push_back(II); 1000*5ffd83dbSDimitry Andric if (isa<GetElementPtrInst>(II->getArgOperand(0))) 1001*5ffd83dbSDimitry Andric Changed |= optimiseOffsets( 1002*5ffd83dbSDimitry Andric cast<Instruction>(II->getArgOperand(0))->getOperand(1), 1003*5ffd83dbSDimitry Andric II->getParent(), LI); 1004*5ffd83dbSDimitry Andric } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) { 1005*5ffd83dbSDimitry Andric Scatters.push_back(II); 1006*5ffd83dbSDimitry Andric if (isa<GetElementPtrInst>(II->getArgOperand(1))) 1007*5ffd83dbSDimitry Andric Changed |= optimiseOffsets( 1008*5ffd83dbSDimitry Andric cast<Instruction>(II->getArgOperand(1))->getOperand(1), 1009*5ffd83dbSDimitry Andric II->getParent(), LI); 1010*5ffd83dbSDimitry Andric } 1011480093f4SDimitry Andric } 1012480093f4SDimitry Andric } 1013480093f4SDimitry Andric 1014*5ffd83dbSDimitry Andric for (unsigned i = 0; i < Gathers.size(); i++) { 1015*5ffd83dbSDimitry Andric IntrinsicInst *I = Gathers[i]; 1016*5ffd83dbSDimitry Andric Value *L = lowerGather(I); 1017*5ffd83dbSDimitry Andric if (L == nullptr) 1018*5ffd83dbSDimitry Andric continue; 1019480093f4SDimitry Andric 1020*5ffd83dbSDimitry Andric // Get rid of any now dead instructions 1021*5ffd83dbSDimitry Andric SimplifyInstructionsInBlock(cast<Instruction>(L)->getParent()); 1022*5ffd83dbSDimitry Andric Changed = true; 1023*5ffd83dbSDimitry Andric } 1024480093f4SDimitry Andric 1025*5ffd83dbSDimitry Andric for (unsigned i = 0; i < Scatters.size(); i++) { 1026*5ffd83dbSDimitry Andric IntrinsicInst *I = Scatters[i]; 1027*5ffd83dbSDimitry Andric Value *S = lowerScatter(I); 1028*5ffd83dbSDimitry Andric if (S == nullptr) 1029*5ffd83dbSDimitry Andric continue; 1030*5ffd83dbSDimitry Andric 1031*5ffd83dbSDimitry Andric // Get rid of any now dead instructions 1032*5ffd83dbSDimitry Andric SimplifyInstructionsInBlock(cast<Instruction>(S)->getParent()); 1033*5ffd83dbSDimitry Andric Changed = true; 1034*5ffd83dbSDimitry Andric } 1035*5ffd83dbSDimitry Andric return Changed; 1036480093f4SDimitry Andric } 1037