1 //===-- AArch64CondBrTuning.cpp --- Conditional branch tuning for AArch64 -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// \file 9 /// This file contains a pass that transforms CBZ/CBNZ/TBZ/TBNZ instructions 10 /// into a conditional branch (B.cond), when the NZCV flags can be set for 11 /// "free". This is preferred on targets that have more flexibility when 12 /// scheduling B.cond instructions as compared to CBZ/CBNZ/TBZ/TBNZ (assuming 13 /// all other variables are equal). This can also reduce register pressure. 14 /// 15 /// A few examples: 16 /// 17 /// 1) add w8, w0, w1 -> cmn w0, w1 ; CMN is an alias of ADDS. 18 /// cbz w8, .LBB_2 -> b.eq .LBB0_2 19 /// 20 /// 2) add w8, w0, w1 -> adds w8, w0, w1 ; w8 has multiple uses. 21 /// cbz w8, .LBB1_2 -> b.eq .LBB1_2 22 /// 23 /// 3) sub w8, w0, w1 -> subs w8, w0, w1 ; w8 has multiple uses. 24 /// tbz w8, #31, .LBB6_2 -> b.pl .LBB6_2 25 /// 26 //===----------------------------------------------------------------------===// 27 28 #include "AArch64.h" 29 #include "AArch64Subtarget.h" 30 #include "llvm/CodeGen/MachineFunction.h" 31 #include "llvm/CodeGen/MachineFunctionPass.h" 32 #include "llvm/CodeGen/MachineInstrBuilder.h" 33 #include "llvm/CodeGen/MachineRegisterInfo.h" 34 #include "llvm/CodeGen/Passes.h" 35 #include "llvm/CodeGen/TargetInstrInfo.h" 36 #include "llvm/CodeGen/TargetRegisterInfo.h" 37 #include "llvm/CodeGen/TargetSubtargetInfo.h" 38 #include "llvm/Support/Debug.h" 39 #include "llvm/Support/raw_ostream.h" 40 41 using namespace llvm; 42 43 #define DEBUG_TYPE "aarch64-cond-br-tuning" 44 #define AARCH64_CONDBR_TUNING_NAME "AArch64 Conditional Branch Tuning" 45 46 namespace { 47 class AArch64CondBrTuning : public MachineFunctionPass { 48 const AArch64InstrInfo *TII; 49 const TargetRegisterInfo *TRI; 50 51 MachineRegisterInfo *MRI; 52 53 public: 54 static char ID; 55 AArch64CondBrTuning() : MachineFunctionPass(ID) { 56 initializeAArch64CondBrTuningPass(*PassRegistry::getPassRegistry()); 57 } 58 void getAnalysisUsage(AnalysisUsage &AU) const override; 59 bool runOnMachineFunction(MachineFunction &MF) override; 60 StringRef getPassName() const override { return AARCH64_CONDBR_TUNING_NAME; } 61 62 private: 63 MachineInstr *getOperandDef(const MachineOperand &MO); 64 MachineInstr *convertToFlagSetting(MachineInstr &MI, bool IsFlagSetting); 65 MachineInstr *convertToCondBr(MachineInstr &MI); 66 bool tryToTuneBranch(MachineInstr &MI, MachineInstr &DefMI); 67 }; 68 } // end anonymous namespace 69 70 char AArch64CondBrTuning::ID = 0; 71 72 INITIALIZE_PASS(AArch64CondBrTuning, "aarch64-cond-br-tuning", 73 AARCH64_CONDBR_TUNING_NAME, false, false) 74 75 void AArch64CondBrTuning::getAnalysisUsage(AnalysisUsage &AU) const { 76 AU.setPreservesCFG(); 77 MachineFunctionPass::getAnalysisUsage(AU); 78 } 79 80 MachineInstr *AArch64CondBrTuning::getOperandDef(const MachineOperand &MO) { 81 if (!TargetRegisterInfo::isVirtualRegister(MO.getReg())) 82 return nullptr; 83 return MRI->getUniqueVRegDef(MO.getReg()); 84 } 85 86 MachineInstr *AArch64CondBrTuning::convertToFlagSetting(MachineInstr &MI, 87 bool IsFlagSetting) { 88 // If this is already the flag setting version of the instruction (e.g., SUBS) 89 // just make sure the implicit-def of NZCV isn't marked dead. 90 if (IsFlagSetting) { 91 for (unsigned I = MI.getNumExplicitOperands(), E = MI.getNumOperands(); 92 I != E; ++I) { 93 MachineOperand &MO = MI.getOperand(I); 94 if (MO.isReg() && MO.isDead() && MO.getReg() == AArch64::NZCV) 95 MO.setIsDead(false); 96 } 97 return &MI; 98 } 99 bool Is64Bit; 100 unsigned NewOpc = TII->convertToFlagSettingOpc(MI.getOpcode(), Is64Bit); 101 unsigned NewDestReg = MI.getOperand(0).getReg(); 102 if (MRI->hasOneNonDBGUse(MI.getOperand(0).getReg())) 103 NewDestReg = Is64Bit ? AArch64::XZR : AArch64::WZR; 104 105 MachineInstrBuilder MIB = BuildMI(*MI.getParent(), MI, MI.getDebugLoc(), 106 TII->get(NewOpc), NewDestReg); 107 for (unsigned I = 1, E = MI.getNumOperands(); I != E; ++I) 108 MIB.add(MI.getOperand(I)); 109 110 return MIB; 111 } 112 113 MachineInstr *AArch64CondBrTuning::convertToCondBr(MachineInstr &MI) { 114 AArch64CC::CondCode CC; 115 MachineBasicBlock *TargetMBB = TII->getBranchDestBlock(MI); 116 switch (MI.getOpcode()) { 117 default: 118 llvm_unreachable("Unexpected opcode!"); 119 120 case AArch64::CBZW: 121 case AArch64::CBZX: 122 CC = AArch64CC::EQ; 123 break; 124 case AArch64::CBNZW: 125 case AArch64::CBNZX: 126 CC = AArch64CC::NE; 127 break; 128 case AArch64::TBZW: 129 case AArch64::TBZX: 130 CC = AArch64CC::PL; 131 break; 132 case AArch64::TBNZW: 133 case AArch64::TBNZX: 134 CC = AArch64CC::MI; 135 break; 136 } 137 return BuildMI(*MI.getParent(), MI, MI.getDebugLoc(), TII->get(AArch64::Bcc)) 138 .addImm(CC) 139 .addMBB(TargetMBB); 140 } 141 142 bool AArch64CondBrTuning::tryToTuneBranch(MachineInstr &MI, 143 MachineInstr &DefMI) { 144 // We don't want NZCV bits live across blocks. 145 if (MI.getParent() != DefMI.getParent()) 146 return false; 147 148 bool IsFlagSetting = true; 149 unsigned MIOpc = MI.getOpcode(); 150 MachineInstr *NewCmp = nullptr, *NewBr = nullptr; 151 switch (DefMI.getOpcode()) { 152 default: 153 return false; 154 case AArch64::ADDWri: 155 case AArch64::ADDWrr: 156 case AArch64::ADDWrs: 157 case AArch64::ADDWrx: 158 case AArch64::ANDWri: 159 case AArch64::ANDWrr: 160 case AArch64::ANDWrs: 161 case AArch64::BICWrr: 162 case AArch64::BICWrs: 163 case AArch64::SUBWri: 164 case AArch64::SUBWrr: 165 case AArch64::SUBWrs: 166 case AArch64::SUBWrx: 167 IsFlagSetting = false; 168 LLVM_FALLTHROUGH; 169 case AArch64::ADDSWri: 170 case AArch64::ADDSWrr: 171 case AArch64::ADDSWrs: 172 case AArch64::ADDSWrx: 173 case AArch64::ANDSWri: 174 case AArch64::ANDSWrr: 175 case AArch64::ANDSWrs: 176 case AArch64::BICSWrr: 177 case AArch64::BICSWrs: 178 case AArch64::SUBSWri: 179 case AArch64::SUBSWrr: 180 case AArch64::SUBSWrs: 181 case AArch64::SUBSWrx: 182 switch (MIOpc) { 183 default: 184 llvm_unreachable("Unexpected opcode!"); 185 186 case AArch64::CBZW: 187 case AArch64::CBNZW: 188 case AArch64::TBZW: 189 case AArch64::TBNZW: 190 // Check to see if the TBZ/TBNZ is checking the sign bit. 191 if ((MIOpc == AArch64::TBZW || MIOpc == AArch64::TBNZW) && 192 MI.getOperand(1).getImm() != 31) 193 return false; 194 195 // There must not be any instruction between DefMI and MI that clobbers or 196 // reads NZCV. 197 MachineBasicBlock::iterator I(DefMI), E(MI); 198 for (I = std::next(I); I != E; ++I) { 199 if (I->modifiesRegister(AArch64::NZCV, TRI) || 200 I->readsRegister(AArch64::NZCV, TRI)) 201 return false; 202 } 203 LLVM_DEBUG(dbgs() << " Replacing instructions:\n "); 204 LLVM_DEBUG(DefMI.print(dbgs())); 205 LLVM_DEBUG(dbgs() << " "); 206 LLVM_DEBUG(MI.print(dbgs())); 207 208 NewCmp = convertToFlagSetting(DefMI, IsFlagSetting); 209 NewBr = convertToCondBr(MI); 210 break; 211 } 212 break; 213 214 case AArch64::ADDXri: 215 case AArch64::ADDXrr: 216 case AArch64::ADDXrs: 217 case AArch64::ADDXrx: 218 case AArch64::ANDXri: 219 case AArch64::ANDXrr: 220 case AArch64::ANDXrs: 221 case AArch64::BICXrr: 222 case AArch64::BICXrs: 223 case AArch64::SUBXri: 224 case AArch64::SUBXrr: 225 case AArch64::SUBXrs: 226 case AArch64::SUBXrx: 227 IsFlagSetting = false; 228 LLVM_FALLTHROUGH; 229 case AArch64::ADDSXri: 230 case AArch64::ADDSXrr: 231 case AArch64::ADDSXrs: 232 case AArch64::ADDSXrx: 233 case AArch64::ANDSXri: 234 case AArch64::ANDSXrr: 235 case AArch64::ANDSXrs: 236 case AArch64::BICSXrr: 237 case AArch64::BICSXrs: 238 case AArch64::SUBSXri: 239 case AArch64::SUBSXrr: 240 case AArch64::SUBSXrs: 241 case AArch64::SUBSXrx: 242 switch (MIOpc) { 243 default: 244 llvm_unreachable("Unexpected opcode!"); 245 246 case AArch64::CBZX: 247 case AArch64::CBNZX: 248 case AArch64::TBZX: 249 case AArch64::TBNZX: { 250 // Check to see if the TBZ/TBNZ is checking the sign bit. 251 if ((MIOpc == AArch64::TBZX || MIOpc == AArch64::TBNZX) && 252 MI.getOperand(1).getImm() != 63) 253 return false; 254 // There must not be any instruction between DefMI and MI that clobbers or 255 // reads NZCV. 256 MachineBasicBlock::iterator I(DefMI), E(MI); 257 for (I = std::next(I); I != E; ++I) { 258 if (I->modifiesRegister(AArch64::NZCV, TRI) || 259 I->readsRegister(AArch64::NZCV, TRI)) 260 return false; 261 } 262 LLVM_DEBUG(dbgs() << " Replacing instructions:\n "); 263 LLVM_DEBUG(DefMI.print(dbgs())); 264 LLVM_DEBUG(dbgs() << " "); 265 LLVM_DEBUG(MI.print(dbgs())); 266 267 NewCmp = convertToFlagSetting(DefMI, IsFlagSetting); 268 NewBr = convertToCondBr(MI); 269 break; 270 } 271 } 272 break; 273 } 274 (void)NewCmp; (void)NewBr; 275 assert(NewCmp && NewBr && "Expected new instructions."); 276 277 LLVM_DEBUG(dbgs() << " with instruction:\n "); 278 LLVM_DEBUG(NewCmp->print(dbgs())); 279 LLVM_DEBUG(dbgs() << " "); 280 LLVM_DEBUG(NewBr->print(dbgs())); 281 282 // If this was a flag setting version of the instruction, we use the original 283 // instruction by just clearing the dead marked on the implicit-def of NCZV. 284 // Therefore, we should not erase this instruction. 285 if (!IsFlagSetting) 286 DefMI.eraseFromParent(); 287 MI.eraseFromParent(); 288 return true; 289 } 290 291 bool AArch64CondBrTuning::runOnMachineFunction(MachineFunction &MF) { 292 if (skipFunction(MF.getFunction())) 293 return false; 294 295 LLVM_DEBUG( 296 dbgs() << "********** AArch64 Conditional Branch Tuning **********\n" 297 << "********** Function: " << MF.getName() << '\n'); 298 299 TII = static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo()); 300 TRI = MF.getSubtarget().getRegisterInfo(); 301 MRI = &MF.getRegInfo(); 302 303 bool Changed = false; 304 for (MachineBasicBlock &MBB : MF) { 305 bool LocalChange = false; 306 for (MachineBasicBlock::iterator I = MBB.getFirstTerminator(), 307 E = MBB.end(); 308 I != E; ++I) { 309 MachineInstr &MI = *I; 310 switch (MI.getOpcode()) { 311 default: 312 break; 313 case AArch64::CBZW: 314 case AArch64::CBZX: 315 case AArch64::CBNZW: 316 case AArch64::CBNZX: 317 case AArch64::TBZW: 318 case AArch64::TBZX: 319 case AArch64::TBNZW: 320 case AArch64::TBNZX: 321 MachineInstr *DefMI = getOperandDef(MI.getOperand(0)); 322 LocalChange = (DefMI && tryToTuneBranch(MI, *DefMI)); 323 break; 324 } 325 // If the optimization was successful, we can't optimize any other 326 // branches because doing so would clobber the NZCV flags. 327 if (LocalChange) { 328 Changed = true; 329 break; 330 } 331 } 332 } 333 return Changed; 334 } 335 336 FunctionPass *llvm::createAArch64CondBrTuning() { 337 return new AArch64CondBrTuning(); 338 } 339