1 //=== AArch64PostLegalizerCombiner.cpp --------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// Post-legalization combines on generic MachineInstrs. 11 /// 12 /// The combines here must preserve instruction legality. 13 /// 14 /// Lowering combines (e.g. pseudo matching) should be handled by 15 /// AArch64PostLegalizerLowering. 16 /// 17 /// Combines which don't rely on instruction legality should go in the 18 /// AArch64PreLegalizerCombiner. 19 /// 20 //===----------------------------------------------------------------------===// 21 22 #include "AArch64TargetMachine.h" 23 #include "llvm/CodeGen/GlobalISel/Combiner.h" 24 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h" 25 #include "llvm/CodeGen/GlobalISel/CombinerInfo.h" 26 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h" 27 #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h" 28 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" 29 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 30 #include "llvm/CodeGen/GlobalISel/Utils.h" 31 #include "llvm/CodeGen/MachineDominators.h" 32 #include "llvm/CodeGen/MachineFunctionPass.h" 33 #include "llvm/CodeGen/MachineRegisterInfo.h" 34 #include "llvm/CodeGen/TargetOpcodes.h" 35 #include "llvm/CodeGen/TargetPassConfig.h" 36 #include "llvm/Support/Debug.h" 37 38 #define DEBUG_TYPE "aarch64-postlegalizer-combiner" 39 40 using namespace llvm; 41 using namespace MIPatternMatch; 42 43 /// This combine tries do what performExtractVectorEltCombine does in SDAG. 44 /// Rewrite for pairwise fadd pattern 45 /// (s32 (g_extract_vector_elt 46 /// (g_fadd (vXs32 Other) 47 /// (g_vector_shuffle (vXs32 Other) undef <1,X,...> )) 0)) 48 /// -> 49 /// (s32 (g_fadd (g_extract_vector_elt (vXs32 Other) 0) 50 /// (g_extract_vector_elt (vXs32 Other) 1)) 51 bool matchExtractVecEltPairwiseAdd( 52 MachineInstr &MI, MachineRegisterInfo &MRI, 53 std::tuple<unsigned, LLT, Register> &MatchInfo) { 54 Register Src1 = MI.getOperand(1).getReg(); 55 Register Src2 = MI.getOperand(2).getReg(); 56 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 57 58 auto Cst = getIConstantVRegValWithLookThrough(Src2, MRI); 59 if (!Cst || Cst->Value != 0) 60 return false; 61 // SDAG also checks for FullFP16, but this looks to be beneficial anyway. 62 63 // Now check for an fadd operation. TODO: expand this for integer add? 64 auto *FAddMI = getOpcodeDef(TargetOpcode::G_FADD, Src1, MRI); 65 if (!FAddMI) 66 return false; 67 68 // If we add support for integer add, must restrict these types to just s64. 69 unsigned DstSize = DstTy.getSizeInBits(); 70 if (DstSize != 16 && DstSize != 32 && DstSize != 64) 71 return false; 72 73 Register Src1Op1 = FAddMI->getOperand(1).getReg(); 74 Register Src1Op2 = FAddMI->getOperand(2).getReg(); 75 MachineInstr *Shuffle = 76 getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op2, MRI); 77 MachineInstr *Other = MRI.getVRegDef(Src1Op1); 78 if (!Shuffle) { 79 Shuffle = getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op1, MRI); 80 Other = MRI.getVRegDef(Src1Op2); 81 } 82 83 // We're looking for a shuffle that moves the second element to index 0. 84 if (Shuffle && Shuffle->getOperand(3).getShuffleMask()[0] == 1 && 85 Other == MRI.getVRegDef(Shuffle->getOperand(1).getReg())) { 86 std::get<0>(MatchInfo) = TargetOpcode::G_FADD; 87 std::get<1>(MatchInfo) = DstTy; 88 std::get<2>(MatchInfo) = Other->getOperand(0).getReg(); 89 return true; 90 } 91 return false; 92 } 93 94 bool applyExtractVecEltPairwiseAdd( 95 MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B, 96 std::tuple<unsigned, LLT, Register> &MatchInfo) { 97 unsigned Opc = std::get<0>(MatchInfo); 98 assert(Opc == TargetOpcode::G_FADD && "Unexpected opcode!"); 99 // We want to generate two extracts of elements 0 and 1, and add them. 100 LLT Ty = std::get<1>(MatchInfo); 101 Register Src = std::get<2>(MatchInfo); 102 LLT s64 = LLT::scalar(64); 103 B.setInstrAndDebugLoc(MI); 104 auto Elt0 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 0)); 105 auto Elt1 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 1)); 106 B.buildInstr(Opc, {MI.getOperand(0).getReg()}, {Elt0, Elt1}); 107 MI.eraseFromParent(); 108 return true; 109 } 110 111 static bool isSignExtended(Register R, MachineRegisterInfo &MRI) { 112 // TODO: check if extended build vector as well. 113 unsigned Opc = MRI.getVRegDef(R)->getOpcode(); 114 return Opc == TargetOpcode::G_SEXT || Opc == TargetOpcode::G_SEXT_INREG; 115 } 116 117 static bool isZeroExtended(Register R, MachineRegisterInfo &MRI) { 118 // TODO: check if extended build vector as well. 119 return MRI.getVRegDef(R)->getOpcode() == TargetOpcode::G_ZEXT; 120 } 121 122 bool matchAArch64MulConstCombine( 123 MachineInstr &MI, MachineRegisterInfo &MRI, 124 std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) { 125 assert(MI.getOpcode() == TargetOpcode::G_MUL); 126 Register LHS = MI.getOperand(1).getReg(); 127 Register RHS = MI.getOperand(2).getReg(); 128 Register Dst = MI.getOperand(0).getReg(); 129 const LLT Ty = MRI.getType(LHS); 130 131 // The below optimizations require a constant RHS. 132 auto Const = getIConstantVRegValWithLookThrough(RHS, MRI); 133 if (!Const) 134 return false; 135 136 const APInt ConstValue = Const->Value.sextOrSelf(Ty.getSizeInBits()); 137 // The following code is ported from AArch64ISelLowering. 138 // Multiplication of a power of two plus/minus one can be done more 139 // cheaply as as shift+add/sub. For now, this is true unilaterally. If 140 // future CPUs have a cheaper MADD instruction, this may need to be 141 // gated on a subtarget feature. For Cyclone, 32-bit MADD is 4 cycles and 142 // 64-bit is 5 cycles, so this is always a win. 143 // More aggressively, some multiplications N0 * C can be lowered to 144 // shift+add+shift if the constant C = A * B where A = 2^N + 1 and B = 2^M, 145 // e.g. 6=3*2=(2+1)*2. 146 // TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45 147 // which equals to (1+2)*16-(1+2). 148 // TrailingZeroes is used to test if the mul can be lowered to 149 // shift+add+shift. 150 unsigned TrailingZeroes = ConstValue.countTrailingZeros(); 151 if (TrailingZeroes) { 152 // Conservatively do not lower to shift+add+shift if the mul might be 153 // folded into smul or umul. 154 if (MRI.hasOneNonDBGUse(LHS) && 155 (isSignExtended(LHS, MRI) || isZeroExtended(LHS, MRI))) 156 return false; 157 // Conservatively do not lower to shift+add+shift if the mul might be 158 // folded into madd or msub. 159 if (MRI.hasOneNonDBGUse(Dst)) { 160 MachineInstr &UseMI = *MRI.use_instr_begin(Dst); 161 unsigned UseOpc = UseMI.getOpcode(); 162 if (UseOpc == TargetOpcode::G_ADD || UseOpc == TargetOpcode::G_PTR_ADD || 163 UseOpc == TargetOpcode::G_SUB) 164 return false; 165 } 166 } 167 // Use ShiftedConstValue instead of ConstValue to support both shift+add/sub 168 // and shift+add+shift. 169 APInt ShiftedConstValue = ConstValue.ashr(TrailingZeroes); 170 171 unsigned ShiftAmt, AddSubOpc; 172 // Is the shifted value the LHS operand of the add/sub? 173 bool ShiftValUseIsLHS = true; 174 // Do we need to negate the result? 175 bool NegateResult = false; 176 177 if (ConstValue.isNonNegative()) { 178 // (mul x, 2^N + 1) => (add (shl x, N), x) 179 // (mul x, 2^N - 1) => (sub (shl x, N), x) 180 // (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M) 181 APInt SCVMinus1 = ShiftedConstValue - 1; 182 APInt CVPlus1 = ConstValue + 1; 183 if (SCVMinus1.isPowerOf2()) { 184 ShiftAmt = SCVMinus1.logBase2(); 185 AddSubOpc = TargetOpcode::G_ADD; 186 } else if (CVPlus1.isPowerOf2()) { 187 ShiftAmt = CVPlus1.logBase2(); 188 AddSubOpc = TargetOpcode::G_SUB; 189 } else 190 return false; 191 } else { 192 // (mul x, -(2^N - 1)) => (sub x, (shl x, N)) 193 // (mul x, -(2^N + 1)) => - (add (shl x, N), x) 194 APInt CVNegPlus1 = -ConstValue + 1; 195 APInt CVNegMinus1 = -ConstValue - 1; 196 if (CVNegPlus1.isPowerOf2()) { 197 ShiftAmt = CVNegPlus1.logBase2(); 198 AddSubOpc = TargetOpcode::G_SUB; 199 ShiftValUseIsLHS = false; 200 } else if (CVNegMinus1.isPowerOf2()) { 201 ShiftAmt = CVNegMinus1.logBase2(); 202 AddSubOpc = TargetOpcode::G_ADD; 203 NegateResult = true; 204 } else 205 return false; 206 } 207 208 if (NegateResult && TrailingZeroes) 209 return false; 210 211 ApplyFn = [=](MachineIRBuilder &B, Register DstReg) { 212 auto Shift = B.buildConstant(LLT::scalar(64), ShiftAmt); 213 auto ShiftedVal = B.buildShl(Ty, LHS, Shift); 214 215 Register AddSubLHS = ShiftValUseIsLHS ? ShiftedVal.getReg(0) : LHS; 216 Register AddSubRHS = ShiftValUseIsLHS ? LHS : ShiftedVal.getReg(0); 217 auto Res = B.buildInstr(AddSubOpc, {Ty}, {AddSubLHS, AddSubRHS}); 218 assert(!(NegateResult && TrailingZeroes) && 219 "NegateResult and TrailingZeroes cannot both be true for now."); 220 // Negate the result. 221 if (NegateResult) { 222 B.buildSub(DstReg, B.buildConstant(Ty, 0), Res); 223 return; 224 } 225 // Shift the result. 226 if (TrailingZeroes) { 227 B.buildShl(DstReg, Res, B.buildConstant(LLT::scalar(64), TrailingZeroes)); 228 return; 229 } 230 B.buildCopy(DstReg, Res.getReg(0)); 231 }; 232 return true; 233 } 234 235 bool applyAArch64MulConstCombine( 236 MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B, 237 std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) { 238 B.setInstrAndDebugLoc(MI); 239 ApplyFn(B, MI.getOperand(0).getReg()); 240 MI.eraseFromParent(); 241 return true; 242 } 243 244 /// Try to fold a G_MERGE_VALUES of 2 s32 sources, where the second source 245 /// is a zero, into a G_ZEXT of the first. 246 bool matchFoldMergeToZext(MachineInstr &MI, MachineRegisterInfo &MRI) { 247 auto &Merge = cast<GMerge>(MI); 248 LLT SrcTy = MRI.getType(Merge.getSourceReg(0)); 249 if (SrcTy != LLT::scalar(32) || Merge.getNumSources() != 2) 250 return false; 251 return mi_match(Merge.getSourceReg(1), MRI, m_SpecificICst(0)); 252 } 253 254 void applyFoldMergeToZext(MachineInstr &MI, MachineRegisterInfo &MRI, 255 MachineIRBuilder &B, GISelChangeObserver &Observer) { 256 // Mutate %d(s64) = G_MERGE_VALUES %a(s32), 0(s32) 257 // -> 258 // %d(s64) = G_ZEXT %a(s32) 259 Observer.changingInstr(MI); 260 MI.setDesc(B.getTII().get(TargetOpcode::G_ZEXT)); 261 MI.RemoveOperand(2); 262 Observer.changedInstr(MI); 263 } 264 265 /// \returns True if a G_ANYEXT instruction \p MI should be mutated to a G_ZEXT 266 /// instruction. 267 static bool matchMutateAnyExtToZExt(MachineInstr &MI, MachineRegisterInfo &MRI) { 268 // If this is coming from a scalar compare then we can use a G_ZEXT instead of 269 // a G_ANYEXT: 270 // 271 // %cmp:_(s32) = G_[I|F]CMP ... <-- produces 0/1. 272 // %ext:_(s64) = G_ANYEXT %cmp(s32) 273 // 274 // By doing this, we can leverage more KnownBits combines. 275 assert(MI.getOpcode() == TargetOpcode::G_ANYEXT); 276 Register Dst = MI.getOperand(0).getReg(); 277 Register Src = MI.getOperand(1).getReg(); 278 return MRI.getType(Dst).isScalar() && 279 mi_match(Src, MRI, 280 m_any_of(m_GICmp(m_Pred(), m_Reg(), m_Reg()), 281 m_GFCmp(m_Pred(), m_Reg(), m_Reg()))); 282 } 283 284 static void applyMutateAnyExtToZExt(MachineInstr &MI, MachineRegisterInfo &MRI, 285 MachineIRBuilder &B, 286 GISelChangeObserver &Observer) { 287 Observer.changingInstr(MI); 288 MI.setDesc(B.getTII().get(TargetOpcode::G_ZEXT)); 289 Observer.changedInstr(MI); 290 } 291 292 /// Match a 128b store of zero and split it into two 64 bit stores, for 293 /// size/performance reasons. 294 static bool matchSplitStoreZero128(MachineInstr &MI, MachineRegisterInfo &MRI) { 295 GStore &Store = cast<GStore>(MI); 296 if (!Store.isSimple()) 297 return false; 298 LLT ValTy = MRI.getType(Store.getValueReg()); 299 if (!ValTy.isVector() || ValTy.getSizeInBits() != 128) 300 return false; 301 if (ValTy.getSizeInBits() != Store.getMemSizeInBits()) 302 return false; // Don't split truncating stores. 303 if (!MRI.hasOneNonDBGUse(Store.getValueReg())) 304 return false; 305 auto MaybeCst = isConstantOrConstantSplatVector( 306 *MRI.getVRegDef(Store.getValueReg()), MRI); 307 return MaybeCst && MaybeCst->isZero(); 308 } 309 310 static void applySplitStoreZero128(MachineInstr &MI, MachineRegisterInfo &MRI, 311 MachineIRBuilder &B, 312 GISelChangeObserver &Observer) { 313 B.setInstrAndDebugLoc(MI); 314 GStore &Store = cast<GStore>(MI); 315 assert(MRI.getType(Store.getValueReg()).isVector() && 316 "Expected a vector store value"); 317 LLT NewTy = LLT::scalar(64); 318 Register PtrReg = Store.getPointerReg(); 319 auto Zero = B.buildConstant(NewTy, 0); 320 auto HighPtr = B.buildPtrAdd(MRI.getType(PtrReg), PtrReg, 321 B.buildConstant(LLT::scalar(64), 8)); 322 auto &MF = *MI.getMF(); 323 auto *LowMMO = MF.getMachineMemOperand(&Store.getMMO(), 0, NewTy); 324 auto *HighMMO = MF.getMachineMemOperand(&Store.getMMO(), 8, NewTy); 325 B.buildStore(Zero, PtrReg, *LowMMO); 326 B.buildStore(Zero, HighPtr, *HighMMO); 327 Store.eraseFromParent(); 328 } 329 330 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS 331 #include "AArch64GenPostLegalizeGICombiner.inc" 332 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS 333 334 namespace { 335 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_H 336 #include "AArch64GenPostLegalizeGICombiner.inc" 337 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_H 338 339 class AArch64PostLegalizerCombinerInfo : public CombinerInfo { 340 GISelKnownBits *KB; 341 MachineDominatorTree *MDT; 342 343 public: 344 AArch64GenPostLegalizerCombinerHelperRuleConfig GeneratedRuleCfg; 345 346 AArch64PostLegalizerCombinerInfo(bool EnableOpt, bool OptSize, bool MinSize, 347 GISelKnownBits *KB, 348 MachineDominatorTree *MDT) 349 : CombinerInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false, 350 /*LegalizerInfo*/ nullptr, EnableOpt, OptSize, MinSize), 351 KB(KB), MDT(MDT) { 352 if (!GeneratedRuleCfg.parseCommandLineOption()) 353 report_fatal_error("Invalid rule identifier"); 354 } 355 356 virtual bool combine(GISelChangeObserver &Observer, MachineInstr &MI, 357 MachineIRBuilder &B) const override; 358 }; 359 360 bool AArch64PostLegalizerCombinerInfo::combine(GISelChangeObserver &Observer, 361 MachineInstr &MI, 362 MachineIRBuilder &B) const { 363 const auto *LI = 364 MI.getParent()->getParent()->getSubtarget().getLegalizerInfo(); 365 CombinerHelper Helper(Observer, B, KB, MDT, LI); 366 AArch64GenPostLegalizerCombinerHelper Generated(GeneratedRuleCfg); 367 return Generated.tryCombineAll(Observer, MI, B, Helper); 368 } 369 370 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_CPP 371 #include "AArch64GenPostLegalizeGICombiner.inc" 372 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_CPP 373 374 class AArch64PostLegalizerCombiner : public MachineFunctionPass { 375 public: 376 static char ID; 377 378 AArch64PostLegalizerCombiner(bool IsOptNone = false); 379 380 StringRef getPassName() const override { 381 return "AArch64PostLegalizerCombiner"; 382 } 383 384 bool runOnMachineFunction(MachineFunction &MF) override; 385 void getAnalysisUsage(AnalysisUsage &AU) const override; 386 387 private: 388 bool IsOptNone; 389 }; 390 } // end anonymous namespace 391 392 void AArch64PostLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const { 393 AU.addRequired<TargetPassConfig>(); 394 AU.setPreservesCFG(); 395 getSelectionDAGFallbackAnalysisUsage(AU); 396 AU.addRequired<GISelKnownBitsAnalysis>(); 397 AU.addPreserved<GISelKnownBitsAnalysis>(); 398 if (!IsOptNone) { 399 AU.addRequired<MachineDominatorTree>(); 400 AU.addPreserved<MachineDominatorTree>(); 401 AU.addRequired<GISelCSEAnalysisWrapperPass>(); 402 AU.addPreserved<GISelCSEAnalysisWrapperPass>(); 403 } 404 MachineFunctionPass::getAnalysisUsage(AU); 405 } 406 407 AArch64PostLegalizerCombiner::AArch64PostLegalizerCombiner(bool IsOptNone) 408 : MachineFunctionPass(ID), IsOptNone(IsOptNone) { 409 initializeAArch64PostLegalizerCombinerPass(*PassRegistry::getPassRegistry()); 410 } 411 412 bool AArch64PostLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) { 413 if (MF.getProperties().hasProperty( 414 MachineFunctionProperties::Property::FailedISel)) 415 return false; 416 assert(MF.getProperties().hasProperty( 417 MachineFunctionProperties::Property::Legalized) && 418 "Expected a legalized function?"); 419 auto *TPC = &getAnalysis<TargetPassConfig>(); 420 const Function &F = MF.getFunction(); 421 bool EnableOpt = 422 MF.getTarget().getOptLevel() != CodeGenOpt::None && !skipFunction(F); 423 GISelKnownBits *KB = &getAnalysis<GISelKnownBitsAnalysis>().get(MF); 424 MachineDominatorTree *MDT = 425 IsOptNone ? nullptr : &getAnalysis<MachineDominatorTree>(); 426 AArch64PostLegalizerCombinerInfo PCInfo(EnableOpt, F.hasOptSize(), 427 F.hasMinSize(), KB, MDT); 428 GISelCSEAnalysisWrapper &Wrapper = 429 getAnalysis<GISelCSEAnalysisWrapperPass>().getCSEWrapper(); 430 auto *CSEInfo = &Wrapper.get(TPC->getCSEConfig()); 431 Combiner C(PCInfo, TPC); 432 return C.combineMachineInstrs(MF, CSEInfo); 433 } 434 435 char AArch64PostLegalizerCombiner::ID = 0; 436 INITIALIZE_PASS_BEGIN(AArch64PostLegalizerCombiner, DEBUG_TYPE, 437 "Combine AArch64 MachineInstrs after legalization", false, 438 false) 439 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 440 INITIALIZE_PASS_DEPENDENCY(GISelKnownBitsAnalysis) 441 INITIALIZE_PASS_END(AArch64PostLegalizerCombiner, DEBUG_TYPE, 442 "Combine AArch64 MachineInstrs after legalization", false, 443 false) 444 445 namespace llvm { 446 FunctionPass *createAArch64PostLegalizerCombiner(bool IsOptNone) { 447 return new AArch64PostLegalizerCombiner(IsOptNone); 448 } 449 } // end namespace llvm 450