1 //=== AArch64PostLegalizerCombiner.cpp --------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// Post-legalization combines on generic MachineInstrs. 11 /// 12 /// The combines here must preserve instruction legality. 13 /// 14 /// Lowering combines (e.g. pseudo matching) should be handled by 15 /// AArch64PostLegalizerLowering. 16 /// 17 /// Combines which don't rely on instruction legality should go in the 18 /// AArch64PreLegalizerCombiner. 19 /// 20 //===----------------------------------------------------------------------===// 21 22 #include "AArch64TargetMachine.h" 23 #include "llvm/CodeGen/GlobalISel/Combiner.h" 24 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h" 25 #include "llvm/CodeGen/GlobalISel/CombinerInfo.h" 26 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h" 27 #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h" 28 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" 29 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 30 #include "llvm/CodeGen/GlobalISel/Utils.h" 31 #include "llvm/CodeGen/MachineDominators.h" 32 #include "llvm/CodeGen/MachineFunctionPass.h" 33 #include "llvm/CodeGen/MachineRegisterInfo.h" 34 #include "llvm/CodeGen/TargetOpcodes.h" 35 #include "llvm/CodeGen/TargetPassConfig.h" 36 #include "llvm/Support/Debug.h" 37 38 #define DEBUG_TYPE "aarch64-postlegalizer-combiner" 39 40 using namespace llvm; 41 using namespace MIPatternMatch; 42 43 /// This combine tries do what performExtractVectorEltCombine does in SDAG. 44 /// Rewrite for pairwise fadd pattern 45 /// (s32 (g_extract_vector_elt 46 /// (g_fadd (vXs32 Other) 47 /// (g_vector_shuffle (vXs32 Other) undef <1,X,...> )) 0)) 48 /// -> 49 /// (s32 (g_fadd (g_extract_vector_elt (vXs32 Other) 0) 50 /// (g_extract_vector_elt (vXs32 Other) 1)) 51 bool matchExtractVecEltPairwiseAdd( 52 MachineInstr &MI, MachineRegisterInfo &MRI, 53 std::tuple<unsigned, LLT, Register> &MatchInfo) { 54 Register Src1 = MI.getOperand(1).getReg(); 55 Register Src2 = MI.getOperand(2).getReg(); 56 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 57 58 auto Cst = getConstantVRegValWithLookThrough(Src2, MRI); 59 if (!Cst || Cst->Value != 0) 60 return false; 61 // SDAG also checks for FullFP16, but this looks to be beneficial anyway. 62 63 // Now check for an fadd operation. TODO: expand this for integer add? 64 auto *FAddMI = getOpcodeDef(TargetOpcode::G_FADD, Src1, MRI); 65 if (!FAddMI) 66 return false; 67 68 // If we add support for integer add, must restrict these types to just s64. 69 unsigned DstSize = DstTy.getSizeInBits(); 70 if (DstSize != 16 && DstSize != 32 && DstSize != 64) 71 return false; 72 73 Register Src1Op1 = FAddMI->getOperand(1).getReg(); 74 Register Src1Op2 = FAddMI->getOperand(2).getReg(); 75 MachineInstr *Shuffle = 76 getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op2, MRI); 77 MachineInstr *Other = MRI.getVRegDef(Src1Op1); 78 if (!Shuffle) { 79 Shuffle = getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op1, MRI); 80 Other = MRI.getVRegDef(Src1Op2); 81 } 82 83 // We're looking for a shuffle that moves the second element to index 0. 84 if (Shuffle && Shuffle->getOperand(3).getShuffleMask()[0] == 1 && 85 Other == MRI.getVRegDef(Shuffle->getOperand(1).getReg())) { 86 std::get<0>(MatchInfo) = TargetOpcode::G_FADD; 87 std::get<1>(MatchInfo) = DstTy; 88 std::get<2>(MatchInfo) = Other->getOperand(0).getReg(); 89 return true; 90 } 91 return false; 92 } 93 94 bool applyExtractVecEltPairwiseAdd( 95 MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B, 96 std::tuple<unsigned, LLT, Register> &MatchInfo) { 97 unsigned Opc = std::get<0>(MatchInfo); 98 assert(Opc == TargetOpcode::G_FADD && "Unexpected opcode!"); 99 // We want to generate two extracts of elements 0 and 1, and add them. 100 LLT Ty = std::get<1>(MatchInfo); 101 Register Src = std::get<2>(MatchInfo); 102 LLT s64 = LLT::scalar(64); 103 B.setInstrAndDebugLoc(MI); 104 auto Elt0 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 0)); 105 auto Elt1 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 1)); 106 B.buildInstr(Opc, {MI.getOperand(0).getReg()}, {Elt0, Elt1}); 107 MI.eraseFromParent(); 108 return true; 109 } 110 111 static bool isSignExtended(Register R, MachineRegisterInfo &MRI) { 112 // TODO: check if extended build vector as well. 113 unsigned Opc = MRI.getVRegDef(R)->getOpcode(); 114 return Opc == TargetOpcode::G_SEXT || Opc == TargetOpcode::G_SEXT_INREG; 115 } 116 117 static bool isZeroExtended(Register R, MachineRegisterInfo &MRI) { 118 // TODO: check if extended build vector as well. 119 return MRI.getVRegDef(R)->getOpcode() == TargetOpcode::G_ZEXT; 120 } 121 122 bool matchAArch64MulConstCombine( 123 MachineInstr &MI, MachineRegisterInfo &MRI, 124 std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) { 125 assert(MI.getOpcode() == TargetOpcode::G_MUL); 126 Register LHS = MI.getOperand(1).getReg(); 127 Register RHS = MI.getOperand(2).getReg(); 128 Register Dst = MI.getOperand(0).getReg(); 129 const LLT Ty = MRI.getType(LHS); 130 131 // The below optimizations require a constant RHS. 132 auto Const = getConstantVRegValWithLookThrough(RHS, MRI); 133 if (!Const) 134 return false; 135 136 const APInt ConstValue = Const->Value.sextOrSelf(Ty.getSizeInBits()); 137 // The following code is ported from AArch64ISelLowering. 138 // Multiplication of a power of two plus/minus one can be done more 139 // cheaply as as shift+add/sub. For now, this is true unilaterally. If 140 // future CPUs have a cheaper MADD instruction, this may need to be 141 // gated on a subtarget feature. For Cyclone, 32-bit MADD is 4 cycles and 142 // 64-bit is 5 cycles, so this is always a win. 143 // More aggressively, some multiplications N0 * C can be lowered to 144 // shift+add+shift if the constant C = A * B where A = 2^N + 1 and B = 2^M, 145 // e.g. 6=3*2=(2+1)*2. 146 // TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45 147 // which equals to (1+2)*16-(1+2). 148 // TrailingZeroes is used to test if the mul can be lowered to 149 // shift+add+shift. 150 unsigned TrailingZeroes = ConstValue.countTrailingZeros(); 151 if (TrailingZeroes) { 152 // Conservatively do not lower to shift+add+shift if the mul might be 153 // folded into smul or umul. 154 if (MRI.hasOneNonDBGUse(LHS) && 155 (isSignExtended(LHS, MRI) || isZeroExtended(LHS, MRI))) 156 return false; 157 // Conservatively do not lower to shift+add+shift if the mul might be 158 // folded into madd or msub. 159 if (MRI.hasOneNonDBGUse(Dst)) { 160 MachineInstr &UseMI = *MRI.use_instr_begin(Dst); 161 unsigned UseOpc = UseMI.getOpcode(); 162 if (UseOpc == TargetOpcode::G_ADD || UseOpc == TargetOpcode::G_PTR_ADD || 163 UseOpc == TargetOpcode::G_SUB) 164 return false; 165 } 166 } 167 // Use ShiftedConstValue instead of ConstValue to support both shift+add/sub 168 // and shift+add+shift. 169 APInt ShiftedConstValue = ConstValue.ashr(TrailingZeroes); 170 171 unsigned ShiftAmt, AddSubOpc; 172 // Is the shifted value the LHS operand of the add/sub? 173 bool ShiftValUseIsLHS = true; 174 // Do we need to negate the result? 175 bool NegateResult = false; 176 177 if (ConstValue.isNonNegative()) { 178 // (mul x, 2^N + 1) => (add (shl x, N), x) 179 // (mul x, 2^N - 1) => (sub (shl x, N), x) 180 // (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M) 181 APInt SCVMinus1 = ShiftedConstValue - 1; 182 APInt CVPlus1 = ConstValue + 1; 183 if (SCVMinus1.isPowerOf2()) { 184 ShiftAmt = SCVMinus1.logBase2(); 185 AddSubOpc = TargetOpcode::G_ADD; 186 } else if (CVPlus1.isPowerOf2()) { 187 ShiftAmt = CVPlus1.logBase2(); 188 AddSubOpc = TargetOpcode::G_SUB; 189 } else 190 return false; 191 } else { 192 // (mul x, -(2^N - 1)) => (sub x, (shl x, N)) 193 // (mul x, -(2^N + 1)) => - (add (shl x, N), x) 194 APInt CVNegPlus1 = -ConstValue + 1; 195 APInt CVNegMinus1 = -ConstValue - 1; 196 if (CVNegPlus1.isPowerOf2()) { 197 ShiftAmt = CVNegPlus1.logBase2(); 198 AddSubOpc = TargetOpcode::G_SUB; 199 ShiftValUseIsLHS = false; 200 } else if (CVNegMinus1.isPowerOf2()) { 201 ShiftAmt = CVNegMinus1.logBase2(); 202 AddSubOpc = TargetOpcode::G_ADD; 203 NegateResult = true; 204 } else 205 return false; 206 } 207 208 if (NegateResult && TrailingZeroes) 209 return false; 210 211 ApplyFn = [=](MachineIRBuilder &B, Register DstReg) { 212 auto Shift = B.buildConstant(LLT::scalar(64), ShiftAmt); 213 auto ShiftedVal = B.buildShl(Ty, LHS, Shift); 214 215 Register AddSubLHS = ShiftValUseIsLHS ? ShiftedVal.getReg(0) : LHS; 216 Register AddSubRHS = ShiftValUseIsLHS ? LHS : ShiftedVal.getReg(0); 217 auto Res = B.buildInstr(AddSubOpc, {Ty}, {AddSubLHS, AddSubRHS}); 218 assert(!(NegateResult && TrailingZeroes) && 219 "NegateResult and TrailingZeroes cannot both be true for now."); 220 // Negate the result. 221 if (NegateResult) { 222 B.buildSub(DstReg, B.buildConstant(Ty, 0), Res); 223 return; 224 } 225 // Shift the result. 226 if (TrailingZeroes) { 227 B.buildShl(DstReg, Res, B.buildConstant(LLT::scalar(64), TrailingZeroes)); 228 return; 229 } 230 B.buildCopy(DstReg, Res.getReg(0)); 231 }; 232 return true; 233 } 234 235 bool applyAArch64MulConstCombine( 236 MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B, 237 std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) { 238 B.setInstrAndDebugLoc(MI); 239 ApplyFn(B, MI.getOperand(0).getReg()); 240 MI.eraseFromParent(); 241 return true; 242 } 243 244 /// Try to fold a G_MERGE_VALUES of 2 s32 sources, where the second source 245 /// is a zero, into a G_ZEXT of the first. 246 bool matchFoldMergeToZext(MachineInstr &MI, MachineRegisterInfo &MRI) { 247 auto &Merge = cast<GMerge>(MI); 248 LLT SrcTy = MRI.getType(Merge.getSourceReg(0)); 249 if (SrcTy != LLT::scalar(32) || Merge.getNumSources() != 2) 250 return false; 251 return mi_match(Merge.getSourceReg(1), MRI, m_SpecificICst(0)); 252 } 253 254 void applyFoldMergeToZext(MachineInstr &MI, MachineRegisterInfo &MRI, 255 MachineIRBuilder &B, GISelChangeObserver &Observer) { 256 // Mutate %d(s64) = G_MERGE_VALUES %a(s32), 0(s32) 257 // -> 258 // %d(s64) = G_ZEXT %a(s32) 259 Observer.changingInstr(MI); 260 MI.setDesc(B.getTII().get(TargetOpcode::G_ZEXT)); 261 MI.RemoveOperand(2); 262 Observer.changedInstr(MI); 263 } 264 265 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS 266 #include "AArch64GenPostLegalizeGICombiner.inc" 267 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS 268 269 namespace { 270 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_H 271 #include "AArch64GenPostLegalizeGICombiner.inc" 272 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_H 273 274 class AArch64PostLegalizerCombinerInfo : public CombinerInfo { 275 GISelKnownBits *KB; 276 MachineDominatorTree *MDT; 277 278 public: 279 AArch64GenPostLegalizerCombinerHelperRuleConfig GeneratedRuleCfg; 280 281 AArch64PostLegalizerCombinerInfo(bool EnableOpt, bool OptSize, bool MinSize, 282 GISelKnownBits *KB, 283 MachineDominatorTree *MDT) 284 : CombinerInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false, 285 /*LegalizerInfo*/ nullptr, EnableOpt, OptSize, MinSize), 286 KB(KB), MDT(MDT) { 287 if (!GeneratedRuleCfg.parseCommandLineOption()) 288 report_fatal_error("Invalid rule identifier"); 289 } 290 291 virtual bool combine(GISelChangeObserver &Observer, MachineInstr &MI, 292 MachineIRBuilder &B) const override; 293 }; 294 295 bool AArch64PostLegalizerCombinerInfo::combine(GISelChangeObserver &Observer, 296 MachineInstr &MI, 297 MachineIRBuilder &B) const { 298 const auto *LI = 299 MI.getParent()->getParent()->getSubtarget().getLegalizerInfo(); 300 CombinerHelper Helper(Observer, B, KB, MDT, LI); 301 AArch64GenPostLegalizerCombinerHelper Generated(GeneratedRuleCfg); 302 return Generated.tryCombineAll(Observer, MI, B, Helper); 303 } 304 305 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_CPP 306 #include "AArch64GenPostLegalizeGICombiner.inc" 307 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_CPP 308 309 class AArch64PostLegalizerCombiner : public MachineFunctionPass { 310 public: 311 static char ID; 312 313 AArch64PostLegalizerCombiner(bool IsOptNone = false); 314 315 StringRef getPassName() const override { 316 return "AArch64PostLegalizerCombiner"; 317 } 318 319 bool runOnMachineFunction(MachineFunction &MF) override; 320 void getAnalysisUsage(AnalysisUsage &AU) const override; 321 322 private: 323 bool IsOptNone; 324 }; 325 } // end anonymous namespace 326 327 void AArch64PostLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const { 328 AU.addRequired<TargetPassConfig>(); 329 AU.setPreservesCFG(); 330 getSelectionDAGFallbackAnalysisUsage(AU); 331 AU.addRequired<GISelKnownBitsAnalysis>(); 332 AU.addPreserved<GISelKnownBitsAnalysis>(); 333 if (!IsOptNone) { 334 AU.addRequired<MachineDominatorTree>(); 335 AU.addPreserved<MachineDominatorTree>(); 336 AU.addRequired<GISelCSEAnalysisWrapperPass>(); 337 AU.addPreserved<GISelCSEAnalysisWrapperPass>(); 338 } 339 MachineFunctionPass::getAnalysisUsage(AU); 340 } 341 342 AArch64PostLegalizerCombiner::AArch64PostLegalizerCombiner(bool IsOptNone) 343 : MachineFunctionPass(ID), IsOptNone(IsOptNone) { 344 initializeAArch64PostLegalizerCombinerPass(*PassRegistry::getPassRegistry()); 345 } 346 347 bool AArch64PostLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) { 348 if (MF.getProperties().hasProperty( 349 MachineFunctionProperties::Property::FailedISel)) 350 return false; 351 assert(MF.getProperties().hasProperty( 352 MachineFunctionProperties::Property::Legalized) && 353 "Expected a legalized function?"); 354 auto *TPC = &getAnalysis<TargetPassConfig>(); 355 const Function &F = MF.getFunction(); 356 bool EnableOpt = 357 MF.getTarget().getOptLevel() != CodeGenOpt::None && !skipFunction(F); 358 GISelKnownBits *KB = &getAnalysis<GISelKnownBitsAnalysis>().get(MF); 359 MachineDominatorTree *MDT = 360 IsOptNone ? nullptr : &getAnalysis<MachineDominatorTree>(); 361 AArch64PostLegalizerCombinerInfo PCInfo(EnableOpt, F.hasOptSize(), 362 F.hasMinSize(), KB, MDT); 363 GISelCSEAnalysisWrapper &Wrapper = 364 getAnalysis<GISelCSEAnalysisWrapperPass>().getCSEWrapper(); 365 auto *CSEInfo = &Wrapper.get(TPC->getCSEConfig()); 366 Combiner C(PCInfo, TPC); 367 return C.combineMachineInstrs(MF, CSEInfo); 368 } 369 370 char AArch64PostLegalizerCombiner::ID = 0; 371 INITIALIZE_PASS_BEGIN(AArch64PostLegalizerCombiner, DEBUG_TYPE, 372 "Combine AArch64 MachineInstrs after legalization", false, 373 false) 374 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 375 INITIALIZE_PASS_DEPENDENCY(GISelKnownBitsAnalysis) 376 INITIALIZE_PASS_END(AArch64PostLegalizerCombiner, DEBUG_TYPE, 377 "Combine AArch64 MachineInstrs after legalization", false, 378 false) 379 380 namespace llvm { 381 FunctionPass *createAArch64PostLegalizerCombiner(bool IsOptNone) { 382 return new AArch64PostLegalizerCombiner(IsOptNone); 383 } 384 } // end namespace llvm 385