xref: /freebsd/contrib/llvm-project/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp (revision 5e801ac66d24704442eba426ed13c3effb8a34e7)
1 //===-- ARMLowOverheadLoops.cpp - CodeGen Low-overhead Loops ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// Finalize v8.1-m low-overhead loops by converting the associated pseudo
10 /// instructions into machine operations.
11 /// The expectation is that the loop contains three pseudo instructions:
12 /// - t2*LoopStart - placed in the preheader or pre-preheader. The do-loop
13 ///   form should be in the preheader, whereas the while form should be in the
14 ///   preheaders only predecessor.
15 /// - t2LoopDec - placed within in the loop body.
16 /// - t2LoopEnd - the loop latch terminator.
17 ///
18 /// In addition to this, we also look for the presence of the VCTP instruction,
19 /// which determines whether we can generated the tail-predicated low-overhead
20 /// loop form.
21 ///
22 /// Assumptions and Dependencies:
23 /// Low-overhead loops are constructed and executed using a setup instruction:
24 /// DLS, WLS, DLSTP or WLSTP and an instruction that loops back: LE or LETP.
25 /// WLS(TP) and LE(TP) are branching instructions with a (large) limited range
26 /// but fixed polarity: WLS can only branch forwards and LE can only branch
27 /// backwards. These restrictions mean that this pass is dependent upon block
28 /// layout and block sizes, which is why it's the last pass to run. The same is
29 /// true for ConstantIslands, but this pass does not increase the size of the
30 /// basic blocks, nor does it change the CFG. Instructions are mainly removed
31 /// during the transform and pseudo instructions are replaced by real ones. In
32 /// some cases, when we have to revert to a 'normal' loop, we have to introduce
33 /// multiple instructions for a single pseudo (see RevertWhile and
34 /// RevertLoopEnd). To handle this situation, t2WhileLoopStartLR and t2LoopEnd
35 /// are defined to be as large as this maximum sequence of replacement
36 /// instructions.
37 ///
38 /// A note on VPR.P0 (the lane mask):
39 /// VPT, VCMP, VPNOT and VCTP won't overwrite VPR.P0 when they update it in a
40 /// "VPT Active" context (which includes low-overhead loops and vpt blocks).
41 /// They will simply "and" the result of their calculation with the current
42 /// value of VPR.P0. You can think of it like this:
43 /// \verbatim
44 /// if VPT active:    ; Between a DLSTP/LETP, or for predicated instrs
45 ///   VPR.P0 &= Value
46 /// else
47 ///   VPR.P0 = Value
48 /// \endverbatim
49 /// When we're inside the low-overhead loop (between DLSTP and LETP), we always
50 /// fall in the "VPT active" case, so we can consider that all VPR writes by
51 /// one of those instruction is actually a "and".
52 //===----------------------------------------------------------------------===//
53 
54 #include "ARM.h"
55 #include "ARMBaseInstrInfo.h"
56 #include "ARMBaseRegisterInfo.h"
57 #include "ARMBasicBlockInfo.h"
58 #include "ARMSubtarget.h"
59 #include "MVETailPredUtils.h"
60 #include "Thumb2InstrInfo.h"
61 #include "llvm/ADT/SetOperations.h"
62 #include "llvm/ADT/SmallSet.h"
63 #include "llvm/CodeGen/LivePhysRegs.h"
64 #include "llvm/CodeGen/MachineFunctionPass.h"
65 #include "llvm/CodeGen/MachineLoopInfo.h"
66 #include "llvm/CodeGen/MachineLoopUtils.h"
67 #include "llvm/CodeGen/MachineRegisterInfo.h"
68 #include "llvm/CodeGen/Passes.h"
69 #include "llvm/CodeGen/ReachingDefAnalysis.h"
70 #include "llvm/MC/MCInstrDesc.h"
71 
72 using namespace llvm;
73 
74 #define DEBUG_TYPE "arm-low-overhead-loops"
75 #define ARM_LOW_OVERHEAD_LOOPS_NAME "ARM Low Overhead Loops pass"
76 
77 static cl::opt<bool>
78 DisableTailPredication("arm-loloops-disable-tailpred", cl::Hidden,
79     cl::desc("Disable tail-predication in the ARM LowOverheadLoop pass"),
80     cl::init(false));
81 
82 static bool isVectorPredicated(MachineInstr *MI) {
83   int PIdx = llvm::findFirstVPTPredOperandIdx(*MI);
84   return PIdx != -1 && MI->getOperand(PIdx + 1).getReg() == ARM::VPR;
85 }
86 
87 static bool isVectorPredicate(MachineInstr *MI) {
88   return MI->findRegisterDefOperandIdx(ARM::VPR) != -1;
89 }
90 
91 static bool hasVPRUse(MachineInstr &MI) {
92   return MI.findRegisterUseOperandIdx(ARM::VPR) != -1;
93 }
94 
95 static bool isDomainMVE(MachineInstr *MI) {
96   uint64_t Domain = MI->getDesc().TSFlags & ARMII::DomainMask;
97   return Domain == ARMII::DomainMVE;
98 }
99 
100 static int getVecSize(const MachineInstr &MI) {
101   const MCInstrDesc &MCID = MI.getDesc();
102   uint64_t Flags = MCID.TSFlags;
103   return (Flags & ARMII::VecSize) >> ARMII::VecSizeShift;
104 }
105 
106 static bool shouldInspect(MachineInstr &MI) {
107   if (MI.isDebugInstr())
108     return false;
109   return isDomainMVE(&MI) || isVectorPredicate(&MI) || hasVPRUse(MI);
110 }
111 
112 namespace {
113 
114   using InstSet = SmallPtrSetImpl<MachineInstr *>;
115 
116   class PostOrderLoopTraversal {
117     MachineLoop &ML;
118     MachineLoopInfo &MLI;
119     SmallPtrSet<MachineBasicBlock*, 4> Visited;
120     SmallVector<MachineBasicBlock*, 4> Order;
121 
122   public:
123     PostOrderLoopTraversal(MachineLoop &ML, MachineLoopInfo &MLI)
124       : ML(ML), MLI(MLI) { }
125 
126     const SmallVectorImpl<MachineBasicBlock*> &getOrder() const {
127       return Order;
128     }
129 
130     // Visit all the blocks within the loop, as well as exit blocks and any
131     // blocks properly dominating the header.
132     void ProcessLoop() {
133       std::function<void(MachineBasicBlock*)> Search = [this, &Search]
134         (MachineBasicBlock *MBB) -> void {
135         if (Visited.count(MBB))
136           return;
137 
138         Visited.insert(MBB);
139         for (auto *Succ : MBB->successors()) {
140           if (!ML.contains(Succ))
141             continue;
142           Search(Succ);
143         }
144         Order.push_back(MBB);
145       };
146 
147       // Insert exit blocks.
148       SmallVector<MachineBasicBlock*, 2> ExitBlocks;
149       ML.getExitBlocks(ExitBlocks);
150       append_range(Order, ExitBlocks);
151 
152       // Then add the loop body.
153       Search(ML.getHeader());
154 
155       // Then try the preheader and its predecessors.
156       std::function<void(MachineBasicBlock*)> GetPredecessor =
157         [this, &GetPredecessor] (MachineBasicBlock *MBB) -> void {
158         Order.push_back(MBB);
159         if (MBB->pred_size() == 1)
160           GetPredecessor(*MBB->pred_begin());
161       };
162 
163       if (auto *Preheader = ML.getLoopPreheader())
164         GetPredecessor(Preheader);
165       else if (auto *Preheader = MLI.findLoopPreheader(&ML, true, true))
166         GetPredecessor(Preheader);
167     }
168   };
169 
170   struct PredicatedMI {
171     MachineInstr *MI = nullptr;
172     SetVector<MachineInstr*> Predicates;
173 
174   public:
175     PredicatedMI(MachineInstr *I, SetVector<MachineInstr *> &Preds) : MI(I) {
176       assert(I && "Instruction must not be null!");
177       Predicates.insert(Preds.begin(), Preds.end());
178     }
179   };
180 
181   // Represent the current state of the VPR and hold all instances which
182   // represent a VPT block, which is a list of instructions that begins with a
183   // VPT/VPST and has a maximum of four proceeding instructions. All
184   // instructions within the block are predicated upon the vpr and we allow
185   // instructions to define the vpr within in the block too.
186   class VPTState {
187     friend struct LowOverheadLoop;
188 
189     SmallVector<MachineInstr *, 4> Insts;
190 
191     static SmallVector<VPTState, 4> Blocks;
192     static SetVector<MachineInstr *> CurrentPredicates;
193     static std::map<MachineInstr *,
194       std::unique_ptr<PredicatedMI>> PredicatedInsts;
195 
196     static void CreateVPTBlock(MachineInstr *MI) {
197       assert((CurrentPredicates.size() || MI->getParent()->isLiveIn(ARM::VPR))
198              && "Can't begin VPT without predicate");
199       Blocks.emplace_back(MI);
200       // The execution of MI is predicated upon the current set of instructions
201       // that are AND'ed together to form the VPR predicate value. In the case
202       // that MI is a VPT, CurrentPredicates will also just be MI.
203       PredicatedInsts.emplace(
204         MI, std::make_unique<PredicatedMI>(MI, CurrentPredicates));
205     }
206 
207     static void reset() {
208       Blocks.clear();
209       PredicatedInsts.clear();
210       CurrentPredicates.clear();
211     }
212 
213     static void addInst(MachineInstr *MI) {
214       Blocks.back().insert(MI);
215       PredicatedInsts.emplace(
216         MI, std::make_unique<PredicatedMI>(MI, CurrentPredicates));
217     }
218 
219     static void addPredicate(MachineInstr *MI) {
220       LLVM_DEBUG(dbgs() << "ARM Loops: Adding VPT Predicate: " << *MI);
221       CurrentPredicates.insert(MI);
222     }
223 
224     static void resetPredicate(MachineInstr *MI) {
225       LLVM_DEBUG(dbgs() << "ARM Loops: Resetting VPT Predicate: " << *MI);
226       CurrentPredicates.clear();
227       CurrentPredicates.insert(MI);
228     }
229 
230   public:
231     // Have we found an instruction within the block which defines the vpr? If
232     // so, not all the instructions in the block will have the same predicate.
233     static bool hasUniformPredicate(VPTState &Block) {
234       return getDivergent(Block) == nullptr;
235     }
236 
237     // If it exists, return the first internal instruction which modifies the
238     // VPR.
239     static MachineInstr *getDivergent(VPTState &Block) {
240       SmallVectorImpl<MachineInstr *> &Insts = Block.getInsts();
241       for (unsigned i = 1; i < Insts.size(); ++i) {
242         MachineInstr *Next = Insts[i];
243         if (isVectorPredicate(Next))
244           return Next; // Found an instruction altering the vpr.
245       }
246       return nullptr;
247     }
248 
249     // Return whether the given instruction is predicated upon a VCTP.
250     static bool isPredicatedOnVCTP(MachineInstr *MI, bool Exclusive = false) {
251       SetVector<MachineInstr *> &Predicates = PredicatedInsts[MI]->Predicates;
252       if (Exclusive && Predicates.size() != 1)
253         return false;
254       for (auto *PredMI : Predicates)
255         if (isVCTP(PredMI))
256           return true;
257       return false;
258     }
259 
260     // Is the VPST, controlling the block entry, predicated upon a VCTP.
261     static bool isEntryPredicatedOnVCTP(VPTState &Block,
262                                         bool Exclusive = false) {
263       SmallVectorImpl<MachineInstr *> &Insts = Block.getInsts();
264       return isPredicatedOnVCTP(Insts.front(), Exclusive);
265     }
266 
267     // If this block begins with a VPT, we can check whether it's using
268     // at least one predicated input(s), as well as possible loop invariant
269     // which would result in it being implicitly predicated.
270     static bool hasImplicitlyValidVPT(VPTState &Block,
271                                       ReachingDefAnalysis &RDA) {
272       SmallVectorImpl<MachineInstr *> &Insts = Block.getInsts();
273       MachineInstr *VPT = Insts.front();
274       assert(isVPTOpcode(VPT->getOpcode()) &&
275              "Expected VPT block to begin with VPT/VPST");
276 
277       if (VPT->getOpcode() == ARM::MVE_VPST)
278         return false;
279 
280       auto IsOperandPredicated = [&](MachineInstr *MI, unsigned Idx) {
281         MachineInstr *Op = RDA.getMIOperand(MI, MI->getOperand(Idx));
282         return Op && PredicatedInsts.count(Op) && isPredicatedOnVCTP(Op);
283       };
284 
285       auto IsOperandInvariant = [&](MachineInstr *MI, unsigned Idx) {
286         MachineOperand &MO = MI->getOperand(Idx);
287         if (!MO.isReg() || !MO.getReg())
288           return true;
289 
290         SmallPtrSet<MachineInstr *, 2> Defs;
291         RDA.getGlobalReachingDefs(MI, MO.getReg(), Defs);
292         if (Defs.empty())
293           return true;
294 
295         for (auto *Def : Defs)
296           if (Def->getParent() == VPT->getParent())
297             return false;
298         return true;
299       };
300 
301       // Check that at least one of the operands is directly predicated on a
302       // vctp and allow an invariant value too.
303       return (IsOperandPredicated(VPT, 1) || IsOperandPredicated(VPT, 2)) &&
304              (IsOperandPredicated(VPT, 1) || IsOperandInvariant(VPT, 1)) &&
305              (IsOperandPredicated(VPT, 2) || IsOperandInvariant(VPT, 2));
306     }
307 
308     static bool isValid(ReachingDefAnalysis &RDA) {
309       // All predication within the loop should be based on vctp. If the block
310       // isn't predicated on entry, check whether the vctp is within the block
311       // and that all other instructions are then predicated on it.
312       for (auto &Block : Blocks) {
313         if (isEntryPredicatedOnVCTP(Block, false) ||
314             hasImplicitlyValidVPT(Block, RDA))
315           continue;
316 
317         SmallVectorImpl<MachineInstr *> &Insts = Block.getInsts();
318         // We don't know how to convert a block with just a VPT;VCTP into
319         // anything valid once we remove the VCTP. For now just bail out.
320         assert(isVPTOpcode(Insts.front()->getOpcode()) &&
321                "Expected VPT block to start with a VPST or VPT!");
322         if (Insts.size() == 2 && Insts.front()->getOpcode() != ARM::MVE_VPST &&
323             isVCTP(Insts.back()))
324           return false;
325 
326         for (auto *MI : Insts) {
327           // Check that any internal VCTPs are 'Then' predicated.
328           if (isVCTP(MI) && getVPTInstrPredicate(*MI) != ARMVCC::Then)
329             return false;
330           // Skip other instructions that build up the predicate.
331           if (MI->getOpcode() == ARM::MVE_VPST || isVectorPredicate(MI))
332             continue;
333           // Check that any other instructions are predicated upon a vctp.
334           // TODO: We could infer when VPTs are implicitly predicated on the
335           // vctp (when the operands are predicated).
336           if (!isPredicatedOnVCTP(MI)) {
337             LLVM_DEBUG(dbgs() << "ARM Loops: Can't convert: " << *MI);
338             return false;
339           }
340         }
341       }
342       return true;
343     }
344 
345     VPTState(MachineInstr *MI) { Insts.push_back(MI); }
346 
347     void insert(MachineInstr *MI) {
348       Insts.push_back(MI);
349       // VPT/VPST + 4 predicated instructions.
350       assert(Insts.size() <= 5 && "Too many instructions in VPT block!");
351     }
352 
353     bool containsVCTP() const {
354       for (auto *MI : Insts)
355         if (isVCTP(MI))
356           return true;
357       return false;
358     }
359 
360     unsigned size() const { return Insts.size(); }
361     SmallVectorImpl<MachineInstr *> &getInsts() { return Insts; }
362   };
363 
364   struct LowOverheadLoop {
365 
366     MachineLoop &ML;
367     MachineBasicBlock *Preheader = nullptr;
368     MachineLoopInfo &MLI;
369     ReachingDefAnalysis &RDA;
370     const TargetRegisterInfo &TRI;
371     const ARMBaseInstrInfo &TII;
372     MachineFunction *MF = nullptr;
373     MachineBasicBlock::iterator StartInsertPt;
374     MachineBasicBlock *StartInsertBB = nullptr;
375     MachineInstr *Start = nullptr;
376     MachineInstr *Dec = nullptr;
377     MachineInstr *End = nullptr;
378     MachineOperand TPNumElements;
379     SmallVector<MachineInstr *, 4> VCTPs;
380     SmallPtrSet<MachineInstr *, 4> ToRemove;
381     SmallPtrSet<MachineInstr *, 4> BlockMasksToRecompute;
382     SmallPtrSet<MachineInstr *, 4> DoubleWidthResultInstrs;
383     SmallPtrSet<MachineInstr *, 4> VMOVCopies;
384     bool Revert = false;
385     bool CannotTailPredicate = false;
386 
387     LowOverheadLoop(MachineLoop &ML, MachineLoopInfo &MLI,
388                     ReachingDefAnalysis &RDA, const TargetRegisterInfo &TRI,
389                     const ARMBaseInstrInfo &TII)
390         : ML(ML), MLI(MLI), RDA(RDA), TRI(TRI), TII(TII),
391           TPNumElements(MachineOperand::CreateImm(0)) {
392       MF = ML.getHeader()->getParent();
393       if (auto *MBB = ML.getLoopPreheader())
394         Preheader = MBB;
395       else if (auto *MBB = MLI.findLoopPreheader(&ML, true, true))
396         Preheader = MBB;
397       VPTState::reset();
398     }
399 
400     // If this is an MVE instruction, check that we know how to use tail
401     // predication with it. Record VPT blocks and return whether the
402     // instruction is valid for tail predication.
403     bool ValidateMVEInst(MachineInstr *MI);
404 
405     void AnalyseMVEInst(MachineInstr *MI) {
406       CannotTailPredicate = !ValidateMVEInst(MI);
407     }
408 
409     bool IsTailPredicationLegal() const {
410       // For now, let's keep things really simple and only support a single
411       // block for tail predication.
412       return !Revert && FoundAllComponents() && !VCTPs.empty() &&
413              !CannotTailPredicate && ML.getNumBlocks() == 1;
414     }
415 
416     // Given that MI is a VCTP, check that is equivalent to any other VCTPs
417     // found.
418     bool AddVCTP(MachineInstr *MI);
419 
420     // Check that the predication in the loop will be equivalent once we
421     // perform the conversion. Also ensure that we can provide the number
422     // of elements to the loop start instruction.
423     bool ValidateTailPredicate();
424 
425     // Check that any values available outside of the loop will be the same
426     // after tail predication conversion.
427     bool ValidateLiveOuts();
428 
429     // Is it safe to define LR with DLS/WLS?
430     // LR can be defined if it is the operand to start, because it's the same
431     // value, or if it's going to be equivalent to the operand to Start.
432     MachineInstr *isSafeToDefineLR();
433 
434     // Check the branch targets are within range and we satisfy our
435     // restrictions.
436     void Validate(ARMBasicBlockUtils *BBUtils);
437 
438     bool FoundAllComponents() const {
439       return Start && Dec && End;
440     }
441 
442     SmallVectorImpl<VPTState> &getVPTBlocks() {
443       return VPTState::Blocks;
444     }
445 
446     // Return the operand for the loop start instruction. This will be the loop
447     // iteration count, or the number of elements if we're tail predicating.
448     MachineOperand &getLoopStartOperand() {
449       if (IsTailPredicationLegal())
450         return TPNumElements;
451       return Start->getOperand(1);
452     }
453 
454     unsigned getStartOpcode() const {
455       bool IsDo = isDoLoopStart(*Start);
456       if (!IsTailPredicationLegal())
457         return IsDo ? ARM::t2DLS : ARM::t2WLS;
458 
459       return VCTPOpcodeToLSTP(VCTPs.back()->getOpcode(), IsDo);
460     }
461 
462     void dump() const {
463       if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start;
464       if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec;
465       if (End) dbgs() << "ARM Loops: Found Loop End: " << *End;
466       if (!VCTPs.empty()) {
467         dbgs() << "ARM Loops: Found VCTP(s):\n";
468         for (auto *MI : VCTPs)
469           dbgs() << " - " << *MI;
470       }
471       if (!FoundAllComponents())
472         dbgs() << "ARM Loops: Not a low-overhead loop.\n";
473       else if (!(Start && Dec && End))
474         dbgs() << "ARM Loops: Failed to find all loop components.\n";
475     }
476   };
477 
478   class ARMLowOverheadLoops : public MachineFunctionPass {
479     MachineFunction           *MF = nullptr;
480     MachineLoopInfo           *MLI = nullptr;
481     ReachingDefAnalysis       *RDA = nullptr;
482     const ARMBaseInstrInfo    *TII = nullptr;
483     MachineRegisterInfo       *MRI = nullptr;
484     const TargetRegisterInfo  *TRI = nullptr;
485     std::unique_ptr<ARMBasicBlockUtils> BBUtils = nullptr;
486 
487   public:
488     static char ID;
489 
490     ARMLowOverheadLoops() : MachineFunctionPass(ID) { }
491 
492     void getAnalysisUsage(AnalysisUsage &AU) const override {
493       AU.setPreservesCFG();
494       AU.addRequired<MachineLoopInfo>();
495       AU.addRequired<ReachingDefAnalysis>();
496       MachineFunctionPass::getAnalysisUsage(AU);
497     }
498 
499     bool runOnMachineFunction(MachineFunction &MF) override;
500 
501     MachineFunctionProperties getRequiredProperties() const override {
502       return MachineFunctionProperties().set(
503           MachineFunctionProperties::Property::NoVRegs).set(
504           MachineFunctionProperties::Property::TracksLiveness);
505     }
506 
507     StringRef getPassName() const override {
508       return ARM_LOW_OVERHEAD_LOOPS_NAME;
509     }
510 
511   private:
512     bool ProcessLoop(MachineLoop *ML);
513 
514     bool RevertNonLoops();
515 
516     void RevertWhile(MachineInstr *MI) const;
517     void RevertDo(MachineInstr *MI) const;
518 
519     bool RevertLoopDec(MachineInstr *MI) const;
520 
521     void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const;
522 
523     void RevertLoopEndDec(MachineInstr *MI) const;
524 
525     void ConvertVPTBlocks(LowOverheadLoop &LoLoop);
526 
527     MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop);
528 
529     void Expand(LowOverheadLoop &LoLoop);
530 
531     void IterationCountDCE(LowOverheadLoop &LoLoop);
532   };
533 }
534 
535 char ARMLowOverheadLoops::ID = 0;
536 
537 SmallVector<VPTState, 4> VPTState::Blocks;
538 SetVector<MachineInstr *> VPTState::CurrentPredicates;
539 std::map<MachineInstr *,
540          std::unique_ptr<PredicatedMI>> VPTState::PredicatedInsts;
541 
542 INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME,
543                 false, false)
544 
545 static bool TryRemove(MachineInstr *MI, ReachingDefAnalysis &RDA,
546                       InstSet &ToRemove, InstSet &Ignore) {
547 
548   // Check that we can remove all of Killed without having to modify any IT
549   // blocks.
550   auto WontCorruptITs = [](InstSet &Killed, ReachingDefAnalysis &RDA) {
551     // Collect the dead code and the MBBs in which they reside.
552     SmallPtrSet<MachineBasicBlock*, 2> BasicBlocks;
553     for (auto *Dead : Killed)
554       BasicBlocks.insert(Dead->getParent());
555 
556     // Collect IT blocks in all affected basic blocks.
557     std::map<MachineInstr *, SmallPtrSet<MachineInstr *, 2>> ITBlocks;
558     for (auto *MBB : BasicBlocks) {
559       for (auto &IT : *MBB) {
560         if (IT.getOpcode() != ARM::t2IT)
561           continue;
562         RDA.getReachingLocalUses(&IT, MCRegister::from(ARM::ITSTATE),
563                                  ITBlocks[&IT]);
564       }
565     }
566 
567     // If we're removing all of the instructions within an IT block, then
568     // also remove the IT instruction.
569     SmallPtrSet<MachineInstr *, 2> ModifiedITs;
570     SmallPtrSet<MachineInstr *, 2> RemoveITs;
571     for (auto *Dead : Killed) {
572       if (MachineOperand *MO = Dead->findRegisterUseOperand(ARM::ITSTATE)) {
573         MachineInstr *IT = RDA.getMIOperand(Dead, *MO);
574         RemoveITs.insert(IT);
575         auto &CurrentBlock = ITBlocks[IT];
576         CurrentBlock.erase(Dead);
577         if (CurrentBlock.empty())
578           ModifiedITs.erase(IT);
579         else
580           ModifiedITs.insert(IT);
581       }
582     }
583     if (!ModifiedITs.empty())
584       return false;
585     Killed.insert(RemoveITs.begin(), RemoveITs.end());
586     return true;
587   };
588 
589   SmallPtrSet<MachineInstr *, 2> Uses;
590   if (!RDA.isSafeToRemove(MI, Uses, Ignore))
591     return false;
592 
593   if (WontCorruptITs(Uses, RDA)) {
594     ToRemove.insert(Uses.begin(), Uses.end());
595     LLVM_DEBUG(dbgs() << "ARM Loops: Able to remove: " << *MI
596                << " - can also remove:\n";
597                for (auto *Use : Uses)
598                  dbgs() << "   - " << *Use);
599 
600     SmallPtrSet<MachineInstr*, 4> Killed;
601     RDA.collectKilledOperands(MI, Killed);
602     if (WontCorruptITs(Killed, RDA)) {
603       ToRemove.insert(Killed.begin(), Killed.end());
604       LLVM_DEBUG(for (auto *Dead : Killed)
605                    dbgs() << "   - " << *Dead);
606     }
607     return true;
608   }
609   return false;
610 }
611 
612 bool LowOverheadLoop::ValidateTailPredicate() {
613   if (!IsTailPredicationLegal()) {
614     LLVM_DEBUG(if (VCTPs.empty())
615                  dbgs() << "ARM Loops: Didn't find a VCTP instruction.\n";
616                dbgs() << "ARM Loops: Tail-predication is not valid.\n");
617     return false;
618   }
619 
620   assert(!VCTPs.empty() && "VCTP instruction expected but is not set");
621   assert(ML.getBlocks().size() == 1 &&
622          "Shouldn't be processing a loop with more than one block");
623 
624   if (DisableTailPredication) {
625     LLVM_DEBUG(dbgs() << "ARM Loops: tail-predication is disabled\n");
626     return false;
627   }
628 
629   if (!VPTState::isValid(RDA)) {
630     LLVM_DEBUG(dbgs() << "ARM Loops: Invalid VPT state.\n");
631     return false;
632   }
633 
634   if (!ValidateLiveOuts()) {
635     LLVM_DEBUG(dbgs() << "ARM Loops: Invalid live outs.\n");
636     return false;
637   }
638 
639   // For tail predication, we need to provide the number of elements, instead
640   // of the iteration count, to the loop start instruction. The number of
641   // elements is provided to the vctp instruction, so we need to check that
642   // we can use this register at InsertPt.
643   MachineInstr *VCTP = VCTPs.back();
644   if (Start->getOpcode() == ARM::t2DoLoopStartTP ||
645       Start->getOpcode() == ARM::t2WhileLoopStartTP) {
646     TPNumElements = Start->getOperand(2);
647     StartInsertPt = Start;
648     StartInsertBB = Start->getParent();
649   } else {
650     TPNumElements = VCTP->getOperand(1);
651     MCRegister NumElements = TPNumElements.getReg().asMCReg();
652 
653     // If the register is defined within loop, then we can't perform TP.
654     // TODO: Check whether this is just a mov of a register that would be
655     // available.
656     if (RDA.hasLocalDefBefore(VCTP, NumElements)) {
657       LLVM_DEBUG(dbgs() << "ARM Loops: VCTP operand is defined in the loop.\n");
658       return false;
659     }
660 
661     // The element count register maybe defined after InsertPt, in which case we
662     // need to try to move either InsertPt or the def so that the [w|d]lstp can
663     // use the value.
664 
665     if (StartInsertPt != StartInsertBB->end() &&
666         !RDA.isReachingDefLiveOut(&*StartInsertPt, NumElements)) {
667       if (auto *ElemDef =
668               RDA.getLocalLiveOutMIDef(StartInsertBB, NumElements)) {
669         if (RDA.isSafeToMoveForwards(ElemDef, &*StartInsertPt)) {
670           ElemDef->removeFromParent();
671           StartInsertBB->insert(StartInsertPt, ElemDef);
672           LLVM_DEBUG(dbgs()
673                      << "ARM Loops: Moved element count def: " << *ElemDef);
674         } else if (RDA.isSafeToMoveBackwards(&*StartInsertPt, ElemDef)) {
675           StartInsertPt->removeFromParent();
676           StartInsertBB->insertAfter(MachineBasicBlock::iterator(ElemDef),
677                                      &*StartInsertPt);
678           LLVM_DEBUG(dbgs() << "ARM Loops: Moved start past: " << *ElemDef);
679         } else {
680           // If we fail to move an instruction and the element count is provided
681           // by a mov, use the mov operand if it will have the same value at the
682           // insertion point
683           MachineOperand Operand = ElemDef->getOperand(1);
684           if (isMovRegOpcode(ElemDef->getOpcode()) &&
685               RDA.getUniqueReachingMIDef(ElemDef, Operand.getReg().asMCReg()) ==
686                   RDA.getUniqueReachingMIDef(&*StartInsertPt,
687                                              Operand.getReg().asMCReg())) {
688             TPNumElements = Operand;
689             NumElements = TPNumElements.getReg();
690           } else {
691             LLVM_DEBUG(dbgs()
692                        << "ARM Loops: Unable to move element count to loop "
693                        << "start instruction.\n");
694             return false;
695           }
696         }
697       }
698     }
699 
700     // Especially in the case of while loops, InsertBB may not be the
701     // preheader, so we need to check that the register isn't redefined
702     // before entering the loop.
703     auto CannotProvideElements = [this](MachineBasicBlock *MBB,
704                                         MCRegister NumElements) {
705       if (MBB->empty())
706         return false;
707       // NumElements is redefined in this block.
708       if (RDA.hasLocalDefBefore(&MBB->back(), NumElements))
709         return true;
710 
711       // Don't continue searching up through multiple predecessors.
712       if (MBB->pred_size() > 1)
713         return true;
714 
715       return false;
716     };
717 
718     // Search backwards for a def, until we get to InsertBB.
719     MachineBasicBlock *MBB = Preheader;
720     while (MBB && MBB != StartInsertBB) {
721       if (CannotProvideElements(MBB, NumElements)) {
722         LLVM_DEBUG(dbgs() << "ARM Loops: Unable to provide element count.\n");
723         return false;
724       }
725       MBB = *MBB->pred_begin();
726     }
727   }
728 
729   // Could inserting the [W|D]LSTP cause some unintended affects? In a perfect
730   // world the [w|d]lstp instruction would be last instruction in the preheader
731   // and so it would only affect instructions within the loop body. But due to
732   // scheduling, and/or the logic in this pass (above), the insertion point can
733   // be moved earlier. So if the Loop Start isn't the last instruction in the
734   // preheader, and if the initial element count is smaller than the vector
735   // width, the Loop Start instruction will immediately generate one or more
736   // false lane mask which can, incorrectly, affect the proceeding MVE
737   // instructions in the preheader.
738   if (std::any_of(StartInsertPt, StartInsertBB->end(), shouldInspect)) {
739     LLVM_DEBUG(dbgs() << "ARM Loops: Instruction blocks [W|D]LSTP\n");
740     return false;
741   }
742 
743   // For any DoubleWidthResultInstrs we found whilst scanning instructions, they
744   // need to compute an output size that is smaller than the VCTP mask operates
745   // on. The VecSize of the DoubleWidthResult is the larger vector size - the
746   // size it extends into, so any VCTP VecSize <= is valid.
747   unsigned VCTPVecSize = getVecSize(*VCTP);
748   for (MachineInstr *MI : DoubleWidthResultInstrs) {
749     unsigned InstrVecSize = getVecSize(*MI);
750     if (InstrVecSize > VCTPVecSize) {
751       LLVM_DEBUG(dbgs() << "ARM Loops: Double width result larger than VCTP "
752                         << "VecSize:\n" << *MI);
753       return false;
754     }
755   }
756 
757   // Check that the value change of the element count is what we expect and
758   // that the predication will be equivalent. For this we need:
759   // NumElements = NumElements - VectorWidth. The sub will be a sub immediate
760   // and we can also allow register copies within the chain too.
761   auto IsValidSub = [](MachineInstr *MI, int ExpectedVecWidth) {
762     return -getAddSubImmediate(*MI) == ExpectedVecWidth;
763   };
764 
765   MachineBasicBlock *MBB = VCTP->getParent();
766   // Remove modifications to the element count since they have no purpose in a
767   // tail predicated loop. Explicitly refer to the vctp operand no matter which
768   // register NumElements has been assigned to, since that is what the
769   // modifications will be using
770   if (auto *Def = RDA.getUniqueReachingMIDef(
771           &MBB->back(), VCTP->getOperand(1).getReg().asMCReg())) {
772     SmallPtrSet<MachineInstr*, 2> ElementChain;
773     SmallPtrSet<MachineInstr*, 2> Ignore;
774     unsigned ExpectedVectorWidth = getTailPredVectorWidth(VCTP->getOpcode());
775 
776     Ignore.insert(VCTPs.begin(), VCTPs.end());
777 
778     if (TryRemove(Def, RDA, ElementChain, Ignore)) {
779       bool FoundSub = false;
780 
781       for (auto *MI : ElementChain) {
782         if (isMovRegOpcode(MI->getOpcode()))
783           continue;
784 
785         if (isSubImmOpcode(MI->getOpcode())) {
786           if (FoundSub || !IsValidSub(MI, ExpectedVectorWidth)) {
787             LLVM_DEBUG(dbgs() << "ARM Loops: Unexpected instruction in element"
788                        " count: " << *MI);
789             return false;
790           }
791           FoundSub = true;
792         } else {
793           LLVM_DEBUG(dbgs() << "ARM Loops: Unexpected instruction in element"
794                      " count: " << *MI);
795           return false;
796         }
797       }
798       ToRemove.insert(ElementChain.begin(), ElementChain.end());
799     }
800   }
801 
802   // If we converted the LoopStart to a t2DoLoopStartTP/t2WhileLoopStartTP, we
803   // can also remove any extra instructions in the preheader, which often
804   // includes a now unused MOV.
805   if ((Start->getOpcode() == ARM::t2DoLoopStartTP ||
806        Start->getOpcode() == ARM::t2WhileLoopStartTP) &&
807       Preheader && !Preheader->empty() &&
808       !RDA.hasLocalDefBefore(VCTP, VCTP->getOperand(1).getReg())) {
809     if (auto *Def = RDA.getUniqueReachingMIDef(
810             &Preheader->back(), VCTP->getOperand(1).getReg().asMCReg())) {
811       SmallPtrSet<MachineInstr*, 2> Ignore;
812       Ignore.insert(VCTPs.begin(), VCTPs.end());
813       TryRemove(Def, RDA, ToRemove, Ignore);
814     }
815   }
816 
817   return true;
818 }
819 
820 static bool isRegInClass(const MachineOperand &MO,
821                          const TargetRegisterClass *Class) {
822   return MO.isReg() && MO.getReg() && Class->contains(MO.getReg());
823 }
824 
825 // MVE 'narrowing' operate on half a lane, reading from half and writing
826 // to half, which are referred to has the top and bottom half. The other
827 // half retains its previous value.
828 static bool retainsPreviousHalfElement(const MachineInstr &MI) {
829   const MCInstrDesc &MCID = MI.getDesc();
830   uint64_t Flags = MCID.TSFlags;
831   return (Flags & ARMII::RetainsPreviousHalfElement) != 0;
832 }
833 
834 // Some MVE instructions read from the top/bottom halves of their operand(s)
835 // and generate a vector result with result elements that are double the
836 // width of the input.
837 static bool producesDoubleWidthResult(const MachineInstr &MI) {
838   const MCInstrDesc &MCID = MI.getDesc();
839   uint64_t Flags = MCID.TSFlags;
840   return (Flags & ARMII::DoubleWidthResult) != 0;
841 }
842 
843 static bool isHorizontalReduction(const MachineInstr &MI) {
844   const MCInstrDesc &MCID = MI.getDesc();
845   uint64_t Flags = MCID.TSFlags;
846   return (Flags & ARMII::HorizontalReduction) != 0;
847 }
848 
849 // Can this instruction generate a non-zero result when given only zeroed
850 // operands? This allows us to know that, given operands with false bytes
851 // zeroed by masked loads, that the result will also contain zeros in those
852 // bytes.
853 static bool canGenerateNonZeros(const MachineInstr &MI) {
854 
855   // Check for instructions which can write into a larger element size,
856   // possibly writing into a previous zero'd lane.
857   if (producesDoubleWidthResult(MI))
858     return true;
859 
860   switch (MI.getOpcode()) {
861   default:
862     break;
863   // FIXME: VNEG FP and -0? I think we'll need to handle this once we allow
864   // fp16 -> fp32 vector conversions.
865   // Instructions that perform a NOT will generate 1s from 0s.
866   case ARM::MVE_VMVN:
867   case ARM::MVE_VORN:
868   // Count leading zeros will do just that!
869   case ARM::MVE_VCLZs8:
870   case ARM::MVE_VCLZs16:
871   case ARM::MVE_VCLZs32:
872     return true;
873   }
874   return false;
875 }
876 
877 // Look at its register uses to see if it only can only receive zeros
878 // into its false lanes which would then produce zeros. Also check that
879 // the output register is also defined by an FalseLanesZero instruction
880 // so that if tail-predication happens, the lanes that aren't updated will
881 // still be zeros.
882 static bool producesFalseLanesZero(MachineInstr &MI,
883                                    const TargetRegisterClass *QPRs,
884                                    const ReachingDefAnalysis &RDA,
885                                    InstSet &FalseLanesZero) {
886   if (canGenerateNonZeros(MI))
887     return false;
888 
889   bool isPredicated = isVectorPredicated(&MI);
890   // Predicated loads will write zeros to the falsely predicated bytes of the
891   // destination register.
892   if (MI.mayLoad())
893     return isPredicated;
894 
895   auto IsZeroInit = [](MachineInstr *Def) {
896     return !isVectorPredicated(Def) &&
897            Def->getOpcode() == ARM::MVE_VMOVimmi32 &&
898            Def->getOperand(1).getImm() == 0;
899   };
900 
901   bool AllowScalars = isHorizontalReduction(MI);
902   for (auto &MO : MI.operands()) {
903     if (!MO.isReg() || !MO.getReg())
904       continue;
905     if (!isRegInClass(MO, QPRs) && AllowScalars)
906       continue;
907     // Skip the lr predicate reg
908     int PIdx = llvm::findFirstVPTPredOperandIdx(MI);
909     if (PIdx != -1 && (int)MI.getOperandNo(&MO) == PIdx + 2)
910       continue;
911 
912     // Check that this instruction will produce zeros in its false lanes:
913     // - If it only consumes false lanes zero or constant 0 (vmov #0)
914     // - If it's predicated, it only matters that it's def register already has
915     //   false lane zeros, so we can ignore the uses.
916     SmallPtrSet<MachineInstr *, 2> Defs;
917     RDA.getGlobalReachingDefs(&MI, MO.getReg(), Defs);
918     for (auto *Def : Defs) {
919       if (Def == &MI || FalseLanesZero.count(Def) || IsZeroInit(Def))
920         continue;
921       if (MO.isUse() && isPredicated)
922         continue;
923       return false;
924     }
925   }
926   LLVM_DEBUG(dbgs() << "ARM Loops: Always False Zeros: " << MI);
927   return true;
928 }
929 
930 bool LowOverheadLoop::ValidateLiveOuts() {
931   // We want to find out if the tail-predicated version of this loop will
932   // produce the same values as the loop in its original form. For this to
933   // be true, the newly inserted implicit predication must not change the
934   // the (observable) results.
935   // We're doing this because many instructions in the loop will not be
936   // predicated and so the conversion from VPT predication to tail-predication
937   // can result in different values being produced; due to the tail-predication
938   // preventing many instructions from updating their falsely predicated
939   // lanes. This analysis assumes that all the instructions perform lane-wise
940   // operations and don't perform any exchanges.
941   // A masked load, whether through VPT or tail predication, will write zeros
942   // to any of the falsely predicated bytes. So, from the loads, we know that
943   // the false lanes are zeroed and here we're trying to track that those false
944   // lanes remain zero, or where they change, the differences are masked away
945   // by their user(s).
946   // All MVE stores have to be predicated, so we know that any predicate load
947   // operands, or stored results are equivalent already. Other explicitly
948   // predicated instructions will perform the same operation in the original
949   // loop and the tail-predicated form too. Because of this, we can insert
950   // loads, stores and other predicated instructions into our Predicated
951   // set and build from there.
952   const TargetRegisterClass *QPRs = TRI.getRegClass(ARM::MQPRRegClassID);
953   SetVector<MachineInstr *> FalseLanesUnknown;
954   SmallPtrSet<MachineInstr *, 4> FalseLanesZero;
955   SmallPtrSet<MachineInstr *, 4> Predicated;
956   MachineBasicBlock *Header = ML.getHeader();
957 
958   LLVM_DEBUG(dbgs() << "ARM Loops: Validating Live outs\n");
959 
960   for (auto &MI : *Header) {
961     if (!shouldInspect(MI))
962       continue;
963 
964     if (isVCTP(&MI) || isVPTOpcode(MI.getOpcode()))
965       continue;
966 
967     bool isPredicated = isVectorPredicated(&MI);
968     bool retainsOrReduces =
969       retainsPreviousHalfElement(MI) || isHorizontalReduction(MI);
970 
971     if (isPredicated)
972       Predicated.insert(&MI);
973     if (producesFalseLanesZero(MI, QPRs, RDA, FalseLanesZero))
974       FalseLanesZero.insert(&MI);
975     else if (MI.getNumDefs() == 0)
976       continue;
977     else if (!isPredicated && retainsOrReduces) {
978       LLVM_DEBUG(dbgs() << "  Unpredicated instruction that retainsOrReduces: " << MI);
979       return false;
980     } else if (!isPredicated && MI.getOpcode() != ARM::MQPRCopy)
981       FalseLanesUnknown.insert(&MI);
982   }
983 
984   LLVM_DEBUG({
985     dbgs() << "  Predicated:\n";
986     for (auto *I : Predicated)
987       dbgs() << "  " << *I;
988     dbgs() << "  FalseLanesZero:\n";
989     for (auto *I : FalseLanesZero)
990       dbgs() << "  " << *I;
991     dbgs() << "  FalseLanesUnknown:\n";
992     for (auto *I : FalseLanesUnknown)
993       dbgs() << "  " << *I;
994   });
995 
996   auto HasPredicatedUsers = [this](MachineInstr *MI, const MachineOperand &MO,
997                               SmallPtrSetImpl<MachineInstr *> &Predicated) {
998     SmallPtrSet<MachineInstr *, 2> Uses;
999     RDA.getGlobalUses(MI, MO.getReg().asMCReg(), Uses);
1000     for (auto *Use : Uses) {
1001       if (Use != MI && !Predicated.count(Use))
1002         return false;
1003     }
1004     return true;
1005   };
1006 
1007   // Visit the unknowns in reverse so that we can start at the values being
1008   // stored and then we can work towards the leaves, hopefully adding more
1009   // instructions to Predicated. Successfully terminating the loop means that
1010   // all the unknown values have to found to be masked by predicated user(s).
1011   // For any unpredicated values, we store them in NonPredicated so that we
1012   // can later check whether these form a reduction.
1013   SmallPtrSet<MachineInstr*, 2> NonPredicated;
1014   for (auto *MI : reverse(FalseLanesUnknown)) {
1015     for (auto &MO : MI->operands()) {
1016       if (!isRegInClass(MO, QPRs) || !MO.isDef())
1017         continue;
1018       if (!HasPredicatedUsers(MI, MO, Predicated)) {
1019         LLVM_DEBUG(dbgs() << "  Found an unknown def of : "
1020                           << TRI.getRegAsmName(MO.getReg()) << " at " << *MI);
1021         NonPredicated.insert(MI);
1022         break;
1023       }
1024     }
1025     // Any unknown false lanes have been masked away by the user(s).
1026     if (!NonPredicated.contains(MI))
1027       Predicated.insert(MI);
1028   }
1029 
1030   SmallPtrSet<MachineInstr *, 2> LiveOutMIs;
1031   SmallVector<MachineBasicBlock *, 2> ExitBlocks;
1032   ML.getExitBlocks(ExitBlocks);
1033   assert(ML.getNumBlocks() == 1 && "Expected single block loop!");
1034   assert(ExitBlocks.size() == 1 && "Expected a single exit block");
1035   MachineBasicBlock *ExitBB = ExitBlocks.front();
1036   for (const MachineBasicBlock::RegisterMaskPair &RegMask : ExitBB->liveins()) {
1037     // TODO: Instead of blocking predication, we could move the vctp to the exit
1038     // block and calculate it's operand there in or the preheader.
1039     if (RegMask.PhysReg == ARM::VPR) {
1040       LLVM_DEBUG(dbgs() << "  VPR is live in to the exit block.");
1041       return false;
1042     }
1043     // Check Q-regs that are live in the exit blocks. We don't collect scalars
1044     // because they won't be affected by lane predication.
1045     if (QPRs->contains(RegMask.PhysReg))
1046       if (auto *MI = RDA.getLocalLiveOutMIDef(Header, RegMask.PhysReg))
1047         LiveOutMIs.insert(MI);
1048   }
1049 
1050   // We've already validated that any VPT predication within the loop will be
1051   // equivalent when we perform the predication transformation; so we know that
1052   // any VPT predicated instruction is predicated upon VCTP. Any live-out
1053   // instruction needs to be predicated, so check this here. The instructions
1054   // in NonPredicated have been found to be a reduction that we can ensure its
1055   // legality. Any MQPRCopy found will need to validate its input as if it was
1056   // live out.
1057   SmallVector<MachineInstr *> Worklist(LiveOutMIs.begin(), LiveOutMIs.end());
1058   while (!Worklist.empty()) {
1059     MachineInstr *MI = Worklist.pop_back_val();
1060     if (MI->getOpcode() == ARM::MQPRCopy) {
1061       VMOVCopies.insert(MI);
1062       MachineInstr *CopySrc =
1063           RDA.getUniqueReachingMIDef(MI, MI->getOperand(1).getReg());
1064       if (CopySrc)
1065         Worklist.push_back(CopySrc);
1066     } else if (NonPredicated.count(MI) && FalseLanesUnknown.contains(MI)) {
1067       LLVM_DEBUG(dbgs() << " Unable to handle live out: " << *MI);
1068       VMOVCopies.clear();
1069       return false;
1070     }
1071   }
1072 
1073   return true;
1074 }
1075 
1076 void LowOverheadLoop::Validate(ARMBasicBlockUtils *BBUtils) {
1077   if (Revert)
1078     return;
1079 
1080   // Check branch target ranges: WLS[TP] can only branch forwards and LE[TP]
1081   // can only jump back.
1082   auto ValidateRanges = [](MachineInstr *Start, MachineInstr *End,
1083                            ARMBasicBlockUtils *BBUtils, MachineLoop &ML) {
1084     MachineBasicBlock *TgtBB = End->getOpcode() == ARM::t2LoopEnd
1085                                    ? End->getOperand(1).getMBB()
1086                                    : End->getOperand(2).getMBB();
1087     // TODO Maybe there's cases where the target doesn't have to be the header,
1088     // but for now be safe and revert.
1089     if (TgtBB != ML.getHeader()) {
1090       LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targeting header.\n");
1091       return false;
1092     }
1093 
1094     // The WLS and LE instructions have 12-bits for the label offset. WLS
1095     // requires a positive offset, while LE uses negative.
1096     if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML.getHeader()) ||
1097         !BBUtils->isBBInRange(End, ML.getHeader(), 4094)) {
1098       LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n");
1099       return false;
1100     }
1101 
1102     if (isWhileLoopStart(*Start)) {
1103       MachineBasicBlock *TargetBB = getWhileLoopStartTargetBB(*Start);
1104       if (BBUtils->getOffsetOf(Start) > BBUtils->getOffsetOf(TargetBB) ||
1105           !BBUtils->isBBInRange(Start, TargetBB, 4094)) {
1106         LLVM_DEBUG(dbgs() << "ARM Loops: WLS offset is out-of-range!\n");
1107         return false;
1108       }
1109     }
1110     return true;
1111   };
1112 
1113   StartInsertPt = MachineBasicBlock::iterator(Start);
1114   StartInsertBB = Start->getParent();
1115   LLVM_DEBUG(dbgs() << "ARM Loops: Will insert LoopStart at "
1116                     << *StartInsertPt);
1117 
1118   Revert = !ValidateRanges(Start, End, BBUtils, ML);
1119   CannotTailPredicate = !ValidateTailPredicate();
1120 }
1121 
1122 bool LowOverheadLoop::AddVCTP(MachineInstr *MI) {
1123   LLVM_DEBUG(dbgs() << "ARM Loops: Adding VCTP: " << *MI);
1124   if (VCTPs.empty()) {
1125     VCTPs.push_back(MI);
1126     return true;
1127   }
1128 
1129   // If we find another VCTP, check whether it uses the same value as the main VCTP.
1130   // If it does, store it in the VCTPs set, else refuse it.
1131   MachineInstr *Prev = VCTPs.back();
1132   if (!Prev->getOperand(1).isIdenticalTo(MI->getOperand(1)) ||
1133       !RDA.hasSameReachingDef(Prev, MI, MI->getOperand(1).getReg().asMCReg())) {
1134     LLVM_DEBUG(dbgs() << "ARM Loops: Found VCTP with a different reaching "
1135                          "definition from the main VCTP");
1136     return false;
1137   }
1138   VCTPs.push_back(MI);
1139   return true;
1140 }
1141 
1142 static bool ValidateMVEStore(MachineInstr *MI, MachineLoop *ML) {
1143 
1144   auto GetFrameIndex = [](MachineMemOperand *Operand) {
1145     const PseudoSourceValue *PseudoValue = Operand->getPseudoValue();
1146     if (PseudoValue && PseudoValue->kind() == PseudoSourceValue::FixedStack) {
1147       if (const auto *FS = dyn_cast<FixedStackPseudoSourceValue>(PseudoValue)) {
1148         return FS->getFrameIndex();
1149       }
1150     }
1151     return -1;
1152   };
1153 
1154   auto IsStackOp = [GetFrameIndex](MachineInstr *I) {
1155     switch (I->getOpcode()) {
1156     case ARM::MVE_VSTRWU32:
1157     case ARM::MVE_VLDRWU32: {
1158       return I->getOperand(1).getReg() == ARM::SP &&
1159              I->memoperands().size() == 1 &&
1160              GetFrameIndex(I->memoperands().front()) >= 0;
1161     }
1162     default:
1163       return false;
1164     }
1165   };
1166 
1167   // An unpredicated vector register spill is allowed if all of the uses of the
1168   // stack slot are within the loop
1169   if (MI->getOpcode() != ARM::MVE_VSTRWU32 || !IsStackOp(MI))
1170     return false;
1171 
1172   // Search all blocks after the loop for accesses to the same stack slot.
1173   // ReachingDefAnalysis doesn't work for sp as it relies on registers being
1174   // live-out (which sp never is) to know what blocks to look in
1175   if (MI->memoperands().size() == 0)
1176     return false;
1177   int FI = GetFrameIndex(MI->memoperands().front());
1178 
1179   auto &FrameInfo = MI->getParent()->getParent()->getFrameInfo();
1180   if (FI == -1 || !FrameInfo.isSpillSlotObjectIndex(FI))
1181     return false;
1182 
1183   SmallVector<MachineBasicBlock *> Frontier;
1184   ML->getExitBlocks(Frontier);
1185   SmallPtrSet<MachineBasicBlock *, 4> Visited{MI->getParent()};
1186   unsigned Idx = 0;
1187   while (Idx < Frontier.size()) {
1188     MachineBasicBlock *BB = Frontier[Idx];
1189     bool LookAtSuccessors = true;
1190     for (auto &I : *BB) {
1191       if (!IsStackOp(&I) || I.memoperands().size() == 0)
1192         continue;
1193       if (GetFrameIndex(I.memoperands().front()) != FI)
1194         continue;
1195       // If this block has a store to the stack slot before any loads then we
1196       // can ignore the block
1197       if (I.getOpcode() == ARM::MVE_VSTRWU32) {
1198         LookAtSuccessors = false;
1199         break;
1200       }
1201       // If the store and the load are using the same stack slot then the
1202       // store isn't valid for tail predication
1203       if (I.getOpcode() == ARM::MVE_VLDRWU32)
1204         return false;
1205     }
1206 
1207     if (LookAtSuccessors) {
1208       for (auto Succ : BB->successors()) {
1209         if (!Visited.contains(Succ) && !is_contained(Frontier, Succ))
1210           Frontier.push_back(Succ);
1211       }
1212     }
1213     Visited.insert(BB);
1214     Idx++;
1215   }
1216 
1217   return true;
1218 }
1219 
1220 bool LowOverheadLoop::ValidateMVEInst(MachineInstr *MI) {
1221   if (CannotTailPredicate)
1222     return false;
1223 
1224   if (!shouldInspect(*MI))
1225     return true;
1226 
1227   if (MI->getOpcode() == ARM::MVE_VPSEL ||
1228       MI->getOpcode() == ARM::MVE_VPNOT) {
1229     // TODO: Allow VPSEL and VPNOT, we currently cannot because:
1230     // 1) It will use the VPR as a predicate operand, but doesn't have to be
1231     //    instead a VPT block, which means we can assert while building up
1232     //    the VPT block because we don't find another VPT or VPST to being a new
1233     //    one.
1234     // 2) VPSEL still requires a VPR operand even after tail predicating,
1235     //    which means we can't remove it unless there is another
1236     //    instruction, such as vcmp, that can provide the VPR def.
1237     return false;
1238   }
1239 
1240   // Record all VCTPs and check that they're equivalent to one another.
1241   if (isVCTP(MI) && !AddVCTP(MI))
1242     return false;
1243 
1244   // Inspect uses first so that any instructions that alter the VPR don't
1245   // alter the predicate upon themselves.
1246   const MCInstrDesc &MCID = MI->getDesc();
1247   bool IsUse = false;
1248   unsigned LastOpIdx = MI->getNumOperands() - 1;
1249   for (auto &Op : enumerate(reverse(MCID.operands()))) {
1250     const MachineOperand &MO = MI->getOperand(LastOpIdx - Op.index());
1251     if (!MO.isReg() || !MO.isUse() || MO.getReg() != ARM::VPR)
1252       continue;
1253 
1254     if (ARM::isVpred(Op.value().OperandType)) {
1255       VPTState::addInst(MI);
1256       IsUse = true;
1257     } else if (MI->getOpcode() != ARM::MVE_VPST) {
1258       LLVM_DEBUG(dbgs() << "ARM Loops: Found instruction using vpr: " << *MI);
1259       return false;
1260     }
1261   }
1262 
1263   // If we find an instruction that has been marked as not valid for tail
1264   // predication, only allow the instruction if it's contained within a valid
1265   // VPT block.
1266   bool RequiresExplicitPredication =
1267     (MCID.TSFlags & ARMII::ValidForTailPredication) == 0;
1268   if (isDomainMVE(MI) && RequiresExplicitPredication) {
1269     if (MI->getOpcode() == ARM::MQPRCopy)
1270       return true;
1271     if (!IsUse && producesDoubleWidthResult(*MI)) {
1272       DoubleWidthResultInstrs.insert(MI);
1273       return true;
1274     }
1275 
1276     LLVM_DEBUG(if (!IsUse) dbgs()
1277                << "ARM Loops: Can't tail predicate: " << *MI);
1278     return IsUse;
1279   }
1280 
1281   // If the instruction is already explicitly predicated, then the conversion
1282   // will be fine, but ensure that all store operations are predicated.
1283   if (MI->mayStore() && !ValidateMVEStore(MI, &ML))
1284     return IsUse;
1285 
1286   // If this instruction defines the VPR, update the predicate for the
1287   // proceeding instructions.
1288   if (isVectorPredicate(MI)) {
1289     // Clear the existing predicate when we're not in VPT Active state,
1290     // otherwise we add to it.
1291     if (!isVectorPredicated(MI))
1292       VPTState::resetPredicate(MI);
1293     else
1294       VPTState::addPredicate(MI);
1295   }
1296 
1297   // Finally once the predicate has been modified, we can start a new VPT
1298   // block if necessary.
1299   if (isVPTOpcode(MI->getOpcode()))
1300     VPTState::CreateVPTBlock(MI);
1301 
1302   return true;
1303 }
1304 
1305 bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) {
1306   const ARMSubtarget &ST = static_cast<const ARMSubtarget&>(mf.getSubtarget());
1307   if (!ST.hasLOB())
1308     return false;
1309 
1310   MF = &mf;
1311   LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n");
1312 
1313   MLI = &getAnalysis<MachineLoopInfo>();
1314   RDA = &getAnalysis<ReachingDefAnalysis>();
1315   MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness);
1316   MRI = &MF->getRegInfo();
1317   TII = static_cast<const ARMBaseInstrInfo*>(ST.getInstrInfo());
1318   TRI = ST.getRegisterInfo();
1319   BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(*MF));
1320   BBUtils->computeAllBlockSizes();
1321   BBUtils->adjustBBOffsetsAfter(&MF->front());
1322 
1323   bool Changed = false;
1324   for (auto ML : *MLI) {
1325     if (ML->isOutermost())
1326       Changed |= ProcessLoop(ML);
1327   }
1328   Changed |= RevertNonLoops();
1329   return Changed;
1330 }
1331 
1332 bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
1333 
1334   bool Changed = false;
1335 
1336   // Process inner loops first.
1337   for (auto I = ML->begin(), E = ML->end(); I != E; ++I)
1338     Changed |= ProcessLoop(*I);
1339 
1340   LLVM_DEBUG({
1341     dbgs() << "ARM Loops: Processing loop containing:\n";
1342     if (auto *Preheader = ML->getLoopPreheader())
1343       dbgs() << " - Preheader: " << printMBBReference(*Preheader) << "\n";
1344     else if (auto *Preheader = MLI->findLoopPreheader(ML, true, true))
1345       dbgs() << " - Preheader: " << printMBBReference(*Preheader) << "\n";
1346     for (auto *MBB : ML->getBlocks())
1347       dbgs() << " - Block: " << printMBBReference(*MBB) << "\n";
1348   });
1349 
1350   // Search the given block for a loop start instruction. If one isn't found,
1351   // and there's only one predecessor block, search that one too.
1352   std::function<MachineInstr*(MachineBasicBlock*)> SearchForStart =
1353     [&SearchForStart](MachineBasicBlock *MBB) -> MachineInstr* {
1354     for (auto &MI : *MBB) {
1355       if (isLoopStart(MI))
1356         return &MI;
1357     }
1358     if (MBB->pred_size() == 1)
1359       return SearchForStart(*MBB->pred_begin());
1360     return nullptr;
1361   };
1362 
1363   LowOverheadLoop LoLoop(*ML, *MLI, *RDA, *TRI, *TII);
1364   // Search the preheader for the start intrinsic.
1365   // FIXME: I don't see why we shouldn't be supporting multiple predecessors
1366   // with potentially multiple set.loop.iterations, so we need to enable this.
1367   if (LoLoop.Preheader)
1368     LoLoop.Start = SearchForStart(LoLoop.Preheader);
1369   else
1370     return Changed;
1371 
1372   // Find the low-overhead loop components and decide whether or not to fall
1373   // back to a normal loop. Also look for a vctp instructions and decide
1374   // whether we can convert that predicate using tail predication.
1375   for (auto *MBB : reverse(ML->getBlocks())) {
1376     for (auto &MI : *MBB) {
1377       if (MI.isDebugValue())
1378         continue;
1379       else if (MI.getOpcode() == ARM::t2LoopDec)
1380         LoLoop.Dec = &MI;
1381       else if (MI.getOpcode() == ARM::t2LoopEnd)
1382         LoLoop.End = &MI;
1383       else if (MI.getOpcode() == ARM::t2LoopEndDec)
1384         LoLoop.End = LoLoop.Dec = &MI;
1385       else if (isLoopStart(MI))
1386         LoLoop.Start = &MI;
1387       else if (MI.getDesc().isCall()) {
1388         // TODO: Though the call will require LE to execute again, does this
1389         // mean we should revert? Always executing LE hopefully should be
1390         // faster than performing a sub,cmp,br or even subs,br.
1391         LoLoop.Revert = true;
1392         LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n");
1393       } else {
1394         // Record VPR defs and build up their corresponding vpt blocks.
1395         // Check we know how to tail predicate any mve instructions.
1396         LoLoop.AnalyseMVEInst(&MI);
1397       }
1398     }
1399   }
1400 
1401   LLVM_DEBUG(LoLoop.dump());
1402   if (!LoLoop.FoundAllComponents()) {
1403     LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find loop start, update, end\n");
1404     return Changed;
1405   }
1406 
1407   assert(LoLoop.Start->getOpcode() != ARM::t2WhileLoopStart &&
1408          "Expected t2WhileLoopStart to be removed before regalloc!");
1409 
1410   // Check that the only instruction using LoopDec is LoopEnd. This can only
1411   // happen when the Dec and End are separate, not a single t2LoopEndDec.
1412   // TODO: Check for copy chains that really have no effect.
1413   if (LoLoop.Dec != LoLoop.End) {
1414     SmallPtrSet<MachineInstr *, 2> Uses;
1415     RDA->getReachingLocalUses(LoLoop.Dec, MCRegister::from(ARM::LR), Uses);
1416     if (Uses.size() > 1 || !Uses.count(LoLoop.End)) {
1417       LLVM_DEBUG(dbgs() << "ARM Loops: Unable to remove LoopDec.\n");
1418       LoLoop.Revert = true;
1419     }
1420   }
1421   LoLoop.Validate(BBUtils.get());
1422   Expand(LoLoop);
1423   return true;
1424 }
1425 
1426 // WhileLoopStart holds the exit block, so produce a cmp lr, 0 and then a
1427 // beq that branches to the exit branch.
1428 // TODO: We could also try to generate a cbz if the value in LR is also in
1429 // another low register.
1430 void ARMLowOverheadLoops::RevertWhile(MachineInstr *MI) const {
1431   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp: " << *MI);
1432   MachineBasicBlock *DestBB = getWhileLoopStartTargetBB(*MI);
1433   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
1434     ARM::tBcc : ARM::t2Bcc;
1435 
1436   RevertWhileLoopStartLR(MI, TII, BrOpc);
1437 }
1438 
1439 void ARMLowOverheadLoops::RevertDo(MachineInstr *MI) const {
1440   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to mov: " << *MI);
1441   RevertDoLoopStart(MI, TII);
1442 }
1443 
1444 bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI) const {
1445   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub: " << *MI);
1446   MachineBasicBlock *MBB = MI->getParent();
1447   SmallPtrSet<MachineInstr*, 1> Ignore;
1448   for (auto I = MachineBasicBlock::iterator(MI), E = MBB->end(); I != E; ++I) {
1449     if (I->getOpcode() == ARM::t2LoopEnd) {
1450       Ignore.insert(&*I);
1451       break;
1452     }
1453   }
1454 
1455   // If nothing defines CPSR between LoopDec and LoopEnd, use a t2SUBS.
1456   bool SetFlags =
1457       RDA->isSafeToDefRegAt(MI, MCRegister::from(ARM::CPSR), Ignore);
1458 
1459   llvm::RevertLoopDec(MI, TII, SetFlags);
1460   return SetFlags;
1461 }
1462 
1463 // Generate a subs, or sub and cmp, and a branch instead of an LE.
1464 void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const {
1465   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp, br: " << *MI);
1466 
1467   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
1468   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
1469     ARM::tBcc : ARM::t2Bcc;
1470 
1471   llvm::RevertLoopEnd(MI, TII, BrOpc, SkipCmp);
1472 }
1473 
1474 // Generate a subs, or sub and cmp, and a branch instead of an LE.
1475 void ARMLowOverheadLoops::RevertLoopEndDec(MachineInstr *MI) const {
1476   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to subs, br: " << *MI);
1477   assert(MI->getOpcode() == ARM::t2LoopEndDec && "Expected a t2LoopEndDec!");
1478   MachineBasicBlock *MBB = MI->getParent();
1479 
1480   MachineInstrBuilder MIB =
1481       BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(ARM::t2SUBri));
1482   MIB.addDef(ARM::LR);
1483   MIB.add(MI->getOperand(1));
1484   MIB.addImm(1);
1485   MIB.addImm(ARMCC::AL);
1486   MIB.addReg(ARM::NoRegister);
1487   MIB.addReg(ARM::CPSR);
1488   MIB->getOperand(5).setIsDef(true);
1489 
1490   MachineBasicBlock *DestBB = MI->getOperand(2).getMBB();
1491   unsigned BrOpc =
1492       BBUtils->isBBInRange(MI, DestBB, 254) ? ARM::tBcc : ARM::t2Bcc;
1493 
1494   // Create bne
1495   MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
1496   MIB.add(MI->getOperand(2)); // branch target
1497   MIB.addImm(ARMCC::NE);      // condition code
1498   MIB.addReg(ARM::CPSR);
1499 
1500   MI->eraseFromParent();
1501 }
1502 
1503 // Perform dead code elimation on the loop iteration count setup expression.
1504 // If we are tail-predicating, the number of elements to be processed is the
1505 // operand of the VCTP instruction in the vector body, see getCount(), which is
1506 // register $r3 in this example:
1507 //
1508 //   $lr = big-itercount-expression
1509 //   ..
1510 //   $lr = t2DoLoopStart renamable $lr
1511 //   vector.body:
1512 //     ..
1513 //     $vpr = MVE_VCTP32 renamable $r3
1514 //     renamable $lr = t2LoopDec killed renamable $lr, 1
1515 //     t2LoopEnd renamable $lr, %vector.body
1516 //     tB %end
1517 //
1518 // What we would like achieve here is to replace the do-loop start pseudo
1519 // instruction t2DoLoopStart with:
1520 //
1521 //    $lr = MVE_DLSTP_32 killed renamable $r3
1522 //
1523 // Thus, $r3 which defines the number of elements, is written to $lr,
1524 // and then we want to delete the whole chain that used to define $lr,
1525 // see the comment below how this chain could look like.
1526 //
1527 void ARMLowOverheadLoops::IterationCountDCE(LowOverheadLoop &LoLoop) {
1528   if (!LoLoop.IsTailPredicationLegal())
1529     return;
1530 
1531   LLVM_DEBUG(dbgs() << "ARM Loops: Trying DCE on loop iteration count.\n");
1532 
1533   MachineInstr *Def = RDA->getMIOperand(LoLoop.Start, 1);
1534   if (!Def) {
1535     LLVM_DEBUG(dbgs() << "ARM Loops: Couldn't find iteration count.\n");
1536     return;
1537   }
1538 
1539   // Collect and remove the users of iteration count.
1540   SmallPtrSet<MachineInstr*, 4> Killed  = { LoLoop.Start, LoLoop.Dec,
1541                                             LoLoop.End };
1542   if (!TryRemove(Def, *RDA, LoLoop.ToRemove, Killed))
1543     LLVM_DEBUG(dbgs() << "ARM Loops: Unsafe to remove loop iteration count.\n");
1544 }
1545 
1546 MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) {
1547   LLVM_DEBUG(dbgs() << "ARM Loops: Expanding LoopStart.\n");
1548   // When using tail-predication, try to delete the dead code that was used to
1549   // calculate the number of loop iterations.
1550   IterationCountDCE(LoLoop);
1551 
1552   MachineBasicBlock::iterator InsertPt = LoLoop.StartInsertPt;
1553   MachineInstr *Start = LoLoop.Start;
1554   MachineBasicBlock *MBB = LoLoop.StartInsertBB;
1555   unsigned Opc = LoLoop.getStartOpcode();
1556   MachineOperand &Count = LoLoop.getLoopStartOperand();
1557 
1558   // A DLS lr, lr we needn't emit
1559   MachineInstr* NewStart;
1560   if (Opc == ARM::t2DLS && Count.isReg() && Count.getReg() == ARM::LR) {
1561     LLVM_DEBUG(dbgs() << "ARM Loops: Didn't insert start: DLS lr, lr");
1562     NewStart = nullptr;
1563   } else {
1564     MachineInstrBuilder MIB =
1565       BuildMI(*MBB, InsertPt, Start->getDebugLoc(), TII->get(Opc));
1566 
1567     MIB.addDef(ARM::LR);
1568     MIB.add(Count);
1569     if (isWhileLoopStart(*Start))
1570       MIB.addMBB(getWhileLoopStartTargetBB(*Start));
1571 
1572     LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB);
1573     NewStart = &*MIB;
1574   }
1575 
1576   LoLoop.ToRemove.insert(Start);
1577   return NewStart;
1578 }
1579 
1580 void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) {
1581   auto RemovePredicate = [](MachineInstr *MI) {
1582     if (MI->isDebugInstr())
1583       return;
1584     LLVM_DEBUG(dbgs() << "ARM Loops: Removing predicate from: " << *MI);
1585     int PIdx = llvm::findFirstVPTPredOperandIdx(*MI);
1586     assert(PIdx >= 1 && "Trying to unpredicate a non-predicated instruction");
1587     assert(MI->getOperand(PIdx).getImm() == ARMVCC::Then &&
1588            "Expected Then predicate!");
1589     MI->getOperand(PIdx).setImm(ARMVCC::None);
1590     MI->getOperand(PIdx + 1).setReg(0);
1591   };
1592 
1593   for (auto &Block : LoLoop.getVPTBlocks()) {
1594     SmallVectorImpl<MachineInstr *> &Insts = Block.getInsts();
1595 
1596     auto ReplaceVCMPWithVPT = [&](MachineInstr *&TheVCMP, MachineInstr *At) {
1597       assert(TheVCMP && "Replacing a removed or non-existent VCMP");
1598       // Replace the VCMP with a VPT
1599       MachineInstrBuilder MIB =
1600           BuildMI(*At->getParent(), At, At->getDebugLoc(),
1601                   TII->get(VCMPOpcodeToVPT(TheVCMP->getOpcode())));
1602       MIB.addImm(ARMVCC::Then);
1603       // Register one
1604       MIB.add(TheVCMP->getOperand(1));
1605       // Register two
1606       MIB.add(TheVCMP->getOperand(2));
1607       // The comparison code, e.g. ge, eq, lt
1608       MIB.add(TheVCMP->getOperand(3));
1609       LLVM_DEBUG(dbgs() << "ARM Loops: Combining with VCMP to VPT: " << *MIB);
1610       LoLoop.BlockMasksToRecompute.insert(MIB.getInstr());
1611       LoLoop.ToRemove.insert(TheVCMP);
1612       TheVCMP = nullptr;
1613     };
1614 
1615     if (VPTState::isEntryPredicatedOnVCTP(Block, /*exclusive*/ true)) {
1616       MachineInstr *VPST = Insts.front();
1617       if (VPTState::hasUniformPredicate(Block)) {
1618         // A vpt block starting with VPST, is only predicated upon vctp and has no
1619         // internal vpr defs:
1620         // - Remove vpst.
1621         // - Unpredicate the remaining instructions.
1622         LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *VPST);
1623         for (unsigned i = 1; i < Insts.size(); ++i)
1624           RemovePredicate(Insts[i]);
1625       } else {
1626         // The VPT block has a non-uniform predicate but it uses a vpst and its
1627         // entry is guarded only by a vctp, which means we:
1628         // - Need to remove the original vpst.
1629         // - Then need to unpredicate any following instructions, until
1630         //   we come across the divergent vpr def.
1631         // - Insert a new vpst to predicate the instruction(s) that following
1632         //   the divergent vpr def.
1633         MachineInstr *Divergent = VPTState::getDivergent(Block);
1634         MachineBasicBlock *MBB = Divergent->getParent();
1635         auto DivergentNext = ++MachineBasicBlock::iterator(Divergent);
1636         while (DivergentNext != MBB->end() && DivergentNext->isDebugInstr())
1637           ++DivergentNext;
1638 
1639         bool DivergentNextIsPredicated =
1640             DivergentNext != MBB->end() &&
1641             getVPTInstrPredicate(*DivergentNext) != ARMVCC::None;
1642 
1643         for (auto I = ++MachineBasicBlock::iterator(VPST), E = DivergentNext;
1644              I != E; ++I)
1645           RemovePredicate(&*I);
1646 
1647         // Check if the instruction defining vpr is a vcmp so it can be combined
1648         // with the VPST This should be the divergent instruction
1649         MachineInstr *VCMP =
1650             VCMPOpcodeToVPT(Divergent->getOpcode()) != 0 ? Divergent : nullptr;
1651 
1652         if (DivergentNextIsPredicated) {
1653           // Insert a VPST at the divergent only if the next instruction
1654           // would actually use it. A VCMP following a VPST can be
1655           // merged into a VPT so do that instead if the VCMP exists.
1656           if (!VCMP) {
1657             // Create a VPST (with a null mask for now, we'll recompute it
1658             // later)
1659             MachineInstrBuilder MIB =
1660                 BuildMI(*Divergent->getParent(), Divergent,
1661                         Divergent->getDebugLoc(), TII->get(ARM::MVE_VPST));
1662             MIB.addImm(0);
1663             LLVM_DEBUG(dbgs() << "ARM Loops: Created VPST: " << *MIB);
1664             LoLoop.BlockMasksToRecompute.insert(MIB.getInstr());
1665           } else {
1666             // No RDA checks are necessary here since the VPST would have been
1667             // directly after the VCMP
1668             ReplaceVCMPWithVPT(VCMP, VCMP);
1669           }
1670         }
1671       }
1672       LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *VPST);
1673       LoLoop.ToRemove.insert(VPST);
1674     } else if (Block.containsVCTP()) {
1675       // The vctp will be removed, so either the entire block will be dead or
1676       // the block mask of the vp(s)t will need to be recomputed.
1677       MachineInstr *VPST = Insts.front();
1678       if (Block.size() == 2) {
1679         assert(VPST->getOpcode() == ARM::MVE_VPST &&
1680                "Found a VPST in an otherwise empty vpt block");
1681         LoLoop.ToRemove.insert(VPST);
1682       } else
1683         LoLoop.BlockMasksToRecompute.insert(VPST);
1684     } else if (Insts.front()->getOpcode() == ARM::MVE_VPST) {
1685       // If this block starts with a VPST then attempt to merge it with the
1686       // preceeding un-merged VCMP into a VPT. This VCMP comes from a VPT
1687       // block that no longer exists
1688       MachineInstr *VPST = Insts.front();
1689       auto Next = ++MachineBasicBlock::iterator(VPST);
1690       assert(getVPTInstrPredicate(*Next) != ARMVCC::None &&
1691              "The instruction after a VPST must be predicated");
1692       (void)Next;
1693       MachineInstr *VprDef = RDA->getUniqueReachingMIDef(VPST, ARM::VPR);
1694       if (VprDef && VCMPOpcodeToVPT(VprDef->getOpcode()) &&
1695           !LoLoop.ToRemove.contains(VprDef)) {
1696         MachineInstr *VCMP = VprDef;
1697         // The VCMP and VPST can only be merged if the VCMP's operands will have
1698         // the same values at the VPST.
1699         // If any of the instructions between the VCMP and VPST are predicated
1700         // then a different code path is expected to have merged the VCMP and
1701         // VPST already.
1702         if (!std::any_of(++MachineBasicBlock::iterator(VCMP),
1703                          MachineBasicBlock::iterator(VPST), hasVPRUse) &&
1704             RDA->hasSameReachingDef(VCMP, VPST, VCMP->getOperand(1).getReg()) &&
1705             RDA->hasSameReachingDef(VCMP, VPST, VCMP->getOperand(2).getReg())) {
1706           ReplaceVCMPWithVPT(VCMP, VPST);
1707           LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *VPST);
1708           LoLoop.ToRemove.insert(VPST);
1709         }
1710       }
1711     }
1712   }
1713 
1714   LoLoop.ToRemove.insert(LoLoop.VCTPs.begin(), LoLoop.VCTPs.end());
1715 }
1716 
1717 void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
1718 
1719   // Combine the LoopDec and LoopEnd instructions into LE(TP).
1720   auto ExpandLoopEnd = [this](LowOverheadLoop &LoLoop) {
1721     MachineInstr *End = LoLoop.End;
1722     MachineBasicBlock *MBB = End->getParent();
1723     unsigned Opc = LoLoop.IsTailPredicationLegal() ?
1724       ARM::MVE_LETP : ARM::t2LEUpdate;
1725     MachineInstrBuilder MIB = BuildMI(*MBB, End, End->getDebugLoc(),
1726                                       TII->get(Opc));
1727     MIB.addDef(ARM::LR);
1728     unsigned Off = LoLoop.Dec == LoLoop.End ? 1 : 0;
1729     MIB.add(End->getOperand(Off + 0));
1730     MIB.add(End->getOperand(Off + 1));
1731     LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB);
1732     LoLoop.ToRemove.insert(LoLoop.Dec);
1733     LoLoop.ToRemove.insert(End);
1734     return &*MIB;
1735   };
1736 
1737   // TODO: We should be able to automatically remove these branches before we
1738   // get here - probably by teaching analyzeBranch about the pseudo
1739   // instructions.
1740   // If there is an unconditional branch, after I, that just branches to the
1741   // next block, remove it.
1742   auto RemoveDeadBranch = [](MachineInstr *I) {
1743     MachineBasicBlock *BB = I->getParent();
1744     MachineInstr *Terminator = &BB->instr_back();
1745     if (Terminator->isUnconditionalBranch() && I != Terminator) {
1746       MachineBasicBlock *Succ = Terminator->getOperand(0).getMBB();
1747       if (BB->isLayoutSuccessor(Succ)) {
1748         LLVM_DEBUG(dbgs() << "ARM Loops: Removing branch: " << *Terminator);
1749         Terminator->eraseFromParent();
1750       }
1751     }
1752   };
1753 
1754   // And VMOVCopies need to become 2xVMOVD for tail predication to be valid.
1755   // Anything other MQPRCopy can be converted to MVE_VORR later on.
1756   auto ExpandVMOVCopies = [this](SmallPtrSet<MachineInstr *, 4> &VMOVCopies) {
1757     for (auto *MI : VMOVCopies) {
1758       LLVM_DEBUG(dbgs() << "Converting copy to VMOVD: " << *MI);
1759       assert(MI->getOpcode() == ARM::MQPRCopy && "Only expected MQPRCOPY!");
1760       MachineBasicBlock *MBB = MI->getParent();
1761       Register Dst = MI->getOperand(0).getReg();
1762       Register Src = MI->getOperand(1).getReg();
1763       auto MIB1 = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(ARM::VMOVD),
1764                           ARM::D0 + (Dst - ARM::Q0) * 2)
1765                       .addReg(ARM::D0 + (Src - ARM::Q0) * 2)
1766                       .add(predOps(ARMCC::AL));
1767       (void)MIB1;
1768       LLVM_DEBUG(dbgs() << " into " << *MIB1);
1769       auto MIB2 = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(ARM::VMOVD),
1770                           ARM::D0 + (Dst - ARM::Q0) * 2 + 1)
1771                       .addReg(ARM::D0 + (Src - ARM::Q0) * 2 + 1)
1772                       .add(predOps(ARMCC::AL));
1773       LLVM_DEBUG(dbgs() << " and  " << *MIB2);
1774       (void)MIB2;
1775       MI->eraseFromParent();
1776     }
1777   };
1778 
1779   if (LoLoop.Revert) {
1780     if (isWhileLoopStart(*LoLoop.Start))
1781       RevertWhile(LoLoop.Start);
1782     else
1783       RevertDo(LoLoop.Start);
1784     if (LoLoop.Dec == LoLoop.End)
1785       RevertLoopEndDec(LoLoop.End);
1786     else
1787       RevertLoopEnd(LoLoop.End, RevertLoopDec(LoLoop.Dec));
1788   } else {
1789     ExpandVMOVCopies(LoLoop.VMOVCopies);
1790     LoLoop.Start = ExpandLoopStart(LoLoop);
1791     if (LoLoop.Start)
1792       RemoveDeadBranch(LoLoop.Start);
1793     LoLoop.End = ExpandLoopEnd(LoLoop);
1794     RemoveDeadBranch(LoLoop.End);
1795     if (LoLoop.IsTailPredicationLegal())
1796       ConvertVPTBlocks(LoLoop);
1797     for (auto *I : LoLoop.ToRemove) {
1798       LLVM_DEBUG(dbgs() << "ARM Loops: Erasing " << *I);
1799       I->eraseFromParent();
1800     }
1801     for (auto *I : LoLoop.BlockMasksToRecompute) {
1802       LLVM_DEBUG(dbgs() << "ARM Loops: Recomputing VPT/VPST Block Mask: " << *I);
1803       recomputeVPTBlockMask(*I);
1804       LLVM_DEBUG(dbgs() << "           ... done: " << *I);
1805     }
1806   }
1807 
1808   PostOrderLoopTraversal DFS(LoLoop.ML, *MLI);
1809   DFS.ProcessLoop();
1810   const SmallVectorImpl<MachineBasicBlock*> &PostOrder = DFS.getOrder();
1811   for (auto *MBB : PostOrder) {
1812     recomputeLiveIns(*MBB);
1813     // FIXME: For some reason, the live-in print order is non-deterministic for
1814     // our tests and I can't out why... So just sort them.
1815     MBB->sortUniqueLiveIns();
1816   }
1817 
1818   for (auto *MBB : reverse(PostOrder))
1819     recomputeLivenessFlags(*MBB);
1820 
1821   // We've moved, removed and inserted new instructions, so update RDA.
1822   RDA->reset();
1823 }
1824 
1825 bool ARMLowOverheadLoops::RevertNonLoops() {
1826   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting any remaining pseudos...\n");
1827   bool Changed = false;
1828 
1829   for (auto &MBB : *MF) {
1830     SmallVector<MachineInstr*, 4> Starts;
1831     SmallVector<MachineInstr*, 4> Decs;
1832     SmallVector<MachineInstr*, 4> Ends;
1833     SmallVector<MachineInstr *, 4> EndDecs;
1834     SmallVector<MachineInstr *, 4> MQPRCopies;
1835 
1836     for (auto &I : MBB) {
1837       if (isLoopStart(I))
1838         Starts.push_back(&I);
1839       else if (I.getOpcode() == ARM::t2LoopDec)
1840         Decs.push_back(&I);
1841       else if (I.getOpcode() == ARM::t2LoopEnd)
1842         Ends.push_back(&I);
1843       else if (I.getOpcode() == ARM::t2LoopEndDec)
1844         EndDecs.push_back(&I);
1845       else if (I.getOpcode() == ARM::MQPRCopy)
1846         MQPRCopies.push_back(&I);
1847     }
1848 
1849     if (Starts.empty() && Decs.empty() && Ends.empty() && EndDecs.empty() &&
1850         MQPRCopies.empty())
1851       continue;
1852 
1853     Changed = true;
1854 
1855     for (auto *Start : Starts) {
1856       if (isWhileLoopStart(*Start))
1857         RevertWhile(Start);
1858       else
1859         RevertDo(Start);
1860     }
1861     for (auto *Dec : Decs)
1862       RevertLoopDec(Dec);
1863 
1864     for (auto *End : Ends)
1865       RevertLoopEnd(End);
1866     for (auto *End : EndDecs)
1867       RevertLoopEndDec(End);
1868     for (auto *MI : MQPRCopies) {
1869       LLVM_DEBUG(dbgs() << "Converting copy to VORR: " << *MI);
1870       assert(MI->getOpcode() == ARM::MQPRCopy && "Only expected MQPRCOPY!");
1871       MachineBasicBlock *MBB = MI->getParent();
1872       auto MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(ARM::MVE_VORR),
1873                          MI->getOperand(0).getReg())
1874                      .add(MI->getOperand(1))
1875                      .add(MI->getOperand(1));
1876       addUnpredicatedMveVpredROp(MIB, MI->getOperand(0).getReg());
1877       MI->eraseFromParent();
1878     }
1879   }
1880   return Changed;
1881 }
1882 
1883 FunctionPass *llvm::createARMLowOverheadLoopsPass() {
1884   return new ARMLowOverheadLoops();
1885 }
1886