1 //===-- AArch64CondBrTuning.cpp --- Conditional branch tuning for AArch64 -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// \file 9 /// This file contains a pass that transforms CBZ/CBNZ/TBZ/TBNZ instructions 10 /// into a conditional branch (B.cond), when the NZCV flags can be set for 11 /// "free". This is preferred on targets that have more flexibility when 12 /// scheduling B.cond instructions as compared to CBZ/CBNZ/TBZ/TBNZ (assuming 13 /// all other variables are equal). This can also reduce register pressure. 14 /// 15 /// A few examples: 16 /// 17 /// 1) add w8, w0, w1 -> cmn w0, w1 ; CMN is an alias of ADDS. 18 /// cbz w8, .LBB_2 -> b.eq .LBB0_2 19 /// 20 /// 2) add w8, w0, w1 -> adds w8, w0, w1 ; w8 has multiple uses. 21 /// cbz w8, .LBB1_2 -> b.eq .LBB1_2 22 /// 23 /// 3) sub w8, w0, w1 -> subs w8, w0, w1 ; w8 has multiple uses. 24 /// tbz w8, #31, .LBB6_2 -> b.pl .LBB6_2 25 /// 26 //===----------------------------------------------------------------------===// 27 28 #include "AArch64.h" 29 #include "AArch64Subtarget.h" 30 #include "llvm/CodeGen/MachineFunction.h" 31 #include "llvm/CodeGen/MachineFunctionPass.h" 32 #include "llvm/CodeGen/MachineInstrBuilder.h" 33 #include "llvm/CodeGen/MachineRegisterInfo.h" 34 #include "llvm/CodeGen/Passes.h" 35 #include "llvm/CodeGen/TargetInstrInfo.h" 36 #include "llvm/CodeGen/TargetRegisterInfo.h" 37 #include "llvm/CodeGen/TargetSubtargetInfo.h" 38 #include "llvm/Support/Debug.h" 39 #include "llvm/Support/raw_ostream.h" 40 41 using namespace llvm; 42 43 #define DEBUG_TYPE "aarch64-cond-br-tuning" 44 #define AARCH64_CONDBR_TUNING_NAME "AArch64 Conditional Branch Tuning" 45 46 namespace { 47 class AArch64CondBrTuning : public MachineFunctionPass { 48 const AArch64InstrInfo *TII; 49 const TargetRegisterInfo *TRI; 50 51 MachineRegisterInfo *MRI; 52 53 public: 54 static char ID; 55 AArch64CondBrTuning() : MachineFunctionPass(ID) {} 56 void getAnalysisUsage(AnalysisUsage &AU) const override; 57 bool runOnMachineFunction(MachineFunction &MF) override; 58 StringRef getPassName() const override { return AARCH64_CONDBR_TUNING_NAME; } 59 60 private: 61 MachineInstr *getOperandDef(const MachineOperand &MO); 62 MachineInstr *convertToFlagSetting(MachineInstr &MI, bool IsFlagSetting, 63 bool Is64Bit); 64 MachineInstr *convertToCondBr(MachineInstr &MI); 65 bool tryToTuneBranch(MachineInstr &MI, MachineInstr &DefMI); 66 }; 67 } // end anonymous namespace 68 69 char AArch64CondBrTuning::ID = 0; 70 71 INITIALIZE_PASS(AArch64CondBrTuning, "aarch64-cond-br-tuning", 72 AARCH64_CONDBR_TUNING_NAME, false, false) 73 74 void AArch64CondBrTuning::getAnalysisUsage(AnalysisUsage &AU) const { 75 AU.setPreservesCFG(); 76 MachineFunctionPass::getAnalysisUsage(AU); 77 } 78 79 MachineInstr *AArch64CondBrTuning::getOperandDef(const MachineOperand &MO) { 80 if (!MO.getReg().isVirtual()) 81 return nullptr; 82 return MRI->getUniqueVRegDef(MO.getReg()); 83 } 84 85 MachineInstr *AArch64CondBrTuning::convertToFlagSetting(MachineInstr &MI, 86 bool IsFlagSetting, 87 bool Is64Bit) { 88 // If this is already the flag setting version of the instruction (e.g., SUBS) 89 // just make sure the implicit-def of NZCV isn't marked dead. 90 if (IsFlagSetting) { 91 for (MachineOperand &MO : MI.implicit_operands()) 92 if (MO.isReg() && MO.isDead() && MO.getReg() == AArch64::NZCV) 93 MO.setIsDead(false); 94 return &MI; 95 } 96 unsigned NewOpc = TII->convertToFlagSettingOpc(MI.getOpcode()); 97 Register NewDestReg = MI.getOperand(0).getReg(); 98 if (MRI->hasOneNonDBGUse(MI.getOperand(0).getReg())) 99 NewDestReg = Is64Bit ? AArch64::XZR : AArch64::WZR; 100 101 MachineInstrBuilder MIB = BuildMI(*MI.getParent(), MI, MI.getDebugLoc(), 102 TII->get(NewOpc), NewDestReg); 103 104 // If the MI has a debug instruction number, preserve that in the new Machine 105 // Instruction that is created. 106 if (MI.peekDebugInstrNum() != 0) 107 MIB->setDebugInstrNum(MI.peekDebugInstrNum()); 108 109 for (const MachineOperand &MO : llvm::drop_begin(MI.operands())) 110 MIB.add(MO); 111 112 return MIB; 113 } 114 115 MachineInstr *AArch64CondBrTuning::convertToCondBr(MachineInstr &MI) { 116 AArch64CC::CondCode CC; 117 MachineBasicBlock *TargetMBB = TII->getBranchDestBlock(MI); 118 switch (MI.getOpcode()) { 119 default: 120 llvm_unreachable("Unexpected opcode!"); 121 122 case AArch64::CBZW: 123 case AArch64::CBZX: 124 CC = AArch64CC::EQ; 125 break; 126 case AArch64::CBNZW: 127 case AArch64::CBNZX: 128 CC = AArch64CC::NE; 129 break; 130 case AArch64::TBZW: 131 case AArch64::TBZX: 132 CC = AArch64CC::PL; 133 break; 134 case AArch64::TBNZW: 135 case AArch64::TBNZX: 136 CC = AArch64CC::MI; 137 break; 138 } 139 return BuildMI(*MI.getParent(), MI, MI.getDebugLoc(), TII->get(AArch64::Bcc)) 140 .addImm(CC) 141 .addMBB(TargetMBB); 142 } 143 144 bool AArch64CondBrTuning::tryToTuneBranch(MachineInstr &MI, 145 MachineInstr &DefMI) { 146 // We don't want NZCV bits live across blocks. 147 if (MI.getParent() != DefMI.getParent()) 148 return false; 149 150 bool IsFlagSetting = true; 151 unsigned MIOpc = MI.getOpcode(); 152 MachineInstr *NewCmp = nullptr, *NewBr = nullptr; 153 switch (DefMI.getOpcode()) { 154 default: 155 return false; 156 case AArch64::ADDWri: 157 case AArch64::ADDWrr: 158 case AArch64::ADDWrs: 159 case AArch64::ADDWrx: 160 case AArch64::ANDWri: 161 case AArch64::ANDWrr: 162 case AArch64::ANDWrs: 163 case AArch64::BICWrr: 164 case AArch64::BICWrs: 165 case AArch64::SUBWri: 166 case AArch64::SUBWrr: 167 case AArch64::SUBWrs: 168 case AArch64::SUBWrx: 169 IsFlagSetting = false; 170 [[fallthrough]]; 171 case AArch64::ADDSWri: 172 case AArch64::ADDSWrr: 173 case AArch64::ADDSWrs: 174 case AArch64::ADDSWrx: 175 case AArch64::ANDSWri: 176 case AArch64::ANDSWrr: 177 case AArch64::ANDSWrs: 178 case AArch64::BICSWrr: 179 case AArch64::BICSWrs: 180 case AArch64::SUBSWri: 181 case AArch64::SUBSWrr: 182 case AArch64::SUBSWrs: 183 case AArch64::SUBSWrx: 184 switch (MIOpc) { 185 default: 186 llvm_unreachable("Unexpected opcode!"); 187 188 case AArch64::CBZW: 189 case AArch64::CBNZW: 190 case AArch64::TBZW: 191 case AArch64::TBNZW: 192 // Check to see if the TBZ/TBNZ is checking the sign bit. 193 if ((MIOpc == AArch64::TBZW || MIOpc == AArch64::TBNZW) && 194 MI.getOperand(1).getImm() != 31) 195 return false; 196 197 // There must not be any instruction between DefMI and MI that clobbers or 198 // reads NZCV. 199 if (isNZCVTouchedInInstructionRange(DefMI, MI, TRI)) 200 return false; 201 LLVM_DEBUG(dbgs() << " Replacing instructions:\n "); 202 LLVM_DEBUG(DefMI.print(dbgs())); 203 LLVM_DEBUG(dbgs() << " "); 204 LLVM_DEBUG(MI.print(dbgs())); 205 206 NewCmp = convertToFlagSetting(DefMI, IsFlagSetting, /*Is64Bit=*/false); 207 NewBr = convertToCondBr(MI); 208 break; 209 } 210 break; 211 212 case AArch64::ADDXri: 213 case AArch64::ADDXrr: 214 case AArch64::ADDXrs: 215 case AArch64::ADDXrx: 216 case AArch64::ANDXri: 217 case AArch64::ANDXrr: 218 case AArch64::ANDXrs: 219 case AArch64::BICXrr: 220 case AArch64::BICXrs: 221 case AArch64::SUBXri: 222 case AArch64::SUBXrr: 223 case AArch64::SUBXrs: 224 case AArch64::SUBXrx: 225 IsFlagSetting = false; 226 [[fallthrough]]; 227 case AArch64::ADDSXri: 228 case AArch64::ADDSXrr: 229 case AArch64::ADDSXrs: 230 case AArch64::ADDSXrx: 231 case AArch64::ANDSXri: 232 case AArch64::ANDSXrr: 233 case AArch64::ANDSXrs: 234 case AArch64::BICSXrr: 235 case AArch64::BICSXrs: 236 case AArch64::SUBSXri: 237 case AArch64::SUBSXrr: 238 case AArch64::SUBSXrs: 239 case AArch64::SUBSXrx: 240 switch (MIOpc) { 241 default: 242 llvm_unreachable("Unexpected opcode!"); 243 244 case AArch64::CBZX: 245 case AArch64::CBNZX: 246 case AArch64::TBZX: 247 case AArch64::TBNZX: { 248 // Check to see if the TBZ/TBNZ is checking the sign bit. 249 if ((MIOpc == AArch64::TBZX || MIOpc == AArch64::TBNZX) && 250 MI.getOperand(1).getImm() != 63) 251 return false; 252 // There must not be any instruction between DefMI and MI that clobbers or 253 // reads NZCV. 254 if (isNZCVTouchedInInstructionRange(DefMI, MI, TRI)) 255 return false; 256 LLVM_DEBUG(dbgs() << " Replacing instructions:\n "); 257 LLVM_DEBUG(DefMI.print(dbgs())); 258 LLVM_DEBUG(dbgs() << " "); 259 LLVM_DEBUG(MI.print(dbgs())); 260 261 NewCmp = convertToFlagSetting(DefMI, IsFlagSetting, /*Is64Bit=*/true); 262 NewBr = convertToCondBr(MI); 263 break; 264 } 265 } 266 break; 267 } 268 (void)NewCmp; (void)NewBr; 269 assert(NewCmp && NewBr && "Expected new instructions."); 270 271 LLVM_DEBUG(dbgs() << " with instruction:\n "); 272 LLVM_DEBUG(NewCmp->print(dbgs())); 273 LLVM_DEBUG(dbgs() << " "); 274 LLVM_DEBUG(NewBr->print(dbgs())); 275 276 // If this was a flag setting version of the instruction, we use the original 277 // instruction by just clearing the dead marked on the implicit-def of NCZV. 278 // Therefore, we should not erase this instruction. 279 if (!IsFlagSetting) 280 DefMI.eraseFromParent(); 281 MI.eraseFromParent(); 282 return true; 283 } 284 285 bool AArch64CondBrTuning::runOnMachineFunction(MachineFunction &MF) { 286 if (skipFunction(MF.getFunction())) 287 return false; 288 289 LLVM_DEBUG( 290 dbgs() << "********** AArch64 Conditional Branch Tuning **********\n" 291 << "********** Function: " << MF.getName() << '\n'); 292 293 TII = static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo()); 294 TRI = MF.getSubtarget().getRegisterInfo(); 295 MRI = &MF.getRegInfo(); 296 297 bool Changed = false; 298 for (MachineBasicBlock &MBB : MF) { 299 bool LocalChange = false; 300 for (MachineInstr &MI : MBB.terminators()) { 301 switch (MI.getOpcode()) { 302 default: 303 break; 304 case AArch64::CBZW: 305 case AArch64::CBZX: 306 case AArch64::CBNZW: 307 case AArch64::CBNZX: 308 case AArch64::TBZW: 309 case AArch64::TBZX: 310 case AArch64::TBNZW: 311 case AArch64::TBNZX: 312 MachineInstr *DefMI = getOperandDef(MI.getOperand(0)); 313 LocalChange = (DefMI && tryToTuneBranch(MI, *DefMI)); 314 break; 315 } 316 // If the optimization was successful, we can't optimize any other 317 // branches because doing so would clobber the NZCV flags. 318 if (LocalChange) { 319 Changed = true; 320 break; 321 } 322 } 323 } 324 return Changed; 325 } 326 327 FunctionPass *llvm::createAArch64CondBrTuning() { 328 return new AArch64CondBrTuning(); 329 } 330