1 //=== AArch64PostLegalizerCombiner.cpp --------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// Post-legalization combines on generic MachineInstrs. 11 /// 12 /// The combines here must preserve instruction legality. 13 /// 14 /// Lowering combines (e.g. pseudo matching) should be handled by 15 /// AArch64PostLegalizerLowering. 16 /// 17 /// Combines which don't rely on instruction legality should go in the 18 /// AArch64PreLegalizerCombiner. 19 /// 20 //===----------------------------------------------------------------------===// 21 22 #include "AArch64TargetMachine.h" 23 #include "llvm/CodeGen/GlobalISel/CSEInfo.h" 24 #include "llvm/CodeGen/GlobalISel/Combiner.h" 25 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h" 26 #include "llvm/CodeGen/GlobalISel/CombinerInfo.h" 27 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h" 28 #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h" 29 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" 30 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" 31 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 32 #include "llvm/CodeGen/GlobalISel/Utils.h" 33 #include "llvm/CodeGen/MachineDominators.h" 34 #include "llvm/CodeGen/MachineFunctionPass.h" 35 #include "llvm/CodeGen/MachineRegisterInfo.h" 36 #include "llvm/CodeGen/TargetOpcodes.h" 37 #include "llvm/CodeGen/TargetPassConfig.h" 38 #include "llvm/Support/Debug.h" 39 40 #define DEBUG_TYPE "aarch64-postlegalizer-combiner" 41 42 using namespace llvm; 43 using namespace MIPatternMatch; 44 45 /// This combine tries do what performExtractVectorEltCombine does in SDAG. 46 /// Rewrite for pairwise fadd pattern 47 /// (s32 (g_extract_vector_elt 48 /// (g_fadd (vXs32 Other) 49 /// (g_vector_shuffle (vXs32 Other) undef <1,X,...> )) 0)) 50 /// -> 51 /// (s32 (g_fadd (g_extract_vector_elt (vXs32 Other) 0) 52 /// (g_extract_vector_elt (vXs32 Other) 1)) 53 bool matchExtractVecEltPairwiseAdd( 54 MachineInstr &MI, MachineRegisterInfo &MRI, 55 std::tuple<unsigned, LLT, Register> &MatchInfo) { 56 Register Src1 = MI.getOperand(1).getReg(); 57 Register Src2 = MI.getOperand(2).getReg(); 58 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 59 60 auto Cst = getIConstantVRegValWithLookThrough(Src2, MRI); 61 if (!Cst || Cst->Value != 0) 62 return false; 63 // SDAG also checks for FullFP16, but this looks to be beneficial anyway. 64 65 // Now check for an fadd operation. TODO: expand this for integer add? 66 auto *FAddMI = getOpcodeDef(TargetOpcode::G_FADD, Src1, MRI); 67 if (!FAddMI) 68 return false; 69 70 // If we add support for integer add, must restrict these types to just s64. 71 unsigned DstSize = DstTy.getSizeInBits(); 72 if (DstSize != 16 && DstSize != 32 && DstSize != 64) 73 return false; 74 75 Register Src1Op1 = FAddMI->getOperand(1).getReg(); 76 Register Src1Op2 = FAddMI->getOperand(2).getReg(); 77 MachineInstr *Shuffle = 78 getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op2, MRI); 79 MachineInstr *Other = MRI.getVRegDef(Src1Op1); 80 if (!Shuffle) { 81 Shuffle = getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op1, MRI); 82 Other = MRI.getVRegDef(Src1Op2); 83 } 84 85 // We're looking for a shuffle that moves the second element to index 0. 86 if (Shuffle && Shuffle->getOperand(3).getShuffleMask()[0] == 1 && 87 Other == MRI.getVRegDef(Shuffle->getOperand(1).getReg())) { 88 std::get<0>(MatchInfo) = TargetOpcode::G_FADD; 89 std::get<1>(MatchInfo) = DstTy; 90 std::get<2>(MatchInfo) = Other->getOperand(0).getReg(); 91 return true; 92 } 93 return false; 94 } 95 96 bool applyExtractVecEltPairwiseAdd( 97 MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B, 98 std::tuple<unsigned, LLT, Register> &MatchInfo) { 99 unsigned Opc = std::get<0>(MatchInfo); 100 assert(Opc == TargetOpcode::G_FADD && "Unexpected opcode!"); 101 // We want to generate two extracts of elements 0 and 1, and add them. 102 LLT Ty = std::get<1>(MatchInfo); 103 Register Src = std::get<2>(MatchInfo); 104 LLT s64 = LLT::scalar(64); 105 B.setInstrAndDebugLoc(MI); 106 auto Elt0 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 0)); 107 auto Elt1 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 1)); 108 B.buildInstr(Opc, {MI.getOperand(0).getReg()}, {Elt0, Elt1}); 109 MI.eraseFromParent(); 110 return true; 111 } 112 113 static bool isSignExtended(Register R, MachineRegisterInfo &MRI) { 114 // TODO: check if extended build vector as well. 115 unsigned Opc = MRI.getVRegDef(R)->getOpcode(); 116 return Opc == TargetOpcode::G_SEXT || Opc == TargetOpcode::G_SEXT_INREG; 117 } 118 119 static bool isZeroExtended(Register R, MachineRegisterInfo &MRI) { 120 // TODO: check if extended build vector as well. 121 return MRI.getVRegDef(R)->getOpcode() == TargetOpcode::G_ZEXT; 122 } 123 124 bool matchAArch64MulConstCombine( 125 MachineInstr &MI, MachineRegisterInfo &MRI, 126 std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) { 127 assert(MI.getOpcode() == TargetOpcode::G_MUL); 128 Register LHS = MI.getOperand(1).getReg(); 129 Register RHS = MI.getOperand(2).getReg(); 130 Register Dst = MI.getOperand(0).getReg(); 131 const LLT Ty = MRI.getType(LHS); 132 133 // The below optimizations require a constant RHS. 134 auto Const = getIConstantVRegValWithLookThrough(RHS, MRI); 135 if (!Const) 136 return false; 137 138 APInt ConstValue = Const->Value.sext(Ty.getSizeInBits()); 139 // The following code is ported from AArch64ISelLowering. 140 // Multiplication of a power of two plus/minus one can be done more 141 // cheaply as as shift+add/sub. For now, this is true unilaterally. If 142 // future CPUs have a cheaper MADD instruction, this may need to be 143 // gated on a subtarget feature. For Cyclone, 32-bit MADD is 4 cycles and 144 // 64-bit is 5 cycles, so this is always a win. 145 // More aggressively, some multiplications N0 * C can be lowered to 146 // shift+add+shift if the constant C = A * B where A = 2^N + 1 and B = 2^M, 147 // e.g. 6=3*2=(2+1)*2. 148 // TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45 149 // which equals to (1+2)*16-(1+2). 150 // TrailingZeroes is used to test if the mul can be lowered to 151 // shift+add+shift. 152 unsigned TrailingZeroes = ConstValue.countTrailingZeros(); 153 if (TrailingZeroes) { 154 // Conservatively do not lower to shift+add+shift if the mul might be 155 // folded into smul or umul. 156 if (MRI.hasOneNonDBGUse(LHS) && 157 (isSignExtended(LHS, MRI) || isZeroExtended(LHS, MRI))) 158 return false; 159 // Conservatively do not lower to shift+add+shift if the mul might be 160 // folded into madd or msub. 161 if (MRI.hasOneNonDBGUse(Dst)) { 162 MachineInstr &UseMI = *MRI.use_instr_begin(Dst); 163 unsigned UseOpc = UseMI.getOpcode(); 164 if (UseOpc == TargetOpcode::G_ADD || UseOpc == TargetOpcode::G_PTR_ADD || 165 UseOpc == TargetOpcode::G_SUB) 166 return false; 167 } 168 } 169 // Use ShiftedConstValue instead of ConstValue to support both shift+add/sub 170 // and shift+add+shift. 171 APInt ShiftedConstValue = ConstValue.ashr(TrailingZeroes); 172 173 unsigned ShiftAmt, AddSubOpc; 174 // Is the shifted value the LHS operand of the add/sub? 175 bool ShiftValUseIsLHS = true; 176 // Do we need to negate the result? 177 bool NegateResult = false; 178 179 if (ConstValue.isNonNegative()) { 180 // (mul x, 2^N + 1) => (add (shl x, N), x) 181 // (mul x, 2^N - 1) => (sub (shl x, N), x) 182 // (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M) 183 APInt SCVMinus1 = ShiftedConstValue - 1; 184 APInt CVPlus1 = ConstValue + 1; 185 if (SCVMinus1.isPowerOf2()) { 186 ShiftAmt = SCVMinus1.logBase2(); 187 AddSubOpc = TargetOpcode::G_ADD; 188 } else if (CVPlus1.isPowerOf2()) { 189 ShiftAmt = CVPlus1.logBase2(); 190 AddSubOpc = TargetOpcode::G_SUB; 191 } else 192 return false; 193 } else { 194 // (mul x, -(2^N - 1)) => (sub x, (shl x, N)) 195 // (mul x, -(2^N + 1)) => - (add (shl x, N), x) 196 APInt CVNegPlus1 = -ConstValue + 1; 197 APInt CVNegMinus1 = -ConstValue - 1; 198 if (CVNegPlus1.isPowerOf2()) { 199 ShiftAmt = CVNegPlus1.logBase2(); 200 AddSubOpc = TargetOpcode::G_SUB; 201 ShiftValUseIsLHS = false; 202 } else if (CVNegMinus1.isPowerOf2()) { 203 ShiftAmt = CVNegMinus1.logBase2(); 204 AddSubOpc = TargetOpcode::G_ADD; 205 NegateResult = true; 206 } else 207 return false; 208 } 209 210 if (NegateResult && TrailingZeroes) 211 return false; 212 213 ApplyFn = [=](MachineIRBuilder &B, Register DstReg) { 214 auto Shift = B.buildConstant(LLT::scalar(64), ShiftAmt); 215 auto ShiftedVal = B.buildShl(Ty, LHS, Shift); 216 217 Register AddSubLHS = ShiftValUseIsLHS ? ShiftedVal.getReg(0) : LHS; 218 Register AddSubRHS = ShiftValUseIsLHS ? LHS : ShiftedVal.getReg(0); 219 auto Res = B.buildInstr(AddSubOpc, {Ty}, {AddSubLHS, AddSubRHS}); 220 assert(!(NegateResult && TrailingZeroes) && 221 "NegateResult and TrailingZeroes cannot both be true for now."); 222 // Negate the result. 223 if (NegateResult) { 224 B.buildSub(DstReg, B.buildConstant(Ty, 0), Res); 225 return; 226 } 227 // Shift the result. 228 if (TrailingZeroes) { 229 B.buildShl(DstReg, Res, B.buildConstant(LLT::scalar(64), TrailingZeroes)); 230 return; 231 } 232 B.buildCopy(DstReg, Res.getReg(0)); 233 }; 234 return true; 235 } 236 237 bool applyAArch64MulConstCombine( 238 MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B, 239 std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) { 240 B.setInstrAndDebugLoc(MI); 241 ApplyFn(B, MI.getOperand(0).getReg()); 242 MI.eraseFromParent(); 243 return true; 244 } 245 246 /// Try to fold a G_MERGE_VALUES of 2 s32 sources, where the second source 247 /// is a zero, into a G_ZEXT of the first. 248 bool matchFoldMergeToZext(MachineInstr &MI, MachineRegisterInfo &MRI) { 249 auto &Merge = cast<GMerge>(MI); 250 LLT SrcTy = MRI.getType(Merge.getSourceReg(0)); 251 if (SrcTy != LLT::scalar(32) || Merge.getNumSources() != 2) 252 return false; 253 return mi_match(Merge.getSourceReg(1), MRI, m_SpecificICst(0)); 254 } 255 256 void applyFoldMergeToZext(MachineInstr &MI, MachineRegisterInfo &MRI, 257 MachineIRBuilder &B, GISelChangeObserver &Observer) { 258 // Mutate %d(s64) = G_MERGE_VALUES %a(s32), 0(s32) 259 // -> 260 // %d(s64) = G_ZEXT %a(s32) 261 Observer.changingInstr(MI); 262 MI.setDesc(B.getTII().get(TargetOpcode::G_ZEXT)); 263 MI.removeOperand(2); 264 Observer.changedInstr(MI); 265 } 266 267 /// \returns True if a G_ANYEXT instruction \p MI should be mutated to a G_ZEXT 268 /// instruction. 269 static bool matchMutateAnyExtToZExt(MachineInstr &MI, MachineRegisterInfo &MRI) { 270 // If this is coming from a scalar compare then we can use a G_ZEXT instead of 271 // a G_ANYEXT: 272 // 273 // %cmp:_(s32) = G_[I|F]CMP ... <-- produces 0/1. 274 // %ext:_(s64) = G_ANYEXT %cmp(s32) 275 // 276 // By doing this, we can leverage more KnownBits combines. 277 assert(MI.getOpcode() == TargetOpcode::G_ANYEXT); 278 Register Dst = MI.getOperand(0).getReg(); 279 Register Src = MI.getOperand(1).getReg(); 280 return MRI.getType(Dst).isScalar() && 281 mi_match(Src, MRI, 282 m_any_of(m_GICmp(m_Pred(), m_Reg(), m_Reg()), 283 m_GFCmp(m_Pred(), m_Reg(), m_Reg()))); 284 } 285 286 static void applyMutateAnyExtToZExt(MachineInstr &MI, MachineRegisterInfo &MRI, 287 MachineIRBuilder &B, 288 GISelChangeObserver &Observer) { 289 Observer.changingInstr(MI); 290 MI.setDesc(B.getTII().get(TargetOpcode::G_ZEXT)); 291 Observer.changedInstr(MI); 292 } 293 294 /// Match a 128b store of zero and split it into two 64 bit stores, for 295 /// size/performance reasons. 296 static bool matchSplitStoreZero128(MachineInstr &MI, MachineRegisterInfo &MRI) { 297 GStore &Store = cast<GStore>(MI); 298 if (!Store.isSimple()) 299 return false; 300 LLT ValTy = MRI.getType(Store.getValueReg()); 301 if (!ValTy.isVector() || ValTy.getSizeInBits() != 128) 302 return false; 303 if (ValTy.getSizeInBits() != Store.getMemSizeInBits()) 304 return false; // Don't split truncating stores. 305 if (!MRI.hasOneNonDBGUse(Store.getValueReg())) 306 return false; 307 auto MaybeCst = isConstantOrConstantSplatVector( 308 *MRI.getVRegDef(Store.getValueReg()), MRI); 309 return MaybeCst && MaybeCst->isZero(); 310 } 311 312 static void applySplitStoreZero128(MachineInstr &MI, MachineRegisterInfo &MRI, 313 MachineIRBuilder &B, 314 GISelChangeObserver &Observer) { 315 B.setInstrAndDebugLoc(MI); 316 GStore &Store = cast<GStore>(MI); 317 assert(MRI.getType(Store.getValueReg()).isVector() && 318 "Expected a vector store value"); 319 LLT NewTy = LLT::scalar(64); 320 Register PtrReg = Store.getPointerReg(); 321 auto Zero = B.buildConstant(NewTy, 0); 322 auto HighPtr = B.buildPtrAdd(MRI.getType(PtrReg), PtrReg, 323 B.buildConstant(LLT::scalar(64), 8)); 324 auto &MF = *MI.getMF(); 325 auto *LowMMO = MF.getMachineMemOperand(&Store.getMMO(), 0, NewTy); 326 auto *HighMMO = MF.getMachineMemOperand(&Store.getMMO(), 8, NewTy); 327 B.buildStore(Zero, PtrReg, *LowMMO); 328 B.buildStore(Zero, HighPtr, *HighMMO); 329 Store.eraseFromParent(); 330 } 331 332 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS 333 #include "AArch64GenPostLegalizeGICombiner.inc" 334 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS 335 336 namespace { 337 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_H 338 #include "AArch64GenPostLegalizeGICombiner.inc" 339 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_H 340 341 class AArch64PostLegalizerCombinerInfo : public CombinerInfo { 342 GISelKnownBits *KB; 343 MachineDominatorTree *MDT; 344 345 public: 346 AArch64GenPostLegalizerCombinerHelperRuleConfig GeneratedRuleCfg; 347 348 AArch64PostLegalizerCombinerInfo(bool EnableOpt, bool OptSize, bool MinSize, 349 GISelKnownBits *KB, 350 MachineDominatorTree *MDT) 351 : CombinerInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false, 352 /*LegalizerInfo*/ nullptr, EnableOpt, OptSize, MinSize), 353 KB(KB), MDT(MDT) { 354 if (!GeneratedRuleCfg.parseCommandLineOption()) 355 report_fatal_error("Invalid rule identifier"); 356 } 357 358 bool combine(GISelChangeObserver &Observer, MachineInstr &MI, 359 MachineIRBuilder &B) const override; 360 }; 361 362 bool AArch64PostLegalizerCombinerInfo::combine(GISelChangeObserver &Observer, 363 MachineInstr &MI, 364 MachineIRBuilder &B) const { 365 const auto *LI = 366 MI.getParent()->getParent()->getSubtarget().getLegalizerInfo(); 367 CombinerHelper Helper(Observer, B, /*IsPreLegalize*/ false, KB, MDT, LI); 368 AArch64GenPostLegalizerCombinerHelper Generated(GeneratedRuleCfg); 369 return Generated.tryCombineAll(Observer, MI, B, Helper); 370 } 371 372 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_CPP 373 #include "AArch64GenPostLegalizeGICombiner.inc" 374 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_CPP 375 376 class AArch64PostLegalizerCombiner : public MachineFunctionPass { 377 public: 378 static char ID; 379 380 AArch64PostLegalizerCombiner(bool IsOptNone = false); 381 382 StringRef getPassName() const override { 383 return "AArch64PostLegalizerCombiner"; 384 } 385 386 bool runOnMachineFunction(MachineFunction &MF) override; 387 void getAnalysisUsage(AnalysisUsage &AU) const override; 388 389 private: 390 bool IsOptNone; 391 }; 392 } // end anonymous namespace 393 394 void AArch64PostLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const { 395 AU.addRequired<TargetPassConfig>(); 396 AU.setPreservesCFG(); 397 getSelectionDAGFallbackAnalysisUsage(AU); 398 AU.addRequired<GISelKnownBitsAnalysis>(); 399 AU.addPreserved<GISelKnownBitsAnalysis>(); 400 if (!IsOptNone) { 401 AU.addRequired<MachineDominatorTree>(); 402 AU.addPreserved<MachineDominatorTree>(); 403 AU.addRequired<GISelCSEAnalysisWrapperPass>(); 404 AU.addPreserved<GISelCSEAnalysisWrapperPass>(); 405 } 406 MachineFunctionPass::getAnalysisUsage(AU); 407 } 408 409 AArch64PostLegalizerCombiner::AArch64PostLegalizerCombiner(bool IsOptNone) 410 : MachineFunctionPass(ID), IsOptNone(IsOptNone) { 411 initializeAArch64PostLegalizerCombinerPass(*PassRegistry::getPassRegistry()); 412 } 413 414 bool AArch64PostLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) { 415 if (MF.getProperties().hasProperty( 416 MachineFunctionProperties::Property::FailedISel)) 417 return false; 418 assert(MF.getProperties().hasProperty( 419 MachineFunctionProperties::Property::Legalized) && 420 "Expected a legalized function?"); 421 auto *TPC = &getAnalysis<TargetPassConfig>(); 422 const Function &F = MF.getFunction(); 423 bool EnableOpt = 424 MF.getTarget().getOptLevel() != CodeGenOpt::None && !skipFunction(F); 425 GISelKnownBits *KB = &getAnalysis<GISelKnownBitsAnalysis>().get(MF); 426 MachineDominatorTree *MDT = 427 IsOptNone ? nullptr : &getAnalysis<MachineDominatorTree>(); 428 AArch64PostLegalizerCombinerInfo PCInfo(EnableOpt, F.hasOptSize(), 429 F.hasMinSize(), KB, MDT); 430 GISelCSEAnalysisWrapper &Wrapper = 431 getAnalysis<GISelCSEAnalysisWrapperPass>().getCSEWrapper(); 432 auto *CSEInfo = &Wrapper.get(TPC->getCSEConfig()); 433 Combiner C(PCInfo, TPC); 434 return C.combineMachineInstrs(MF, CSEInfo); 435 } 436 437 char AArch64PostLegalizerCombiner::ID = 0; 438 INITIALIZE_PASS_BEGIN(AArch64PostLegalizerCombiner, DEBUG_TYPE, 439 "Combine AArch64 MachineInstrs after legalization", false, 440 false) 441 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 442 INITIALIZE_PASS_DEPENDENCY(GISelKnownBitsAnalysis) 443 INITIALIZE_PASS_END(AArch64PostLegalizerCombiner, DEBUG_TYPE, 444 "Combine AArch64 MachineInstrs after legalization", false, 445 false) 446 447 namespace llvm { 448 FunctionPass *createAArch64PostLegalizerCombiner(bool IsOptNone) { 449 return new AArch64PostLegalizerCombiner(IsOptNone); 450 } 451 } // end namespace llvm 452