1 //=== AArch64PostLegalizerCombiner.cpp --------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// Post-legalization combines on generic MachineInstrs. 11 /// 12 /// The combines here must preserve instruction legality. 13 /// 14 /// Lowering combines (e.g. pseudo matching) should be handled by 15 /// AArch64PostLegalizerLowering. 16 /// 17 /// Combines which don't rely on instruction legality should go in the 18 /// AArch64PreLegalizerCombiner. 19 /// 20 //===----------------------------------------------------------------------===// 21 22 #include "AArch64TargetMachine.h" 23 #include "llvm/ADT/STLExtras.h" 24 #include "llvm/CodeGen/GlobalISel/CSEInfo.h" 25 #include "llvm/CodeGen/GlobalISel/CSEMIRBuilder.h" 26 #include "llvm/CodeGen/GlobalISel/Combiner.h" 27 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h" 28 #include "llvm/CodeGen/GlobalISel/CombinerInfo.h" 29 #include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h" 30 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h" 31 #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h" 32 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" 33 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" 34 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 35 #include "llvm/CodeGen/GlobalISel/Utils.h" 36 #include "llvm/CodeGen/MachineDominators.h" 37 #include "llvm/CodeGen/MachineFunctionPass.h" 38 #include "llvm/CodeGen/MachineRegisterInfo.h" 39 #include "llvm/CodeGen/TargetOpcodes.h" 40 #include "llvm/CodeGen/TargetPassConfig.h" 41 #include "llvm/Support/Debug.h" 42 43 #define GET_GICOMBINER_DEPS 44 #include "AArch64GenPostLegalizeGICombiner.inc" 45 #undef GET_GICOMBINER_DEPS 46 47 #define DEBUG_TYPE "aarch64-postlegalizer-combiner" 48 49 using namespace llvm; 50 using namespace MIPatternMatch; 51 52 namespace { 53 54 #define GET_GICOMBINER_TYPES 55 #include "AArch64GenPostLegalizeGICombiner.inc" 56 #undef GET_GICOMBINER_TYPES 57 58 /// This combine tries do what performExtractVectorEltCombine does in SDAG. 59 /// Rewrite for pairwise fadd pattern 60 /// (s32 (g_extract_vector_elt 61 /// (g_fadd (vXs32 Other) 62 /// (g_vector_shuffle (vXs32 Other) undef <1,X,...> )) 0)) 63 /// -> 64 /// (s32 (g_fadd (g_extract_vector_elt (vXs32 Other) 0) 65 /// (g_extract_vector_elt (vXs32 Other) 1)) 66 bool matchExtractVecEltPairwiseAdd( 67 MachineInstr &MI, MachineRegisterInfo &MRI, 68 std::tuple<unsigned, LLT, Register> &MatchInfo) { 69 Register Src1 = MI.getOperand(1).getReg(); 70 Register Src2 = MI.getOperand(2).getReg(); 71 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 72 73 auto Cst = getIConstantVRegValWithLookThrough(Src2, MRI); 74 if (!Cst || Cst->Value != 0) 75 return false; 76 // SDAG also checks for FullFP16, but this looks to be beneficial anyway. 77 78 // Now check for an fadd operation. TODO: expand this for integer add? 79 auto *FAddMI = getOpcodeDef(TargetOpcode::G_FADD, Src1, MRI); 80 if (!FAddMI) 81 return false; 82 83 // If we add support for integer add, must restrict these types to just s64. 84 unsigned DstSize = DstTy.getSizeInBits(); 85 if (DstSize != 16 && DstSize != 32 && DstSize != 64) 86 return false; 87 88 Register Src1Op1 = FAddMI->getOperand(1).getReg(); 89 Register Src1Op2 = FAddMI->getOperand(2).getReg(); 90 MachineInstr *Shuffle = 91 getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op2, MRI); 92 MachineInstr *Other = MRI.getVRegDef(Src1Op1); 93 if (!Shuffle) { 94 Shuffle = getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op1, MRI); 95 Other = MRI.getVRegDef(Src1Op2); 96 } 97 98 // We're looking for a shuffle that moves the second element to index 0. 99 if (Shuffle && Shuffle->getOperand(3).getShuffleMask()[0] == 1 && 100 Other == MRI.getVRegDef(Shuffle->getOperand(1).getReg())) { 101 std::get<0>(MatchInfo) = TargetOpcode::G_FADD; 102 std::get<1>(MatchInfo) = DstTy; 103 std::get<2>(MatchInfo) = Other->getOperand(0).getReg(); 104 return true; 105 } 106 return false; 107 } 108 109 void applyExtractVecEltPairwiseAdd( 110 MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B, 111 std::tuple<unsigned, LLT, Register> &MatchInfo) { 112 unsigned Opc = std::get<0>(MatchInfo); 113 assert(Opc == TargetOpcode::G_FADD && "Unexpected opcode!"); 114 // We want to generate two extracts of elements 0 and 1, and add them. 115 LLT Ty = std::get<1>(MatchInfo); 116 Register Src = std::get<2>(MatchInfo); 117 LLT s64 = LLT::scalar(64); 118 B.setInstrAndDebugLoc(MI); 119 auto Elt0 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 0)); 120 auto Elt1 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 1)); 121 B.buildInstr(Opc, {MI.getOperand(0).getReg()}, {Elt0, Elt1}); 122 MI.eraseFromParent(); 123 } 124 125 bool isSignExtended(Register R, MachineRegisterInfo &MRI) { 126 // TODO: check if extended build vector as well. 127 unsigned Opc = MRI.getVRegDef(R)->getOpcode(); 128 return Opc == TargetOpcode::G_SEXT || Opc == TargetOpcode::G_SEXT_INREG; 129 } 130 131 bool isZeroExtended(Register R, MachineRegisterInfo &MRI) { 132 // TODO: check if extended build vector as well. 133 return MRI.getVRegDef(R)->getOpcode() == TargetOpcode::G_ZEXT; 134 } 135 136 bool matchAArch64MulConstCombine( 137 MachineInstr &MI, MachineRegisterInfo &MRI, 138 std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) { 139 assert(MI.getOpcode() == TargetOpcode::G_MUL); 140 Register LHS = MI.getOperand(1).getReg(); 141 Register RHS = MI.getOperand(2).getReg(); 142 Register Dst = MI.getOperand(0).getReg(); 143 const LLT Ty = MRI.getType(LHS); 144 145 // The below optimizations require a constant RHS. 146 auto Const = getIConstantVRegValWithLookThrough(RHS, MRI); 147 if (!Const) 148 return false; 149 150 APInt ConstValue = Const->Value.sext(Ty.getSizeInBits()); 151 // The following code is ported from AArch64ISelLowering. 152 // Multiplication of a power of two plus/minus one can be done more 153 // cheaply as shift+add/sub. For now, this is true unilaterally. If 154 // future CPUs have a cheaper MADD instruction, this may need to be 155 // gated on a subtarget feature. For Cyclone, 32-bit MADD is 4 cycles and 156 // 64-bit is 5 cycles, so this is always a win. 157 // More aggressively, some multiplications N0 * C can be lowered to 158 // shift+add+shift if the constant C = A * B where A = 2^N + 1 and B = 2^M, 159 // e.g. 6=3*2=(2+1)*2. 160 // TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45 161 // which equals to (1+2)*16-(1+2). 162 // TrailingZeroes is used to test if the mul can be lowered to 163 // shift+add+shift. 164 unsigned TrailingZeroes = ConstValue.countr_zero(); 165 if (TrailingZeroes) { 166 // Conservatively do not lower to shift+add+shift if the mul might be 167 // folded into smul or umul. 168 if (MRI.hasOneNonDBGUse(LHS) && 169 (isSignExtended(LHS, MRI) || isZeroExtended(LHS, MRI))) 170 return false; 171 // Conservatively do not lower to shift+add+shift if the mul might be 172 // folded into madd or msub. 173 if (MRI.hasOneNonDBGUse(Dst)) { 174 MachineInstr &UseMI = *MRI.use_instr_begin(Dst); 175 unsigned UseOpc = UseMI.getOpcode(); 176 if (UseOpc == TargetOpcode::G_ADD || UseOpc == TargetOpcode::G_PTR_ADD || 177 UseOpc == TargetOpcode::G_SUB) 178 return false; 179 } 180 } 181 // Use ShiftedConstValue instead of ConstValue to support both shift+add/sub 182 // and shift+add+shift. 183 APInt ShiftedConstValue = ConstValue.ashr(TrailingZeroes); 184 185 unsigned ShiftAmt, AddSubOpc; 186 // Is the shifted value the LHS operand of the add/sub? 187 bool ShiftValUseIsLHS = true; 188 // Do we need to negate the result? 189 bool NegateResult = false; 190 191 if (ConstValue.isNonNegative()) { 192 // (mul x, 2^N + 1) => (add (shl x, N), x) 193 // (mul x, 2^N - 1) => (sub (shl x, N), x) 194 // (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M) 195 APInt SCVMinus1 = ShiftedConstValue - 1; 196 APInt CVPlus1 = ConstValue + 1; 197 if (SCVMinus1.isPowerOf2()) { 198 ShiftAmt = SCVMinus1.logBase2(); 199 AddSubOpc = TargetOpcode::G_ADD; 200 } else if (CVPlus1.isPowerOf2()) { 201 ShiftAmt = CVPlus1.logBase2(); 202 AddSubOpc = TargetOpcode::G_SUB; 203 } else 204 return false; 205 } else { 206 // (mul x, -(2^N - 1)) => (sub x, (shl x, N)) 207 // (mul x, -(2^N + 1)) => - (add (shl x, N), x) 208 APInt CVNegPlus1 = -ConstValue + 1; 209 APInt CVNegMinus1 = -ConstValue - 1; 210 if (CVNegPlus1.isPowerOf2()) { 211 ShiftAmt = CVNegPlus1.logBase2(); 212 AddSubOpc = TargetOpcode::G_SUB; 213 ShiftValUseIsLHS = false; 214 } else if (CVNegMinus1.isPowerOf2()) { 215 ShiftAmt = CVNegMinus1.logBase2(); 216 AddSubOpc = TargetOpcode::G_ADD; 217 NegateResult = true; 218 } else 219 return false; 220 } 221 222 if (NegateResult && TrailingZeroes) 223 return false; 224 225 ApplyFn = [=](MachineIRBuilder &B, Register DstReg) { 226 auto Shift = B.buildConstant(LLT::scalar(64), ShiftAmt); 227 auto ShiftedVal = B.buildShl(Ty, LHS, Shift); 228 229 Register AddSubLHS = ShiftValUseIsLHS ? ShiftedVal.getReg(0) : LHS; 230 Register AddSubRHS = ShiftValUseIsLHS ? LHS : ShiftedVal.getReg(0); 231 auto Res = B.buildInstr(AddSubOpc, {Ty}, {AddSubLHS, AddSubRHS}); 232 assert(!(NegateResult && TrailingZeroes) && 233 "NegateResult and TrailingZeroes cannot both be true for now."); 234 // Negate the result. 235 if (NegateResult) { 236 B.buildSub(DstReg, B.buildConstant(Ty, 0), Res); 237 return; 238 } 239 // Shift the result. 240 if (TrailingZeroes) { 241 B.buildShl(DstReg, Res, B.buildConstant(LLT::scalar(64), TrailingZeroes)); 242 return; 243 } 244 B.buildCopy(DstReg, Res.getReg(0)); 245 }; 246 return true; 247 } 248 249 void applyAArch64MulConstCombine( 250 MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B, 251 std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) { 252 B.setInstrAndDebugLoc(MI); 253 ApplyFn(B, MI.getOperand(0).getReg()); 254 MI.eraseFromParent(); 255 } 256 257 /// Try to fold a G_MERGE_VALUES of 2 s32 sources, where the second source 258 /// is a zero, into a G_ZEXT of the first. 259 bool matchFoldMergeToZext(MachineInstr &MI, MachineRegisterInfo &MRI) { 260 auto &Merge = cast<GMerge>(MI); 261 LLT SrcTy = MRI.getType(Merge.getSourceReg(0)); 262 if (SrcTy != LLT::scalar(32) || Merge.getNumSources() != 2) 263 return false; 264 return mi_match(Merge.getSourceReg(1), MRI, m_SpecificICst(0)); 265 } 266 267 void applyFoldMergeToZext(MachineInstr &MI, MachineRegisterInfo &MRI, 268 MachineIRBuilder &B, GISelChangeObserver &Observer) { 269 // Mutate %d(s64) = G_MERGE_VALUES %a(s32), 0(s32) 270 // -> 271 // %d(s64) = G_ZEXT %a(s32) 272 Observer.changingInstr(MI); 273 MI.setDesc(B.getTII().get(TargetOpcode::G_ZEXT)); 274 MI.removeOperand(2); 275 Observer.changedInstr(MI); 276 } 277 278 /// \returns True if a G_ANYEXT instruction \p MI should be mutated to a G_ZEXT 279 /// instruction. 280 bool matchMutateAnyExtToZExt(MachineInstr &MI, MachineRegisterInfo &MRI) { 281 // If this is coming from a scalar compare then we can use a G_ZEXT instead of 282 // a G_ANYEXT: 283 // 284 // %cmp:_(s32) = G_[I|F]CMP ... <-- produces 0/1. 285 // %ext:_(s64) = G_ANYEXT %cmp(s32) 286 // 287 // By doing this, we can leverage more KnownBits combines. 288 assert(MI.getOpcode() == TargetOpcode::G_ANYEXT); 289 Register Dst = MI.getOperand(0).getReg(); 290 Register Src = MI.getOperand(1).getReg(); 291 return MRI.getType(Dst).isScalar() && 292 mi_match(Src, MRI, 293 m_any_of(m_GICmp(m_Pred(), m_Reg(), m_Reg()), 294 m_GFCmp(m_Pred(), m_Reg(), m_Reg()))); 295 } 296 297 void applyMutateAnyExtToZExt(MachineInstr &MI, MachineRegisterInfo &MRI, 298 MachineIRBuilder &B, 299 GISelChangeObserver &Observer) { 300 Observer.changingInstr(MI); 301 MI.setDesc(B.getTII().get(TargetOpcode::G_ZEXT)); 302 Observer.changedInstr(MI); 303 } 304 305 /// Match a 128b store of zero and split it into two 64 bit stores, for 306 /// size/performance reasons. 307 bool matchSplitStoreZero128(MachineInstr &MI, MachineRegisterInfo &MRI) { 308 GStore &Store = cast<GStore>(MI); 309 if (!Store.isSimple()) 310 return false; 311 LLT ValTy = MRI.getType(Store.getValueReg()); 312 if (ValTy.isScalableVector()) 313 return false; 314 if (!ValTy.isVector() || ValTy.getSizeInBits() != 128) 315 return false; 316 if (Store.getMemSizeInBits() != ValTy.getSizeInBits()) 317 return false; // Don't split truncating stores. 318 if (!MRI.hasOneNonDBGUse(Store.getValueReg())) 319 return false; 320 auto MaybeCst = isConstantOrConstantSplatVector( 321 *MRI.getVRegDef(Store.getValueReg()), MRI); 322 return MaybeCst && MaybeCst->isZero(); 323 } 324 325 void applySplitStoreZero128(MachineInstr &MI, MachineRegisterInfo &MRI, 326 MachineIRBuilder &B, 327 GISelChangeObserver &Observer) { 328 B.setInstrAndDebugLoc(MI); 329 GStore &Store = cast<GStore>(MI); 330 assert(MRI.getType(Store.getValueReg()).isVector() && 331 "Expected a vector store value"); 332 LLT NewTy = LLT::scalar(64); 333 Register PtrReg = Store.getPointerReg(); 334 auto Zero = B.buildConstant(NewTy, 0); 335 auto HighPtr = B.buildPtrAdd(MRI.getType(PtrReg), PtrReg, 336 B.buildConstant(LLT::scalar(64), 8)); 337 auto &MF = *MI.getMF(); 338 auto *LowMMO = MF.getMachineMemOperand(&Store.getMMO(), 0, NewTy); 339 auto *HighMMO = MF.getMachineMemOperand(&Store.getMMO(), 8, NewTy); 340 B.buildStore(Zero, PtrReg, *LowMMO); 341 B.buildStore(Zero, HighPtr, *HighMMO); 342 Store.eraseFromParent(); 343 } 344 345 bool matchOrToBSP(MachineInstr &MI, MachineRegisterInfo &MRI, 346 std::tuple<Register, Register, Register> &MatchInfo) { 347 const LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 348 if (!DstTy.isVector()) 349 return false; 350 351 Register AO1, AO2, BVO1, BVO2; 352 if (!mi_match(MI, MRI, 353 m_GOr(m_GAnd(m_Reg(AO1), m_Reg(BVO1)), 354 m_GAnd(m_Reg(AO2), m_Reg(BVO2))))) 355 return false; 356 357 auto *BV1 = getOpcodeDef<GBuildVector>(BVO1, MRI); 358 auto *BV2 = getOpcodeDef<GBuildVector>(BVO2, MRI); 359 if (!BV1 || !BV2) 360 return false; 361 362 for (int I = 0, E = DstTy.getNumElements(); I < E; I++) { 363 auto ValAndVReg1 = 364 getIConstantVRegValWithLookThrough(BV1->getSourceReg(I), MRI); 365 auto ValAndVReg2 = 366 getIConstantVRegValWithLookThrough(BV2->getSourceReg(I), MRI); 367 if (!ValAndVReg1 || !ValAndVReg2 || 368 ValAndVReg1->Value != ~ValAndVReg2->Value) 369 return false; 370 } 371 372 MatchInfo = {AO1, AO2, BVO1}; 373 return true; 374 } 375 376 void applyOrToBSP(MachineInstr &MI, MachineRegisterInfo &MRI, 377 MachineIRBuilder &B, 378 std::tuple<Register, Register, Register> &MatchInfo) { 379 B.setInstrAndDebugLoc(MI); 380 B.buildInstr( 381 AArch64::G_BSP, {MI.getOperand(0).getReg()}, 382 {std::get<2>(MatchInfo), std::get<0>(MatchInfo), std::get<1>(MatchInfo)}); 383 MI.eraseFromParent(); 384 } 385 386 // Combines Mul(And(Srl(X, 15), 0x10001), 0xffff) into CMLTz 387 bool matchCombineMulCMLT(MachineInstr &MI, MachineRegisterInfo &MRI, 388 Register &SrcReg) { 389 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 390 391 if (DstTy != LLT::fixed_vector(2, 64) && DstTy != LLT::fixed_vector(2, 32) && 392 DstTy != LLT::fixed_vector(4, 32) && DstTy != LLT::fixed_vector(4, 16) && 393 DstTy != LLT::fixed_vector(8, 16)) 394 return false; 395 396 auto AndMI = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI); 397 if (AndMI->getOpcode() != TargetOpcode::G_AND) 398 return false; 399 auto LShrMI = getDefIgnoringCopies(AndMI->getOperand(1).getReg(), MRI); 400 if (LShrMI->getOpcode() != TargetOpcode::G_LSHR) 401 return false; 402 403 // Check the constant splat values 404 auto V1 = isConstantOrConstantSplatVector( 405 *MRI.getVRegDef(MI.getOperand(2).getReg()), MRI); 406 auto V2 = isConstantOrConstantSplatVector( 407 *MRI.getVRegDef(AndMI->getOperand(2).getReg()), MRI); 408 auto V3 = isConstantOrConstantSplatVector( 409 *MRI.getVRegDef(LShrMI->getOperand(2).getReg()), MRI); 410 if (!V1.has_value() || !V2.has_value() || !V3.has_value()) 411 return false; 412 unsigned HalfSize = DstTy.getScalarSizeInBits() / 2; 413 if (!V1.value().isMask(HalfSize) || V2.value() != (1ULL | 1ULL << HalfSize) || 414 V3 != (HalfSize - 1)) 415 return false; 416 417 SrcReg = LShrMI->getOperand(1).getReg(); 418 419 return true; 420 } 421 422 void applyCombineMulCMLT(MachineInstr &MI, MachineRegisterInfo &MRI, 423 MachineIRBuilder &B, Register &SrcReg) { 424 Register DstReg = MI.getOperand(0).getReg(); 425 LLT DstTy = MRI.getType(DstReg); 426 LLT HalfTy = 427 DstTy.changeElementCount(DstTy.getElementCount().multiplyCoefficientBy(2)) 428 .changeElementSize(DstTy.getScalarSizeInBits() / 2); 429 430 Register ZeroVec = B.buildConstant(HalfTy, 0).getReg(0); 431 Register CastReg = 432 B.buildInstr(TargetOpcode::G_BITCAST, {HalfTy}, {SrcReg}).getReg(0); 433 Register CMLTReg = 434 B.buildICmp(CmpInst::Predicate::ICMP_SLT, HalfTy, CastReg, ZeroVec) 435 .getReg(0); 436 437 B.buildInstr(TargetOpcode::G_BITCAST, {DstReg}, {CMLTReg}).getReg(0); 438 MI.eraseFromParent(); 439 } 440 441 class AArch64PostLegalizerCombinerImpl : public Combiner { 442 protected: 443 // TODO: Make CombinerHelper methods const. 444 mutable CombinerHelper Helper; 445 const AArch64PostLegalizerCombinerImplRuleConfig &RuleConfig; 446 const AArch64Subtarget &STI; 447 448 public: 449 AArch64PostLegalizerCombinerImpl( 450 MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC, 451 GISelKnownBits &KB, GISelCSEInfo *CSEInfo, 452 const AArch64PostLegalizerCombinerImplRuleConfig &RuleConfig, 453 const AArch64Subtarget &STI, MachineDominatorTree *MDT, 454 const LegalizerInfo *LI); 455 456 static const char *getName() { return "AArch64PostLegalizerCombiner"; } 457 458 bool tryCombineAll(MachineInstr &I) const override; 459 460 private: 461 #define GET_GICOMBINER_CLASS_MEMBERS 462 #include "AArch64GenPostLegalizeGICombiner.inc" 463 #undef GET_GICOMBINER_CLASS_MEMBERS 464 }; 465 466 #define GET_GICOMBINER_IMPL 467 #include "AArch64GenPostLegalizeGICombiner.inc" 468 #undef GET_GICOMBINER_IMPL 469 470 AArch64PostLegalizerCombinerImpl::AArch64PostLegalizerCombinerImpl( 471 MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC, 472 GISelKnownBits &KB, GISelCSEInfo *CSEInfo, 473 const AArch64PostLegalizerCombinerImplRuleConfig &RuleConfig, 474 const AArch64Subtarget &STI, MachineDominatorTree *MDT, 475 const LegalizerInfo *LI) 476 : Combiner(MF, CInfo, TPC, &KB, CSEInfo), 477 Helper(Observer, B, /*IsPreLegalize*/ false, &KB, MDT, LI), 478 RuleConfig(RuleConfig), STI(STI), 479 #define GET_GICOMBINER_CONSTRUCTOR_INITS 480 #include "AArch64GenPostLegalizeGICombiner.inc" 481 #undef GET_GICOMBINER_CONSTRUCTOR_INITS 482 { 483 } 484 485 class AArch64PostLegalizerCombiner : public MachineFunctionPass { 486 public: 487 static char ID; 488 489 AArch64PostLegalizerCombiner(bool IsOptNone = false); 490 491 StringRef getPassName() const override { 492 return "AArch64PostLegalizerCombiner"; 493 } 494 495 bool runOnMachineFunction(MachineFunction &MF) override; 496 void getAnalysisUsage(AnalysisUsage &AU) const override; 497 498 private: 499 bool IsOptNone; 500 AArch64PostLegalizerCombinerImplRuleConfig RuleConfig; 501 502 503 struct StoreInfo { 504 GStore *St = nullptr; 505 // The G_PTR_ADD that's used by the store. We keep this to cache the 506 // MachineInstr def. 507 GPtrAdd *Ptr = nullptr; 508 // The signed offset to the Ptr instruction. 509 int64_t Offset = 0; 510 LLT StoredType; 511 }; 512 bool tryOptimizeConsecStores(SmallVectorImpl<StoreInfo> &Stores, 513 CSEMIRBuilder &MIB); 514 515 bool optimizeConsecutiveMemOpAddressing(MachineFunction &MF, 516 CSEMIRBuilder &MIB); 517 }; 518 } // end anonymous namespace 519 520 void AArch64PostLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const { 521 AU.addRequired<TargetPassConfig>(); 522 AU.setPreservesCFG(); 523 getSelectionDAGFallbackAnalysisUsage(AU); 524 AU.addRequired<GISelKnownBitsAnalysis>(); 525 AU.addPreserved<GISelKnownBitsAnalysis>(); 526 if (!IsOptNone) { 527 AU.addRequired<MachineDominatorTreeWrapperPass>(); 528 AU.addPreserved<MachineDominatorTreeWrapperPass>(); 529 AU.addRequired<GISelCSEAnalysisWrapperPass>(); 530 AU.addPreserved<GISelCSEAnalysisWrapperPass>(); 531 } 532 MachineFunctionPass::getAnalysisUsage(AU); 533 } 534 535 AArch64PostLegalizerCombiner::AArch64PostLegalizerCombiner(bool IsOptNone) 536 : MachineFunctionPass(ID), IsOptNone(IsOptNone) { 537 initializeAArch64PostLegalizerCombinerPass(*PassRegistry::getPassRegistry()); 538 539 if (!RuleConfig.parseCommandLineOption()) 540 report_fatal_error("Invalid rule identifier"); 541 } 542 543 bool AArch64PostLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) { 544 if (MF.getProperties().hasProperty( 545 MachineFunctionProperties::Property::FailedISel)) 546 return false; 547 assert(MF.getProperties().hasProperty( 548 MachineFunctionProperties::Property::Legalized) && 549 "Expected a legalized function?"); 550 auto *TPC = &getAnalysis<TargetPassConfig>(); 551 const Function &F = MF.getFunction(); 552 bool EnableOpt = 553 MF.getTarget().getOptLevel() != CodeGenOptLevel::None && !skipFunction(F); 554 555 const AArch64Subtarget &ST = MF.getSubtarget<AArch64Subtarget>(); 556 const auto *LI = ST.getLegalizerInfo(); 557 558 GISelKnownBits *KB = &getAnalysis<GISelKnownBitsAnalysis>().get(MF); 559 MachineDominatorTree *MDT = 560 IsOptNone ? nullptr 561 : &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree(); 562 GISelCSEAnalysisWrapper &Wrapper = 563 getAnalysis<GISelCSEAnalysisWrapperPass>().getCSEWrapper(); 564 auto *CSEInfo = &Wrapper.get(TPC->getCSEConfig()); 565 566 CombinerInfo CInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false, 567 /*LegalizerInfo*/ nullptr, EnableOpt, F.hasOptSize(), 568 F.hasMinSize()); 569 AArch64PostLegalizerCombinerImpl Impl(MF, CInfo, TPC, *KB, CSEInfo, 570 RuleConfig, ST, MDT, LI); 571 bool Changed = Impl.combineMachineInstrs(); 572 573 auto MIB = CSEMIRBuilder(MF); 574 MIB.setCSEInfo(CSEInfo); 575 Changed |= optimizeConsecutiveMemOpAddressing(MF, MIB); 576 return Changed; 577 } 578 579 bool AArch64PostLegalizerCombiner::tryOptimizeConsecStores( 580 SmallVectorImpl<StoreInfo> &Stores, CSEMIRBuilder &MIB) { 581 if (Stores.size() <= 2) 582 return false; 583 584 // Profitabity checks: 585 int64_t BaseOffset = Stores[0].Offset; 586 unsigned NumPairsExpected = Stores.size() / 2; 587 unsigned TotalInstsExpected = NumPairsExpected + (Stores.size() % 2); 588 // Size savings will depend on whether we can fold the offset, as an 589 // immediate of an ADD. 590 auto &TLI = *MIB.getMF().getSubtarget().getTargetLowering(); 591 if (!TLI.isLegalAddImmediate(BaseOffset)) 592 TotalInstsExpected++; 593 int SavingsExpected = Stores.size() - TotalInstsExpected; 594 if (SavingsExpected <= 0) 595 return false; 596 597 auto &MRI = MIB.getMF().getRegInfo(); 598 599 // We have a series of consecutive stores. Factor out the common base 600 // pointer and rewrite the offsets. 601 Register NewBase = Stores[0].Ptr->getReg(0); 602 for (auto &SInfo : Stores) { 603 // Compute a new pointer with the new base ptr and adjusted offset. 604 MIB.setInstrAndDebugLoc(*SInfo.St); 605 auto NewOff = MIB.buildConstant(LLT::scalar(64), SInfo.Offset - BaseOffset); 606 auto NewPtr = MIB.buildPtrAdd(MRI.getType(SInfo.St->getPointerReg()), 607 NewBase, NewOff); 608 if (MIB.getObserver()) 609 MIB.getObserver()->changingInstr(*SInfo.St); 610 SInfo.St->getOperand(1).setReg(NewPtr.getReg(0)); 611 if (MIB.getObserver()) 612 MIB.getObserver()->changedInstr(*SInfo.St); 613 } 614 LLVM_DEBUG(dbgs() << "Split a series of " << Stores.size() 615 << " stores into a base pointer and offsets.\n"); 616 return true; 617 } 618 619 static cl::opt<bool> 620 EnableConsecutiveMemOpOpt("aarch64-postlegalizer-consecutive-memops", 621 cl::init(true), cl::Hidden, 622 cl::desc("Enable consecutive memop optimization " 623 "in AArch64PostLegalizerCombiner")); 624 625 bool AArch64PostLegalizerCombiner::optimizeConsecutiveMemOpAddressing( 626 MachineFunction &MF, CSEMIRBuilder &MIB) { 627 // This combine needs to run after all reassociations/folds on pointer 628 // addressing have been done, specifically those that combine two G_PTR_ADDs 629 // with constant offsets into a single G_PTR_ADD with a combined offset. 630 // The goal of this optimization is to undo that combine in the case where 631 // doing so has prevented the formation of pair stores due to illegal 632 // addressing modes of STP. The reason that we do it here is because 633 // it's much easier to undo the transformation of a series consecutive 634 // mem ops, than it is to detect when doing it would be a bad idea looking 635 // at a single G_PTR_ADD in the reassociation/ptradd_immed_chain combine. 636 // 637 // An example: 638 // G_STORE %11:_(<2 x s64>), %base:_(p0) :: (store (<2 x s64>), align 1) 639 // %off1:_(s64) = G_CONSTANT i64 4128 640 // %p1:_(p0) = G_PTR_ADD %0:_, %off1:_(s64) 641 // G_STORE %11:_(<2 x s64>), %p1:_(p0) :: (store (<2 x s64>), align 1) 642 // %off2:_(s64) = G_CONSTANT i64 4144 643 // %p2:_(p0) = G_PTR_ADD %0:_, %off2:_(s64) 644 // G_STORE %11:_(<2 x s64>), %p2:_(p0) :: (store (<2 x s64>), align 1) 645 // %off3:_(s64) = G_CONSTANT i64 4160 646 // %p3:_(p0) = G_PTR_ADD %0:_, %off3:_(s64) 647 // G_STORE %11:_(<2 x s64>), %17:_(p0) :: (store (<2 x s64>), align 1) 648 bool Changed = false; 649 auto &MRI = MF.getRegInfo(); 650 651 if (!EnableConsecutiveMemOpOpt) 652 return Changed; 653 654 SmallVector<StoreInfo, 8> Stores; 655 // If we see a load, then we keep track of any values defined by it. 656 // In the following example, STP formation will fail anyway because 657 // the latter store is using a load result that appears after the 658 // the prior store. In this situation if we factor out the offset then 659 // we increase code size for no benefit. 660 // G_STORE %v1:_(s64), %base:_(p0) :: (store (s64)) 661 // %v2:_(s64) = G_LOAD %ldptr:_(p0) :: (load (s64)) 662 // G_STORE %v2:_(s64), %base:_(p0) :: (store (s64)) 663 SmallVector<Register> LoadValsSinceLastStore; 664 665 auto storeIsValid = [&](StoreInfo &Last, StoreInfo New) { 666 // Check if this store is consecutive to the last one. 667 if (Last.Ptr->getBaseReg() != New.Ptr->getBaseReg() || 668 (Last.Offset + static_cast<int64_t>(Last.StoredType.getSizeInBytes()) != 669 New.Offset) || 670 Last.StoredType != New.StoredType) 671 return false; 672 673 // Check if this store is using a load result that appears after the 674 // last store. If so, bail out. 675 if (any_of(LoadValsSinceLastStore, [&](Register LoadVal) { 676 return New.St->getValueReg() == LoadVal; 677 })) 678 return false; 679 680 // Check if the current offset would be too large for STP. 681 // If not, then STP formation should be able to handle it, so we don't 682 // need to do anything. 683 int64_t MaxLegalOffset; 684 switch (New.StoredType.getSizeInBits()) { 685 case 32: 686 MaxLegalOffset = 252; 687 break; 688 case 64: 689 MaxLegalOffset = 504; 690 break; 691 case 128: 692 MaxLegalOffset = 1008; 693 break; 694 default: 695 llvm_unreachable("Unexpected stored type size"); 696 } 697 if (New.Offset < MaxLegalOffset) 698 return false; 699 700 // If factoring it out still wouldn't help then don't bother. 701 return New.Offset - Stores[0].Offset <= MaxLegalOffset; 702 }; 703 704 auto resetState = [&]() { 705 Stores.clear(); 706 LoadValsSinceLastStore.clear(); 707 }; 708 709 for (auto &MBB : MF) { 710 // We're looking inside a single BB at a time since the memset pattern 711 // should only be in a single block. 712 resetState(); 713 for (auto &MI : MBB) { 714 // Skip for scalable vectors 715 if (auto *LdSt = dyn_cast<GLoadStore>(&MI); 716 LdSt && MRI.getType(LdSt->getOperand(0).getReg()).isScalableVector()) 717 continue; 718 719 if (auto *St = dyn_cast<GStore>(&MI)) { 720 Register PtrBaseReg; 721 APInt Offset; 722 LLT StoredValTy = MRI.getType(St->getValueReg()); 723 unsigned ValSize = StoredValTy.getSizeInBits(); 724 if (ValSize < 32 || St->getMMO().getSizeInBits() != ValSize) 725 continue; 726 727 Register PtrReg = St->getPointerReg(); 728 if (mi_match( 729 PtrReg, MRI, 730 m_OneNonDBGUse(m_GPtrAdd(m_Reg(PtrBaseReg), m_ICst(Offset))))) { 731 GPtrAdd *PtrAdd = cast<GPtrAdd>(MRI.getVRegDef(PtrReg)); 732 StoreInfo New = {St, PtrAdd, Offset.getSExtValue(), StoredValTy}; 733 734 if (Stores.empty()) { 735 Stores.push_back(New); 736 continue; 737 } 738 739 // Check if this store is a valid continuation of the sequence. 740 auto &Last = Stores.back(); 741 if (storeIsValid(Last, New)) { 742 Stores.push_back(New); 743 LoadValsSinceLastStore.clear(); // Reset the load value tracking. 744 } else { 745 // The store isn't a valid to consider for the prior sequence, 746 // so try to optimize what we have so far and start a new sequence. 747 Changed |= tryOptimizeConsecStores(Stores, MIB); 748 resetState(); 749 Stores.push_back(New); 750 } 751 } 752 } else if (auto *Ld = dyn_cast<GLoad>(&MI)) { 753 LoadValsSinceLastStore.push_back(Ld->getDstReg()); 754 } 755 } 756 Changed |= tryOptimizeConsecStores(Stores, MIB); 757 resetState(); 758 } 759 760 return Changed; 761 } 762 763 char AArch64PostLegalizerCombiner::ID = 0; 764 INITIALIZE_PASS_BEGIN(AArch64PostLegalizerCombiner, DEBUG_TYPE, 765 "Combine AArch64 MachineInstrs after legalization", false, 766 false) 767 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 768 INITIALIZE_PASS_DEPENDENCY(GISelKnownBitsAnalysis) 769 INITIALIZE_PASS_END(AArch64PostLegalizerCombiner, DEBUG_TYPE, 770 "Combine AArch64 MachineInstrs after legalization", false, 771 false) 772 773 namespace llvm { 774 FunctionPass *createAArch64PostLegalizerCombiner(bool IsOptNone) { 775 return new AArch64PostLegalizerCombiner(IsOptNone); 776 } 777 } // end namespace llvm 778