1 //=== AArch64PostLegalizerCombiner.cpp --------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// Post-legalization combines on generic MachineInstrs. 11 /// 12 /// The combines here must preserve instruction legality. 13 /// 14 /// Lowering combines (e.g. pseudo matching) should be handled by 15 /// AArch64PostLegalizerLowering. 16 /// 17 /// Combines which don't rely on instruction legality should go in the 18 /// AArch64PreLegalizerCombiner. 19 /// 20 //===----------------------------------------------------------------------===// 21 22 #include "AArch64TargetMachine.h" 23 #include "llvm/CodeGen/GlobalISel/Combiner.h" 24 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h" 25 #include "llvm/CodeGen/GlobalISel/CombinerInfo.h" 26 #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h" 27 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 28 #include "llvm/CodeGen/GlobalISel/Utils.h" 29 #include "llvm/CodeGen/MachineDominators.h" 30 #include "llvm/CodeGen/MachineFunctionPass.h" 31 #include "llvm/CodeGen/MachineRegisterInfo.h" 32 #include "llvm/CodeGen/TargetOpcodes.h" 33 #include "llvm/CodeGen/TargetPassConfig.h" 34 #include "llvm/Support/Debug.h" 35 36 #define DEBUG_TYPE "aarch64-postlegalizer-combiner" 37 38 using namespace llvm; 39 40 /// This combine tries do what performExtractVectorEltCombine does in SDAG. 41 /// Rewrite for pairwise fadd pattern 42 /// (s32 (g_extract_vector_elt 43 /// (g_fadd (vXs32 Other) 44 /// (g_vector_shuffle (vXs32 Other) undef <1,X,...> )) 0)) 45 /// -> 46 /// (s32 (g_fadd (g_extract_vector_elt (vXs32 Other) 0) 47 /// (g_extract_vector_elt (vXs32 Other) 1)) 48 bool matchExtractVecEltPairwiseAdd( 49 MachineInstr &MI, MachineRegisterInfo &MRI, 50 std::tuple<unsigned, LLT, Register> &MatchInfo) { 51 Register Src1 = MI.getOperand(1).getReg(); 52 Register Src2 = MI.getOperand(2).getReg(); 53 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 54 55 auto Cst = getConstantVRegValWithLookThrough(Src2, MRI); 56 if (!Cst || Cst->Value != 0) 57 return false; 58 // SDAG also checks for FullFP16, but this looks to be beneficial anyway. 59 60 // Now check for an fadd operation. TODO: expand this for integer add? 61 auto *FAddMI = getOpcodeDef(TargetOpcode::G_FADD, Src1, MRI); 62 if (!FAddMI) 63 return false; 64 65 // If we add support for integer add, must restrict these types to just s64. 66 unsigned DstSize = DstTy.getSizeInBits(); 67 if (DstSize != 16 && DstSize != 32 && DstSize != 64) 68 return false; 69 70 Register Src1Op1 = FAddMI->getOperand(1).getReg(); 71 Register Src1Op2 = FAddMI->getOperand(2).getReg(); 72 MachineInstr *Shuffle = 73 getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op2, MRI); 74 MachineInstr *Other = MRI.getVRegDef(Src1Op1); 75 if (!Shuffle) { 76 Shuffle = getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op1, MRI); 77 Other = MRI.getVRegDef(Src1Op2); 78 } 79 80 // We're looking for a shuffle that moves the second element to index 0. 81 if (Shuffle && Shuffle->getOperand(3).getShuffleMask()[0] == 1 && 82 Other == MRI.getVRegDef(Shuffle->getOperand(1).getReg())) { 83 std::get<0>(MatchInfo) = TargetOpcode::G_FADD; 84 std::get<1>(MatchInfo) = DstTy; 85 std::get<2>(MatchInfo) = Other->getOperand(0).getReg(); 86 return true; 87 } 88 return false; 89 } 90 91 bool applyExtractVecEltPairwiseAdd( 92 MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B, 93 std::tuple<unsigned, LLT, Register> &MatchInfo) { 94 unsigned Opc = std::get<0>(MatchInfo); 95 assert(Opc == TargetOpcode::G_FADD && "Unexpected opcode!"); 96 // We want to generate two extracts of elements 0 and 1, and add them. 97 LLT Ty = std::get<1>(MatchInfo); 98 Register Src = std::get<2>(MatchInfo); 99 LLT s64 = LLT::scalar(64); 100 B.setInstrAndDebugLoc(MI); 101 auto Elt0 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 0)); 102 auto Elt1 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 1)); 103 B.buildInstr(Opc, {MI.getOperand(0).getReg()}, {Elt0, Elt1}); 104 MI.eraseFromParent(); 105 return true; 106 } 107 108 static bool isSignExtended(Register R, MachineRegisterInfo &MRI) { 109 // TODO: check if extended build vector as well. 110 unsigned Opc = MRI.getVRegDef(R)->getOpcode(); 111 return Opc == TargetOpcode::G_SEXT || Opc == TargetOpcode::G_SEXT_INREG; 112 } 113 114 static bool isZeroExtended(Register R, MachineRegisterInfo &MRI) { 115 // TODO: check if extended build vector as well. 116 return MRI.getVRegDef(R)->getOpcode() == TargetOpcode::G_ZEXT; 117 } 118 119 bool matchAArch64MulConstCombine( 120 MachineInstr &MI, MachineRegisterInfo &MRI, 121 std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) { 122 assert(MI.getOpcode() == TargetOpcode::G_MUL); 123 Register LHS = MI.getOperand(1).getReg(); 124 Register RHS = MI.getOperand(2).getReg(); 125 Register Dst = MI.getOperand(0).getReg(); 126 const LLT Ty = MRI.getType(LHS); 127 128 // The below optimizations require a constant RHS. 129 auto Const = getConstantVRegValWithLookThrough(RHS, MRI); 130 if (!Const) 131 return false; 132 133 const APInt ConstValue = Const->Value.sextOrSelf(Ty.getSizeInBits()); 134 // The following code is ported from AArch64ISelLowering. 135 // Multiplication of a power of two plus/minus one can be done more 136 // cheaply as as shift+add/sub. For now, this is true unilaterally. If 137 // future CPUs have a cheaper MADD instruction, this may need to be 138 // gated on a subtarget feature. For Cyclone, 32-bit MADD is 4 cycles and 139 // 64-bit is 5 cycles, so this is always a win. 140 // More aggressively, some multiplications N0 * C can be lowered to 141 // shift+add+shift if the constant C = A * B where A = 2^N + 1 and B = 2^M, 142 // e.g. 6=3*2=(2+1)*2. 143 // TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45 144 // which equals to (1+2)*16-(1+2). 145 // TrailingZeroes is used to test if the mul can be lowered to 146 // shift+add+shift. 147 unsigned TrailingZeroes = ConstValue.countTrailingZeros(); 148 if (TrailingZeroes) { 149 // Conservatively do not lower to shift+add+shift if the mul might be 150 // folded into smul or umul. 151 if (MRI.hasOneNonDBGUse(LHS) && 152 (isSignExtended(LHS, MRI) || isZeroExtended(LHS, MRI))) 153 return false; 154 // Conservatively do not lower to shift+add+shift if the mul might be 155 // folded into madd or msub. 156 if (MRI.hasOneNonDBGUse(Dst)) { 157 MachineInstr &UseMI = *MRI.use_instr_begin(Dst); 158 if (UseMI.getOpcode() == TargetOpcode::G_ADD || 159 UseMI.getOpcode() == TargetOpcode::G_SUB) 160 return false; 161 } 162 } 163 // Use ShiftedConstValue instead of ConstValue to support both shift+add/sub 164 // and shift+add+shift. 165 APInt ShiftedConstValue = ConstValue.ashr(TrailingZeroes); 166 167 unsigned ShiftAmt, AddSubOpc; 168 // Is the shifted value the LHS operand of the add/sub? 169 bool ShiftValUseIsLHS = true; 170 // Do we need to negate the result? 171 bool NegateResult = false; 172 173 if (ConstValue.isNonNegative()) { 174 // (mul x, 2^N + 1) => (add (shl x, N), x) 175 // (mul x, 2^N - 1) => (sub (shl x, N), x) 176 // (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M) 177 APInt SCVMinus1 = ShiftedConstValue - 1; 178 APInt CVPlus1 = ConstValue + 1; 179 if (SCVMinus1.isPowerOf2()) { 180 ShiftAmt = SCVMinus1.logBase2(); 181 AddSubOpc = TargetOpcode::G_ADD; 182 } else if (CVPlus1.isPowerOf2()) { 183 ShiftAmt = CVPlus1.logBase2(); 184 AddSubOpc = TargetOpcode::G_SUB; 185 } else 186 return false; 187 } else { 188 // (mul x, -(2^N - 1)) => (sub x, (shl x, N)) 189 // (mul x, -(2^N + 1)) => - (add (shl x, N), x) 190 APInt CVNegPlus1 = -ConstValue + 1; 191 APInt CVNegMinus1 = -ConstValue - 1; 192 if (CVNegPlus1.isPowerOf2()) { 193 ShiftAmt = CVNegPlus1.logBase2(); 194 AddSubOpc = TargetOpcode::G_SUB; 195 ShiftValUseIsLHS = false; 196 } else if (CVNegMinus1.isPowerOf2()) { 197 ShiftAmt = CVNegMinus1.logBase2(); 198 AddSubOpc = TargetOpcode::G_ADD; 199 NegateResult = true; 200 } else 201 return false; 202 } 203 204 if (NegateResult && TrailingZeroes) 205 return false; 206 207 ApplyFn = [=](MachineIRBuilder &B, Register DstReg) { 208 auto Shift = B.buildConstant(LLT::scalar(64), ShiftAmt); 209 auto ShiftedVal = B.buildShl(Ty, LHS, Shift); 210 211 Register AddSubLHS = ShiftValUseIsLHS ? ShiftedVal.getReg(0) : LHS; 212 Register AddSubRHS = ShiftValUseIsLHS ? LHS : ShiftedVal.getReg(0); 213 auto Res = B.buildInstr(AddSubOpc, {Ty}, {AddSubLHS, AddSubRHS}); 214 assert(!(NegateResult && TrailingZeroes) && 215 "NegateResult and TrailingZeroes cannot both be true for now."); 216 // Negate the result. 217 if (NegateResult) { 218 B.buildSub(DstReg, B.buildConstant(Ty, 0), Res); 219 return; 220 } 221 // Shift the result. 222 if (TrailingZeroes) { 223 B.buildShl(DstReg, Res, B.buildConstant(LLT::scalar(64), TrailingZeroes)); 224 return; 225 } 226 B.buildCopy(DstReg, Res.getReg(0)); 227 }; 228 return true; 229 } 230 231 bool applyAArch64MulConstCombine( 232 MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B, 233 std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) { 234 B.setInstrAndDebugLoc(MI); 235 ApplyFn(B, MI.getOperand(0).getReg()); 236 MI.eraseFromParent(); 237 return true; 238 } 239 240 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS 241 #include "AArch64GenPostLegalizeGICombiner.inc" 242 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS 243 244 namespace { 245 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_H 246 #include "AArch64GenPostLegalizeGICombiner.inc" 247 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_H 248 249 class AArch64PostLegalizerCombinerInfo : public CombinerInfo { 250 GISelKnownBits *KB; 251 MachineDominatorTree *MDT; 252 253 public: 254 AArch64GenPostLegalizerCombinerHelperRuleConfig GeneratedRuleCfg; 255 256 AArch64PostLegalizerCombinerInfo(bool EnableOpt, bool OptSize, bool MinSize, 257 GISelKnownBits *KB, 258 MachineDominatorTree *MDT) 259 : CombinerInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false, 260 /*LegalizerInfo*/ nullptr, EnableOpt, OptSize, MinSize), 261 KB(KB), MDT(MDT) { 262 if (!GeneratedRuleCfg.parseCommandLineOption()) 263 report_fatal_error("Invalid rule identifier"); 264 } 265 266 virtual bool combine(GISelChangeObserver &Observer, MachineInstr &MI, 267 MachineIRBuilder &B) const override; 268 }; 269 270 bool AArch64PostLegalizerCombinerInfo::combine(GISelChangeObserver &Observer, 271 MachineInstr &MI, 272 MachineIRBuilder &B) const { 273 const auto *LI = 274 MI.getParent()->getParent()->getSubtarget().getLegalizerInfo(); 275 CombinerHelper Helper(Observer, B, KB, MDT, LI); 276 AArch64GenPostLegalizerCombinerHelper Generated(GeneratedRuleCfg); 277 return Generated.tryCombineAll(Observer, MI, B, Helper); 278 } 279 280 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_CPP 281 #include "AArch64GenPostLegalizeGICombiner.inc" 282 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_CPP 283 284 class AArch64PostLegalizerCombiner : public MachineFunctionPass { 285 public: 286 static char ID; 287 288 AArch64PostLegalizerCombiner(bool IsOptNone = false); 289 290 StringRef getPassName() const override { 291 return "AArch64PostLegalizerCombiner"; 292 } 293 294 bool runOnMachineFunction(MachineFunction &MF) override; 295 void getAnalysisUsage(AnalysisUsage &AU) const override; 296 297 private: 298 bool IsOptNone; 299 }; 300 } // end anonymous namespace 301 302 void AArch64PostLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const { 303 AU.addRequired<TargetPassConfig>(); 304 AU.setPreservesCFG(); 305 getSelectionDAGFallbackAnalysisUsage(AU); 306 AU.addRequired<GISelKnownBitsAnalysis>(); 307 AU.addPreserved<GISelKnownBitsAnalysis>(); 308 if (!IsOptNone) { 309 AU.addRequired<MachineDominatorTree>(); 310 AU.addPreserved<MachineDominatorTree>(); 311 } 312 MachineFunctionPass::getAnalysisUsage(AU); 313 } 314 315 AArch64PostLegalizerCombiner::AArch64PostLegalizerCombiner(bool IsOptNone) 316 : MachineFunctionPass(ID), IsOptNone(IsOptNone) { 317 initializeAArch64PostLegalizerCombinerPass(*PassRegistry::getPassRegistry()); 318 } 319 320 bool AArch64PostLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) { 321 if (MF.getProperties().hasProperty( 322 MachineFunctionProperties::Property::FailedISel)) 323 return false; 324 assert(MF.getProperties().hasProperty( 325 MachineFunctionProperties::Property::Legalized) && 326 "Expected a legalized function?"); 327 auto *TPC = &getAnalysis<TargetPassConfig>(); 328 const Function &F = MF.getFunction(); 329 bool EnableOpt = 330 MF.getTarget().getOptLevel() != CodeGenOpt::None && !skipFunction(F); 331 GISelKnownBits *KB = &getAnalysis<GISelKnownBitsAnalysis>().get(MF); 332 MachineDominatorTree *MDT = 333 IsOptNone ? nullptr : &getAnalysis<MachineDominatorTree>(); 334 AArch64PostLegalizerCombinerInfo PCInfo(EnableOpt, F.hasOptSize(), 335 F.hasMinSize(), KB, MDT); 336 Combiner C(PCInfo, TPC); 337 return C.combineMachineInstrs(MF, /*CSEInfo*/ nullptr); 338 } 339 340 char AArch64PostLegalizerCombiner::ID = 0; 341 INITIALIZE_PASS_BEGIN(AArch64PostLegalizerCombiner, DEBUG_TYPE, 342 "Combine AArch64 MachineInstrs after legalization", false, 343 false) 344 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 345 INITIALIZE_PASS_DEPENDENCY(GISelKnownBitsAnalysis) 346 INITIALIZE_PASS_END(AArch64PostLegalizerCombiner, DEBUG_TYPE, 347 "Combine AArch64 MachineInstrs after legalization", false, 348 false) 349 350 namespace llvm { 351 FunctionPass *createAArch64PostLegalizerCombiner(bool IsOptNone) { 352 return new AArch64PostLegalizerCombiner(IsOptNone); 353 } 354 } // end namespace llvm 355