1 //===-- AArch64CondBrTuning.cpp --- Conditional branch tuning for AArch64 -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// \file 9 /// This file contains a pass that transforms CBZ/CBNZ/TBZ/TBNZ instructions 10 /// into a conditional branch (B.cond), when the NZCV flags can be set for 11 /// "free". This is preferred on targets that have more flexibility when 12 /// scheduling B.cond instructions as compared to CBZ/CBNZ/TBZ/TBNZ (assuming 13 /// all other variables are equal). This can also reduce register pressure. 14 /// 15 /// A few examples: 16 /// 17 /// 1) add w8, w0, w1 -> cmn w0, w1 ; CMN is an alias of ADDS. 18 /// cbz w8, .LBB_2 -> b.eq .LBB0_2 19 /// 20 /// 2) add w8, w0, w1 -> adds w8, w0, w1 ; w8 has multiple uses. 21 /// cbz w8, .LBB1_2 -> b.eq .LBB1_2 22 /// 23 /// 3) sub w8, w0, w1 -> subs w8, w0, w1 ; w8 has multiple uses. 24 /// tbz w8, #31, .LBB6_2 -> b.pl .LBB6_2 25 /// 26 //===----------------------------------------------------------------------===// 27 28 #include "AArch64.h" 29 #include "AArch64Subtarget.h" 30 #include "llvm/CodeGen/MachineFunction.h" 31 #include "llvm/CodeGen/MachineFunctionPass.h" 32 #include "llvm/CodeGen/MachineInstrBuilder.h" 33 #include "llvm/CodeGen/MachineRegisterInfo.h" 34 #include "llvm/CodeGen/Passes.h" 35 #include "llvm/CodeGen/TargetInstrInfo.h" 36 #include "llvm/CodeGen/TargetRegisterInfo.h" 37 #include "llvm/CodeGen/TargetSubtargetInfo.h" 38 #include "llvm/Support/Debug.h" 39 #include "llvm/Support/raw_ostream.h" 40 41 using namespace llvm; 42 43 #define DEBUG_TYPE "aarch64-cond-br-tuning" 44 #define AARCH64_CONDBR_TUNING_NAME "AArch64 Conditional Branch Tuning" 45 46 namespace { 47 class AArch64CondBrTuning : public MachineFunctionPass { 48 const AArch64InstrInfo *TII; 49 const TargetRegisterInfo *TRI; 50 51 MachineRegisterInfo *MRI; 52 53 public: 54 static char ID; 55 AArch64CondBrTuning() : MachineFunctionPass(ID) { 56 initializeAArch64CondBrTuningPass(*PassRegistry::getPassRegistry()); 57 } 58 void getAnalysisUsage(AnalysisUsage &AU) const override; 59 bool runOnMachineFunction(MachineFunction &MF) override; 60 StringRef getPassName() const override { return AARCH64_CONDBR_TUNING_NAME; } 61 62 private: 63 MachineInstr *getOperandDef(const MachineOperand &MO); 64 MachineInstr *convertToFlagSetting(MachineInstr &MI, bool IsFlagSetting); 65 MachineInstr *convertToCondBr(MachineInstr &MI); 66 bool tryToTuneBranch(MachineInstr &MI, MachineInstr &DefMI); 67 }; 68 } // end anonymous namespace 69 70 char AArch64CondBrTuning::ID = 0; 71 72 INITIALIZE_PASS(AArch64CondBrTuning, "aarch64-cond-br-tuning", 73 AARCH64_CONDBR_TUNING_NAME, false, false) 74 75 void AArch64CondBrTuning::getAnalysisUsage(AnalysisUsage &AU) const { 76 AU.setPreservesCFG(); 77 MachineFunctionPass::getAnalysisUsage(AU); 78 } 79 80 MachineInstr *AArch64CondBrTuning::getOperandDef(const MachineOperand &MO) { 81 if (!Register::isVirtualRegister(MO.getReg())) 82 return nullptr; 83 return MRI->getUniqueVRegDef(MO.getReg()); 84 } 85 86 MachineInstr *AArch64CondBrTuning::convertToFlagSetting(MachineInstr &MI, 87 bool IsFlagSetting) { 88 // If this is already the flag setting version of the instruction (e.g., SUBS) 89 // just make sure the implicit-def of NZCV isn't marked dead. 90 if (IsFlagSetting) { 91 for (MachineOperand &MO : MI.implicit_operands()) 92 if (MO.isReg() && MO.isDead() && MO.getReg() == AArch64::NZCV) 93 MO.setIsDead(false); 94 return &MI; 95 } 96 bool Is64Bit; 97 unsigned NewOpc = TII->convertToFlagSettingOpc(MI.getOpcode(), Is64Bit); 98 Register NewDestReg = MI.getOperand(0).getReg(); 99 if (MRI->hasOneNonDBGUse(MI.getOperand(0).getReg())) 100 NewDestReg = Is64Bit ? AArch64::XZR : AArch64::WZR; 101 102 MachineInstrBuilder MIB = BuildMI(*MI.getParent(), MI, MI.getDebugLoc(), 103 TII->get(NewOpc), NewDestReg); 104 for (const MachineOperand &MO : llvm::drop_begin(MI.operands())) 105 MIB.add(MO); 106 107 return MIB; 108 } 109 110 MachineInstr *AArch64CondBrTuning::convertToCondBr(MachineInstr &MI) { 111 AArch64CC::CondCode CC; 112 MachineBasicBlock *TargetMBB = TII->getBranchDestBlock(MI); 113 switch (MI.getOpcode()) { 114 default: 115 llvm_unreachable("Unexpected opcode!"); 116 117 case AArch64::CBZW: 118 case AArch64::CBZX: 119 CC = AArch64CC::EQ; 120 break; 121 case AArch64::CBNZW: 122 case AArch64::CBNZX: 123 CC = AArch64CC::NE; 124 break; 125 case AArch64::TBZW: 126 case AArch64::TBZX: 127 CC = AArch64CC::PL; 128 break; 129 case AArch64::TBNZW: 130 case AArch64::TBNZX: 131 CC = AArch64CC::MI; 132 break; 133 } 134 return BuildMI(*MI.getParent(), MI, MI.getDebugLoc(), TII->get(AArch64::Bcc)) 135 .addImm(CC) 136 .addMBB(TargetMBB); 137 } 138 139 bool AArch64CondBrTuning::tryToTuneBranch(MachineInstr &MI, 140 MachineInstr &DefMI) { 141 // We don't want NZCV bits live across blocks. 142 if (MI.getParent() != DefMI.getParent()) 143 return false; 144 145 bool IsFlagSetting = true; 146 unsigned MIOpc = MI.getOpcode(); 147 MachineInstr *NewCmp = nullptr, *NewBr = nullptr; 148 switch (DefMI.getOpcode()) { 149 default: 150 return false; 151 case AArch64::ADDWri: 152 case AArch64::ADDWrr: 153 case AArch64::ADDWrs: 154 case AArch64::ADDWrx: 155 case AArch64::ANDWri: 156 case AArch64::ANDWrr: 157 case AArch64::ANDWrs: 158 case AArch64::BICWrr: 159 case AArch64::BICWrs: 160 case AArch64::SUBWri: 161 case AArch64::SUBWrr: 162 case AArch64::SUBWrs: 163 case AArch64::SUBWrx: 164 IsFlagSetting = false; 165 LLVM_FALLTHROUGH; 166 case AArch64::ADDSWri: 167 case AArch64::ADDSWrr: 168 case AArch64::ADDSWrs: 169 case AArch64::ADDSWrx: 170 case AArch64::ANDSWri: 171 case AArch64::ANDSWrr: 172 case AArch64::ANDSWrs: 173 case AArch64::BICSWrr: 174 case AArch64::BICSWrs: 175 case AArch64::SUBSWri: 176 case AArch64::SUBSWrr: 177 case AArch64::SUBSWrs: 178 case AArch64::SUBSWrx: 179 switch (MIOpc) { 180 default: 181 llvm_unreachable("Unexpected opcode!"); 182 183 case AArch64::CBZW: 184 case AArch64::CBNZW: 185 case AArch64::TBZW: 186 case AArch64::TBNZW: 187 // Check to see if the TBZ/TBNZ is checking the sign bit. 188 if ((MIOpc == AArch64::TBZW || MIOpc == AArch64::TBNZW) && 189 MI.getOperand(1).getImm() != 31) 190 return false; 191 192 // There must not be any instruction between DefMI and MI that clobbers or 193 // reads NZCV. 194 if (isNZCVTouchedInInstructionRange(DefMI, MI, TRI)) 195 return false; 196 LLVM_DEBUG(dbgs() << " Replacing instructions:\n "); 197 LLVM_DEBUG(DefMI.print(dbgs())); 198 LLVM_DEBUG(dbgs() << " "); 199 LLVM_DEBUG(MI.print(dbgs())); 200 201 NewCmp = convertToFlagSetting(DefMI, IsFlagSetting); 202 NewBr = convertToCondBr(MI); 203 break; 204 } 205 break; 206 207 case AArch64::ADDXri: 208 case AArch64::ADDXrr: 209 case AArch64::ADDXrs: 210 case AArch64::ADDXrx: 211 case AArch64::ANDXri: 212 case AArch64::ANDXrr: 213 case AArch64::ANDXrs: 214 case AArch64::BICXrr: 215 case AArch64::BICXrs: 216 case AArch64::SUBXri: 217 case AArch64::SUBXrr: 218 case AArch64::SUBXrs: 219 case AArch64::SUBXrx: 220 IsFlagSetting = false; 221 LLVM_FALLTHROUGH; 222 case AArch64::ADDSXri: 223 case AArch64::ADDSXrr: 224 case AArch64::ADDSXrs: 225 case AArch64::ADDSXrx: 226 case AArch64::ANDSXri: 227 case AArch64::ANDSXrr: 228 case AArch64::ANDSXrs: 229 case AArch64::BICSXrr: 230 case AArch64::BICSXrs: 231 case AArch64::SUBSXri: 232 case AArch64::SUBSXrr: 233 case AArch64::SUBSXrs: 234 case AArch64::SUBSXrx: 235 switch (MIOpc) { 236 default: 237 llvm_unreachable("Unexpected opcode!"); 238 239 case AArch64::CBZX: 240 case AArch64::CBNZX: 241 case AArch64::TBZX: 242 case AArch64::TBNZX: { 243 // Check to see if the TBZ/TBNZ is checking the sign bit. 244 if ((MIOpc == AArch64::TBZX || MIOpc == AArch64::TBNZX) && 245 MI.getOperand(1).getImm() != 63) 246 return false; 247 // There must not be any instruction between DefMI and MI that clobbers or 248 // reads NZCV. 249 if (isNZCVTouchedInInstructionRange(DefMI, MI, TRI)) 250 return false; 251 LLVM_DEBUG(dbgs() << " Replacing instructions:\n "); 252 LLVM_DEBUG(DefMI.print(dbgs())); 253 LLVM_DEBUG(dbgs() << " "); 254 LLVM_DEBUG(MI.print(dbgs())); 255 256 NewCmp = convertToFlagSetting(DefMI, IsFlagSetting); 257 NewBr = convertToCondBr(MI); 258 break; 259 } 260 } 261 break; 262 } 263 (void)NewCmp; (void)NewBr; 264 assert(NewCmp && NewBr && "Expected new instructions."); 265 266 LLVM_DEBUG(dbgs() << " with instruction:\n "); 267 LLVM_DEBUG(NewCmp->print(dbgs())); 268 LLVM_DEBUG(dbgs() << " "); 269 LLVM_DEBUG(NewBr->print(dbgs())); 270 271 // If this was a flag setting version of the instruction, we use the original 272 // instruction by just clearing the dead marked on the implicit-def of NCZV. 273 // Therefore, we should not erase this instruction. 274 if (!IsFlagSetting) 275 DefMI.eraseFromParent(); 276 MI.eraseFromParent(); 277 return true; 278 } 279 280 bool AArch64CondBrTuning::runOnMachineFunction(MachineFunction &MF) { 281 if (skipFunction(MF.getFunction())) 282 return false; 283 284 LLVM_DEBUG( 285 dbgs() << "********** AArch64 Conditional Branch Tuning **********\n" 286 << "********** Function: " << MF.getName() << '\n'); 287 288 TII = static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo()); 289 TRI = MF.getSubtarget().getRegisterInfo(); 290 MRI = &MF.getRegInfo(); 291 292 bool Changed = false; 293 for (MachineBasicBlock &MBB : MF) { 294 bool LocalChange = false; 295 for (MachineInstr &MI : MBB.terminators()) { 296 switch (MI.getOpcode()) { 297 default: 298 break; 299 case AArch64::CBZW: 300 case AArch64::CBZX: 301 case AArch64::CBNZW: 302 case AArch64::CBNZX: 303 case AArch64::TBZW: 304 case AArch64::TBZX: 305 case AArch64::TBNZW: 306 case AArch64::TBNZX: 307 MachineInstr *DefMI = getOperandDef(MI.getOperand(0)); 308 LocalChange = (DefMI && tryToTuneBranch(MI, *DefMI)); 309 break; 310 } 311 // If the optimization was successful, we can't optimize any other 312 // branches because doing so would clobber the NZCV flags. 313 if (LocalChange) { 314 Changed = true; 315 break; 316 } 317 } 318 } 319 return Changed; 320 } 321 322 FunctionPass *llvm::createAArch64CondBrTuning() { 323 return new AArch64CondBrTuning(); 324 } 325