15ffd83dbSDimitry Andric //===- MVETailPredication.cpp - MVE Tail Predication ------------*- C++ -*-===//
28bcb0991SDimitry Andric //
38bcb0991SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48bcb0991SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
58bcb0991SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68bcb0991SDimitry Andric //
78bcb0991SDimitry Andric //===----------------------------------------------------------------------===//
88bcb0991SDimitry Andric //
98bcb0991SDimitry Andric /// \file
108bcb0991SDimitry Andric /// Armv8.1m introduced MVE, M-Profile Vector Extension, and low-overhead
115ffd83dbSDimitry Andric /// branches to help accelerate DSP applications. These two extensions,
125ffd83dbSDimitry Andric /// combined with a new form of predication called tail-predication, can be used
135ffd83dbSDimitry Andric /// to provide implicit vector predication within a low-overhead loop.
145ffd83dbSDimitry Andric /// This is implicit because the predicate of active/inactive lanes is
155ffd83dbSDimitry Andric /// calculated by hardware, and thus does not need to be explicitly passed
165ffd83dbSDimitry Andric /// to vector instructions. The instructions responsible for this are the
175ffd83dbSDimitry Andric /// DLSTP and WLSTP instructions, which setup a tail-predicated loop and the
185ffd83dbSDimitry Andric /// the total number of data elements processed by the loop. The loop-end
195ffd83dbSDimitry Andric /// LETP instruction is responsible for decrementing and setting the remaining
205ffd83dbSDimitry Andric /// elements to be processed and generating the mask of active lanes.
215ffd83dbSDimitry Andric ///
228bcb0991SDimitry Andric /// The HardwareLoops pass inserts intrinsics identifying loops that the
238bcb0991SDimitry Andric /// backend will attempt to convert into a low-overhead loop. The vectorizer is
248bcb0991SDimitry Andric /// responsible for generating a vectorized loop in which the lanes are
25e8d8bef9SDimitry Andric /// predicated upon an get.active.lane.mask intrinsic. This pass looks at these
26e8d8bef9SDimitry Andric /// get.active.lane.mask intrinsic and attempts to convert them to VCTP
27e8d8bef9SDimitry Andric /// instructions. This will be picked up by the ARM Low-overhead loop pass later
28e8d8bef9SDimitry Andric /// in the backend, which performs the final transformation to a DLSTP or WLSTP
29e8d8bef9SDimitry Andric /// tail-predicated loop.
30e8d8bef9SDimitry Andric //
31e8d8bef9SDimitry Andric //===----------------------------------------------------------------------===//
328bcb0991SDimitry Andric
33480093f4SDimitry Andric #include "ARM.h"
34480093f4SDimitry Andric #include "ARMSubtarget.h"
355ffd83dbSDimitry Andric #include "ARMTargetTransformInfo.h"
368bcb0991SDimitry Andric #include "llvm/Analysis/LoopInfo.h"
378bcb0991SDimitry Andric #include "llvm/Analysis/LoopPass.h"
388bcb0991SDimitry Andric #include "llvm/Analysis/ScalarEvolution.h"
398bcb0991SDimitry Andric #include "llvm/Analysis/ScalarEvolutionExpressions.h"
405ffd83dbSDimitry Andric #include "llvm/Analysis/TargetLibraryInfo.h"
418bcb0991SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h"
4206c3fb27SDimitry Andric #include "llvm/Analysis/ValueTracking.h"
438bcb0991SDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h"
448bcb0991SDimitry Andric #include "llvm/IR/IRBuilder.h"
45480093f4SDimitry Andric #include "llvm/IR/Instructions.h"
46480093f4SDimitry Andric #include "llvm/IR/IntrinsicsARM.h"
478bcb0991SDimitry Andric #include "llvm/IR/PatternMatch.h"
485ffd83dbSDimitry Andric #include "llvm/InitializePasses.h"
498bcb0991SDimitry Andric #include "llvm/Support/Debug.h"
508bcb0991SDimitry Andric #include "llvm/Transforms/Utils/BasicBlockUtils.h"
51e8d8bef9SDimitry Andric #include "llvm/Transforms/Utils/Local.h"
525ffd83dbSDimitry Andric #include "llvm/Transforms/Utils/LoopUtils.h"
535ffd83dbSDimitry Andric #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
548bcb0991SDimitry Andric
558bcb0991SDimitry Andric using namespace llvm;
568bcb0991SDimitry Andric
578bcb0991SDimitry Andric #define DEBUG_TYPE "mve-tail-predication"
588bcb0991SDimitry Andric #define DESC "Transform predicated vector loops to use MVE tail predication"
598bcb0991SDimitry Andric
605ffd83dbSDimitry Andric cl::opt<TailPredication::Mode> EnableTailPredication(
61e8d8bef9SDimitry Andric "tail-predication", cl::desc("MVE tail-predication pass options"),
62e8d8bef9SDimitry Andric cl::init(TailPredication::Enabled),
635ffd83dbSDimitry Andric cl::values(clEnumValN(TailPredication::Disabled, "disabled",
645ffd83dbSDimitry Andric "Don't tail-predicate loops"),
655ffd83dbSDimitry Andric clEnumValN(TailPredication::EnabledNoReductions,
665ffd83dbSDimitry Andric "enabled-no-reductions",
675ffd83dbSDimitry Andric "Enable tail-predication, but not for reduction loops"),
685ffd83dbSDimitry Andric clEnumValN(TailPredication::Enabled,
695ffd83dbSDimitry Andric "enabled",
705ffd83dbSDimitry Andric "Enable tail-predication, including reduction loops"),
715ffd83dbSDimitry Andric clEnumValN(TailPredication::ForceEnabledNoReductions,
725ffd83dbSDimitry Andric "force-enabled-no-reductions",
735ffd83dbSDimitry Andric "Enable tail-predication, but not for reduction loops, "
745ffd83dbSDimitry Andric "and force this which might be unsafe"),
755ffd83dbSDimitry Andric clEnumValN(TailPredication::ForceEnabled,
765ffd83dbSDimitry Andric "force-enabled",
775ffd83dbSDimitry Andric "Enable tail-predication, including reduction loops, "
785ffd83dbSDimitry Andric "and force this which might be unsafe")));
795ffd83dbSDimitry Andric
805ffd83dbSDimitry Andric
818bcb0991SDimitry Andric namespace {
828bcb0991SDimitry Andric
838bcb0991SDimitry Andric class MVETailPredication : public LoopPass {
848bcb0991SDimitry Andric SmallVector<IntrinsicInst*, 4> MaskedInsts;
858bcb0991SDimitry Andric Loop *L = nullptr;
868bcb0991SDimitry Andric ScalarEvolution *SE = nullptr;
878bcb0991SDimitry Andric TargetTransformInfo *TTI = nullptr;
885ffd83dbSDimitry Andric const ARMSubtarget *ST = nullptr;
898bcb0991SDimitry Andric
908bcb0991SDimitry Andric public:
918bcb0991SDimitry Andric static char ID;
928bcb0991SDimitry Andric
MVETailPredication()938bcb0991SDimitry Andric MVETailPredication() : LoopPass(ID) { }
948bcb0991SDimitry Andric
getAnalysisUsage(AnalysisUsage & AU) const958bcb0991SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override {
968bcb0991SDimitry Andric AU.addRequired<ScalarEvolutionWrapperPass>();
978bcb0991SDimitry Andric AU.addRequired<LoopInfoWrapperPass>();
988bcb0991SDimitry Andric AU.addRequired<TargetPassConfig>();
998bcb0991SDimitry Andric AU.addRequired<TargetTransformInfoWrapperPass>();
1008bcb0991SDimitry Andric AU.addPreserved<LoopInfoWrapperPass>();
1018bcb0991SDimitry Andric AU.setPreservesCFG();
1028bcb0991SDimitry Andric }
1038bcb0991SDimitry Andric
1048bcb0991SDimitry Andric bool runOnLoop(Loop *L, LPPassManager&) override;
1058bcb0991SDimitry Andric
1068bcb0991SDimitry Andric private:
107e8d8bef9SDimitry Andric /// Perform the relevant checks on the loop and convert active lane masks if
108e8d8bef9SDimitry Andric /// possible.
109e8d8bef9SDimitry Andric bool TryConvertActiveLaneMask(Value *TripCount);
1108bcb0991SDimitry Andric
111e8d8bef9SDimitry Andric /// Perform several checks on the arguments of @llvm.get.active.lane.mask
112e8d8bef9SDimitry Andric /// intrinsic. E.g., check that the loop induction variable and the element
113e8d8bef9SDimitry Andric /// count are of the form we expect, and also perform overflow checks for
114e8d8bef9SDimitry Andric /// the new expressions that are created.
11506c3fb27SDimitry Andric const SCEV *IsSafeActiveMask(IntrinsicInst *ActiveLaneMask, Value *TripCount);
116480093f4SDimitry Andric
117480093f4SDimitry Andric /// Insert the intrinsic to represent the effect of tail predication.
11806c3fb27SDimitry Andric void InsertVCTPIntrinsic(IntrinsicInst *ActiveLaneMask, Value *Start);
1198bcb0991SDimitry Andric };
1208bcb0991SDimitry Andric
1218bcb0991SDimitry Andric } // end namespace
1228bcb0991SDimitry Andric
runOnLoop(Loop * L,LPPassManager &)1238bcb0991SDimitry Andric bool MVETailPredication::runOnLoop(Loop *L, LPPassManager&) {
1245ffd83dbSDimitry Andric if (skipLoop(L) || !EnableTailPredication)
1258bcb0991SDimitry Andric return false;
1268bcb0991SDimitry Andric
1275ffd83dbSDimitry Andric MaskedInsts.clear();
1288bcb0991SDimitry Andric Function &F = *L->getHeader()->getParent();
1298bcb0991SDimitry Andric auto &TPC = getAnalysis<TargetPassConfig>();
1308bcb0991SDimitry Andric auto &TM = TPC.getTM<TargetMachine>();
1315ffd83dbSDimitry Andric ST = &TM.getSubtarget<ARMSubtarget>(F);
1328bcb0991SDimitry Andric TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
1338bcb0991SDimitry Andric SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
1348bcb0991SDimitry Andric this->L = L;
1358bcb0991SDimitry Andric
1368bcb0991SDimitry Andric // The MVE and LOB extensions are combined to enable tail-predication, but
1378bcb0991SDimitry Andric // there's nothing preventing us from generating VCTP instructions for v8.1m.
1388bcb0991SDimitry Andric if (!ST->hasMVEIntegerOps() || !ST->hasV8_1MMainlineOps()) {
139480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Not a v8.1m.main+mve target.\n");
1408bcb0991SDimitry Andric return false;
1418bcb0991SDimitry Andric }
1428bcb0991SDimitry Andric
1438bcb0991SDimitry Andric BasicBlock *Preheader = L->getLoopPreheader();
1448bcb0991SDimitry Andric if (!Preheader)
1458bcb0991SDimitry Andric return false;
1468bcb0991SDimitry Andric
1478bcb0991SDimitry Andric auto FindLoopIterations = [](BasicBlock *BB) -> IntrinsicInst* {
1488bcb0991SDimitry Andric for (auto &I : *BB) {
1498bcb0991SDimitry Andric auto *Call = dyn_cast<IntrinsicInst>(&I);
1508bcb0991SDimitry Andric if (!Call)
1518bcb0991SDimitry Andric continue;
1528bcb0991SDimitry Andric
1538bcb0991SDimitry Andric Intrinsic::ID ID = Call->getIntrinsicID();
154e8d8bef9SDimitry Andric if (ID == Intrinsic::start_loop_iterations ||
155fe6060f1SDimitry Andric ID == Intrinsic::test_start_loop_iterations)
1568bcb0991SDimitry Andric return cast<IntrinsicInst>(&I);
1578bcb0991SDimitry Andric }
1588bcb0991SDimitry Andric return nullptr;
1598bcb0991SDimitry Andric };
1608bcb0991SDimitry Andric
1618bcb0991SDimitry Andric // Look for the hardware loop intrinsic that sets the iteration count.
1628bcb0991SDimitry Andric IntrinsicInst *Setup = FindLoopIterations(Preheader);
1638bcb0991SDimitry Andric
1648bcb0991SDimitry Andric // The test.set iteration could live in the pre-preheader.
1658bcb0991SDimitry Andric if (!Setup) {
1668bcb0991SDimitry Andric if (!Preheader->getSinglePredecessor())
1678bcb0991SDimitry Andric return false;
1688bcb0991SDimitry Andric Setup = FindLoopIterations(Preheader->getSinglePredecessor());
1698bcb0991SDimitry Andric if (!Setup)
1708bcb0991SDimitry Andric return false;
1718bcb0991SDimitry Andric }
1728bcb0991SDimitry Andric
173e8d8bef9SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Running on Loop: " << *L << *Setup << "\n");
1748bcb0991SDimitry Andric
175e8d8bef9SDimitry Andric bool Changed = TryConvertActiveLaneMask(Setup->getArgOperand(0));
1768bcb0991SDimitry Andric
177e8d8bef9SDimitry Andric return Changed;
1788bcb0991SDimitry Andric }
1798bcb0991SDimitry Andric
1805ffd83dbSDimitry Andric // The active lane intrinsic has this form:
1815ffd83dbSDimitry Andric //
182e8d8bef9SDimitry Andric // @llvm.get.active.lane.mask(IV, TC)
1835ffd83dbSDimitry Andric //
1845ffd83dbSDimitry Andric // Here we perform checks that this intrinsic behaves as expected,
1855ffd83dbSDimitry Andric // which means:
1865ffd83dbSDimitry Andric //
187e8d8bef9SDimitry Andric // 1) Check that the TripCount (TC) belongs to this loop (originally).
188e8d8bef9SDimitry Andric // 2) The element count (TC) needs to be sufficiently large that the decrement
189e8d8bef9SDimitry Andric // of element counter doesn't overflow, which means that we need to prove:
1905ffd83dbSDimitry Andric // ceil(ElementCount / VectorWidth) >= TripCount
1915ffd83dbSDimitry Andric // by rounding up ElementCount up:
1925ffd83dbSDimitry Andric // ((ElementCount + (VectorWidth - 1)) / VectorWidth
1935ffd83dbSDimitry Andric // and evaluate if expression isKnownNonNegative:
1945ffd83dbSDimitry Andric // (((ElementCount + (VectorWidth - 1)) / VectorWidth) - TripCount
1955ffd83dbSDimitry Andric // 3) The IV must be an induction phi with an increment equal to the
1965ffd83dbSDimitry Andric // vector width.
IsSafeActiveMask(IntrinsicInst * ActiveLaneMask,Value * TripCount)19706c3fb27SDimitry Andric const SCEV *MVETailPredication::IsSafeActiveMask(IntrinsicInst *ActiveLaneMask,
198e8d8bef9SDimitry Andric Value *TripCount) {
1995ffd83dbSDimitry Andric bool ForceTailPredication =
2005ffd83dbSDimitry Andric EnableTailPredication == TailPredication::ForceEnabledNoReductions ||
2015ffd83dbSDimitry Andric EnableTailPredication == TailPredication::ForceEnabled;
2025ffd83dbSDimitry Andric
203e8d8bef9SDimitry Andric Value *ElemCount = ActiveLaneMask->getOperand(1);
2044652422eSDimitry Andric bool Changed = false;
2054652422eSDimitry Andric if (!L->makeLoopInvariant(ElemCount, Changed))
20606c3fb27SDimitry Andric return nullptr;
2074652422eSDimitry Andric
208e8d8bef9SDimitry Andric auto *EC= SE->getSCEV(ElemCount);
2095ffd83dbSDimitry Andric auto *TC = SE->getSCEV(TripCount);
210e8d8bef9SDimitry Andric int VectorWidth =
211e8d8bef9SDimitry Andric cast<FixedVectorType>(ActiveLaneMask->getType())->getNumElements();
2120eae32dcSDimitry Andric if (VectorWidth != 2 && VectorWidth != 4 && VectorWidth != 8 &&
2130eae32dcSDimitry Andric VectorWidth != 16)
21406c3fb27SDimitry Andric return nullptr;
215e8d8bef9SDimitry Andric ConstantInt *ConstElemCount = nullptr;
2165ffd83dbSDimitry Andric
217e8d8bef9SDimitry Andric // 1) Smoke tests that the original scalar loop TripCount (TC) belongs to
218e8d8bef9SDimitry Andric // this loop. The scalar tripcount corresponds the number of elements
219e8d8bef9SDimitry Andric // processed by the loop, so we will refer to that from this point on.
220e8d8bef9SDimitry Andric if (!SE->isLoopInvariant(EC, L)) {
221e8d8bef9SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: element count must be loop invariant.\n");
22206c3fb27SDimitry Andric return nullptr;
22306c3fb27SDimitry Andric }
22406c3fb27SDimitry Andric
22506c3fb27SDimitry Andric // 2) Find out if IV is an induction phi. Note that we can't use Loop
22606c3fb27SDimitry Andric // helpers here to get the induction variable, because the hardware loop is
22706c3fb27SDimitry Andric // no longer in loopsimplify form, and also the hwloop intrinsic uses a
22806c3fb27SDimitry Andric // different counter. Using SCEV, we check that the induction is of the
22906c3fb27SDimitry Andric // form i = i + 4, where the increment must be equal to the VectorWidth.
23006c3fb27SDimitry Andric auto *IV = ActiveLaneMask->getOperand(0);
23106c3fb27SDimitry Andric auto *IVExpr = SE->getSCEV(IV);
23206c3fb27SDimitry Andric auto *AddExpr = dyn_cast<SCEVAddRecExpr>(IVExpr);
23306c3fb27SDimitry Andric
23406c3fb27SDimitry Andric if (!AddExpr) {
23506c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: induction not an add expr: "; IVExpr->dump());
23606c3fb27SDimitry Andric return nullptr;
23706c3fb27SDimitry Andric }
23806c3fb27SDimitry Andric // Check that this AddRec is associated with this loop.
23906c3fb27SDimitry Andric if (AddExpr->getLoop() != L) {
24006c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: phi not part of this loop\n");
24106c3fb27SDimitry Andric return nullptr;
24206c3fb27SDimitry Andric }
24306c3fb27SDimitry Andric auto *Step = dyn_cast<SCEVConstant>(AddExpr->getOperand(1));
24406c3fb27SDimitry Andric if (!Step) {
24506c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: induction step is not a constant: ";
24606c3fb27SDimitry Andric AddExpr->getOperand(1)->dump());
24706c3fb27SDimitry Andric return nullptr;
24806c3fb27SDimitry Andric }
24906c3fb27SDimitry Andric auto StepValue = Step->getValue()->getSExtValue();
25006c3fb27SDimitry Andric if (VectorWidth != StepValue) {
25106c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Step value " << StepValue
25206c3fb27SDimitry Andric << " doesn't match vector width " << VectorWidth << "\n");
25306c3fb27SDimitry Andric return nullptr;
2545ffd83dbSDimitry Andric }
2555ffd83dbSDimitry Andric
256e8d8bef9SDimitry Andric if ((ConstElemCount = dyn_cast<ConstantInt>(ElemCount))) {
257e8d8bef9SDimitry Andric ConstantInt *TC = dyn_cast<ConstantInt>(TripCount);
258e8d8bef9SDimitry Andric if (!TC) {
259e8d8bef9SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Constant tripcount expected in "
260e8d8bef9SDimitry Andric "set.loop.iterations\n");
26106c3fb27SDimitry Andric return nullptr;
262e8d8bef9SDimitry Andric }
263e8d8bef9SDimitry Andric
264e8d8bef9SDimitry Andric // Calculate 2 tripcount values and check that they are consistent with
265e8d8bef9SDimitry Andric // each other. The TripCount for a predicated vector loop body is
266e8d8bef9SDimitry Andric // ceil(ElementCount/Width), or floor((ElementCount+Width-1)/Width) as we
267e8d8bef9SDimitry Andric // work it out here.
268e8d8bef9SDimitry Andric uint64_t TC1 = TC->getZExtValue();
269e8d8bef9SDimitry Andric uint64_t TC2 =
270e8d8bef9SDimitry Andric (ConstElemCount->getZExtValue() + VectorWidth - 1) / VectorWidth;
271e8d8bef9SDimitry Andric
272e8d8bef9SDimitry Andric // If the tripcount values are inconsistent, we can't insert the VCTP and
273e8d8bef9SDimitry Andric // trigger tail-predication; keep the intrinsic as a get.active.lane.mask
274e8d8bef9SDimitry Andric // and legalize this.
275e8d8bef9SDimitry Andric if (TC1 != TC2) {
276e8d8bef9SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: inconsistent constant tripcount values: "
277e8d8bef9SDimitry Andric << TC1 << " from set.loop.iterations, and "
278e8d8bef9SDimitry Andric << TC2 << " from get.active.lane.mask\n");
27906c3fb27SDimitry Andric return nullptr;
280e8d8bef9SDimitry Andric }
281e8d8bef9SDimitry Andric } else if (!ForceTailPredication) {
28206c3fb27SDimitry Andric // 3) We need to prove that the sub expression that we create in the
283e8d8bef9SDimitry Andric // tail-predicated loop body, which calculates the remaining elements to be
284e8d8bef9SDimitry Andric // processed, is non-negative, i.e. it doesn't overflow:
2855ffd83dbSDimitry Andric //
286e8d8bef9SDimitry Andric // ((ElementCount + VectorWidth - 1) / VectorWidth) - TripCount >= 0
2875ffd83dbSDimitry Andric //
288e8d8bef9SDimitry Andric // This is true if:
2895ffd83dbSDimitry Andric //
290e8d8bef9SDimitry Andric // TripCount == (ElementCount + VectorWidth - 1) / VectorWidth
2915ffd83dbSDimitry Andric //
292e8d8bef9SDimitry Andric // which what we will be using here.
2935ffd83dbSDimitry Andric //
294e8d8bef9SDimitry Andric auto *VW = SE->getSCEV(ConstantInt::get(TripCount->getType(), VectorWidth));
295e8d8bef9SDimitry Andric // ElementCount + (VW-1):
29606c3fb27SDimitry Andric auto *Start = AddExpr->getStart();
297e8d8bef9SDimitry Andric auto *ECPlusVWMinus1 = SE->getAddExpr(EC,
2985ffd83dbSDimitry Andric SE->getSCEV(ConstantInt::get(TripCount->getType(), VectorWidth - 1)));
2995ffd83dbSDimitry Andric
300e8d8bef9SDimitry Andric // Ceil = ElementCount + (VW-1) / VW
301e8d8bef9SDimitry Andric auto *Ceil = SE->getUDivExpr(ECPlusVWMinus1, VW);
302e8d8bef9SDimitry Andric
303e8d8bef9SDimitry Andric // Prevent unused variable warnings with TC
304e8d8bef9SDimitry Andric (void)TC;
30506c3fb27SDimitry Andric LLVM_DEBUG({
306e8d8bef9SDimitry Andric dbgs() << "ARM TP: Analysing overflow behaviour for:\n";
30706c3fb27SDimitry Andric dbgs() << "ARM TP: - TripCount = " << *TC << "\n";
30806c3fb27SDimitry Andric dbgs() << "ARM TP: - ElemCount = " << *EC << "\n";
30906c3fb27SDimitry Andric dbgs() << "ARM TP: - Start = " << *Start << "\n";
31006c3fb27SDimitry Andric dbgs() << "ARM TP: - BETC = " << *SE->getBackedgeTakenCount(L) << "\n";
311e8d8bef9SDimitry Andric dbgs() << "ARM TP: - VecWidth = " << VectorWidth << "\n";
31206c3fb27SDimitry Andric dbgs() << "ARM TP: - (ElemCount+VW-1) / VW = " << *Ceil << "\n";
31306c3fb27SDimitry Andric });
314e8d8bef9SDimitry Andric
315e8d8bef9SDimitry Andric // As an example, almost all the tripcount expressions (produced by the
316e8d8bef9SDimitry Andric // vectoriser) look like this:
317e8d8bef9SDimitry Andric //
31806c3fb27SDimitry Andric // TC = ((-4 + (4 * ((3 + %N) /u 4))<nuw> - start) /u 4)
319e8d8bef9SDimitry Andric //
320e8d8bef9SDimitry Andric // and "ElementCount + (VW-1) / VW":
321e8d8bef9SDimitry Andric //
322e8d8bef9SDimitry Andric // Ceil = ((3 + %N) /u 4)
323e8d8bef9SDimitry Andric //
324e8d8bef9SDimitry Andric // Check for equality of TC and Ceil by calculating SCEV expression
325e8d8bef9SDimitry Andric // TC - Ceil and test it for zero.
326e8d8bef9SDimitry Andric //
32706c3fb27SDimitry Andric const SCEV *Div = SE->getUDivExpr(
32806c3fb27SDimitry Andric SE->getAddExpr(SE->getMulExpr(Ceil, VW), SE->getNegativeSCEV(VW),
32906c3fb27SDimitry Andric SE->getNegativeSCEV(Start)),
33006c3fb27SDimitry Andric VW);
33106c3fb27SDimitry Andric const SCEV *Sub = SE->getMinusSCEV(SE->getBackedgeTakenCount(L), Div);
33206c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: - Sub = "; Sub->dump());
333e8d8bef9SDimitry Andric
334349cc55cSDimitry Andric // Use context sensitive facts about the path to the loop to refine. This
335349cc55cSDimitry Andric // comes up as the backedge taken count can incorporate context sensitive
336349cc55cSDimitry Andric // reasoning, and our RHS just above doesn't.
337349cc55cSDimitry Andric Sub = SE->applyLoopGuards(Sub, L);
33806c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: - (Guarded) = "; Sub->dump());
339349cc55cSDimitry Andric
340349cc55cSDimitry Andric if (!Sub->isZero()) {
341e8d8bef9SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: possible overflow in sub expression.\n");
34206c3fb27SDimitry Andric return nullptr;
3435ffd83dbSDimitry Andric }
344e8d8bef9SDimitry Andric }
3455ffd83dbSDimitry Andric
34606c3fb27SDimitry Andric // Check that the start value is a multiple of the VectorWidth.
34706c3fb27SDimitry Andric // TODO: This could do with a method to check if the scev is a multiple of
34806c3fb27SDimitry Andric // VectorWidth. For the moment we just check for constants, muls and unknowns
34906c3fb27SDimitry Andric // (which use MaskedValueIsZero and seems to be the most common).
35006c3fb27SDimitry Andric if (auto *BaseC = dyn_cast<SCEVConstant>(AddExpr->getStart())) {
35106c3fb27SDimitry Andric if (BaseC->getAPInt().urem(VectorWidth) == 0)
35206c3fb27SDimitry Andric return SE->getMinusSCEV(EC, BaseC);
35306c3fb27SDimitry Andric } else if (auto *BaseV = dyn_cast<SCEVUnknown>(AddExpr->getStart())) {
35406c3fb27SDimitry Andric Type *Ty = BaseV->getType();
35506c3fb27SDimitry Andric APInt Mask = APInt::getLowBitsSet(Ty->getPrimitiveSizeInBits(),
35606c3fb27SDimitry Andric Log2_64(VectorWidth));
35706c3fb27SDimitry Andric if (MaskedValueIsZero(BaseV->getValue(), Mask,
358*0fca6ea1SDimitry Andric L->getHeader()->getDataLayout()))
35906c3fb27SDimitry Andric return SE->getMinusSCEV(EC, BaseV);
36006c3fb27SDimitry Andric } else if (auto *BaseMul = dyn_cast<SCEVMulExpr>(AddExpr->getStart())) {
36106c3fb27SDimitry Andric if (auto *BaseC = dyn_cast<SCEVConstant>(BaseMul->getOperand(0)))
36206c3fb27SDimitry Andric if (BaseC->getAPInt().urem(VectorWidth) == 0)
36306c3fb27SDimitry Andric return SE->getMinusSCEV(EC, BaseC);
36406c3fb27SDimitry Andric if (auto *BaseC = dyn_cast<SCEVConstant>(BaseMul->getOperand(1)))
36506c3fb27SDimitry Andric if (BaseC->getAPInt().urem(VectorWidth) == 0)
36606c3fb27SDimitry Andric return SE->getMinusSCEV(EC, BaseC);
36706c3fb27SDimitry Andric }
368e8d8bef9SDimitry Andric
36906c3fb27SDimitry Andric LLVM_DEBUG(
37006c3fb27SDimitry Andric dbgs() << "ARM TP: induction base is not know to be a multiple of VF: "
37106c3fb27SDimitry Andric << *AddExpr->getOperand(0) << "\n");
37206c3fb27SDimitry Andric return nullptr;
3735ffd83dbSDimitry Andric }
3745ffd83dbSDimitry Andric
InsertVCTPIntrinsic(IntrinsicInst * ActiveLaneMask,Value * Start)3755ffd83dbSDimitry Andric void MVETailPredication::InsertVCTPIntrinsic(IntrinsicInst *ActiveLaneMask,
37606c3fb27SDimitry Andric Value *Start) {
3775ffd83dbSDimitry Andric IRBuilder<> Builder(L->getLoopPreheader()->getTerminator());
378480093f4SDimitry Andric Module *M = L->getHeader()->getModule();
379480093f4SDimitry Andric Type *Ty = IntegerType::get(M->getContext(), 32);
380e8d8bef9SDimitry Andric unsigned VectorWidth =
381e8d8bef9SDimitry Andric cast<FixedVectorType>(ActiveLaneMask->getType())->getNumElements();
3828bcb0991SDimitry Andric
383480093f4SDimitry Andric // Insert a phi to count the number of elements processed by the loop.
3845f757f3fSDimitry Andric Builder.SetInsertPoint(L->getHeader(), L->getHeader()->getFirstNonPHIIt());
385480093f4SDimitry Andric PHINode *Processed = Builder.CreatePHI(Ty, 2);
38606c3fb27SDimitry Andric Processed->addIncoming(Start, L->getLoopPreheader());
387480093f4SDimitry Andric
388e8d8bef9SDimitry Andric // Replace @llvm.get.active.mask() with the ARM specific VCTP intrinic, and
389e8d8bef9SDimitry Andric // thus represent the effect of tail predication.
3905ffd83dbSDimitry Andric Builder.SetInsertPoint(ActiveLaneMask);
391e8d8bef9SDimitry Andric ConstantInt *Factor = ConstantInt::get(cast<IntegerType>(Ty), VectorWidth);
392480093f4SDimitry Andric
393480093f4SDimitry Andric Intrinsic::ID VCTPID;
3945ffd83dbSDimitry Andric switch (VectorWidth) {
395480093f4SDimitry Andric default:
396480093f4SDimitry Andric llvm_unreachable("unexpected number of lanes");
3970eae32dcSDimitry Andric case 2: VCTPID = Intrinsic::arm_mve_vctp64; break;
398480093f4SDimitry Andric case 4: VCTPID = Intrinsic::arm_mve_vctp32; break;
399480093f4SDimitry Andric case 8: VCTPID = Intrinsic::arm_mve_vctp16; break;
400480093f4SDimitry Andric case 16: VCTPID = Intrinsic::arm_mve_vctp8; break;
401480093f4SDimitry Andric }
402480093f4SDimitry Andric Function *VCTP = Intrinsic::getDeclaration(M, VCTPID);
4035ffd83dbSDimitry Andric Value *VCTPCall = Builder.CreateCall(VCTP, Processed);
4045ffd83dbSDimitry Andric ActiveLaneMask->replaceAllUsesWith(VCTPCall);
405480093f4SDimitry Andric
406480093f4SDimitry Andric // Add the incoming value to the new phi.
407480093f4SDimitry Andric // TODO: This add likely already exists in the loop.
408480093f4SDimitry Andric Value *Remaining = Builder.CreateSub(Processed, Factor);
409480093f4SDimitry Andric Processed->addIncoming(Remaining, L->getLoopLatch());
410480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Insert processed elements phi: "
411480093f4SDimitry Andric << *Processed << "\n"
4125ffd83dbSDimitry Andric << "ARM TP: Inserted VCTP: " << *VCTPCall << "\n");
413480093f4SDimitry Andric }
414480093f4SDimitry Andric
TryConvertActiveLaneMask(Value * TripCount)415e8d8bef9SDimitry Andric bool MVETailPredication::TryConvertActiveLaneMask(Value *TripCount) {
416e8d8bef9SDimitry Andric SmallVector<IntrinsicInst *, 4> ActiveLaneMasks;
417e8d8bef9SDimitry Andric for (auto *BB : L->getBlocks())
418e8d8bef9SDimitry Andric for (auto &I : *BB)
419e8d8bef9SDimitry Andric if (auto *Int = dyn_cast<IntrinsicInst>(&I))
420e8d8bef9SDimitry Andric if (Int->getIntrinsicID() == Intrinsic::get_active_lane_mask)
421e8d8bef9SDimitry Andric ActiveLaneMasks.push_back(Int);
422e8d8bef9SDimitry Andric
423e8d8bef9SDimitry Andric if (ActiveLaneMasks.empty())
424480093f4SDimitry Andric return false;
425480093f4SDimitry Andric
426480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Found predicated vector loop.\n");
4278bcb0991SDimitry Andric
428e8d8bef9SDimitry Andric for (auto *ActiveLaneMask : ActiveLaneMasks) {
4295ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Found active lane mask: "
4305ffd83dbSDimitry Andric << *ActiveLaneMask << "\n");
4318bcb0991SDimitry Andric
43206c3fb27SDimitry Andric const SCEV *StartSCEV = IsSafeActiveMask(ActiveLaneMask, TripCount);
43306c3fb27SDimitry Andric if (!StartSCEV) {
4345ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Not safe to insert VCTP.\n");
4355ffd83dbSDimitry Andric return false;
4365ffd83dbSDimitry Andric }
43706c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Safe to insert VCTP. Start is " << *StartSCEV
43806c3fb27SDimitry Andric << "\n");
439*0fca6ea1SDimitry Andric SCEVExpander Expander(*SE, L->getHeader()->getDataLayout(),
44006c3fb27SDimitry Andric "start");
44106c3fb27SDimitry Andric Instruction *Ins = L->getLoopPreheader()->getTerminator();
44206c3fb27SDimitry Andric Value *Start = Expander.expandCodeFor(StartSCEV, StartSCEV->getType(), Ins);
44306c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Created start value " << *Start << "\n");
44406c3fb27SDimitry Andric InsertVCTPIntrinsic(ActiveLaneMask, Start);
4458bcb0991SDimitry Andric }
4468bcb0991SDimitry Andric
447e8d8bef9SDimitry Andric // Remove dead instructions and now dead phis.
448e8d8bef9SDimitry Andric for (auto *II : ActiveLaneMasks)
449e8d8bef9SDimitry Andric RecursivelyDeleteTriviallyDeadInstructions(II);
450bdd1243dSDimitry Andric for (auto *I : L->blocks())
451e8d8bef9SDimitry Andric DeleteDeadPHIs(I);
4528bcb0991SDimitry Andric return true;
4538bcb0991SDimitry Andric }
4548bcb0991SDimitry Andric
createMVETailPredicationPass()4558bcb0991SDimitry Andric Pass *llvm::createMVETailPredicationPass() {
4568bcb0991SDimitry Andric return new MVETailPredication();
4578bcb0991SDimitry Andric }
4588bcb0991SDimitry Andric
4598bcb0991SDimitry Andric char MVETailPredication::ID = 0;
4608bcb0991SDimitry Andric
4618bcb0991SDimitry Andric INITIALIZE_PASS_BEGIN(MVETailPredication, DEBUG_TYPE, DESC, false, false)
4628bcb0991SDimitry Andric INITIALIZE_PASS_END(MVETailPredication, DEBUG_TYPE, DESC, false, false)
463