1 //===- SMEPeepholeOpt.cpp - SME peephole optimization pass-----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // This pass tries to remove back-to-back (smstart, smstop) and 9 // (smstop, smstart) sequences. The pass is conservative when it cannot 10 // determine that it is safe to remove these sequences. 11 //===----------------------------------------------------------------------===// 12 13 #include "AArch64InstrInfo.h" 14 #include "AArch64MachineFunctionInfo.h" 15 #include "AArch64Subtarget.h" 16 #include "llvm/ADT/SmallVector.h" 17 #include "llvm/CodeGen/MachineBasicBlock.h" 18 #include "llvm/CodeGen/MachineFunctionPass.h" 19 #include "llvm/CodeGen/MachineRegisterInfo.h" 20 #include "llvm/CodeGen/TargetRegisterInfo.h" 21 22 using namespace llvm; 23 24 #define DEBUG_TYPE "aarch64-sme-peephole-opt" 25 26 namespace { 27 28 struct SMEPeepholeOpt : public MachineFunctionPass { 29 static char ID; 30 31 SMEPeepholeOpt() : MachineFunctionPass(ID) {} 32 33 bool runOnMachineFunction(MachineFunction &MF) override; 34 35 StringRef getPassName() const override { 36 return "SME Peephole Optimization pass"; 37 } 38 39 void getAnalysisUsage(AnalysisUsage &AU) const override { 40 AU.setPreservesCFG(); 41 MachineFunctionPass::getAnalysisUsage(AU); 42 } 43 44 bool optimizeStartStopPairs(MachineBasicBlock &MBB, 45 bool &HasRemovedAllSMChanges) const; 46 bool visitRegSequence(MachineInstr &MI); 47 }; 48 49 char SMEPeepholeOpt::ID = 0; 50 51 } // end anonymous namespace 52 53 static bool isConditionalStartStop(const MachineInstr *MI) { 54 return MI->getOpcode() == AArch64::MSRpstatePseudo; 55 } 56 57 static bool isMatchingStartStopPair(const MachineInstr *MI1, 58 const MachineInstr *MI2) { 59 // We only consider the same type of streaming mode change here, i.e. 60 // start/stop SM, or start/stop ZA pairs. 61 if (MI1->getOperand(0).getImm() != MI2->getOperand(0).getImm()) 62 return false; 63 64 // One must be 'start', the other must be 'stop' 65 if (MI1->getOperand(1).getImm() == MI2->getOperand(1).getImm()) 66 return false; 67 68 bool IsConditional = isConditionalStartStop(MI2); 69 if (isConditionalStartStop(MI1) != IsConditional) 70 return false; 71 72 if (!IsConditional) 73 return true; 74 75 // Check to make sure the conditional start/stop pairs are identical. 76 if (MI1->getOperand(2).getImm() != MI2->getOperand(2).getImm()) 77 return false; 78 79 // Ensure reg masks are identical. 80 if (MI1->getOperand(4).getRegMask() != MI2->getOperand(4).getRegMask()) 81 return false; 82 83 // This optimisation is unlikely to happen in practice for conditional 84 // smstart/smstop pairs as the virtual registers for pstate.sm will always 85 // be different. 86 // TODO: For this optimisation to apply to conditional smstart/smstop, 87 // this pass will need to do more work to remove redundant calls to 88 // __arm_sme_state. 89 90 // Only consider conditional start/stop pairs which read the same register 91 // holding the original value of pstate.sm, as some conditional start/stops 92 // require the state on entry to the function. 93 if (MI1->getOperand(3).isReg() && MI2->getOperand(3).isReg()) { 94 Register Reg1 = MI1->getOperand(3).getReg(); 95 Register Reg2 = MI2->getOperand(3).getReg(); 96 if (Reg1.isPhysical() || Reg2.isPhysical() || Reg1 != Reg2) 97 return false; 98 } 99 100 return true; 101 } 102 103 static bool ChangesStreamingMode(const MachineInstr *MI) { 104 assert((MI->getOpcode() == AArch64::MSRpstatesvcrImm1 || 105 MI->getOpcode() == AArch64::MSRpstatePseudo) && 106 "Expected MI to be a smstart/smstop instruction"); 107 return MI->getOperand(0).getImm() == AArch64SVCR::SVCRSM || 108 MI->getOperand(0).getImm() == AArch64SVCR::SVCRSMZA; 109 } 110 111 static bool isSVERegOp(const TargetRegisterInfo &TRI, 112 const MachineRegisterInfo &MRI, 113 const MachineOperand &MO) { 114 if (!MO.isReg()) 115 return false; 116 117 Register R = MO.getReg(); 118 if (R.isPhysical()) 119 return llvm::any_of(TRI.subregs_inclusive(R), [](const MCPhysReg &SR) { 120 return AArch64::ZPRRegClass.contains(SR) || 121 AArch64::PPRRegClass.contains(SR); 122 }); 123 124 const TargetRegisterClass *RC = MRI.getRegClass(R); 125 return TRI.getCommonSubClass(&AArch64::ZPRRegClass, RC) || 126 TRI.getCommonSubClass(&AArch64::PPRRegClass, RC); 127 } 128 129 bool SMEPeepholeOpt::optimizeStartStopPairs( 130 MachineBasicBlock &MBB, bool &HasRemovedAllSMChanges) const { 131 const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo(); 132 const TargetRegisterInfo &TRI = 133 *MBB.getParent()->getSubtarget().getRegisterInfo(); 134 135 bool Changed = false; 136 MachineInstr *Prev = nullptr; 137 SmallVector<MachineInstr *, 4> ToBeRemoved; 138 139 // Convenience function to reset the matching of a sequence. 140 auto Reset = [&]() { 141 Prev = nullptr; 142 ToBeRemoved.clear(); 143 }; 144 145 // Walk through instructions in the block trying to find pairs of smstart 146 // and smstop nodes that cancel each other out. We only permit a limited 147 // set of instructions to appear between them, otherwise we reset our 148 // tracking. 149 unsigned NumSMChanges = 0; 150 unsigned NumSMChangesRemoved = 0; 151 for (MachineInstr &MI : make_early_inc_range(MBB)) { 152 switch (MI.getOpcode()) { 153 case AArch64::MSRpstatesvcrImm1: 154 case AArch64::MSRpstatePseudo: { 155 if (ChangesStreamingMode(&MI)) 156 NumSMChanges++; 157 158 if (!Prev) 159 Prev = &MI; 160 else if (isMatchingStartStopPair(Prev, &MI)) { 161 // If they match, we can remove them, and possibly any instructions 162 // that we marked for deletion in between. 163 Prev->eraseFromParent(); 164 MI.eraseFromParent(); 165 for (MachineInstr *TBR : ToBeRemoved) 166 TBR->eraseFromParent(); 167 ToBeRemoved.clear(); 168 Prev = nullptr; 169 Changed = true; 170 NumSMChangesRemoved += 2; 171 } else { 172 Reset(); 173 Prev = &MI; 174 } 175 continue; 176 } 177 default: 178 if (!Prev) 179 // Avoid doing expensive checks when Prev is nullptr. 180 continue; 181 break; 182 } 183 184 // Test if the instructions in between the start/stop sequence are agnostic 185 // of streaming mode. If not, the algorithm should reset. 186 switch (MI.getOpcode()) { 187 default: 188 Reset(); 189 break; 190 case AArch64::COALESCER_BARRIER_FPR16: 191 case AArch64::COALESCER_BARRIER_FPR32: 192 case AArch64::COALESCER_BARRIER_FPR64: 193 case AArch64::COALESCER_BARRIER_FPR128: 194 case AArch64::COPY: 195 // These instructions should be safe when executed on their own, but 196 // the code remains conservative when SVE registers are used. There may 197 // exist subtle cases where executing a COPY in a different mode results 198 // in different behaviour, even if we can't yet come up with any 199 // concrete example/test-case. 200 if (isSVERegOp(TRI, MRI, MI.getOperand(0)) || 201 isSVERegOp(TRI, MRI, MI.getOperand(1))) 202 Reset(); 203 break; 204 case AArch64::ADJCALLSTACKDOWN: 205 case AArch64::ADJCALLSTACKUP: 206 case AArch64::ANDXri: 207 case AArch64::ADDXri: 208 // We permit these as they don't generate SVE/NEON instructions. 209 break; 210 case AArch64::VGRestorePseudo: 211 case AArch64::VGSavePseudo: 212 // When the smstart/smstop are removed, we should also remove 213 // the pseudos that save/restore the VG value for CFI info. 214 ToBeRemoved.push_back(&MI); 215 break; 216 case AArch64::MSRpstatesvcrImm1: 217 case AArch64::MSRpstatePseudo: 218 llvm_unreachable("Should have been handled"); 219 } 220 } 221 222 HasRemovedAllSMChanges = 223 NumSMChanges && (NumSMChanges == NumSMChangesRemoved); 224 return Changed; 225 } 226 227 // Using the FORM_TRANSPOSED_REG_TUPLE pseudo can improve register allocation 228 // of multi-vector intrinsics. However, the pseudo should only be emitted if 229 // the input registers of the REG_SEQUENCE are copy nodes where the source 230 // register is in a StridedOrContiguous class. For example: 231 // 232 // %3:zpr2stridedorcontiguous = LD1B_2Z_IMM_PSEUDO .. 233 // %4:zpr = COPY %3.zsub1:zpr2stridedorcontiguous 234 // %5:zpr = COPY %3.zsub0:zpr2stridedorcontiguous 235 // %6:zpr2stridedorcontiguous = LD1B_2Z_PSEUDO .. 236 // %7:zpr = COPY %6.zsub1:zpr2stridedorcontiguous 237 // %8:zpr = COPY %6.zsub0:zpr2stridedorcontiguous 238 // %9:zpr2mul2 = REG_SEQUENCE %5:zpr, %subreg.zsub0, %8:zpr, %subreg.zsub1 239 // 240 // -> %9:zpr2mul2 = FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO %5:zpr, %8:zpr 241 // 242 bool SMEPeepholeOpt::visitRegSequence(MachineInstr &MI) { 243 assert(MI.getMF()->getRegInfo().isSSA() && "Expected to be run on SSA form!"); 244 245 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); 246 switch (MRI.getRegClass(MI.getOperand(0).getReg())->getID()) { 247 case AArch64::ZPR2RegClassID: 248 case AArch64::ZPR4RegClassID: 249 case AArch64::ZPR2Mul2RegClassID: 250 case AArch64::ZPR4Mul4RegClassID: 251 break; 252 default: 253 return false; 254 } 255 256 // The first operand is the register class created by the REG_SEQUENCE. 257 // Each operand pair after this consists of a vreg + subreg index, so 258 // for example a sequence of 2 registers will have a total of 5 operands. 259 if (MI.getNumOperands() != 5 && MI.getNumOperands() != 9) 260 return false; 261 262 MCRegister SubReg = MCRegister::NoRegister; 263 for (unsigned I = 1; I < MI.getNumOperands(); I += 2) { 264 MachineOperand &MO = MI.getOperand(I); 265 266 MachineOperand *Def = MRI.getOneDef(MO.getReg()); 267 if (!Def || !Def->getParent()->isCopy()) 268 return false; 269 270 const MachineOperand &CopySrc = Def->getParent()->getOperand(1); 271 unsigned OpSubReg = CopySrc.getSubReg(); 272 if (SubReg == MCRegister::NoRegister) 273 SubReg = OpSubReg; 274 275 MachineOperand *CopySrcOp = MRI.getOneDef(CopySrc.getReg()); 276 if (!CopySrcOp || !CopySrcOp->isReg() || OpSubReg != SubReg || 277 CopySrcOp->getReg().isPhysical()) 278 return false; 279 280 const TargetRegisterClass *CopySrcClass = 281 MRI.getRegClass(CopySrcOp->getReg()); 282 if (CopySrcClass != &AArch64::ZPR2StridedOrContiguousRegClass && 283 CopySrcClass != &AArch64::ZPR4StridedOrContiguousRegClass) 284 return false; 285 } 286 287 unsigned Opc = MI.getNumOperands() == 5 288 ? AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO 289 : AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO; 290 291 const TargetInstrInfo *TII = 292 MI.getMF()->getSubtarget<AArch64Subtarget>().getInstrInfo(); 293 MachineInstrBuilder MIB = BuildMI(*MI.getParent(), MI, MI.getDebugLoc(), 294 TII->get(Opc), MI.getOperand(0).getReg()); 295 for (unsigned I = 1; I < MI.getNumOperands(); I += 2) 296 MIB.addReg(MI.getOperand(I).getReg()); 297 298 MI.eraseFromParent(); 299 return true; 300 } 301 302 INITIALIZE_PASS(SMEPeepholeOpt, "aarch64-sme-peephole-opt", 303 "SME Peephole Optimization", false, false) 304 305 bool SMEPeepholeOpt::runOnMachineFunction(MachineFunction &MF) { 306 if (skipFunction(MF.getFunction())) 307 return false; 308 309 if (!MF.getSubtarget<AArch64Subtarget>().hasSME()) 310 return false; 311 312 assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!"); 313 314 bool Changed = false; 315 bool FunctionHasAllSMChangesRemoved = false; 316 317 // Even if the block lives in a function with no SME attributes attached we 318 // still have to analyze all the blocks because we may call a streaming 319 // function that requires smstart/smstop pairs. 320 for (MachineBasicBlock &MBB : MF) { 321 bool BlockHasAllSMChangesRemoved; 322 Changed |= optimizeStartStopPairs(MBB, BlockHasAllSMChangesRemoved); 323 FunctionHasAllSMChangesRemoved |= BlockHasAllSMChangesRemoved; 324 325 if (MF.getSubtarget<AArch64Subtarget>().isStreaming()) { 326 for (MachineInstr &MI : make_early_inc_range(MBB)) 327 if (MI.getOpcode() == AArch64::REG_SEQUENCE) 328 Changed |= visitRegSequence(MI); 329 } 330 } 331 332 AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>(); 333 if (FunctionHasAllSMChangesRemoved) 334 AFI->setHasStreamingModeChanges(false); 335 336 return Changed; 337 } 338 339 FunctionPass *llvm::createSMEPeepholeOptPass() { return new SMEPeepholeOpt(); } 340