xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp (revision 2f9966ff63d65bd474478888c9088eeae3f9c669)
1 //=== AArch64PostLegalizerCombiner.cpp --------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// Post-legalization combines on generic MachineInstrs.
11 ///
12 /// The combines here must preserve instruction legality.
13 ///
14 /// Lowering combines (e.g. pseudo matching) should be handled by
15 /// AArch64PostLegalizerLowering.
16 ///
17 /// Combines which don't rely on instruction legality should go in the
18 /// AArch64PreLegalizerCombiner.
19 ///
20 //===----------------------------------------------------------------------===//
21 
22 #include "AArch64TargetMachine.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/CodeGen/GlobalISel/CSEInfo.h"
25 #include "llvm/CodeGen/GlobalISel/CSEMIRBuilder.h"
26 #include "llvm/CodeGen/GlobalISel/Combiner.h"
27 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
28 #include "llvm/CodeGen/GlobalISel/CombinerInfo.h"
29 #include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h"
30 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
31 #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h"
32 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
33 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
34 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
35 #include "llvm/CodeGen/GlobalISel/Utils.h"
36 #include "llvm/CodeGen/MachineDominators.h"
37 #include "llvm/CodeGen/MachineFunctionPass.h"
38 #include "llvm/CodeGen/MachineRegisterInfo.h"
39 #include "llvm/CodeGen/TargetOpcodes.h"
40 #include "llvm/CodeGen/TargetPassConfig.h"
41 #include "llvm/Support/Debug.h"
42 
43 #define GET_GICOMBINER_DEPS
44 #include "AArch64GenPostLegalizeGICombiner.inc"
45 #undef GET_GICOMBINER_DEPS
46 
47 #define DEBUG_TYPE "aarch64-postlegalizer-combiner"
48 
49 using namespace llvm;
50 using namespace MIPatternMatch;
51 
52 namespace {
53 
54 #define GET_GICOMBINER_TYPES
55 #include "AArch64GenPostLegalizeGICombiner.inc"
56 #undef GET_GICOMBINER_TYPES
57 
58 /// This combine tries do what performExtractVectorEltCombine does in SDAG.
59 /// Rewrite for pairwise fadd pattern
60 ///   (s32 (g_extract_vector_elt
61 ///           (g_fadd (vXs32 Other)
62 ///                  (g_vector_shuffle (vXs32 Other) undef <1,X,...> )) 0))
63 /// ->
64 ///   (s32 (g_fadd (g_extract_vector_elt (vXs32 Other) 0)
65 ///              (g_extract_vector_elt (vXs32 Other) 1))
66 bool matchExtractVecEltPairwiseAdd(
67     MachineInstr &MI, MachineRegisterInfo &MRI,
68     std::tuple<unsigned, LLT, Register> &MatchInfo) {
69   Register Src1 = MI.getOperand(1).getReg();
70   Register Src2 = MI.getOperand(2).getReg();
71   LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
72 
73   auto Cst = getIConstantVRegValWithLookThrough(Src2, MRI);
74   if (!Cst || Cst->Value != 0)
75     return false;
76   // SDAG also checks for FullFP16, but this looks to be beneficial anyway.
77 
78   // Now check for an fadd operation. TODO: expand this for integer add?
79   auto *FAddMI = getOpcodeDef(TargetOpcode::G_FADD, Src1, MRI);
80   if (!FAddMI)
81     return false;
82 
83   // If we add support for integer add, must restrict these types to just s64.
84   unsigned DstSize = DstTy.getSizeInBits();
85   if (DstSize != 16 && DstSize != 32 && DstSize != 64)
86     return false;
87 
88   Register Src1Op1 = FAddMI->getOperand(1).getReg();
89   Register Src1Op2 = FAddMI->getOperand(2).getReg();
90   MachineInstr *Shuffle =
91       getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op2, MRI);
92   MachineInstr *Other = MRI.getVRegDef(Src1Op1);
93   if (!Shuffle) {
94     Shuffle = getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op1, MRI);
95     Other = MRI.getVRegDef(Src1Op2);
96   }
97 
98   // We're looking for a shuffle that moves the second element to index 0.
99   if (Shuffle && Shuffle->getOperand(3).getShuffleMask()[0] == 1 &&
100       Other == MRI.getVRegDef(Shuffle->getOperand(1).getReg())) {
101     std::get<0>(MatchInfo) = TargetOpcode::G_FADD;
102     std::get<1>(MatchInfo) = DstTy;
103     std::get<2>(MatchInfo) = Other->getOperand(0).getReg();
104     return true;
105   }
106   return false;
107 }
108 
109 void applyExtractVecEltPairwiseAdd(
110     MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B,
111     std::tuple<unsigned, LLT, Register> &MatchInfo) {
112   unsigned Opc = std::get<0>(MatchInfo);
113   assert(Opc == TargetOpcode::G_FADD && "Unexpected opcode!");
114   // We want to generate two extracts of elements 0 and 1, and add them.
115   LLT Ty = std::get<1>(MatchInfo);
116   Register Src = std::get<2>(MatchInfo);
117   LLT s64 = LLT::scalar(64);
118   B.setInstrAndDebugLoc(MI);
119   auto Elt0 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 0));
120   auto Elt1 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 1));
121   B.buildInstr(Opc, {MI.getOperand(0).getReg()}, {Elt0, Elt1});
122   MI.eraseFromParent();
123 }
124 
125 bool isSignExtended(Register R, MachineRegisterInfo &MRI) {
126   // TODO: check if extended build vector as well.
127   unsigned Opc = MRI.getVRegDef(R)->getOpcode();
128   return Opc == TargetOpcode::G_SEXT || Opc == TargetOpcode::G_SEXT_INREG;
129 }
130 
131 bool isZeroExtended(Register R, MachineRegisterInfo &MRI) {
132   // TODO: check if extended build vector as well.
133   return MRI.getVRegDef(R)->getOpcode() == TargetOpcode::G_ZEXT;
134 }
135 
136 bool matchAArch64MulConstCombine(
137     MachineInstr &MI, MachineRegisterInfo &MRI,
138     std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) {
139   assert(MI.getOpcode() == TargetOpcode::G_MUL);
140   Register LHS = MI.getOperand(1).getReg();
141   Register RHS = MI.getOperand(2).getReg();
142   Register Dst = MI.getOperand(0).getReg();
143   const LLT Ty = MRI.getType(LHS);
144 
145   // The below optimizations require a constant RHS.
146   auto Const = getIConstantVRegValWithLookThrough(RHS, MRI);
147   if (!Const)
148     return false;
149 
150   APInt ConstValue = Const->Value.sext(Ty.getSizeInBits());
151   // The following code is ported from AArch64ISelLowering.
152   // Multiplication of a power of two plus/minus one can be done more
153   // cheaply as shift+add/sub. For now, this is true unilaterally. If
154   // future CPUs have a cheaper MADD instruction, this may need to be
155   // gated on a subtarget feature. For Cyclone, 32-bit MADD is 4 cycles and
156   // 64-bit is 5 cycles, so this is always a win.
157   // More aggressively, some multiplications N0 * C can be lowered to
158   // shift+add+shift if the constant C = A * B where A = 2^N + 1 and B = 2^M,
159   // e.g. 6=3*2=(2+1)*2.
160   // TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45
161   // which equals to (1+2)*16-(1+2).
162   // TrailingZeroes is used to test if the mul can be lowered to
163   // shift+add+shift.
164   unsigned TrailingZeroes = ConstValue.countr_zero();
165   if (TrailingZeroes) {
166     // Conservatively do not lower to shift+add+shift if the mul might be
167     // folded into smul or umul.
168     if (MRI.hasOneNonDBGUse(LHS) &&
169         (isSignExtended(LHS, MRI) || isZeroExtended(LHS, MRI)))
170       return false;
171     // Conservatively do not lower to shift+add+shift if the mul might be
172     // folded into madd or msub.
173     if (MRI.hasOneNonDBGUse(Dst)) {
174       MachineInstr &UseMI = *MRI.use_instr_begin(Dst);
175       unsigned UseOpc = UseMI.getOpcode();
176       if (UseOpc == TargetOpcode::G_ADD || UseOpc == TargetOpcode::G_PTR_ADD ||
177           UseOpc == TargetOpcode::G_SUB)
178         return false;
179     }
180   }
181   // Use ShiftedConstValue instead of ConstValue to support both shift+add/sub
182   // and shift+add+shift.
183   APInt ShiftedConstValue = ConstValue.ashr(TrailingZeroes);
184 
185   unsigned ShiftAmt, AddSubOpc;
186   // Is the shifted value the LHS operand of the add/sub?
187   bool ShiftValUseIsLHS = true;
188   // Do we need to negate the result?
189   bool NegateResult = false;
190 
191   if (ConstValue.isNonNegative()) {
192     // (mul x, 2^N + 1) => (add (shl x, N), x)
193     // (mul x, 2^N - 1) => (sub (shl x, N), x)
194     // (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M)
195     APInt SCVMinus1 = ShiftedConstValue - 1;
196     APInt CVPlus1 = ConstValue + 1;
197     if (SCVMinus1.isPowerOf2()) {
198       ShiftAmt = SCVMinus1.logBase2();
199       AddSubOpc = TargetOpcode::G_ADD;
200     } else if (CVPlus1.isPowerOf2()) {
201       ShiftAmt = CVPlus1.logBase2();
202       AddSubOpc = TargetOpcode::G_SUB;
203     } else
204       return false;
205   } else {
206     // (mul x, -(2^N - 1)) => (sub x, (shl x, N))
207     // (mul x, -(2^N + 1)) => - (add (shl x, N), x)
208     APInt CVNegPlus1 = -ConstValue + 1;
209     APInt CVNegMinus1 = -ConstValue - 1;
210     if (CVNegPlus1.isPowerOf2()) {
211       ShiftAmt = CVNegPlus1.logBase2();
212       AddSubOpc = TargetOpcode::G_SUB;
213       ShiftValUseIsLHS = false;
214     } else if (CVNegMinus1.isPowerOf2()) {
215       ShiftAmt = CVNegMinus1.logBase2();
216       AddSubOpc = TargetOpcode::G_ADD;
217       NegateResult = true;
218     } else
219       return false;
220   }
221 
222   if (NegateResult && TrailingZeroes)
223     return false;
224 
225   ApplyFn = [=](MachineIRBuilder &B, Register DstReg) {
226     auto Shift = B.buildConstant(LLT::scalar(64), ShiftAmt);
227     auto ShiftedVal = B.buildShl(Ty, LHS, Shift);
228 
229     Register AddSubLHS = ShiftValUseIsLHS ? ShiftedVal.getReg(0) : LHS;
230     Register AddSubRHS = ShiftValUseIsLHS ? LHS : ShiftedVal.getReg(0);
231     auto Res = B.buildInstr(AddSubOpc, {Ty}, {AddSubLHS, AddSubRHS});
232     assert(!(NegateResult && TrailingZeroes) &&
233            "NegateResult and TrailingZeroes cannot both be true for now.");
234     // Negate the result.
235     if (NegateResult) {
236       B.buildSub(DstReg, B.buildConstant(Ty, 0), Res);
237       return;
238     }
239     // Shift the result.
240     if (TrailingZeroes) {
241       B.buildShl(DstReg, Res, B.buildConstant(LLT::scalar(64), TrailingZeroes));
242       return;
243     }
244     B.buildCopy(DstReg, Res.getReg(0));
245   };
246   return true;
247 }
248 
249 void applyAArch64MulConstCombine(
250     MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B,
251     std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) {
252   B.setInstrAndDebugLoc(MI);
253   ApplyFn(B, MI.getOperand(0).getReg());
254   MI.eraseFromParent();
255 }
256 
257 /// Try to fold a G_MERGE_VALUES of 2 s32 sources, where the second source
258 /// is a zero, into a G_ZEXT of the first.
259 bool matchFoldMergeToZext(MachineInstr &MI, MachineRegisterInfo &MRI) {
260   auto &Merge = cast<GMerge>(MI);
261   LLT SrcTy = MRI.getType(Merge.getSourceReg(0));
262   if (SrcTy != LLT::scalar(32) || Merge.getNumSources() != 2)
263     return false;
264   return mi_match(Merge.getSourceReg(1), MRI, m_SpecificICst(0));
265 }
266 
267 void applyFoldMergeToZext(MachineInstr &MI, MachineRegisterInfo &MRI,
268                           MachineIRBuilder &B, GISelChangeObserver &Observer) {
269   // Mutate %d(s64) = G_MERGE_VALUES %a(s32), 0(s32)
270   //  ->
271   // %d(s64) = G_ZEXT %a(s32)
272   Observer.changingInstr(MI);
273   MI.setDesc(B.getTII().get(TargetOpcode::G_ZEXT));
274   MI.removeOperand(2);
275   Observer.changedInstr(MI);
276 }
277 
278 /// \returns True if a G_ANYEXT instruction \p MI should be mutated to a G_ZEXT
279 /// instruction.
280 bool matchMutateAnyExtToZExt(MachineInstr &MI, MachineRegisterInfo &MRI) {
281   // If this is coming from a scalar compare then we can use a G_ZEXT instead of
282   // a G_ANYEXT:
283   //
284   // %cmp:_(s32) = G_[I|F]CMP ... <-- produces 0/1.
285   // %ext:_(s64) = G_ANYEXT %cmp(s32)
286   //
287   // By doing this, we can leverage more KnownBits combines.
288   assert(MI.getOpcode() == TargetOpcode::G_ANYEXT);
289   Register Dst = MI.getOperand(0).getReg();
290   Register Src = MI.getOperand(1).getReg();
291   return MRI.getType(Dst).isScalar() &&
292          mi_match(Src, MRI,
293                   m_any_of(m_GICmp(m_Pred(), m_Reg(), m_Reg()),
294                            m_GFCmp(m_Pred(), m_Reg(), m_Reg())));
295 }
296 
297 void applyMutateAnyExtToZExt(MachineInstr &MI, MachineRegisterInfo &MRI,
298                              MachineIRBuilder &B,
299                              GISelChangeObserver &Observer) {
300   Observer.changingInstr(MI);
301   MI.setDesc(B.getTII().get(TargetOpcode::G_ZEXT));
302   Observer.changedInstr(MI);
303 }
304 
305 /// Match a 128b store of zero and split it into two 64 bit stores, for
306 /// size/performance reasons.
307 bool matchSplitStoreZero128(MachineInstr &MI, MachineRegisterInfo &MRI) {
308   GStore &Store = cast<GStore>(MI);
309   if (!Store.isSimple())
310     return false;
311   LLT ValTy = MRI.getType(Store.getValueReg());
312   if (!ValTy.isVector() || ValTy.getSizeInBits() != 128)
313     return false;
314   if (ValTy.getSizeInBits() != Store.getMemSizeInBits())
315     return false; // Don't split truncating stores.
316   if (!MRI.hasOneNonDBGUse(Store.getValueReg()))
317     return false;
318   auto MaybeCst = isConstantOrConstantSplatVector(
319       *MRI.getVRegDef(Store.getValueReg()), MRI);
320   return MaybeCst && MaybeCst->isZero();
321 }
322 
323 void applySplitStoreZero128(MachineInstr &MI, MachineRegisterInfo &MRI,
324                             MachineIRBuilder &B,
325                             GISelChangeObserver &Observer) {
326   B.setInstrAndDebugLoc(MI);
327   GStore &Store = cast<GStore>(MI);
328   assert(MRI.getType(Store.getValueReg()).isVector() &&
329          "Expected a vector store value");
330   LLT NewTy = LLT::scalar(64);
331   Register PtrReg = Store.getPointerReg();
332   auto Zero = B.buildConstant(NewTy, 0);
333   auto HighPtr = B.buildPtrAdd(MRI.getType(PtrReg), PtrReg,
334                                B.buildConstant(LLT::scalar(64), 8));
335   auto &MF = *MI.getMF();
336   auto *LowMMO = MF.getMachineMemOperand(&Store.getMMO(), 0, NewTy);
337   auto *HighMMO = MF.getMachineMemOperand(&Store.getMMO(), 8, NewTy);
338   B.buildStore(Zero, PtrReg, *LowMMO);
339   B.buildStore(Zero, HighPtr, *HighMMO);
340   Store.eraseFromParent();
341 }
342 
343 bool matchOrToBSP(MachineInstr &MI, MachineRegisterInfo &MRI,
344                   std::tuple<Register, Register, Register> &MatchInfo) {
345   const LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
346   if (!DstTy.isVector())
347     return false;
348 
349   Register AO1, AO2, BVO1, BVO2;
350   if (!mi_match(MI, MRI,
351                 m_GOr(m_GAnd(m_Reg(AO1), m_Reg(BVO1)),
352                       m_GAnd(m_Reg(AO2), m_Reg(BVO2)))))
353     return false;
354 
355   auto *BV1 = getOpcodeDef<GBuildVector>(BVO1, MRI);
356   auto *BV2 = getOpcodeDef<GBuildVector>(BVO2, MRI);
357   if (!BV1 || !BV2)
358     return false;
359 
360   for (int I = 0, E = DstTy.getNumElements(); I < E; I++) {
361     auto ValAndVReg1 =
362         getIConstantVRegValWithLookThrough(BV1->getSourceReg(I), MRI);
363     auto ValAndVReg2 =
364         getIConstantVRegValWithLookThrough(BV2->getSourceReg(I), MRI);
365     if (!ValAndVReg1 || !ValAndVReg2 ||
366         ValAndVReg1->Value != ~ValAndVReg2->Value)
367       return false;
368   }
369 
370   MatchInfo = {AO1, AO2, BVO1};
371   return true;
372 }
373 
374 void applyOrToBSP(MachineInstr &MI, MachineRegisterInfo &MRI,
375                   MachineIRBuilder &B,
376                   std::tuple<Register, Register, Register> &MatchInfo) {
377   B.setInstrAndDebugLoc(MI);
378   B.buildInstr(
379       AArch64::G_BSP, {MI.getOperand(0).getReg()},
380       {std::get<2>(MatchInfo), std::get<0>(MatchInfo), std::get<1>(MatchInfo)});
381   MI.eraseFromParent();
382 }
383 
384 class AArch64PostLegalizerCombinerImpl : public Combiner {
385 protected:
386   // TODO: Make CombinerHelper methods const.
387   mutable CombinerHelper Helper;
388   const AArch64PostLegalizerCombinerImplRuleConfig &RuleConfig;
389   const AArch64Subtarget &STI;
390 
391 public:
392   AArch64PostLegalizerCombinerImpl(
393       MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC,
394       GISelKnownBits &KB, GISelCSEInfo *CSEInfo,
395       const AArch64PostLegalizerCombinerImplRuleConfig &RuleConfig,
396       const AArch64Subtarget &STI, MachineDominatorTree *MDT,
397       const LegalizerInfo *LI);
398 
399   static const char *getName() { return "AArch64PostLegalizerCombiner"; }
400 
401   bool tryCombineAll(MachineInstr &I) const override;
402 
403 private:
404 #define GET_GICOMBINER_CLASS_MEMBERS
405 #include "AArch64GenPostLegalizeGICombiner.inc"
406 #undef GET_GICOMBINER_CLASS_MEMBERS
407 };
408 
409 #define GET_GICOMBINER_IMPL
410 #include "AArch64GenPostLegalizeGICombiner.inc"
411 #undef GET_GICOMBINER_IMPL
412 
413 AArch64PostLegalizerCombinerImpl::AArch64PostLegalizerCombinerImpl(
414     MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC,
415     GISelKnownBits &KB, GISelCSEInfo *CSEInfo,
416     const AArch64PostLegalizerCombinerImplRuleConfig &RuleConfig,
417     const AArch64Subtarget &STI, MachineDominatorTree *MDT,
418     const LegalizerInfo *LI)
419     : Combiner(MF, CInfo, TPC, &KB, CSEInfo),
420       Helper(Observer, B, /*IsPreLegalize*/ false, &KB, MDT, LI),
421       RuleConfig(RuleConfig), STI(STI),
422 #define GET_GICOMBINER_CONSTRUCTOR_INITS
423 #include "AArch64GenPostLegalizeGICombiner.inc"
424 #undef GET_GICOMBINER_CONSTRUCTOR_INITS
425 {
426 }
427 
428 class AArch64PostLegalizerCombiner : public MachineFunctionPass {
429 public:
430   static char ID;
431 
432   AArch64PostLegalizerCombiner(bool IsOptNone = false);
433 
434   StringRef getPassName() const override {
435     return "AArch64PostLegalizerCombiner";
436   }
437 
438   bool runOnMachineFunction(MachineFunction &MF) override;
439   void getAnalysisUsage(AnalysisUsage &AU) const override;
440 
441 private:
442   bool IsOptNone;
443   AArch64PostLegalizerCombinerImplRuleConfig RuleConfig;
444 
445 
446   struct StoreInfo {
447     GStore *St = nullptr;
448     // The G_PTR_ADD that's used by the store. We keep this to cache the
449     // MachineInstr def.
450     GPtrAdd *Ptr = nullptr;
451     // The signed offset to the Ptr instruction.
452     int64_t Offset = 0;
453     LLT StoredType;
454   };
455   bool tryOptimizeConsecStores(SmallVectorImpl<StoreInfo> &Stores,
456                                CSEMIRBuilder &MIB);
457 
458   bool optimizeConsecutiveMemOpAddressing(MachineFunction &MF,
459                                           CSEMIRBuilder &MIB);
460 };
461 } // end anonymous namespace
462 
463 void AArch64PostLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const {
464   AU.addRequired<TargetPassConfig>();
465   AU.setPreservesCFG();
466   getSelectionDAGFallbackAnalysisUsage(AU);
467   AU.addRequired<GISelKnownBitsAnalysis>();
468   AU.addPreserved<GISelKnownBitsAnalysis>();
469   if (!IsOptNone) {
470     AU.addRequired<MachineDominatorTree>();
471     AU.addPreserved<MachineDominatorTree>();
472     AU.addRequired<GISelCSEAnalysisWrapperPass>();
473     AU.addPreserved<GISelCSEAnalysisWrapperPass>();
474   }
475   MachineFunctionPass::getAnalysisUsage(AU);
476 }
477 
478 AArch64PostLegalizerCombiner::AArch64PostLegalizerCombiner(bool IsOptNone)
479     : MachineFunctionPass(ID), IsOptNone(IsOptNone) {
480   initializeAArch64PostLegalizerCombinerPass(*PassRegistry::getPassRegistry());
481 
482   if (!RuleConfig.parseCommandLineOption())
483     report_fatal_error("Invalid rule identifier");
484 }
485 
486 bool AArch64PostLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) {
487   if (MF.getProperties().hasProperty(
488           MachineFunctionProperties::Property::FailedISel))
489     return false;
490   assert(MF.getProperties().hasProperty(
491              MachineFunctionProperties::Property::Legalized) &&
492          "Expected a legalized function?");
493   auto *TPC = &getAnalysis<TargetPassConfig>();
494   const Function &F = MF.getFunction();
495   bool EnableOpt =
496       MF.getTarget().getOptLevel() != CodeGenOptLevel::None && !skipFunction(F);
497 
498   const AArch64Subtarget &ST = MF.getSubtarget<AArch64Subtarget>();
499   const auto *LI = ST.getLegalizerInfo();
500 
501   GISelKnownBits *KB = &getAnalysis<GISelKnownBitsAnalysis>().get(MF);
502   MachineDominatorTree *MDT =
503       IsOptNone ? nullptr : &getAnalysis<MachineDominatorTree>();
504   GISelCSEAnalysisWrapper &Wrapper =
505       getAnalysis<GISelCSEAnalysisWrapperPass>().getCSEWrapper();
506   auto *CSEInfo = &Wrapper.get(TPC->getCSEConfig());
507 
508   CombinerInfo CInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false,
509                      /*LegalizerInfo*/ nullptr, EnableOpt, F.hasOptSize(),
510                      F.hasMinSize());
511   AArch64PostLegalizerCombinerImpl Impl(MF, CInfo, TPC, *KB, CSEInfo,
512                                         RuleConfig, ST, MDT, LI);
513   bool Changed = Impl.combineMachineInstrs();
514 
515   auto MIB = CSEMIRBuilder(MF);
516   MIB.setCSEInfo(CSEInfo);
517   Changed |= optimizeConsecutiveMemOpAddressing(MF, MIB);
518   return Changed;
519 }
520 
521 bool AArch64PostLegalizerCombiner::tryOptimizeConsecStores(
522     SmallVectorImpl<StoreInfo> &Stores, CSEMIRBuilder &MIB) {
523   if (Stores.size() <= 2)
524     return false;
525 
526   // Profitabity checks:
527   int64_t BaseOffset = Stores[0].Offset;
528   unsigned NumPairsExpected = Stores.size() / 2;
529   unsigned TotalInstsExpected = NumPairsExpected + (Stores.size() % 2);
530   // Size savings will depend on whether we can fold the offset, as an
531   // immediate of an ADD.
532   auto &TLI = *MIB.getMF().getSubtarget().getTargetLowering();
533   if (!TLI.isLegalAddImmediate(BaseOffset))
534     TotalInstsExpected++;
535   int SavingsExpected = Stores.size() - TotalInstsExpected;
536   if (SavingsExpected <= 0)
537     return false;
538 
539   auto &MRI = MIB.getMF().getRegInfo();
540 
541   // We have a series of consecutive stores. Factor out the common base
542   // pointer and rewrite the offsets.
543   Register NewBase = Stores[0].Ptr->getReg(0);
544   for (auto &SInfo : Stores) {
545     // Compute a new pointer with the new base ptr and adjusted offset.
546     MIB.setInstrAndDebugLoc(*SInfo.St);
547     auto NewOff = MIB.buildConstant(LLT::scalar(64), SInfo.Offset - BaseOffset);
548     auto NewPtr = MIB.buildPtrAdd(MRI.getType(SInfo.St->getPointerReg()),
549                                   NewBase, NewOff);
550     if (MIB.getObserver())
551       MIB.getObserver()->changingInstr(*SInfo.St);
552     SInfo.St->getOperand(1).setReg(NewPtr.getReg(0));
553     if (MIB.getObserver())
554       MIB.getObserver()->changedInstr(*SInfo.St);
555   }
556   LLVM_DEBUG(dbgs() << "Split a series of " << Stores.size()
557                     << " stores into a base pointer and offsets.\n");
558   return true;
559 }
560 
561 static cl::opt<bool>
562     EnableConsecutiveMemOpOpt("aarch64-postlegalizer-consecutive-memops",
563                               cl::init(true), cl::Hidden,
564                               cl::desc("Enable consecutive memop optimization "
565                                        "in AArch64PostLegalizerCombiner"));
566 
567 bool AArch64PostLegalizerCombiner::optimizeConsecutiveMemOpAddressing(
568     MachineFunction &MF, CSEMIRBuilder &MIB) {
569   // This combine needs to run after all reassociations/folds on pointer
570   // addressing have been done, specifically those that combine two G_PTR_ADDs
571   // with constant offsets into a single G_PTR_ADD with a combined offset.
572   // The goal of this optimization is to undo that combine in the case where
573   // doing so has prevented the formation of pair stores due to illegal
574   // addressing modes of STP. The reason that we do it here is because
575   // it's much easier to undo the transformation of a series consecutive
576   // mem ops, than it is to detect when doing it would be a bad idea looking
577   // at a single G_PTR_ADD in the reassociation/ptradd_immed_chain combine.
578   //
579   // An example:
580   //   G_STORE %11:_(<2 x s64>), %base:_(p0) :: (store (<2 x s64>), align 1)
581   //   %off1:_(s64) = G_CONSTANT i64 4128
582   //   %p1:_(p0) = G_PTR_ADD %0:_, %off1:_(s64)
583   //   G_STORE %11:_(<2 x s64>), %p1:_(p0) :: (store (<2 x s64>), align 1)
584   //   %off2:_(s64) = G_CONSTANT i64 4144
585   //   %p2:_(p0) = G_PTR_ADD %0:_, %off2:_(s64)
586   //   G_STORE %11:_(<2 x s64>), %p2:_(p0) :: (store (<2 x s64>), align 1)
587   //   %off3:_(s64) = G_CONSTANT i64 4160
588   //   %p3:_(p0) = G_PTR_ADD %0:_, %off3:_(s64)
589   //   G_STORE %11:_(<2 x s64>), %17:_(p0) :: (store (<2 x s64>), align 1)
590   bool Changed = false;
591   auto &MRI = MF.getRegInfo();
592 
593   if (!EnableConsecutiveMemOpOpt)
594     return Changed;
595 
596   SmallVector<StoreInfo, 8> Stores;
597   // If we see a load, then we keep track of any values defined by it.
598   // In the following example, STP formation will fail anyway because
599   // the latter store is using a load result that appears after the
600   // the prior store. In this situation if we factor out the offset then
601   // we increase code size for no benefit.
602   //   G_STORE %v1:_(s64), %base:_(p0) :: (store (s64))
603   //   %v2:_(s64) = G_LOAD %ldptr:_(p0) :: (load (s64))
604   //   G_STORE %v2:_(s64), %base:_(p0) :: (store (s64))
605   SmallVector<Register> LoadValsSinceLastStore;
606 
607   auto storeIsValid = [&](StoreInfo &Last, StoreInfo New) {
608     // Check if this store is consecutive to the last one.
609     if (Last.Ptr->getBaseReg() != New.Ptr->getBaseReg() ||
610         (Last.Offset + static_cast<int64_t>(Last.StoredType.getSizeInBytes()) !=
611          New.Offset) ||
612         Last.StoredType != New.StoredType)
613       return false;
614 
615     // Check if this store is using a load result that appears after the
616     // last store. If so, bail out.
617     if (any_of(LoadValsSinceLastStore, [&](Register LoadVal) {
618           return New.St->getValueReg() == LoadVal;
619         }))
620       return false;
621 
622     // Check if the current offset would be too large for STP.
623     // If not, then STP formation should be able to handle it, so we don't
624     // need to do anything.
625     int64_t MaxLegalOffset;
626     switch (New.StoredType.getSizeInBits()) {
627     case 32:
628       MaxLegalOffset = 252;
629       break;
630     case 64:
631       MaxLegalOffset = 504;
632       break;
633     case 128:
634       MaxLegalOffset = 1008;
635       break;
636     default:
637       llvm_unreachable("Unexpected stored type size");
638     }
639     if (New.Offset < MaxLegalOffset)
640       return false;
641 
642     // If factoring it out still wouldn't help then don't bother.
643     return New.Offset - Stores[0].Offset <= MaxLegalOffset;
644   };
645 
646   auto resetState = [&]() {
647     Stores.clear();
648     LoadValsSinceLastStore.clear();
649   };
650 
651   for (auto &MBB : MF) {
652     // We're looking inside a single BB at a time since the memset pattern
653     // should only be in a single block.
654     resetState();
655     for (auto &MI : MBB) {
656       if (auto *St = dyn_cast<GStore>(&MI)) {
657         Register PtrBaseReg;
658         APInt Offset;
659         LLT StoredValTy = MRI.getType(St->getValueReg());
660         unsigned ValSize = StoredValTy.getSizeInBits();
661         if (ValSize < 32 || ValSize != St->getMMO().getSizeInBits())
662           continue;
663 
664         Register PtrReg = St->getPointerReg();
665         if (mi_match(
666                 PtrReg, MRI,
667                 m_OneNonDBGUse(m_GPtrAdd(m_Reg(PtrBaseReg), m_ICst(Offset))))) {
668           GPtrAdd *PtrAdd = cast<GPtrAdd>(MRI.getVRegDef(PtrReg));
669           StoreInfo New = {St, PtrAdd, Offset.getSExtValue(), StoredValTy};
670 
671           if (Stores.empty()) {
672             Stores.push_back(New);
673             continue;
674           }
675 
676           // Check if this store is a valid continuation of the sequence.
677           auto &Last = Stores.back();
678           if (storeIsValid(Last, New)) {
679             Stores.push_back(New);
680             LoadValsSinceLastStore.clear(); // Reset the load value tracking.
681           } else {
682             // The store isn't a valid to consider for the prior sequence,
683             // so try to optimize what we have so far and start a new sequence.
684             Changed |= tryOptimizeConsecStores(Stores, MIB);
685             resetState();
686             Stores.push_back(New);
687           }
688         }
689       } else if (auto *Ld = dyn_cast<GLoad>(&MI)) {
690         LoadValsSinceLastStore.push_back(Ld->getDstReg());
691       }
692     }
693     Changed |= tryOptimizeConsecStores(Stores, MIB);
694     resetState();
695   }
696 
697   return Changed;
698 }
699 
700 char AArch64PostLegalizerCombiner::ID = 0;
701 INITIALIZE_PASS_BEGIN(AArch64PostLegalizerCombiner, DEBUG_TYPE,
702                       "Combine AArch64 MachineInstrs after legalization", false,
703                       false)
704 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
705 INITIALIZE_PASS_DEPENDENCY(GISelKnownBitsAnalysis)
706 INITIALIZE_PASS_END(AArch64PostLegalizerCombiner, DEBUG_TYPE,
707                     "Combine AArch64 MachineInstrs after legalization", false,
708                     false)
709 
710 namespace llvm {
711 FunctionPass *createAArch64PostLegalizerCombiner(bool IsOptNone) {
712   return new AArch64PostLegalizerCombiner(IsOptNone);
713 }
714 } // end namespace llvm
715