xref: /freebsd/contrib/llvm-project/llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp (revision 480093f4440d54b30b3025afeac24b48f2ba7a2e)
1*480093f4SDimitry Andric //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===//
2*480093f4SDimitry Andric //
3*480093f4SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*480093f4SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*480093f4SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*480093f4SDimitry Andric //
7*480093f4SDimitry Andric //===----------------------------------------------------------------------===//
8*480093f4SDimitry Andric //
9*480093f4SDimitry Andric /// This pass custom lowers llvm.gather and llvm.scatter instructions to
10*480093f4SDimitry Andric /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to
11*480093f4SDimitry Andric /// produce a better final result as we go.
12*480093f4SDimitry Andric //
13*480093f4SDimitry Andric //===----------------------------------------------------------------------===//
14*480093f4SDimitry Andric 
15*480093f4SDimitry Andric #include "ARM.h"
16*480093f4SDimitry Andric #include "ARMBaseInstrInfo.h"
17*480093f4SDimitry Andric #include "ARMSubtarget.h"
18*480093f4SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h"
19*480093f4SDimitry Andric #include "llvm/CodeGen/TargetLowering.h"
20*480093f4SDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h"
21*480093f4SDimitry Andric #include "llvm/CodeGen/TargetSubtargetInfo.h"
22*480093f4SDimitry Andric #include "llvm/InitializePasses.h"
23*480093f4SDimitry Andric #include "llvm/IR/BasicBlock.h"
24*480093f4SDimitry Andric #include "llvm/IR/Constant.h"
25*480093f4SDimitry Andric #include "llvm/IR/Constants.h"
26*480093f4SDimitry Andric #include "llvm/IR/DerivedTypes.h"
27*480093f4SDimitry Andric #include "llvm/IR/Function.h"
28*480093f4SDimitry Andric #include "llvm/IR/InstrTypes.h"
29*480093f4SDimitry Andric #include "llvm/IR/Instruction.h"
30*480093f4SDimitry Andric #include "llvm/IR/Instructions.h"
31*480093f4SDimitry Andric #include "llvm/IR/IntrinsicInst.h"
32*480093f4SDimitry Andric #include "llvm/IR/Intrinsics.h"
33*480093f4SDimitry Andric #include "llvm/IR/IntrinsicsARM.h"
34*480093f4SDimitry Andric #include "llvm/IR/IRBuilder.h"
35*480093f4SDimitry Andric #include "llvm/IR/PatternMatch.h"
36*480093f4SDimitry Andric #include "llvm/IR/Type.h"
37*480093f4SDimitry Andric #include "llvm/IR/Value.h"
38*480093f4SDimitry Andric #include "llvm/Pass.h"
39*480093f4SDimitry Andric #include "llvm/Support/Casting.h"
40*480093f4SDimitry Andric #include <algorithm>
41*480093f4SDimitry Andric #include <cassert>
42*480093f4SDimitry Andric 
43*480093f4SDimitry Andric using namespace llvm;
44*480093f4SDimitry Andric 
45*480093f4SDimitry Andric #define DEBUG_TYPE "mve-gather-scatter-lowering"
46*480093f4SDimitry Andric 
47*480093f4SDimitry Andric cl::opt<bool> EnableMaskedGatherScatters(
48*480093f4SDimitry Andric     "enable-arm-maskedgatscat", cl::Hidden, cl::init(false),
49*480093f4SDimitry Andric     cl::desc("Enable the generation of masked gathers and scatters"));
50*480093f4SDimitry Andric 
51*480093f4SDimitry Andric namespace {
52*480093f4SDimitry Andric 
53*480093f4SDimitry Andric class MVEGatherScatterLowering : public FunctionPass {
54*480093f4SDimitry Andric public:
55*480093f4SDimitry Andric   static char ID; // Pass identification, replacement for typeid
56*480093f4SDimitry Andric 
57*480093f4SDimitry Andric   explicit MVEGatherScatterLowering() : FunctionPass(ID) {
58*480093f4SDimitry Andric     initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry());
59*480093f4SDimitry Andric   }
60*480093f4SDimitry Andric 
61*480093f4SDimitry Andric   bool runOnFunction(Function &F) override;
62*480093f4SDimitry Andric 
63*480093f4SDimitry Andric   StringRef getPassName() const override {
64*480093f4SDimitry Andric     return "MVE gather/scatter lowering";
65*480093f4SDimitry Andric   }
66*480093f4SDimitry Andric 
67*480093f4SDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
68*480093f4SDimitry Andric     AU.setPreservesCFG();
69*480093f4SDimitry Andric     AU.addRequired<TargetPassConfig>();
70*480093f4SDimitry Andric     FunctionPass::getAnalysisUsage(AU);
71*480093f4SDimitry Andric   }
72*480093f4SDimitry Andric 
73*480093f4SDimitry Andric private:
74*480093f4SDimitry Andric   // Check this is a valid gather with correct alignment
75*480093f4SDimitry Andric   bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,
76*480093f4SDimitry Andric                                unsigned Alignment);
77*480093f4SDimitry Andric   // Check whether Ptr is hidden behind a bitcast and look through it
78*480093f4SDimitry Andric   void lookThroughBitcast(Value *&Ptr);
79*480093f4SDimitry Andric   // Check for a getelementptr and deduce base and offsets from it, on success
80*480093f4SDimitry Andric   // returning the base directly and the offsets indirectly using the Offsets
81*480093f4SDimitry Andric   // argument
82*480093f4SDimitry Andric   Value *checkGEP(Value *&Offsets, Type *Ty, Value *Ptr, IRBuilder<> Builder);
83*480093f4SDimitry Andric 
84*480093f4SDimitry Andric   bool lowerGather(IntrinsicInst *I);
85*480093f4SDimitry Andric   // Create a gather from a base + vector of offsets
86*480093f4SDimitry Andric   Value *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,
87*480093f4SDimitry Andric                                      IRBuilder<> Builder);
88*480093f4SDimitry Andric   // Create a gather from a vector of pointers
89*480093f4SDimitry Andric   Value *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,
90*480093f4SDimitry Andric                                    IRBuilder<> Builder);
91*480093f4SDimitry Andric };
92*480093f4SDimitry Andric 
93*480093f4SDimitry Andric } // end anonymous namespace
94*480093f4SDimitry Andric 
95*480093f4SDimitry Andric char MVEGatherScatterLowering::ID = 0;
96*480093f4SDimitry Andric 
97*480093f4SDimitry Andric INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE,
98*480093f4SDimitry Andric                 "MVE gather/scattering lowering pass", false, false)
99*480093f4SDimitry Andric 
100*480093f4SDimitry Andric Pass *llvm::createMVEGatherScatterLoweringPass() {
101*480093f4SDimitry Andric   return new MVEGatherScatterLowering();
102*480093f4SDimitry Andric }
103*480093f4SDimitry Andric 
104*480093f4SDimitry Andric bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,
105*480093f4SDimitry Andric                                                        unsigned ElemSize,
106*480093f4SDimitry Andric                                                        unsigned Alignment) {
107*480093f4SDimitry Andric   // Do only allow non-extending gathers for now
108*480093f4SDimitry Andric   if (((NumElements == 4 && ElemSize == 32) ||
109*480093f4SDimitry Andric        (NumElements == 8 && ElemSize == 16) ||
110*480093f4SDimitry Andric        (NumElements == 16 && ElemSize == 8)) &&
111*480093f4SDimitry Andric       ElemSize / 8 <= Alignment)
112*480093f4SDimitry Andric     return true;
113*480093f4SDimitry Andric   LLVM_DEBUG(dbgs() << "masked gathers: instruction does not have valid "
114*480093f4SDimitry Andric                     << "alignment or vector type \n");
115*480093f4SDimitry Andric   return false;
116*480093f4SDimitry Andric }
117*480093f4SDimitry Andric 
118*480093f4SDimitry Andric Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, Type *Ty, Value *Ptr,
119*480093f4SDimitry Andric                                           IRBuilder<> Builder) {
120*480093f4SDimitry Andric   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
121*480093f4SDimitry Andric   if (!GEP) {
122*480093f4SDimitry Andric     LLVM_DEBUG(dbgs() << "masked gathers: no getelementpointer found\n");
123*480093f4SDimitry Andric     return nullptr;
124*480093f4SDimitry Andric   }
125*480093f4SDimitry Andric   LLVM_DEBUG(dbgs() << "masked gathers: getelementpointer found. Loading"
126*480093f4SDimitry Andric                     << " from base + vector of offsets\n");
127*480093f4SDimitry Andric   Value *GEPPtr = GEP->getPointerOperand();
128*480093f4SDimitry Andric   if (GEPPtr->getType()->isVectorTy()) {
129*480093f4SDimitry Andric     LLVM_DEBUG(dbgs() << "masked gathers: gather from a vector of pointers"
130*480093f4SDimitry Andric                       << " hidden behind a getelementptr currently not"
131*480093f4SDimitry Andric                       << " supported. Expanding.\n");
132*480093f4SDimitry Andric     return nullptr;
133*480093f4SDimitry Andric   }
134*480093f4SDimitry Andric   if (GEP->getNumOperands() != 2) {
135*480093f4SDimitry Andric     LLVM_DEBUG(dbgs() << "masked gathers: getelementptr with too many"
136*480093f4SDimitry Andric                       << " operands. Expanding.\n");
137*480093f4SDimitry Andric     return nullptr;
138*480093f4SDimitry Andric   }
139*480093f4SDimitry Andric   Offsets = GEP->getOperand(1);
140*480093f4SDimitry Andric   // SExt offsets inside masked gathers are not permitted by the architecture;
141*480093f4SDimitry Andric   // we therefore can't fold them
142*480093f4SDimitry Andric   if (ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets))
143*480093f4SDimitry Andric     Offsets = ZextOffs->getOperand(0);
144*480093f4SDimitry Andric   Type *OffsType = VectorType::getInteger(cast<VectorType>(Ty));
145*480093f4SDimitry Andric   // If the offset we found does not have the type the intrinsic expects,
146*480093f4SDimitry Andric   // i.e., the same type as the gather itself, we need to convert it (only i
147*480093f4SDimitry Andric   // types) or fall back to expanding the gather
148*480093f4SDimitry Andric   if (OffsType != Offsets->getType()) {
149*480093f4SDimitry Andric     if (OffsType->getScalarSizeInBits() >
150*480093f4SDimitry Andric         Offsets->getType()->getScalarSizeInBits()) {
151*480093f4SDimitry Andric       LLVM_DEBUG(dbgs() << "masked gathers: extending offsets\n");
152*480093f4SDimitry Andric       Offsets = Builder.CreateZExt(Offsets, OffsType, "");
153*480093f4SDimitry Andric     } else {
154*480093f4SDimitry Andric       LLVM_DEBUG(dbgs() << "masked gathers: no correct offset type. Can't"
155*480093f4SDimitry Andric                         << " create masked gather\n");
156*480093f4SDimitry Andric       return nullptr;
157*480093f4SDimitry Andric     }
158*480093f4SDimitry Andric   }
159*480093f4SDimitry Andric   // If none of the checks failed, return the gep's base pointer
160*480093f4SDimitry Andric   return GEPPtr;
161*480093f4SDimitry Andric }
162*480093f4SDimitry Andric 
163*480093f4SDimitry Andric void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) {
164*480093f4SDimitry Andric   // Look through bitcast instruction if #elements is the same
165*480093f4SDimitry Andric   if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) {
166*480093f4SDimitry Andric     Type *BCTy = BitCast->getType();
167*480093f4SDimitry Andric     Type *BCSrcTy = BitCast->getOperand(0)->getType();
168*480093f4SDimitry Andric     if (BCTy->getVectorNumElements() == BCSrcTy->getVectorNumElements()) {
169*480093f4SDimitry Andric       LLVM_DEBUG(dbgs() << "masked gathers: looking through bitcast\n");
170*480093f4SDimitry Andric       Ptr = BitCast->getOperand(0);
171*480093f4SDimitry Andric     }
172*480093f4SDimitry Andric   }
173*480093f4SDimitry Andric }
174*480093f4SDimitry Andric 
175*480093f4SDimitry Andric bool MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
176*480093f4SDimitry Andric   using namespace PatternMatch;
177*480093f4SDimitry Andric   LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n");
178*480093f4SDimitry Andric 
179*480093f4SDimitry Andric   // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
180*480093f4SDimitry Andric   // Attempt to turn the masked gather in I into a MVE intrinsic
181*480093f4SDimitry Andric   // Potentially optimising the addressing modes as we do so.
182*480093f4SDimitry Andric   Type *Ty = I->getType();
183*480093f4SDimitry Andric   Value *Ptr = I->getArgOperand(0);
184*480093f4SDimitry Andric   unsigned Alignment = cast<ConstantInt>(I->getArgOperand(1))->getZExtValue();
185*480093f4SDimitry Andric   Value *Mask = I->getArgOperand(2);
186*480093f4SDimitry Andric   Value *PassThru = I->getArgOperand(3);
187*480093f4SDimitry Andric 
188*480093f4SDimitry Andric   if (!isLegalTypeAndAlignment(Ty->getVectorNumElements(),
189*480093f4SDimitry Andric                                Ty->getScalarSizeInBits(), Alignment))
190*480093f4SDimitry Andric     return false;
191*480093f4SDimitry Andric   lookThroughBitcast(Ptr);
192*480093f4SDimitry Andric   assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
193*480093f4SDimitry Andric 
194*480093f4SDimitry Andric   IRBuilder<> Builder(I->getContext());
195*480093f4SDimitry Andric   Builder.SetInsertPoint(I);
196*480093f4SDimitry Andric   Builder.SetCurrentDebugLocation(I->getDebugLoc());
197*480093f4SDimitry Andric   Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Builder);
198*480093f4SDimitry Andric   if (!Load)
199*480093f4SDimitry Andric     Load = tryCreateMaskedGatherBase(I, Ptr, Builder);
200*480093f4SDimitry Andric   if (!Load)
201*480093f4SDimitry Andric     return false;
202*480093f4SDimitry Andric 
203*480093f4SDimitry Andric   if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) {
204*480093f4SDimitry Andric     LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
205*480093f4SDimitry Andric                       << "creating select\n");
206*480093f4SDimitry Andric     Load = Builder.CreateSelect(Mask, Load, PassThru);
207*480093f4SDimitry Andric   }
208*480093f4SDimitry Andric 
209*480093f4SDimitry Andric   LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n");
210*480093f4SDimitry Andric   I->replaceAllUsesWith(Load);
211*480093f4SDimitry Andric   I->eraseFromParent();
212*480093f4SDimitry Andric   return true;
213*480093f4SDimitry Andric }
214*480093f4SDimitry Andric 
215*480093f4SDimitry Andric Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase(
216*480093f4SDimitry Andric     IntrinsicInst *I, Value *Ptr, IRBuilder<> Builder) {
217*480093f4SDimitry Andric   using namespace PatternMatch;
218*480093f4SDimitry Andric   LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
219*480093f4SDimitry Andric   Type *Ty = I->getType();
220*480093f4SDimitry Andric   if (Ty->getVectorNumElements() != 4)
221*480093f4SDimitry Andric     // Can't build an intrinsic for this
222*480093f4SDimitry Andric     return nullptr;
223*480093f4SDimitry Andric   Value *Mask = I->getArgOperand(2);
224*480093f4SDimitry Andric   if (match(Mask, m_One()))
225*480093f4SDimitry Andric     return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
226*480093f4SDimitry Andric                                    {Ty, Ptr->getType()},
227*480093f4SDimitry Andric                                    {Ptr, Builder.getInt32(0)});
228*480093f4SDimitry Andric   else
229*480093f4SDimitry Andric     return Builder.CreateIntrinsic(
230*480093f4SDimitry Andric         Intrinsic::arm_mve_vldr_gather_base_predicated,
231*480093f4SDimitry Andric         {Ty, Ptr->getType(), Mask->getType()},
232*480093f4SDimitry Andric         {Ptr, Builder.getInt32(0), Mask});
233*480093f4SDimitry Andric }
234*480093f4SDimitry Andric 
235*480093f4SDimitry Andric Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
236*480093f4SDimitry Andric     IntrinsicInst *I, Value *Ptr, IRBuilder<> Builder) {
237*480093f4SDimitry Andric   using namespace PatternMatch;
238*480093f4SDimitry Andric   Type *Ty = I->getType();
239*480093f4SDimitry Andric   Value *Offsets;
240*480093f4SDimitry Andric   Value *BasePtr = checkGEP(Offsets, Ty, Ptr, Builder);
241*480093f4SDimitry Andric   if (!BasePtr)
242*480093f4SDimitry Andric     return nullptr;
243*480093f4SDimitry Andric 
244*480093f4SDimitry Andric   unsigned Scale;
245*480093f4SDimitry Andric   int GEPElemSize =
246*480093f4SDimitry Andric       BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits();
247*480093f4SDimitry Andric   int ResultElemSize = Ty->getScalarSizeInBits();
248*480093f4SDimitry Andric   // This can be a 32bit load scaled by 4, a 16bit load scaled by 2, or a
249*480093f4SDimitry Andric   // 8bit, 16bit or 32bit load scaled by 1
250*480093f4SDimitry Andric   if (GEPElemSize == 32 && ResultElemSize == 32) {
251*480093f4SDimitry Andric     Scale = 2;
252*480093f4SDimitry Andric   } else if (GEPElemSize == 16 && ResultElemSize == 16) {
253*480093f4SDimitry Andric     Scale = 1;
254*480093f4SDimitry Andric   } else if (GEPElemSize == 8) {
255*480093f4SDimitry Andric     Scale = 0;
256*480093f4SDimitry Andric   } else {
257*480093f4SDimitry Andric     LLVM_DEBUG(dbgs() << "masked gathers: incorrect scale for load. Can't"
258*480093f4SDimitry Andric                       << " create masked gather\n");
259*480093f4SDimitry Andric     return nullptr;
260*480093f4SDimitry Andric   }
261*480093f4SDimitry Andric 
262*480093f4SDimitry Andric   Value *Mask = I->getArgOperand(2);
263*480093f4SDimitry Andric   if (!match(Mask, m_One()))
264*480093f4SDimitry Andric     return Builder.CreateIntrinsic(
265*480093f4SDimitry Andric         Intrinsic::arm_mve_vldr_gather_offset_predicated,
266*480093f4SDimitry Andric         {Ty, BasePtr->getType(), Offsets->getType(), Mask->getType()},
267*480093f4SDimitry Andric         {BasePtr, Offsets, Builder.getInt32(Ty->getScalarSizeInBits()),
268*480093f4SDimitry Andric          Builder.getInt32(Scale), Builder.getInt32(1), Mask});
269*480093f4SDimitry Andric   else
270*480093f4SDimitry Andric     return Builder.CreateIntrinsic(
271*480093f4SDimitry Andric         Intrinsic::arm_mve_vldr_gather_offset,
272*480093f4SDimitry Andric         {Ty, BasePtr->getType(), Offsets->getType()},
273*480093f4SDimitry Andric         {BasePtr, Offsets, Builder.getInt32(Ty->getScalarSizeInBits()),
274*480093f4SDimitry Andric          Builder.getInt32(Scale), Builder.getInt32(1)});
275*480093f4SDimitry Andric }
276*480093f4SDimitry Andric 
277*480093f4SDimitry Andric bool MVEGatherScatterLowering::runOnFunction(Function &F) {
278*480093f4SDimitry Andric   if (!EnableMaskedGatherScatters)
279*480093f4SDimitry Andric     return false;
280*480093f4SDimitry Andric   auto &TPC = getAnalysis<TargetPassConfig>();
281*480093f4SDimitry Andric   auto &TM = TPC.getTM<TargetMachine>();
282*480093f4SDimitry Andric   auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
283*480093f4SDimitry Andric   if (!ST->hasMVEIntegerOps())
284*480093f4SDimitry Andric     return false;
285*480093f4SDimitry Andric   SmallVector<IntrinsicInst *, 4> Gathers;
286*480093f4SDimitry Andric   for (BasicBlock &BB : F) {
287*480093f4SDimitry Andric     for (Instruction &I : BB) {
288*480093f4SDimitry Andric       IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
289*480093f4SDimitry Andric       if (II && II->getIntrinsicID() == Intrinsic::masked_gather)
290*480093f4SDimitry Andric         Gathers.push_back(II);
291*480093f4SDimitry Andric     }
292*480093f4SDimitry Andric   }
293*480093f4SDimitry Andric 
294*480093f4SDimitry Andric   if (Gathers.empty())
295*480093f4SDimitry Andric     return false;
296*480093f4SDimitry Andric 
297*480093f4SDimitry Andric   for (IntrinsicInst *I : Gathers)
298*480093f4SDimitry Andric     lowerGather(I);
299*480093f4SDimitry Andric 
300*480093f4SDimitry Andric   return true;
301*480093f4SDimitry Andric }
302