xref: /freebsd/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp (revision 66fd12cf4896eb08ad8e7a2627537f84ead84dd3)
1 //===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass custom lowers llvm.gather and llvm.scatter instructions to
10 // RISCV intrinsics.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "RISCV.h"
15 #include "RISCVTargetMachine.h"
16 #include "llvm/Analysis/LoopInfo.h"
17 #include "llvm/Analysis/ValueTracking.h"
18 #include "llvm/Analysis/VectorUtils.h"
19 #include "llvm/CodeGen/TargetPassConfig.h"
20 #include "llvm/IR/GetElementPtrTypeIterator.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/IntrinsicInst.h"
23 #include "llvm/IR/IntrinsicsRISCV.h"
24 #include "llvm/IR/PatternMatch.h"
25 #include "llvm/Transforms/Utils/Local.h"
26 #include <optional>
27 
28 using namespace llvm;
29 using namespace PatternMatch;
30 
31 #define DEBUG_TYPE "riscv-gather-scatter-lowering"
32 
33 namespace {
34 
35 class RISCVGatherScatterLowering : public FunctionPass {
36   const RISCVSubtarget *ST = nullptr;
37   const RISCVTargetLowering *TLI = nullptr;
38   LoopInfo *LI = nullptr;
39   const DataLayout *DL = nullptr;
40 
41   SmallVector<WeakTrackingVH> MaybeDeadPHIs;
42 
43   // Cache of the BasePtr and Stride determined from this GEP. When a GEP is
44   // used by multiple gathers/scatters, this allow us to reuse the scalar
45   // instructions we created for the first gather/scatter for the others.
46   DenseMap<GetElementPtrInst *, std::pair<Value *, Value *>> StridedAddrs;
47 
48 public:
49   static char ID; // Pass identification, replacement for typeid
50 
51   RISCVGatherScatterLowering() : FunctionPass(ID) {}
52 
53   bool runOnFunction(Function &F) override;
54 
55   void getAnalysisUsage(AnalysisUsage &AU) const override {
56     AU.setPreservesCFG();
57     AU.addRequired<TargetPassConfig>();
58     AU.addRequired<LoopInfoWrapperPass>();
59   }
60 
61   StringRef getPassName() const override {
62     return "RISCV gather/scatter lowering";
63   }
64 
65 private:
66   bool isLegalTypeAndAlignment(Type *DataType, Value *AlignOp);
67 
68   bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr,
69                                  Value *AlignOp);
70 
71   std::pair<Value *, Value *> determineBaseAndStride(GetElementPtrInst *GEP,
72                                                      IRBuilder<> &Builder);
73 
74   bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride,
75                               PHINode *&BasePtr, BinaryOperator *&Inc,
76                               IRBuilder<> &Builder);
77 };
78 
79 } // end anonymous namespace
80 
81 char RISCVGatherScatterLowering::ID = 0;
82 
83 INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE,
84                 "RISCV gather/scatter lowering pass", false, false)
85 
86 FunctionPass *llvm::createRISCVGatherScatterLoweringPass() {
87   return new RISCVGatherScatterLowering();
88 }
89 
90 bool RISCVGatherScatterLowering::isLegalTypeAndAlignment(Type *DataType,
91                                                          Value *AlignOp) {
92   Type *ScalarType = DataType->getScalarType();
93   if (!TLI->isLegalElementTypeForRVV(ScalarType))
94     return false;
95 
96   MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue();
97   if (MA && MA->value() < DL->getTypeStoreSize(ScalarType).getFixedValue())
98     return false;
99 
100   // FIXME: Let the backend type legalize by splitting/widening?
101   EVT DataVT = TLI->getValueType(*DL, DataType);
102   if (!TLI->isTypeLegal(DataVT))
103     return false;
104 
105   return true;
106 }
107 
108 // TODO: Should we consider the mask when looking for a stride?
109 static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) {
110   unsigned NumElts = cast<FixedVectorType>(StartC->getType())->getNumElements();
111 
112   // Check that the start value is a strided constant.
113   auto *StartVal =
114       dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement((unsigned)0));
115   if (!StartVal)
116     return std::make_pair(nullptr, nullptr);
117   APInt StrideVal(StartVal->getValue().getBitWidth(), 0);
118   ConstantInt *Prev = StartVal;
119   for (unsigned i = 1; i != NumElts; ++i) {
120     auto *C = dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement(i));
121     if (!C)
122       return std::make_pair(nullptr, nullptr);
123 
124     APInt LocalStride = C->getValue() - Prev->getValue();
125     if (i == 1)
126       StrideVal = LocalStride;
127     else if (StrideVal != LocalStride)
128       return std::make_pair(nullptr, nullptr);
129 
130     Prev = C;
131   }
132 
133   Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal);
134 
135   return std::make_pair(StartVal, Stride);
136 }
137 
138 static std::pair<Value *, Value *> matchStridedStart(Value *Start,
139                                                      IRBuilder<> &Builder) {
140   // Base case, start is a strided constant.
141   auto *StartC = dyn_cast<Constant>(Start);
142   if (StartC)
143     return matchStridedConstant(StartC);
144 
145   // Base case, start is a stepvector
146   if (match(Start, m_Intrinsic<Intrinsic::experimental_stepvector>())) {
147     auto *Ty = Start->getType()->getScalarType();
148     return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1));
149   }
150 
151   // Not a constant, maybe it's a strided constant with a splat added to it.
152   auto *BO = dyn_cast<BinaryOperator>(Start);
153   if (!BO || BO->getOpcode() != Instruction::Add)
154     return std::make_pair(nullptr, nullptr);
155 
156   // Look for an operand that is splatted.
157   unsigned OtherIndex = 1;
158   Value *Splat = getSplatValue(BO->getOperand(0));
159   if (!Splat) {
160     Splat = getSplatValue(BO->getOperand(1));
161     OtherIndex = 0;
162   }
163   if (!Splat)
164     return std::make_pair(nullptr, nullptr);
165 
166   Value *Stride;
167   std::tie(Start, Stride) = matchStridedStart(BO->getOperand(OtherIndex),
168                                               Builder);
169   if (!Start)
170     return std::make_pair(nullptr, nullptr);
171 
172   // Add the splat value to the start.
173   Builder.SetInsertPoint(BO);
174   Builder.SetCurrentDebugLocation(DebugLoc());
175   Start = Builder.CreateAdd(Start, Splat);
176   return std::make_pair(Start, Stride);
177 }
178 
179 // Recursively, walk about the use-def chain until we find a Phi with a strided
180 // start value. Build and update a scalar recurrence as we unwind the recursion.
181 // We also update the Stride as we unwind. Our goal is to move all of the
182 // arithmetic out of the loop.
183 bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
184                                                         Value *&Stride,
185                                                         PHINode *&BasePtr,
186                                                         BinaryOperator *&Inc,
187                                                         IRBuilder<> &Builder) {
188   // Our base case is a Phi.
189   if (auto *Phi = dyn_cast<PHINode>(Index)) {
190     // A phi node we want to perform this function on should be from the
191     // loop header.
192     if (Phi->getParent() != L->getHeader())
193       return false;
194 
195     Value *Step, *Start;
196     if (!matchSimpleRecurrence(Phi, Inc, Start, Step) ||
197         Inc->getOpcode() != Instruction::Add)
198       return false;
199     assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
200     unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1;
201     assert(Phi->getIncomingValue(IncrementingBlock) == Inc &&
202            "Expected one operand of phi to be Inc");
203 
204     // Only proceed if the step is loop invariant.
205     if (!L->isLoopInvariant(Step))
206       return false;
207 
208     // Step should be a splat.
209     Step = getSplatValue(Step);
210     if (!Step)
211       return false;
212 
213     std::tie(Start, Stride) = matchStridedStart(Start, Builder);
214     if (!Start)
215       return false;
216     assert(Stride != nullptr);
217 
218     // Build scalar phi and increment.
219     BasePtr =
220         PHINode::Create(Start->getType(), 2, Phi->getName() + ".scalar", Phi);
221     Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar",
222                                     Inc);
223     BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock));
224     BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock));
225 
226     // Note that this Phi might be eligible for removal.
227     MaybeDeadPHIs.push_back(Phi);
228     return true;
229   }
230 
231   // Otherwise look for binary operator.
232   auto *BO = dyn_cast<BinaryOperator>(Index);
233   if (!BO)
234     return false;
235 
236   if (BO->getOpcode() != Instruction::Add &&
237       BO->getOpcode() != Instruction::Or &&
238       BO->getOpcode() != Instruction::Mul &&
239       BO->getOpcode() != Instruction::Shl)
240     return false;
241 
242   // Only support shift by constant.
243   if (BO->getOpcode() == Instruction::Shl && !isa<Constant>(BO->getOperand(1)))
244     return false;
245 
246   // We need to be able to treat Or as Add.
247   if (BO->getOpcode() == Instruction::Or &&
248       !haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL))
249     return false;
250 
251   // We should have one operand in the loop and one splat.
252   Value *OtherOp;
253   if (isa<Instruction>(BO->getOperand(0)) &&
254       L->contains(cast<Instruction>(BO->getOperand(0)))) {
255     Index = cast<Instruction>(BO->getOperand(0));
256     OtherOp = BO->getOperand(1);
257   } else if (isa<Instruction>(BO->getOperand(1)) &&
258              L->contains(cast<Instruction>(BO->getOperand(1)))) {
259     Index = cast<Instruction>(BO->getOperand(1));
260     OtherOp = BO->getOperand(0);
261   } else {
262     return false;
263   }
264 
265   // Make sure other op is loop invariant.
266   if (!L->isLoopInvariant(OtherOp))
267     return false;
268 
269   // Make sure we have a splat.
270   Value *SplatOp = getSplatValue(OtherOp);
271   if (!SplatOp)
272     return false;
273 
274   // Recurse up the use-def chain.
275   if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder))
276     return false;
277 
278   // Locate the Step and Start values from the recurrence.
279   unsigned StepIndex = Inc->getOperand(0) == BasePtr ? 1 : 0;
280   unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0;
281   Value *Step = Inc->getOperand(StepIndex);
282   Value *Start = BasePtr->getOperand(StartBlock);
283 
284   // We need to adjust the start value in the preheader.
285   Builder.SetInsertPoint(
286       BasePtr->getIncomingBlock(StartBlock)->getTerminator());
287   Builder.SetCurrentDebugLocation(DebugLoc());
288 
289   switch (BO->getOpcode()) {
290   default:
291     llvm_unreachable("Unexpected opcode!");
292   case Instruction::Add:
293   case Instruction::Or: {
294     // An add only affects the start value. It's ok to do this for Or because
295     // we already checked that there are no common set bits.
296 
297     // If the start value is Zero, just take the SplatOp.
298     if (isa<ConstantInt>(Start) && cast<ConstantInt>(Start)->isZero())
299       Start = SplatOp;
300     else
301       Start = Builder.CreateAdd(Start, SplatOp, "start");
302     BasePtr->setIncomingValue(StartBlock, Start);
303     break;
304   }
305   case Instruction::Mul: {
306     // If the start is zero we don't need to multiply.
307     if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero())
308       Start = Builder.CreateMul(Start, SplatOp, "start");
309 
310     Step = Builder.CreateMul(Step, SplatOp, "step");
311 
312     // If the Stride is 1 just take the SplatOpt.
313     if (isa<ConstantInt>(Stride) && cast<ConstantInt>(Stride)->isOne())
314       Stride = SplatOp;
315     else
316       Stride = Builder.CreateMul(Stride, SplatOp, "stride");
317     Inc->setOperand(StepIndex, Step);
318     BasePtr->setIncomingValue(StartBlock, Start);
319     break;
320   }
321   case Instruction::Shl: {
322     // If the start is zero we don't need to shift.
323     if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero())
324       Start = Builder.CreateShl(Start, SplatOp, "start");
325     Step = Builder.CreateShl(Step, SplatOp, "step");
326     Stride = Builder.CreateShl(Stride, SplatOp, "stride");
327     Inc->setOperand(StepIndex, Step);
328     BasePtr->setIncomingValue(StartBlock, Start);
329     break;
330   }
331   }
332 
333   return true;
334 }
335 
336 std::pair<Value *, Value *>
337 RISCVGatherScatterLowering::determineBaseAndStride(GetElementPtrInst *GEP,
338                                                    IRBuilder<> &Builder) {
339 
340   auto I = StridedAddrs.find(GEP);
341   if (I != StridedAddrs.end())
342     return I->second;
343 
344   SmallVector<Value *, 2> Ops(GEP->operands());
345 
346   // Base pointer needs to be a scalar.
347   if (Ops[0]->getType()->isVectorTy())
348     return std::make_pair(nullptr, nullptr);
349 
350   std::optional<unsigned> VecOperand;
351   unsigned TypeScale = 0;
352 
353   // Look for a vector operand and scale.
354   gep_type_iterator GTI = gep_type_begin(GEP);
355   for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) {
356     if (!Ops[i]->getType()->isVectorTy())
357       continue;
358 
359     if (VecOperand)
360       return std::make_pair(nullptr, nullptr);
361 
362     VecOperand = i;
363 
364     TypeSize TS = DL->getTypeAllocSize(GTI.getIndexedType());
365     if (TS.isScalable())
366       return std::make_pair(nullptr, nullptr);
367 
368     TypeScale = TS.getFixedValue();
369   }
370 
371   // We need to find a vector index to simplify.
372   if (!VecOperand)
373     return std::make_pair(nullptr, nullptr);
374 
375   // We can't extract the stride if the arithmetic is done at a different size
376   // than the pointer type. Adding the stride later may not wrap correctly.
377   // Technically we could handle wider indices, but I don't expect that in
378   // practice.
379   Value *VecIndex = Ops[*VecOperand];
380   Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType());
381   if (VecIndex->getType() != VecIntPtrTy)
382     return std::make_pair(nullptr, nullptr);
383 
384   // Handle the non-recursive case.  This is what we see if the vectorizer
385   // decides to use a scalar IV + vid on demand instead of a vector IV.
386   auto [Start, Stride] = matchStridedStart(VecIndex, Builder);
387   if (Start) {
388     assert(Stride);
389     Builder.SetInsertPoint(GEP);
390 
391     // Replace the vector index with the scalar start and build a scalar GEP.
392     Ops[*VecOperand] = Start;
393     Type *SourceTy = GEP->getSourceElementType();
394     Value *BasePtr =
395         Builder.CreateGEP(SourceTy, Ops[0], ArrayRef(Ops).drop_front());
396 
397     // Convert stride to pointer size if needed.
398     Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
399     assert(Stride->getType() == IntPtrTy && "Unexpected type");
400 
401     // Scale the stride by the size of the indexed type.
402     if (TypeScale != 1)
403       Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
404 
405     auto P = std::make_pair(BasePtr, Stride);
406     StridedAddrs[GEP] = P;
407     return P;
408   }
409 
410   // Make sure we're in a loop and that has a pre-header and a single latch.
411   Loop *L = LI->getLoopFor(GEP->getParent());
412   if (!L || !L->getLoopPreheader() || !L->getLoopLatch())
413     return std::make_pair(nullptr, nullptr);
414 
415   BinaryOperator *Inc;
416   PHINode *BasePhi;
417   if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder))
418     return std::make_pair(nullptr, nullptr);
419 
420   assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
421   unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1;
422   assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc &&
423          "Expected one operand of phi to be Inc");
424 
425   Builder.SetInsertPoint(GEP);
426 
427   // Replace the vector index with the scalar phi and build a scalar GEP.
428   Ops[*VecOperand] = BasePhi;
429   Type *SourceTy = GEP->getSourceElementType();
430   Value *BasePtr =
431       Builder.CreateGEP(SourceTy, Ops[0], ArrayRef(Ops).drop_front());
432 
433   // Final adjustments to stride should go in the start block.
434   Builder.SetInsertPoint(
435       BasePhi->getIncomingBlock(1 - IncrementingBlock)->getTerminator());
436 
437   // Convert stride to pointer size if needed.
438   Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
439   assert(Stride->getType() == IntPtrTy && "Unexpected type");
440 
441   // Scale the stride by the size of the indexed type.
442   if (TypeScale != 1)
443     Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
444 
445   auto P = std::make_pair(BasePtr, Stride);
446   StridedAddrs[GEP] = P;
447   return P;
448 }
449 
450 bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
451                                                            Type *DataType,
452                                                            Value *Ptr,
453                                                            Value *AlignOp) {
454   // Make sure the operation will be supported by the backend.
455   if (!isLegalTypeAndAlignment(DataType, AlignOp))
456     return false;
457 
458   // Pointer should be a GEP.
459   auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
460   if (!GEP)
461     return false;
462 
463   IRBuilder<> Builder(GEP);
464 
465   Value *BasePtr, *Stride;
466   std::tie(BasePtr, Stride) = determineBaseAndStride(GEP, Builder);
467   if (!BasePtr)
468     return false;
469   assert(Stride != nullptr);
470 
471   Builder.SetInsertPoint(II);
472 
473   CallInst *Call;
474   if (II->getIntrinsicID() == Intrinsic::masked_gather)
475     Call = Builder.CreateIntrinsic(
476         Intrinsic::riscv_masked_strided_load,
477         {DataType, BasePtr->getType(), Stride->getType()},
478         {II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)});
479   else
480     Call = Builder.CreateIntrinsic(
481         Intrinsic::riscv_masked_strided_store,
482         {DataType, BasePtr->getType(), Stride->getType()},
483         {II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)});
484 
485   Call->takeName(II);
486   II->replaceAllUsesWith(Call);
487   II->eraseFromParent();
488 
489   if (GEP->use_empty())
490     RecursivelyDeleteTriviallyDeadInstructions(GEP);
491 
492   return true;
493 }
494 
495 bool RISCVGatherScatterLowering::runOnFunction(Function &F) {
496   if (skipFunction(F))
497     return false;
498 
499   auto &TPC = getAnalysis<TargetPassConfig>();
500   auto &TM = TPC.getTM<RISCVTargetMachine>();
501   ST = &TM.getSubtarget<RISCVSubtarget>(F);
502   if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors())
503     return false;
504 
505   TLI = ST->getTargetLowering();
506   DL = &F.getParent()->getDataLayout();
507   LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
508 
509   StridedAddrs.clear();
510 
511   SmallVector<IntrinsicInst *, 4> Gathers;
512   SmallVector<IntrinsicInst *, 4> Scatters;
513 
514   bool Changed = false;
515 
516   for (BasicBlock &BB : F) {
517     for (Instruction &I : BB) {
518       IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
519       if (II && II->getIntrinsicID() == Intrinsic::masked_gather) {
520         Gathers.push_back(II);
521       } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) {
522         Scatters.push_back(II);
523       }
524     }
525   }
526 
527   // Rewrite gather/scatter to form strided load/store if possible.
528   for (auto *II : Gathers)
529     Changed |= tryCreateStridedLoadStore(
530         II, II->getType(), II->getArgOperand(0), II->getArgOperand(1));
531   for (auto *II : Scatters)
532     Changed |=
533         tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(),
534                                   II->getArgOperand(1), II->getArgOperand(2));
535 
536   // Remove any dead phis.
537   while (!MaybeDeadPHIs.empty()) {
538     if (auto *Phi = dyn_cast_or_null<PHINode>(MaybeDeadPHIs.pop_back_val()))
539       RecursivelyDeleteDeadPHINode(Phi);
540   }
541 
542   return Changed;
543 }
544