1 //===-- AArch64CondBrTuning.cpp --- Conditional branch tuning for AArch64 -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// \file 9 /// This file contains a pass that transforms CBZ/CBNZ/TBZ/TBNZ instructions 10 /// into a conditional branch (B.cond), when the NZCV flags can be set for 11 /// "free". This is preferred on targets that have more flexibility when 12 /// scheduling B.cond instructions as compared to CBZ/CBNZ/TBZ/TBNZ (assuming 13 /// all other variables are equal). This can also reduce register pressure. 14 /// 15 /// A few examples: 16 /// 17 /// 1) add w8, w0, w1 -> cmn w0, w1 ; CMN is an alias of ADDS. 18 /// cbz w8, .LBB_2 -> b.eq .LBB0_2 19 /// 20 /// 2) add w8, w0, w1 -> adds w8, w0, w1 ; w8 has multiple uses. 21 /// cbz w8, .LBB1_2 -> b.eq .LBB1_2 22 /// 23 /// 3) sub w8, w0, w1 -> subs w8, w0, w1 ; w8 has multiple uses. 24 /// tbz w8, #31, .LBB6_2 -> b.pl .LBB6_2 25 /// 26 //===----------------------------------------------------------------------===// 27 28 #include "AArch64.h" 29 #include "AArch64Subtarget.h" 30 #include "llvm/CodeGen/MachineFunction.h" 31 #include "llvm/CodeGen/MachineFunctionPass.h" 32 #include "llvm/CodeGen/MachineInstrBuilder.h" 33 #include "llvm/CodeGen/MachineRegisterInfo.h" 34 #include "llvm/CodeGen/Passes.h" 35 #include "llvm/CodeGen/TargetInstrInfo.h" 36 #include "llvm/CodeGen/TargetRegisterInfo.h" 37 #include "llvm/CodeGen/TargetSubtargetInfo.h" 38 #include "llvm/Support/Debug.h" 39 #include "llvm/Support/raw_ostream.h" 40 41 using namespace llvm; 42 43 #define DEBUG_TYPE "aarch64-cond-br-tuning" 44 #define AARCH64_CONDBR_TUNING_NAME "AArch64 Conditional Branch Tuning" 45 46 namespace { 47 class AArch64CondBrTuning : public MachineFunctionPass { 48 const AArch64InstrInfo *TII; 49 const TargetRegisterInfo *TRI; 50 51 MachineRegisterInfo *MRI; 52 53 public: 54 static char ID; 55 AArch64CondBrTuning() : MachineFunctionPass(ID) { 56 initializeAArch64CondBrTuningPass(*PassRegistry::getPassRegistry()); 57 } 58 void getAnalysisUsage(AnalysisUsage &AU) const override; 59 bool runOnMachineFunction(MachineFunction &MF) override; 60 StringRef getPassName() const override { return AARCH64_CONDBR_TUNING_NAME; } 61 62 private: 63 MachineInstr *getOperandDef(const MachineOperand &MO); 64 MachineInstr *convertToFlagSetting(MachineInstr &MI, bool IsFlagSetting, 65 bool Is64Bit); 66 MachineInstr *convertToCondBr(MachineInstr &MI); 67 bool tryToTuneBranch(MachineInstr &MI, MachineInstr &DefMI); 68 }; 69 } // end anonymous namespace 70 71 char AArch64CondBrTuning::ID = 0; 72 73 INITIALIZE_PASS(AArch64CondBrTuning, "aarch64-cond-br-tuning", 74 AARCH64_CONDBR_TUNING_NAME, false, false) 75 76 void AArch64CondBrTuning::getAnalysisUsage(AnalysisUsage &AU) const { 77 AU.setPreservesCFG(); 78 MachineFunctionPass::getAnalysisUsage(AU); 79 } 80 81 MachineInstr *AArch64CondBrTuning::getOperandDef(const MachineOperand &MO) { 82 if (!MO.getReg().isVirtual()) 83 return nullptr; 84 return MRI->getUniqueVRegDef(MO.getReg()); 85 } 86 87 MachineInstr *AArch64CondBrTuning::convertToFlagSetting(MachineInstr &MI, 88 bool IsFlagSetting, 89 bool Is64Bit) { 90 // If this is already the flag setting version of the instruction (e.g., SUBS) 91 // just make sure the implicit-def of NZCV isn't marked dead. 92 if (IsFlagSetting) { 93 for (MachineOperand &MO : MI.implicit_operands()) 94 if (MO.isReg() && MO.isDead() && MO.getReg() == AArch64::NZCV) 95 MO.setIsDead(false); 96 return &MI; 97 } 98 unsigned NewOpc = TII->convertToFlagSettingOpc(MI.getOpcode()); 99 Register NewDestReg = MI.getOperand(0).getReg(); 100 if (MRI->hasOneNonDBGUse(MI.getOperand(0).getReg())) 101 NewDestReg = Is64Bit ? AArch64::XZR : AArch64::WZR; 102 103 MachineInstrBuilder MIB = BuildMI(*MI.getParent(), MI, MI.getDebugLoc(), 104 TII->get(NewOpc), NewDestReg); 105 for (const MachineOperand &MO : llvm::drop_begin(MI.operands())) 106 MIB.add(MO); 107 108 return MIB; 109 } 110 111 MachineInstr *AArch64CondBrTuning::convertToCondBr(MachineInstr &MI) { 112 AArch64CC::CondCode CC; 113 MachineBasicBlock *TargetMBB = TII->getBranchDestBlock(MI); 114 switch (MI.getOpcode()) { 115 default: 116 llvm_unreachable("Unexpected opcode!"); 117 118 case AArch64::CBZW: 119 case AArch64::CBZX: 120 CC = AArch64CC::EQ; 121 break; 122 case AArch64::CBNZW: 123 case AArch64::CBNZX: 124 CC = AArch64CC::NE; 125 break; 126 case AArch64::TBZW: 127 case AArch64::TBZX: 128 CC = AArch64CC::PL; 129 break; 130 case AArch64::TBNZW: 131 case AArch64::TBNZX: 132 CC = AArch64CC::MI; 133 break; 134 } 135 return BuildMI(*MI.getParent(), MI, MI.getDebugLoc(), TII->get(AArch64::Bcc)) 136 .addImm(CC) 137 .addMBB(TargetMBB); 138 } 139 140 bool AArch64CondBrTuning::tryToTuneBranch(MachineInstr &MI, 141 MachineInstr &DefMI) { 142 // We don't want NZCV bits live across blocks. 143 if (MI.getParent() != DefMI.getParent()) 144 return false; 145 146 bool IsFlagSetting = true; 147 unsigned MIOpc = MI.getOpcode(); 148 MachineInstr *NewCmp = nullptr, *NewBr = nullptr; 149 switch (DefMI.getOpcode()) { 150 default: 151 return false; 152 case AArch64::ADDWri: 153 case AArch64::ADDWrr: 154 case AArch64::ADDWrs: 155 case AArch64::ADDWrx: 156 case AArch64::ANDWri: 157 case AArch64::ANDWrr: 158 case AArch64::ANDWrs: 159 case AArch64::BICWrr: 160 case AArch64::BICWrs: 161 case AArch64::SUBWri: 162 case AArch64::SUBWrr: 163 case AArch64::SUBWrs: 164 case AArch64::SUBWrx: 165 IsFlagSetting = false; 166 [[fallthrough]]; 167 case AArch64::ADDSWri: 168 case AArch64::ADDSWrr: 169 case AArch64::ADDSWrs: 170 case AArch64::ADDSWrx: 171 case AArch64::ANDSWri: 172 case AArch64::ANDSWrr: 173 case AArch64::ANDSWrs: 174 case AArch64::BICSWrr: 175 case AArch64::BICSWrs: 176 case AArch64::SUBSWri: 177 case AArch64::SUBSWrr: 178 case AArch64::SUBSWrs: 179 case AArch64::SUBSWrx: 180 switch (MIOpc) { 181 default: 182 llvm_unreachable("Unexpected opcode!"); 183 184 case AArch64::CBZW: 185 case AArch64::CBNZW: 186 case AArch64::TBZW: 187 case AArch64::TBNZW: 188 // Check to see if the TBZ/TBNZ is checking the sign bit. 189 if ((MIOpc == AArch64::TBZW || MIOpc == AArch64::TBNZW) && 190 MI.getOperand(1).getImm() != 31) 191 return false; 192 193 // There must not be any instruction between DefMI and MI that clobbers or 194 // reads NZCV. 195 if (isNZCVTouchedInInstructionRange(DefMI, MI, TRI)) 196 return false; 197 LLVM_DEBUG(dbgs() << " Replacing instructions:\n "); 198 LLVM_DEBUG(DefMI.print(dbgs())); 199 LLVM_DEBUG(dbgs() << " "); 200 LLVM_DEBUG(MI.print(dbgs())); 201 202 NewCmp = convertToFlagSetting(DefMI, IsFlagSetting, /*Is64Bit=*/false); 203 NewBr = convertToCondBr(MI); 204 break; 205 } 206 break; 207 208 case AArch64::ADDXri: 209 case AArch64::ADDXrr: 210 case AArch64::ADDXrs: 211 case AArch64::ADDXrx: 212 case AArch64::ANDXri: 213 case AArch64::ANDXrr: 214 case AArch64::ANDXrs: 215 case AArch64::BICXrr: 216 case AArch64::BICXrs: 217 case AArch64::SUBXri: 218 case AArch64::SUBXrr: 219 case AArch64::SUBXrs: 220 case AArch64::SUBXrx: 221 IsFlagSetting = false; 222 [[fallthrough]]; 223 case AArch64::ADDSXri: 224 case AArch64::ADDSXrr: 225 case AArch64::ADDSXrs: 226 case AArch64::ADDSXrx: 227 case AArch64::ANDSXri: 228 case AArch64::ANDSXrr: 229 case AArch64::ANDSXrs: 230 case AArch64::BICSXrr: 231 case AArch64::BICSXrs: 232 case AArch64::SUBSXri: 233 case AArch64::SUBSXrr: 234 case AArch64::SUBSXrs: 235 case AArch64::SUBSXrx: 236 switch (MIOpc) { 237 default: 238 llvm_unreachable("Unexpected opcode!"); 239 240 case AArch64::CBZX: 241 case AArch64::CBNZX: 242 case AArch64::TBZX: 243 case AArch64::TBNZX: { 244 // Check to see if the TBZ/TBNZ is checking the sign bit. 245 if ((MIOpc == AArch64::TBZX || MIOpc == AArch64::TBNZX) && 246 MI.getOperand(1).getImm() != 63) 247 return false; 248 // There must not be any instruction between DefMI and MI that clobbers or 249 // reads NZCV. 250 if (isNZCVTouchedInInstructionRange(DefMI, MI, TRI)) 251 return false; 252 LLVM_DEBUG(dbgs() << " Replacing instructions:\n "); 253 LLVM_DEBUG(DefMI.print(dbgs())); 254 LLVM_DEBUG(dbgs() << " "); 255 LLVM_DEBUG(MI.print(dbgs())); 256 257 NewCmp = convertToFlagSetting(DefMI, IsFlagSetting, /*Is64Bit=*/true); 258 NewBr = convertToCondBr(MI); 259 break; 260 } 261 } 262 break; 263 } 264 (void)NewCmp; (void)NewBr; 265 assert(NewCmp && NewBr && "Expected new instructions."); 266 267 LLVM_DEBUG(dbgs() << " with instruction:\n "); 268 LLVM_DEBUG(NewCmp->print(dbgs())); 269 LLVM_DEBUG(dbgs() << " "); 270 LLVM_DEBUG(NewBr->print(dbgs())); 271 272 // If this was a flag setting version of the instruction, we use the original 273 // instruction by just clearing the dead marked on the implicit-def of NCZV. 274 // Therefore, we should not erase this instruction. 275 if (!IsFlagSetting) 276 DefMI.eraseFromParent(); 277 MI.eraseFromParent(); 278 return true; 279 } 280 281 bool AArch64CondBrTuning::runOnMachineFunction(MachineFunction &MF) { 282 if (skipFunction(MF.getFunction())) 283 return false; 284 285 LLVM_DEBUG( 286 dbgs() << "********** AArch64 Conditional Branch Tuning **********\n" 287 << "********** Function: " << MF.getName() << '\n'); 288 289 TII = static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo()); 290 TRI = MF.getSubtarget().getRegisterInfo(); 291 MRI = &MF.getRegInfo(); 292 293 bool Changed = false; 294 for (MachineBasicBlock &MBB : MF) { 295 bool LocalChange = false; 296 for (MachineInstr &MI : MBB.terminators()) { 297 switch (MI.getOpcode()) { 298 default: 299 break; 300 case AArch64::CBZW: 301 case AArch64::CBZX: 302 case AArch64::CBNZW: 303 case AArch64::CBNZX: 304 case AArch64::TBZW: 305 case AArch64::TBZX: 306 case AArch64::TBNZW: 307 case AArch64::TBNZX: 308 MachineInstr *DefMI = getOperandDef(MI.getOperand(0)); 309 LocalChange = (DefMI && tryToTuneBranch(MI, *DefMI)); 310 break; 311 } 312 // If the optimization was successful, we can't optimize any other 313 // branches because doing so would clobber the NZCV flags. 314 if (LocalChange) { 315 Changed = true; 316 break; 317 } 318 } 319 } 320 return Changed; 321 } 322 323 FunctionPass *llvm::createAArch64CondBrTuning() { 324 return new AArch64CondBrTuning(); 325 } 326