xref: /freebsd/contrib/llvm-project/llvm/lib/Target/ARM/MVETailPredication.cpp (revision da759cfa320d5076b075d15ff3f00ab3ba5634fd)
1 //===- MVETailPredication.cpp - MVE Tail Predication ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Armv8.1m introduced MVE, M-Profile Vector Extension, and low-overhead
11 /// branches to help accelerate DSP applications. These two extensions can be
12 /// combined to provide implicit vector predication within a low-overhead loop.
13 /// The HardwareLoops pass inserts intrinsics identifying loops that the
14 /// backend will attempt to convert into a low-overhead loop. The vectorizer is
15 /// responsible for generating a vectorized loop in which the lanes are
16 /// predicated upon the iteration counter. This pass looks at these predicated
17 /// vector loops, that are targets for low-overhead loops, and prepares it for
18 /// code generation. Once the vectorizer has produced a masked loop, there's a
19 /// couple of final forms:
20 /// - A tail-predicated loop, with implicit predication.
21 /// - A loop containing multiple VCPT instructions, predicating multiple VPT
22 ///   blocks of instructions operating on different vector types.
23 ///
24 /// This pass inserts the inserts the VCTP intrinsic to represent the effect of
25 /// tail predication. This will be picked up by the ARM Low-overhead loop pass,
26 /// which performs the final transformation to a DLSTP or WLSTP tail-predicated
27 /// loop.
28 
29 #include "ARM.h"
30 #include "ARMSubtarget.h"
31 #include "llvm/Analysis/LoopInfo.h"
32 #include "llvm/Analysis/LoopPass.h"
33 #include "llvm/Analysis/ScalarEvolution.h"
34 #include "llvm/Analysis/ScalarEvolutionExpander.h"
35 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
36 #include "llvm/Analysis/TargetTransformInfo.h"
37 #include "llvm/CodeGen/TargetPassConfig.h"
38 #include "llvm/IR/IRBuilder.h"
39 #include "llvm/IR/Instructions.h"
40 #include "llvm/IR/IntrinsicsARM.h"
41 #include "llvm/IR/PatternMatch.h"
42 #include "llvm/Support/Debug.h"
43 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
44 
45 using namespace llvm;
46 
47 #define DEBUG_TYPE "mve-tail-predication"
48 #define DESC "Transform predicated vector loops to use MVE tail predication"
49 
50 cl::opt<bool>
51 DisableTailPredication("disable-mve-tail-predication", cl::Hidden,
52                        cl::init(true),
53                        cl::desc("Disable MVE Tail Predication"));
54 namespace {
55 
56 class MVETailPredication : public LoopPass {
57   SmallVector<IntrinsicInst*, 4> MaskedInsts;
58   Loop *L = nullptr;
59   ScalarEvolution *SE = nullptr;
60   TargetTransformInfo *TTI = nullptr;
61 
62 public:
63   static char ID;
64 
65   MVETailPredication() : LoopPass(ID) { }
66 
67   void getAnalysisUsage(AnalysisUsage &AU) const override {
68     AU.addRequired<ScalarEvolutionWrapperPass>();
69     AU.addRequired<LoopInfoWrapperPass>();
70     AU.addRequired<TargetPassConfig>();
71     AU.addRequired<TargetTransformInfoWrapperPass>();
72     AU.addPreserved<LoopInfoWrapperPass>();
73     AU.setPreservesCFG();
74   }
75 
76   bool runOnLoop(Loop *L, LPPassManager&) override;
77 
78 private:
79 
80   /// Perform the relevant checks on the loop and convert if possible.
81   bool TryConvert(Value *TripCount);
82 
83   /// Return whether this is a vectorized loop, that contains masked
84   /// load/stores.
85   bool IsPredicatedVectorLoop();
86 
87   /// Compute a value for the total number of elements that the predicated
88   /// loop will process.
89   Value *ComputeElements(Value *TripCount, VectorType *VecTy);
90 
91   /// Is the icmp that generates an i1 vector, based upon a loop counter
92   /// and a limit that is defined outside the loop.
93   bool isTailPredicate(Instruction *Predicate, Value *NumElements);
94 
95   /// Insert the intrinsic to represent the effect of tail predication.
96   void InsertVCTPIntrinsic(Instruction *Predicate,
97                            DenseMap<Instruction*, Instruction*> &NewPredicates,
98                            VectorType *VecTy,
99                            Value *NumElements);
100 };
101 
102 } // end namespace
103 
104 static bool IsDecrement(Instruction &I) {
105   auto *Call = dyn_cast<IntrinsicInst>(&I);
106   if (!Call)
107     return false;
108 
109   Intrinsic::ID ID = Call->getIntrinsicID();
110   return ID == Intrinsic::loop_decrement_reg;
111 }
112 
113 static bool IsMasked(Instruction *I) {
114   auto *Call = dyn_cast<IntrinsicInst>(I);
115   if (!Call)
116     return false;
117 
118   Intrinsic::ID ID = Call->getIntrinsicID();
119   // TODO: Support gather/scatter expand/compress operations.
120   return ID == Intrinsic::masked_store || ID == Intrinsic::masked_load;
121 }
122 
123 bool MVETailPredication::runOnLoop(Loop *L, LPPassManager&) {
124   if (skipLoop(L) || DisableTailPredication)
125     return false;
126 
127   Function &F = *L->getHeader()->getParent();
128   auto &TPC = getAnalysis<TargetPassConfig>();
129   auto &TM = TPC.getTM<TargetMachine>();
130   auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
131   TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
132   SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
133   this->L = L;
134 
135   // The MVE and LOB extensions are combined to enable tail-predication, but
136   // there's nothing preventing us from generating VCTP instructions for v8.1m.
137   if (!ST->hasMVEIntegerOps() || !ST->hasV8_1MMainlineOps()) {
138     LLVM_DEBUG(dbgs() << "ARM TP: Not a v8.1m.main+mve target.\n");
139     return false;
140   }
141 
142   BasicBlock *Preheader = L->getLoopPreheader();
143   if (!Preheader)
144     return false;
145 
146   auto FindLoopIterations = [](BasicBlock *BB) -> IntrinsicInst* {
147     for (auto &I : *BB) {
148       auto *Call = dyn_cast<IntrinsicInst>(&I);
149       if (!Call)
150         continue;
151 
152       Intrinsic::ID ID = Call->getIntrinsicID();
153       if (ID == Intrinsic::set_loop_iterations ||
154           ID == Intrinsic::test_set_loop_iterations)
155         return cast<IntrinsicInst>(&I);
156     }
157     return nullptr;
158   };
159 
160   // Look for the hardware loop intrinsic that sets the iteration count.
161   IntrinsicInst *Setup = FindLoopIterations(Preheader);
162 
163   // The test.set iteration could live in the pre-preheader.
164   if (!Setup) {
165     if (!Preheader->getSinglePredecessor())
166       return false;
167     Setup = FindLoopIterations(Preheader->getSinglePredecessor());
168     if (!Setup)
169       return false;
170   }
171 
172   // Search for the hardware loop intrinic that decrements the loop counter.
173   IntrinsicInst *Decrement = nullptr;
174   for (auto *BB : L->getBlocks()) {
175     for (auto &I : *BB) {
176       if (IsDecrement(I)) {
177         Decrement = cast<IntrinsicInst>(&I);
178         break;
179       }
180     }
181   }
182 
183   if (!Decrement)
184     return false;
185 
186   LLVM_DEBUG(dbgs() << "ARM TP: Running on Loop: " << *L << *Setup << "\n"
187              << *Decrement << "\n");
188   return TryConvert(Setup->getArgOperand(0));
189 }
190 
191 bool MVETailPredication::isTailPredicate(Instruction *I, Value *NumElements) {
192   // Look for the following:
193 
194   // %trip.count.minus.1 = add i32 %N, -1
195   // %broadcast.splatinsert10 = insertelement <4 x i32> undef,
196   //                                          i32 %trip.count.minus.1, i32 0
197   // %broadcast.splat11 = shufflevector <4 x i32> %broadcast.splatinsert10,
198   //                                    <4 x i32> undef,
199   //                                    <4 x i32> zeroinitializer
200   // ...
201   // ...
202   // %index = phi i32
203   // %broadcast.splatinsert = insertelement <4 x i32> undef, i32 %index, i32 0
204   // %broadcast.splat = shufflevector <4 x i32> %broadcast.splatinsert,
205   //                                  <4 x i32> undef,
206   //                                  <4 x i32> zeroinitializer
207   // %induction = add <4 x i32> %broadcast.splat, <i32 0, i32 1, i32 2, i32 3>
208   // %pred = icmp ule <4 x i32> %induction, %broadcast.splat11
209 
210   // And return whether V == %pred.
211 
212   using namespace PatternMatch;
213 
214   CmpInst::Predicate Pred;
215   Instruction *Shuffle = nullptr;
216   Instruction *Induction = nullptr;
217 
218   // The vector icmp
219   if (!match(I, m_ICmp(Pred, m_Instruction(Induction),
220                        m_Instruction(Shuffle))) ||
221       Pred != ICmpInst::ICMP_ULE)
222     return false;
223 
224   // First find the stuff outside the loop which is setting up the limit
225   // vector....
226   // The invariant shuffle that broadcast the limit into a vector.
227   Instruction *Insert = nullptr;
228   if (!match(Shuffle, m_ShuffleVector(m_Instruction(Insert), m_Undef(),
229                                       m_Zero())))
230     return false;
231 
232   // Insert the limit into a vector.
233   Instruction *BECount = nullptr;
234   if (!match(Insert, m_InsertElement(m_Undef(), m_Instruction(BECount),
235                                      m_Zero())))
236     return false;
237 
238   // The limit calculation, backedge count.
239   Value *TripCount = nullptr;
240   if (!match(BECount, m_Add(m_Value(TripCount), m_AllOnes())))
241     return false;
242 
243   if (TripCount != NumElements || !L->isLoopInvariant(BECount))
244     return false;
245 
246   // Now back to searching inside the loop body...
247   // Find the add with takes the index iv and adds a constant vector to it.
248   Instruction *BroadcastSplat = nullptr;
249   Constant *Const = nullptr;
250   if (!match(Induction, m_Add(m_Instruction(BroadcastSplat),
251                               m_Constant(Const))))
252    return false;
253 
254   // Check that we're adding <0, 1, 2, 3...
255   if (auto *CDS = dyn_cast<ConstantDataSequential>(Const)) {
256     for (unsigned i = 0; i < CDS->getNumElements(); ++i) {
257       if (CDS->getElementAsInteger(i) != i)
258         return false;
259     }
260   } else
261     return false;
262 
263   // The shuffle which broadcasts the index iv into a vector.
264   if (!match(BroadcastSplat, m_ShuffleVector(m_Instruction(Insert), m_Undef(),
265                                              m_Zero())))
266     return false;
267 
268   // The insert element which initialises a vector with the index iv.
269   Instruction *IV = nullptr;
270   if (!match(Insert, m_InsertElement(m_Undef(), m_Instruction(IV), m_Zero())))
271     return false;
272 
273   // The index iv.
274   auto *Phi = dyn_cast<PHINode>(IV);
275   if (!Phi)
276     return false;
277 
278   // TODO: Don't think we need to check the entry value.
279   Value *OnEntry = Phi->getIncomingValueForBlock(L->getLoopPreheader());
280   if (!match(OnEntry, m_Zero()))
281     return false;
282 
283   Value *InLoop = Phi->getIncomingValueForBlock(L->getLoopLatch());
284   unsigned Lanes = cast<VectorType>(Insert->getType())->getNumElements();
285 
286   Instruction *LHS = nullptr;
287   if (!match(InLoop, m_Add(m_Instruction(LHS), m_SpecificInt(Lanes))))
288     return false;
289 
290   return LHS == Phi;
291 }
292 
293 static VectorType* getVectorType(IntrinsicInst *I) {
294   unsigned TypeOp = I->getIntrinsicID() == Intrinsic::masked_load ? 0 : 1;
295   auto *PtrTy = cast<PointerType>(I->getOperand(TypeOp)->getType());
296   return cast<VectorType>(PtrTy->getElementType());
297 }
298 
299 bool MVETailPredication::IsPredicatedVectorLoop() {
300   // Check that the loop contains at least one masked load/store intrinsic.
301   // We only support 'normal' vector instructions - other than masked
302   // load/stores.
303   for (auto *BB : L->getBlocks()) {
304     for (auto &I : *BB) {
305       if (IsMasked(&I)) {
306         VectorType *VecTy = getVectorType(cast<IntrinsicInst>(&I));
307         unsigned Lanes = VecTy->getNumElements();
308         unsigned ElementWidth = VecTy->getScalarSizeInBits();
309         // MVE vectors are 128-bit, but don't support 128 x i1.
310         // TODO: Can we support vectors larger than 128-bits?
311         unsigned MaxWidth = TTI->getRegisterBitWidth(true);
312         if (Lanes * ElementWidth > MaxWidth || Lanes == MaxWidth)
313           return false;
314         MaskedInsts.push_back(cast<IntrinsicInst>(&I));
315       } else if (auto *Int = dyn_cast<IntrinsicInst>(&I)) {
316         for (auto &U : Int->args()) {
317           if (isa<VectorType>(U->getType()))
318             return false;
319         }
320       }
321     }
322   }
323 
324   return !MaskedInsts.empty();
325 }
326 
327 Value* MVETailPredication::ComputeElements(Value *TripCount,
328                                            VectorType *VecTy) {
329   const SCEV *TripCountSE = SE->getSCEV(TripCount);
330   ConstantInt *VF = ConstantInt::get(cast<IntegerType>(TripCount->getType()),
331                                      VecTy->getNumElements());
332 
333   if (VF->equalsInt(1))
334     return nullptr;
335 
336   // TODO: Support constant trip counts.
337   auto VisitAdd = [&](const SCEVAddExpr *S) -> const SCEVMulExpr* {
338     if (auto *Const = dyn_cast<SCEVConstant>(S->getOperand(0))) {
339       if (Const->getAPInt() != -VF->getValue())
340         return nullptr;
341     } else
342       return nullptr;
343     return dyn_cast<SCEVMulExpr>(S->getOperand(1));
344   };
345 
346   auto VisitMul = [&](const SCEVMulExpr *S) -> const SCEVUDivExpr* {
347     if (auto *Const = dyn_cast<SCEVConstant>(S->getOperand(0))) {
348       if (Const->getValue() != VF)
349         return nullptr;
350     } else
351       return nullptr;
352     return dyn_cast<SCEVUDivExpr>(S->getOperand(1));
353   };
354 
355   auto VisitDiv = [&](const SCEVUDivExpr *S) -> const SCEV* {
356     if (auto *Const = dyn_cast<SCEVConstant>(S->getRHS())) {
357       if (Const->getValue() != VF)
358         return nullptr;
359     } else
360       return nullptr;
361 
362     if (auto *RoundUp = dyn_cast<SCEVAddExpr>(S->getLHS())) {
363       if (auto *Const = dyn_cast<SCEVConstant>(RoundUp->getOperand(0))) {
364         if (Const->getAPInt() != (VF->getValue() - 1))
365           return nullptr;
366       } else
367         return nullptr;
368 
369       return RoundUp->getOperand(1);
370     }
371     return nullptr;
372   };
373 
374   // TODO: Can we use SCEV helpers, such as findArrayDimensions, and friends to
375   // determine the numbers of elements instead? Looks like this is what is used
376   // for delinearization, but I'm not sure if it can be applied to the
377   // vectorized form - at least not without a bit more work than I feel
378   // comfortable with.
379 
380   // Search for Elems in the following SCEV:
381   // (1 + ((-VF + (VF * (((VF - 1) + %Elems) /u VF))<nuw>) /u VF))<nuw><nsw>
382   const SCEV *Elems = nullptr;
383   if (auto *TC = dyn_cast<SCEVAddExpr>(TripCountSE))
384     if (auto *Div = dyn_cast<SCEVUDivExpr>(TC->getOperand(1)))
385       if (auto *Add = dyn_cast<SCEVAddExpr>(Div->getLHS()))
386         if (auto *Mul = VisitAdd(Add))
387           if (auto *Div = VisitMul(Mul))
388             if (auto *Res = VisitDiv(Div))
389               Elems = Res;
390 
391   if (!Elems)
392     return nullptr;
393 
394   Instruction *InsertPt = L->getLoopPreheader()->getTerminator();
395   if (!isSafeToExpandAt(Elems, InsertPt, *SE))
396     return nullptr;
397 
398   auto DL = L->getHeader()->getModule()->getDataLayout();
399   SCEVExpander Expander(*SE, DL, "elements");
400   return Expander.expandCodeFor(Elems, Elems->getType(), InsertPt);
401 }
402 
403 // Look through the exit block to see whether there's a duplicate predicate
404 // instruction. This can happen when we need to perform a select on values
405 // from the last and previous iteration. Instead of doing a straight
406 // replacement of that predicate with the vctp, clone the vctp and place it
407 // in the block. This means that the VPR doesn't have to be live into the
408 // exit block which should make it easier to convert this loop into a proper
409 // tail predicated loop.
410 static void Cleanup(DenseMap<Instruction*, Instruction*> &NewPredicates,
411                     SetVector<Instruction*> &MaybeDead, Loop *L) {
412   BasicBlock *Exit = L->getUniqueExitBlock();
413   if (!Exit) {
414     LLVM_DEBUG(dbgs() << "ARM TP: can't find loop exit block\n");
415     return;
416   }
417 
418   for (auto &Pair : NewPredicates) {
419     Instruction *OldPred = Pair.first;
420     Instruction *NewPred = Pair.second;
421 
422     for (auto &I : *Exit) {
423       if (I.isSameOperationAs(OldPred)) {
424         Instruction *PredClone = NewPred->clone();
425         PredClone->insertBefore(&I);
426         I.replaceAllUsesWith(PredClone);
427         MaybeDead.insert(&I);
428         LLVM_DEBUG(dbgs() << "ARM TP: replacing: "; I.dump();
429                    dbgs() << "ARM TP: with:      "; PredClone->dump());
430         break;
431       }
432     }
433   }
434 
435   // Drop references and add operands to check for dead.
436   SmallPtrSet<Instruction*, 4> Dead;
437   while (!MaybeDead.empty()) {
438     auto *I = MaybeDead.front();
439     MaybeDead.remove(I);
440     if (I->hasNUsesOrMore(1))
441       continue;
442 
443     for (auto &U : I->operands()) {
444       if (auto *OpI = dyn_cast<Instruction>(U))
445         MaybeDead.insert(OpI);
446     }
447     I->dropAllReferences();
448     Dead.insert(I);
449   }
450 
451   for (auto *I : Dead) {
452     LLVM_DEBUG(dbgs() << "ARM TP: removing dead insn: "; I->dump());
453     I->eraseFromParent();
454   }
455 
456   for (auto I : L->blocks())
457     DeleteDeadPHIs(I);
458 }
459 
460 void MVETailPredication::InsertVCTPIntrinsic(Instruction *Predicate,
461     DenseMap<Instruction*, Instruction*> &NewPredicates,
462     VectorType *VecTy, Value *NumElements) {
463   IRBuilder<> Builder(L->getHeader()->getFirstNonPHI());
464   Module *M = L->getHeader()->getModule();
465   Type *Ty = IntegerType::get(M->getContext(), 32);
466 
467   // Insert a phi to count the number of elements processed by the loop.
468   PHINode *Processed = Builder.CreatePHI(Ty, 2);
469   Processed->addIncoming(NumElements, L->getLoopPreheader());
470 
471   // Insert the intrinsic to represent the effect of tail predication.
472   Builder.SetInsertPoint(cast<Instruction>(Predicate));
473   ConstantInt *Factor =
474     ConstantInt::get(cast<IntegerType>(Ty), VecTy->getNumElements());
475 
476   Intrinsic::ID VCTPID;
477   switch (VecTy->getNumElements()) {
478   default:
479     llvm_unreachable("unexpected number of lanes");
480   case 4:  VCTPID = Intrinsic::arm_mve_vctp32; break;
481   case 8:  VCTPID = Intrinsic::arm_mve_vctp16; break;
482   case 16: VCTPID = Intrinsic::arm_mve_vctp8; break;
483 
484     // FIXME: vctp64 currently not supported because the predicate
485     // vector wants to be <2 x i1>, but v2i1 is not a legal MVE
486     // type, so problems happen at isel time.
487     // Intrinsic::arm_mve_vctp64 exists for ACLE intrinsics
488     // purposes, but takes a v4i1 instead of a v2i1.
489   }
490   Function *VCTP = Intrinsic::getDeclaration(M, VCTPID);
491   Value *TailPredicate = Builder.CreateCall(VCTP, Processed);
492   Predicate->replaceAllUsesWith(TailPredicate);
493   NewPredicates[Predicate] = cast<Instruction>(TailPredicate);
494 
495   // Add the incoming value to the new phi.
496   // TODO: This add likely already exists in the loop.
497   Value *Remaining = Builder.CreateSub(Processed, Factor);
498   Processed->addIncoming(Remaining, L->getLoopLatch());
499   LLVM_DEBUG(dbgs() << "ARM TP: Insert processed elements phi: "
500              << *Processed << "\n"
501              << "ARM TP: Inserted VCTP: " << *TailPredicate << "\n");
502 }
503 
504 bool MVETailPredication::TryConvert(Value *TripCount) {
505   if (!IsPredicatedVectorLoop()) {
506     LLVM_DEBUG(dbgs() << "ARM TP: no masked instructions in loop");
507     return false;
508   }
509 
510   LLVM_DEBUG(dbgs() << "ARM TP: Found predicated vector loop.\n");
511 
512   // Walk through the masked intrinsics and try to find whether the predicate
513   // operand is generated from an induction variable.
514   SetVector<Instruction*> Predicates;
515   DenseMap<Instruction*, Instruction*> NewPredicates;
516 
517   for (auto *I : MaskedInsts) {
518     Intrinsic::ID ID = I->getIntrinsicID();
519     unsigned PredOp = ID == Intrinsic::masked_load ? 2 : 3;
520     auto *Predicate = dyn_cast<Instruction>(I->getArgOperand(PredOp));
521     if (!Predicate || Predicates.count(Predicate))
522       continue;
523 
524     VectorType *VecTy = getVectorType(I);
525     Value *NumElements = ComputeElements(TripCount, VecTy);
526     if (!NumElements)
527       continue;
528 
529     if (!isTailPredicate(Predicate, NumElements)) {
530       LLVM_DEBUG(dbgs() << "ARM TP: Not tail predicate: " << *Predicate << "\n");
531       continue;
532     }
533 
534     LLVM_DEBUG(dbgs() << "ARM TP: Found tail predicate: " << *Predicate << "\n");
535     Predicates.insert(Predicate);
536 
537     InsertVCTPIntrinsic(Predicate, NewPredicates, VecTy, NumElements);
538   }
539 
540   // Now clean up.
541   Cleanup(NewPredicates, Predicates, L);
542   return true;
543 }
544 
545 Pass *llvm::createMVETailPredicationPass() {
546   return new MVETailPredication();
547 }
548 
549 char MVETailPredication::ID = 0;
550 
551 INITIALIZE_PASS_BEGIN(MVETailPredication, DEBUG_TYPE, DESC, false, false)
552 INITIALIZE_PASS_END(MVETailPredication, DEBUG_TYPE, DESC, false, false)
553