Lines Matching +full:low +full:- +full:pass

1 //===- MVETailPredication.cpp - MVE Tail Predication ------------*- C++ -*-===//
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 /// Armv8.1m introduced MVE, M-Profile Vector Extension, and low-overhead
12 /// combined with a new form of predication called tail-predication, can be used
13 /// to provide implicit vector predication within a low-overhead loop.
17 /// DLSTP and WLSTP instructions, which setup a tail-predicated loop and the
18 /// the total number of data elements processed by the loop. The loop-end
22 /// The HardwareLoops pass inserts intrinsics identifying loops that the
23 /// backend will attempt to convert into a low-overhead loop. The vectorizer is
25 /// predicated upon an get.active.lane.mask intrinsic. This pass looks at these
27 /// instructions. This will be picked up by the ARM Low-overhead loop pass later
29 /// tail-predicated loop.
31 //===----------------------------------------------------------------------===//
57 #define DEBUG_TYPE "mve-tail-predication"
61 "tail-predication", cl::desc("MVE tail-predication pass options"),
64 "Don't tail-predicate loops"),
66 "enabled-no-reductions",
67 "Enable tail-predication, but not for reduction loops"),
70 "Enable tail-predication, including reduction loops"),
72 "force-enabled-no-reductions",
73 "Enable tail-predication, but not for reduction loops, "
76 "force-enabled",
77 "Enable tail-predication, including reduction loops, "
128 Function &F = *L->getHeader()->getParent(); in runOnLoop()
134 this->L = L; in runOnLoop()
136 // The MVE and LOB extensions are combined to enable tail-predication, but in runOnLoop()
138 if (!ST->hasMVEIntegerOps() || !ST->hasV8_1MMainlineOps()) { in runOnLoop()
143 BasicBlock *Preheader = L->getLoopPreheader(); in runOnLoop()
147 auto FindLoopIterations = [](BasicBlock *BB) -> IntrinsicInst* { in runOnLoop()
153 Intrinsic::ID ID = Call->getIntrinsicID(); in runOnLoop()
164 // The test.set iteration could live in the pre-preheader. in runOnLoop()
166 if (!Preheader->getSinglePredecessor()) in runOnLoop()
168 Setup = FindLoopIterations(Preheader->getSinglePredecessor()); in runOnLoop()
175 bool Changed = TryConvertActiveLaneMask(Setup->getArgOperand(0)); in runOnLoop()
192 // ((ElementCount + (VectorWidth - 1)) / VectorWidth
194 // (((ElementCount + (VectorWidth - 1)) / VectorWidth) - TripCount
203 Value *ElemCount = ActiveLaneMask->getOperand(1); in IsSafeActiveMask()
205 if (!L->makeLoopInvariant(ElemCount, Changed)) in IsSafeActiveMask()
208 auto *EC= SE->getSCEV(ElemCount); in IsSafeActiveMask()
209 auto *TC = SE->getSCEV(TripCount); in IsSafeActiveMask()
211 cast<FixedVectorType>(ActiveLaneMask->getType())->getNumElements(); in IsSafeActiveMask()
220 if (!SE->isLoopInvariant(EC, L)) { in IsSafeActiveMask()
230 auto *IV = ActiveLaneMask->getOperand(0); in IsSafeActiveMask()
231 auto *IVExpr = SE->getSCEV(IV); in IsSafeActiveMask()
235 LLVM_DEBUG(dbgs() << "ARM TP: induction not an add expr: "; IVExpr->dump()); in IsSafeActiveMask()
239 if (AddExpr->getLoop() != L) { in IsSafeActiveMask()
243 auto *Step = dyn_cast<SCEVConstant>(AddExpr->getOperand(1)); in IsSafeActiveMask()
246 AddExpr->getOperand(1)->dump()); in IsSafeActiveMask()
249 auto StepValue = Step->getValue()->getSExtValue(); in IsSafeActiveMask()
266 // ceil(ElementCount/Width), or floor((ElementCount+Width-1)/Width) as we in IsSafeActiveMask()
268 uint64_t TC1 = TC->getZExtValue(); in IsSafeActiveMask()
270 (ConstElemCount->getZExtValue() + VectorWidth - 1) / VectorWidth; in IsSafeActiveMask()
273 // trigger tail-predication; keep the intrinsic as a get.active.lane.mask in IsSafeActiveMask()
283 // tail-predicated loop body, which calculates the remaining elements to be in IsSafeActiveMask()
284 // processed, is non-negative, i.e. it doesn't overflow: in IsSafeActiveMask()
286 // ((ElementCount + VectorWidth - 1) / VectorWidth) - TripCount >= 0 in IsSafeActiveMask()
290 // TripCount == (ElementCount + VectorWidth - 1) / VectorWidth in IsSafeActiveMask()
294 auto *VW = SE->getSCEV(ConstantInt::get(TripCount->getType(), VectorWidth)); in IsSafeActiveMask()
295 // ElementCount + (VW-1): in IsSafeActiveMask()
296 auto *Start = AddExpr->getStart(); in IsSafeActiveMask()
297 auto *ECPlusVWMinus1 = SE->getAddExpr(EC, in IsSafeActiveMask()
298 SE->getSCEV(ConstantInt::get(TripCount->getType(), VectorWidth - 1))); in IsSafeActiveMask()
300 // Ceil = ElementCount + (VW-1) / VW in IsSafeActiveMask()
301 auto *Ceil = SE->getUDivExpr(ECPlusVWMinus1, VW); in IsSafeActiveMask()
307 dbgs() << "ARM TP: - TripCount = " << *TC << "\n"; in IsSafeActiveMask()
308 dbgs() << "ARM TP: - ElemCount = " << *EC << "\n"; in IsSafeActiveMask()
309 dbgs() << "ARM TP: - Start = " << *Start << "\n"; in IsSafeActiveMask()
310 dbgs() << "ARM TP: - BETC = " << *SE->getBackedgeTakenCount(L) << "\n"; in IsSafeActiveMask()
311 dbgs() << "ARM TP: - VecWidth = " << VectorWidth << "\n"; in IsSafeActiveMask()
312 dbgs() << "ARM TP: - (ElemCount+VW-1) / VW = " << *Ceil << "\n"; in IsSafeActiveMask()
318 // TC = ((-4 + (4 * ((3 + %N) /u 4))<nuw> - start) /u 4) in IsSafeActiveMask()
320 // and "ElementCount + (VW-1) / VW": in IsSafeActiveMask()
325 // TC - Ceil and test it for zero. in IsSafeActiveMask()
327 const SCEV *Div = SE->getUDivExpr( in IsSafeActiveMask()
328 SE->getAddExpr(SE->getMulExpr(Ceil, VW), SE->getNegativeSCEV(VW), in IsSafeActiveMask()
329 SE->getNegativeSCEV(Start)), in IsSafeActiveMask()
331 const SCEV *Sub = SE->getMinusSCEV(SE->getBackedgeTakenCount(L), Div); in IsSafeActiveMask()
332 LLVM_DEBUG(dbgs() << "ARM TP: - Sub = "; Sub->dump()); in IsSafeActiveMask()
337 Sub = SE->applyLoopGuards(Sub, L); in IsSafeActiveMask()
338 LLVM_DEBUG(dbgs() << "ARM TP: - (Guarded) = "; Sub->dump()); in IsSafeActiveMask()
340 if (!Sub->isZero()) { in IsSafeActiveMask()
350 if (auto *BaseC = dyn_cast<SCEVConstant>(AddExpr->getStart())) { in IsSafeActiveMask()
351 if (BaseC->getAPInt().urem(VectorWidth) == 0) in IsSafeActiveMask()
352 return SE->getMinusSCEV(EC, BaseC); in IsSafeActiveMask()
353 } else if (auto *BaseV = dyn_cast<SCEVUnknown>(AddExpr->getStart())) { in IsSafeActiveMask()
354 Type *Ty = BaseV->getType(); in IsSafeActiveMask()
355 APInt Mask = APInt::getLowBitsSet(Ty->getPrimitiveSizeInBits(), in IsSafeActiveMask()
357 if (MaskedValueIsZero(BaseV->getValue(), Mask, in IsSafeActiveMask()
358 L->getHeader()->getDataLayout())) in IsSafeActiveMask()
359 return SE->getMinusSCEV(EC, BaseV); in IsSafeActiveMask()
360 } else if (auto *BaseMul = dyn_cast<SCEVMulExpr>(AddExpr->getStart())) { in IsSafeActiveMask()
361 if (auto *BaseC = dyn_cast<SCEVConstant>(BaseMul->getOperand(0))) in IsSafeActiveMask()
362 if (BaseC->getAPInt().urem(VectorWidth) == 0) in IsSafeActiveMask()
363 return SE->getMinusSCEV(EC, BaseC); in IsSafeActiveMask()
364 if (auto *BaseC = dyn_cast<SCEVConstant>(BaseMul->getOperand(1))) in IsSafeActiveMask()
365 if (BaseC->getAPInt().urem(VectorWidth) == 0) in IsSafeActiveMask()
366 return SE->getMinusSCEV(EC, BaseC); in IsSafeActiveMask()
371 << *AddExpr->getOperand(0) << "\n"); in IsSafeActiveMask()
377 IRBuilder<> Builder(L->getLoopPreheader()->getTerminator()); in InsertVCTPIntrinsic()
378 Module *M = L->getHeader()->getModule(); in InsertVCTPIntrinsic()
379 Type *Ty = IntegerType::get(M->getContext(), 32); in InsertVCTPIntrinsic()
381 cast<FixedVectorType>(ActiveLaneMask->getType())->getNumElements(); in InsertVCTPIntrinsic()
384 Builder.SetInsertPoint(L->getHeader(), L->getHeader()->getFirstNonPHIIt()); in InsertVCTPIntrinsic()
386 Processed->addIncoming(Start, L->getLoopPreheader()); in InsertVCTPIntrinsic()
404 ActiveLaneMask->replaceAllUsesWith(VCTPCall); in InsertVCTPIntrinsic()
409 Processed->addIncoming(Remaining, L->getLoopLatch()); in InsertVCTPIntrinsic()
417 for (auto *BB : L->getBlocks()) in TryConvertActiveLaneMask()
420 if (Int->getIntrinsicID() == Intrinsic::get_active_lane_mask) in TryConvertActiveLaneMask()
439 SCEVExpander Expander(*SE, L->getHeader()->getDataLayout(), in TryConvertActiveLaneMask()
441 Instruction *Ins = L->getLoopPreheader()->getTerminator(); in TryConvertActiveLaneMask()
442 Value *Start = Expander.expandCodeFor(StartSCEV, StartSCEV->getType(), Ins); in TryConvertActiveLaneMask()
450 for (auto *I : L->blocks()) in TryConvertActiveLaneMask()
455 Pass *llvm::createMVETailPredicationPass() { in createMVETailPredicationPass()