1 //===- AArch64MacroFusion.cpp - AArch64 Macro Fusion ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file This file contains the AArch64 implementation of the DAG scheduling 10 /// mutation to pair instructions back to back. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "AArch64MacroFusion.h" 15 #include "AArch64Subtarget.h" 16 #include "llvm/CodeGen/MacroFusion.h" 17 #include "llvm/CodeGen/TargetInstrInfo.h" 18 19 using namespace llvm; 20 21 /// CMN, CMP, TST followed by Bcc 22 static bool isArithmeticBccPair(const MachineInstr *FirstMI, 23 const MachineInstr &SecondMI, bool CmpOnly) { 24 if (SecondMI.getOpcode() != AArch64::Bcc) 25 return false; 26 27 // Assume the 1st instr to be a wildcard if it is unspecified. 28 if (FirstMI == nullptr) 29 return true; 30 31 // If we're in CmpOnly mode, we only fuse arithmetic instructions that 32 // discard their result. 33 if (CmpOnly && FirstMI->getOperand(0).isReg() && 34 !(FirstMI->getOperand(0).getReg() == AArch64::XZR || 35 FirstMI->getOperand(0).getReg() == AArch64::WZR)) { 36 return false; 37 } 38 39 switch (FirstMI->getOpcode()) { 40 case AArch64::ADDSWri: 41 case AArch64::ADDSWrr: 42 case AArch64::ADDSXri: 43 case AArch64::ADDSXrr: 44 case AArch64::ANDSWri: 45 case AArch64::ANDSWrr: 46 case AArch64::ANDSXri: 47 case AArch64::ANDSXrr: 48 case AArch64::SUBSWri: 49 case AArch64::SUBSWrr: 50 case AArch64::SUBSXri: 51 case AArch64::SUBSXrr: 52 case AArch64::BICSWrr: 53 case AArch64::BICSXrr: 54 return true; 55 case AArch64::ADDSWrs: 56 case AArch64::ADDSXrs: 57 case AArch64::ANDSWrs: 58 case AArch64::ANDSXrs: 59 case AArch64::SUBSWrs: 60 case AArch64::SUBSXrs: 61 case AArch64::BICSWrs: 62 case AArch64::BICSXrs: 63 // Shift value can be 0 making these behave like the "rr" variant... 64 return !AArch64InstrInfo::hasShiftedReg(*FirstMI); 65 } 66 67 return false; 68 } 69 70 /// ALU operations followed by CBZ/CBNZ. 71 static bool isArithmeticCbzPair(const MachineInstr *FirstMI, 72 const MachineInstr &SecondMI) { 73 if (SecondMI.getOpcode() != AArch64::CBZW && 74 SecondMI.getOpcode() != AArch64::CBZX && 75 SecondMI.getOpcode() != AArch64::CBNZW && 76 SecondMI.getOpcode() != AArch64::CBNZX) 77 return false; 78 79 // Assume the 1st instr to be a wildcard if it is unspecified. 80 if (FirstMI == nullptr) 81 return true; 82 83 switch (FirstMI->getOpcode()) { 84 case AArch64::ADDWri: 85 case AArch64::ADDWrr: 86 case AArch64::ADDXri: 87 case AArch64::ADDXrr: 88 case AArch64::ANDWri: 89 case AArch64::ANDWrr: 90 case AArch64::ANDXri: 91 case AArch64::ANDXrr: 92 case AArch64::EORWri: 93 case AArch64::EORWrr: 94 case AArch64::EORXri: 95 case AArch64::EORXrr: 96 case AArch64::ORRWri: 97 case AArch64::ORRWrr: 98 case AArch64::ORRXri: 99 case AArch64::ORRXrr: 100 case AArch64::SUBWri: 101 case AArch64::SUBWrr: 102 case AArch64::SUBXri: 103 case AArch64::SUBXrr: 104 return true; 105 case AArch64::ADDWrs: 106 case AArch64::ADDXrs: 107 case AArch64::ANDWrs: 108 case AArch64::ANDXrs: 109 case AArch64::SUBWrs: 110 case AArch64::SUBXrs: 111 case AArch64::BICWrs: 112 case AArch64::BICXrs: 113 // Shift value can be 0 making these behave like the "rr" variant... 114 return !AArch64InstrInfo::hasShiftedReg(*FirstMI); 115 } 116 117 return false; 118 } 119 120 /// AES crypto encoding or decoding. 121 static bool isAESPair(const MachineInstr *FirstMI, 122 const MachineInstr &SecondMI) { 123 // Assume the 1st instr to be a wildcard if it is unspecified. 124 switch (SecondMI.getOpcode()) { 125 // AES encode. 126 case AArch64::AESMCrr: 127 case AArch64::AESMCrrTied: 128 return FirstMI == nullptr || FirstMI->getOpcode() == AArch64::AESErr; 129 // AES decode. 130 case AArch64::AESIMCrr: 131 case AArch64::AESIMCrrTied: 132 return FirstMI == nullptr || FirstMI->getOpcode() == AArch64::AESDrr; 133 } 134 135 return false; 136 } 137 138 /// AESE/AESD/PMULL + EOR. 139 static bool isCryptoEORPair(const MachineInstr *FirstMI, 140 const MachineInstr &SecondMI) { 141 if (SecondMI.getOpcode() != AArch64::EORv16i8) 142 return false; 143 144 // Assume the 1st instr to be a wildcard if it is unspecified. 145 if (FirstMI == nullptr) 146 return true; 147 148 switch (FirstMI->getOpcode()) { 149 case AArch64::AESErr: 150 case AArch64::AESDrr: 151 case AArch64::PMULLv16i8: 152 case AArch64::PMULLv8i8: 153 case AArch64::PMULLv1i64: 154 case AArch64::PMULLv2i64: 155 return true; 156 } 157 158 return false; 159 } 160 161 static bool isAdrpAddPair(const MachineInstr *FirstMI, 162 const MachineInstr &SecondMI) { 163 // Assume the 1st instr to be a wildcard if it is unspecified. 164 if ((FirstMI == nullptr || FirstMI->getOpcode() == AArch64::ADRP) && 165 SecondMI.getOpcode() == AArch64::ADDXri) 166 return true; 167 return false; 168 } 169 170 /// Literal generation. 171 static bool isLiteralsPair(const MachineInstr *FirstMI, 172 const MachineInstr &SecondMI) { 173 // Assume the 1st instr to be a wildcard if it is unspecified. 174 // 32 bit immediate. 175 if ((FirstMI == nullptr || FirstMI->getOpcode() == AArch64::MOVZWi) && 176 (SecondMI.getOpcode() == AArch64::MOVKWi && 177 SecondMI.getOperand(3).getImm() == 16)) 178 return true; 179 180 // Lower half of 64 bit immediate. 181 if((FirstMI == nullptr || FirstMI->getOpcode() == AArch64::MOVZXi) && 182 (SecondMI.getOpcode() == AArch64::MOVKXi && 183 SecondMI.getOperand(3).getImm() == 16)) 184 return true; 185 186 // Upper half of 64 bit immediate. 187 if ((FirstMI == nullptr || 188 (FirstMI->getOpcode() == AArch64::MOVKXi && 189 FirstMI->getOperand(3).getImm() == 32)) && 190 (SecondMI.getOpcode() == AArch64::MOVKXi && 191 SecondMI.getOperand(3).getImm() == 48)) 192 return true; 193 194 return false; 195 } 196 197 /// Fuse address generation and loads or stores. 198 static bool isAddressLdStPair(const MachineInstr *FirstMI, 199 const MachineInstr &SecondMI) { 200 switch (SecondMI.getOpcode()) { 201 case AArch64::STRBBui: 202 case AArch64::STRBui: 203 case AArch64::STRDui: 204 case AArch64::STRHHui: 205 case AArch64::STRHui: 206 case AArch64::STRQui: 207 case AArch64::STRSui: 208 case AArch64::STRWui: 209 case AArch64::STRXui: 210 case AArch64::LDRBBui: 211 case AArch64::LDRBui: 212 case AArch64::LDRDui: 213 case AArch64::LDRHHui: 214 case AArch64::LDRHui: 215 case AArch64::LDRQui: 216 case AArch64::LDRSui: 217 case AArch64::LDRWui: 218 case AArch64::LDRXui: 219 case AArch64::LDRSBWui: 220 case AArch64::LDRSBXui: 221 case AArch64::LDRSHWui: 222 case AArch64::LDRSHXui: 223 case AArch64::LDRSWui: 224 // Assume the 1st instr to be a wildcard if it is unspecified. 225 if (FirstMI == nullptr) 226 return true; 227 228 switch (FirstMI->getOpcode()) { 229 case AArch64::ADR: 230 return SecondMI.getOperand(2).getImm() == 0; 231 case AArch64::ADRP: 232 return true; 233 } 234 } 235 236 return false; 237 } 238 239 /// Compare and conditional select. 240 static bool isCCSelectPair(const MachineInstr *FirstMI, 241 const MachineInstr &SecondMI) { 242 // 32 bits 243 if (SecondMI.getOpcode() == AArch64::CSELWr) { 244 // Assume the 1st instr to be a wildcard if it is unspecified. 245 if (FirstMI == nullptr) 246 return true; 247 248 if (FirstMI->definesRegister(AArch64::WZR, /*TRI=*/nullptr)) 249 switch (FirstMI->getOpcode()) { 250 case AArch64::SUBSWrs: 251 return !AArch64InstrInfo::hasShiftedReg(*FirstMI); 252 case AArch64::SUBSWrx: 253 return !AArch64InstrInfo::hasExtendedReg(*FirstMI); 254 case AArch64::SUBSWrr: 255 case AArch64::SUBSWri: 256 return true; 257 } 258 } 259 260 // 64 bits 261 if (SecondMI.getOpcode() == AArch64::CSELXr) { 262 // Assume the 1st instr to be a wildcard if it is unspecified. 263 if (FirstMI == nullptr) 264 return true; 265 266 if (FirstMI->definesRegister(AArch64::XZR, /*TRI=*/nullptr)) 267 switch (FirstMI->getOpcode()) { 268 case AArch64::SUBSXrs: 269 return !AArch64InstrInfo::hasShiftedReg(*FirstMI); 270 case AArch64::SUBSXrx: 271 case AArch64::SUBSXrx64: 272 return !AArch64InstrInfo::hasExtendedReg(*FirstMI); 273 case AArch64::SUBSXrr: 274 case AArch64::SUBSXri: 275 return true; 276 } 277 } 278 279 return false; 280 } 281 282 // Arithmetic and logic. 283 static bool isArithmeticLogicPair(const MachineInstr *FirstMI, 284 const MachineInstr &SecondMI) { 285 if (AArch64InstrInfo::hasShiftedReg(SecondMI)) 286 return false; 287 288 switch (SecondMI.getOpcode()) { 289 // Arithmetic 290 case AArch64::ADDWrr: 291 case AArch64::ADDXrr: 292 case AArch64::SUBWrr: 293 case AArch64::SUBXrr: 294 case AArch64::ADDWrs: 295 case AArch64::ADDXrs: 296 case AArch64::SUBWrs: 297 case AArch64::SUBXrs: 298 // Logic 299 case AArch64::ANDWrr: 300 case AArch64::ANDXrr: 301 case AArch64::BICWrr: 302 case AArch64::BICXrr: 303 case AArch64::EONWrr: 304 case AArch64::EONXrr: 305 case AArch64::EORWrr: 306 case AArch64::EORXrr: 307 case AArch64::ORNWrr: 308 case AArch64::ORNXrr: 309 case AArch64::ORRWrr: 310 case AArch64::ORRXrr: 311 case AArch64::ANDWrs: 312 case AArch64::ANDXrs: 313 case AArch64::BICWrs: 314 case AArch64::BICXrs: 315 case AArch64::EONWrs: 316 case AArch64::EONXrs: 317 case AArch64::EORWrs: 318 case AArch64::EORXrs: 319 case AArch64::ORNWrs: 320 case AArch64::ORNXrs: 321 case AArch64::ORRWrs: 322 case AArch64::ORRXrs: 323 // Assume the 1st instr to be a wildcard if it is unspecified. 324 if (FirstMI == nullptr) 325 return true; 326 327 // Arithmetic 328 switch (FirstMI->getOpcode()) { 329 case AArch64::ADDWrr: 330 case AArch64::ADDXrr: 331 case AArch64::ADDSWrr: 332 case AArch64::ADDSXrr: 333 case AArch64::SUBWrr: 334 case AArch64::SUBXrr: 335 case AArch64::SUBSWrr: 336 case AArch64::SUBSXrr: 337 return true; 338 case AArch64::ADDWrs: 339 case AArch64::ADDXrs: 340 case AArch64::ADDSWrs: 341 case AArch64::ADDSXrs: 342 case AArch64::SUBWrs: 343 case AArch64::SUBXrs: 344 case AArch64::SUBSWrs: 345 case AArch64::SUBSXrs: 346 return !AArch64InstrInfo::hasShiftedReg(*FirstMI); 347 } 348 break; 349 350 // Arithmetic, setting flags. 351 case AArch64::ADDSWrr: 352 case AArch64::ADDSXrr: 353 case AArch64::SUBSWrr: 354 case AArch64::SUBSXrr: 355 case AArch64::ADDSWrs: 356 case AArch64::ADDSXrs: 357 case AArch64::SUBSWrs: 358 case AArch64::SUBSXrs: 359 // Assume the 1st instr to be a wildcard if it is unspecified. 360 if (FirstMI == nullptr) 361 return true; 362 363 // Arithmetic, not setting flags. 364 switch (FirstMI->getOpcode()) { 365 case AArch64::ADDWrr: 366 case AArch64::ADDXrr: 367 case AArch64::SUBWrr: 368 case AArch64::SUBXrr: 369 return true; 370 case AArch64::ADDWrs: 371 case AArch64::ADDXrs: 372 case AArch64::SUBWrs: 373 case AArch64::SUBXrs: 374 return !AArch64InstrInfo::hasShiftedReg(*FirstMI); 375 } 376 break; 377 } 378 379 return false; 380 } 381 382 // "(A + B) + 1" or "(A - B) - 1" 383 static bool isAddSub2RegAndConstOnePair(const MachineInstr *FirstMI, 384 const MachineInstr &SecondMI) { 385 bool NeedsSubtract = false; 386 387 // The 2nd instr must be an add-immediate or subtract-immediate. 388 switch (SecondMI.getOpcode()) { 389 case AArch64::SUBWri: 390 case AArch64::SUBXri: 391 NeedsSubtract = true; 392 [[fallthrough]]; 393 case AArch64::ADDWri: 394 case AArch64::ADDXri: 395 break; 396 397 default: 398 return false; 399 } 400 401 // The immediate in the 2nd instr must be "1". 402 if (!SecondMI.getOperand(2).isImm() || SecondMI.getOperand(2).getImm() != 1) { 403 return false; 404 } 405 406 // Assume the 1st instr to be a wildcard if it is unspecified. 407 if (FirstMI == nullptr) { 408 return true; 409 } 410 411 switch (FirstMI->getOpcode()) { 412 case AArch64::SUBWrs: 413 case AArch64::SUBXrs: 414 if (AArch64InstrInfo::hasShiftedReg(*FirstMI)) 415 return false; 416 [[fallthrough]]; 417 case AArch64::SUBWrr: 418 case AArch64::SUBXrr: 419 if (NeedsSubtract) { 420 return true; 421 } 422 break; 423 424 case AArch64::ADDWrs: 425 case AArch64::ADDXrs: 426 if (AArch64InstrInfo::hasShiftedReg(*FirstMI)) 427 return false; 428 [[fallthrough]]; 429 case AArch64::ADDWrr: 430 case AArch64::ADDXrr: 431 if (!NeedsSubtract) { 432 return true; 433 } 434 break; 435 } 436 437 return false; 438 } 439 440 /// \brief Check if the instr pair, FirstMI and SecondMI, should be fused 441 /// together. Given SecondMI, when FirstMI is unspecified, then check if 442 /// SecondMI may be part of a fused pair at all. 443 static bool shouldScheduleAdjacent(const TargetInstrInfo &TII, 444 const TargetSubtargetInfo &TSI, 445 const MachineInstr *FirstMI, 446 const MachineInstr &SecondMI) { 447 const AArch64Subtarget &ST = static_cast<const AArch64Subtarget&>(TSI); 448 449 // All checking functions assume that the 1st instr is a wildcard if it is 450 // unspecified. 451 if (ST.hasCmpBccFusion() || ST.hasArithmeticBccFusion()) { 452 bool CmpOnly = !ST.hasArithmeticBccFusion(); 453 if (isArithmeticBccPair(FirstMI, SecondMI, CmpOnly)) 454 return true; 455 } 456 if (ST.hasArithmeticCbzFusion() && isArithmeticCbzPair(FirstMI, SecondMI)) 457 return true; 458 if (ST.hasFuseAES() && isAESPair(FirstMI, SecondMI)) 459 return true; 460 if (ST.hasFuseCryptoEOR() && isCryptoEORPair(FirstMI, SecondMI)) 461 return true; 462 if (ST.hasFuseAdrpAdd() && isAdrpAddPair(FirstMI, SecondMI)) 463 return true; 464 if (ST.hasFuseLiterals() && isLiteralsPair(FirstMI, SecondMI)) 465 return true; 466 if (ST.hasFuseAddress() && isAddressLdStPair(FirstMI, SecondMI)) 467 return true; 468 if (ST.hasFuseCCSelect() && isCCSelectPair(FirstMI, SecondMI)) 469 return true; 470 if (ST.hasFuseArithmeticLogic() && isArithmeticLogicPair(FirstMI, SecondMI)) 471 return true; 472 if (ST.hasFuseAddSub2RegAndConstOne() && 473 isAddSub2RegAndConstOnePair(FirstMI, SecondMI)) 474 return true; 475 476 return false; 477 } 478 479 std::unique_ptr<ScheduleDAGMutation> 480 llvm::createAArch64MacroFusionDAGMutation() { 481 return createMacroFusionDAGMutation(shouldScheduleAdjacent); 482 } 483