xref: /freebsd/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp (revision 02e9120893770924227138ba49df1edb3896112a)
1 //===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass custom lowers llvm.gather and llvm.scatter instructions to
10 // RISC-V intrinsics.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "RISCV.h"
15 #include "RISCVTargetMachine.h"
16 #include "llvm/Analysis/InstSimplifyFolder.h"
17 #include "llvm/Analysis/LoopInfo.h"
18 #include "llvm/Analysis/ValueTracking.h"
19 #include "llvm/Analysis/VectorUtils.h"
20 #include "llvm/CodeGen/TargetPassConfig.h"
21 #include "llvm/IR/GetElementPtrTypeIterator.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/IntrinsicInst.h"
24 #include "llvm/IR/IntrinsicsRISCV.h"
25 #include "llvm/IR/PatternMatch.h"
26 #include "llvm/Transforms/Utils/Local.h"
27 #include <optional>
28 
29 using namespace llvm;
30 using namespace PatternMatch;
31 
32 #define DEBUG_TYPE "riscv-gather-scatter-lowering"
33 
34 namespace {
35 
36 class RISCVGatherScatterLowering : public FunctionPass {
37   const RISCVSubtarget *ST = nullptr;
38   const RISCVTargetLowering *TLI = nullptr;
39   LoopInfo *LI = nullptr;
40   const DataLayout *DL = nullptr;
41 
42   SmallVector<WeakTrackingVH> MaybeDeadPHIs;
43 
44   // Cache of the BasePtr and Stride determined from this GEP. When a GEP is
45   // used by multiple gathers/scatters, this allow us to reuse the scalar
46   // instructions we created for the first gather/scatter for the others.
47   DenseMap<GetElementPtrInst *, std::pair<Value *, Value *>> StridedAddrs;
48 
49 public:
50   static char ID; // Pass identification, replacement for typeid
51 
52   RISCVGatherScatterLowering() : FunctionPass(ID) {}
53 
54   bool runOnFunction(Function &F) override;
55 
56   void getAnalysisUsage(AnalysisUsage &AU) const override {
57     AU.setPreservesCFG();
58     AU.addRequired<TargetPassConfig>();
59     AU.addRequired<LoopInfoWrapperPass>();
60   }
61 
62   StringRef getPassName() const override {
63     return "RISC-V gather/scatter lowering";
64   }
65 
66 private:
67   bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr,
68                                  Value *AlignOp);
69 
70   std::pair<Value *, Value *> determineBaseAndStride(GetElementPtrInst *GEP,
71                                                      IRBuilderBase &Builder);
72 
73   bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride,
74                               PHINode *&BasePtr, BinaryOperator *&Inc,
75                               IRBuilderBase &Builder);
76 };
77 
78 } // end anonymous namespace
79 
80 char RISCVGatherScatterLowering::ID = 0;
81 
82 INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE,
83                 "RISC-V gather/scatter lowering pass", false, false)
84 
85 FunctionPass *llvm::createRISCVGatherScatterLoweringPass() {
86   return new RISCVGatherScatterLowering();
87 }
88 
89 // TODO: Should we consider the mask when looking for a stride?
90 static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) {
91   if (!isa<FixedVectorType>(StartC->getType()))
92     return std::make_pair(nullptr, nullptr);
93 
94   unsigned NumElts = cast<FixedVectorType>(StartC->getType())->getNumElements();
95 
96   // Check that the start value is a strided constant.
97   auto *StartVal =
98       dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement((unsigned)0));
99   if (!StartVal)
100     return std::make_pair(nullptr, nullptr);
101   APInt StrideVal(StartVal->getValue().getBitWidth(), 0);
102   ConstantInt *Prev = StartVal;
103   for (unsigned i = 1; i != NumElts; ++i) {
104     auto *C = dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement(i));
105     if (!C)
106       return std::make_pair(nullptr, nullptr);
107 
108     APInt LocalStride = C->getValue() - Prev->getValue();
109     if (i == 1)
110       StrideVal = LocalStride;
111     else if (StrideVal != LocalStride)
112       return std::make_pair(nullptr, nullptr);
113 
114     Prev = C;
115   }
116 
117   Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal);
118 
119   return std::make_pair(StartVal, Stride);
120 }
121 
122 static std::pair<Value *, Value *> matchStridedStart(Value *Start,
123                                                      IRBuilderBase &Builder) {
124   // Base case, start is a strided constant.
125   auto *StartC = dyn_cast<Constant>(Start);
126   if (StartC)
127     return matchStridedConstant(StartC);
128 
129   // Base case, start is a stepvector
130   if (match(Start, m_Intrinsic<Intrinsic::experimental_stepvector>())) {
131     auto *Ty = Start->getType()->getScalarType();
132     return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1));
133   }
134 
135   // Not a constant, maybe it's a strided constant with a splat added or
136   // multipled.
137   auto *BO = dyn_cast<BinaryOperator>(Start);
138   if (!BO || (BO->getOpcode() != Instruction::Add &&
139               BO->getOpcode() != Instruction::Shl &&
140               BO->getOpcode() != Instruction::Mul))
141     return std::make_pair(nullptr, nullptr);
142 
143   // Look for an operand that is splatted.
144   unsigned OtherIndex = 0;
145   Value *Splat = getSplatValue(BO->getOperand(1));
146   if (!Splat && Instruction::isCommutative(BO->getOpcode())) {
147     Splat = getSplatValue(BO->getOperand(0));
148     OtherIndex = 1;
149   }
150   if (!Splat)
151     return std::make_pair(nullptr, nullptr);
152 
153   Value *Stride;
154   std::tie(Start, Stride) = matchStridedStart(BO->getOperand(OtherIndex),
155                                               Builder);
156   if (!Start)
157     return std::make_pair(nullptr, nullptr);
158 
159   Builder.SetInsertPoint(BO);
160   Builder.SetCurrentDebugLocation(DebugLoc());
161   // Add the splat value to the start or multiply the start and stride by the
162   // splat.
163   switch (BO->getOpcode()) {
164   default:
165     llvm_unreachable("Unexpected opcode");
166   case Instruction::Add:
167     Start = Builder.CreateAdd(Start, Splat);
168     break;
169   case Instruction::Mul:
170     Start = Builder.CreateMul(Start, Splat);
171     Stride = Builder.CreateMul(Stride, Splat);
172     break;
173   case Instruction::Shl:
174     Start = Builder.CreateShl(Start, Splat);
175     Stride = Builder.CreateShl(Stride, Splat);
176     break;
177   }
178 
179   return std::make_pair(Start, Stride);
180 }
181 
182 // Recursively, walk about the use-def chain until we find a Phi with a strided
183 // start value. Build and update a scalar recurrence as we unwind the recursion.
184 // We also update the Stride as we unwind. Our goal is to move all of the
185 // arithmetic out of the loop.
186 bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
187                                                         Value *&Stride,
188                                                         PHINode *&BasePtr,
189                                                         BinaryOperator *&Inc,
190                                                         IRBuilderBase &Builder) {
191   // Our base case is a Phi.
192   if (auto *Phi = dyn_cast<PHINode>(Index)) {
193     // A phi node we want to perform this function on should be from the
194     // loop header.
195     if (Phi->getParent() != L->getHeader())
196       return false;
197 
198     Value *Step, *Start;
199     if (!matchSimpleRecurrence(Phi, Inc, Start, Step) ||
200         Inc->getOpcode() != Instruction::Add)
201       return false;
202     assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
203     unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1;
204     assert(Phi->getIncomingValue(IncrementingBlock) == Inc &&
205            "Expected one operand of phi to be Inc");
206 
207     // Only proceed if the step is loop invariant.
208     if (!L->isLoopInvariant(Step))
209       return false;
210 
211     // Step should be a splat.
212     Step = getSplatValue(Step);
213     if (!Step)
214       return false;
215 
216     std::tie(Start, Stride) = matchStridedStart(Start, Builder);
217     if (!Start)
218       return false;
219     assert(Stride != nullptr);
220 
221     // Build scalar phi and increment.
222     BasePtr =
223         PHINode::Create(Start->getType(), 2, Phi->getName() + ".scalar", Phi);
224     Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar",
225                                     Inc);
226     BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock));
227     BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock));
228 
229     // Note that this Phi might be eligible for removal.
230     MaybeDeadPHIs.push_back(Phi);
231     return true;
232   }
233 
234   // Otherwise look for binary operator.
235   auto *BO = dyn_cast<BinaryOperator>(Index);
236   if (!BO)
237     return false;
238 
239   switch (BO->getOpcode()) {
240   default:
241     return false;
242   case Instruction::Or:
243     // We need to be able to treat Or as Add.
244     if (!haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL))
245       return false;
246     break;
247   case Instruction::Add:
248     break;
249   case Instruction::Shl:
250     break;
251   case Instruction::Mul:
252     break;
253   }
254 
255   // We should have one operand in the loop and one splat.
256   Value *OtherOp;
257   if (isa<Instruction>(BO->getOperand(0)) &&
258       L->contains(cast<Instruction>(BO->getOperand(0)))) {
259     Index = cast<Instruction>(BO->getOperand(0));
260     OtherOp = BO->getOperand(1);
261   } else if (isa<Instruction>(BO->getOperand(1)) &&
262              L->contains(cast<Instruction>(BO->getOperand(1))) &&
263              Instruction::isCommutative(BO->getOpcode())) {
264     Index = cast<Instruction>(BO->getOperand(1));
265     OtherOp = BO->getOperand(0);
266   } else {
267     return false;
268   }
269 
270   // Make sure other op is loop invariant.
271   if (!L->isLoopInvariant(OtherOp))
272     return false;
273 
274   // Make sure we have a splat.
275   Value *SplatOp = getSplatValue(OtherOp);
276   if (!SplatOp)
277     return false;
278 
279   // Recurse up the use-def chain.
280   if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder))
281     return false;
282 
283   // Locate the Step and Start values from the recurrence.
284   unsigned StepIndex = Inc->getOperand(0) == BasePtr ? 1 : 0;
285   unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0;
286   Value *Step = Inc->getOperand(StepIndex);
287   Value *Start = BasePtr->getOperand(StartBlock);
288 
289   // We need to adjust the start value in the preheader.
290   Builder.SetInsertPoint(
291       BasePtr->getIncomingBlock(StartBlock)->getTerminator());
292   Builder.SetCurrentDebugLocation(DebugLoc());
293 
294   switch (BO->getOpcode()) {
295   default:
296     llvm_unreachable("Unexpected opcode!");
297   case Instruction::Add:
298   case Instruction::Or: {
299     // An add only affects the start value. It's ok to do this for Or because
300     // we already checked that there are no common set bits.
301     Start = Builder.CreateAdd(Start, SplatOp, "start");
302     break;
303   }
304   case Instruction::Mul: {
305     Start = Builder.CreateMul(Start, SplatOp, "start");
306     Step = Builder.CreateMul(Step, SplatOp, "step");
307     Stride = Builder.CreateMul(Stride, SplatOp, "stride");
308     break;
309   }
310   case Instruction::Shl: {
311     Start = Builder.CreateShl(Start, SplatOp, "start");
312     Step = Builder.CreateShl(Step, SplatOp, "step");
313     Stride = Builder.CreateShl(Stride, SplatOp, "stride");
314     break;
315   }
316   }
317 
318   Inc->setOperand(StepIndex, Step);
319   BasePtr->setIncomingValue(StartBlock, Start);
320   return true;
321 }
322 
323 std::pair<Value *, Value *>
324 RISCVGatherScatterLowering::determineBaseAndStride(GetElementPtrInst *GEP,
325                                                    IRBuilderBase &Builder) {
326 
327   auto I = StridedAddrs.find(GEP);
328   if (I != StridedAddrs.end())
329     return I->second;
330 
331   SmallVector<Value *, 2> Ops(GEP->operands());
332 
333   // Base pointer needs to be a scalar.
334   if (Ops[0]->getType()->isVectorTy())
335     return std::make_pair(nullptr, nullptr);
336 
337   std::optional<unsigned> VecOperand;
338   unsigned TypeScale = 0;
339 
340   // Look for a vector operand and scale.
341   gep_type_iterator GTI = gep_type_begin(GEP);
342   for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) {
343     if (!Ops[i]->getType()->isVectorTy())
344       continue;
345 
346     if (VecOperand)
347       return std::make_pair(nullptr, nullptr);
348 
349     VecOperand = i;
350 
351     TypeSize TS = DL->getTypeAllocSize(GTI.getIndexedType());
352     if (TS.isScalable())
353       return std::make_pair(nullptr, nullptr);
354 
355     TypeScale = TS.getFixedValue();
356   }
357 
358   // We need to find a vector index to simplify.
359   if (!VecOperand)
360     return std::make_pair(nullptr, nullptr);
361 
362   // We can't extract the stride if the arithmetic is done at a different size
363   // than the pointer type. Adding the stride later may not wrap correctly.
364   // Technically we could handle wider indices, but I don't expect that in
365   // practice.
366   Value *VecIndex = Ops[*VecOperand];
367   Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType());
368   if (VecIndex->getType() != VecIntPtrTy)
369     return std::make_pair(nullptr, nullptr);
370 
371   // Handle the non-recursive case.  This is what we see if the vectorizer
372   // decides to use a scalar IV + vid on demand instead of a vector IV.
373   auto [Start, Stride] = matchStridedStart(VecIndex, Builder);
374   if (Start) {
375     assert(Stride);
376     Builder.SetInsertPoint(GEP);
377 
378     // Replace the vector index with the scalar start and build a scalar GEP.
379     Ops[*VecOperand] = Start;
380     Type *SourceTy = GEP->getSourceElementType();
381     Value *BasePtr =
382         Builder.CreateGEP(SourceTy, Ops[0], ArrayRef(Ops).drop_front());
383 
384     // Convert stride to pointer size if needed.
385     Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
386     assert(Stride->getType() == IntPtrTy && "Unexpected type");
387 
388     // Scale the stride by the size of the indexed type.
389     if (TypeScale != 1)
390       Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
391 
392     auto P = std::make_pair(BasePtr, Stride);
393     StridedAddrs[GEP] = P;
394     return P;
395   }
396 
397   // Make sure we're in a loop and that has a pre-header and a single latch.
398   Loop *L = LI->getLoopFor(GEP->getParent());
399   if (!L || !L->getLoopPreheader() || !L->getLoopLatch())
400     return std::make_pair(nullptr, nullptr);
401 
402   BinaryOperator *Inc;
403   PHINode *BasePhi;
404   if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder))
405     return std::make_pair(nullptr, nullptr);
406 
407   assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
408   unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1;
409   assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc &&
410          "Expected one operand of phi to be Inc");
411 
412   Builder.SetInsertPoint(GEP);
413 
414   // Replace the vector index with the scalar phi and build a scalar GEP.
415   Ops[*VecOperand] = BasePhi;
416   Type *SourceTy = GEP->getSourceElementType();
417   Value *BasePtr =
418       Builder.CreateGEP(SourceTy, Ops[0], ArrayRef(Ops).drop_front());
419 
420   // Final adjustments to stride should go in the start block.
421   Builder.SetInsertPoint(
422       BasePhi->getIncomingBlock(1 - IncrementingBlock)->getTerminator());
423 
424   // Convert stride to pointer size if needed.
425   Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
426   assert(Stride->getType() == IntPtrTy && "Unexpected type");
427 
428   // Scale the stride by the size of the indexed type.
429   if (TypeScale != 1)
430     Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
431 
432   auto P = std::make_pair(BasePtr, Stride);
433   StridedAddrs[GEP] = P;
434   return P;
435 }
436 
437 bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
438                                                            Type *DataType,
439                                                            Value *Ptr,
440                                                            Value *AlignOp) {
441   // Make sure the operation will be supported by the backend.
442   MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue();
443   EVT DataTypeVT = TLI->getValueType(*DL, DataType);
444   if (!MA || !TLI->isLegalStridedLoadStore(DataTypeVT, *MA))
445     return false;
446 
447   // FIXME: Let the backend type legalize by splitting/widening?
448   if (!TLI->isTypeLegal(DataTypeVT))
449     return false;
450 
451   // Pointer should be a GEP.
452   auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
453   if (!GEP)
454     return false;
455 
456   LLVMContext &Ctx = GEP->getContext();
457   IRBuilder<InstSimplifyFolder> Builder(Ctx, *DL);
458   Builder.SetInsertPoint(GEP);
459 
460   Value *BasePtr, *Stride;
461   std::tie(BasePtr, Stride) = determineBaseAndStride(GEP, Builder);
462   if (!BasePtr)
463     return false;
464   assert(Stride != nullptr);
465 
466   Builder.SetInsertPoint(II);
467 
468   CallInst *Call;
469   if (II->getIntrinsicID() == Intrinsic::masked_gather)
470     Call = Builder.CreateIntrinsic(
471         Intrinsic::riscv_masked_strided_load,
472         {DataType, BasePtr->getType(), Stride->getType()},
473         {II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)});
474   else
475     Call = Builder.CreateIntrinsic(
476         Intrinsic::riscv_masked_strided_store,
477         {DataType, BasePtr->getType(), Stride->getType()},
478         {II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)});
479 
480   Call->takeName(II);
481   II->replaceAllUsesWith(Call);
482   II->eraseFromParent();
483 
484   if (GEP->use_empty())
485     RecursivelyDeleteTriviallyDeadInstructions(GEP);
486 
487   return true;
488 }
489 
490 bool RISCVGatherScatterLowering::runOnFunction(Function &F) {
491   if (skipFunction(F))
492     return false;
493 
494   auto &TPC = getAnalysis<TargetPassConfig>();
495   auto &TM = TPC.getTM<RISCVTargetMachine>();
496   ST = &TM.getSubtarget<RISCVSubtarget>(F);
497   if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors())
498     return false;
499 
500   TLI = ST->getTargetLowering();
501   DL = &F.getParent()->getDataLayout();
502   LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
503 
504   StridedAddrs.clear();
505 
506   SmallVector<IntrinsicInst *, 4> Gathers;
507   SmallVector<IntrinsicInst *, 4> Scatters;
508 
509   bool Changed = false;
510 
511   for (BasicBlock &BB : F) {
512     for (Instruction &I : BB) {
513       IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
514       if (II && II->getIntrinsicID() == Intrinsic::masked_gather) {
515         Gathers.push_back(II);
516       } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) {
517         Scatters.push_back(II);
518       }
519     }
520   }
521 
522   // Rewrite gather/scatter to form strided load/store if possible.
523   for (auto *II : Gathers)
524     Changed |= tryCreateStridedLoadStore(
525         II, II->getType(), II->getArgOperand(0), II->getArgOperand(1));
526   for (auto *II : Scatters)
527     Changed |=
528         tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(),
529                                   II->getArgOperand(1), II->getArgOperand(2));
530 
531   // Remove any dead phis.
532   while (!MaybeDeadPHIs.empty()) {
533     if (auto *Phi = dyn_cast_or_null<PHINode>(MaybeDeadPHIs.pop_back_val()))
534       RecursivelyDeleteDeadPHINode(Phi);
535   }
536 
537   return Changed;
538 }
539