1349cc55cSDimitry Andric //===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===//
2349cc55cSDimitry Andric //
3349cc55cSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4349cc55cSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5349cc55cSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6349cc55cSDimitry Andric //
7349cc55cSDimitry Andric //===----------------------------------------------------------------------===//
8349cc55cSDimitry Andric //
9349cc55cSDimitry Andric // This pass custom lowers llvm.gather and llvm.scatter instructions to
1006c3fb27SDimitry Andric // RISC-V intrinsics.
11349cc55cSDimitry Andric //
12349cc55cSDimitry Andric //===----------------------------------------------------------------------===//
13349cc55cSDimitry Andric
14349cc55cSDimitry Andric #include "RISCV.h"
15349cc55cSDimitry Andric #include "RISCVTargetMachine.h"
1606c3fb27SDimitry Andric #include "llvm/Analysis/InstSimplifyFolder.h"
17349cc55cSDimitry Andric #include "llvm/Analysis/LoopInfo.h"
18349cc55cSDimitry Andric #include "llvm/Analysis/ValueTracking.h"
19349cc55cSDimitry Andric #include "llvm/Analysis/VectorUtils.h"
20349cc55cSDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h"
21349cc55cSDimitry Andric #include "llvm/IR/GetElementPtrTypeIterator.h"
22349cc55cSDimitry Andric #include "llvm/IR/IRBuilder.h"
23349cc55cSDimitry Andric #include "llvm/IR/IntrinsicInst.h"
24349cc55cSDimitry Andric #include "llvm/IR/IntrinsicsRISCV.h"
25bdd1243dSDimitry Andric #include "llvm/IR/PatternMatch.h"
26349cc55cSDimitry Andric #include "llvm/Transforms/Utils/Local.h"
27bdd1243dSDimitry Andric #include <optional>
28349cc55cSDimitry Andric
29349cc55cSDimitry Andric using namespace llvm;
30bdd1243dSDimitry Andric using namespace PatternMatch;
31349cc55cSDimitry Andric
32349cc55cSDimitry Andric #define DEBUG_TYPE "riscv-gather-scatter-lowering"
33349cc55cSDimitry Andric
34349cc55cSDimitry Andric namespace {
35349cc55cSDimitry Andric
36349cc55cSDimitry Andric class RISCVGatherScatterLowering : public FunctionPass {
37349cc55cSDimitry Andric const RISCVSubtarget *ST = nullptr;
38349cc55cSDimitry Andric const RISCVTargetLowering *TLI = nullptr;
39349cc55cSDimitry Andric LoopInfo *LI = nullptr;
40349cc55cSDimitry Andric const DataLayout *DL = nullptr;
41349cc55cSDimitry Andric
42349cc55cSDimitry Andric SmallVector<WeakTrackingVH> MaybeDeadPHIs;
43349cc55cSDimitry Andric
4481ad6265SDimitry Andric // Cache of the BasePtr and Stride determined from this GEP. When a GEP is
4581ad6265SDimitry Andric // used by multiple gathers/scatters, this allow us to reuse the scalar
4681ad6265SDimitry Andric // instructions we created for the first gather/scatter for the others.
4781ad6265SDimitry Andric DenseMap<GetElementPtrInst *, std::pair<Value *, Value *>> StridedAddrs;
4881ad6265SDimitry Andric
49349cc55cSDimitry Andric public:
50349cc55cSDimitry Andric static char ID; // Pass identification, replacement for typeid
51349cc55cSDimitry Andric
RISCVGatherScatterLowering()52349cc55cSDimitry Andric RISCVGatherScatterLowering() : FunctionPass(ID) {}
53349cc55cSDimitry Andric
54349cc55cSDimitry Andric bool runOnFunction(Function &F) override;
55349cc55cSDimitry Andric
getAnalysisUsage(AnalysisUsage & AU) const56349cc55cSDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override {
57349cc55cSDimitry Andric AU.setPreservesCFG();
58349cc55cSDimitry Andric AU.addRequired<TargetPassConfig>();
59349cc55cSDimitry Andric AU.addRequired<LoopInfoWrapperPass>();
60349cc55cSDimitry Andric }
61349cc55cSDimitry Andric
getPassName() const62349cc55cSDimitry Andric StringRef getPassName() const override {
6306c3fb27SDimitry Andric return "RISC-V gather/scatter lowering";
64349cc55cSDimitry Andric }
65349cc55cSDimitry Andric
66349cc55cSDimitry Andric private:
67349cc55cSDimitry Andric bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr,
68349cc55cSDimitry Andric Value *AlignOp);
69349cc55cSDimitry Andric
705f757f3fSDimitry Andric std::pair<Value *, Value *> determineBaseAndStride(Instruction *Ptr,
7106c3fb27SDimitry Andric IRBuilderBase &Builder);
72349cc55cSDimitry Andric
73349cc55cSDimitry Andric bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride,
74349cc55cSDimitry Andric PHINode *&BasePtr, BinaryOperator *&Inc,
7506c3fb27SDimitry Andric IRBuilderBase &Builder);
76349cc55cSDimitry Andric };
77349cc55cSDimitry Andric
78349cc55cSDimitry Andric } // end anonymous namespace
79349cc55cSDimitry Andric
80349cc55cSDimitry Andric char RISCVGatherScatterLowering::ID = 0;
81349cc55cSDimitry Andric
82349cc55cSDimitry Andric INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE,
8306c3fb27SDimitry Andric "RISC-V gather/scatter lowering pass", false, false)
84349cc55cSDimitry Andric
createRISCVGatherScatterLoweringPass()85349cc55cSDimitry Andric FunctionPass *llvm::createRISCVGatherScatterLoweringPass() {
86349cc55cSDimitry Andric return new RISCVGatherScatterLowering();
87349cc55cSDimitry Andric }
88349cc55cSDimitry Andric
89349cc55cSDimitry Andric // TODO: Should we consider the mask when looking for a stride?
matchStridedConstant(Constant * StartC)90349cc55cSDimitry Andric static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) {
9106c3fb27SDimitry Andric if (!isa<FixedVectorType>(StartC->getType()))
9206c3fb27SDimitry Andric return std::make_pair(nullptr, nullptr);
9306c3fb27SDimitry Andric
94349cc55cSDimitry Andric unsigned NumElts = cast<FixedVectorType>(StartC->getType())->getNumElements();
95349cc55cSDimitry Andric
96349cc55cSDimitry Andric // Check that the start value is a strided constant.
97349cc55cSDimitry Andric auto *StartVal =
98349cc55cSDimitry Andric dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement((unsigned)0));
99349cc55cSDimitry Andric if (!StartVal)
100349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr);
101349cc55cSDimitry Andric APInt StrideVal(StartVal->getValue().getBitWidth(), 0);
102349cc55cSDimitry Andric ConstantInt *Prev = StartVal;
103349cc55cSDimitry Andric for (unsigned i = 1; i != NumElts; ++i) {
104349cc55cSDimitry Andric auto *C = dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement(i));
105349cc55cSDimitry Andric if (!C)
106349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr);
107349cc55cSDimitry Andric
108349cc55cSDimitry Andric APInt LocalStride = C->getValue() - Prev->getValue();
109349cc55cSDimitry Andric if (i == 1)
110349cc55cSDimitry Andric StrideVal = LocalStride;
111349cc55cSDimitry Andric else if (StrideVal != LocalStride)
112349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr);
113349cc55cSDimitry Andric
114349cc55cSDimitry Andric Prev = C;
115349cc55cSDimitry Andric }
116349cc55cSDimitry Andric
117349cc55cSDimitry Andric Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal);
118349cc55cSDimitry Andric
119349cc55cSDimitry Andric return std::make_pair(StartVal, Stride);
120349cc55cSDimitry Andric }
121349cc55cSDimitry Andric
matchStridedStart(Value * Start,IRBuilderBase & Builder)12204eeddc0SDimitry Andric static std::pair<Value *, Value *> matchStridedStart(Value *Start,
12306c3fb27SDimitry Andric IRBuilderBase &Builder) {
12404eeddc0SDimitry Andric // Base case, start is a strided constant.
12504eeddc0SDimitry Andric auto *StartC = dyn_cast<Constant>(Start);
12604eeddc0SDimitry Andric if (StartC)
12704eeddc0SDimitry Andric return matchStridedConstant(StartC);
12804eeddc0SDimitry Andric
129bdd1243dSDimitry Andric // Base case, start is a stepvector
130bdd1243dSDimitry Andric if (match(Start, m_Intrinsic<Intrinsic::experimental_stepvector>())) {
131bdd1243dSDimitry Andric auto *Ty = Start->getType()->getScalarType();
132bdd1243dSDimitry Andric return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1));
133bdd1243dSDimitry Andric }
134bdd1243dSDimitry Andric
13506c3fb27SDimitry Andric // Not a constant, maybe it's a strided constant with a splat added or
13606c3fb27SDimitry Andric // multipled.
13704eeddc0SDimitry Andric auto *BO = dyn_cast<BinaryOperator>(Start);
13806c3fb27SDimitry Andric if (!BO || (BO->getOpcode() != Instruction::Add &&
1397a6dacacSDimitry Andric BO->getOpcode() != Instruction::Or &&
14006c3fb27SDimitry Andric BO->getOpcode() != Instruction::Shl &&
14106c3fb27SDimitry Andric BO->getOpcode() != Instruction::Mul))
14204eeddc0SDimitry Andric return std::make_pair(nullptr, nullptr);
14304eeddc0SDimitry Andric
1447a6dacacSDimitry Andric if (BO->getOpcode() == Instruction::Or &&
1457a6dacacSDimitry Andric !cast<PossiblyDisjointInst>(BO)->isDisjoint())
1467a6dacacSDimitry Andric return std::make_pair(nullptr, nullptr);
1477a6dacacSDimitry Andric
14804eeddc0SDimitry Andric // Look for an operand that is splatted.
14906c3fb27SDimitry Andric unsigned OtherIndex = 0;
15006c3fb27SDimitry Andric Value *Splat = getSplatValue(BO->getOperand(1));
15106c3fb27SDimitry Andric if (!Splat && Instruction::isCommutative(BO->getOpcode())) {
15206c3fb27SDimitry Andric Splat = getSplatValue(BO->getOperand(0));
15306c3fb27SDimitry Andric OtherIndex = 1;
15404eeddc0SDimitry Andric }
15504eeddc0SDimitry Andric if (!Splat)
15604eeddc0SDimitry Andric return std::make_pair(nullptr, nullptr);
15704eeddc0SDimitry Andric
15804eeddc0SDimitry Andric Value *Stride;
15904eeddc0SDimitry Andric std::tie(Start, Stride) = matchStridedStart(BO->getOperand(OtherIndex),
16004eeddc0SDimitry Andric Builder);
16104eeddc0SDimitry Andric if (!Start)
16204eeddc0SDimitry Andric return std::make_pair(nullptr, nullptr);
16304eeddc0SDimitry Andric
16404eeddc0SDimitry Andric Builder.SetInsertPoint(BO);
16504eeddc0SDimitry Andric Builder.SetCurrentDebugLocation(DebugLoc());
16606c3fb27SDimitry Andric // Add the splat value to the start or multiply the start and stride by the
16706c3fb27SDimitry Andric // splat.
16806c3fb27SDimitry Andric switch (BO->getOpcode()) {
16906c3fb27SDimitry Andric default:
17006c3fb27SDimitry Andric llvm_unreachable("Unexpected opcode");
1717a6dacacSDimitry Andric case Instruction::Or:
1727a6dacacSDimitry Andric // TODO: We'd be better off creating disjoint or here, but we don't yet
1737a6dacacSDimitry Andric // have an IRBuilder API for that.
1747a6dacacSDimitry Andric [[fallthrough]];
17506c3fb27SDimitry Andric case Instruction::Add:
17604eeddc0SDimitry Andric Start = Builder.CreateAdd(Start, Splat);
17706c3fb27SDimitry Andric break;
17806c3fb27SDimitry Andric case Instruction::Mul:
17906c3fb27SDimitry Andric Start = Builder.CreateMul(Start, Splat);
18006c3fb27SDimitry Andric Stride = Builder.CreateMul(Stride, Splat);
18106c3fb27SDimitry Andric break;
18206c3fb27SDimitry Andric case Instruction::Shl:
18306c3fb27SDimitry Andric Start = Builder.CreateShl(Start, Splat);
18406c3fb27SDimitry Andric Stride = Builder.CreateShl(Stride, Splat);
18506c3fb27SDimitry Andric break;
18606c3fb27SDimitry Andric }
18706c3fb27SDimitry Andric
18804eeddc0SDimitry Andric return std::make_pair(Start, Stride);
18904eeddc0SDimitry Andric }
19004eeddc0SDimitry Andric
191349cc55cSDimitry Andric // Recursively, walk about the use-def chain until we find a Phi with a strided
192349cc55cSDimitry Andric // start value. Build and update a scalar recurrence as we unwind the recursion.
193349cc55cSDimitry Andric // We also update the Stride as we unwind. Our goal is to move all of the
194349cc55cSDimitry Andric // arithmetic out of the loop.
matchStridedRecurrence(Value * Index,Loop * L,Value * & Stride,PHINode * & BasePtr,BinaryOperator * & Inc,IRBuilderBase & Builder)195349cc55cSDimitry Andric bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
196349cc55cSDimitry Andric Value *&Stride,
197349cc55cSDimitry Andric PHINode *&BasePtr,
198349cc55cSDimitry Andric BinaryOperator *&Inc,
19906c3fb27SDimitry Andric IRBuilderBase &Builder) {
200349cc55cSDimitry Andric // Our base case is a Phi.
201349cc55cSDimitry Andric if (auto *Phi = dyn_cast<PHINode>(Index)) {
202349cc55cSDimitry Andric // A phi node we want to perform this function on should be from the
203349cc55cSDimitry Andric // loop header.
204349cc55cSDimitry Andric if (Phi->getParent() != L->getHeader())
205349cc55cSDimitry Andric return false;
206349cc55cSDimitry Andric
207349cc55cSDimitry Andric Value *Step, *Start;
208349cc55cSDimitry Andric if (!matchSimpleRecurrence(Phi, Inc, Start, Step) ||
209349cc55cSDimitry Andric Inc->getOpcode() != Instruction::Add)
210349cc55cSDimitry Andric return false;
211349cc55cSDimitry Andric assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
212349cc55cSDimitry Andric unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1;
213349cc55cSDimitry Andric assert(Phi->getIncomingValue(IncrementingBlock) == Inc &&
214349cc55cSDimitry Andric "Expected one operand of phi to be Inc");
215349cc55cSDimitry Andric
216349cc55cSDimitry Andric // Only proceed if the step is loop invariant.
217349cc55cSDimitry Andric if (!L->isLoopInvariant(Step))
218349cc55cSDimitry Andric return false;
219349cc55cSDimitry Andric
220349cc55cSDimitry Andric // Step should be a splat.
221349cc55cSDimitry Andric Step = getSplatValue(Step);
222349cc55cSDimitry Andric if (!Step)
223349cc55cSDimitry Andric return false;
224349cc55cSDimitry Andric
22504eeddc0SDimitry Andric std::tie(Start, Stride) = matchStridedStart(Start, Builder);
226349cc55cSDimitry Andric if (!Start)
227349cc55cSDimitry Andric return false;
228349cc55cSDimitry Andric assert(Stride != nullptr);
229349cc55cSDimitry Andric
230349cc55cSDimitry Andric // Build scalar phi and increment.
231349cc55cSDimitry Andric BasePtr =
232*0fca6ea1SDimitry Andric PHINode::Create(Start->getType(), 2, Phi->getName() + ".scalar", Phi->getIterator());
233349cc55cSDimitry Andric Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar",
234*0fca6ea1SDimitry Andric Inc->getIterator());
235349cc55cSDimitry Andric BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock));
236349cc55cSDimitry Andric BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock));
237349cc55cSDimitry Andric
238349cc55cSDimitry Andric // Note that this Phi might be eligible for removal.
239349cc55cSDimitry Andric MaybeDeadPHIs.push_back(Phi);
240349cc55cSDimitry Andric return true;
241349cc55cSDimitry Andric }
242349cc55cSDimitry Andric
243349cc55cSDimitry Andric // Otherwise look for binary operator.
244349cc55cSDimitry Andric auto *BO = dyn_cast<BinaryOperator>(Index);
245349cc55cSDimitry Andric if (!BO)
246349cc55cSDimitry Andric return false;
247349cc55cSDimitry Andric
24806c3fb27SDimitry Andric switch (BO->getOpcode()) {
24906c3fb27SDimitry Andric default:
250349cc55cSDimitry Andric return false;
25106c3fb27SDimitry Andric case Instruction::Or:
252349cc55cSDimitry Andric // We need to be able to treat Or as Add.
2537a6dacacSDimitry Andric if (!cast<PossiblyDisjointInst>(BO)->isDisjoint())
254349cc55cSDimitry Andric return false;
25506c3fb27SDimitry Andric break;
25606c3fb27SDimitry Andric case Instruction::Add:
25706c3fb27SDimitry Andric break;
25806c3fb27SDimitry Andric case Instruction::Shl:
25906c3fb27SDimitry Andric break;
26006c3fb27SDimitry Andric case Instruction::Mul:
26106c3fb27SDimitry Andric break;
26206c3fb27SDimitry Andric }
263349cc55cSDimitry Andric
264349cc55cSDimitry Andric // We should have one operand in the loop and one splat.
265349cc55cSDimitry Andric Value *OtherOp;
266349cc55cSDimitry Andric if (isa<Instruction>(BO->getOperand(0)) &&
267349cc55cSDimitry Andric L->contains(cast<Instruction>(BO->getOperand(0)))) {
268349cc55cSDimitry Andric Index = cast<Instruction>(BO->getOperand(0));
269349cc55cSDimitry Andric OtherOp = BO->getOperand(1);
270349cc55cSDimitry Andric } else if (isa<Instruction>(BO->getOperand(1)) &&
27106c3fb27SDimitry Andric L->contains(cast<Instruction>(BO->getOperand(1))) &&
27206c3fb27SDimitry Andric Instruction::isCommutative(BO->getOpcode())) {
273349cc55cSDimitry Andric Index = cast<Instruction>(BO->getOperand(1));
274349cc55cSDimitry Andric OtherOp = BO->getOperand(0);
275349cc55cSDimitry Andric } else {
276349cc55cSDimitry Andric return false;
277349cc55cSDimitry Andric }
278349cc55cSDimitry Andric
279349cc55cSDimitry Andric // Make sure other op is loop invariant.
280349cc55cSDimitry Andric if (!L->isLoopInvariant(OtherOp))
281349cc55cSDimitry Andric return false;
282349cc55cSDimitry Andric
283349cc55cSDimitry Andric // Make sure we have a splat.
284349cc55cSDimitry Andric Value *SplatOp = getSplatValue(OtherOp);
285349cc55cSDimitry Andric if (!SplatOp)
286349cc55cSDimitry Andric return false;
287349cc55cSDimitry Andric
288349cc55cSDimitry Andric // Recurse up the use-def chain.
289349cc55cSDimitry Andric if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder))
290349cc55cSDimitry Andric return false;
291349cc55cSDimitry Andric
292349cc55cSDimitry Andric // Locate the Step and Start values from the recurrence.
293349cc55cSDimitry Andric unsigned StepIndex = Inc->getOperand(0) == BasePtr ? 1 : 0;
294349cc55cSDimitry Andric unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0;
295349cc55cSDimitry Andric Value *Step = Inc->getOperand(StepIndex);
296349cc55cSDimitry Andric Value *Start = BasePtr->getOperand(StartBlock);
297349cc55cSDimitry Andric
298349cc55cSDimitry Andric // We need to adjust the start value in the preheader.
299349cc55cSDimitry Andric Builder.SetInsertPoint(
300349cc55cSDimitry Andric BasePtr->getIncomingBlock(StartBlock)->getTerminator());
301349cc55cSDimitry Andric Builder.SetCurrentDebugLocation(DebugLoc());
302349cc55cSDimitry Andric
303349cc55cSDimitry Andric switch (BO->getOpcode()) {
304349cc55cSDimitry Andric default:
305349cc55cSDimitry Andric llvm_unreachable("Unexpected opcode!");
306349cc55cSDimitry Andric case Instruction::Add:
307349cc55cSDimitry Andric case Instruction::Or: {
308349cc55cSDimitry Andric // An add only affects the start value. It's ok to do this for Or because
309349cc55cSDimitry Andric // we already checked that there are no common set bits.
310349cc55cSDimitry Andric Start = Builder.CreateAdd(Start, SplatOp, "start");
311349cc55cSDimitry Andric break;
312349cc55cSDimitry Andric }
313349cc55cSDimitry Andric case Instruction::Mul: {
314349cc55cSDimitry Andric Start = Builder.CreateMul(Start, SplatOp, "start");
315349cc55cSDimitry Andric Step = Builder.CreateMul(Step, SplatOp, "step");
316349cc55cSDimitry Andric Stride = Builder.CreateMul(Stride, SplatOp, "stride");
317349cc55cSDimitry Andric break;
318349cc55cSDimitry Andric }
319349cc55cSDimitry Andric case Instruction::Shl: {
320349cc55cSDimitry Andric Start = Builder.CreateShl(Start, SplatOp, "start");
321349cc55cSDimitry Andric Step = Builder.CreateShl(Step, SplatOp, "step");
322349cc55cSDimitry Andric Stride = Builder.CreateShl(Stride, SplatOp, "stride");
323349cc55cSDimitry Andric break;
324349cc55cSDimitry Andric }
325349cc55cSDimitry Andric }
326349cc55cSDimitry Andric
32706c3fb27SDimitry Andric Inc->setOperand(StepIndex, Step);
32806c3fb27SDimitry Andric BasePtr->setIncomingValue(StartBlock, Start);
329349cc55cSDimitry Andric return true;
330349cc55cSDimitry Andric }
331349cc55cSDimitry Andric
332349cc55cSDimitry Andric std::pair<Value *, Value *>
determineBaseAndStride(Instruction * Ptr,IRBuilderBase & Builder)3335f757f3fSDimitry Andric RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr,
33406c3fb27SDimitry Andric IRBuilderBase &Builder) {
335349cc55cSDimitry Andric
3365f757f3fSDimitry Andric // A gather/scatter of a splat is a zero strided load/store.
3375f757f3fSDimitry Andric if (auto *BasePtr = getSplatValue(Ptr)) {
3385f757f3fSDimitry Andric Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
3395f757f3fSDimitry Andric return std::make_pair(BasePtr, ConstantInt::get(IntPtrTy, 0));
3405f757f3fSDimitry Andric }
3415f757f3fSDimitry Andric
3425f757f3fSDimitry Andric auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
3435f757f3fSDimitry Andric if (!GEP)
3445f757f3fSDimitry Andric return std::make_pair(nullptr, nullptr);
3455f757f3fSDimitry Andric
34681ad6265SDimitry Andric auto I = StridedAddrs.find(GEP);
34781ad6265SDimitry Andric if (I != StridedAddrs.end())
34881ad6265SDimitry Andric return I->second;
34981ad6265SDimitry Andric
350349cc55cSDimitry Andric SmallVector<Value *, 2> Ops(GEP->operands());
351349cc55cSDimitry Andric
352*0fca6ea1SDimitry Andric // If the base pointer is a vector, check if it's strided.
353*0fca6ea1SDimitry Andric Value *Base = GEP->getPointerOperand();
354*0fca6ea1SDimitry Andric if (auto *BaseInst = dyn_cast<Instruction>(Base);
355*0fca6ea1SDimitry Andric BaseInst && BaseInst->getType()->isVectorTy()) {
356*0fca6ea1SDimitry Andric // If GEP's offset is scalar then we can add it to the base pointer's base.
357*0fca6ea1SDimitry Andric auto IsScalar = [](Value *Idx) { return !Idx->getType()->isVectorTy(); };
358*0fca6ea1SDimitry Andric if (all_of(GEP->indices(), IsScalar)) {
359*0fca6ea1SDimitry Andric auto [BaseBase, Stride] = determineBaseAndStride(BaseInst, Builder);
360*0fca6ea1SDimitry Andric if (BaseBase) {
361*0fca6ea1SDimitry Andric Builder.SetInsertPoint(GEP);
362*0fca6ea1SDimitry Andric SmallVector<Value *> Indices(GEP->indices());
363*0fca6ea1SDimitry Andric Value *OffsetBase =
364*0fca6ea1SDimitry Andric Builder.CreateGEP(GEP->getSourceElementType(), BaseBase, Indices,
365*0fca6ea1SDimitry Andric GEP->getName() + "offset", GEP->isInBounds());
366*0fca6ea1SDimitry Andric return {OffsetBase, Stride};
367*0fca6ea1SDimitry Andric }
368*0fca6ea1SDimitry Andric }
369*0fca6ea1SDimitry Andric }
370*0fca6ea1SDimitry Andric
371349cc55cSDimitry Andric // Base pointer needs to be a scalar.
372*0fca6ea1SDimitry Andric Value *ScalarBase = Base;
3735f757f3fSDimitry Andric if (ScalarBase->getType()->isVectorTy()) {
3745f757f3fSDimitry Andric ScalarBase = getSplatValue(ScalarBase);
3755f757f3fSDimitry Andric if (!ScalarBase)
376349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr);
3775f757f3fSDimitry Andric }
378349cc55cSDimitry Andric
379bdd1243dSDimitry Andric std::optional<unsigned> VecOperand;
380349cc55cSDimitry Andric unsigned TypeScale = 0;
381349cc55cSDimitry Andric
382349cc55cSDimitry Andric // Look for a vector operand and scale.
383349cc55cSDimitry Andric gep_type_iterator GTI = gep_type_begin(GEP);
384349cc55cSDimitry Andric for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) {
385349cc55cSDimitry Andric if (!Ops[i]->getType()->isVectorTy())
386349cc55cSDimitry Andric continue;
387349cc55cSDimitry Andric
388349cc55cSDimitry Andric if (VecOperand)
389349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr);
390349cc55cSDimitry Andric
391349cc55cSDimitry Andric VecOperand = i;
392349cc55cSDimitry Andric
3931db9f3b2SDimitry Andric TypeSize TS = GTI.getSequentialElementStride(*DL);
394349cc55cSDimitry Andric if (TS.isScalable())
395349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr);
396349cc55cSDimitry Andric
397bdd1243dSDimitry Andric TypeScale = TS.getFixedValue();
398349cc55cSDimitry Andric }
399349cc55cSDimitry Andric
400349cc55cSDimitry Andric // We need to find a vector index to simplify.
401349cc55cSDimitry Andric if (!VecOperand)
402349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr);
403349cc55cSDimitry Andric
404349cc55cSDimitry Andric // We can't extract the stride if the arithmetic is done at a different size
405349cc55cSDimitry Andric // than the pointer type. Adding the stride later may not wrap correctly.
406349cc55cSDimitry Andric // Technically we could handle wider indices, but I don't expect that in
4075f757f3fSDimitry Andric // practice. Handle one special case here - constants. This simplifies
4085f757f3fSDimitry Andric // writing test cases.
409349cc55cSDimitry Andric Value *VecIndex = Ops[*VecOperand];
410349cc55cSDimitry Andric Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType());
4115f757f3fSDimitry Andric if (VecIndex->getType() != VecIntPtrTy) {
4125f757f3fSDimitry Andric auto *VecIndexC = dyn_cast<Constant>(VecIndex);
4135f757f3fSDimitry Andric if (!VecIndexC)
414349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr);
4155f757f3fSDimitry Andric if (VecIndex->getType()->getScalarSizeInBits() > VecIntPtrTy->getScalarSizeInBits())
4165f757f3fSDimitry Andric VecIndex = ConstantFoldCastInstruction(Instruction::Trunc, VecIndexC, VecIntPtrTy);
4175f757f3fSDimitry Andric else
4185f757f3fSDimitry Andric VecIndex = ConstantFoldCastInstruction(Instruction::SExt, VecIndexC, VecIntPtrTy);
4195f757f3fSDimitry Andric }
420349cc55cSDimitry Andric
421bdd1243dSDimitry Andric // Handle the non-recursive case. This is what we see if the vectorizer
422bdd1243dSDimitry Andric // decides to use a scalar IV + vid on demand instead of a vector IV.
423bdd1243dSDimitry Andric auto [Start, Stride] = matchStridedStart(VecIndex, Builder);
424bdd1243dSDimitry Andric if (Start) {
425bdd1243dSDimitry Andric assert(Stride);
426bdd1243dSDimitry Andric Builder.SetInsertPoint(GEP);
427bdd1243dSDimitry Andric
428bdd1243dSDimitry Andric // Replace the vector index with the scalar start and build a scalar GEP.
429bdd1243dSDimitry Andric Ops[*VecOperand] = Start;
430bdd1243dSDimitry Andric Type *SourceTy = GEP->getSourceElementType();
431bdd1243dSDimitry Andric Value *BasePtr =
4325f757f3fSDimitry Andric Builder.CreateGEP(SourceTy, ScalarBase, ArrayRef(Ops).drop_front());
433bdd1243dSDimitry Andric
434bdd1243dSDimitry Andric // Convert stride to pointer size if needed.
435bdd1243dSDimitry Andric Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
436bdd1243dSDimitry Andric assert(Stride->getType() == IntPtrTy && "Unexpected type");
437bdd1243dSDimitry Andric
438bdd1243dSDimitry Andric // Scale the stride by the size of the indexed type.
439bdd1243dSDimitry Andric if (TypeScale != 1)
440bdd1243dSDimitry Andric Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
441bdd1243dSDimitry Andric
442bdd1243dSDimitry Andric auto P = std::make_pair(BasePtr, Stride);
443bdd1243dSDimitry Andric StridedAddrs[GEP] = P;
444bdd1243dSDimitry Andric return P;
445bdd1243dSDimitry Andric }
446bdd1243dSDimitry Andric
447bdd1243dSDimitry Andric // Make sure we're in a loop and that has a pre-header and a single latch.
448bdd1243dSDimitry Andric Loop *L = LI->getLoopFor(GEP->getParent());
449bdd1243dSDimitry Andric if (!L || !L->getLoopPreheader() || !L->getLoopLatch())
450bdd1243dSDimitry Andric return std::make_pair(nullptr, nullptr);
451bdd1243dSDimitry Andric
452349cc55cSDimitry Andric BinaryOperator *Inc;
453349cc55cSDimitry Andric PHINode *BasePhi;
454349cc55cSDimitry Andric if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder))
455349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr);
456349cc55cSDimitry Andric
457349cc55cSDimitry Andric assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
458349cc55cSDimitry Andric unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1;
459349cc55cSDimitry Andric assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc &&
460349cc55cSDimitry Andric "Expected one operand of phi to be Inc");
461349cc55cSDimitry Andric
462349cc55cSDimitry Andric Builder.SetInsertPoint(GEP);
463349cc55cSDimitry Andric
464349cc55cSDimitry Andric // Replace the vector index with the scalar phi and build a scalar GEP.
465349cc55cSDimitry Andric Ops[*VecOperand] = BasePhi;
466349cc55cSDimitry Andric Type *SourceTy = GEP->getSourceElementType();
467349cc55cSDimitry Andric Value *BasePtr =
4685f757f3fSDimitry Andric Builder.CreateGEP(SourceTy, ScalarBase, ArrayRef(Ops).drop_front());
469349cc55cSDimitry Andric
470349cc55cSDimitry Andric // Final adjustments to stride should go in the start block.
471349cc55cSDimitry Andric Builder.SetInsertPoint(
472349cc55cSDimitry Andric BasePhi->getIncomingBlock(1 - IncrementingBlock)->getTerminator());
473349cc55cSDimitry Andric
474349cc55cSDimitry Andric // Convert stride to pointer size if needed.
475349cc55cSDimitry Andric Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
476349cc55cSDimitry Andric assert(Stride->getType() == IntPtrTy && "Unexpected type");
477349cc55cSDimitry Andric
478349cc55cSDimitry Andric // Scale the stride by the size of the indexed type.
479349cc55cSDimitry Andric if (TypeScale != 1)
480349cc55cSDimitry Andric Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
481349cc55cSDimitry Andric
48281ad6265SDimitry Andric auto P = std::make_pair(BasePtr, Stride);
48381ad6265SDimitry Andric StridedAddrs[GEP] = P;
48481ad6265SDimitry Andric return P;
485349cc55cSDimitry Andric }
486349cc55cSDimitry Andric
tryCreateStridedLoadStore(IntrinsicInst * II,Type * DataType,Value * Ptr,Value * AlignOp)487349cc55cSDimitry Andric bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
488349cc55cSDimitry Andric Type *DataType,
489349cc55cSDimitry Andric Value *Ptr,
490349cc55cSDimitry Andric Value *AlignOp) {
491349cc55cSDimitry Andric // Make sure the operation will be supported by the backend.
49206c3fb27SDimitry Andric MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue();
49306c3fb27SDimitry Andric EVT DataTypeVT = TLI->getValueType(*DL, DataType);
49406c3fb27SDimitry Andric if (!MA || !TLI->isLegalStridedLoadStore(DataTypeVT, *MA))
49506c3fb27SDimitry Andric return false;
49606c3fb27SDimitry Andric
49706c3fb27SDimitry Andric // FIXME: Let the backend type legalize by splitting/widening?
49806c3fb27SDimitry Andric if (!TLI->isTypeLegal(DataTypeVT))
499349cc55cSDimitry Andric return false;
500349cc55cSDimitry Andric
5015f757f3fSDimitry Andric // Pointer should be an instruction.
5025f757f3fSDimitry Andric auto *PtrI = dyn_cast<Instruction>(Ptr);
5035f757f3fSDimitry Andric if (!PtrI)
504349cc55cSDimitry Andric return false;
505349cc55cSDimitry Andric
5065f757f3fSDimitry Andric LLVMContext &Ctx = PtrI->getContext();
50706c3fb27SDimitry Andric IRBuilder<InstSimplifyFolder> Builder(Ctx, *DL);
5085f757f3fSDimitry Andric Builder.SetInsertPoint(PtrI);
509349cc55cSDimitry Andric
510349cc55cSDimitry Andric Value *BasePtr, *Stride;
5115f757f3fSDimitry Andric std::tie(BasePtr, Stride) = determineBaseAndStride(PtrI, Builder);
512349cc55cSDimitry Andric if (!BasePtr)
513349cc55cSDimitry Andric return false;
514349cc55cSDimitry Andric assert(Stride != nullptr);
515349cc55cSDimitry Andric
516349cc55cSDimitry Andric Builder.SetInsertPoint(II);
517349cc55cSDimitry Andric
518349cc55cSDimitry Andric CallInst *Call;
519349cc55cSDimitry Andric if (II->getIntrinsicID() == Intrinsic::masked_gather)
520349cc55cSDimitry Andric Call = Builder.CreateIntrinsic(
521349cc55cSDimitry Andric Intrinsic::riscv_masked_strided_load,
522349cc55cSDimitry Andric {DataType, BasePtr->getType(), Stride->getType()},
523349cc55cSDimitry Andric {II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)});
524349cc55cSDimitry Andric else
525349cc55cSDimitry Andric Call = Builder.CreateIntrinsic(
526349cc55cSDimitry Andric Intrinsic::riscv_masked_strided_store,
527349cc55cSDimitry Andric {DataType, BasePtr->getType(), Stride->getType()},
528349cc55cSDimitry Andric {II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)});
529349cc55cSDimitry Andric
530349cc55cSDimitry Andric Call->takeName(II);
531349cc55cSDimitry Andric II->replaceAllUsesWith(Call);
532349cc55cSDimitry Andric II->eraseFromParent();
533349cc55cSDimitry Andric
5345f757f3fSDimitry Andric if (PtrI->use_empty())
5355f757f3fSDimitry Andric RecursivelyDeleteTriviallyDeadInstructions(PtrI);
536349cc55cSDimitry Andric
537349cc55cSDimitry Andric return true;
538349cc55cSDimitry Andric }
539349cc55cSDimitry Andric
runOnFunction(Function & F)540349cc55cSDimitry Andric bool RISCVGatherScatterLowering::runOnFunction(Function &F) {
541349cc55cSDimitry Andric if (skipFunction(F))
542349cc55cSDimitry Andric return false;
543349cc55cSDimitry Andric
544349cc55cSDimitry Andric auto &TPC = getAnalysis<TargetPassConfig>();
545349cc55cSDimitry Andric auto &TM = TPC.getTM<RISCVTargetMachine>();
546349cc55cSDimitry Andric ST = &TM.getSubtarget<RISCVSubtarget>(F);
547349cc55cSDimitry Andric if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors())
548349cc55cSDimitry Andric return false;
549349cc55cSDimitry Andric
550349cc55cSDimitry Andric TLI = ST->getTargetLowering();
551*0fca6ea1SDimitry Andric DL = &F.getDataLayout();
552349cc55cSDimitry Andric LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
553349cc55cSDimitry Andric
55481ad6265SDimitry Andric StridedAddrs.clear();
55581ad6265SDimitry Andric
556349cc55cSDimitry Andric SmallVector<IntrinsicInst *, 4> Gathers;
557349cc55cSDimitry Andric SmallVector<IntrinsicInst *, 4> Scatters;
558349cc55cSDimitry Andric
559349cc55cSDimitry Andric bool Changed = false;
560349cc55cSDimitry Andric
561349cc55cSDimitry Andric for (BasicBlock &BB : F) {
562349cc55cSDimitry Andric for (Instruction &I : BB) {
563349cc55cSDimitry Andric IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
564bdd1243dSDimitry Andric if (II && II->getIntrinsicID() == Intrinsic::masked_gather) {
565349cc55cSDimitry Andric Gathers.push_back(II);
566bdd1243dSDimitry Andric } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) {
567349cc55cSDimitry Andric Scatters.push_back(II);
568349cc55cSDimitry Andric }
569349cc55cSDimitry Andric }
570349cc55cSDimitry Andric }
571349cc55cSDimitry Andric
572349cc55cSDimitry Andric // Rewrite gather/scatter to form strided load/store if possible.
573349cc55cSDimitry Andric for (auto *II : Gathers)
574349cc55cSDimitry Andric Changed |= tryCreateStridedLoadStore(
575349cc55cSDimitry Andric II, II->getType(), II->getArgOperand(0), II->getArgOperand(1));
576349cc55cSDimitry Andric for (auto *II : Scatters)
577349cc55cSDimitry Andric Changed |=
578349cc55cSDimitry Andric tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(),
579349cc55cSDimitry Andric II->getArgOperand(1), II->getArgOperand(2));
580349cc55cSDimitry Andric
581349cc55cSDimitry Andric // Remove any dead phis.
582349cc55cSDimitry Andric while (!MaybeDeadPHIs.empty()) {
583349cc55cSDimitry Andric if (auto *Phi = dyn_cast_or_null<PHINode>(MaybeDeadPHIs.pop_back_val()))
584349cc55cSDimitry Andric RecursivelyDeleteDeadPHINode(Phi);
585349cc55cSDimitry Andric }
586349cc55cSDimitry Andric
587349cc55cSDimitry Andric return Changed;
588349cc55cSDimitry Andric }
589