xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp (revision f126d349810fdb512c0b01e101342d430b947488)
1 //=== AArch64PostLegalizerCombiner.cpp --------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// Post-legalization combines on generic MachineInstrs.
11 ///
12 /// The combines here must preserve instruction legality.
13 ///
14 /// Lowering combines (e.g. pseudo matching) should be handled by
15 /// AArch64PostLegalizerLowering.
16 ///
17 /// Combines which don't rely on instruction legality should go in the
18 /// AArch64PreLegalizerCombiner.
19 ///
20 //===----------------------------------------------------------------------===//
21 
22 #include "AArch64TargetMachine.h"
23 #include "llvm/CodeGen/GlobalISel/Combiner.h"
24 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
25 #include "llvm/CodeGen/GlobalISel/CombinerInfo.h"
26 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
27 #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h"
28 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
29 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
30 #include "llvm/CodeGen/GlobalISel/Utils.h"
31 #include "llvm/CodeGen/MachineDominators.h"
32 #include "llvm/CodeGen/MachineFunctionPass.h"
33 #include "llvm/CodeGen/MachineRegisterInfo.h"
34 #include "llvm/CodeGen/TargetOpcodes.h"
35 #include "llvm/CodeGen/TargetPassConfig.h"
36 #include "llvm/Support/Debug.h"
37 
38 #define DEBUG_TYPE "aarch64-postlegalizer-combiner"
39 
40 using namespace llvm;
41 using namespace MIPatternMatch;
42 
43 /// This combine tries do what performExtractVectorEltCombine does in SDAG.
44 /// Rewrite for pairwise fadd pattern
45 ///   (s32 (g_extract_vector_elt
46 ///           (g_fadd (vXs32 Other)
47 ///                  (g_vector_shuffle (vXs32 Other) undef <1,X,...> )) 0))
48 /// ->
49 ///   (s32 (g_fadd (g_extract_vector_elt (vXs32 Other) 0)
50 ///              (g_extract_vector_elt (vXs32 Other) 1))
51 bool matchExtractVecEltPairwiseAdd(
52     MachineInstr &MI, MachineRegisterInfo &MRI,
53     std::tuple<unsigned, LLT, Register> &MatchInfo) {
54   Register Src1 = MI.getOperand(1).getReg();
55   Register Src2 = MI.getOperand(2).getReg();
56   LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
57 
58   auto Cst = getIConstantVRegValWithLookThrough(Src2, MRI);
59   if (!Cst || Cst->Value != 0)
60     return false;
61   // SDAG also checks for FullFP16, but this looks to be beneficial anyway.
62 
63   // Now check for an fadd operation. TODO: expand this for integer add?
64   auto *FAddMI = getOpcodeDef(TargetOpcode::G_FADD, Src1, MRI);
65   if (!FAddMI)
66     return false;
67 
68   // If we add support for integer add, must restrict these types to just s64.
69   unsigned DstSize = DstTy.getSizeInBits();
70   if (DstSize != 16 && DstSize != 32 && DstSize != 64)
71     return false;
72 
73   Register Src1Op1 = FAddMI->getOperand(1).getReg();
74   Register Src1Op2 = FAddMI->getOperand(2).getReg();
75   MachineInstr *Shuffle =
76       getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op2, MRI);
77   MachineInstr *Other = MRI.getVRegDef(Src1Op1);
78   if (!Shuffle) {
79     Shuffle = getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op1, MRI);
80     Other = MRI.getVRegDef(Src1Op2);
81   }
82 
83   // We're looking for a shuffle that moves the second element to index 0.
84   if (Shuffle && Shuffle->getOperand(3).getShuffleMask()[0] == 1 &&
85       Other == MRI.getVRegDef(Shuffle->getOperand(1).getReg())) {
86     std::get<0>(MatchInfo) = TargetOpcode::G_FADD;
87     std::get<1>(MatchInfo) = DstTy;
88     std::get<2>(MatchInfo) = Other->getOperand(0).getReg();
89     return true;
90   }
91   return false;
92 }
93 
94 bool applyExtractVecEltPairwiseAdd(
95     MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B,
96     std::tuple<unsigned, LLT, Register> &MatchInfo) {
97   unsigned Opc = std::get<0>(MatchInfo);
98   assert(Opc == TargetOpcode::G_FADD && "Unexpected opcode!");
99   // We want to generate two extracts of elements 0 and 1, and add them.
100   LLT Ty = std::get<1>(MatchInfo);
101   Register Src = std::get<2>(MatchInfo);
102   LLT s64 = LLT::scalar(64);
103   B.setInstrAndDebugLoc(MI);
104   auto Elt0 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 0));
105   auto Elt1 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 1));
106   B.buildInstr(Opc, {MI.getOperand(0).getReg()}, {Elt0, Elt1});
107   MI.eraseFromParent();
108   return true;
109 }
110 
111 static bool isSignExtended(Register R, MachineRegisterInfo &MRI) {
112   // TODO: check if extended build vector as well.
113   unsigned Opc = MRI.getVRegDef(R)->getOpcode();
114   return Opc == TargetOpcode::G_SEXT || Opc == TargetOpcode::G_SEXT_INREG;
115 }
116 
117 static bool isZeroExtended(Register R, MachineRegisterInfo &MRI) {
118   // TODO: check if extended build vector as well.
119   return MRI.getVRegDef(R)->getOpcode() == TargetOpcode::G_ZEXT;
120 }
121 
122 bool matchAArch64MulConstCombine(
123     MachineInstr &MI, MachineRegisterInfo &MRI,
124     std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) {
125   assert(MI.getOpcode() == TargetOpcode::G_MUL);
126   Register LHS = MI.getOperand(1).getReg();
127   Register RHS = MI.getOperand(2).getReg();
128   Register Dst = MI.getOperand(0).getReg();
129   const LLT Ty = MRI.getType(LHS);
130 
131   // The below optimizations require a constant RHS.
132   auto Const = getIConstantVRegValWithLookThrough(RHS, MRI);
133   if (!Const)
134     return false;
135 
136   const APInt ConstValue = Const->Value.sextOrSelf(Ty.getSizeInBits());
137   // The following code is ported from AArch64ISelLowering.
138   // Multiplication of a power of two plus/minus one can be done more
139   // cheaply as as shift+add/sub. For now, this is true unilaterally. If
140   // future CPUs have a cheaper MADD instruction, this may need to be
141   // gated on a subtarget feature. For Cyclone, 32-bit MADD is 4 cycles and
142   // 64-bit is 5 cycles, so this is always a win.
143   // More aggressively, some multiplications N0 * C can be lowered to
144   // shift+add+shift if the constant C = A * B where A = 2^N + 1 and B = 2^M,
145   // e.g. 6=3*2=(2+1)*2.
146   // TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45
147   // which equals to (1+2)*16-(1+2).
148   // TrailingZeroes is used to test if the mul can be lowered to
149   // shift+add+shift.
150   unsigned TrailingZeroes = ConstValue.countTrailingZeros();
151   if (TrailingZeroes) {
152     // Conservatively do not lower to shift+add+shift if the mul might be
153     // folded into smul or umul.
154     if (MRI.hasOneNonDBGUse(LHS) &&
155         (isSignExtended(LHS, MRI) || isZeroExtended(LHS, MRI)))
156       return false;
157     // Conservatively do not lower to shift+add+shift if the mul might be
158     // folded into madd or msub.
159     if (MRI.hasOneNonDBGUse(Dst)) {
160       MachineInstr &UseMI = *MRI.use_instr_begin(Dst);
161       unsigned UseOpc = UseMI.getOpcode();
162       if (UseOpc == TargetOpcode::G_ADD || UseOpc == TargetOpcode::G_PTR_ADD ||
163           UseOpc == TargetOpcode::G_SUB)
164         return false;
165     }
166   }
167   // Use ShiftedConstValue instead of ConstValue to support both shift+add/sub
168   // and shift+add+shift.
169   APInt ShiftedConstValue = ConstValue.ashr(TrailingZeroes);
170 
171   unsigned ShiftAmt, AddSubOpc;
172   // Is the shifted value the LHS operand of the add/sub?
173   bool ShiftValUseIsLHS = true;
174   // Do we need to negate the result?
175   bool NegateResult = false;
176 
177   if (ConstValue.isNonNegative()) {
178     // (mul x, 2^N + 1) => (add (shl x, N), x)
179     // (mul x, 2^N - 1) => (sub (shl x, N), x)
180     // (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M)
181     APInt SCVMinus1 = ShiftedConstValue - 1;
182     APInt CVPlus1 = ConstValue + 1;
183     if (SCVMinus1.isPowerOf2()) {
184       ShiftAmt = SCVMinus1.logBase2();
185       AddSubOpc = TargetOpcode::G_ADD;
186     } else if (CVPlus1.isPowerOf2()) {
187       ShiftAmt = CVPlus1.logBase2();
188       AddSubOpc = TargetOpcode::G_SUB;
189     } else
190       return false;
191   } else {
192     // (mul x, -(2^N - 1)) => (sub x, (shl x, N))
193     // (mul x, -(2^N + 1)) => - (add (shl x, N), x)
194     APInt CVNegPlus1 = -ConstValue + 1;
195     APInt CVNegMinus1 = -ConstValue - 1;
196     if (CVNegPlus1.isPowerOf2()) {
197       ShiftAmt = CVNegPlus1.logBase2();
198       AddSubOpc = TargetOpcode::G_SUB;
199       ShiftValUseIsLHS = false;
200     } else if (CVNegMinus1.isPowerOf2()) {
201       ShiftAmt = CVNegMinus1.logBase2();
202       AddSubOpc = TargetOpcode::G_ADD;
203       NegateResult = true;
204     } else
205       return false;
206   }
207 
208   if (NegateResult && TrailingZeroes)
209     return false;
210 
211   ApplyFn = [=](MachineIRBuilder &B, Register DstReg) {
212     auto Shift = B.buildConstant(LLT::scalar(64), ShiftAmt);
213     auto ShiftedVal = B.buildShl(Ty, LHS, Shift);
214 
215     Register AddSubLHS = ShiftValUseIsLHS ? ShiftedVal.getReg(0) : LHS;
216     Register AddSubRHS = ShiftValUseIsLHS ? LHS : ShiftedVal.getReg(0);
217     auto Res = B.buildInstr(AddSubOpc, {Ty}, {AddSubLHS, AddSubRHS});
218     assert(!(NegateResult && TrailingZeroes) &&
219            "NegateResult and TrailingZeroes cannot both be true for now.");
220     // Negate the result.
221     if (NegateResult) {
222       B.buildSub(DstReg, B.buildConstant(Ty, 0), Res);
223       return;
224     }
225     // Shift the result.
226     if (TrailingZeroes) {
227       B.buildShl(DstReg, Res, B.buildConstant(LLT::scalar(64), TrailingZeroes));
228       return;
229     }
230     B.buildCopy(DstReg, Res.getReg(0));
231   };
232   return true;
233 }
234 
235 bool applyAArch64MulConstCombine(
236     MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B,
237     std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) {
238   B.setInstrAndDebugLoc(MI);
239   ApplyFn(B, MI.getOperand(0).getReg());
240   MI.eraseFromParent();
241   return true;
242 }
243 
244 /// Try to fold a G_MERGE_VALUES of 2 s32 sources, where the second source
245 /// is a zero, into a G_ZEXT of the first.
246 bool matchFoldMergeToZext(MachineInstr &MI, MachineRegisterInfo &MRI) {
247   auto &Merge = cast<GMerge>(MI);
248   LLT SrcTy = MRI.getType(Merge.getSourceReg(0));
249   if (SrcTy != LLT::scalar(32) || Merge.getNumSources() != 2)
250     return false;
251   return mi_match(Merge.getSourceReg(1), MRI, m_SpecificICst(0));
252 }
253 
254 void applyFoldMergeToZext(MachineInstr &MI, MachineRegisterInfo &MRI,
255                           MachineIRBuilder &B, GISelChangeObserver &Observer) {
256   // Mutate %d(s64) = G_MERGE_VALUES %a(s32), 0(s32)
257   //  ->
258   // %d(s64) = G_ZEXT %a(s32)
259   Observer.changingInstr(MI);
260   MI.setDesc(B.getTII().get(TargetOpcode::G_ZEXT));
261   MI.RemoveOperand(2);
262   Observer.changedInstr(MI);
263 }
264 
265 /// \returns True if a G_ANYEXT instruction \p MI should be mutated to a G_ZEXT
266 /// instruction.
267 static bool matchMutateAnyExtToZExt(MachineInstr &MI, MachineRegisterInfo &MRI) {
268   // If this is coming from a scalar compare then we can use a G_ZEXT instead of
269   // a G_ANYEXT:
270   //
271   // %cmp:_(s32) = G_[I|F]CMP ... <-- produces 0/1.
272   // %ext:_(s64) = G_ANYEXT %cmp(s32)
273   //
274   // By doing this, we can leverage more KnownBits combines.
275   assert(MI.getOpcode() == TargetOpcode::G_ANYEXT);
276   Register Dst = MI.getOperand(0).getReg();
277   Register Src = MI.getOperand(1).getReg();
278   return MRI.getType(Dst).isScalar() &&
279          mi_match(Src, MRI,
280                   m_any_of(m_GICmp(m_Pred(), m_Reg(), m_Reg()),
281                            m_GFCmp(m_Pred(), m_Reg(), m_Reg())));
282 }
283 
284 static void applyMutateAnyExtToZExt(MachineInstr &MI, MachineRegisterInfo &MRI,
285                               MachineIRBuilder &B,
286                               GISelChangeObserver &Observer) {
287   Observer.changingInstr(MI);
288   MI.setDesc(B.getTII().get(TargetOpcode::G_ZEXT));
289   Observer.changedInstr(MI);
290 }
291 
292 /// Match a 128b store of zero and split it into two 64 bit stores, for
293 /// size/performance reasons.
294 static bool matchSplitStoreZero128(MachineInstr &MI, MachineRegisterInfo &MRI) {
295   GStore &Store = cast<GStore>(MI);
296   if (!Store.isSimple())
297     return false;
298   LLT ValTy = MRI.getType(Store.getValueReg());
299   if (!ValTy.isVector() || ValTy.getSizeInBits() != 128)
300     return false;
301   if (ValTy.getSizeInBits() != Store.getMemSizeInBits())
302     return false; // Don't split truncating stores.
303   if (!MRI.hasOneNonDBGUse(Store.getValueReg()))
304     return false;
305   auto MaybeCst = isConstantOrConstantSplatVector(
306       *MRI.getVRegDef(Store.getValueReg()), MRI);
307   return MaybeCst && MaybeCst->isZero();
308 }
309 
310 static void applySplitStoreZero128(MachineInstr &MI, MachineRegisterInfo &MRI,
311                                    MachineIRBuilder &B,
312                                    GISelChangeObserver &Observer) {
313   B.setInstrAndDebugLoc(MI);
314   GStore &Store = cast<GStore>(MI);
315   assert(MRI.getType(Store.getValueReg()).isVector() &&
316          "Expected a vector store value");
317   LLT NewTy = LLT::scalar(64);
318   Register PtrReg = Store.getPointerReg();
319   auto Zero = B.buildConstant(NewTy, 0);
320   auto HighPtr = B.buildPtrAdd(MRI.getType(PtrReg), PtrReg,
321                                B.buildConstant(LLT::scalar(64), 8));
322   auto &MF = *MI.getMF();
323   auto *LowMMO = MF.getMachineMemOperand(&Store.getMMO(), 0, NewTy);
324   auto *HighMMO = MF.getMachineMemOperand(&Store.getMMO(), 8, NewTy);
325   B.buildStore(Zero, PtrReg, *LowMMO);
326   B.buildStore(Zero, HighPtr, *HighMMO);
327   Store.eraseFromParent();
328 }
329 
330 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS
331 #include "AArch64GenPostLegalizeGICombiner.inc"
332 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS
333 
334 namespace {
335 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_H
336 #include "AArch64GenPostLegalizeGICombiner.inc"
337 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_H
338 
339 class AArch64PostLegalizerCombinerInfo : public CombinerInfo {
340   GISelKnownBits *KB;
341   MachineDominatorTree *MDT;
342 
343 public:
344   AArch64GenPostLegalizerCombinerHelperRuleConfig GeneratedRuleCfg;
345 
346   AArch64PostLegalizerCombinerInfo(bool EnableOpt, bool OptSize, bool MinSize,
347                                    GISelKnownBits *KB,
348                                    MachineDominatorTree *MDT)
349       : CombinerInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false,
350                      /*LegalizerInfo*/ nullptr, EnableOpt, OptSize, MinSize),
351         KB(KB), MDT(MDT) {
352     if (!GeneratedRuleCfg.parseCommandLineOption())
353       report_fatal_error("Invalid rule identifier");
354   }
355 
356   virtual bool combine(GISelChangeObserver &Observer, MachineInstr &MI,
357                        MachineIRBuilder &B) const override;
358 };
359 
360 bool AArch64PostLegalizerCombinerInfo::combine(GISelChangeObserver &Observer,
361                                                MachineInstr &MI,
362                                                MachineIRBuilder &B) const {
363   const auto *LI =
364       MI.getParent()->getParent()->getSubtarget().getLegalizerInfo();
365   CombinerHelper Helper(Observer, B, KB, MDT, LI);
366   AArch64GenPostLegalizerCombinerHelper Generated(GeneratedRuleCfg);
367   return Generated.tryCombineAll(Observer, MI, B, Helper);
368 }
369 
370 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_CPP
371 #include "AArch64GenPostLegalizeGICombiner.inc"
372 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_CPP
373 
374 class AArch64PostLegalizerCombiner : public MachineFunctionPass {
375 public:
376   static char ID;
377 
378   AArch64PostLegalizerCombiner(bool IsOptNone = false);
379 
380   StringRef getPassName() const override {
381     return "AArch64PostLegalizerCombiner";
382   }
383 
384   bool runOnMachineFunction(MachineFunction &MF) override;
385   void getAnalysisUsage(AnalysisUsage &AU) const override;
386 
387 private:
388   bool IsOptNone;
389 };
390 } // end anonymous namespace
391 
392 void AArch64PostLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const {
393   AU.addRequired<TargetPassConfig>();
394   AU.setPreservesCFG();
395   getSelectionDAGFallbackAnalysisUsage(AU);
396   AU.addRequired<GISelKnownBitsAnalysis>();
397   AU.addPreserved<GISelKnownBitsAnalysis>();
398   if (!IsOptNone) {
399     AU.addRequired<MachineDominatorTree>();
400     AU.addPreserved<MachineDominatorTree>();
401     AU.addRequired<GISelCSEAnalysisWrapperPass>();
402     AU.addPreserved<GISelCSEAnalysisWrapperPass>();
403   }
404   MachineFunctionPass::getAnalysisUsage(AU);
405 }
406 
407 AArch64PostLegalizerCombiner::AArch64PostLegalizerCombiner(bool IsOptNone)
408     : MachineFunctionPass(ID), IsOptNone(IsOptNone) {
409   initializeAArch64PostLegalizerCombinerPass(*PassRegistry::getPassRegistry());
410 }
411 
412 bool AArch64PostLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) {
413   if (MF.getProperties().hasProperty(
414           MachineFunctionProperties::Property::FailedISel))
415     return false;
416   assert(MF.getProperties().hasProperty(
417              MachineFunctionProperties::Property::Legalized) &&
418          "Expected a legalized function?");
419   auto *TPC = &getAnalysis<TargetPassConfig>();
420   const Function &F = MF.getFunction();
421   bool EnableOpt =
422       MF.getTarget().getOptLevel() != CodeGenOpt::None && !skipFunction(F);
423   GISelKnownBits *KB = &getAnalysis<GISelKnownBitsAnalysis>().get(MF);
424   MachineDominatorTree *MDT =
425       IsOptNone ? nullptr : &getAnalysis<MachineDominatorTree>();
426   AArch64PostLegalizerCombinerInfo PCInfo(EnableOpt, F.hasOptSize(),
427                                           F.hasMinSize(), KB, MDT);
428   GISelCSEAnalysisWrapper &Wrapper =
429       getAnalysis<GISelCSEAnalysisWrapperPass>().getCSEWrapper();
430   auto *CSEInfo = &Wrapper.get(TPC->getCSEConfig());
431   Combiner C(PCInfo, TPC);
432   return C.combineMachineInstrs(MF, CSEInfo);
433 }
434 
435 char AArch64PostLegalizerCombiner::ID = 0;
436 INITIALIZE_PASS_BEGIN(AArch64PostLegalizerCombiner, DEBUG_TYPE,
437                       "Combine AArch64 MachineInstrs after legalization", false,
438                       false)
439 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
440 INITIALIZE_PASS_DEPENDENCY(GISelKnownBitsAnalysis)
441 INITIALIZE_PASS_END(AArch64PostLegalizerCombiner, DEBUG_TYPE,
442                     "Combine AArch64 MachineInstrs after legalization", false,
443                     false)
444 
445 namespace llvm {
446 FunctionPass *createAArch64PostLegalizerCombiner(bool IsOptNone) {
447   return new AArch64PostLegalizerCombiner(IsOptNone);
448 }
449 } // end namespace llvm
450