xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- SMEPeepholeOpt.cpp - SME peephole optimization pass-----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This pass tries to remove back-to-back (smstart, smstop) and
9 // (smstop, smstart) sequences. The pass is conservative when it cannot
10 // determine that it is safe to remove these sequences.
11 //===----------------------------------------------------------------------===//
12 
13 #include "AArch64InstrInfo.h"
14 #include "AArch64MachineFunctionInfo.h"
15 #include "AArch64Subtarget.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/CodeGen/MachineBasicBlock.h"
18 #include "llvm/CodeGen/MachineFunctionPass.h"
19 #include "llvm/CodeGen/MachineRegisterInfo.h"
20 #include "llvm/CodeGen/TargetRegisterInfo.h"
21 
22 using namespace llvm;
23 
24 #define DEBUG_TYPE "aarch64-sme-peephole-opt"
25 
26 namespace {
27 
28 struct SMEPeepholeOpt : public MachineFunctionPass {
29   static char ID;
30 
SMEPeepholeOpt__anon2e9980480111::SMEPeepholeOpt31   SMEPeepholeOpt() : MachineFunctionPass(ID) {}
32 
33   bool runOnMachineFunction(MachineFunction &MF) override;
34 
getPassName__anon2e9980480111::SMEPeepholeOpt35   StringRef getPassName() const override {
36     return "SME Peephole Optimization pass";
37   }
38 
getAnalysisUsage__anon2e9980480111::SMEPeepholeOpt39   void getAnalysisUsage(AnalysisUsage &AU) const override {
40     AU.setPreservesCFG();
41     MachineFunctionPass::getAnalysisUsage(AU);
42   }
43 
44   bool optimizeStartStopPairs(MachineBasicBlock &MBB,
45                               bool &HasRemovedAllSMChanges) const;
46   bool visitRegSequence(MachineInstr &MI);
47 };
48 
49 char SMEPeepholeOpt::ID = 0;
50 
51 } // end anonymous namespace
52 
isConditionalStartStop(const MachineInstr * MI)53 static bool isConditionalStartStop(const MachineInstr *MI) {
54   return MI->getOpcode() == AArch64::MSRpstatePseudo;
55 }
56 
isMatchingStartStopPair(const MachineInstr * MI1,const MachineInstr * MI2)57 static bool isMatchingStartStopPair(const MachineInstr *MI1,
58                                     const MachineInstr *MI2) {
59   // We only consider the same type of streaming mode change here, i.e.
60   // start/stop SM, or start/stop ZA pairs.
61   if (MI1->getOperand(0).getImm() != MI2->getOperand(0).getImm())
62     return false;
63 
64   // One must be 'start', the other must be 'stop'
65   if (MI1->getOperand(1).getImm() == MI2->getOperand(1).getImm())
66     return false;
67 
68   bool IsConditional = isConditionalStartStop(MI2);
69   if (isConditionalStartStop(MI1) != IsConditional)
70     return false;
71 
72   if (!IsConditional)
73     return true;
74 
75   // Check to make sure the conditional start/stop pairs are identical.
76   if (MI1->getOperand(2).getImm() != MI2->getOperand(2).getImm())
77     return false;
78 
79   // Ensure reg masks are identical.
80   if (MI1->getOperand(4).getRegMask() != MI2->getOperand(4).getRegMask())
81     return false;
82 
83   // This optimisation is unlikely to happen in practice for conditional
84   // smstart/smstop pairs as the virtual registers for pstate.sm will always
85   // be different.
86   // TODO: For this optimisation to apply to conditional smstart/smstop,
87   // this pass will need to do more work to remove redundant calls to
88   // __arm_sme_state.
89 
90   // Only consider conditional start/stop pairs which read the same register
91   // holding the original value of pstate.sm, as some conditional start/stops
92   // require the state on entry to the function.
93   if (MI1->getOperand(3).isReg() && MI2->getOperand(3).isReg()) {
94     Register Reg1 = MI1->getOperand(3).getReg();
95     Register Reg2 = MI2->getOperand(3).getReg();
96     if (Reg1.isPhysical() || Reg2.isPhysical() || Reg1 != Reg2)
97       return false;
98   }
99 
100   return true;
101 }
102 
ChangesStreamingMode(const MachineInstr * MI)103 static bool ChangesStreamingMode(const MachineInstr *MI) {
104   assert((MI->getOpcode() == AArch64::MSRpstatesvcrImm1 ||
105           MI->getOpcode() == AArch64::MSRpstatePseudo) &&
106          "Expected MI to be a smstart/smstop instruction");
107   return MI->getOperand(0).getImm() == AArch64SVCR::SVCRSM ||
108          MI->getOperand(0).getImm() == AArch64SVCR::SVCRSMZA;
109 }
110 
isSVERegOp(const TargetRegisterInfo & TRI,const MachineRegisterInfo & MRI,const MachineOperand & MO)111 static bool isSVERegOp(const TargetRegisterInfo &TRI,
112                        const MachineRegisterInfo &MRI,
113                        const MachineOperand &MO) {
114   if (!MO.isReg())
115     return false;
116 
117   Register R = MO.getReg();
118   if (R.isPhysical())
119     return llvm::any_of(TRI.subregs_inclusive(R), [](const MCPhysReg &SR) {
120       return AArch64::ZPRRegClass.contains(SR) ||
121              AArch64::PPRRegClass.contains(SR);
122     });
123 
124   const TargetRegisterClass *RC = MRI.getRegClass(R);
125   return TRI.getCommonSubClass(&AArch64::ZPRRegClass, RC) ||
126          TRI.getCommonSubClass(&AArch64::PPRRegClass, RC);
127 }
128 
optimizeStartStopPairs(MachineBasicBlock & MBB,bool & HasRemovedAllSMChanges) const129 bool SMEPeepholeOpt::optimizeStartStopPairs(
130     MachineBasicBlock &MBB, bool &HasRemovedAllSMChanges) const {
131   const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
132   const TargetRegisterInfo &TRI =
133       *MBB.getParent()->getSubtarget().getRegisterInfo();
134 
135   bool Changed = false;
136   MachineInstr *Prev = nullptr;
137   SmallVector<MachineInstr *, 4> ToBeRemoved;
138 
139   // Convenience function to reset the matching of a sequence.
140   auto Reset = [&]() {
141     Prev = nullptr;
142     ToBeRemoved.clear();
143   };
144 
145   // Walk through instructions in the block trying to find pairs of smstart
146   // and smstop nodes that cancel each other out. We only permit a limited
147   // set of instructions to appear between them, otherwise we reset our
148   // tracking.
149   unsigned NumSMChanges = 0;
150   unsigned NumSMChangesRemoved = 0;
151   for (MachineInstr &MI : make_early_inc_range(MBB)) {
152     switch (MI.getOpcode()) {
153     case AArch64::MSRpstatesvcrImm1:
154     case AArch64::MSRpstatePseudo: {
155       if (ChangesStreamingMode(&MI))
156         NumSMChanges++;
157 
158       if (!Prev)
159         Prev = &MI;
160       else if (isMatchingStartStopPair(Prev, &MI)) {
161         // If they match, we can remove them, and possibly any instructions
162         // that we marked for deletion in between.
163         Prev->eraseFromParent();
164         MI.eraseFromParent();
165         for (MachineInstr *TBR : ToBeRemoved)
166           TBR->eraseFromParent();
167         ToBeRemoved.clear();
168         Prev = nullptr;
169         Changed = true;
170         NumSMChangesRemoved += 2;
171       } else {
172         Reset();
173         Prev = &MI;
174       }
175       continue;
176     }
177     default:
178       if (!Prev)
179         // Avoid doing expensive checks when Prev is nullptr.
180         continue;
181       break;
182     }
183 
184     // Test if the instructions in between the start/stop sequence are agnostic
185     // of streaming mode. If not, the algorithm should reset.
186     switch (MI.getOpcode()) {
187     default:
188       Reset();
189       break;
190     case AArch64::COALESCER_BARRIER_FPR16:
191     case AArch64::COALESCER_BARRIER_FPR32:
192     case AArch64::COALESCER_BARRIER_FPR64:
193     case AArch64::COALESCER_BARRIER_FPR128:
194     case AArch64::COPY:
195       // These instructions should be safe when executed on their own, but
196       // the code remains conservative when SVE registers are used. There may
197       // exist subtle cases where executing a COPY in a different mode results
198       // in different behaviour, even if we can't yet come up with any
199       // concrete example/test-case.
200       if (isSVERegOp(TRI, MRI, MI.getOperand(0)) ||
201           isSVERegOp(TRI, MRI, MI.getOperand(1)))
202         Reset();
203       break;
204     case AArch64::ADJCALLSTACKDOWN:
205     case AArch64::ADJCALLSTACKUP:
206     case AArch64::ANDXri:
207     case AArch64::ADDXri:
208       // We permit these as they don't generate SVE/NEON instructions.
209       break;
210     case AArch64::VGRestorePseudo:
211     case AArch64::VGSavePseudo:
212       // When the smstart/smstop are removed, we should also remove
213       // the pseudos that save/restore the VG value for CFI info.
214       ToBeRemoved.push_back(&MI);
215       break;
216     case AArch64::MSRpstatesvcrImm1:
217     case AArch64::MSRpstatePseudo:
218       llvm_unreachable("Should have been handled");
219     }
220   }
221 
222   HasRemovedAllSMChanges =
223       NumSMChanges && (NumSMChanges == NumSMChangesRemoved);
224   return Changed;
225 }
226 
227 // Using the FORM_TRANSPOSED_REG_TUPLE pseudo can improve register allocation
228 // of multi-vector intrinsics. However, the pseudo should only be emitted if
229 // the input registers of the REG_SEQUENCE are copy nodes where the source
230 // register is in a StridedOrContiguous class. For example:
231 //
232 //   %3:zpr2stridedorcontiguous = LD1B_2Z_IMM_PSEUDO ..
233 //   %4:zpr = COPY %3.zsub1:zpr2stridedorcontiguous
234 //   %5:zpr = COPY %3.zsub0:zpr2stridedorcontiguous
235 //   %6:zpr2stridedorcontiguous = LD1B_2Z_PSEUDO ..
236 //   %7:zpr = COPY %6.zsub1:zpr2stridedorcontiguous
237 //   %8:zpr = COPY %6.zsub0:zpr2stridedorcontiguous
238 //   %9:zpr2mul2 = REG_SEQUENCE %5:zpr, %subreg.zsub0, %8:zpr, %subreg.zsub1
239 //
240 //   ->  %9:zpr2mul2 = FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO %5:zpr, %8:zpr
241 //
visitRegSequence(MachineInstr & MI)242 bool SMEPeepholeOpt::visitRegSequence(MachineInstr &MI) {
243   assert(MI.getMF()->getRegInfo().isSSA() && "Expected to be run on SSA form!");
244 
245   MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
246   switch (MRI.getRegClass(MI.getOperand(0).getReg())->getID()) {
247   case AArch64::ZPR2RegClassID:
248   case AArch64::ZPR4RegClassID:
249   case AArch64::ZPR2Mul2RegClassID:
250   case AArch64::ZPR4Mul4RegClassID:
251     break;
252   default:
253     return false;
254   }
255 
256   // The first operand is the register class created by the REG_SEQUENCE.
257   // Each operand pair after this consists of a vreg + subreg index, so
258   // for example a sequence of 2 registers will have a total of 5 operands.
259   if (MI.getNumOperands() != 5 && MI.getNumOperands() != 9)
260     return false;
261 
262   MCRegister SubReg = MCRegister::NoRegister;
263   for (unsigned I = 1; I < MI.getNumOperands(); I += 2) {
264     MachineOperand &MO = MI.getOperand(I);
265 
266     MachineOperand *Def = MRI.getOneDef(MO.getReg());
267     if (!Def || !Def->getParent()->isCopy())
268       return false;
269 
270     const MachineOperand &CopySrc = Def->getParent()->getOperand(1);
271     unsigned OpSubReg = CopySrc.getSubReg();
272     if (SubReg == MCRegister::NoRegister)
273       SubReg = OpSubReg;
274 
275     MachineOperand *CopySrcOp = MRI.getOneDef(CopySrc.getReg());
276     if (!CopySrcOp || !CopySrcOp->isReg() || OpSubReg != SubReg ||
277         CopySrcOp->getReg().isPhysical())
278       return false;
279 
280     const TargetRegisterClass *CopySrcClass =
281         MRI.getRegClass(CopySrcOp->getReg());
282     if (CopySrcClass != &AArch64::ZPR2StridedOrContiguousRegClass &&
283         CopySrcClass != &AArch64::ZPR4StridedOrContiguousRegClass)
284       return false;
285   }
286 
287   unsigned Opc = MI.getNumOperands() == 5
288                      ? AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO
289                      : AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO;
290 
291   const TargetInstrInfo *TII =
292       MI.getMF()->getSubtarget<AArch64Subtarget>().getInstrInfo();
293   MachineInstrBuilder MIB = BuildMI(*MI.getParent(), MI, MI.getDebugLoc(),
294                                     TII->get(Opc), MI.getOperand(0).getReg());
295   for (unsigned I = 1; I < MI.getNumOperands(); I += 2)
296     MIB.addReg(MI.getOperand(I).getReg());
297 
298   MI.eraseFromParent();
299   return true;
300 }
301 
302 INITIALIZE_PASS(SMEPeepholeOpt, "aarch64-sme-peephole-opt",
303                 "SME Peephole Optimization", false, false)
304 
runOnMachineFunction(MachineFunction & MF)305 bool SMEPeepholeOpt::runOnMachineFunction(MachineFunction &MF) {
306   if (skipFunction(MF.getFunction()))
307     return false;
308 
309   if (!MF.getSubtarget<AArch64Subtarget>().hasSME())
310     return false;
311 
312   assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
313 
314   bool Changed = false;
315   bool FunctionHasAllSMChangesRemoved = false;
316 
317   // Even if the block lives in a function with no SME attributes attached we
318   // still have to analyze all the blocks because we may call a streaming
319   // function that requires smstart/smstop pairs.
320   for (MachineBasicBlock &MBB : MF) {
321     bool BlockHasAllSMChangesRemoved;
322     Changed |= optimizeStartStopPairs(MBB, BlockHasAllSMChangesRemoved);
323     FunctionHasAllSMChangesRemoved |= BlockHasAllSMChangesRemoved;
324 
325     if (MF.getSubtarget<AArch64Subtarget>().isStreaming()) {
326       for (MachineInstr &MI : make_early_inc_range(MBB))
327         if (MI.getOpcode() == AArch64::REG_SEQUENCE)
328           Changed |= visitRegSequence(MI);
329     }
330   }
331 
332   AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
333   if (FunctionHasAllSMChangesRemoved)
334     AFI->setHasStreamingModeChanges(false);
335 
336   return Changed;
337 }
338 
createSMEPeepholeOptPass()339 FunctionPass *llvm::createSMEPeepholeOptPass() { return new SMEPeepholeOpt(); }
340