1 //=== AArch64PostLegalizerCombiner.cpp --------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// Post-legalization combines on generic MachineInstrs.
11 ///
12 /// The combines here must preserve instruction legality.
13 ///
14 /// Lowering combines (e.g. pseudo matching) should be handled by
15 /// AArch64PostLegalizerLowering.
16 ///
17 /// Combines which don't rely on instruction legality should go in the
18 /// AArch64PreLegalizerCombiner.
19 ///
20 //===----------------------------------------------------------------------===//
21
22 #include "AArch64TargetMachine.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/CodeGen/GlobalISel/CSEInfo.h"
25 #include "llvm/CodeGen/GlobalISel/CSEMIRBuilder.h"
26 #include "llvm/CodeGen/GlobalISel/Combiner.h"
27 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
28 #include "llvm/CodeGen/GlobalISel/CombinerInfo.h"
29 #include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h"
30 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
31 #include "llvm/CodeGen/GlobalISel/GISelValueTracking.h"
32 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
33 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
34 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
35 #include "llvm/CodeGen/GlobalISel/Utils.h"
36 #include "llvm/CodeGen/MachineDominators.h"
37 #include "llvm/CodeGen/MachineFunctionPass.h"
38 #include "llvm/CodeGen/MachineRegisterInfo.h"
39 #include "llvm/CodeGen/TargetOpcodes.h"
40 #include "llvm/CodeGen/TargetPassConfig.h"
41 #include "llvm/Support/Debug.h"
42
43 #define GET_GICOMBINER_DEPS
44 #include "AArch64GenPostLegalizeGICombiner.inc"
45 #undef GET_GICOMBINER_DEPS
46
47 #define DEBUG_TYPE "aarch64-postlegalizer-combiner"
48
49 using namespace llvm;
50 using namespace MIPatternMatch;
51
52 namespace {
53
54 #define GET_GICOMBINER_TYPES
55 #include "AArch64GenPostLegalizeGICombiner.inc"
56 #undef GET_GICOMBINER_TYPES
57
58 /// This combine tries do what performExtractVectorEltCombine does in SDAG.
59 /// Rewrite for pairwise fadd pattern
60 /// (s32 (g_extract_vector_elt
61 /// (g_fadd (vXs32 Other)
62 /// (g_vector_shuffle (vXs32 Other) undef <1,X,...> )) 0))
63 /// ->
64 /// (s32 (g_fadd (g_extract_vector_elt (vXs32 Other) 0)
65 /// (g_extract_vector_elt (vXs32 Other) 1))
matchExtractVecEltPairwiseAdd(MachineInstr & MI,MachineRegisterInfo & MRI,std::tuple<unsigned,LLT,Register> & MatchInfo)66 bool matchExtractVecEltPairwiseAdd(
67 MachineInstr &MI, MachineRegisterInfo &MRI,
68 std::tuple<unsigned, LLT, Register> &MatchInfo) {
69 Register Src1 = MI.getOperand(1).getReg();
70 Register Src2 = MI.getOperand(2).getReg();
71 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
72
73 auto Cst = getIConstantVRegValWithLookThrough(Src2, MRI);
74 if (!Cst || Cst->Value != 0)
75 return false;
76 // SDAG also checks for FullFP16, but this looks to be beneficial anyway.
77
78 // Now check for an fadd operation. TODO: expand this for integer add?
79 auto *FAddMI = getOpcodeDef(TargetOpcode::G_FADD, Src1, MRI);
80 if (!FAddMI)
81 return false;
82
83 // If we add support for integer add, must restrict these types to just s64.
84 unsigned DstSize = DstTy.getSizeInBits();
85 if (DstSize != 16 && DstSize != 32 && DstSize != 64)
86 return false;
87
88 Register Src1Op1 = FAddMI->getOperand(1).getReg();
89 Register Src1Op2 = FAddMI->getOperand(2).getReg();
90 MachineInstr *Shuffle =
91 getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op2, MRI);
92 MachineInstr *Other = MRI.getVRegDef(Src1Op1);
93 if (!Shuffle) {
94 Shuffle = getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op1, MRI);
95 Other = MRI.getVRegDef(Src1Op2);
96 }
97
98 // We're looking for a shuffle that moves the second element to index 0.
99 if (Shuffle && Shuffle->getOperand(3).getShuffleMask()[0] == 1 &&
100 Other == MRI.getVRegDef(Shuffle->getOperand(1).getReg())) {
101 std::get<0>(MatchInfo) = TargetOpcode::G_FADD;
102 std::get<1>(MatchInfo) = DstTy;
103 std::get<2>(MatchInfo) = Other->getOperand(0).getReg();
104 return true;
105 }
106 return false;
107 }
108
applyExtractVecEltPairwiseAdd(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & B,std::tuple<unsigned,LLT,Register> & MatchInfo)109 void applyExtractVecEltPairwiseAdd(
110 MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B,
111 std::tuple<unsigned, LLT, Register> &MatchInfo) {
112 unsigned Opc = std::get<0>(MatchInfo);
113 assert(Opc == TargetOpcode::G_FADD && "Unexpected opcode!");
114 // We want to generate two extracts of elements 0 and 1, and add them.
115 LLT Ty = std::get<1>(MatchInfo);
116 Register Src = std::get<2>(MatchInfo);
117 LLT s64 = LLT::scalar(64);
118 B.setInstrAndDebugLoc(MI);
119 auto Elt0 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 0));
120 auto Elt1 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 1));
121 B.buildInstr(Opc, {MI.getOperand(0).getReg()}, {Elt0, Elt1});
122 MI.eraseFromParent();
123 }
124
isSignExtended(Register R,MachineRegisterInfo & MRI)125 bool isSignExtended(Register R, MachineRegisterInfo &MRI) {
126 // TODO: check if extended build vector as well.
127 unsigned Opc = MRI.getVRegDef(R)->getOpcode();
128 return Opc == TargetOpcode::G_SEXT || Opc == TargetOpcode::G_SEXT_INREG;
129 }
130
isZeroExtended(Register R,MachineRegisterInfo & MRI)131 bool isZeroExtended(Register R, MachineRegisterInfo &MRI) {
132 // TODO: check if extended build vector as well.
133 return MRI.getVRegDef(R)->getOpcode() == TargetOpcode::G_ZEXT;
134 }
135
matchAArch64MulConstCombine(MachineInstr & MI,MachineRegisterInfo & MRI,std::function<void (MachineIRBuilder & B,Register DstReg)> & ApplyFn)136 bool matchAArch64MulConstCombine(
137 MachineInstr &MI, MachineRegisterInfo &MRI,
138 std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) {
139 assert(MI.getOpcode() == TargetOpcode::G_MUL);
140 Register LHS = MI.getOperand(1).getReg();
141 Register RHS = MI.getOperand(2).getReg();
142 Register Dst = MI.getOperand(0).getReg();
143 const LLT Ty = MRI.getType(LHS);
144
145 // The below optimizations require a constant RHS.
146 auto Const = getIConstantVRegValWithLookThrough(RHS, MRI);
147 if (!Const)
148 return false;
149
150 APInt ConstValue = Const->Value.sext(Ty.getSizeInBits());
151 // The following code is ported from AArch64ISelLowering.
152 // Multiplication of a power of two plus/minus one can be done more
153 // cheaply as shift+add/sub. For now, this is true unilaterally. If
154 // future CPUs have a cheaper MADD instruction, this may need to be
155 // gated on a subtarget feature. For Cyclone, 32-bit MADD is 4 cycles and
156 // 64-bit is 5 cycles, so this is always a win.
157 // More aggressively, some multiplications N0 * C can be lowered to
158 // shift+add+shift if the constant C = A * B where A = 2^N + 1 and B = 2^M,
159 // e.g. 6=3*2=(2+1)*2.
160 // TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45
161 // which equals to (1+2)*16-(1+2).
162 // TrailingZeroes is used to test if the mul can be lowered to
163 // shift+add+shift.
164 unsigned TrailingZeroes = ConstValue.countr_zero();
165 if (TrailingZeroes) {
166 // Conservatively do not lower to shift+add+shift if the mul might be
167 // folded into smul or umul.
168 if (MRI.hasOneNonDBGUse(LHS) &&
169 (isSignExtended(LHS, MRI) || isZeroExtended(LHS, MRI)))
170 return false;
171 // Conservatively do not lower to shift+add+shift if the mul might be
172 // folded into madd or msub.
173 if (MRI.hasOneNonDBGUse(Dst)) {
174 MachineInstr &UseMI = *MRI.use_instr_begin(Dst);
175 unsigned UseOpc = UseMI.getOpcode();
176 if (UseOpc == TargetOpcode::G_ADD || UseOpc == TargetOpcode::G_PTR_ADD ||
177 UseOpc == TargetOpcode::G_SUB)
178 return false;
179 }
180 }
181 // Use ShiftedConstValue instead of ConstValue to support both shift+add/sub
182 // and shift+add+shift.
183 APInt ShiftedConstValue = ConstValue.ashr(TrailingZeroes);
184
185 unsigned ShiftAmt, AddSubOpc;
186 // Is the shifted value the LHS operand of the add/sub?
187 bool ShiftValUseIsLHS = true;
188 // Do we need to negate the result?
189 bool NegateResult = false;
190
191 if (ConstValue.isNonNegative()) {
192 // (mul x, 2^N + 1) => (add (shl x, N), x)
193 // (mul x, 2^N - 1) => (sub (shl x, N), x)
194 // (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M)
195 APInt SCVMinus1 = ShiftedConstValue - 1;
196 APInt CVPlus1 = ConstValue + 1;
197 if (SCVMinus1.isPowerOf2()) {
198 ShiftAmt = SCVMinus1.logBase2();
199 AddSubOpc = TargetOpcode::G_ADD;
200 } else if (CVPlus1.isPowerOf2()) {
201 ShiftAmt = CVPlus1.logBase2();
202 AddSubOpc = TargetOpcode::G_SUB;
203 } else
204 return false;
205 } else {
206 // (mul x, -(2^N - 1)) => (sub x, (shl x, N))
207 // (mul x, -(2^N + 1)) => - (add (shl x, N), x)
208 APInt CVNegPlus1 = -ConstValue + 1;
209 APInt CVNegMinus1 = -ConstValue - 1;
210 if (CVNegPlus1.isPowerOf2()) {
211 ShiftAmt = CVNegPlus1.logBase2();
212 AddSubOpc = TargetOpcode::G_SUB;
213 ShiftValUseIsLHS = false;
214 } else if (CVNegMinus1.isPowerOf2()) {
215 ShiftAmt = CVNegMinus1.logBase2();
216 AddSubOpc = TargetOpcode::G_ADD;
217 NegateResult = true;
218 } else
219 return false;
220 }
221
222 if (NegateResult && TrailingZeroes)
223 return false;
224
225 ApplyFn = [=](MachineIRBuilder &B, Register DstReg) {
226 auto Shift = B.buildConstant(LLT::scalar(64), ShiftAmt);
227 auto ShiftedVal = B.buildShl(Ty, LHS, Shift);
228
229 Register AddSubLHS = ShiftValUseIsLHS ? ShiftedVal.getReg(0) : LHS;
230 Register AddSubRHS = ShiftValUseIsLHS ? LHS : ShiftedVal.getReg(0);
231 auto Res = B.buildInstr(AddSubOpc, {Ty}, {AddSubLHS, AddSubRHS});
232 assert(!(NegateResult && TrailingZeroes) &&
233 "NegateResult and TrailingZeroes cannot both be true for now.");
234 // Negate the result.
235 if (NegateResult) {
236 B.buildSub(DstReg, B.buildConstant(Ty, 0), Res);
237 return;
238 }
239 // Shift the result.
240 if (TrailingZeroes) {
241 B.buildShl(DstReg, Res, B.buildConstant(LLT::scalar(64), TrailingZeroes));
242 return;
243 }
244 B.buildCopy(DstReg, Res.getReg(0));
245 };
246 return true;
247 }
248
applyAArch64MulConstCombine(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & B,std::function<void (MachineIRBuilder & B,Register DstReg)> & ApplyFn)249 void applyAArch64MulConstCombine(
250 MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B,
251 std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) {
252 B.setInstrAndDebugLoc(MI);
253 ApplyFn(B, MI.getOperand(0).getReg());
254 MI.eraseFromParent();
255 }
256
257 /// Try to fold a G_MERGE_VALUES of 2 s32 sources, where the second source
258 /// is a zero, into a G_ZEXT of the first.
matchFoldMergeToZext(MachineInstr & MI,MachineRegisterInfo & MRI)259 bool matchFoldMergeToZext(MachineInstr &MI, MachineRegisterInfo &MRI) {
260 auto &Merge = cast<GMerge>(MI);
261 LLT SrcTy = MRI.getType(Merge.getSourceReg(0));
262 if (SrcTy != LLT::scalar(32) || Merge.getNumSources() != 2)
263 return false;
264 return mi_match(Merge.getSourceReg(1), MRI, m_SpecificICst(0));
265 }
266
applyFoldMergeToZext(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & B,GISelChangeObserver & Observer)267 void applyFoldMergeToZext(MachineInstr &MI, MachineRegisterInfo &MRI,
268 MachineIRBuilder &B, GISelChangeObserver &Observer) {
269 // Mutate %d(s64) = G_MERGE_VALUES %a(s32), 0(s32)
270 // ->
271 // %d(s64) = G_ZEXT %a(s32)
272 Observer.changingInstr(MI);
273 MI.setDesc(B.getTII().get(TargetOpcode::G_ZEXT));
274 MI.removeOperand(2);
275 Observer.changedInstr(MI);
276 }
277
278 /// \returns True if a G_ANYEXT instruction \p MI should be mutated to a G_ZEXT
279 /// instruction.
matchMutateAnyExtToZExt(MachineInstr & MI,MachineRegisterInfo & MRI)280 bool matchMutateAnyExtToZExt(MachineInstr &MI, MachineRegisterInfo &MRI) {
281 // If this is coming from a scalar compare then we can use a G_ZEXT instead of
282 // a G_ANYEXT:
283 //
284 // %cmp:_(s32) = G_[I|F]CMP ... <-- produces 0/1.
285 // %ext:_(s64) = G_ANYEXT %cmp(s32)
286 //
287 // By doing this, we can leverage more KnownBits combines.
288 assert(MI.getOpcode() == TargetOpcode::G_ANYEXT);
289 Register Dst = MI.getOperand(0).getReg();
290 Register Src = MI.getOperand(1).getReg();
291 return MRI.getType(Dst).isScalar() &&
292 mi_match(Src, MRI,
293 m_any_of(m_GICmp(m_Pred(), m_Reg(), m_Reg()),
294 m_GFCmp(m_Pred(), m_Reg(), m_Reg())));
295 }
296
applyMutateAnyExtToZExt(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & B,GISelChangeObserver & Observer)297 void applyMutateAnyExtToZExt(MachineInstr &MI, MachineRegisterInfo &MRI,
298 MachineIRBuilder &B,
299 GISelChangeObserver &Observer) {
300 Observer.changingInstr(MI);
301 MI.setDesc(B.getTII().get(TargetOpcode::G_ZEXT));
302 Observer.changedInstr(MI);
303 }
304
305 /// Match a 128b store of zero and split it into two 64 bit stores, for
306 /// size/performance reasons.
matchSplitStoreZero128(MachineInstr & MI,MachineRegisterInfo & MRI)307 bool matchSplitStoreZero128(MachineInstr &MI, MachineRegisterInfo &MRI) {
308 GStore &Store = cast<GStore>(MI);
309 if (!Store.isSimple())
310 return false;
311 LLT ValTy = MRI.getType(Store.getValueReg());
312 if (ValTy.isScalableVector())
313 return false;
314 if (!ValTy.isVector() || ValTy.getSizeInBits() != 128)
315 return false;
316 if (Store.getMemSizeInBits() != ValTy.getSizeInBits())
317 return false; // Don't split truncating stores.
318 if (!MRI.hasOneNonDBGUse(Store.getValueReg()))
319 return false;
320 auto MaybeCst = isConstantOrConstantSplatVector(
321 *MRI.getVRegDef(Store.getValueReg()), MRI);
322 return MaybeCst && MaybeCst->isZero();
323 }
324
applySplitStoreZero128(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & B,GISelChangeObserver & Observer)325 void applySplitStoreZero128(MachineInstr &MI, MachineRegisterInfo &MRI,
326 MachineIRBuilder &B,
327 GISelChangeObserver &Observer) {
328 B.setInstrAndDebugLoc(MI);
329 GStore &Store = cast<GStore>(MI);
330 assert(MRI.getType(Store.getValueReg()).isVector() &&
331 "Expected a vector store value");
332 LLT NewTy = LLT::scalar(64);
333 Register PtrReg = Store.getPointerReg();
334 auto Zero = B.buildConstant(NewTy, 0);
335 auto HighPtr = B.buildPtrAdd(MRI.getType(PtrReg), PtrReg,
336 B.buildConstant(LLT::scalar(64), 8));
337 auto &MF = *MI.getMF();
338 auto *LowMMO = MF.getMachineMemOperand(&Store.getMMO(), 0, NewTy);
339 auto *HighMMO = MF.getMachineMemOperand(&Store.getMMO(), 8, NewTy);
340 B.buildStore(Zero, PtrReg, *LowMMO);
341 B.buildStore(Zero, HighPtr, *HighMMO);
342 Store.eraseFromParent();
343 }
344
matchOrToBSP(MachineInstr & MI,MachineRegisterInfo & MRI,std::tuple<Register,Register,Register> & MatchInfo)345 bool matchOrToBSP(MachineInstr &MI, MachineRegisterInfo &MRI,
346 std::tuple<Register, Register, Register> &MatchInfo) {
347 const LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
348 if (!DstTy.isVector())
349 return false;
350
351 Register AO1, AO2, BVO1, BVO2;
352 if (!mi_match(MI, MRI,
353 m_GOr(m_GAnd(m_Reg(AO1), m_Reg(BVO1)),
354 m_GAnd(m_Reg(AO2), m_Reg(BVO2)))))
355 return false;
356
357 auto *BV1 = getOpcodeDef<GBuildVector>(BVO1, MRI);
358 auto *BV2 = getOpcodeDef<GBuildVector>(BVO2, MRI);
359 if (!BV1 || !BV2)
360 return false;
361
362 for (int I = 0, E = DstTy.getNumElements(); I < E; I++) {
363 auto ValAndVReg1 =
364 getIConstantVRegValWithLookThrough(BV1->getSourceReg(I), MRI);
365 auto ValAndVReg2 =
366 getIConstantVRegValWithLookThrough(BV2->getSourceReg(I), MRI);
367 if (!ValAndVReg1 || !ValAndVReg2 ||
368 ValAndVReg1->Value != ~ValAndVReg2->Value)
369 return false;
370 }
371
372 MatchInfo = {AO1, AO2, BVO1};
373 return true;
374 }
375
applyOrToBSP(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & B,std::tuple<Register,Register,Register> & MatchInfo)376 void applyOrToBSP(MachineInstr &MI, MachineRegisterInfo &MRI,
377 MachineIRBuilder &B,
378 std::tuple<Register, Register, Register> &MatchInfo) {
379 B.setInstrAndDebugLoc(MI);
380 B.buildInstr(
381 AArch64::G_BSP, {MI.getOperand(0).getReg()},
382 {std::get<2>(MatchInfo), std::get<0>(MatchInfo), std::get<1>(MatchInfo)});
383 MI.eraseFromParent();
384 }
385
386 // Combines Mul(And(Srl(X, 15), 0x10001), 0xffff) into CMLTz
matchCombineMulCMLT(MachineInstr & MI,MachineRegisterInfo & MRI,Register & SrcReg)387 bool matchCombineMulCMLT(MachineInstr &MI, MachineRegisterInfo &MRI,
388 Register &SrcReg) {
389 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
390
391 if (DstTy != LLT::fixed_vector(2, 64) && DstTy != LLT::fixed_vector(2, 32) &&
392 DstTy != LLT::fixed_vector(4, 32) && DstTy != LLT::fixed_vector(4, 16) &&
393 DstTy != LLT::fixed_vector(8, 16))
394 return false;
395
396 auto AndMI = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
397 if (AndMI->getOpcode() != TargetOpcode::G_AND)
398 return false;
399 auto LShrMI = getDefIgnoringCopies(AndMI->getOperand(1).getReg(), MRI);
400 if (LShrMI->getOpcode() != TargetOpcode::G_LSHR)
401 return false;
402
403 // Check the constant splat values
404 auto V1 = isConstantOrConstantSplatVector(
405 *MRI.getVRegDef(MI.getOperand(2).getReg()), MRI);
406 auto V2 = isConstantOrConstantSplatVector(
407 *MRI.getVRegDef(AndMI->getOperand(2).getReg()), MRI);
408 auto V3 = isConstantOrConstantSplatVector(
409 *MRI.getVRegDef(LShrMI->getOperand(2).getReg()), MRI);
410 if (!V1.has_value() || !V2.has_value() || !V3.has_value())
411 return false;
412 unsigned HalfSize = DstTy.getScalarSizeInBits() / 2;
413 if (!V1.value().isMask(HalfSize) || V2.value() != (1ULL | 1ULL << HalfSize) ||
414 V3 != (HalfSize - 1))
415 return false;
416
417 SrcReg = LShrMI->getOperand(1).getReg();
418
419 return true;
420 }
421
applyCombineMulCMLT(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & B,Register & SrcReg)422 void applyCombineMulCMLT(MachineInstr &MI, MachineRegisterInfo &MRI,
423 MachineIRBuilder &B, Register &SrcReg) {
424 Register DstReg = MI.getOperand(0).getReg();
425 LLT DstTy = MRI.getType(DstReg);
426 LLT HalfTy =
427 DstTy.changeElementCount(DstTy.getElementCount().multiplyCoefficientBy(2))
428 .changeElementSize(DstTy.getScalarSizeInBits() / 2);
429
430 Register ZeroVec = B.buildConstant(HalfTy, 0).getReg(0);
431 Register CastReg =
432 B.buildInstr(TargetOpcode::G_BITCAST, {HalfTy}, {SrcReg}).getReg(0);
433 Register CMLTReg =
434 B.buildICmp(CmpInst::Predicate::ICMP_SLT, HalfTy, CastReg, ZeroVec)
435 .getReg(0);
436
437 B.buildInstr(TargetOpcode::G_BITCAST, {DstReg}, {CMLTReg}).getReg(0);
438 MI.eraseFromParent();
439 }
440
441 // Match mul({z/s}ext , {z/s}ext) => {u/s}mull
matchExtMulToMULL(MachineInstr & MI,MachineRegisterInfo & MRI,GISelValueTracking * KB,std::tuple<bool,Register,Register> & MatchInfo)442 bool matchExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI,
443 GISelValueTracking *KB,
444 std::tuple<bool, Register, Register> &MatchInfo) {
445 // Get the instructions that defined the source operand
446 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
447 MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
448 MachineInstr *I2 = getDefIgnoringCopies(MI.getOperand(2).getReg(), MRI);
449 unsigned I1Opc = I1->getOpcode();
450 unsigned I2Opc = I2->getOpcode();
451 unsigned EltSize = DstTy.getScalarSizeInBits();
452
453 if (!DstTy.isVector() || I1->getNumOperands() < 2 || I2->getNumOperands() < 2)
454 return false;
455
456 auto IsAtLeastDoubleExtend = [&](Register R) {
457 LLT Ty = MRI.getType(R);
458 return EltSize >= Ty.getScalarSizeInBits() * 2;
459 };
460
461 // If the source operands were EXTENDED before, then {U/S}MULL can be used
462 bool IsZExt1 =
463 I1Opc == TargetOpcode::G_ZEXT || I1Opc == TargetOpcode::G_ANYEXT;
464 bool IsZExt2 =
465 I2Opc == TargetOpcode::G_ZEXT || I2Opc == TargetOpcode::G_ANYEXT;
466 if (IsZExt1 && IsZExt2 && IsAtLeastDoubleExtend(I1->getOperand(1).getReg()) &&
467 IsAtLeastDoubleExtend(I2->getOperand(1).getReg())) {
468 get<0>(MatchInfo) = true;
469 get<1>(MatchInfo) = I1->getOperand(1).getReg();
470 get<2>(MatchInfo) = I2->getOperand(1).getReg();
471 return true;
472 }
473
474 bool IsSExt1 =
475 I1Opc == TargetOpcode::G_SEXT || I1Opc == TargetOpcode::G_ANYEXT;
476 bool IsSExt2 =
477 I2Opc == TargetOpcode::G_SEXT || I2Opc == TargetOpcode::G_ANYEXT;
478 if (IsSExt1 && IsSExt2 && IsAtLeastDoubleExtend(I1->getOperand(1).getReg()) &&
479 IsAtLeastDoubleExtend(I2->getOperand(1).getReg())) {
480 get<0>(MatchInfo) = false;
481 get<1>(MatchInfo) = I1->getOperand(1).getReg();
482 get<2>(MatchInfo) = I2->getOperand(1).getReg();
483 return true;
484 }
485
486 // Select UMULL if we can replace the other operand with an extend.
487 APInt Mask = APInt::getHighBitsSet(EltSize, EltSize / 2);
488 if (KB && (IsZExt1 || IsZExt2) &&
489 IsAtLeastDoubleExtend(IsZExt1 ? I1->getOperand(1).getReg()
490 : I2->getOperand(1).getReg())) {
491 Register ZExtOp =
492 IsZExt1 ? MI.getOperand(2).getReg() : MI.getOperand(1).getReg();
493 if (KB->maskedValueIsZero(ZExtOp, Mask)) {
494 get<0>(MatchInfo) = true;
495 get<1>(MatchInfo) = IsZExt1 ? I1->getOperand(1).getReg() : ZExtOp;
496 get<2>(MatchInfo) = IsZExt1 ? ZExtOp : I2->getOperand(1).getReg();
497 return true;
498 }
499 } else if (KB && DstTy == LLT::fixed_vector(2, 64) &&
500 KB->maskedValueIsZero(MI.getOperand(1).getReg(), Mask) &&
501 KB->maskedValueIsZero(MI.getOperand(2).getReg(), Mask)) {
502 get<0>(MatchInfo) = true;
503 get<1>(MatchInfo) = MI.getOperand(1).getReg();
504 get<2>(MatchInfo) = MI.getOperand(2).getReg();
505 return true;
506 }
507
508 if (KB && (IsSExt1 || IsSExt2) &&
509 IsAtLeastDoubleExtend(IsSExt1 ? I1->getOperand(1).getReg()
510 : I2->getOperand(1).getReg())) {
511 Register SExtOp =
512 IsSExt1 ? MI.getOperand(2).getReg() : MI.getOperand(1).getReg();
513 if (KB->computeNumSignBits(SExtOp) > EltSize / 2) {
514 get<0>(MatchInfo) = false;
515 get<1>(MatchInfo) = IsSExt1 ? I1->getOperand(1).getReg() : SExtOp;
516 get<2>(MatchInfo) = IsSExt1 ? SExtOp : I2->getOperand(1).getReg();
517 return true;
518 }
519 } else if (KB && DstTy == LLT::fixed_vector(2, 64) &&
520 KB->computeNumSignBits(MI.getOperand(1).getReg()) > EltSize / 2 &&
521 KB->computeNumSignBits(MI.getOperand(2).getReg()) > EltSize / 2) {
522 get<0>(MatchInfo) = false;
523 get<1>(MatchInfo) = MI.getOperand(1).getReg();
524 get<2>(MatchInfo) = MI.getOperand(2).getReg();
525 return true;
526 }
527
528 return false;
529 }
530
applyExtMulToMULL(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & B,GISelChangeObserver & Observer,std::tuple<bool,Register,Register> & MatchInfo)531 void applyExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI,
532 MachineIRBuilder &B, GISelChangeObserver &Observer,
533 std::tuple<bool, Register, Register> &MatchInfo) {
534 assert(MI.getOpcode() == TargetOpcode::G_MUL &&
535 "Expected a G_MUL instruction");
536
537 // Get the instructions that defined the source operand
538 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
539 bool IsZExt = get<0>(MatchInfo);
540 Register Src1Reg = get<1>(MatchInfo);
541 Register Src2Reg = get<2>(MatchInfo);
542 LLT Src1Ty = MRI.getType(Src1Reg);
543 LLT Src2Ty = MRI.getType(Src2Reg);
544 LLT HalfDstTy = DstTy.changeElementSize(DstTy.getScalarSizeInBits() / 2);
545 unsigned ExtOpc = IsZExt ? TargetOpcode::G_ZEXT : TargetOpcode::G_SEXT;
546
547 if (Src1Ty.getScalarSizeInBits() * 2 != DstTy.getScalarSizeInBits())
548 Src1Reg = B.buildExtOrTrunc(ExtOpc, {HalfDstTy}, {Src1Reg}).getReg(0);
549 if (Src2Ty.getScalarSizeInBits() * 2 != DstTy.getScalarSizeInBits())
550 Src2Reg = B.buildExtOrTrunc(ExtOpc, {HalfDstTy}, {Src2Reg}).getReg(0);
551
552 B.buildInstr(IsZExt ? AArch64::G_UMULL : AArch64::G_SMULL,
553 {MI.getOperand(0).getReg()}, {Src1Reg, Src2Reg});
554 MI.eraseFromParent();
555 }
556
557 class AArch64PostLegalizerCombinerImpl : public Combiner {
558 protected:
559 const CombinerHelper Helper;
560 const AArch64PostLegalizerCombinerImplRuleConfig &RuleConfig;
561 const AArch64Subtarget &STI;
562
563 public:
564 AArch64PostLegalizerCombinerImpl(
565 MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC,
566 GISelValueTracking &VT, GISelCSEInfo *CSEInfo,
567 const AArch64PostLegalizerCombinerImplRuleConfig &RuleConfig,
568 const AArch64Subtarget &STI, MachineDominatorTree *MDT,
569 const LegalizerInfo *LI);
570
getName()571 static const char *getName() { return "AArch64PostLegalizerCombiner"; }
572
573 bool tryCombineAll(MachineInstr &I) const override;
574
575 private:
576 #define GET_GICOMBINER_CLASS_MEMBERS
577 #include "AArch64GenPostLegalizeGICombiner.inc"
578 #undef GET_GICOMBINER_CLASS_MEMBERS
579 };
580
581 #define GET_GICOMBINER_IMPL
582 #include "AArch64GenPostLegalizeGICombiner.inc"
583 #undef GET_GICOMBINER_IMPL
584
AArch64PostLegalizerCombinerImpl(MachineFunction & MF,CombinerInfo & CInfo,const TargetPassConfig * TPC,GISelValueTracking & VT,GISelCSEInfo * CSEInfo,const AArch64PostLegalizerCombinerImplRuleConfig & RuleConfig,const AArch64Subtarget & STI,MachineDominatorTree * MDT,const LegalizerInfo * LI)585 AArch64PostLegalizerCombinerImpl::AArch64PostLegalizerCombinerImpl(
586 MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC,
587 GISelValueTracking &VT, GISelCSEInfo *CSEInfo,
588 const AArch64PostLegalizerCombinerImplRuleConfig &RuleConfig,
589 const AArch64Subtarget &STI, MachineDominatorTree *MDT,
590 const LegalizerInfo *LI)
591 : Combiner(MF, CInfo, TPC, &VT, CSEInfo),
592 Helper(Observer, B, /*IsPreLegalize*/ false, &VT, MDT, LI),
593 RuleConfig(RuleConfig), STI(STI),
594 #define GET_GICOMBINER_CONSTRUCTOR_INITS
595 #include "AArch64GenPostLegalizeGICombiner.inc"
596 #undef GET_GICOMBINER_CONSTRUCTOR_INITS
597 {
598 }
599
600 class AArch64PostLegalizerCombiner : public MachineFunctionPass {
601 public:
602 static char ID;
603
604 AArch64PostLegalizerCombiner(bool IsOptNone = false);
605
getPassName() const606 StringRef getPassName() const override {
607 return "AArch64PostLegalizerCombiner";
608 }
609
610 bool runOnMachineFunction(MachineFunction &MF) override;
611 void getAnalysisUsage(AnalysisUsage &AU) const override;
612
613 private:
614 bool IsOptNone;
615 AArch64PostLegalizerCombinerImplRuleConfig RuleConfig;
616
617
618 struct StoreInfo {
619 GStore *St = nullptr;
620 // The G_PTR_ADD that's used by the store. We keep this to cache the
621 // MachineInstr def.
622 GPtrAdd *Ptr = nullptr;
623 // The signed offset to the Ptr instruction.
624 int64_t Offset = 0;
625 LLT StoredType;
626 };
627 bool tryOptimizeConsecStores(SmallVectorImpl<StoreInfo> &Stores,
628 CSEMIRBuilder &MIB);
629
630 bool optimizeConsecutiveMemOpAddressing(MachineFunction &MF,
631 CSEMIRBuilder &MIB);
632 };
633 } // end anonymous namespace
634
getAnalysisUsage(AnalysisUsage & AU) const635 void AArch64PostLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const {
636 AU.addRequired<TargetPassConfig>();
637 AU.setPreservesCFG();
638 getSelectionDAGFallbackAnalysisUsage(AU);
639 AU.addRequired<GISelValueTrackingAnalysisLegacy>();
640 AU.addPreserved<GISelValueTrackingAnalysisLegacy>();
641 if (!IsOptNone) {
642 AU.addRequired<MachineDominatorTreeWrapperPass>();
643 AU.addPreserved<MachineDominatorTreeWrapperPass>();
644 AU.addRequired<GISelCSEAnalysisWrapperPass>();
645 AU.addPreserved<GISelCSEAnalysisWrapperPass>();
646 }
647 MachineFunctionPass::getAnalysisUsage(AU);
648 }
649
AArch64PostLegalizerCombiner(bool IsOptNone)650 AArch64PostLegalizerCombiner::AArch64PostLegalizerCombiner(bool IsOptNone)
651 : MachineFunctionPass(ID), IsOptNone(IsOptNone) {
652 if (!RuleConfig.parseCommandLineOption())
653 report_fatal_error("Invalid rule identifier");
654 }
655
runOnMachineFunction(MachineFunction & MF)656 bool AArch64PostLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) {
657 if (MF.getProperties().hasFailedISel())
658 return false;
659 assert(MF.getProperties().hasLegalized() && "Expected a legalized function?");
660 auto *TPC = &getAnalysis<TargetPassConfig>();
661 const Function &F = MF.getFunction();
662 bool EnableOpt =
663 MF.getTarget().getOptLevel() != CodeGenOptLevel::None && !skipFunction(F);
664
665 const AArch64Subtarget &ST = MF.getSubtarget<AArch64Subtarget>();
666 const auto *LI = ST.getLegalizerInfo();
667
668 GISelValueTracking *VT =
669 &getAnalysis<GISelValueTrackingAnalysisLegacy>().get(MF);
670 MachineDominatorTree *MDT =
671 IsOptNone ? nullptr
672 : &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
673 GISelCSEAnalysisWrapper &Wrapper =
674 getAnalysis<GISelCSEAnalysisWrapperPass>().getCSEWrapper();
675 auto *CSEInfo = &Wrapper.get(TPC->getCSEConfig());
676
677 CombinerInfo CInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false,
678 /*LegalizerInfo*/ nullptr, EnableOpt, F.hasOptSize(),
679 F.hasMinSize());
680 // Disable fixed-point iteration to reduce compile-time
681 CInfo.MaxIterations = 1;
682 CInfo.ObserverLvl = CombinerInfo::ObserverLevel::SinglePass;
683 // Legalizer performs DCE, so a full DCE pass is unnecessary.
684 CInfo.EnableFullDCE = false;
685 AArch64PostLegalizerCombinerImpl Impl(MF, CInfo, TPC, *VT, CSEInfo,
686 RuleConfig, ST, MDT, LI);
687 bool Changed = Impl.combineMachineInstrs();
688
689 auto MIB = CSEMIRBuilder(MF);
690 MIB.setCSEInfo(CSEInfo);
691 Changed |= optimizeConsecutiveMemOpAddressing(MF, MIB);
692 return Changed;
693 }
694
tryOptimizeConsecStores(SmallVectorImpl<StoreInfo> & Stores,CSEMIRBuilder & MIB)695 bool AArch64PostLegalizerCombiner::tryOptimizeConsecStores(
696 SmallVectorImpl<StoreInfo> &Stores, CSEMIRBuilder &MIB) {
697 if (Stores.size() <= 2)
698 return false;
699
700 // Profitabity checks:
701 int64_t BaseOffset = Stores[0].Offset;
702 unsigned NumPairsExpected = Stores.size() / 2;
703 unsigned TotalInstsExpected = NumPairsExpected + (Stores.size() % 2);
704 // Size savings will depend on whether we can fold the offset, as an
705 // immediate of an ADD.
706 auto &TLI = *MIB.getMF().getSubtarget().getTargetLowering();
707 if (!TLI.isLegalAddImmediate(BaseOffset))
708 TotalInstsExpected++;
709 int SavingsExpected = Stores.size() - TotalInstsExpected;
710 if (SavingsExpected <= 0)
711 return false;
712
713 auto &MRI = MIB.getMF().getRegInfo();
714
715 // We have a series of consecutive stores. Factor out the common base
716 // pointer and rewrite the offsets.
717 Register NewBase = Stores[0].Ptr->getReg(0);
718 for (auto &SInfo : Stores) {
719 // Compute a new pointer with the new base ptr and adjusted offset.
720 MIB.setInstrAndDebugLoc(*SInfo.St);
721 auto NewOff = MIB.buildConstant(LLT::scalar(64), SInfo.Offset - BaseOffset);
722 auto NewPtr = MIB.buildPtrAdd(MRI.getType(SInfo.St->getPointerReg()),
723 NewBase, NewOff);
724 if (MIB.getObserver())
725 MIB.getObserver()->changingInstr(*SInfo.St);
726 SInfo.St->getOperand(1).setReg(NewPtr.getReg(0));
727 if (MIB.getObserver())
728 MIB.getObserver()->changedInstr(*SInfo.St);
729 }
730 LLVM_DEBUG(dbgs() << "Split a series of " << Stores.size()
731 << " stores into a base pointer and offsets.\n");
732 return true;
733 }
734
735 static cl::opt<bool>
736 EnableConsecutiveMemOpOpt("aarch64-postlegalizer-consecutive-memops",
737 cl::init(true), cl::Hidden,
738 cl::desc("Enable consecutive memop optimization "
739 "in AArch64PostLegalizerCombiner"));
740
optimizeConsecutiveMemOpAddressing(MachineFunction & MF,CSEMIRBuilder & MIB)741 bool AArch64PostLegalizerCombiner::optimizeConsecutiveMemOpAddressing(
742 MachineFunction &MF, CSEMIRBuilder &MIB) {
743 // This combine needs to run after all reassociations/folds on pointer
744 // addressing have been done, specifically those that combine two G_PTR_ADDs
745 // with constant offsets into a single G_PTR_ADD with a combined offset.
746 // The goal of this optimization is to undo that combine in the case where
747 // doing so has prevented the formation of pair stores due to illegal
748 // addressing modes of STP. The reason that we do it here is because
749 // it's much easier to undo the transformation of a series consecutive
750 // mem ops, than it is to detect when doing it would be a bad idea looking
751 // at a single G_PTR_ADD in the reassociation/ptradd_immed_chain combine.
752 //
753 // An example:
754 // G_STORE %11:_(<2 x s64>), %base:_(p0) :: (store (<2 x s64>), align 1)
755 // %off1:_(s64) = G_CONSTANT i64 4128
756 // %p1:_(p0) = G_PTR_ADD %0:_, %off1:_(s64)
757 // G_STORE %11:_(<2 x s64>), %p1:_(p0) :: (store (<2 x s64>), align 1)
758 // %off2:_(s64) = G_CONSTANT i64 4144
759 // %p2:_(p0) = G_PTR_ADD %0:_, %off2:_(s64)
760 // G_STORE %11:_(<2 x s64>), %p2:_(p0) :: (store (<2 x s64>), align 1)
761 // %off3:_(s64) = G_CONSTANT i64 4160
762 // %p3:_(p0) = G_PTR_ADD %0:_, %off3:_(s64)
763 // G_STORE %11:_(<2 x s64>), %17:_(p0) :: (store (<2 x s64>), align 1)
764 bool Changed = false;
765 auto &MRI = MF.getRegInfo();
766
767 if (!EnableConsecutiveMemOpOpt)
768 return Changed;
769
770 SmallVector<StoreInfo, 8> Stores;
771 // If we see a load, then we keep track of any values defined by it.
772 // In the following example, STP formation will fail anyway because
773 // the latter store is using a load result that appears after the
774 // the prior store. In this situation if we factor out the offset then
775 // we increase code size for no benefit.
776 // G_STORE %v1:_(s64), %base:_(p0) :: (store (s64))
777 // %v2:_(s64) = G_LOAD %ldptr:_(p0) :: (load (s64))
778 // G_STORE %v2:_(s64), %base:_(p0) :: (store (s64))
779 SmallVector<Register> LoadValsSinceLastStore;
780
781 auto storeIsValid = [&](StoreInfo &Last, StoreInfo New) {
782 // Check if this store is consecutive to the last one.
783 if (Last.Ptr->getBaseReg() != New.Ptr->getBaseReg() ||
784 (Last.Offset + static_cast<int64_t>(Last.StoredType.getSizeInBytes()) !=
785 New.Offset) ||
786 Last.StoredType != New.StoredType)
787 return false;
788
789 // Check if this store is using a load result that appears after the
790 // last store. If so, bail out.
791 if (any_of(LoadValsSinceLastStore, [&](Register LoadVal) {
792 return New.St->getValueReg() == LoadVal;
793 }))
794 return false;
795
796 // Check if the current offset would be too large for STP.
797 // If not, then STP formation should be able to handle it, so we don't
798 // need to do anything.
799 int64_t MaxLegalOffset;
800 switch (New.StoredType.getSizeInBits()) {
801 case 32:
802 MaxLegalOffset = 252;
803 break;
804 case 64:
805 MaxLegalOffset = 504;
806 break;
807 case 128:
808 MaxLegalOffset = 1008;
809 break;
810 default:
811 llvm_unreachable("Unexpected stored type size");
812 }
813 if (New.Offset < MaxLegalOffset)
814 return false;
815
816 // If factoring it out still wouldn't help then don't bother.
817 return New.Offset - Stores[0].Offset <= MaxLegalOffset;
818 };
819
820 auto resetState = [&]() {
821 Stores.clear();
822 LoadValsSinceLastStore.clear();
823 };
824
825 for (auto &MBB : MF) {
826 // We're looking inside a single BB at a time since the memset pattern
827 // should only be in a single block.
828 resetState();
829 for (auto &MI : MBB) {
830 // Skip for scalable vectors
831 if (auto *LdSt = dyn_cast<GLoadStore>(&MI);
832 LdSt && MRI.getType(LdSt->getOperand(0).getReg()).isScalableVector())
833 continue;
834
835 if (auto *St = dyn_cast<GStore>(&MI)) {
836 Register PtrBaseReg;
837 APInt Offset;
838 LLT StoredValTy = MRI.getType(St->getValueReg());
839 unsigned ValSize = StoredValTy.getSizeInBits();
840 if (ValSize < 32 || St->getMMO().getSizeInBits() != ValSize)
841 continue;
842
843 Register PtrReg = St->getPointerReg();
844 if (mi_match(
845 PtrReg, MRI,
846 m_OneNonDBGUse(m_GPtrAdd(m_Reg(PtrBaseReg), m_ICst(Offset))))) {
847 GPtrAdd *PtrAdd = cast<GPtrAdd>(MRI.getVRegDef(PtrReg));
848 StoreInfo New = {St, PtrAdd, Offset.getSExtValue(), StoredValTy};
849
850 if (Stores.empty()) {
851 Stores.push_back(New);
852 continue;
853 }
854
855 // Check if this store is a valid continuation of the sequence.
856 auto &Last = Stores.back();
857 if (storeIsValid(Last, New)) {
858 Stores.push_back(New);
859 LoadValsSinceLastStore.clear(); // Reset the load value tracking.
860 } else {
861 // The store isn't a valid to consider for the prior sequence,
862 // so try to optimize what we have so far and start a new sequence.
863 Changed |= tryOptimizeConsecStores(Stores, MIB);
864 resetState();
865 Stores.push_back(New);
866 }
867 }
868 } else if (auto *Ld = dyn_cast<GLoad>(&MI)) {
869 LoadValsSinceLastStore.push_back(Ld->getDstReg());
870 }
871 }
872 Changed |= tryOptimizeConsecStores(Stores, MIB);
873 resetState();
874 }
875
876 return Changed;
877 }
878
879 char AArch64PostLegalizerCombiner::ID = 0;
880 INITIALIZE_PASS_BEGIN(AArch64PostLegalizerCombiner, DEBUG_TYPE,
881 "Combine AArch64 MachineInstrs after legalization", false,
882 false)
883 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
884 INITIALIZE_PASS_DEPENDENCY(GISelValueTrackingAnalysisLegacy)
885 INITIALIZE_PASS_END(AArch64PostLegalizerCombiner, DEBUG_TYPE,
886 "Combine AArch64 MachineInstrs after legalization", false,
887 false)
888
889 namespace llvm {
createAArch64PostLegalizerCombiner(bool IsOptNone)890 FunctionPass *createAArch64PostLegalizerCombiner(bool IsOptNone) {
891 return new AArch64PostLegalizerCombiner(IsOptNone);
892 }
893 } // end namespace llvm
894