1 //===-- AArch64CondBrTuning.cpp --- Conditional branch tuning for AArch64 -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// \file 9 /// This file contains a pass that transforms CBZ/CBNZ/TBZ/TBNZ instructions 10 /// into a conditional branch (B.cond), when the NZCV flags can be set for 11 /// "free". This is preferred on targets that have more flexibility when 12 /// scheduling B.cond instructions as compared to CBZ/CBNZ/TBZ/TBNZ (assuming 13 /// all other variables are equal). This can also reduce register pressure. 14 /// 15 /// A few examples: 16 /// 17 /// 1) add w8, w0, w1 -> cmn w0, w1 ; CMN is an alias of ADDS. 18 /// cbz w8, .LBB_2 -> b.eq .LBB0_2 19 /// 20 /// 2) add w8, w0, w1 -> adds w8, w0, w1 ; w8 has multiple uses. 21 /// cbz w8, .LBB1_2 -> b.eq .LBB1_2 22 /// 23 /// 3) sub w8, w0, w1 -> subs w8, w0, w1 ; w8 has multiple uses. 24 /// tbz w8, #31, .LBB6_2 -> b.pl .LBB6_2 25 /// 26 //===----------------------------------------------------------------------===// 27 28 #include "AArch64.h" 29 #include "AArch64Subtarget.h" 30 #include "llvm/CodeGen/MachineFunction.h" 31 #include "llvm/CodeGen/MachineFunctionPass.h" 32 #include "llvm/CodeGen/MachineInstrBuilder.h" 33 #include "llvm/CodeGen/MachineRegisterInfo.h" 34 #include "llvm/CodeGen/Passes.h" 35 #include "llvm/CodeGen/TargetInstrInfo.h" 36 #include "llvm/CodeGen/TargetRegisterInfo.h" 37 #include "llvm/CodeGen/TargetSubtargetInfo.h" 38 #include "llvm/Support/Debug.h" 39 #include "llvm/Support/raw_ostream.h" 40 41 using namespace llvm; 42 43 #define DEBUG_TYPE "aarch64-cond-br-tuning" 44 #define AARCH64_CONDBR_TUNING_NAME "AArch64 Conditional Branch Tuning" 45 46 namespace { 47 class AArch64CondBrTuning : public MachineFunctionPass { 48 const AArch64InstrInfo *TII; 49 const TargetRegisterInfo *TRI; 50 51 MachineRegisterInfo *MRI; 52 53 public: 54 static char ID; 55 AArch64CondBrTuning() : MachineFunctionPass(ID) { 56 initializeAArch64CondBrTuningPass(*PassRegistry::getPassRegistry()); 57 } 58 void getAnalysisUsage(AnalysisUsage &AU) const override; 59 bool runOnMachineFunction(MachineFunction &MF) override; 60 StringRef getPassName() const override { return AARCH64_CONDBR_TUNING_NAME; } 61 62 private: 63 MachineInstr *getOperandDef(const MachineOperand &MO); 64 MachineInstr *convertToFlagSetting(MachineInstr &MI, bool IsFlagSetting); 65 MachineInstr *convertToCondBr(MachineInstr &MI); 66 bool tryToTuneBranch(MachineInstr &MI, MachineInstr &DefMI); 67 }; 68 } // end anonymous namespace 69 70 char AArch64CondBrTuning::ID = 0; 71 72 INITIALIZE_PASS(AArch64CondBrTuning, "aarch64-cond-br-tuning", 73 AARCH64_CONDBR_TUNING_NAME, false, false) 74 75 void AArch64CondBrTuning::getAnalysisUsage(AnalysisUsage &AU) const { 76 AU.setPreservesCFG(); 77 MachineFunctionPass::getAnalysisUsage(AU); 78 } 79 80 MachineInstr *AArch64CondBrTuning::getOperandDef(const MachineOperand &MO) { 81 if (!Register::isVirtualRegister(MO.getReg())) 82 return nullptr; 83 return MRI->getUniqueVRegDef(MO.getReg()); 84 } 85 86 MachineInstr *AArch64CondBrTuning::convertToFlagSetting(MachineInstr &MI, 87 bool IsFlagSetting) { 88 // If this is already the flag setting version of the instruction (e.g., SUBS) 89 // just make sure the implicit-def of NZCV isn't marked dead. 90 if (IsFlagSetting) { 91 for (unsigned I = MI.getNumExplicitOperands(), E = MI.getNumOperands(); 92 I != E; ++I) { 93 MachineOperand &MO = MI.getOperand(I); 94 if (MO.isReg() && MO.isDead() && MO.getReg() == AArch64::NZCV) 95 MO.setIsDead(false); 96 } 97 return &MI; 98 } 99 bool Is64Bit; 100 unsigned NewOpc = TII->convertToFlagSettingOpc(MI.getOpcode(), Is64Bit); 101 Register NewDestReg = MI.getOperand(0).getReg(); 102 if (MRI->hasOneNonDBGUse(MI.getOperand(0).getReg())) 103 NewDestReg = Is64Bit ? AArch64::XZR : AArch64::WZR; 104 105 MachineInstrBuilder MIB = BuildMI(*MI.getParent(), MI, MI.getDebugLoc(), 106 TII->get(NewOpc), NewDestReg); 107 for (unsigned I = 1, E = MI.getNumOperands(); I != E; ++I) 108 MIB.add(MI.getOperand(I)); 109 110 return MIB; 111 } 112 113 MachineInstr *AArch64CondBrTuning::convertToCondBr(MachineInstr &MI) { 114 AArch64CC::CondCode CC; 115 MachineBasicBlock *TargetMBB = TII->getBranchDestBlock(MI); 116 switch (MI.getOpcode()) { 117 default: 118 llvm_unreachable("Unexpected opcode!"); 119 120 case AArch64::CBZW: 121 case AArch64::CBZX: 122 CC = AArch64CC::EQ; 123 break; 124 case AArch64::CBNZW: 125 case AArch64::CBNZX: 126 CC = AArch64CC::NE; 127 break; 128 case AArch64::TBZW: 129 case AArch64::TBZX: 130 CC = AArch64CC::PL; 131 break; 132 case AArch64::TBNZW: 133 case AArch64::TBNZX: 134 CC = AArch64CC::MI; 135 break; 136 } 137 return BuildMI(*MI.getParent(), MI, MI.getDebugLoc(), TII->get(AArch64::Bcc)) 138 .addImm(CC) 139 .addMBB(TargetMBB); 140 } 141 142 bool AArch64CondBrTuning::tryToTuneBranch(MachineInstr &MI, 143 MachineInstr &DefMI) { 144 // We don't want NZCV bits live across blocks. 145 if (MI.getParent() != DefMI.getParent()) 146 return false; 147 148 bool IsFlagSetting = true; 149 unsigned MIOpc = MI.getOpcode(); 150 MachineInstr *NewCmp = nullptr, *NewBr = nullptr; 151 switch (DefMI.getOpcode()) { 152 default: 153 return false; 154 case AArch64::ADDWri: 155 case AArch64::ADDWrr: 156 case AArch64::ADDWrs: 157 case AArch64::ADDWrx: 158 case AArch64::ANDWri: 159 case AArch64::ANDWrr: 160 case AArch64::ANDWrs: 161 case AArch64::BICWrr: 162 case AArch64::BICWrs: 163 case AArch64::SUBWri: 164 case AArch64::SUBWrr: 165 case AArch64::SUBWrs: 166 case AArch64::SUBWrx: 167 IsFlagSetting = false; 168 LLVM_FALLTHROUGH; 169 case AArch64::ADDSWri: 170 case AArch64::ADDSWrr: 171 case AArch64::ADDSWrs: 172 case AArch64::ADDSWrx: 173 case AArch64::ANDSWri: 174 case AArch64::ANDSWrr: 175 case AArch64::ANDSWrs: 176 case AArch64::BICSWrr: 177 case AArch64::BICSWrs: 178 case AArch64::SUBSWri: 179 case AArch64::SUBSWrr: 180 case AArch64::SUBSWrs: 181 case AArch64::SUBSWrx: 182 switch (MIOpc) { 183 default: 184 llvm_unreachable("Unexpected opcode!"); 185 186 case AArch64::CBZW: 187 case AArch64::CBNZW: 188 case AArch64::TBZW: 189 case AArch64::TBNZW: 190 // Check to see if the TBZ/TBNZ is checking the sign bit. 191 if ((MIOpc == AArch64::TBZW || MIOpc == AArch64::TBNZW) && 192 MI.getOperand(1).getImm() != 31) 193 return false; 194 195 // There must not be any instruction between DefMI and MI that clobbers or 196 // reads NZCV. 197 if (isNZCVTouchedInInstructionRange(DefMI, MI, TRI)) 198 return false; 199 LLVM_DEBUG(dbgs() << " Replacing instructions:\n "); 200 LLVM_DEBUG(DefMI.print(dbgs())); 201 LLVM_DEBUG(dbgs() << " "); 202 LLVM_DEBUG(MI.print(dbgs())); 203 204 NewCmp = convertToFlagSetting(DefMI, IsFlagSetting); 205 NewBr = convertToCondBr(MI); 206 break; 207 } 208 break; 209 210 case AArch64::ADDXri: 211 case AArch64::ADDXrr: 212 case AArch64::ADDXrs: 213 case AArch64::ADDXrx: 214 case AArch64::ANDXri: 215 case AArch64::ANDXrr: 216 case AArch64::ANDXrs: 217 case AArch64::BICXrr: 218 case AArch64::BICXrs: 219 case AArch64::SUBXri: 220 case AArch64::SUBXrr: 221 case AArch64::SUBXrs: 222 case AArch64::SUBXrx: 223 IsFlagSetting = false; 224 LLVM_FALLTHROUGH; 225 case AArch64::ADDSXri: 226 case AArch64::ADDSXrr: 227 case AArch64::ADDSXrs: 228 case AArch64::ADDSXrx: 229 case AArch64::ANDSXri: 230 case AArch64::ANDSXrr: 231 case AArch64::ANDSXrs: 232 case AArch64::BICSXrr: 233 case AArch64::BICSXrs: 234 case AArch64::SUBSXri: 235 case AArch64::SUBSXrr: 236 case AArch64::SUBSXrs: 237 case AArch64::SUBSXrx: 238 switch (MIOpc) { 239 default: 240 llvm_unreachable("Unexpected opcode!"); 241 242 case AArch64::CBZX: 243 case AArch64::CBNZX: 244 case AArch64::TBZX: 245 case AArch64::TBNZX: { 246 // Check to see if the TBZ/TBNZ is checking the sign bit. 247 if ((MIOpc == AArch64::TBZX || MIOpc == AArch64::TBNZX) && 248 MI.getOperand(1).getImm() != 63) 249 return false; 250 // There must not be any instruction between DefMI and MI that clobbers or 251 // reads NZCV. 252 if (isNZCVTouchedInInstructionRange(DefMI, MI, TRI)) 253 return false; 254 LLVM_DEBUG(dbgs() << " Replacing instructions:\n "); 255 LLVM_DEBUG(DefMI.print(dbgs())); 256 LLVM_DEBUG(dbgs() << " "); 257 LLVM_DEBUG(MI.print(dbgs())); 258 259 NewCmp = convertToFlagSetting(DefMI, IsFlagSetting); 260 NewBr = convertToCondBr(MI); 261 break; 262 } 263 } 264 break; 265 } 266 (void)NewCmp; (void)NewBr; 267 assert(NewCmp && NewBr && "Expected new instructions."); 268 269 LLVM_DEBUG(dbgs() << " with instruction:\n "); 270 LLVM_DEBUG(NewCmp->print(dbgs())); 271 LLVM_DEBUG(dbgs() << " "); 272 LLVM_DEBUG(NewBr->print(dbgs())); 273 274 // If this was a flag setting version of the instruction, we use the original 275 // instruction by just clearing the dead marked on the implicit-def of NCZV. 276 // Therefore, we should not erase this instruction. 277 if (!IsFlagSetting) 278 DefMI.eraseFromParent(); 279 MI.eraseFromParent(); 280 return true; 281 } 282 283 bool AArch64CondBrTuning::runOnMachineFunction(MachineFunction &MF) { 284 if (skipFunction(MF.getFunction())) 285 return false; 286 287 LLVM_DEBUG( 288 dbgs() << "********** AArch64 Conditional Branch Tuning **********\n" 289 << "********** Function: " << MF.getName() << '\n'); 290 291 TII = static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo()); 292 TRI = MF.getSubtarget().getRegisterInfo(); 293 MRI = &MF.getRegInfo(); 294 295 bool Changed = false; 296 for (MachineBasicBlock &MBB : MF) { 297 bool LocalChange = false; 298 for (MachineBasicBlock::iterator I = MBB.getFirstTerminator(), 299 E = MBB.end(); 300 I != E; ++I) { 301 MachineInstr &MI = *I; 302 switch (MI.getOpcode()) { 303 default: 304 break; 305 case AArch64::CBZW: 306 case AArch64::CBZX: 307 case AArch64::CBNZW: 308 case AArch64::CBNZX: 309 case AArch64::TBZW: 310 case AArch64::TBZX: 311 case AArch64::TBNZW: 312 case AArch64::TBNZX: 313 MachineInstr *DefMI = getOperandDef(MI.getOperand(0)); 314 LocalChange = (DefMI && tryToTuneBranch(MI, *DefMI)); 315 break; 316 } 317 // If the optimization was successful, we can't optimize any other 318 // branches because doing so would clobber the NZCV flags. 319 if (LocalChange) { 320 Changed = true; 321 break; 322 } 323 } 324 } 325 return Changed; 326 } 327 328 FunctionPass *llvm::createAArch64CondBrTuning() { 329 return new AArch64CondBrTuning(); 330 } 331