xref: /freebsd/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp (revision 3dd5524264095ed8612c28908e13f80668eff2f9)
1 //===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass custom lowers llvm.gather and llvm.scatter instructions to
10 // RISCV intrinsics.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "RISCV.h"
15 #include "RISCVTargetMachine.h"
16 #include "llvm/Analysis/LoopInfo.h"
17 #include "llvm/Analysis/ValueTracking.h"
18 #include "llvm/Analysis/VectorUtils.h"
19 #include "llvm/CodeGen/TargetPassConfig.h"
20 #include "llvm/IR/GetElementPtrTypeIterator.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/IntrinsicInst.h"
23 #include "llvm/IR/IntrinsicsRISCV.h"
24 #include "llvm/Transforms/Utils/Local.h"
25 
26 using namespace llvm;
27 
28 #define DEBUG_TYPE "riscv-gather-scatter-lowering"
29 
30 namespace {
31 
32 class RISCVGatherScatterLowering : public FunctionPass {
33   const RISCVSubtarget *ST = nullptr;
34   const RISCVTargetLowering *TLI = nullptr;
35   LoopInfo *LI = nullptr;
36   const DataLayout *DL = nullptr;
37 
38   SmallVector<WeakTrackingVH> MaybeDeadPHIs;
39 
40   // Cache of the BasePtr and Stride determined from this GEP. When a GEP is
41   // used by multiple gathers/scatters, this allow us to reuse the scalar
42   // instructions we created for the first gather/scatter for the others.
43   DenseMap<GetElementPtrInst *, std::pair<Value *, Value *>> StridedAddrs;
44 
45 public:
46   static char ID; // Pass identification, replacement for typeid
47 
48   RISCVGatherScatterLowering() : FunctionPass(ID) {}
49 
50   bool runOnFunction(Function &F) override;
51 
52   void getAnalysisUsage(AnalysisUsage &AU) const override {
53     AU.setPreservesCFG();
54     AU.addRequired<TargetPassConfig>();
55     AU.addRequired<LoopInfoWrapperPass>();
56   }
57 
58   StringRef getPassName() const override {
59     return "RISCV gather/scatter lowering";
60   }
61 
62 private:
63   bool isLegalTypeAndAlignment(Type *DataType, Value *AlignOp);
64 
65   bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr,
66                                  Value *AlignOp);
67 
68   std::pair<Value *, Value *> determineBaseAndStride(GetElementPtrInst *GEP,
69                                                      IRBuilder<> &Builder);
70 
71   bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride,
72                               PHINode *&BasePtr, BinaryOperator *&Inc,
73                               IRBuilder<> &Builder);
74 };
75 
76 } // end anonymous namespace
77 
78 char RISCVGatherScatterLowering::ID = 0;
79 
80 INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE,
81                 "RISCV gather/scatter lowering pass", false, false)
82 
83 FunctionPass *llvm::createRISCVGatherScatterLoweringPass() {
84   return new RISCVGatherScatterLowering();
85 }
86 
87 bool RISCVGatherScatterLowering::isLegalTypeAndAlignment(Type *DataType,
88                                                          Value *AlignOp) {
89   Type *ScalarType = DataType->getScalarType();
90   if (!TLI->isLegalElementTypeForRVV(ScalarType))
91     return false;
92 
93   MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue();
94   if (MA && MA->value() < DL->getTypeStoreSize(ScalarType).getFixedSize())
95     return false;
96 
97   // FIXME: Let the backend type legalize by splitting/widening?
98   EVT DataVT = TLI->getValueType(*DL, DataType);
99   if (!TLI->isTypeLegal(DataVT))
100     return false;
101 
102   return true;
103 }
104 
105 // TODO: Should we consider the mask when looking for a stride?
106 static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) {
107   unsigned NumElts = cast<FixedVectorType>(StartC->getType())->getNumElements();
108 
109   // Check that the start value is a strided constant.
110   auto *StartVal =
111       dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement((unsigned)0));
112   if (!StartVal)
113     return std::make_pair(nullptr, nullptr);
114   APInt StrideVal(StartVal->getValue().getBitWidth(), 0);
115   ConstantInt *Prev = StartVal;
116   for (unsigned i = 1; i != NumElts; ++i) {
117     auto *C = dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement(i));
118     if (!C)
119       return std::make_pair(nullptr, nullptr);
120 
121     APInt LocalStride = C->getValue() - Prev->getValue();
122     if (i == 1)
123       StrideVal = LocalStride;
124     else if (StrideVal != LocalStride)
125       return std::make_pair(nullptr, nullptr);
126 
127     Prev = C;
128   }
129 
130   Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal);
131 
132   return std::make_pair(StartVal, Stride);
133 }
134 
135 static std::pair<Value *, Value *> matchStridedStart(Value *Start,
136                                                      IRBuilder<> &Builder) {
137   // Base case, start is a strided constant.
138   auto *StartC = dyn_cast<Constant>(Start);
139   if (StartC)
140     return matchStridedConstant(StartC);
141 
142   // Not a constant, maybe it's a strided constant with a splat added to it.
143   auto *BO = dyn_cast<BinaryOperator>(Start);
144   if (!BO || BO->getOpcode() != Instruction::Add)
145     return std::make_pair(nullptr, nullptr);
146 
147   // Look for an operand that is splatted.
148   unsigned OtherIndex = 1;
149   Value *Splat = getSplatValue(BO->getOperand(0));
150   if (!Splat) {
151     Splat = getSplatValue(BO->getOperand(1));
152     OtherIndex = 0;
153   }
154   if (!Splat)
155     return std::make_pair(nullptr, nullptr);
156 
157   Value *Stride;
158   std::tie(Start, Stride) = matchStridedStart(BO->getOperand(OtherIndex),
159                                               Builder);
160   if (!Start)
161     return std::make_pair(nullptr, nullptr);
162 
163   // Add the splat value to the start.
164   Builder.SetInsertPoint(BO);
165   Builder.SetCurrentDebugLocation(DebugLoc());
166   Start = Builder.CreateAdd(Start, Splat);
167   return std::make_pair(Start, Stride);
168 }
169 
170 // Recursively, walk about the use-def chain until we find a Phi with a strided
171 // start value. Build and update a scalar recurrence as we unwind the recursion.
172 // We also update the Stride as we unwind. Our goal is to move all of the
173 // arithmetic out of the loop.
174 bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
175                                                         Value *&Stride,
176                                                         PHINode *&BasePtr,
177                                                         BinaryOperator *&Inc,
178                                                         IRBuilder<> &Builder) {
179   // Our base case is a Phi.
180   if (auto *Phi = dyn_cast<PHINode>(Index)) {
181     // A phi node we want to perform this function on should be from the
182     // loop header.
183     if (Phi->getParent() != L->getHeader())
184       return false;
185 
186     Value *Step, *Start;
187     if (!matchSimpleRecurrence(Phi, Inc, Start, Step) ||
188         Inc->getOpcode() != Instruction::Add)
189       return false;
190     assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
191     unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1;
192     assert(Phi->getIncomingValue(IncrementingBlock) == Inc &&
193            "Expected one operand of phi to be Inc");
194 
195     // Only proceed if the step is loop invariant.
196     if (!L->isLoopInvariant(Step))
197       return false;
198 
199     // Step should be a splat.
200     Step = getSplatValue(Step);
201     if (!Step)
202       return false;
203 
204     std::tie(Start, Stride) = matchStridedStart(Start, Builder);
205     if (!Start)
206       return false;
207     assert(Stride != nullptr);
208 
209     // Build scalar phi and increment.
210     BasePtr =
211         PHINode::Create(Start->getType(), 2, Phi->getName() + ".scalar", Phi);
212     Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar",
213                                     Inc);
214     BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock));
215     BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock));
216 
217     // Note that this Phi might be eligible for removal.
218     MaybeDeadPHIs.push_back(Phi);
219     return true;
220   }
221 
222   // Otherwise look for binary operator.
223   auto *BO = dyn_cast<BinaryOperator>(Index);
224   if (!BO)
225     return false;
226 
227   if (BO->getOpcode() != Instruction::Add &&
228       BO->getOpcode() != Instruction::Or &&
229       BO->getOpcode() != Instruction::Mul &&
230       BO->getOpcode() != Instruction::Shl)
231     return false;
232 
233   // Only support shift by constant.
234   if (BO->getOpcode() == Instruction::Shl && !isa<Constant>(BO->getOperand(1)))
235     return false;
236 
237   // We need to be able to treat Or as Add.
238   if (BO->getOpcode() == Instruction::Or &&
239       !haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL))
240     return false;
241 
242   // We should have one operand in the loop and one splat.
243   Value *OtherOp;
244   if (isa<Instruction>(BO->getOperand(0)) &&
245       L->contains(cast<Instruction>(BO->getOperand(0)))) {
246     Index = cast<Instruction>(BO->getOperand(0));
247     OtherOp = BO->getOperand(1);
248   } else if (isa<Instruction>(BO->getOperand(1)) &&
249              L->contains(cast<Instruction>(BO->getOperand(1)))) {
250     Index = cast<Instruction>(BO->getOperand(1));
251     OtherOp = BO->getOperand(0);
252   } else {
253     return false;
254   }
255 
256   // Make sure other op is loop invariant.
257   if (!L->isLoopInvariant(OtherOp))
258     return false;
259 
260   // Make sure we have a splat.
261   Value *SplatOp = getSplatValue(OtherOp);
262   if (!SplatOp)
263     return false;
264 
265   // Recurse up the use-def chain.
266   if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder))
267     return false;
268 
269   // Locate the Step and Start values from the recurrence.
270   unsigned StepIndex = Inc->getOperand(0) == BasePtr ? 1 : 0;
271   unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0;
272   Value *Step = Inc->getOperand(StepIndex);
273   Value *Start = BasePtr->getOperand(StartBlock);
274 
275   // We need to adjust the start value in the preheader.
276   Builder.SetInsertPoint(
277       BasePtr->getIncomingBlock(StartBlock)->getTerminator());
278   Builder.SetCurrentDebugLocation(DebugLoc());
279 
280   switch (BO->getOpcode()) {
281   default:
282     llvm_unreachable("Unexpected opcode!");
283   case Instruction::Add:
284   case Instruction::Or: {
285     // An add only affects the start value. It's ok to do this for Or because
286     // we already checked that there are no common set bits.
287 
288     // If the start value is Zero, just take the SplatOp.
289     if (isa<ConstantInt>(Start) && cast<ConstantInt>(Start)->isZero())
290       Start = SplatOp;
291     else
292       Start = Builder.CreateAdd(Start, SplatOp, "start");
293     BasePtr->setIncomingValue(StartBlock, Start);
294     break;
295   }
296   case Instruction::Mul: {
297     // If the start is zero we don't need to multiply.
298     if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero())
299       Start = Builder.CreateMul(Start, SplatOp, "start");
300 
301     Step = Builder.CreateMul(Step, SplatOp, "step");
302 
303     // If the Stride is 1 just take the SplatOpt.
304     if (isa<ConstantInt>(Stride) && cast<ConstantInt>(Stride)->isOne())
305       Stride = SplatOp;
306     else
307       Stride = Builder.CreateMul(Stride, SplatOp, "stride");
308     Inc->setOperand(StepIndex, Step);
309     BasePtr->setIncomingValue(StartBlock, Start);
310     break;
311   }
312   case Instruction::Shl: {
313     // If the start is zero we don't need to shift.
314     if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero())
315       Start = Builder.CreateShl(Start, SplatOp, "start");
316     Step = Builder.CreateShl(Step, SplatOp, "step");
317     Stride = Builder.CreateShl(Stride, SplatOp, "stride");
318     Inc->setOperand(StepIndex, Step);
319     BasePtr->setIncomingValue(StartBlock, Start);
320     break;
321   }
322   }
323 
324   return true;
325 }
326 
327 std::pair<Value *, Value *>
328 RISCVGatherScatterLowering::determineBaseAndStride(GetElementPtrInst *GEP,
329                                                    IRBuilder<> &Builder) {
330 
331   auto I = StridedAddrs.find(GEP);
332   if (I != StridedAddrs.end())
333     return I->second;
334 
335   SmallVector<Value *, 2> Ops(GEP->operands());
336 
337   // Base pointer needs to be a scalar.
338   if (Ops[0]->getType()->isVectorTy())
339     return std::make_pair(nullptr, nullptr);
340 
341   // Make sure we're in a loop and that has a pre-header and a single latch.
342   Loop *L = LI->getLoopFor(GEP->getParent());
343   if (!L || !L->getLoopPreheader() || !L->getLoopLatch())
344     return std::make_pair(nullptr, nullptr);
345 
346   Optional<unsigned> VecOperand;
347   unsigned TypeScale = 0;
348 
349   // Look for a vector operand and scale.
350   gep_type_iterator GTI = gep_type_begin(GEP);
351   for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) {
352     if (!Ops[i]->getType()->isVectorTy())
353       continue;
354 
355     if (VecOperand)
356       return std::make_pair(nullptr, nullptr);
357 
358     VecOperand = i;
359 
360     TypeSize TS = DL->getTypeAllocSize(GTI.getIndexedType());
361     if (TS.isScalable())
362       return std::make_pair(nullptr, nullptr);
363 
364     TypeScale = TS.getFixedSize();
365   }
366 
367   // We need to find a vector index to simplify.
368   if (!VecOperand)
369     return std::make_pair(nullptr, nullptr);
370 
371   // We can't extract the stride if the arithmetic is done at a different size
372   // than the pointer type. Adding the stride later may not wrap correctly.
373   // Technically we could handle wider indices, but I don't expect that in
374   // practice.
375   Value *VecIndex = Ops[*VecOperand];
376   Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType());
377   if (VecIndex->getType() != VecIntPtrTy)
378     return std::make_pair(nullptr, nullptr);
379 
380   Value *Stride;
381   BinaryOperator *Inc;
382   PHINode *BasePhi;
383   if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder))
384     return std::make_pair(nullptr, nullptr);
385 
386   assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
387   unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1;
388   assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc &&
389          "Expected one operand of phi to be Inc");
390 
391   Builder.SetInsertPoint(GEP);
392 
393   // Replace the vector index with the scalar phi and build a scalar GEP.
394   Ops[*VecOperand] = BasePhi;
395   Type *SourceTy = GEP->getSourceElementType();
396   Value *BasePtr =
397       Builder.CreateGEP(SourceTy, Ops[0], makeArrayRef(Ops).drop_front());
398 
399   // Final adjustments to stride should go in the start block.
400   Builder.SetInsertPoint(
401       BasePhi->getIncomingBlock(1 - IncrementingBlock)->getTerminator());
402 
403   // Convert stride to pointer size if needed.
404   Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
405   assert(Stride->getType() == IntPtrTy && "Unexpected type");
406 
407   // Scale the stride by the size of the indexed type.
408   if (TypeScale != 1)
409     Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
410 
411   auto P = std::make_pair(BasePtr, Stride);
412   StridedAddrs[GEP] = P;
413   return P;
414 }
415 
416 bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
417                                                            Type *DataType,
418                                                            Value *Ptr,
419                                                            Value *AlignOp) {
420   // Make sure the operation will be supported by the backend.
421   if (!isLegalTypeAndAlignment(DataType, AlignOp))
422     return false;
423 
424   // Pointer should be a GEP.
425   auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
426   if (!GEP)
427     return false;
428 
429   IRBuilder<> Builder(GEP);
430 
431   Value *BasePtr, *Stride;
432   std::tie(BasePtr, Stride) = determineBaseAndStride(GEP, Builder);
433   if (!BasePtr)
434     return false;
435   assert(Stride != nullptr);
436 
437   Builder.SetInsertPoint(II);
438 
439   CallInst *Call;
440   if (II->getIntrinsicID() == Intrinsic::masked_gather)
441     Call = Builder.CreateIntrinsic(
442         Intrinsic::riscv_masked_strided_load,
443         {DataType, BasePtr->getType(), Stride->getType()},
444         {II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)});
445   else
446     Call = Builder.CreateIntrinsic(
447         Intrinsic::riscv_masked_strided_store,
448         {DataType, BasePtr->getType(), Stride->getType()},
449         {II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)});
450 
451   Call->takeName(II);
452   II->replaceAllUsesWith(Call);
453   II->eraseFromParent();
454 
455   if (GEP->use_empty())
456     RecursivelyDeleteTriviallyDeadInstructions(GEP);
457 
458   return true;
459 }
460 
461 bool RISCVGatherScatterLowering::runOnFunction(Function &F) {
462   if (skipFunction(F))
463     return false;
464 
465   auto &TPC = getAnalysis<TargetPassConfig>();
466   auto &TM = TPC.getTM<RISCVTargetMachine>();
467   ST = &TM.getSubtarget<RISCVSubtarget>(F);
468   if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors())
469     return false;
470 
471   TLI = ST->getTargetLowering();
472   DL = &F.getParent()->getDataLayout();
473   LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
474 
475   StridedAddrs.clear();
476 
477   SmallVector<IntrinsicInst *, 4> Gathers;
478   SmallVector<IntrinsicInst *, 4> Scatters;
479 
480   bool Changed = false;
481 
482   for (BasicBlock &BB : F) {
483     for (Instruction &I : BB) {
484       IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
485       if (II && II->getIntrinsicID() == Intrinsic::masked_gather &&
486           isa<FixedVectorType>(II->getType())) {
487         Gathers.push_back(II);
488       } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter &&
489                  isa<FixedVectorType>(II->getArgOperand(0)->getType())) {
490         Scatters.push_back(II);
491       }
492     }
493   }
494 
495   // Rewrite gather/scatter to form strided load/store if possible.
496   for (auto *II : Gathers)
497     Changed |= tryCreateStridedLoadStore(
498         II, II->getType(), II->getArgOperand(0), II->getArgOperand(1));
499   for (auto *II : Scatters)
500     Changed |=
501         tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(),
502                                   II->getArgOperand(1), II->getArgOperand(2));
503 
504   // Remove any dead phis.
505   while (!MaybeDeadPHIs.empty()) {
506     if (auto *Phi = dyn_cast_or_null<PHINode>(MaybeDeadPHIs.pop_back_val()))
507       RecursivelyDeleteDeadPHINode(Phi);
508   }
509 
510   return Changed;
511 }
512