1*480093f4SDimitry Andric //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===// 2*480093f4SDimitry Andric // 3*480093f4SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*480093f4SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5*480093f4SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*480093f4SDimitry Andric // 7*480093f4SDimitry Andric //===----------------------------------------------------------------------===// 8*480093f4SDimitry Andric // 9*480093f4SDimitry Andric /// This pass custom lowers llvm.gather and llvm.scatter instructions to 10*480093f4SDimitry Andric /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to 11*480093f4SDimitry Andric /// produce a better final result as we go. 12*480093f4SDimitry Andric // 13*480093f4SDimitry Andric //===----------------------------------------------------------------------===// 14*480093f4SDimitry Andric 15*480093f4SDimitry Andric #include "ARM.h" 16*480093f4SDimitry Andric #include "ARMBaseInstrInfo.h" 17*480093f4SDimitry Andric #include "ARMSubtarget.h" 18*480093f4SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h" 19*480093f4SDimitry Andric #include "llvm/CodeGen/TargetLowering.h" 20*480093f4SDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h" 21*480093f4SDimitry Andric #include "llvm/CodeGen/TargetSubtargetInfo.h" 22*480093f4SDimitry Andric #include "llvm/InitializePasses.h" 23*480093f4SDimitry Andric #include "llvm/IR/BasicBlock.h" 24*480093f4SDimitry Andric #include "llvm/IR/Constant.h" 25*480093f4SDimitry Andric #include "llvm/IR/Constants.h" 26*480093f4SDimitry Andric #include "llvm/IR/DerivedTypes.h" 27*480093f4SDimitry Andric #include "llvm/IR/Function.h" 28*480093f4SDimitry Andric #include "llvm/IR/InstrTypes.h" 29*480093f4SDimitry Andric #include "llvm/IR/Instruction.h" 30*480093f4SDimitry Andric #include "llvm/IR/Instructions.h" 31*480093f4SDimitry Andric #include "llvm/IR/IntrinsicInst.h" 32*480093f4SDimitry Andric #include "llvm/IR/Intrinsics.h" 33*480093f4SDimitry Andric #include "llvm/IR/IntrinsicsARM.h" 34*480093f4SDimitry Andric #include "llvm/IR/IRBuilder.h" 35*480093f4SDimitry Andric #include "llvm/IR/PatternMatch.h" 36*480093f4SDimitry Andric #include "llvm/IR/Type.h" 37*480093f4SDimitry Andric #include "llvm/IR/Value.h" 38*480093f4SDimitry Andric #include "llvm/Pass.h" 39*480093f4SDimitry Andric #include "llvm/Support/Casting.h" 40*480093f4SDimitry Andric #include <algorithm> 41*480093f4SDimitry Andric #include <cassert> 42*480093f4SDimitry Andric 43*480093f4SDimitry Andric using namespace llvm; 44*480093f4SDimitry Andric 45*480093f4SDimitry Andric #define DEBUG_TYPE "mve-gather-scatter-lowering" 46*480093f4SDimitry Andric 47*480093f4SDimitry Andric cl::opt<bool> EnableMaskedGatherScatters( 48*480093f4SDimitry Andric "enable-arm-maskedgatscat", cl::Hidden, cl::init(false), 49*480093f4SDimitry Andric cl::desc("Enable the generation of masked gathers and scatters")); 50*480093f4SDimitry Andric 51*480093f4SDimitry Andric namespace { 52*480093f4SDimitry Andric 53*480093f4SDimitry Andric class MVEGatherScatterLowering : public FunctionPass { 54*480093f4SDimitry Andric public: 55*480093f4SDimitry Andric static char ID; // Pass identification, replacement for typeid 56*480093f4SDimitry Andric 57*480093f4SDimitry Andric explicit MVEGatherScatterLowering() : FunctionPass(ID) { 58*480093f4SDimitry Andric initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry()); 59*480093f4SDimitry Andric } 60*480093f4SDimitry Andric 61*480093f4SDimitry Andric bool runOnFunction(Function &F) override; 62*480093f4SDimitry Andric 63*480093f4SDimitry Andric StringRef getPassName() const override { 64*480093f4SDimitry Andric return "MVE gather/scatter lowering"; 65*480093f4SDimitry Andric } 66*480093f4SDimitry Andric 67*480093f4SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 68*480093f4SDimitry Andric AU.setPreservesCFG(); 69*480093f4SDimitry Andric AU.addRequired<TargetPassConfig>(); 70*480093f4SDimitry Andric FunctionPass::getAnalysisUsage(AU); 71*480093f4SDimitry Andric } 72*480093f4SDimitry Andric 73*480093f4SDimitry Andric private: 74*480093f4SDimitry Andric // Check this is a valid gather with correct alignment 75*480093f4SDimitry Andric bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize, 76*480093f4SDimitry Andric unsigned Alignment); 77*480093f4SDimitry Andric // Check whether Ptr is hidden behind a bitcast and look through it 78*480093f4SDimitry Andric void lookThroughBitcast(Value *&Ptr); 79*480093f4SDimitry Andric // Check for a getelementptr and deduce base and offsets from it, on success 80*480093f4SDimitry Andric // returning the base directly and the offsets indirectly using the Offsets 81*480093f4SDimitry Andric // argument 82*480093f4SDimitry Andric Value *checkGEP(Value *&Offsets, Type *Ty, Value *Ptr, IRBuilder<> Builder); 83*480093f4SDimitry Andric 84*480093f4SDimitry Andric bool lowerGather(IntrinsicInst *I); 85*480093f4SDimitry Andric // Create a gather from a base + vector of offsets 86*480093f4SDimitry Andric Value *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr, 87*480093f4SDimitry Andric IRBuilder<> Builder); 88*480093f4SDimitry Andric // Create a gather from a vector of pointers 89*480093f4SDimitry Andric Value *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr, 90*480093f4SDimitry Andric IRBuilder<> Builder); 91*480093f4SDimitry Andric }; 92*480093f4SDimitry Andric 93*480093f4SDimitry Andric } // end anonymous namespace 94*480093f4SDimitry Andric 95*480093f4SDimitry Andric char MVEGatherScatterLowering::ID = 0; 96*480093f4SDimitry Andric 97*480093f4SDimitry Andric INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE, 98*480093f4SDimitry Andric "MVE gather/scattering lowering pass", false, false) 99*480093f4SDimitry Andric 100*480093f4SDimitry Andric Pass *llvm::createMVEGatherScatterLoweringPass() { 101*480093f4SDimitry Andric return new MVEGatherScatterLowering(); 102*480093f4SDimitry Andric } 103*480093f4SDimitry Andric 104*480093f4SDimitry Andric bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements, 105*480093f4SDimitry Andric unsigned ElemSize, 106*480093f4SDimitry Andric unsigned Alignment) { 107*480093f4SDimitry Andric // Do only allow non-extending gathers for now 108*480093f4SDimitry Andric if (((NumElements == 4 && ElemSize == 32) || 109*480093f4SDimitry Andric (NumElements == 8 && ElemSize == 16) || 110*480093f4SDimitry Andric (NumElements == 16 && ElemSize == 8)) && 111*480093f4SDimitry Andric ElemSize / 8 <= Alignment) 112*480093f4SDimitry Andric return true; 113*480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: instruction does not have valid " 114*480093f4SDimitry Andric << "alignment or vector type \n"); 115*480093f4SDimitry Andric return false; 116*480093f4SDimitry Andric } 117*480093f4SDimitry Andric 118*480093f4SDimitry Andric Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, Type *Ty, Value *Ptr, 119*480093f4SDimitry Andric IRBuilder<> Builder) { 120*480093f4SDimitry Andric GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); 121*480093f4SDimitry Andric if (!GEP) { 122*480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: no getelementpointer found\n"); 123*480093f4SDimitry Andric return nullptr; 124*480093f4SDimitry Andric } 125*480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: getelementpointer found. Loading" 126*480093f4SDimitry Andric << " from base + vector of offsets\n"); 127*480093f4SDimitry Andric Value *GEPPtr = GEP->getPointerOperand(); 128*480093f4SDimitry Andric if (GEPPtr->getType()->isVectorTy()) { 129*480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: gather from a vector of pointers" 130*480093f4SDimitry Andric << " hidden behind a getelementptr currently not" 131*480093f4SDimitry Andric << " supported. Expanding.\n"); 132*480093f4SDimitry Andric return nullptr; 133*480093f4SDimitry Andric } 134*480093f4SDimitry Andric if (GEP->getNumOperands() != 2) { 135*480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: getelementptr with too many" 136*480093f4SDimitry Andric << " operands. Expanding.\n"); 137*480093f4SDimitry Andric return nullptr; 138*480093f4SDimitry Andric } 139*480093f4SDimitry Andric Offsets = GEP->getOperand(1); 140*480093f4SDimitry Andric // SExt offsets inside masked gathers are not permitted by the architecture; 141*480093f4SDimitry Andric // we therefore can't fold them 142*480093f4SDimitry Andric if (ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets)) 143*480093f4SDimitry Andric Offsets = ZextOffs->getOperand(0); 144*480093f4SDimitry Andric Type *OffsType = VectorType::getInteger(cast<VectorType>(Ty)); 145*480093f4SDimitry Andric // If the offset we found does not have the type the intrinsic expects, 146*480093f4SDimitry Andric // i.e., the same type as the gather itself, we need to convert it (only i 147*480093f4SDimitry Andric // types) or fall back to expanding the gather 148*480093f4SDimitry Andric if (OffsType != Offsets->getType()) { 149*480093f4SDimitry Andric if (OffsType->getScalarSizeInBits() > 150*480093f4SDimitry Andric Offsets->getType()->getScalarSizeInBits()) { 151*480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: extending offsets\n"); 152*480093f4SDimitry Andric Offsets = Builder.CreateZExt(Offsets, OffsType, ""); 153*480093f4SDimitry Andric } else { 154*480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: no correct offset type. Can't" 155*480093f4SDimitry Andric << " create masked gather\n"); 156*480093f4SDimitry Andric return nullptr; 157*480093f4SDimitry Andric } 158*480093f4SDimitry Andric } 159*480093f4SDimitry Andric // If none of the checks failed, return the gep's base pointer 160*480093f4SDimitry Andric return GEPPtr; 161*480093f4SDimitry Andric } 162*480093f4SDimitry Andric 163*480093f4SDimitry Andric void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) { 164*480093f4SDimitry Andric // Look through bitcast instruction if #elements is the same 165*480093f4SDimitry Andric if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) { 166*480093f4SDimitry Andric Type *BCTy = BitCast->getType(); 167*480093f4SDimitry Andric Type *BCSrcTy = BitCast->getOperand(0)->getType(); 168*480093f4SDimitry Andric if (BCTy->getVectorNumElements() == BCSrcTy->getVectorNumElements()) { 169*480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: looking through bitcast\n"); 170*480093f4SDimitry Andric Ptr = BitCast->getOperand(0); 171*480093f4SDimitry Andric } 172*480093f4SDimitry Andric } 173*480093f4SDimitry Andric } 174*480093f4SDimitry Andric 175*480093f4SDimitry Andric bool MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) { 176*480093f4SDimitry Andric using namespace PatternMatch; 177*480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n"); 178*480093f4SDimitry Andric 179*480093f4SDimitry Andric // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0) 180*480093f4SDimitry Andric // Attempt to turn the masked gather in I into a MVE intrinsic 181*480093f4SDimitry Andric // Potentially optimising the addressing modes as we do so. 182*480093f4SDimitry Andric Type *Ty = I->getType(); 183*480093f4SDimitry Andric Value *Ptr = I->getArgOperand(0); 184*480093f4SDimitry Andric unsigned Alignment = cast<ConstantInt>(I->getArgOperand(1))->getZExtValue(); 185*480093f4SDimitry Andric Value *Mask = I->getArgOperand(2); 186*480093f4SDimitry Andric Value *PassThru = I->getArgOperand(3); 187*480093f4SDimitry Andric 188*480093f4SDimitry Andric if (!isLegalTypeAndAlignment(Ty->getVectorNumElements(), 189*480093f4SDimitry Andric Ty->getScalarSizeInBits(), Alignment)) 190*480093f4SDimitry Andric return false; 191*480093f4SDimitry Andric lookThroughBitcast(Ptr); 192*480093f4SDimitry Andric assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); 193*480093f4SDimitry Andric 194*480093f4SDimitry Andric IRBuilder<> Builder(I->getContext()); 195*480093f4SDimitry Andric Builder.SetInsertPoint(I); 196*480093f4SDimitry Andric Builder.SetCurrentDebugLocation(I->getDebugLoc()); 197*480093f4SDimitry Andric Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Builder); 198*480093f4SDimitry Andric if (!Load) 199*480093f4SDimitry Andric Load = tryCreateMaskedGatherBase(I, Ptr, Builder); 200*480093f4SDimitry Andric if (!Load) 201*480093f4SDimitry Andric return false; 202*480093f4SDimitry Andric 203*480093f4SDimitry Andric if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) { 204*480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - " 205*480093f4SDimitry Andric << "creating select\n"); 206*480093f4SDimitry Andric Load = Builder.CreateSelect(Mask, Load, PassThru); 207*480093f4SDimitry Andric } 208*480093f4SDimitry Andric 209*480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n"); 210*480093f4SDimitry Andric I->replaceAllUsesWith(Load); 211*480093f4SDimitry Andric I->eraseFromParent(); 212*480093f4SDimitry Andric return true; 213*480093f4SDimitry Andric } 214*480093f4SDimitry Andric 215*480093f4SDimitry Andric Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase( 216*480093f4SDimitry Andric IntrinsicInst *I, Value *Ptr, IRBuilder<> Builder) { 217*480093f4SDimitry Andric using namespace PatternMatch; 218*480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n"); 219*480093f4SDimitry Andric Type *Ty = I->getType(); 220*480093f4SDimitry Andric if (Ty->getVectorNumElements() != 4) 221*480093f4SDimitry Andric // Can't build an intrinsic for this 222*480093f4SDimitry Andric return nullptr; 223*480093f4SDimitry Andric Value *Mask = I->getArgOperand(2); 224*480093f4SDimitry Andric if (match(Mask, m_One())) 225*480093f4SDimitry Andric return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base, 226*480093f4SDimitry Andric {Ty, Ptr->getType()}, 227*480093f4SDimitry Andric {Ptr, Builder.getInt32(0)}); 228*480093f4SDimitry Andric else 229*480093f4SDimitry Andric return Builder.CreateIntrinsic( 230*480093f4SDimitry Andric Intrinsic::arm_mve_vldr_gather_base_predicated, 231*480093f4SDimitry Andric {Ty, Ptr->getType(), Mask->getType()}, 232*480093f4SDimitry Andric {Ptr, Builder.getInt32(0), Mask}); 233*480093f4SDimitry Andric } 234*480093f4SDimitry Andric 235*480093f4SDimitry Andric Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset( 236*480093f4SDimitry Andric IntrinsicInst *I, Value *Ptr, IRBuilder<> Builder) { 237*480093f4SDimitry Andric using namespace PatternMatch; 238*480093f4SDimitry Andric Type *Ty = I->getType(); 239*480093f4SDimitry Andric Value *Offsets; 240*480093f4SDimitry Andric Value *BasePtr = checkGEP(Offsets, Ty, Ptr, Builder); 241*480093f4SDimitry Andric if (!BasePtr) 242*480093f4SDimitry Andric return nullptr; 243*480093f4SDimitry Andric 244*480093f4SDimitry Andric unsigned Scale; 245*480093f4SDimitry Andric int GEPElemSize = 246*480093f4SDimitry Andric BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(); 247*480093f4SDimitry Andric int ResultElemSize = Ty->getScalarSizeInBits(); 248*480093f4SDimitry Andric // This can be a 32bit load scaled by 4, a 16bit load scaled by 2, or a 249*480093f4SDimitry Andric // 8bit, 16bit or 32bit load scaled by 1 250*480093f4SDimitry Andric if (GEPElemSize == 32 && ResultElemSize == 32) { 251*480093f4SDimitry Andric Scale = 2; 252*480093f4SDimitry Andric } else if (GEPElemSize == 16 && ResultElemSize == 16) { 253*480093f4SDimitry Andric Scale = 1; 254*480093f4SDimitry Andric } else if (GEPElemSize == 8) { 255*480093f4SDimitry Andric Scale = 0; 256*480093f4SDimitry Andric } else { 257*480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "masked gathers: incorrect scale for load. Can't" 258*480093f4SDimitry Andric << " create masked gather\n"); 259*480093f4SDimitry Andric return nullptr; 260*480093f4SDimitry Andric } 261*480093f4SDimitry Andric 262*480093f4SDimitry Andric Value *Mask = I->getArgOperand(2); 263*480093f4SDimitry Andric if (!match(Mask, m_One())) 264*480093f4SDimitry Andric return Builder.CreateIntrinsic( 265*480093f4SDimitry Andric Intrinsic::arm_mve_vldr_gather_offset_predicated, 266*480093f4SDimitry Andric {Ty, BasePtr->getType(), Offsets->getType(), Mask->getType()}, 267*480093f4SDimitry Andric {BasePtr, Offsets, Builder.getInt32(Ty->getScalarSizeInBits()), 268*480093f4SDimitry Andric Builder.getInt32(Scale), Builder.getInt32(1), Mask}); 269*480093f4SDimitry Andric else 270*480093f4SDimitry Andric return Builder.CreateIntrinsic( 271*480093f4SDimitry Andric Intrinsic::arm_mve_vldr_gather_offset, 272*480093f4SDimitry Andric {Ty, BasePtr->getType(), Offsets->getType()}, 273*480093f4SDimitry Andric {BasePtr, Offsets, Builder.getInt32(Ty->getScalarSizeInBits()), 274*480093f4SDimitry Andric Builder.getInt32(Scale), Builder.getInt32(1)}); 275*480093f4SDimitry Andric } 276*480093f4SDimitry Andric 277*480093f4SDimitry Andric bool MVEGatherScatterLowering::runOnFunction(Function &F) { 278*480093f4SDimitry Andric if (!EnableMaskedGatherScatters) 279*480093f4SDimitry Andric return false; 280*480093f4SDimitry Andric auto &TPC = getAnalysis<TargetPassConfig>(); 281*480093f4SDimitry Andric auto &TM = TPC.getTM<TargetMachine>(); 282*480093f4SDimitry Andric auto *ST = &TM.getSubtarget<ARMSubtarget>(F); 283*480093f4SDimitry Andric if (!ST->hasMVEIntegerOps()) 284*480093f4SDimitry Andric return false; 285*480093f4SDimitry Andric SmallVector<IntrinsicInst *, 4> Gathers; 286*480093f4SDimitry Andric for (BasicBlock &BB : F) { 287*480093f4SDimitry Andric for (Instruction &I : BB) { 288*480093f4SDimitry Andric IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I); 289*480093f4SDimitry Andric if (II && II->getIntrinsicID() == Intrinsic::masked_gather) 290*480093f4SDimitry Andric Gathers.push_back(II); 291*480093f4SDimitry Andric } 292*480093f4SDimitry Andric } 293*480093f4SDimitry Andric 294*480093f4SDimitry Andric if (Gathers.empty()) 295*480093f4SDimitry Andric return false; 296*480093f4SDimitry Andric 297*480093f4SDimitry Andric for (IntrinsicInst *I : Gathers) 298*480093f4SDimitry Andric lowerGather(I); 299*480093f4SDimitry Andric 300*480093f4SDimitry Andric return true; 301*480093f4SDimitry Andric } 302