xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (revision 71ac745d76c3ba442e753daff1870893f272b29d)
1 //===-- AArch64ISelLowering.cpp - AArch64 DAG Lowering Implementation  ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AArch64TargetLowering class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "AArch64ISelLowering.h"
14 #include "AArch64CallingConvention.h"
15 #include "AArch64ExpandImm.h"
16 #include "AArch64MachineFunctionInfo.h"
17 #include "AArch64PerfectShuffle.h"
18 #include "AArch64RegisterInfo.h"
19 #include "AArch64Subtarget.h"
20 #include "MCTargetDesc/AArch64AddressingModes.h"
21 #include "Utils/AArch64BaseInfo.h"
22 #include "llvm/ADT/APFloat.h"
23 #include "llvm/ADT/APInt.h"
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallSet.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/Statistic.h"
29 #include "llvm/ADT/StringRef.h"
30 #include "llvm/ADT/Twine.h"
31 #include "llvm/Analysis/LoopInfo.h"
32 #include "llvm/Analysis/MemoryLocation.h"
33 #include "llvm/Analysis/ObjCARCUtil.h"
34 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
35 #include "llvm/Analysis/TargetTransformInfo.h"
36 #include "llvm/Analysis/ValueTracking.h"
37 #include "llvm/Analysis/VectorUtils.h"
38 #include "llvm/CodeGen/Analysis.h"
39 #include "llvm/CodeGen/CallingConvLower.h"
40 #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
41 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
42 #include "llvm/CodeGen/GlobalISel/Utils.h"
43 #include "llvm/CodeGen/ISDOpcodes.h"
44 #include "llvm/CodeGen/MachineBasicBlock.h"
45 #include "llvm/CodeGen/MachineFrameInfo.h"
46 #include "llvm/CodeGen/MachineFunction.h"
47 #include "llvm/CodeGen/MachineInstr.h"
48 #include "llvm/CodeGen/MachineInstrBuilder.h"
49 #include "llvm/CodeGen/MachineMemOperand.h"
50 #include "llvm/CodeGen/MachineRegisterInfo.h"
51 #include "llvm/CodeGen/RuntimeLibcallUtil.h"
52 #include "llvm/CodeGen/SelectionDAG.h"
53 #include "llvm/CodeGen/SelectionDAGNodes.h"
54 #include "llvm/CodeGen/TargetCallingConv.h"
55 #include "llvm/CodeGen/TargetInstrInfo.h"
56 #include "llvm/CodeGen/TargetOpcodes.h"
57 #include "llvm/CodeGen/ValueTypes.h"
58 #include "llvm/CodeGenTypes/MachineValueType.h"
59 #include "llvm/IR/Attributes.h"
60 #include "llvm/IR/Constants.h"
61 #include "llvm/IR/DataLayout.h"
62 #include "llvm/IR/DebugLoc.h"
63 #include "llvm/IR/DerivedTypes.h"
64 #include "llvm/IR/Function.h"
65 #include "llvm/IR/GetElementPtrTypeIterator.h"
66 #include "llvm/IR/GlobalValue.h"
67 #include "llvm/IR/IRBuilder.h"
68 #include "llvm/IR/Instruction.h"
69 #include "llvm/IR/Instructions.h"
70 #include "llvm/IR/IntrinsicInst.h"
71 #include "llvm/IR/Intrinsics.h"
72 #include "llvm/IR/IntrinsicsAArch64.h"
73 #include "llvm/IR/Module.h"
74 #include "llvm/IR/PatternMatch.h"
75 #include "llvm/IR/Type.h"
76 #include "llvm/IR/Use.h"
77 #include "llvm/IR/Value.h"
78 #include "llvm/MC/MCRegisterInfo.h"
79 #include "llvm/Support/AtomicOrdering.h"
80 #include "llvm/Support/Casting.h"
81 #include "llvm/Support/CodeGen.h"
82 #include "llvm/Support/CommandLine.h"
83 #include "llvm/Support/Debug.h"
84 #include "llvm/Support/ErrorHandling.h"
85 #include "llvm/Support/InstructionCost.h"
86 #include "llvm/Support/KnownBits.h"
87 #include "llvm/Support/MathExtras.h"
88 #include "llvm/Support/SipHash.h"
89 #include "llvm/Support/raw_ostream.h"
90 #include "llvm/Target/TargetMachine.h"
91 #include "llvm/Target/TargetOptions.h"
92 #include "llvm/TargetParser/Triple.h"
93 #include <algorithm>
94 #include <bitset>
95 #include <cassert>
96 #include <cctype>
97 #include <cstdint>
98 #include <cstdlib>
99 #include <iterator>
100 #include <limits>
101 #include <optional>
102 #include <tuple>
103 #include <utility>
104 #include <vector>
105 
106 using namespace llvm;
107 using namespace llvm::PatternMatch;
108 
109 #define DEBUG_TYPE "aarch64-lower"
110 
111 STATISTIC(NumTailCalls, "Number of tail calls");
112 STATISTIC(NumShiftInserts, "Number of vector shift inserts");
113 STATISTIC(NumOptimizedImms, "Number of times immediates were optimized");
114 
115 // FIXME: The necessary dtprel relocations don't seem to be supported
116 // well in the GNU bfd and gold linkers at the moment. Therefore, by
117 // default, for now, fall back to GeneralDynamic code generation.
118 cl::opt<bool> EnableAArch64ELFLocalDynamicTLSGeneration(
119     "aarch64-elf-ldtls-generation", cl::Hidden,
120     cl::desc("Allow AArch64 Local Dynamic TLS code generation"),
121     cl::init(false));
122 
123 static cl::opt<bool>
124 EnableOptimizeLogicalImm("aarch64-enable-logical-imm", cl::Hidden,
125                          cl::desc("Enable AArch64 logical imm instruction "
126                                   "optimization"),
127                          cl::init(true));
128 
129 // Temporary option added for the purpose of testing functionality added
130 // to DAGCombiner.cpp in D92230. It is expected that this can be removed
131 // in future when both implementations will be based off MGATHER rather
132 // than the GLD1 nodes added for the SVE gather load intrinsics.
133 static cl::opt<bool>
134 EnableCombineMGatherIntrinsics("aarch64-enable-mgather-combine", cl::Hidden,
135                                 cl::desc("Combine extends of AArch64 masked "
136                                          "gather intrinsics"),
137                                 cl::init(true));
138 
139 static cl::opt<bool> EnableExtToTBL("aarch64-enable-ext-to-tbl", cl::Hidden,
140                                     cl::desc("Combine ext and trunc to TBL"),
141                                     cl::init(true));
142 
143 // All of the XOR, OR and CMP use ALU ports, and data dependency will become the
144 // bottleneck after this transform on high end CPU. So this max leaf node
145 // limitation is guard cmp+ccmp will be profitable.
146 static cl::opt<unsigned> MaxXors("aarch64-max-xors", cl::init(16), cl::Hidden,
147                                  cl::desc("Maximum of xors"));
148 
149 // By turning this on, we will not fallback to DAG ISel when encountering
150 // scalable vector types for all instruction, even if SVE is not yet supported
151 // with some instructions.
152 // See [AArch64TargetLowering::fallbackToDAGISel] for implementation details.
153 cl::opt<bool> EnableSVEGISel(
154     "aarch64-enable-gisel-sve", cl::Hidden,
155     cl::desc("Enable / disable SVE scalable vectors in Global ISel"),
156     cl::init(false));
157 
158 /// Value type used for condition codes.
159 static const MVT MVT_CC = MVT::i32;
160 
161 static const MCPhysReg GPRArgRegs[] = {AArch64::X0, AArch64::X1, AArch64::X2,
162                                        AArch64::X3, AArch64::X4, AArch64::X5,
163                                        AArch64::X6, AArch64::X7};
164 static const MCPhysReg FPRArgRegs[] = {AArch64::Q0, AArch64::Q1, AArch64::Q2,
165                                        AArch64::Q3, AArch64::Q4, AArch64::Q5,
166                                        AArch64::Q6, AArch64::Q7};
167 
getGPRArgRegs()168 ArrayRef<MCPhysReg> llvm::AArch64::getGPRArgRegs() { return GPRArgRegs; }
169 
getFPRArgRegs()170 ArrayRef<MCPhysReg> llvm::AArch64::getFPRArgRegs() { return FPRArgRegs; }
171 
getPackedSVEVectorVT(EVT VT)172 static inline EVT getPackedSVEVectorVT(EVT VT) {
173   switch (VT.getSimpleVT().SimpleTy) {
174   default:
175     llvm_unreachable("unexpected element type for vector");
176   case MVT::i8:
177     return MVT::nxv16i8;
178   case MVT::i16:
179     return MVT::nxv8i16;
180   case MVT::i32:
181     return MVT::nxv4i32;
182   case MVT::i64:
183     return MVT::nxv2i64;
184   case MVT::f16:
185     return MVT::nxv8f16;
186   case MVT::f32:
187     return MVT::nxv4f32;
188   case MVT::f64:
189     return MVT::nxv2f64;
190   case MVT::bf16:
191     return MVT::nxv8bf16;
192   }
193 }
194 
195 // NOTE: Currently there's only a need to return integer vector types. If this
196 // changes then just add an extra "type" parameter.
getPackedSVEVectorVT(ElementCount EC)197 static inline EVT getPackedSVEVectorVT(ElementCount EC) {
198   switch (EC.getKnownMinValue()) {
199   default:
200     llvm_unreachable("unexpected element count for vector");
201   case 16:
202     return MVT::nxv16i8;
203   case 8:
204     return MVT::nxv8i16;
205   case 4:
206     return MVT::nxv4i32;
207   case 2:
208     return MVT::nxv2i64;
209   }
210 }
211 
getPromotedVTForPredicate(EVT VT)212 static inline EVT getPromotedVTForPredicate(EVT VT) {
213   assert(VT.isScalableVector() && (VT.getVectorElementType() == MVT::i1) &&
214          "Expected scalable predicate vector type!");
215   switch (VT.getVectorMinNumElements()) {
216   default:
217     llvm_unreachable("unexpected element count for vector");
218   case 2:
219     return MVT::nxv2i64;
220   case 4:
221     return MVT::nxv4i32;
222   case 8:
223     return MVT::nxv8i16;
224   case 16:
225     return MVT::nxv16i8;
226   }
227 }
228 
229 /// Returns true if VT's elements occupy the lowest bit positions of its
230 /// associated register class without any intervening space.
231 ///
232 /// For example, nxv2f16, nxv4f16 and nxv8f16 are legal types that belong to the
233 /// same register class, but only nxv8f16 can be treated as a packed vector.
isPackedVectorType(EVT VT,SelectionDAG & DAG)234 static inline bool isPackedVectorType(EVT VT, SelectionDAG &DAG) {
235   assert(VT.isVector() && DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
236          "Expected legal vector type!");
237   return VT.isFixedLengthVector() ||
238          VT.getSizeInBits().getKnownMinValue() == AArch64::SVEBitsPerBlock;
239 }
240 
241 // Returns true for ####_MERGE_PASSTHRU opcodes, whose operands have a leading
242 // predicate and end with a passthru value matching the result type.
isMergePassthruOpcode(unsigned Opc)243 static bool isMergePassthruOpcode(unsigned Opc) {
244   switch (Opc) {
245   default:
246     return false;
247   case AArch64ISD::BITREVERSE_MERGE_PASSTHRU:
248   case AArch64ISD::BSWAP_MERGE_PASSTHRU:
249   case AArch64ISD::REVH_MERGE_PASSTHRU:
250   case AArch64ISD::REVW_MERGE_PASSTHRU:
251   case AArch64ISD::REVD_MERGE_PASSTHRU:
252   case AArch64ISD::CTLZ_MERGE_PASSTHRU:
253   case AArch64ISD::CTPOP_MERGE_PASSTHRU:
254   case AArch64ISD::DUP_MERGE_PASSTHRU:
255   case AArch64ISD::ABS_MERGE_PASSTHRU:
256   case AArch64ISD::NEG_MERGE_PASSTHRU:
257   case AArch64ISD::FNEG_MERGE_PASSTHRU:
258   case AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU:
259   case AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU:
260   case AArch64ISD::FCEIL_MERGE_PASSTHRU:
261   case AArch64ISD::FFLOOR_MERGE_PASSTHRU:
262   case AArch64ISD::FNEARBYINT_MERGE_PASSTHRU:
263   case AArch64ISD::FRINT_MERGE_PASSTHRU:
264   case AArch64ISD::FROUND_MERGE_PASSTHRU:
265   case AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU:
266   case AArch64ISD::FTRUNC_MERGE_PASSTHRU:
267   case AArch64ISD::FP_ROUND_MERGE_PASSTHRU:
268   case AArch64ISD::FP_EXTEND_MERGE_PASSTHRU:
269   case AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU:
270   case AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU:
271   case AArch64ISD::FCVTZU_MERGE_PASSTHRU:
272   case AArch64ISD::FCVTZS_MERGE_PASSTHRU:
273   case AArch64ISD::FSQRT_MERGE_PASSTHRU:
274   case AArch64ISD::FRECPX_MERGE_PASSTHRU:
275   case AArch64ISD::FABS_MERGE_PASSTHRU:
276     return true;
277   }
278 }
279 
280 // Returns true if inactive lanes are known to be zeroed by construction.
isZeroingInactiveLanes(SDValue Op)281 static bool isZeroingInactiveLanes(SDValue Op) {
282   switch (Op.getOpcode()) {
283   default:
284     return false;
285   // We guarantee i1 splat_vectors to zero the other lanes
286   case ISD::SPLAT_VECTOR:
287   case AArch64ISD::PTRUE:
288   case AArch64ISD::SETCC_MERGE_ZERO:
289     return true;
290   case ISD::INTRINSIC_WO_CHAIN:
291     switch (Op.getConstantOperandVal(0)) {
292     default:
293       return false;
294     case Intrinsic::aarch64_sve_ptrue:
295     case Intrinsic::aarch64_sve_pnext:
296     case Intrinsic::aarch64_sve_cmpeq:
297     case Intrinsic::aarch64_sve_cmpne:
298     case Intrinsic::aarch64_sve_cmpge:
299     case Intrinsic::aarch64_sve_cmpgt:
300     case Intrinsic::aarch64_sve_cmphs:
301     case Intrinsic::aarch64_sve_cmphi:
302     case Intrinsic::aarch64_sve_cmpeq_wide:
303     case Intrinsic::aarch64_sve_cmpne_wide:
304     case Intrinsic::aarch64_sve_cmpge_wide:
305     case Intrinsic::aarch64_sve_cmpgt_wide:
306     case Intrinsic::aarch64_sve_cmplt_wide:
307     case Intrinsic::aarch64_sve_cmple_wide:
308     case Intrinsic::aarch64_sve_cmphs_wide:
309     case Intrinsic::aarch64_sve_cmphi_wide:
310     case Intrinsic::aarch64_sve_cmplo_wide:
311     case Intrinsic::aarch64_sve_cmpls_wide:
312     case Intrinsic::aarch64_sve_fcmpeq:
313     case Intrinsic::aarch64_sve_fcmpne:
314     case Intrinsic::aarch64_sve_fcmpge:
315     case Intrinsic::aarch64_sve_fcmpgt:
316     case Intrinsic::aarch64_sve_fcmpuo:
317     case Intrinsic::aarch64_sve_facgt:
318     case Intrinsic::aarch64_sve_facge:
319     case Intrinsic::aarch64_sve_whilege:
320     case Intrinsic::aarch64_sve_whilegt:
321     case Intrinsic::aarch64_sve_whilehi:
322     case Intrinsic::aarch64_sve_whilehs:
323     case Intrinsic::aarch64_sve_whilele:
324     case Intrinsic::aarch64_sve_whilelo:
325     case Intrinsic::aarch64_sve_whilels:
326     case Intrinsic::aarch64_sve_whilelt:
327     case Intrinsic::aarch64_sve_match:
328     case Intrinsic::aarch64_sve_nmatch:
329     case Intrinsic::aarch64_sve_whilege_x2:
330     case Intrinsic::aarch64_sve_whilegt_x2:
331     case Intrinsic::aarch64_sve_whilehi_x2:
332     case Intrinsic::aarch64_sve_whilehs_x2:
333     case Intrinsic::aarch64_sve_whilele_x2:
334     case Intrinsic::aarch64_sve_whilelo_x2:
335     case Intrinsic::aarch64_sve_whilels_x2:
336     case Intrinsic::aarch64_sve_whilelt_x2:
337       return true;
338     }
339   }
340 }
341 
342 static std::tuple<SDValue, SDValue>
extractPtrauthBlendDiscriminators(SDValue Disc,SelectionDAG * DAG)343 extractPtrauthBlendDiscriminators(SDValue Disc, SelectionDAG *DAG) {
344   SDLoc DL(Disc);
345   SDValue AddrDisc;
346   SDValue ConstDisc;
347 
348   // If this is a blend, remember the constant and address discriminators.
349   // Otherwise, it's either a constant discriminator, or a non-blended
350   // address discriminator.
351   if (Disc->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
352       Disc->getConstantOperandVal(0) == Intrinsic::ptrauth_blend) {
353     AddrDisc = Disc->getOperand(1);
354     ConstDisc = Disc->getOperand(2);
355   } else {
356     ConstDisc = Disc;
357   }
358 
359   // If the constant discriminator (either the blend RHS, or the entire
360   // discriminator value) isn't a 16-bit constant, bail out, and let the
361   // discriminator be computed separately.
362   const auto *ConstDiscN = dyn_cast<ConstantSDNode>(ConstDisc);
363   if (!ConstDiscN || !isUInt<16>(ConstDiscN->getZExtValue()))
364     return std::make_tuple(DAG->getTargetConstant(0, DL, MVT::i64), Disc);
365 
366   // If there's no address discriminator, use NoRegister, which we'll later
367   // replace with XZR, or directly use a Z variant of the inst. when available.
368   if (!AddrDisc)
369     AddrDisc = DAG->getRegister(AArch64::NoRegister, MVT::i64);
370 
371   return std::make_tuple(
372       DAG->getTargetConstant(ConstDiscN->getZExtValue(), DL, MVT::i64),
373       AddrDisc);
374 }
375 
AArch64TargetLowering(const TargetMachine & TM,const AArch64Subtarget & STI)376 AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
377                                              const AArch64Subtarget &STI)
378     : TargetLowering(TM), Subtarget(&STI) {
379   // AArch64 doesn't have comparisons which set GPRs or setcc instructions, so
380   // we have to make something up. Arbitrarily, choose ZeroOrOne.
381   setBooleanContents(ZeroOrOneBooleanContent);
382   // When comparing vectors the result sets the different elements in the
383   // vector to all-one or all-zero.
384   setBooleanVectorContents(ZeroOrNegativeOneBooleanContent);
385 
386   // Set up the register classes.
387   addRegisterClass(MVT::i32, &AArch64::GPR32allRegClass);
388   addRegisterClass(MVT::i64, &AArch64::GPR64allRegClass);
389 
390   if (Subtarget->hasLS64()) {
391     addRegisterClass(MVT::i64x8, &AArch64::GPR64x8ClassRegClass);
392     setOperationAction(ISD::LOAD, MVT::i64x8, Custom);
393     setOperationAction(ISD::STORE, MVT::i64x8, Custom);
394   }
395 
396   if (Subtarget->hasFPARMv8()) {
397     addRegisterClass(MVT::f16, &AArch64::FPR16RegClass);
398     addRegisterClass(MVT::bf16, &AArch64::FPR16RegClass);
399     addRegisterClass(MVT::f32, &AArch64::FPR32RegClass);
400     addRegisterClass(MVT::f64, &AArch64::FPR64RegClass);
401     addRegisterClass(MVT::f128, &AArch64::FPR128RegClass);
402   }
403 
404   if (Subtarget->hasNEON()) {
405     addRegisterClass(MVT::v16i8, &AArch64::FPR8RegClass);
406     addRegisterClass(MVT::v8i16, &AArch64::FPR16RegClass);
407 
408     addDRType(MVT::v2f32);
409     addDRType(MVT::v8i8);
410     addDRType(MVT::v4i16);
411     addDRType(MVT::v2i32);
412     addDRType(MVT::v1i64);
413     addDRType(MVT::v1f64);
414     addDRType(MVT::v4f16);
415     addDRType(MVT::v4bf16);
416 
417     addQRType(MVT::v4f32);
418     addQRType(MVT::v2f64);
419     addQRType(MVT::v16i8);
420     addQRType(MVT::v8i16);
421     addQRType(MVT::v4i32);
422     addQRType(MVT::v2i64);
423     addQRType(MVT::v8f16);
424     addQRType(MVT::v8bf16);
425   }
426 
427   if (Subtarget->isSVEorStreamingSVEAvailable()) {
428     // Add legal sve predicate types
429     addRegisterClass(MVT::nxv1i1, &AArch64::PPRRegClass);
430     addRegisterClass(MVT::nxv2i1, &AArch64::PPRRegClass);
431     addRegisterClass(MVT::nxv4i1, &AArch64::PPRRegClass);
432     addRegisterClass(MVT::nxv8i1, &AArch64::PPRRegClass);
433     addRegisterClass(MVT::nxv16i1, &AArch64::PPRRegClass);
434 
435     // Add legal sve data types
436     addRegisterClass(MVT::nxv16i8, &AArch64::ZPRRegClass);
437     addRegisterClass(MVT::nxv8i16, &AArch64::ZPRRegClass);
438     addRegisterClass(MVT::nxv4i32, &AArch64::ZPRRegClass);
439     addRegisterClass(MVT::nxv2i64, &AArch64::ZPRRegClass);
440 
441     addRegisterClass(MVT::nxv2f16, &AArch64::ZPRRegClass);
442     addRegisterClass(MVT::nxv4f16, &AArch64::ZPRRegClass);
443     addRegisterClass(MVT::nxv8f16, &AArch64::ZPRRegClass);
444     addRegisterClass(MVT::nxv2f32, &AArch64::ZPRRegClass);
445     addRegisterClass(MVT::nxv4f32, &AArch64::ZPRRegClass);
446     addRegisterClass(MVT::nxv2f64, &AArch64::ZPRRegClass);
447 
448     addRegisterClass(MVT::nxv2bf16, &AArch64::ZPRRegClass);
449     addRegisterClass(MVT::nxv4bf16, &AArch64::ZPRRegClass);
450     addRegisterClass(MVT::nxv8bf16, &AArch64::ZPRRegClass);
451 
452     if (Subtarget->useSVEForFixedLengthVectors()) {
453       for (MVT VT : MVT::integer_fixedlen_vector_valuetypes())
454         if (useSVEForFixedLengthVectorVT(VT))
455           addRegisterClass(VT, &AArch64::ZPRRegClass);
456 
457       for (MVT VT : MVT::fp_fixedlen_vector_valuetypes())
458         if (useSVEForFixedLengthVectorVT(VT))
459           addRegisterClass(VT, &AArch64::ZPRRegClass);
460     }
461   }
462 
463   if (Subtarget->hasSVE2p1() || Subtarget->hasSME2()) {
464     addRegisterClass(MVT::aarch64svcount, &AArch64::PPRRegClass);
465     setOperationPromotedToType(ISD::LOAD, MVT::aarch64svcount, MVT::nxv16i1);
466     setOperationPromotedToType(ISD::STORE, MVT::aarch64svcount, MVT::nxv16i1);
467 
468     setOperationAction(ISD::SELECT, MVT::aarch64svcount, Custom);
469     setOperationAction(ISD::SELECT_CC, MVT::aarch64svcount, Expand);
470   }
471 
472   // Compute derived properties from the register classes
473   computeRegisterProperties(Subtarget->getRegisterInfo());
474 
475   // Provide all sorts of operation actions
476   setOperationAction(ISD::GlobalAddress, MVT::i64, Custom);
477   setOperationAction(ISD::GlobalTLSAddress, MVT::i64, Custom);
478   setOperationAction(ISD::SETCC, MVT::i32, Custom);
479   setOperationAction(ISD::SETCC, MVT::i64, Custom);
480   setOperationAction(ISD::SETCC, MVT::bf16, Custom);
481   setOperationAction(ISD::SETCC, MVT::f16, Custom);
482   setOperationAction(ISD::SETCC, MVT::f32, Custom);
483   setOperationAction(ISD::SETCC, MVT::f64, Custom);
484   setOperationAction(ISD::STRICT_FSETCC, MVT::bf16, Custom);
485   setOperationAction(ISD::STRICT_FSETCC, MVT::f16, Custom);
486   setOperationAction(ISD::STRICT_FSETCC, MVT::f32, Custom);
487   setOperationAction(ISD::STRICT_FSETCC, MVT::f64, Custom);
488   setOperationAction(ISD::STRICT_FSETCCS, MVT::f16, Custom);
489   setOperationAction(ISD::STRICT_FSETCCS, MVT::f32, Custom);
490   setOperationAction(ISD::STRICT_FSETCCS, MVT::f64, Custom);
491   setOperationAction(ISD::BITREVERSE, MVT::i32, Legal);
492   setOperationAction(ISD::BITREVERSE, MVT::i64, Legal);
493   setOperationAction(ISD::BRCOND, MVT::Other, Custom);
494   setOperationAction(ISD::BR_CC, MVT::i32, Custom);
495   setOperationAction(ISD::BR_CC, MVT::i64, Custom);
496   setOperationAction(ISD::BR_CC, MVT::f16, Custom);
497   setOperationAction(ISD::BR_CC, MVT::f32, Custom);
498   setOperationAction(ISD::BR_CC, MVT::f64, Custom);
499   setOperationAction(ISD::SELECT, MVT::i32, Custom);
500   setOperationAction(ISD::SELECT, MVT::i64, Custom);
501   setOperationAction(ISD::SELECT, MVT::f16, Custom);
502   setOperationAction(ISD::SELECT, MVT::bf16, Custom);
503   setOperationAction(ISD::SELECT, MVT::f32, Custom);
504   setOperationAction(ISD::SELECT, MVT::f64, Custom);
505   setOperationAction(ISD::SELECT_CC, MVT::i32, Custom);
506   setOperationAction(ISD::SELECT_CC, MVT::i64, Custom);
507   setOperationAction(ISD::SELECT_CC, MVT::f16, Custom);
508   setOperationAction(ISD::SELECT_CC, MVT::bf16, Custom);
509   setOperationAction(ISD::SELECT_CC, MVT::f32, Custom);
510   setOperationAction(ISD::SELECT_CC, MVT::f64, Custom);
511   setOperationAction(ISD::BR_JT, MVT::Other, Custom);
512   setOperationAction(ISD::JumpTable, MVT::i64, Custom);
513   setOperationAction(ISD::BRIND, MVT::Other, Custom);
514   setOperationAction(ISD::SETCCCARRY, MVT::i64, Custom);
515 
516   setOperationAction(ISD::PtrAuthGlobalAddress, MVT::i64, Custom);
517 
518   setOperationAction(ISD::SHL_PARTS, MVT::i64, Custom);
519   setOperationAction(ISD::SRA_PARTS, MVT::i64, Custom);
520   setOperationAction(ISD::SRL_PARTS, MVT::i64, Custom);
521 
522   setOperationAction(ISD::FREM, MVT::f32, Expand);
523   setOperationAction(ISD::FREM, MVT::f64, Expand);
524   setOperationAction(ISD::FREM, MVT::f80, Expand);
525 
526   setOperationAction(ISD::BUILD_PAIR, MVT::i64, Expand);
527 
528   // Custom lowering hooks are needed for XOR
529   // to fold it into CSINC/CSINV.
530   setOperationAction(ISD::XOR, MVT::i32, Custom);
531   setOperationAction(ISD::XOR, MVT::i64, Custom);
532 
533   // Virtually no operation on f128 is legal, but LLVM can't expand them when
534   // there's a valid register class, so we need custom operations in most cases.
535   setOperationAction(ISD::FABS, MVT::f128, Expand);
536   setOperationAction(ISD::FADD, MVT::f128, LibCall);
537   setOperationAction(ISD::FCOPYSIGN, MVT::f128, Expand);
538   setOperationAction(ISD::FCOS, MVT::f128, Expand);
539   setOperationAction(ISD::FDIV, MVT::f128, LibCall);
540   setOperationAction(ISD::FMA, MVT::f128, Expand);
541   setOperationAction(ISD::FMUL, MVT::f128, LibCall);
542   setOperationAction(ISD::FNEG, MVT::f128, Expand);
543   setOperationAction(ISD::FPOW, MVT::f128, Expand);
544   setOperationAction(ISD::FREM, MVT::f128, Expand);
545   setOperationAction(ISD::FRINT, MVT::f128, Expand);
546   setOperationAction(ISD::FSIN, MVT::f128, Expand);
547   setOperationAction(ISD::FSINCOS, MVT::f128, Expand);
548   setOperationAction(ISD::FSQRT, MVT::f128, Expand);
549   setOperationAction(ISD::FSUB, MVT::f128, LibCall);
550   setOperationAction(ISD::FTAN, MVT::f128, Expand);
551   setOperationAction(ISD::FTRUNC, MVT::f128, Expand);
552   setOperationAction(ISD::SETCC, MVT::f128, Custom);
553   setOperationAction(ISD::STRICT_FSETCC, MVT::f128, Custom);
554   setOperationAction(ISD::STRICT_FSETCCS, MVT::f128, Custom);
555   setOperationAction(ISD::BR_CC, MVT::f128, Custom);
556   setOperationAction(ISD::SELECT, MVT::f128, Custom);
557   setOperationAction(ISD::SELECT_CC, MVT::f128, Custom);
558   setOperationAction(ISD::FP_EXTEND, MVT::f128, Custom);
559   // FIXME: f128 FMINIMUM and FMAXIMUM (including STRICT versions) currently
560   // aren't handled.
561 
562   // Lowering for many of the conversions is actually specified by the non-f128
563   // type. The LowerXXX function will be trivial when f128 isn't involved.
564   setOperationAction(ISD::FP_TO_SINT, MVT::i32, Custom);
565   setOperationAction(ISD::FP_TO_SINT, MVT::i64, Custom);
566   setOperationAction(ISD::FP_TO_SINT, MVT::i128, Custom);
567   setOperationAction(ISD::STRICT_FP_TO_SINT, MVT::i32, Custom);
568   setOperationAction(ISD::STRICT_FP_TO_SINT, MVT::i64, Custom);
569   setOperationAction(ISD::STRICT_FP_TO_SINT, MVT::i128, Custom);
570   setOperationAction(ISD::FP_TO_UINT, MVT::i32, Custom);
571   setOperationAction(ISD::FP_TO_UINT, MVT::i64, Custom);
572   setOperationAction(ISD::FP_TO_UINT, MVT::i128, Custom);
573   setOperationAction(ISD::STRICT_FP_TO_UINT, MVT::i32, Custom);
574   setOperationAction(ISD::STRICT_FP_TO_UINT, MVT::i64, Custom);
575   setOperationAction(ISD::STRICT_FP_TO_UINT, MVT::i128, Custom);
576   setOperationAction(ISD::SINT_TO_FP, MVT::i32, Custom);
577   setOperationAction(ISD::SINT_TO_FP, MVT::i64, Custom);
578   setOperationAction(ISD::SINT_TO_FP, MVT::i128, Custom);
579   setOperationAction(ISD::STRICT_SINT_TO_FP, MVT::i32, Custom);
580   setOperationAction(ISD::STRICT_SINT_TO_FP, MVT::i64, Custom);
581   setOperationAction(ISD::STRICT_SINT_TO_FP, MVT::i128, Custom);
582   setOperationAction(ISD::UINT_TO_FP, MVT::i32, Custom);
583   setOperationAction(ISD::UINT_TO_FP, MVT::i64, Custom);
584   setOperationAction(ISD::UINT_TO_FP, MVT::i128, Custom);
585   setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::i32, Custom);
586   setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::i64, Custom);
587   setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::i128, Custom);
588   if (Subtarget->hasFPARMv8()) {
589     setOperationAction(ISD::FP_ROUND, MVT::f16, Custom);
590     setOperationAction(ISD::FP_ROUND, MVT::bf16, Custom);
591   }
592   setOperationAction(ISD::FP_ROUND, MVT::f32, Custom);
593   setOperationAction(ISD::FP_ROUND, MVT::f64, Custom);
594   if (Subtarget->hasFPARMv8()) {
595     setOperationAction(ISD::STRICT_FP_ROUND, MVT::f16, Custom);
596     setOperationAction(ISD::STRICT_FP_ROUND, MVT::bf16, Custom);
597   }
598   setOperationAction(ISD::STRICT_FP_ROUND, MVT::f32, Custom);
599   setOperationAction(ISD::STRICT_FP_ROUND, MVT::f64, Custom);
600 
601   setOperationAction(ISD::FP_TO_UINT_SAT, MVT::i32, Custom);
602   setOperationAction(ISD::FP_TO_UINT_SAT, MVT::i64, Custom);
603   setOperationAction(ISD::FP_TO_SINT_SAT, MVT::i32, Custom);
604   setOperationAction(ISD::FP_TO_SINT_SAT, MVT::i64, Custom);
605 
606   // Variable arguments.
607   setOperationAction(ISD::VASTART, MVT::Other, Custom);
608   setOperationAction(ISD::VAARG, MVT::Other, Custom);
609   setOperationAction(ISD::VACOPY, MVT::Other, Custom);
610   setOperationAction(ISD::VAEND, MVT::Other, Expand);
611 
612   // Variable-sized objects.
613   setOperationAction(ISD::STACKSAVE, MVT::Other, Expand);
614   setOperationAction(ISD::STACKRESTORE, MVT::Other, Expand);
615 
616   // Lowering Funnel Shifts to EXTR
617   setOperationAction(ISD::FSHR, MVT::i32, Custom);
618   setOperationAction(ISD::FSHR, MVT::i64, Custom);
619   setOperationAction(ISD::FSHL, MVT::i32, Custom);
620   setOperationAction(ISD::FSHL, MVT::i64, Custom);
621 
622   setOperationAction(ISD::DYNAMIC_STACKALLOC, MVT::i64, Custom);
623 
624   // Constant pool entries
625   setOperationAction(ISD::ConstantPool, MVT::i64, Custom);
626 
627   // BlockAddress
628   setOperationAction(ISD::BlockAddress, MVT::i64, Custom);
629 
630   // AArch64 lacks both left-rotate and popcount instructions.
631   setOperationAction(ISD::ROTL, MVT::i32, Expand);
632   setOperationAction(ISD::ROTL, MVT::i64, Expand);
633   for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
634     setOperationAction(ISD::ROTL, VT, Expand);
635     setOperationAction(ISD::ROTR, VT, Expand);
636   }
637 
638   // AArch64 doesn't have i32 MULH{S|U}.
639   setOperationAction(ISD::MULHU, MVT::i32, Expand);
640   setOperationAction(ISD::MULHS, MVT::i32, Expand);
641 
642   // AArch64 doesn't have {U|S}MUL_LOHI.
643   setOperationAction(ISD::UMUL_LOHI, MVT::i32, Expand);
644   setOperationAction(ISD::SMUL_LOHI, MVT::i32, Expand);
645   setOperationAction(ISD::UMUL_LOHI, MVT::i64, Expand);
646   setOperationAction(ISD::SMUL_LOHI, MVT::i64, Expand);
647 
648   if (Subtarget->hasCSSC()) {
649     setOperationAction(ISD::CTPOP, MVT::i32, Legal);
650     setOperationAction(ISD::CTPOP, MVT::i64, Legal);
651     setOperationAction(ISD::CTPOP, MVT::i128, Expand);
652 
653     setOperationAction(ISD::PARITY, MVT::i128, Expand);
654 
655     setOperationAction(ISD::CTTZ, MVT::i32, Legal);
656     setOperationAction(ISD::CTTZ, MVT::i64, Legal);
657     setOperationAction(ISD::CTTZ, MVT::i128, Expand);
658 
659     setOperationAction(ISD::ABS, MVT::i32, Legal);
660     setOperationAction(ISD::ABS, MVT::i64, Legal);
661 
662     setOperationAction(ISD::SMAX, MVT::i32, Legal);
663     setOperationAction(ISD::SMAX, MVT::i64, Legal);
664     setOperationAction(ISD::UMAX, MVT::i32, Legal);
665     setOperationAction(ISD::UMAX, MVT::i64, Legal);
666 
667     setOperationAction(ISD::SMIN, MVT::i32, Legal);
668     setOperationAction(ISD::SMIN, MVT::i64, Legal);
669     setOperationAction(ISD::UMIN, MVT::i32, Legal);
670     setOperationAction(ISD::UMIN, MVT::i64, Legal);
671   } else {
672     setOperationAction(ISD::CTPOP, MVT::i32, Custom);
673     setOperationAction(ISD::CTPOP, MVT::i64, Custom);
674     setOperationAction(ISD::CTPOP, MVT::i128, Custom);
675 
676     setOperationAction(ISD::PARITY, MVT::i64, Custom);
677     setOperationAction(ISD::PARITY, MVT::i128, Custom);
678 
679     setOperationAction(ISD::ABS, MVT::i32, Custom);
680     setOperationAction(ISD::ABS, MVT::i64, Custom);
681   }
682 
683   setOperationAction(ISD::SDIVREM, MVT::i32, Expand);
684   setOperationAction(ISD::SDIVREM, MVT::i64, Expand);
685   for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
686     setOperationAction(ISD::SDIVREM, VT, Expand);
687     setOperationAction(ISD::UDIVREM, VT, Expand);
688   }
689   setOperationAction(ISD::SREM, MVT::i32, Expand);
690   setOperationAction(ISD::SREM, MVT::i64, Expand);
691   setOperationAction(ISD::UDIVREM, MVT::i32, Expand);
692   setOperationAction(ISD::UDIVREM, MVT::i64, Expand);
693   setOperationAction(ISD::UREM, MVT::i32, Expand);
694   setOperationAction(ISD::UREM, MVT::i64, Expand);
695 
696   // Custom lower Add/Sub/Mul with overflow.
697   setOperationAction(ISD::SADDO, MVT::i32, Custom);
698   setOperationAction(ISD::SADDO, MVT::i64, Custom);
699   setOperationAction(ISD::UADDO, MVT::i32, Custom);
700   setOperationAction(ISD::UADDO, MVT::i64, Custom);
701   setOperationAction(ISD::SSUBO, MVT::i32, Custom);
702   setOperationAction(ISD::SSUBO, MVT::i64, Custom);
703   setOperationAction(ISD::USUBO, MVT::i32, Custom);
704   setOperationAction(ISD::USUBO, MVT::i64, Custom);
705   setOperationAction(ISD::SMULO, MVT::i32, Custom);
706   setOperationAction(ISD::SMULO, MVT::i64, Custom);
707   setOperationAction(ISD::UMULO, MVT::i32, Custom);
708   setOperationAction(ISD::UMULO, MVT::i64, Custom);
709 
710   setOperationAction(ISD::UADDO_CARRY, MVT::i32, Custom);
711   setOperationAction(ISD::UADDO_CARRY, MVT::i64, Custom);
712   setOperationAction(ISD::USUBO_CARRY, MVT::i32, Custom);
713   setOperationAction(ISD::USUBO_CARRY, MVT::i64, Custom);
714   setOperationAction(ISD::SADDO_CARRY, MVT::i32, Custom);
715   setOperationAction(ISD::SADDO_CARRY, MVT::i64, Custom);
716   setOperationAction(ISD::SSUBO_CARRY, MVT::i32, Custom);
717   setOperationAction(ISD::SSUBO_CARRY, MVT::i64, Custom);
718 
719   setOperationAction(ISD::FSIN, MVT::f32, Expand);
720   setOperationAction(ISD::FSIN, MVT::f64, Expand);
721   setOperationAction(ISD::FCOS, MVT::f32, Expand);
722   setOperationAction(ISD::FCOS, MVT::f64, Expand);
723   setOperationAction(ISD::FPOW, MVT::f32, Expand);
724   setOperationAction(ISD::FPOW, MVT::f64, Expand);
725   setOperationAction(ISD::FCOPYSIGN, MVT::f64, Custom);
726   setOperationAction(ISD::FCOPYSIGN, MVT::f32, Custom);
727   if (Subtarget->hasFullFP16()) {
728     setOperationAction(ISD::FCOPYSIGN, MVT::f16, Custom);
729     setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Custom);
730   } else {
731     setOperationAction(ISD::FCOPYSIGN, MVT::f16, Promote);
732     setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Promote);
733   }
734 
735   for (auto Op : {ISD::FREM,         ISD::FPOW,          ISD::FPOWI,
736                   ISD::FCOS,         ISD::FSIN,          ISD::FSINCOS,
737                   ISD::FACOS,        ISD::FASIN,         ISD::FATAN,
738                   ISD::FCOSH,        ISD::FSINH,         ISD::FTANH,
739                   ISD::FTAN,         ISD::FEXP,          ISD::FEXP2,
740                   ISD::FEXP10,       ISD::FLOG,          ISD::FLOG2,
741                   ISD::FLOG10,       ISD::STRICT_FREM,   ISD::STRICT_FPOW,
742                   ISD::STRICT_FPOWI, ISD::STRICT_FCOS,   ISD::STRICT_FSIN,
743                   ISD::STRICT_FACOS, ISD::STRICT_FASIN,  ISD::STRICT_FATAN,
744                   ISD::STRICT_FCOSH, ISD::STRICT_FSINH,  ISD::STRICT_FTANH,
745                   ISD::STRICT_FEXP,  ISD::STRICT_FEXP2,  ISD::STRICT_FLOG,
746                   ISD::STRICT_FLOG2, ISD::STRICT_FLOG10, ISD::STRICT_FTAN}) {
747     setOperationAction(Op, MVT::f16, Promote);
748     setOperationAction(Op, MVT::v4f16, Expand);
749     setOperationAction(Op, MVT::v8f16, Expand);
750     setOperationAction(Op, MVT::bf16, Promote);
751     setOperationAction(Op, MVT::v4bf16, Expand);
752     setOperationAction(Op, MVT::v8bf16, Expand);
753   }
754 
755   auto LegalizeNarrowFP = [this](MVT ScalarVT) {
756     for (auto Op : {
757              ISD::SETCC,
758              ISD::SELECT_CC,
759              ISD::BR_CC,
760              ISD::FADD,
761              ISD::FSUB,
762              ISD::FMUL,
763              ISD::FDIV,
764              ISD::FMA,
765              ISD::FCEIL,
766              ISD::FSQRT,
767              ISD::FFLOOR,
768              ISD::FNEARBYINT,
769              ISD::FRINT,
770              ISD::FROUND,
771              ISD::FROUNDEVEN,
772              ISD::FTRUNC,
773              ISD::FMINNUM,
774              ISD::FMAXNUM,
775              ISD::FMINIMUM,
776              ISD::FMAXIMUM,
777              ISD::STRICT_FADD,
778              ISD::STRICT_FSUB,
779              ISD::STRICT_FMUL,
780              ISD::STRICT_FDIV,
781              ISD::STRICT_FMA,
782              ISD::STRICT_FCEIL,
783              ISD::STRICT_FFLOOR,
784              ISD::STRICT_FSQRT,
785              ISD::STRICT_FRINT,
786              ISD::STRICT_FNEARBYINT,
787              ISD::STRICT_FROUND,
788              ISD::STRICT_FTRUNC,
789              ISD::STRICT_FROUNDEVEN,
790              ISD::STRICT_FMINNUM,
791              ISD::STRICT_FMAXNUM,
792              ISD::STRICT_FMINIMUM,
793              ISD::STRICT_FMAXIMUM,
794          })
795       setOperationAction(Op, ScalarVT, Promote);
796 
797     for (auto Op : {ISD::FNEG, ISD::FABS})
798       setOperationAction(Op, ScalarVT, Legal);
799 
800     // Round-to-integer need custom lowering for fp16, as Promote doesn't work
801     // because the result type is integer.
802     for (auto Op : {ISD::LROUND, ISD::LLROUND, ISD::LRINT, ISD::LLRINT,
803                     ISD::STRICT_LROUND, ISD::STRICT_LLROUND, ISD::STRICT_LRINT,
804                     ISD::STRICT_LLRINT})
805       setOperationAction(Op, ScalarVT, Custom);
806 
807     // promote v4f16 to v4f32 when that is known to be safe.
808     auto V4Narrow = MVT::getVectorVT(ScalarVT, 4);
809     setOperationPromotedToType(ISD::FADD,       V4Narrow, MVT::v4f32);
810     setOperationPromotedToType(ISD::FSUB,       V4Narrow, MVT::v4f32);
811     setOperationPromotedToType(ISD::FMUL,       V4Narrow, MVT::v4f32);
812     setOperationPromotedToType(ISD::FDIV,       V4Narrow, MVT::v4f32);
813     setOperationPromotedToType(ISD::FCEIL,      V4Narrow, MVT::v4f32);
814     setOperationPromotedToType(ISD::FFLOOR,     V4Narrow, MVT::v4f32);
815     setOperationPromotedToType(ISD::FROUND,     V4Narrow, MVT::v4f32);
816     setOperationPromotedToType(ISD::FTRUNC,     V4Narrow, MVT::v4f32);
817     setOperationPromotedToType(ISD::FROUNDEVEN, V4Narrow, MVT::v4f32);
818     setOperationPromotedToType(ISD::FRINT,      V4Narrow, MVT::v4f32);
819     setOperationPromotedToType(ISD::FNEARBYINT, V4Narrow, MVT::v4f32);
820 
821     setOperationAction(ISD::FABS,        V4Narrow, Legal);
822     setOperationAction(ISD::FNEG, 	 V4Narrow, Legal);
823     setOperationAction(ISD::FMA,         V4Narrow, Expand);
824     setOperationAction(ISD::SETCC,       V4Narrow, Custom);
825     setOperationAction(ISD::BR_CC,       V4Narrow, Expand);
826     setOperationAction(ISD::SELECT,      V4Narrow, Expand);
827     setOperationAction(ISD::SELECT_CC,   V4Narrow, Expand);
828     setOperationAction(ISD::FCOPYSIGN,   V4Narrow, Custom);
829     setOperationAction(ISD::FSQRT,       V4Narrow, Expand);
830 
831     auto V8Narrow = MVT::getVectorVT(ScalarVT, 8);
832     setOperationAction(ISD::FABS,        V8Narrow, Legal);
833     setOperationAction(ISD::FADD,        V8Narrow, Legal);
834     setOperationAction(ISD::FCEIL,       V8Narrow, Legal);
835     setOperationAction(ISD::FCOPYSIGN,   V8Narrow, Custom);
836     setOperationAction(ISD::FDIV,        V8Narrow, Legal);
837     setOperationAction(ISD::FFLOOR,      V8Narrow, Legal);
838     setOperationAction(ISD::FMA,         V8Narrow, Expand);
839     setOperationAction(ISD::FMUL,        V8Narrow, Legal);
840     setOperationAction(ISD::FNEARBYINT,  V8Narrow, Legal);
841     setOperationAction(ISD::FNEG, 	 V8Narrow, Legal);
842     setOperationAction(ISD::FROUND,      V8Narrow, Legal);
843     setOperationAction(ISD::FROUNDEVEN,  V8Narrow, Legal);
844     setOperationAction(ISD::FRINT,       V8Narrow, Legal);
845     setOperationAction(ISD::FSQRT,       V8Narrow, Expand);
846     setOperationAction(ISD::FSUB,        V8Narrow, Legal);
847     setOperationAction(ISD::FTRUNC,      V8Narrow, Legal);
848     setOperationAction(ISD::SETCC,       V8Narrow, Expand);
849     setOperationAction(ISD::BR_CC,       V8Narrow, Expand);
850     setOperationAction(ISD::SELECT,      V8Narrow, Expand);
851     setOperationAction(ISD::SELECT_CC,   V8Narrow, Expand);
852     setOperationAction(ISD::FP_EXTEND,   V8Narrow, Expand);
853   };
854 
855   if (!Subtarget->hasFullFP16()) {
856     LegalizeNarrowFP(MVT::f16);
857   }
858   LegalizeNarrowFP(MVT::bf16);
859   setOperationAction(ISD::FP_ROUND, MVT::v4f32, Custom);
860   setOperationAction(ISD::FP_ROUND, MVT::v4bf16, Custom);
861 
862   // AArch64 has implementations of a lot of rounding-like FP operations.
863   for (auto Op :
864        {ISD::FFLOOR,          ISD::FNEARBYINT,      ISD::FCEIL,
865         ISD::FRINT,           ISD::FTRUNC,          ISD::FROUND,
866         ISD::FROUNDEVEN,      ISD::FMINNUM,         ISD::FMAXNUM,
867         ISD::FMINIMUM,        ISD::FMAXIMUM,        ISD::LROUND,
868         ISD::LLROUND,         ISD::LRINT,           ISD::LLRINT,
869         ISD::STRICT_FFLOOR,   ISD::STRICT_FCEIL,    ISD::STRICT_FNEARBYINT,
870         ISD::STRICT_FRINT,    ISD::STRICT_FTRUNC,   ISD::STRICT_FROUNDEVEN,
871         ISD::STRICT_FROUND,   ISD::STRICT_FMINNUM,  ISD::STRICT_FMAXNUM,
872         ISD::STRICT_FMINIMUM, ISD::STRICT_FMAXIMUM, ISD::STRICT_LROUND,
873         ISD::STRICT_LLROUND,  ISD::STRICT_LRINT,    ISD::STRICT_LLRINT}) {
874     for (MVT Ty : {MVT::f32, MVT::f64})
875       setOperationAction(Op, Ty, Legal);
876     if (Subtarget->hasFullFP16())
877       setOperationAction(Op, MVT::f16, Legal);
878   }
879 
880   // Basic strict FP operations are legal
881   for (auto Op : {ISD::STRICT_FADD, ISD::STRICT_FSUB, ISD::STRICT_FMUL,
882                   ISD::STRICT_FDIV, ISD::STRICT_FMA, ISD::STRICT_FSQRT}) {
883     for (MVT Ty : {MVT::f32, MVT::f64})
884       setOperationAction(Op, Ty, Legal);
885     if (Subtarget->hasFullFP16())
886       setOperationAction(Op, MVT::f16, Legal);
887   }
888 
889   // Strict conversion to a larger type is legal
890   for (auto VT : {MVT::f32, MVT::f64})
891     setOperationAction(ISD::STRICT_FP_EXTEND, VT, Legal);
892 
893   setOperationAction(ISD::PREFETCH, MVT::Other, Custom);
894 
895   setOperationAction(ISD::GET_ROUNDING, MVT::i32, Custom);
896   setOperationAction(ISD::SET_ROUNDING, MVT::Other, Custom);
897   setOperationAction(ISD::GET_FPMODE, MVT::i32, Custom);
898   setOperationAction(ISD::SET_FPMODE, MVT::i32, Custom);
899   setOperationAction(ISD::RESET_FPMODE, MVT::Other, Custom);
900 
901   setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i128, Custom);
902   if (!Subtarget->hasLSE() && !Subtarget->outlineAtomics()) {
903     setOperationAction(ISD::ATOMIC_LOAD_SUB, MVT::i32, LibCall);
904     setOperationAction(ISD::ATOMIC_LOAD_SUB, MVT::i64, LibCall);
905   } else {
906     setOperationAction(ISD::ATOMIC_LOAD_SUB, MVT::i32, Expand);
907     setOperationAction(ISD::ATOMIC_LOAD_SUB, MVT::i64, Expand);
908   }
909   setOperationAction(ISD::ATOMIC_LOAD_AND, MVT::i32, Custom);
910   setOperationAction(ISD::ATOMIC_LOAD_AND, MVT::i64, Custom);
911 
912   // Generate outline atomics library calls only if LSE was not specified for
913   // subtarget
914   if (Subtarget->outlineAtomics() && !Subtarget->hasLSE()) {
915     setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i8, LibCall);
916     setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i16, LibCall);
917     setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i32, LibCall);
918     setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i64, LibCall);
919     setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i128, LibCall);
920     setOperationAction(ISD::ATOMIC_SWAP, MVT::i8, LibCall);
921     setOperationAction(ISD::ATOMIC_SWAP, MVT::i16, LibCall);
922     setOperationAction(ISD::ATOMIC_SWAP, MVT::i32, LibCall);
923     setOperationAction(ISD::ATOMIC_SWAP, MVT::i64, LibCall);
924     setOperationAction(ISD::ATOMIC_LOAD_ADD, MVT::i8, LibCall);
925     setOperationAction(ISD::ATOMIC_LOAD_ADD, MVT::i16, LibCall);
926     setOperationAction(ISD::ATOMIC_LOAD_ADD, MVT::i32, LibCall);
927     setOperationAction(ISD::ATOMIC_LOAD_ADD, MVT::i64, LibCall);
928     setOperationAction(ISD::ATOMIC_LOAD_OR, MVT::i8, LibCall);
929     setOperationAction(ISD::ATOMIC_LOAD_OR, MVT::i16, LibCall);
930     setOperationAction(ISD::ATOMIC_LOAD_OR, MVT::i32, LibCall);
931     setOperationAction(ISD::ATOMIC_LOAD_OR, MVT::i64, LibCall);
932     setOperationAction(ISD::ATOMIC_LOAD_CLR, MVT::i8, LibCall);
933     setOperationAction(ISD::ATOMIC_LOAD_CLR, MVT::i16, LibCall);
934     setOperationAction(ISD::ATOMIC_LOAD_CLR, MVT::i32, LibCall);
935     setOperationAction(ISD::ATOMIC_LOAD_CLR, MVT::i64, LibCall);
936     setOperationAction(ISD::ATOMIC_LOAD_XOR, MVT::i8, LibCall);
937     setOperationAction(ISD::ATOMIC_LOAD_XOR, MVT::i16, LibCall);
938     setOperationAction(ISD::ATOMIC_LOAD_XOR, MVT::i32, LibCall);
939     setOperationAction(ISD::ATOMIC_LOAD_XOR, MVT::i64, LibCall);
940 #define LCALLNAMES(A, B, N)                                                    \
941   setLibcallName(A##N##_RELAX, #B #N "_relax");                                \
942   setLibcallName(A##N##_ACQ, #B #N "_acq");                                    \
943   setLibcallName(A##N##_REL, #B #N "_rel");                                    \
944   setLibcallName(A##N##_ACQ_REL, #B #N "_acq_rel");
945 #define LCALLNAME4(A, B)                                                       \
946   LCALLNAMES(A, B, 1)                                                          \
947   LCALLNAMES(A, B, 2) LCALLNAMES(A, B, 4) LCALLNAMES(A, B, 8)
948 #define LCALLNAME5(A, B)                                                       \
949   LCALLNAMES(A, B, 1)                                                          \
950   LCALLNAMES(A, B, 2)                                                          \
951   LCALLNAMES(A, B, 4) LCALLNAMES(A, B, 8) LCALLNAMES(A, B, 16)
952     LCALLNAME5(RTLIB::OUTLINE_ATOMIC_CAS, __aarch64_cas)
953     LCALLNAME4(RTLIB::OUTLINE_ATOMIC_SWP, __aarch64_swp)
954     LCALLNAME4(RTLIB::OUTLINE_ATOMIC_LDADD, __aarch64_ldadd)
955     LCALLNAME4(RTLIB::OUTLINE_ATOMIC_LDSET, __aarch64_ldset)
956     LCALLNAME4(RTLIB::OUTLINE_ATOMIC_LDCLR, __aarch64_ldclr)
957     LCALLNAME4(RTLIB::OUTLINE_ATOMIC_LDEOR, __aarch64_ldeor)
958 #undef LCALLNAMES
959 #undef LCALLNAME4
960 #undef LCALLNAME5
961   }
962 
963   if (Subtarget->hasLSE128()) {
964     // Custom lowering because i128 is not legal. Must be replaced by 2x64
965     // values. ATOMIC_LOAD_AND also needs op legalisation to emit LDCLRP.
966     setOperationAction(ISD::ATOMIC_LOAD_AND, MVT::i128, Custom);
967     setOperationAction(ISD::ATOMIC_LOAD_OR, MVT::i128, Custom);
968     setOperationAction(ISD::ATOMIC_SWAP, MVT::i128, Custom);
969   }
970 
971   // 128-bit loads and stores can be done without expanding
972   setOperationAction(ISD::LOAD, MVT::i128, Custom);
973   setOperationAction(ISD::STORE, MVT::i128, Custom);
974 
975   // Aligned 128-bit loads and stores are single-copy atomic according to the
976   // v8.4a spec. LRCPC3 introduces 128-bit STILP/LDIAPP but still requires LSE2.
977   if (Subtarget->hasLSE2()) {
978     setOperationAction(ISD::ATOMIC_LOAD, MVT::i128, Custom);
979     setOperationAction(ISD::ATOMIC_STORE, MVT::i128, Custom);
980   }
981 
982   // 256 bit non-temporal stores can be lowered to STNP. Do this as part of the
983   // custom lowering, as there are no un-paired non-temporal stores and
984   // legalization will break up 256 bit inputs.
985   setOperationAction(ISD::STORE, MVT::v32i8, Custom);
986   setOperationAction(ISD::STORE, MVT::v16i16, Custom);
987   setOperationAction(ISD::STORE, MVT::v16f16, Custom);
988   setOperationAction(ISD::STORE, MVT::v16bf16, Custom);
989   setOperationAction(ISD::STORE, MVT::v8i32, Custom);
990   setOperationAction(ISD::STORE, MVT::v8f32, Custom);
991   setOperationAction(ISD::STORE, MVT::v4f64, Custom);
992   setOperationAction(ISD::STORE, MVT::v4i64, Custom);
993 
994   // 256 bit non-temporal loads can be lowered to LDNP. This is done using
995   // custom lowering, as there are no un-paired non-temporal loads legalization
996   // will break up 256 bit inputs.
997   setOperationAction(ISD::LOAD, MVT::v32i8, Custom);
998   setOperationAction(ISD::LOAD, MVT::v16i16, Custom);
999   setOperationAction(ISD::LOAD, MVT::v16f16, Custom);
1000   setOperationAction(ISD::LOAD, MVT::v16bf16, Custom);
1001   setOperationAction(ISD::LOAD, MVT::v8i32, Custom);
1002   setOperationAction(ISD::LOAD, MVT::v8f32, Custom);
1003   setOperationAction(ISD::LOAD, MVT::v4f64, Custom);
1004   setOperationAction(ISD::LOAD, MVT::v4i64, Custom);
1005 
1006   // Lower READCYCLECOUNTER using an mrs from CNTVCT_EL0.
1007   setOperationAction(ISD::READCYCLECOUNTER, MVT::i64, Legal);
1008 
1009   if (getLibcallName(RTLIB::SINCOS_STRET_F32) != nullptr &&
1010       getLibcallName(RTLIB::SINCOS_STRET_F64) != nullptr) {
1011     // Issue __sincos_stret if available.
1012     setOperationAction(ISD::FSINCOS, MVT::f64, Custom);
1013     setOperationAction(ISD::FSINCOS, MVT::f32, Custom);
1014   } else {
1015     setOperationAction(ISD::FSINCOS, MVT::f64, Expand);
1016     setOperationAction(ISD::FSINCOS, MVT::f32, Expand);
1017   }
1018 
1019   // Make floating-point constants legal for the large code model, so they don't
1020   // become loads from the constant pool.
1021   if (Subtarget->isTargetMachO() && TM.getCodeModel() == CodeModel::Large) {
1022     setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
1023     setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
1024   }
1025 
1026   // AArch64 does not have floating-point extending loads, i1 sign-extending
1027   // load, floating-point truncating stores, or v2i32->v2i16 truncating store.
1028   for (MVT VT : MVT::fp_valuetypes()) {
1029     setLoadExtAction(ISD::EXTLOAD, VT, MVT::bf16, Expand);
1030     setLoadExtAction(ISD::EXTLOAD, VT, MVT::f16, Expand);
1031     setLoadExtAction(ISD::EXTLOAD, VT, MVT::f32, Expand);
1032     setLoadExtAction(ISD::EXTLOAD, VT, MVT::f64, Expand);
1033     setLoadExtAction(ISD::EXTLOAD, VT, MVT::f80, Expand);
1034   }
1035   for (MVT VT : MVT::integer_valuetypes())
1036     setLoadExtAction(ISD::SEXTLOAD, VT, MVT::i1, Expand);
1037 
1038   for (MVT WideVT : MVT::fp_valuetypes()) {
1039     for (MVT NarrowVT : MVT::fp_valuetypes()) {
1040       if (WideVT.getScalarSizeInBits() > NarrowVT.getScalarSizeInBits()) {
1041         setTruncStoreAction(WideVT, NarrowVT, Expand);
1042       }
1043     }
1044   }
1045 
1046   if (Subtarget->hasFPARMv8()) {
1047     setOperationAction(ISD::BITCAST, MVT::i16, Custom);
1048     setOperationAction(ISD::BITCAST, MVT::f16, Custom);
1049     setOperationAction(ISD::BITCAST, MVT::bf16, Custom);
1050   }
1051 
1052   // Indexed loads and stores are supported.
1053   for (unsigned im = (unsigned)ISD::PRE_INC;
1054        im != (unsigned)ISD::LAST_INDEXED_MODE; ++im) {
1055     setIndexedLoadAction(im, MVT::i8, Legal);
1056     setIndexedLoadAction(im, MVT::i16, Legal);
1057     setIndexedLoadAction(im, MVT::i32, Legal);
1058     setIndexedLoadAction(im, MVT::i64, Legal);
1059     setIndexedLoadAction(im, MVT::f64, Legal);
1060     setIndexedLoadAction(im, MVT::f32, Legal);
1061     setIndexedLoadAction(im, MVT::f16, Legal);
1062     setIndexedLoadAction(im, MVT::bf16, Legal);
1063     setIndexedStoreAction(im, MVT::i8, Legal);
1064     setIndexedStoreAction(im, MVT::i16, Legal);
1065     setIndexedStoreAction(im, MVT::i32, Legal);
1066     setIndexedStoreAction(im, MVT::i64, Legal);
1067     setIndexedStoreAction(im, MVT::f64, Legal);
1068     setIndexedStoreAction(im, MVT::f32, Legal);
1069     setIndexedStoreAction(im, MVT::f16, Legal);
1070     setIndexedStoreAction(im, MVT::bf16, Legal);
1071   }
1072 
1073   // Trap.
1074   setOperationAction(ISD::TRAP, MVT::Other, Legal);
1075   setOperationAction(ISD::DEBUGTRAP, MVT::Other, Legal);
1076   setOperationAction(ISD::UBSANTRAP, MVT::Other, Legal);
1077 
1078   // We combine OR nodes for bitfield operations.
1079   setTargetDAGCombine(ISD::OR);
1080   // Try to create BICs for vector ANDs.
1081   setTargetDAGCombine(ISD::AND);
1082 
1083   // llvm.init.trampoline and llvm.adjust.trampoline
1084   setOperationAction(ISD::INIT_TRAMPOLINE, MVT::Other, Custom);
1085   setOperationAction(ISD::ADJUST_TRAMPOLINE, MVT::Other, Custom);
1086 
1087   // Vector add and sub nodes may conceal a high-half opportunity.
1088   // Also, try to fold ADD into CSINC/CSINV..
1089   setTargetDAGCombine({ISD::ADD, ISD::ABS, ISD::SUB, ISD::XOR, ISD::SINT_TO_FP,
1090                        ISD::UINT_TO_FP});
1091 
1092   setTargetDAGCombine({ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FP_TO_SINT_SAT,
1093                        ISD::FP_TO_UINT_SAT, ISD::FADD});
1094 
1095   // Try and combine setcc with csel
1096   setTargetDAGCombine(ISD::SETCC);
1097 
1098   setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
1099 
1100   setTargetDAGCombine({ISD::ANY_EXTEND, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND,
1101                        ISD::SIGN_EXTEND_INREG, ISD::CONCAT_VECTORS,
1102                        ISD::EXTRACT_SUBVECTOR, ISD::INSERT_SUBVECTOR,
1103                        ISD::STORE, ISD::BUILD_VECTOR});
1104   setTargetDAGCombine(ISD::TRUNCATE);
1105   setTargetDAGCombine(ISD::LOAD);
1106 
1107   setTargetDAGCombine(ISD::MSTORE);
1108 
1109   setTargetDAGCombine(ISD::MUL);
1110 
1111   setTargetDAGCombine({ISD::SELECT, ISD::VSELECT});
1112 
1113   setTargetDAGCombine({ISD::INTRINSIC_VOID, ISD::INTRINSIC_W_CHAIN,
1114                        ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT,
1115                        ISD::VECREDUCE_ADD, ISD::STEP_VECTOR});
1116 
1117   setTargetDAGCombine({ISD::MGATHER, ISD::MSCATTER});
1118 
1119   setTargetDAGCombine(ISD::FP_EXTEND);
1120 
1121   setTargetDAGCombine(ISD::GlobalAddress);
1122 
1123   setTargetDAGCombine(ISD::CTLZ);
1124 
1125   setTargetDAGCombine(ISD::VECREDUCE_AND);
1126   setTargetDAGCombine(ISD::VECREDUCE_OR);
1127   setTargetDAGCombine(ISD::VECREDUCE_XOR);
1128 
1129   setTargetDAGCombine(ISD::SCALAR_TO_VECTOR);
1130 
1131   // In case of strict alignment, avoid an excessive number of byte wide stores.
1132   MaxStoresPerMemsetOptSize = 8;
1133   MaxStoresPerMemset =
1134       Subtarget->requiresStrictAlign() ? MaxStoresPerMemsetOptSize : 32;
1135 
1136   MaxGluedStoresPerMemcpy = 4;
1137   MaxStoresPerMemcpyOptSize = 4;
1138   MaxStoresPerMemcpy =
1139       Subtarget->requiresStrictAlign() ? MaxStoresPerMemcpyOptSize : 16;
1140 
1141   MaxStoresPerMemmoveOptSize = 4;
1142   MaxStoresPerMemmove = 4;
1143 
1144   MaxLoadsPerMemcmpOptSize = 4;
1145   MaxLoadsPerMemcmp =
1146       Subtarget->requiresStrictAlign() ? MaxLoadsPerMemcmpOptSize : 8;
1147 
1148   setStackPointerRegisterToSaveRestore(AArch64::SP);
1149 
1150   setSchedulingPreference(Sched::Hybrid);
1151 
1152   EnableExtLdPromotion = true;
1153 
1154   // Set required alignment.
1155   setMinFunctionAlignment(Align(4));
1156   // Set preferred alignments.
1157 
1158   // Don't align loops on Windows. The SEH unwind info generation needs to
1159   // know the exact length of functions before the alignments have been
1160   // expanded.
1161   if (!Subtarget->isTargetWindows())
1162     setPrefLoopAlignment(STI.getPrefLoopAlignment());
1163   setMaxBytesForAlignment(STI.getMaxBytesForLoopAlignment());
1164   setPrefFunctionAlignment(STI.getPrefFunctionAlignment());
1165 
1166   // Only change the limit for entries in a jump table if specified by
1167   // the sub target, but not at the command line.
1168   unsigned MaxJT = STI.getMaximumJumpTableSize();
1169   if (MaxJT && getMaximumJumpTableSize() == UINT_MAX)
1170     setMaximumJumpTableSize(MaxJT);
1171 
1172   setHasExtractBitsInsn(true);
1173 
1174   setMaxDivRemBitWidthSupported(128);
1175 
1176   setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
1177 
1178   if (Subtarget->isNeonAvailable()) {
1179     // FIXME: v1f64 shouldn't be legal if we can avoid it, because it leads to
1180     // silliness like this:
1181     // clang-format off
1182     for (auto Op :
1183          {ISD::SELECT,            ISD::SELECT_CC,
1184           ISD::BR_CC,             ISD::FADD,           ISD::FSUB,
1185           ISD::FMUL,              ISD::FDIV,           ISD::FMA,
1186           ISD::FNEG,              ISD::FABS,           ISD::FCEIL,
1187           ISD::FSQRT,             ISD::FFLOOR,         ISD::FNEARBYINT,
1188           ISD::FSIN,              ISD::FCOS,           ISD::FTAN,
1189           ISD::FASIN,             ISD::FACOS,          ISD::FATAN,
1190           ISD::FSINH,             ISD::FCOSH,          ISD::FTANH,
1191           ISD::FPOW,              ISD::FLOG,           ISD::FLOG2,
1192           ISD::FLOG10,            ISD::FEXP,           ISD::FEXP2,
1193           ISD::FEXP10,            ISD::FRINT,          ISD::FROUND,
1194           ISD::FROUNDEVEN,        ISD::FTRUNC,         ISD::FMINNUM,
1195           ISD::FMAXNUM,           ISD::FMINIMUM,       ISD::FMAXIMUM,
1196           ISD::STRICT_FADD,       ISD::STRICT_FSUB,    ISD::STRICT_FMUL,
1197           ISD::STRICT_FDIV,       ISD::STRICT_FMA,     ISD::STRICT_FCEIL,
1198           ISD::STRICT_FFLOOR,     ISD::STRICT_FSQRT,   ISD::STRICT_FRINT,
1199           ISD::STRICT_FNEARBYINT, ISD::STRICT_FROUND,  ISD::STRICT_FTRUNC,
1200           ISD::STRICT_FROUNDEVEN, ISD::STRICT_FMINNUM, ISD::STRICT_FMAXNUM,
1201           ISD::STRICT_FMINIMUM,   ISD::STRICT_FMAXIMUM})
1202       setOperationAction(Op, MVT::v1f64, Expand);
1203     // clang-format on
1204     for (auto Op :
1205          {ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::SINT_TO_FP, ISD::UINT_TO_FP,
1206           ISD::FP_ROUND, ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT, ISD::MUL,
1207           ISD::STRICT_FP_TO_SINT, ISD::STRICT_FP_TO_UINT,
1208           ISD::STRICT_SINT_TO_FP, ISD::STRICT_UINT_TO_FP, ISD::STRICT_FP_ROUND})
1209       setOperationAction(Op, MVT::v1i64, Expand);
1210 
1211     // AArch64 doesn't have a direct vector ->f32 conversion instructions for
1212     // elements smaller than i32, so promote the input to i32 first.
1213     setOperationPromotedToType(ISD::UINT_TO_FP, MVT::v4i8, MVT::v4i32);
1214     setOperationPromotedToType(ISD::SINT_TO_FP, MVT::v4i8, MVT::v4i32);
1215 
1216     // Similarly, there is no direct i32 -> f64 vector conversion instruction.
1217     // Or, direct i32 -> f16 vector conversion.  Set it so custom, so the
1218     // conversion happens in two steps: v4i32 -> v4f32 -> v4f16
1219     for (auto Op : {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::STRICT_SINT_TO_FP,
1220                     ISD::STRICT_UINT_TO_FP})
1221       for (auto VT : {MVT::v2i32, MVT::v2i64, MVT::v4i32})
1222         setOperationAction(Op, VT, Custom);
1223 
1224     if (Subtarget->hasFullFP16()) {
1225       setOperationAction(ISD::ConstantFP, MVT::f16, Legal);
1226       setOperationAction(ISD::ConstantFP, MVT::bf16, Legal);
1227 
1228       setOperationAction(ISD::SINT_TO_FP, MVT::v8i8, Custom);
1229       setOperationAction(ISD::UINT_TO_FP, MVT::v8i8, Custom);
1230       setOperationAction(ISD::SINT_TO_FP, MVT::v16i8, Custom);
1231       setOperationAction(ISD::UINT_TO_FP, MVT::v16i8, Custom);
1232       setOperationAction(ISD::SINT_TO_FP, MVT::v4i16, Custom);
1233       setOperationAction(ISD::UINT_TO_FP, MVT::v4i16, Custom);
1234       setOperationAction(ISD::SINT_TO_FP, MVT::v8i16, Custom);
1235       setOperationAction(ISD::UINT_TO_FP, MVT::v8i16, Custom);
1236     } else {
1237       // when AArch64 doesn't have fullfp16 support, promote the input
1238       // to i32 first.
1239       setOperationPromotedToType(ISD::SINT_TO_FP, MVT::v8i8, MVT::v8i32);
1240       setOperationPromotedToType(ISD::UINT_TO_FP, MVT::v8i8, MVT::v8i32);
1241       setOperationPromotedToType(ISD::UINT_TO_FP, MVT::v16i8, MVT::v16i32);
1242       setOperationPromotedToType(ISD::SINT_TO_FP, MVT::v16i8, MVT::v16i32);
1243       setOperationPromotedToType(ISD::UINT_TO_FP, MVT::v4i16, MVT::v4i32);
1244       setOperationPromotedToType(ISD::SINT_TO_FP, MVT::v4i16, MVT::v4i32);
1245       setOperationPromotedToType(ISD::SINT_TO_FP, MVT::v8i16, MVT::v8i32);
1246       setOperationPromotedToType(ISD::UINT_TO_FP, MVT::v8i16, MVT::v8i32);
1247     }
1248 
1249     setOperationAction(ISD::CTLZ,       MVT::v1i64, Expand);
1250     setOperationAction(ISD::CTLZ,       MVT::v2i64, Expand);
1251     setOperationAction(ISD::BITREVERSE, MVT::v8i8, Legal);
1252     setOperationAction(ISD::BITREVERSE, MVT::v16i8, Legal);
1253     setOperationAction(ISD::BITREVERSE, MVT::v2i32, Custom);
1254     setOperationAction(ISD::BITREVERSE, MVT::v4i32, Custom);
1255     setOperationAction(ISD::BITREVERSE, MVT::v1i64, Custom);
1256     setOperationAction(ISD::BITREVERSE, MVT::v2i64, Custom);
1257     for (auto VT : {MVT::v1i64, MVT::v2i64}) {
1258       setOperationAction(ISD::UMAX, VT, Custom);
1259       setOperationAction(ISD::SMAX, VT, Custom);
1260       setOperationAction(ISD::UMIN, VT, Custom);
1261       setOperationAction(ISD::SMIN, VT, Custom);
1262     }
1263 
1264     // Custom handling for some quad-vector types to detect MULL.
1265     setOperationAction(ISD::MUL, MVT::v8i16, Custom);
1266     setOperationAction(ISD::MUL, MVT::v4i32, Custom);
1267     setOperationAction(ISD::MUL, MVT::v2i64, Custom);
1268     setOperationAction(ISD::MUL, MVT::v4i16, Custom);
1269     setOperationAction(ISD::MUL, MVT::v2i32, Custom);
1270     setOperationAction(ISD::MUL, MVT::v1i64, Custom);
1271 
1272     // Saturates
1273     for (MVT VT : { MVT::v8i8, MVT::v4i16, MVT::v2i32,
1274                     MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64 }) {
1275       setOperationAction(ISD::SADDSAT, VT, Legal);
1276       setOperationAction(ISD::UADDSAT, VT, Legal);
1277       setOperationAction(ISD::SSUBSAT, VT, Legal);
1278       setOperationAction(ISD::USUBSAT, VT, Legal);
1279     }
1280 
1281     for (MVT VT : {MVT::v8i8, MVT::v4i16, MVT::v2i32, MVT::v16i8, MVT::v8i16,
1282                    MVT::v4i32}) {
1283       setOperationAction(ISD::AVGFLOORS, VT, Legal);
1284       setOperationAction(ISD::AVGFLOORU, VT, Legal);
1285       setOperationAction(ISD::AVGCEILS, VT, Legal);
1286       setOperationAction(ISD::AVGCEILU, VT, Legal);
1287       setOperationAction(ISD::ABDS, VT, Legal);
1288       setOperationAction(ISD::ABDU, VT, Legal);
1289     }
1290 
1291     // Vector reductions
1292     for (MVT VT : { MVT::v4f16, MVT::v2f32,
1293                     MVT::v8f16, MVT::v4f32, MVT::v2f64 }) {
1294       if (VT.getVectorElementType() != MVT::f16 || Subtarget->hasFullFP16()) {
1295         setOperationAction(ISD::VECREDUCE_FMAX, VT, Legal);
1296         setOperationAction(ISD::VECREDUCE_FMIN, VT, Legal);
1297         setOperationAction(ISD::VECREDUCE_FMAXIMUM, VT, Legal);
1298         setOperationAction(ISD::VECREDUCE_FMINIMUM, VT, Legal);
1299 
1300         setOperationAction(ISD::VECREDUCE_FADD, VT, Legal);
1301       }
1302     }
1303     for (MVT VT : { MVT::v8i8, MVT::v4i16, MVT::v2i32,
1304                     MVT::v16i8, MVT::v8i16, MVT::v4i32 }) {
1305       setOperationAction(ISD::VECREDUCE_ADD, VT, Custom);
1306       setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom);
1307       setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom);
1308       setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
1309       setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
1310       setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
1311       setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
1312       setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
1313     }
1314     setOperationAction(ISD::VECREDUCE_ADD, MVT::v2i64, Custom);
1315     setOperationAction(ISD::VECREDUCE_AND, MVT::v2i64, Custom);
1316     setOperationAction(ISD::VECREDUCE_OR, MVT::v2i64, Custom);
1317     setOperationAction(ISD::VECREDUCE_XOR, MVT::v2i64, Custom);
1318 
1319     setOperationAction(ISD::ANY_EXTEND, MVT::v4i32, Legal);
1320     setTruncStoreAction(MVT::v2i32, MVT::v2i16, Expand);
1321     // Likewise, narrowing and extending vector loads/stores aren't handled
1322     // directly.
1323     for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
1324       setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Expand);
1325 
1326       if (VT == MVT::v16i8 || VT == MVT::v8i16 || VT == MVT::v4i32) {
1327         setOperationAction(ISD::MULHS, VT, Legal);
1328         setOperationAction(ISD::MULHU, VT, Legal);
1329       } else {
1330         setOperationAction(ISD::MULHS, VT, Expand);
1331         setOperationAction(ISD::MULHU, VT, Expand);
1332       }
1333       setOperationAction(ISD::SMUL_LOHI, VT, Expand);
1334       setOperationAction(ISD::UMUL_LOHI, VT, Expand);
1335 
1336       setOperationAction(ISD::BSWAP, VT, Expand);
1337       setOperationAction(ISD::CTTZ, VT, Expand);
1338 
1339       for (MVT InnerVT : MVT::fixedlen_vector_valuetypes()) {
1340         setTruncStoreAction(VT, InnerVT, Expand);
1341         setLoadExtAction(ISD::SEXTLOAD, VT, InnerVT, Expand);
1342         setLoadExtAction(ISD::ZEXTLOAD, VT, InnerVT, Expand);
1343         setLoadExtAction(ISD::EXTLOAD, VT, InnerVT, Expand);
1344       }
1345     }
1346 
1347     // AArch64 has implementations of a lot of rounding-like FP operations.
1348     for (auto Op :
1349          {ISD::FFLOOR, ISD::FNEARBYINT, ISD::FCEIL, ISD::FRINT, ISD::FTRUNC,
1350           ISD::FROUND, ISD::FROUNDEVEN, ISD::STRICT_FFLOOR,
1351           ISD::STRICT_FNEARBYINT, ISD::STRICT_FCEIL, ISD::STRICT_FRINT,
1352           ISD::STRICT_FTRUNC, ISD::STRICT_FROUND, ISD::STRICT_FROUNDEVEN}) {
1353       for (MVT Ty : {MVT::v2f32, MVT::v4f32, MVT::v2f64})
1354         setOperationAction(Op, Ty, Legal);
1355       if (Subtarget->hasFullFP16())
1356         for (MVT Ty : {MVT::v4f16, MVT::v8f16})
1357           setOperationAction(Op, Ty, Legal);
1358     }
1359 
1360     // LRINT and LLRINT.
1361     for (auto Op : {ISD::LRINT, ISD::LLRINT}) {
1362       for (MVT Ty : {MVT::v2f32, MVT::v4f32, MVT::v2f64})
1363         setOperationAction(Op, Ty, Custom);
1364       if (Subtarget->hasFullFP16())
1365         for (MVT Ty : {MVT::v4f16, MVT::v8f16})
1366           setOperationAction(Op, Ty, Custom);
1367     }
1368 
1369     setTruncStoreAction(MVT::v4i16, MVT::v4i8, Custom);
1370 
1371     setOperationAction(ISD::BITCAST, MVT::i2, Custom);
1372     setOperationAction(ISD::BITCAST, MVT::i4, Custom);
1373     setOperationAction(ISD::BITCAST, MVT::i8, Custom);
1374     setOperationAction(ISD::BITCAST, MVT::i16, Custom);
1375 
1376     setOperationAction(ISD::BITCAST, MVT::v2i8, Custom);
1377     setOperationAction(ISD::BITCAST, MVT::v2i16, Custom);
1378     setOperationAction(ISD::BITCAST, MVT::v4i8, Custom);
1379 
1380     setLoadExtAction(ISD::EXTLOAD,  MVT::v4i16, MVT::v4i8, Custom);
1381     setLoadExtAction(ISD::SEXTLOAD, MVT::v4i16, MVT::v4i8, Custom);
1382     setLoadExtAction(ISD::ZEXTLOAD, MVT::v4i16, MVT::v4i8, Custom);
1383     setLoadExtAction(ISD::EXTLOAD,  MVT::v4i32, MVT::v4i8, Custom);
1384     setLoadExtAction(ISD::SEXTLOAD, MVT::v4i32, MVT::v4i8, Custom);
1385     setLoadExtAction(ISD::ZEXTLOAD, MVT::v4i32, MVT::v4i8, Custom);
1386 
1387     // ADDP custom lowering
1388     for (MVT VT : { MVT::v32i8, MVT::v16i16, MVT::v8i32, MVT::v4i64 })
1389       setOperationAction(ISD::ADD, VT, Custom);
1390     // FADDP custom lowering
1391     for (MVT VT : { MVT::v16f16, MVT::v8f32, MVT::v4f64 })
1392       setOperationAction(ISD::FADD, VT, Custom);
1393   } else /* !isNeonAvailable */ {
1394     for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
1395       for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op)
1396         setOperationAction(Op, VT, Expand);
1397 
1398       if (VT.is128BitVector() || VT.is64BitVector()) {
1399         setOperationAction(ISD::LOAD, VT, Legal);
1400         setOperationAction(ISD::STORE, VT, Legal);
1401         setOperationAction(ISD::BITCAST, VT,
1402                            Subtarget->isLittleEndian() ? Legal : Expand);
1403       }
1404       for (MVT InnerVT : MVT::fixedlen_vector_valuetypes()) {
1405         setTruncStoreAction(VT, InnerVT, Expand);
1406         setLoadExtAction(ISD::SEXTLOAD, VT, InnerVT, Expand);
1407         setLoadExtAction(ISD::ZEXTLOAD, VT, InnerVT, Expand);
1408         setLoadExtAction(ISD::EXTLOAD, VT, InnerVT, Expand);
1409       }
1410     }
1411   }
1412 
1413   if (Subtarget->hasSME()) {
1414     setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
1415   }
1416 
1417   // FIXME: Move lowering for more nodes here if those are common between
1418   // SVE and SME.
1419   if (Subtarget->isSVEorStreamingSVEAvailable()) {
1420     for (auto VT :
1421          {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1, MVT::nxv1i1}) {
1422       setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
1423       setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
1424       setOperationAction(ISD::VECTOR_DEINTERLEAVE, VT, Custom);
1425       setOperationAction(ISD::VECTOR_INTERLEAVE, VT, Custom);
1426     }
1427   }
1428 
1429   if (Subtarget->isSVEorStreamingSVEAvailable()) {
1430     for (auto VT : {MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32, MVT::nxv2i64}) {
1431       setOperationAction(ISD::BITREVERSE, VT, Custom);
1432       setOperationAction(ISD::BSWAP, VT, Custom);
1433       setOperationAction(ISD::CTLZ, VT, Custom);
1434       setOperationAction(ISD::CTPOP, VT, Custom);
1435       setOperationAction(ISD::CTTZ, VT, Custom);
1436       setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
1437       setOperationAction(ISD::UINT_TO_FP, VT, Custom);
1438       setOperationAction(ISD::SINT_TO_FP, VT, Custom);
1439       setOperationAction(ISD::FP_TO_UINT, VT, Custom);
1440       setOperationAction(ISD::FP_TO_SINT, VT, Custom);
1441       setOperationAction(ISD::MLOAD, VT, Custom);
1442       setOperationAction(ISD::MUL, VT, Custom);
1443       setOperationAction(ISD::MULHS, VT, Custom);
1444       setOperationAction(ISD::MULHU, VT, Custom);
1445       setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
1446       setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);
1447       setOperationAction(ISD::SELECT, VT, Custom);
1448       setOperationAction(ISD::SETCC, VT, Custom);
1449       setOperationAction(ISD::SDIV, VT, Custom);
1450       setOperationAction(ISD::UDIV, VT, Custom);
1451       setOperationAction(ISD::SMIN, VT, Custom);
1452       setOperationAction(ISD::UMIN, VT, Custom);
1453       setOperationAction(ISD::SMAX, VT, Custom);
1454       setOperationAction(ISD::UMAX, VT, Custom);
1455       setOperationAction(ISD::SHL, VT, Custom);
1456       setOperationAction(ISD::SRL, VT, Custom);
1457       setOperationAction(ISD::SRA, VT, Custom);
1458       setOperationAction(ISD::ABS, VT, Custom);
1459       setOperationAction(ISD::ABDS, VT, Custom);
1460       setOperationAction(ISD::ABDU, VT, Custom);
1461       setOperationAction(ISD::VECREDUCE_ADD, VT, Custom);
1462       setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
1463       setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
1464       setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
1465       setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
1466       setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
1467       setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom);
1468       setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom);
1469       setOperationAction(ISD::VECTOR_DEINTERLEAVE, VT, Custom);
1470       setOperationAction(ISD::VECTOR_INTERLEAVE, VT, Custom);
1471 
1472       setOperationAction(ISD::UMUL_LOHI, VT, Expand);
1473       setOperationAction(ISD::SMUL_LOHI, VT, Expand);
1474       setOperationAction(ISD::SELECT_CC, VT, Expand);
1475       setOperationAction(ISD::ROTL, VT, Expand);
1476       setOperationAction(ISD::ROTR, VT, Expand);
1477 
1478       setOperationAction(ISD::SADDSAT, VT, Legal);
1479       setOperationAction(ISD::UADDSAT, VT, Legal);
1480       setOperationAction(ISD::SSUBSAT, VT, Legal);
1481       setOperationAction(ISD::USUBSAT, VT, Legal);
1482       setOperationAction(ISD::UREM, VT, Expand);
1483       setOperationAction(ISD::SREM, VT, Expand);
1484       setOperationAction(ISD::SDIVREM, VT, Expand);
1485       setOperationAction(ISD::UDIVREM, VT, Expand);
1486 
1487       setOperationAction(ISD::AVGFLOORS, VT, Custom);
1488       setOperationAction(ISD::AVGFLOORU, VT, Custom);
1489       setOperationAction(ISD::AVGCEILS, VT, Custom);
1490       setOperationAction(ISD::AVGCEILU, VT, Custom);
1491 
1492       if (!Subtarget->isLittleEndian())
1493         setOperationAction(ISD::BITCAST, VT, Expand);
1494 
1495       if (Subtarget->hasSVE2() ||
1496           (Subtarget->hasSME() && Subtarget->isStreaming()))
1497         // For SLI/SRI.
1498         setOperationAction(ISD::OR, VT, Custom);
1499     }
1500 
1501     // Illegal unpacked integer vector types.
1502     for (auto VT : {MVT::nxv8i8, MVT::nxv4i16, MVT::nxv2i32}) {
1503       setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
1504       setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
1505     }
1506 
1507     // Legalize unpacked bitcasts to REINTERPRET_CAST.
1508     for (auto VT : {MVT::nxv2i16, MVT::nxv4i16, MVT::nxv2i32, MVT::nxv2bf16,
1509                     MVT::nxv4bf16, MVT::nxv2f16, MVT::nxv4f16, MVT::nxv2f32})
1510       setOperationAction(ISD::BITCAST, VT, Custom);
1511 
1512     for (auto VT :
1513          { MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv4i8,
1514            MVT::nxv4i16, MVT::nxv4i32, MVT::nxv8i8, MVT::nxv8i16 })
1515       setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Legal);
1516 
1517     for (auto VT :
1518          {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1, MVT::nxv1i1}) {
1519       setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
1520       setOperationAction(ISD::SELECT, VT, Custom);
1521       setOperationAction(ISD::SETCC, VT, Custom);
1522       setOperationAction(ISD::TRUNCATE, VT, Custom);
1523       setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
1524       setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
1525       setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
1526 
1527       setOperationAction(ISD::SELECT_CC, VT, Expand);
1528       setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
1529       setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
1530 
1531       // There are no legal MVT::nxv16f## based types.
1532       if (VT != MVT::nxv16i1) {
1533         setOperationAction(ISD::SINT_TO_FP, VT, Custom);
1534         setOperationAction(ISD::UINT_TO_FP, VT, Custom);
1535       }
1536     }
1537 
1538     // NEON doesn't support masked loads/stores, but SME and SVE do.
1539     for (auto VT :
1540          {MVT::v4f16, MVT::v8f16, MVT::v2f32, MVT::v4f32, MVT::v1f64,
1541           MVT::v2f64, MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16,
1542           MVT::v2i32, MVT::v4i32, MVT::v1i64, MVT::v2i64}) {
1543       setOperationAction(ISD::MLOAD, VT, Custom);
1544       setOperationAction(ISD::MSTORE, VT, Custom);
1545     }
1546 
1547     // Firstly, exclude all scalable vector extending loads/truncating stores,
1548     // include both integer and floating scalable vector.
1549     for (MVT VT : MVT::scalable_vector_valuetypes()) {
1550       for (MVT InnerVT : MVT::scalable_vector_valuetypes()) {
1551         setTruncStoreAction(VT, InnerVT, Expand);
1552         setLoadExtAction(ISD::SEXTLOAD, VT, InnerVT, Expand);
1553         setLoadExtAction(ISD::ZEXTLOAD, VT, InnerVT, Expand);
1554         setLoadExtAction(ISD::EXTLOAD, VT, InnerVT, Expand);
1555       }
1556     }
1557 
1558     // Then, selectively enable those which we directly support.
1559     setTruncStoreAction(MVT::nxv2i64, MVT::nxv2i8, Legal);
1560     setTruncStoreAction(MVT::nxv2i64, MVT::nxv2i16, Legal);
1561     setTruncStoreAction(MVT::nxv2i64, MVT::nxv2i32, Legal);
1562     setTruncStoreAction(MVT::nxv4i32, MVT::nxv4i8, Legal);
1563     setTruncStoreAction(MVT::nxv4i32, MVT::nxv4i16, Legal);
1564     setTruncStoreAction(MVT::nxv8i16, MVT::nxv8i8, Legal);
1565     for (auto Op : {ISD::ZEXTLOAD, ISD::SEXTLOAD, ISD::EXTLOAD}) {
1566       setLoadExtAction(Op, MVT::nxv2i64, MVT::nxv2i8, Legal);
1567       setLoadExtAction(Op, MVT::nxv2i64, MVT::nxv2i16, Legal);
1568       setLoadExtAction(Op, MVT::nxv2i64, MVT::nxv2i32, Legal);
1569       setLoadExtAction(Op, MVT::nxv4i32, MVT::nxv4i8, Legal);
1570       setLoadExtAction(Op, MVT::nxv4i32, MVT::nxv4i16, Legal);
1571       setLoadExtAction(Op, MVT::nxv8i16, MVT::nxv8i8, Legal);
1572     }
1573 
1574     // SVE supports truncating stores of 64 and 128-bit vectors
1575     setTruncStoreAction(MVT::v2i64, MVT::v2i8, Custom);
1576     setTruncStoreAction(MVT::v2i64, MVT::v2i16, Custom);
1577     setTruncStoreAction(MVT::v2i64, MVT::v2i32, Custom);
1578     setTruncStoreAction(MVT::v2i32, MVT::v2i8, Custom);
1579     setTruncStoreAction(MVT::v2i32, MVT::v2i16, Custom);
1580 
1581     for (auto VT : {MVT::nxv2f16, MVT::nxv4f16, MVT::nxv8f16, MVT::nxv2f32,
1582                     MVT::nxv4f32, MVT::nxv2f64}) {
1583       setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
1584       setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
1585       setOperationAction(ISD::MLOAD, VT, Custom);
1586       setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
1587       setOperationAction(ISD::SELECT, VT, Custom);
1588       setOperationAction(ISD::SETCC, VT, Custom);
1589       setOperationAction(ISD::FADD, VT, Custom);
1590       setOperationAction(ISD::FCOPYSIGN, VT, Custom);
1591       setOperationAction(ISD::FDIV, VT, Custom);
1592       setOperationAction(ISD::FMA, VT, Custom);
1593       setOperationAction(ISD::FMAXIMUM, VT, Custom);
1594       setOperationAction(ISD::FMAXNUM, VT, Custom);
1595       setOperationAction(ISD::FMINIMUM, VT, Custom);
1596       setOperationAction(ISD::FMINNUM, VT, Custom);
1597       setOperationAction(ISD::FMUL, VT, Custom);
1598       setOperationAction(ISD::FNEG, VT, Custom);
1599       setOperationAction(ISD::FSUB, VT, Custom);
1600       setOperationAction(ISD::FCEIL, VT, Custom);
1601       setOperationAction(ISD::FFLOOR, VT, Custom);
1602       setOperationAction(ISD::FNEARBYINT, VT, Custom);
1603       setOperationAction(ISD::FRINT, VT, Custom);
1604       setOperationAction(ISD::LRINT, VT, Custom);
1605       setOperationAction(ISD::LLRINT, VT, Custom);
1606       setOperationAction(ISD::FROUND, VT, Custom);
1607       setOperationAction(ISD::FROUNDEVEN, VT, Custom);
1608       setOperationAction(ISD::FTRUNC, VT, Custom);
1609       setOperationAction(ISD::FSQRT, VT, Custom);
1610       setOperationAction(ISD::FABS, VT, Custom);
1611       setOperationAction(ISD::FP_EXTEND, VT, Custom);
1612       setOperationAction(ISD::FP_ROUND, VT, Custom);
1613       setOperationAction(ISD::VECREDUCE_FADD, VT, Custom);
1614       setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom);
1615       setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom);
1616       setOperationAction(ISD::VECREDUCE_FMAXIMUM, VT, Custom);
1617       setOperationAction(ISD::VECREDUCE_FMINIMUM, VT, Custom);
1618       setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);
1619       setOperationAction(ISD::VECTOR_DEINTERLEAVE, VT, Custom);
1620       setOperationAction(ISD::VECTOR_INTERLEAVE, VT, Custom);
1621 
1622       setOperationAction(ISD::SELECT_CC, VT, Expand);
1623       setOperationAction(ISD::FREM, VT, Expand);
1624       setOperationAction(ISD::FPOW, VT, Expand);
1625       setOperationAction(ISD::FPOWI, VT, Expand);
1626       setOperationAction(ISD::FCOS, VT, Expand);
1627       setOperationAction(ISD::FSIN, VT, Expand);
1628       setOperationAction(ISD::FSINCOS, VT, Expand);
1629       setOperationAction(ISD::FTAN, VT, Expand);
1630       setOperationAction(ISD::FACOS, VT, Expand);
1631       setOperationAction(ISD::FASIN, VT, Expand);
1632       setOperationAction(ISD::FATAN, VT, Expand);
1633       setOperationAction(ISD::FCOSH, VT, Expand);
1634       setOperationAction(ISD::FSINH, VT, Expand);
1635       setOperationAction(ISD::FTANH, VT, Expand);
1636       setOperationAction(ISD::FEXP, VT, Expand);
1637       setOperationAction(ISD::FEXP2, VT, Expand);
1638       setOperationAction(ISD::FEXP10, VT, Expand);
1639       setOperationAction(ISD::FLOG, VT, Expand);
1640       setOperationAction(ISD::FLOG2, VT, Expand);
1641       setOperationAction(ISD::FLOG10, VT, Expand);
1642 
1643       setCondCodeAction(ISD::SETO, VT, Expand);
1644       setCondCodeAction(ISD::SETOLT, VT, Expand);
1645       setCondCodeAction(ISD::SETLT, VT, Expand);
1646       setCondCodeAction(ISD::SETOLE, VT, Expand);
1647       setCondCodeAction(ISD::SETLE, VT, Expand);
1648       setCondCodeAction(ISD::SETULT, VT, Expand);
1649       setCondCodeAction(ISD::SETULE, VT, Expand);
1650       setCondCodeAction(ISD::SETUGE, VT, Expand);
1651       setCondCodeAction(ISD::SETUGT, VT, Expand);
1652       setCondCodeAction(ISD::SETUEQ, VT, Expand);
1653       setCondCodeAction(ISD::SETONE, VT, Expand);
1654 
1655       if (!Subtarget->isLittleEndian())
1656         setOperationAction(ISD::BITCAST, VT, Expand);
1657     }
1658 
1659     for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
1660       setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
1661       setOperationAction(ISD::MLOAD, VT, Custom);
1662       setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
1663       setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
1664       setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);
1665 
1666       if (!Subtarget->isLittleEndian())
1667         setOperationAction(ISD::BITCAST, VT, Expand);
1668     }
1669 
1670     setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
1671     setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i16, Custom);
1672 
1673     // NEON doesn't support integer divides, but SVE does
1674     for (auto VT : {MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16, MVT::v2i32,
1675                     MVT::v4i32, MVT::v1i64, MVT::v2i64}) {
1676       setOperationAction(ISD::SDIV, VT, Custom);
1677       setOperationAction(ISD::UDIV, VT, Custom);
1678     }
1679 
1680     // NEON doesn't support 64-bit vector integer muls, but SVE does.
1681     setOperationAction(ISD::MUL, MVT::v1i64, Custom);
1682     setOperationAction(ISD::MUL, MVT::v2i64, Custom);
1683 
1684     // NOTE: Currently this has to happen after computeRegisterProperties rather
1685     // than the preferred option of combining it with the addRegisterClass call.
1686     if (Subtarget->useSVEForFixedLengthVectors()) {
1687       for (MVT VT : MVT::integer_fixedlen_vector_valuetypes()) {
1688         if (useSVEForFixedLengthVectorVT(
1689                 VT, /*OverrideNEON=*/!Subtarget->isNeonAvailable()))
1690           addTypeForFixedLengthSVE(VT);
1691       }
1692       for (MVT VT : MVT::fp_fixedlen_vector_valuetypes()) {
1693         if (useSVEForFixedLengthVectorVT(
1694                 VT, /*OverrideNEON=*/!Subtarget->isNeonAvailable()))
1695           addTypeForFixedLengthSVE(VT);
1696       }
1697 
1698       // 64bit results can mean a bigger than NEON input.
1699       for (auto VT : {MVT::v8i8, MVT::v4i16})
1700         setOperationAction(ISD::TRUNCATE, VT, Custom);
1701       setOperationAction(ISD::FP_ROUND, MVT::v4f16, Custom);
1702 
1703       // 128bit results imply a bigger than NEON input.
1704       for (auto VT : {MVT::v16i8, MVT::v8i16, MVT::v4i32})
1705         setOperationAction(ISD::TRUNCATE, VT, Custom);
1706       for (auto VT : {MVT::v8f16, MVT::v4f32})
1707         setOperationAction(ISD::FP_ROUND, VT, Custom);
1708 
1709       // These operations are not supported on NEON but SVE can do them.
1710       setOperationAction(ISD::BITREVERSE, MVT::v1i64, Custom);
1711       setOperationAction(ISD::CTLZ, MVT::v1i64, Custom);
1712       setOperationAction(ISD::CTLZ, MVT::v2i64, Custom);
1713       setOperationAction(ISD::CTTZ, MVT::v1i64, Custom);
1714       setOperationAction(ISD::MULHS, MVT::v1i64, Custom);
1715       setOperationAction(ISD::MULHS, MVT::v2i64, Custom);
1716       setOperationAction(ISD::MULHU, MVT::v1i64, Custom);
1717       setOperationAction(ISD::MULHU, MVT::v2i64, Custom);
1718       setOperationAction(ISD::SMAX, MVT::v1i64, Custom);
1719       setOperationAction(ISD::SMAX, MVT::v2i64, Custom);
1720       setOperationAction(ISD::SMIN, MVT::v1i64, Custom);
1721       setOperationAction(ISD::SMIN, MVT::v2i64, Custom);
1722       setOperationAction(ISD::UMAX, MVT::v1i64, Custom);
1723       setOperationAction(ISD::UMAX, MVT::v2i64, Custom);
1724       setOperationAction(ISD::UMIN, MVT::v1i64, Custom);
1725       setOperationAction(ISD::UMIN, MVT::v2i64, Custom);
1726       setOperationAction(ISD::VECREDUCE_SMAX, MVT::v2i64, Custom);
1727       setOperationAction(ISD::VECREDUCE_SMIN, MVT::v2i64, Custom);
1728       setOperationAction(ISD::VECREDUCE_UMAX, MVT::v2i64, Custom);
1729       setOperationAction(ISD::VECREDUCE_UMIN, MVT::v2i64, Custom);
1730 
1731       // Int operations with no NEON support.
1732       for (auto VT : {MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16,
1733                       MVT::v2i32, MVT::v4i32, MVT::v2i64}) {
1734         setOperationAction(ISD::BITREVERSE, VT, Custom);
1735         setOperationAction(ISD::CTTZ, VT, Custom);
1736         setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
1737         setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
1738         setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
1739         setOperationAction(ISD::MULHS, VT, Custom);
1740         setOperationAction(ISD::MULHU, VT, Custom);
1741       }
1742 
1743       // Use SVE for vectors with more than 2 elements.
1744       for (auto VT : {MVT::v4f16, MVT::v8f16, MVT::v4f32})
1745         setOperationAction(ISD::VECREDUCE_FADD, VT, Custom);
1746     }
1747 
1748     setOperationPromotedToType(ISD::VECTOR_SPLICE, MVT::nxv2i1, MVT::nxv2i64);
1749     setOperationPromotedToType(ISD::VECTOR_SPLICE, MVT::nxv4i1, MVT::nxv4i32);
1750     setOperationPromotedToType(ISD::VECTOR_SPLICE, MVT::nxv8i1, MVT::nxv8i16);
1751     setOperationPromotedToType(ISD::VECTOR_SPLICE, MVT::nxv16i1, MVT::nxv16i8);
1752 
1753     setOperationAction(ISD::VSCALE, MVT::i32, Custom);
1754 
1755     for (auto VT : {MVT::v16i1, MVT::v8i1, MVT::v4i1, MVT::v2i1})
1756       setOperationAction(ISD::INTRINSIC_WO_CHAIN, VT, Custom);
1757   }
1758 
1759   // Handle operations that are only available in non-streaming SVE mode.
1760   if (Subtarget->isSVEAvailable()) {
1761     for (auto VT : {MVT::nxv16i8,  MVT::nxv8i16, MVT::nxv4i32,  MVT::nxv2i64,
1762                     MVT::nxv2f16,  MVT::nxv4f16, MVT::nxv8f16,  MVT::nxv2f32,
1763                     MVT::nxv4f32,  MVT::nxv2f64, MVT::nxv2bf16, MVT::nxv4bf16,
1764                     MVT::nxv8bf16, MVT::v4f16,   MVT::v8f16,    MVT::v2f32,
1765                     MVT::v4f32,    MVT::v1f64,   MVT::v2f64,    MVT::v8i8,
1766                     MVT::v16i8,    MVT::v4i16,   MVT::v8i16,    MVT::v2i32,
1767                     MVT::v4i32,    MVT::v1i64,   MVT::v2i64}) {
1768       setOperationAction(ISD::MGATHER, VT, Custom);
1769       setOperationAction(ISD::MSCATTER, VT, Custom);
1770     }
1771 
1772     for (auto VT : {MVT::nxv2f16, MVT::nxv4f16, MVT::nxv8f16, MVT::nxv2f32,
1773                     MVT::nxv4f32, MVT::nxv2f64, MVT::v4f16, MVT::v8f16,
1774                     MVT::v2f32, MVT::v4f32, MVT::v2f64})
1775       setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
1776 
1777     // Histcnt is SVE2 only
1778     if (Subtarget->hasSVE2())
1779       setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::Other,
1780                          Custom);
1781   }
1782 
1783 
1784   if (Subtarget->hasMOPS() && Subtarget->hasMTE()) {
1785     // Only required for llvm.aarch64.mops.memset.tag
1786     setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom);
1787   }
1788 
1789   setOperationAction(ISD::INTRINSIC_VOID, MVT::Other, Custom);
1790 
1791   if (Subtarget->hasSVE()) {
1792     setOperationAction(ISD::FLDEXP, MVT::f64, Custom);
1793     setOperationAction(ISD::FLDEXP, MVT::f32, Custom);
1794     setOperationAction(ISD::FLDEXP, MVT::f16, Custom);
1795     setOperationAction(ISD::FLDEXP, MVT::bf16, Custom);
1796   }
1797 
1798   PredictableSelectIsExpensive = Subtarget->predictableSelectIsExpensive();
1799 
1800   IsStrictFPEnabled = true;
1801   setMaxAtomicSizeInBitsSupported(128);
1802 
1803   // On MSVC, both 32-bit and 64-bit, ldexpf(f32) is not defined.  MinGW has
1804   // it, but it's just a wrapper around ldexp.
1805   if (Subtarget->isTargetWindows()) {
1806     for (ISD::NodeType Op : {ISD::FLDEXP, ISD::STRICT_FLDEXP, ISD::FFREXP})
1807       if (isOperationExpand(Op, MVT::f32))
1808         setOperationAction(Op, MVT::f32, Promote);
1809   }
1810 
1811   // LegalizeDAG currently can't expand fp16 LDEXP/FREXP on targets where i16
1812   // isn't legal.
1813   for (ISD::NodeType Op : {ISD::FLDEXP, ISD::STRICT_FLDEXP, ISD::FFREXP})
1814     if (isOperationExpand(Op, MVT::f16))
1815       setOperationAction(Op, MVT::f16, Promote);
1816 
1817   if (Subtarget->isWindowsArm64EC()) {
1818     // FIXME: are there intrinsics we need to exclude from this?
1819     for (int i = 0; i < RTLIB::UNKNOWN_LIBCALL; ++i) {
1820       auto code = static_cast<RTLIB::Libcall>(i);
1821       auto libcallName = getLibcallName(code);
1822       if ((libcallName != nullptr) && (libcallName[0] != '#')) {
1823         setLibcallName(code, Saver.save(Twine("#") + libcallName).data());
1824       }
1825     }
1826   }
1827 }
1828 
addTypeForNEON(MVT VT)1829 void AArch64TargetLowering::addTypeForNEON(MVT VT) {
1830   assert(VT.isVector() && "VT should be a vector type");
1831 
1832   if (VT.isFloatingPoint()) {
1833     MVT PromoteTo = EVT(VT).changeVectorElementTypeToInteger().getSimpleVT();
1834     setOperationPromotedToType(ISD::LOAD, VT, PromoteTo);
1835     setOperationPromotedToType(ISD::STORE, VT, PromoteTo);
1836   }
1837 
1838   // Mark vector float intrinsics as expand.
1839   if (VT == MVT::v2f32 || VT == MVT::v4f32 || VT == MVT::v2f64) {
1840     setOperationAction(ISD::FSIN, VT, Expand);
1841     setOperationAction(ISD::FCOS, VT, Expand);
1842     setOperationAction(ISD::FTAN, VT, Expand);
1843     setOperationAction(ISD::FASIN, VT, Expand);
1844     setOperationAction(ISD::FACOS, VT, Expand);
1845     setOperationAction(ISD::FATAN, VT, Expand);
1846     setOperationAction(ISD::FSINH, VT, Expand);
1847     setOperationAction(ISD::FCOSH, VT, Expand);
1848     setOperationAction(ISD::FTANH, VT, Expand);
1849     setOperationAction(ISD::FPOW, VT, Expand);
1850     setOperationAction(ISD::FLOG, VT, Expand);
1851     setOperationAction(ISD::FLOG2, VT, Expand);
1852     setOperationAction(ISD::FLOG10, VT, Expand);
1853     setOperationAction(ISD::FEXP, VT, Expand);
1854     setOperationAction(ISD::FEXP2, VT, Expand);
1855     setOperationAction(ISD::FEXP10, VT, Expand);
1856   }
1857 
1858   // But we do support custom-lowering for FCOPYSIGN.
1859   if (VT == MVT::v2f32 || VT == MVT::v4f32 || VT == MVT::v2f64 ||
1860       ((VT == MVT::v4bf16 || VT == MVT::v8bf16 || VT == MVT::v4f16 ||
1861         VT == MVT::v8f16) &&
1862        Subtarget->hasFullFP16()))
1863     setOperationAction(ISD::FCOPYSIGN, VT, Custom);
1864 
1865   setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
1866   setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
1867   setOperationAction(ISD::BUILD_VECTOR, VT, Custom);
1868   setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, VT, Custom);
1869   setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom);
1870   setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
1871   setOperationAction(ISD::SRA, VT, Custom);
1872   setOperationAction(ISD::SRL, VT, Custom);
1873   setOperationAction(ISD::SHL, VT, Custom);
1874   setOperationAction(ISD::OR, VT, Custom);
1875   setOperationAction(ISD::SETCC, VT, Custom);
1876   setOperationAction(ISD::CONCAT_VECTORS, VT, Legal);
1877 
1878   setOperationAction(ISD::SELECT, VT, Expand);
1879   setOperationAction(ISD::SELECT_CC, VT, Expand);
1880   setOperationAction(ISD::VSELECT, VT, Expand);
1881   for (MVT InnerVT : MVT::all_valuetypes())
1882     setLoadExtAction(ISD::EXTLOAD, InnerVT, VT, Expand);
1883 
1884   // CNT supports only B element sizes, then use UADDLP to widen.
1885   if (VT != MVT::v8i8 && VT != MVT::v16i8)
1886     setOperationAction(ISD::CTPOP, VT, Custom);
1887 
1888   setOperationAction(ISD::UDIV, VT, Expand);
1889   setOperationAction(ISD::SDIV, VT, Expand);
1890   setOperationAction(ISD::UREM, VT, Expand);
1891   setOperationAction(ISD::SREM, VT, Expand);
1892   setOperationAction(ISD::FREM, VT, Expand);
1893 
1894   for (unsigned Opcode :
1895        {ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FP_TO_SINT_SAT,
1896         ISD::FP_TO_UINT_SAT, ISD::STRICT_FP_TO_SINT, ISD::STRICT_FP_TO_UINT})
1897     setOperationAction(Opcode, VT, Custom);
1898 
1899   if (!VT.isFloatingPoint())
1900     setOperationAction(ISD::ABS, VT, Legal);
1901 
1902   // [SU][MIN|MAX] are available for all NEON types apart from i64.
1903   if (!VT.isFloatingPoint() && VT != MVT::v2i64 && VT != MVT::v1i64)
1904     for (unsigned Opcode : {ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX})
1905       setOperationAction(Opcode, VT, Legal);
1906 
1907   // F[MIN|MAX][NUM|NAN] and simple strict operations are available for all FP
1908   // NEON types.
1909   if (VT.isFloatingPoint() &&
1910       VT.getVectorElementType() != MVT::bf16 &&
1911       (VT.getVectorElementType() != MVT::f16 || Subtarget->hasFullFP16()))
1912     for (unsigned Opcode :
1913          {ISD::FMINIMUM, ISD::FMAXIMUM, ISD::FMINNUM, ISD::FMAXNUM,
1914           ISD::STRICT_FMINIMUM, ISD::STRICT_FMAXIMUM, ISD::STRICT_FMINNUM,
1915           ISD::STRICT_FMAXNUM, ISD::STRICT_FADD, ISD::STRICT_FSUB,
1916           ISD::STRICT_FMUL, ISD::STRICT_FDIV, ISD::STRICT_FMA,
1917           ISD::STRICT_FSQRT})
1918       setOperationAction(Opcode, VT, Legal);
1919 
1920   // Strict fp extend and trunc are legal
1921   if (VT.isFloatingPoint() && VT.getScalarSizeInBits() != 16)
1922     setOperationAction(ISD::STRICT_FP_EXTEND, VT, Legal);
1923   if (VT.isFloatingPoint() && VT.getScalarSizeInBits() != 64)
1924     setOperationAction(ISD::STRICT_FP_ROUND, VT, Legal);
1925 
1926   // FIXME: We could potentially make use of the vector comparison instructions
1927   // for STRICT_FSETCC and STRICT_FSETCSS, but there's a number of
1928   // complications:
1929   //  * FCMPEQ/NE are quiet comparisons, the rest are signalling comparisons,
1930   //    so we would need to expand when the condition code doesn't match the
1931   //    kind of comparison.
1932   //  * Some kinds of comparison require more than one FCMXY instruction so
1933   //    would need to be expanded instead.
1934   //  * The lowering of the non-strict versions involves target-specific ISD
1935   //    nodes so we would likely need to add strict versions of all of them and
1936   //    handle them appropriately.
1937   setOperationAction(ISD::STRICT_FSETCC, VT, Expand);
1938   setOperationAction(ISD::STRICT_FSETCCS, VT, Expand);
1939 
1940   if (Subtarget->isLittleEndian()) {
1941     for (unsigned im = (unsigned)ISD::PRE_INC;
1942          im != (unsigned)ISD::LAST_INDEXED_MODE; ++im) {
1943       setIndexedLoadAction(im, VT, Legal);
1944       setIndexedStoreAction(im, VT, Legal);
1945     }
1946   }
1947 
1948   if (Subtarget->hasD128()) {
1949     setOperationAction(ISD::READ_REGISTER, MVT::i128, Custom);
1950     setOperationAction(ISD::WRITE_REGISTER, MVT::i128, Custom);
1951   }
1952 }
1953 
shouldExpandGetActiveLaneMask(EVT ResVT,EVT OpVT) const1954 bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
1955                                                           EVT OpVT) const {
1956   // Only SVE has a 1:1 mapping from intrinsic -> instruction (whilelo).
1957   if (!Subtarget->hasSVE())
1958     return true;
1959 
1960   // We can only support legal predicate result types. We can use the SVE
1961   // whilelo instruction for generating fixed-width predicates too.
1962   if (ResVT != MVT::nxv2i1 && ResVT != MVT::nxv4i1 && ResVT != MVT::nxv8i1 &&
1963       ResVT != MVT::nxv16i1 && ResVT != MVT::v2i1 && ResVT != MVT::v4i1 &&
1964       ResVT != MVT::v8i1 && ResVT != MVT::v16i1)
1965     return true;
1966 
1967   // The whilelo instruction only works with i32 or i64 scalar inputs.
1968   if (OpVT != MVT::i32 && OpVT != MVT::i64)
1969     return true;
1970 
1971   return false;
1972 }
1973 
shouldExpandCttzElements(EVT VT) const1974 bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
1975   if (!Subtarget->isSVEorStreamingSVEAvailable())
1976     return true;
1977 
1978   // We can only use the BRKB + CNTP sequence with legal predicate types. We can
1979   // also support fixed-width predicates.
1980   return VT != MVT::nxv16i1 && VT != MVT::nxv8i1 && VT != MVT::nxv4i1 &&
1981          VT != MVT::nxv2i1 && VT != MVT::v16i1 && VT != MVT::v8i1 &&
1982          VT != MVT::v4i1 && VT != MVT::v2i1;
1983 }
1984 
addTypeForFixedLengthSVE(MVT VT)1985 void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
1986   assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
1987 
1988   // By default everything must be expanded.
1989   for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op)
1990     setOperationAction(Op, VT, Expand);
1991 
1992   if (VT.isFloatingPoint()) {
1993     setCondCodeAction(ISD::SETO, VT, Expand);
1994     setCondCodeAction(ISD::SETOLT, VT, Expand);
1995     setCondCodeAction(ISD::SETOLE, VT, Expand);
1996     setCondCodeAction(ISD::SETULT, VT, Expand);
1997     setCondCodeAction(ISD::SETULE, VT, Expand);
1998     setCondCodeAction(ISD::SETUGE, VT, Expand);
1999     setCondCodeAction(ISD::SETUGT, VT, Expand);
2000     setCondCodeAction(ISD::SETUEQ, VT, Expand);
2001     setCondCodeAction(ISD::SETONE, VT, Expand);
2002   }
2003 
2004   TargetLoweringBase::LegalizeAction Default =
2005       VT == MVT::v1f64 ? Expand : Custom;
2006 
2007   // Mark integer truncating stores/extending loads as having custom lowering
2008   if (VT.isInteger()) {
2009     MVT InnerVT = VT.changeVectorElementType(MVT::i8);
2010     while (InnerVT != VT) {
2011       setTruncStoreAction(VT, InnerVT, Default);
2012       setLoadExtAction(ISD::ZEXTLOAD, VT, InnerVT, Default);
2013       setLoadExtAction(ISD::SEXTLOAD, VT, InnerVT, Default);
2014       setLoadExtAction(ISD::EXTLOAD, VT, InnerVT, Default);
2015       InnerVT = InnerVT.changeVectorElementType(
2016           MVT::getIntegerVT(2 * InnerVT.getScalarSizeInBits()));
2017     }
2018   }
2019 
2020   // Mark floating-point truncating stores/extending loads as having custom
2021   // lowering
2022   if (VT.isFloatingPoint()) {
2023     MVT InnerVT = VT.changeVectorElementType(MVT::f16);
2024     while (InnerVT != VT) {
2025       setTruncStoreAction(VT, InnerVT, Custom);
2026       setLoadExtAction(ISD::EXTLOAD, VT, InnerVT, Default);
2027       InnerVT = InnerVT.changeVectorElementType(
2028           MVT::getFloatingPointVT(2 * InnerVT.getScalarSizeInBits()));
2029     }
2030   }
2031 
2032   bool PreferNEON = VT.is64BitVector() || VT.is128BitVector();
2033   bool PreferSVE = !PreferNEON && Subtarget->isSVEAvailable();
2034 
2035   // Lower fixed length vector operations to scalable equivalents.
2036   setOperationAction(ISD::ABS, VT, Default);
2037   setOperationAction(ISD::ADD, VT, Default);
2038   setOperationAction(ISD::AND, VT, Default);
2039   setOperationAction(ISD::ANY_EXTEND, VT, Default);
2040   setOperationAction(ISD::BITCAST, VT, PreferNEON ? Legal : Default);
2041   setOperationAction(ISD::BITREVERSE, VT, Default);
2042   setOperationAction(ISD::BSWAP, VT, Default);
2043   setOperationAction(ISD::BUILD_VECTOR, VT, Default);
2044   setOperationAction(ISD::CONCAT_VECTORS, VT, Default);
2045   setOperationAction(ISD::CTLZ, VT, Default);
2046   setOperationAction(ISD::CTPOP, VT, Default);
2047   setOperationAction(ISD::CTTZ, VT, Default);
2048   setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Default);
2049   setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Default);
2050   setOperationAction(ISD::FABS, VT, Default);
2051   setOperationAction(ISD::FADD, VT, Default);
2052   setOperationAction(ISD::FCEIL, VT, Default);
2053   setOperationAction(ISD::FCOPYSIGN, VT, Default);
2054   setOperationAction(ISD::FDIV, VT, Default);
2055   setOperationAction(ISD::FFLOOR, VT, Default);
2056   setOperationAction(ISD::FMA, VT, Default);
2057   setOperationAction(ISD::FMAXIMUM, VT, Default);
2058   setOperationAction(ISD::FMAXNUM, VT, Default);
2059   setOperationAction(ISD::FMINIMUM, VT, Default);
2060   setOperationAction(ISD::FMINNUM, VT, Default);
2061   setOperationAction(ISD::FMUL, VT, Default);
2062   setOperationAction(ISD::FNEARBYINT, VT, Default);
2063   setOperationAction(ISD::FNEG, VT, Default);
2064   setOperationAction(ISD::FP_EXTEND, VT, Default);
2065   setOperationAction(ISD::FP_ROUND, VT, Default);
2066   setOperationAction(ISD::FP_TO_SINT, VT, Default);
2067   setOperationAction(ISD::FP_TO_UINT, VT, Default);
2068   setOperationAction(ISD::FRINT, VT, Default);
2069   setOperationAction(ISD::LRINT, VT, Default);
2070   setOperationAction(ISD::LLRINT, VT, Default);
2071   setOperationAction(ISD::FROUND, VT, Default);
2072   setOperationAction(ISD::FROUNDEVEN, VT, Default);
2073   setOperationAction(ISD::FSQRT, VT, Default);
2074   setOperationAction(ISD::FSUB, VT, Default);
2075   setOperationAction(ISD::FTRUNC, VT, Default);
2076   setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Default);
2077   setOperationAction(ISD::LOAD, VT, PreferNEON ? Legal : Default);
2078   setOperationAction(ISD::MGATHER, VT, PreferSVE ? Default : Expand);
2079   setOperationAction(ISD::MLOAD, VT, Default);
2080   setOperationAction(ISD::MSCATTER, VT, PreferSVE ? Default : Expand);
2081   setOperationAction(ISD::MSTORE, VT, Default);
2082   setOperationAction(ISD::MUL, VT, Default);
2083   setOperationAction(ISD::MULHS, VT, Default);
2084   setOperationAction(ISD::MULHU, VT, Default);
2085   setOperationAction(ISD::OR, VT, Default);
2086   setOperationAction(ISD::SCALAR_TO_VECTOR, VT, PreferNEON ? Legal : Expand);
2087   setOperationAction(ISD::SDIV, VT, Default);
2088   setOperationAction(ISD::SELECT, VT, Default);
2089   setOperationAction(ISD::SETCC, VT, Default);
2090   setOperationAction(ISD::SHL, VT, Default);
2091   setOperationAction(ISD::SIGN_EXTEND, VT, Default);
2092   setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Default);
2093   setOperationAction(ISD::SINT_TO_FP, VT, Default);
2094   setOperationAction(ISD::SMAX, VT, Default);
2095   setOperationAction(ISD::SMIN, VT, Default);
2096   setOperationAction(ISD::SPLAT_VECTOR, VT, Default);
2097   setOperationAction(ISD::SRA, VT, Default);
2098   setOperationAction(ISD::SRL, VT, Default);
2099   setOperationAction(ISD::STORE, VT, PreferNEON ? Legal : Default);
2100   setOperationAction(ISD::SUB, VT, Default);
2101   setOperationAction(ISD::TRUNCATE, VT, Default);
2102   setOperationAction(ISD::UDIV, VT, Default);
2103   setOperationAction(ISD::UINT_TO_FP, VT, Default);
2104   setOperationAction(ISD::UMAX, VT, Default);
2105   setOperationAction(ISD::UMIN, VT, Default);
2106   setOperationAction(ISD::VECREDUCE_ADD, VT, Default);
2107   setOperationAction(ISD::VECREDUCE_AND, VT, Default);
2108   setOperationAction(ISD::VECREDUCE_FADD, VT, Default);
2109   setOperationAction(ISD::VECREDUCE_FMAX, VT, Default);
2110   setOperationAction(ISD::VECREDUCE_FMIN, VT, Default);
2111   setOperationAction(ISD::VECREDUCE_FMAXIMUM, VT, Default);
2112   setOperationAction(ISD::VECREDUCE_FMINIMUM, VT, Default);
2113   setOperationAction(ISD::VECREDUCE_OR, VT, Default);
2114   setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, PreferSVE ? Default : Expand);
2115   setOperationAction(ISD::VECREDUCE_SMAX, VT, Default);
2116   setOperationAction(ISD::VECREDUCE_SMIN, VT, Default);
2117   setOperationAction(ISD::VECREDUCE_UMAX, VT, Default);
2118   setOperationAction(ISD::VECREDUCE_UMIN, VT, Default);
2119   setOperationAction(ISD::VECREDUCE_XOR, VT, Default);
2120   setOperationAction(ISD::VECTOR_SHUFFLE, VT, Default);
2121   setOperationAction(ISD::VECTOR_SPLICE, VT, Default);
2122   setOperationAction(ISD::VSELECT, VT, Default);
2123   setOperationAction(ISD::XOR, VT, Default);
2124   setOperationAction(ISD::ZERO_EXTEND, VT, Default);
2125 }
2126 
addDRType(MVT VT)2127 void AArch64TargetLowering::addDRType(MVT VT) {
2128   addRegisterClass(VT, &AArch64::FPR64RegClass);
2129   if (Subtarget->isNeonAvailable())
2130     addTypeForNEON(VT);
2131 }
2132 
addQRType(MVT VT)2133 void AArch64TargetLowering::addQRType(MVT VT) {
2134   addRegisterClass(VT, &AArch64::FPR128RegClass);
2135   if (Subtarget->isNeonAvailable())
2136     addTypeForNEON(VT);
2137 }
2138 
getSetCCResultType(const DataLayout &,LLVMContext & C,EVT VT) const2139 EVT AArch64TargetLowering::getSetCCResultType(const DataLayout &,
2140                                               LLVMContext &C, EVT VT) const {
2141   if (!VT.isVector())
2142     return MVT::i32;
2143   if (VT.isScalableVector())
2144     return EVT::getVectorVT(C, MVT::i1, VT.getVectorElementCount());
2145   return VT.changeVectorElementTypeToInteger();
2146 }
2147 
2148 // isIntImmediate - This method tests to see if the node is a constant
2149 // operand. If so Imm will receive the value.
isIntImmediate(const SDNode * N,uint64_t & Imm)2150 static bool isIntImmediate(const SDNode *N, uint64_t &Imm) {
2151   if (const ConstantSDNode *C = dyn_cast<const ConstantSDNode>(N)) {
2152     Imm = C->getZExtValue();
2153     return true;
2154   }
2155   return false;
2156 }
2157 
2158 // isOpcWithIntImmediate - This method tests to see if the node is a specific
2159 // opcode and that it has a immediate integer right operand.
2160 // If so Imm will receive the value.
isOpcWithIntImmediate(const SDNode * N,unsigned Opc,uint64_t & Imm)2161 static bool isOpcWithIntImmediate(const SDNode *N, unsigned Opc,
2162                                   uint64_t &Imm) {
2163   return N->getOpcode() == Opc &&
2164          isIntImmediate(N->getOperand(1).getNode(), Imm);
2165 }
2166 
optimizeLogicalImm(SDValue Op,unsigned Size,uint64_t Imm,const APInt & Demanded,TargetLowering::TargetLoweringOpt & TLO,unsigned NewOpc)2167 static bool optimizeLogicalImm(SDValue Op, unsigned Size, uint64_t Imm,
2168                                const APInt &Demanded,
2169                                TargetLowering::TargetLoweringOpt &TLO,
2170                                unsigned NewOpc) {
2171   uint64_t OldImm = Imm, NewImm, Enc;
2172   uint64_t Mask = ((uint64_t)(-1LL) >> (64 - Size)), OrigMask = Mask;
2173 
2174   // Return if the immediate is already all zeros, all ones, a bimm32 or a
2175   // bimm64.
2176   if (Imm == 0 || Imm == Mask ||
2177       AArch64_AM::isLogicalImmediate(Imm & Mask, Size))
2178     return false;
2179 
2180   unsigned EltSize = Size;
2181   uint64_t DemandedBits = Demanded.getZExtValue();
2182 
2183   // Clear bits that are not demanded.
2184   Imm &= DemandedBits;
2185 
2186   while (true) {
2187     // The goal here is to set the non-demanded bits in a way that minimizes
2188     // the number of switching between 0 and 1. In order to achieve this goal,
2189     // we set the non-demanded bits to the value of the preceding demanded bits.
2190     // For example, if we have an immediate 0bx10xx0x1 ('x' indicates a
2191     // non-demanded bit), we copy bit0 (1) to the least significant 'x',
2192     // bit2 (0) to 'xx', and bit6 (1) to the most significant 'x'.
2193     // The final result is 0b11000011.
2194     uint64_t NonDemandedBits = ~DemandedBits;
2195     uint64_t InvertedImm = ~Imm & DemandedBits;
2196     uint64_t RotatedImm =
2197         ((InvertedImm << 1) | (InvertedImm >> (EltSize - 1) & 1)) &
2198         NonDemandedBits;
2199     uint64_t Sum = RotatedImm + NonDemandedBits;
2200     bool Carry = NonDemandedBits & ~Sum & (1ULL << (EltSize - 1));
2201     uint64_t Ones = (Sum + Carry) & NonDemandedBits;
2202     NewImm = (Imm | Ones) & Mask;
2203 
2204     // If NewImm or its bitwise NOT is a shifted mask, it is a bitmask immediate
2205     // or all-ones or all-zeros, in which case we can stop searching. Otherwise,
2206     // we halve the element size and continue the search.
2207     if (isShiftedMask_64(NewImm) || isShiftedMask_64(~(NewImm | ~Mask)))
2208       break;
2209 
2210     // We cannot shrink the element size any further if it is 2-bits.
2211     if (EltSize == 2)
2212       return false;
2213 
2214     EltSize /= 2;
2215     Mask >>= EltSize;
2216     uint64_t Hi = Imm >> EltSize, DemandedBitsHi = DemandedBits >> EltSize;
2217 
2218     // Return if there is mismatch in any of the demanded bits of Imm and Hi.
2219     if (((Imm ^ Hi) & (DemandedBits & DemandedBitsHi) & Mask) != 0)
2220       return false;
2221 
2222     // Merge the upper and lower halves of Imm and DemandedBits.
2223     Imm |= Hi;
2224     DemandedBits |= DemandedBitsHi;
2225   }
2226 
2227   ++NumOptimizedImms;
2228 
2229   // Replicate the element across the register width.
2230   while (EltSize < Size) {
2231     NewImm |= NewImm << EltSize;
2232     EltSize *= 2;
2233   }
2234 
2235   (void)OldImm;
2236   assert(((OldImm ^ NewImm) & Demanded.getZExtValue()) == 0 &&
2237          "demanded bits should never be altered");
2238   assert(OldImm != NewImm && "the new imm shouldn't be equal to the old imm");
2239 
2240   // Create the new constant immediate node.
2241   EVT VT = Op.getValueType();
2242   SDLoc DL(Op);
2243   SDValue New;
2244 
2245   // If the new constant immediate is all-zeros or all-ones, let the target
2246   // independent DAG combine optimize this node.
2247   if (NewImm == 0 || NewImm == OrigMask) {
2248     New = TLO.DAG.getNode(Op.getOpcode(), DL, VT, Op.getOperand(0),
2249                           TLO.DAG.getConstant(NewImm, DL, VT));
2250   // Otherwise, create a machine node so that target independent DAG combine
2251   // doesn't undo this optimization.
2252   } else {
2253     Enc = AArch64_AM::encodeLogicalImmediate(NewImm, Size);
2254     SDValue EncConst = TLO.DAG.getTargetConstant(Enc, DL, VT);
2255     New = SDValue(
2256         TLO.DAG.getMachineNode(NewOpc, DL, VT, Op.getOperand(0), EncConst), 0);
2257   }
2258 
2259   return TLO.CombineTo(Op, New);
2260 }
2261 
targetShrinkDemandedConstant(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts,TargetLoweringOpt & TLO) const2262 bool AArch64TargetLowering::targetShrinkDemandedConstant(
2263     SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
2264     TargetLoweringOpt &TLO) const {
2265   // Delay this optimization to as late as possible.
2266   if (!TLO.LegalOps)
2267     return false;
2268 
2269   if (!EnableOptimizeLogicalImm)
2270     return false;
2271 
2272   EVT VT = Op.getValueType();
2273   if (VT.isVector())
2274     return false;
2275 
2276   unsigned Size = VT.getSizeInBits();
2277   assert((Size == 32 || Size == 64) &&
2278          "i32 or i64 is expected after legalization.");
2279 
2280   // Exit early if we demand all bits.
2281   if (DemandedBits.popcount() == Size)
2282     return false;
2283 
2284   unsigned NewOpc;
2285   switch (Op.getOpcode()) {
2286   default:
2287     return false;
2288   case ISD::AND:
2289     NewOpc = Size == 32 ? AArch64::ANDWri : AArch64::ANDXri;
2290     break;
2291   case ISD::OR:
2292     NewOpc = Size == 32 ? AArch64::ORRWri : AArch64::ORRXri;
2293     break;
2294   case ISD::XOR:
2295     NewOpc = Size == 32 ? AArch64::EORWri : AArch64::EORXri;
2296     break;
2297   }
2298   ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op.getOperand(1));
2299   if (!C)
2300     return false;
2301   uint64_t Imm = C->getZExtValue();
2302   return optimizeLogicalImm(Op, Size, Imm, DemandedBits, TLO, NewOpc);
2303 }
2304 
2305 /// computeKnownBitsForTargetNode - Determine which of the bits specified in
2306 /// Mask are known to be either zero or one and return them Known.
computeKnownBitsForTargetNode(const SDValue Op,KnownBits & Known,const APInt & DemandedElts,const SelectionDAG & DAG,unsigned Depth) const2307 void AArch64TargetLowering::computeKnownBitsForTargetNode(
2308     const SDValue Op, KnownBits &Known, const APInt &DemandedElts,
2309     const SelectionDAG &DAG, unsigned Depth) const {
2310   switch (Op.getOpcode()) {
2311   default:
2312     break;
2313   case AArch64ISD::DUP: {
2314     SDValue SrcOp = Op.getOperand(0);
2315     Known = DAG.computeKnownBits(SrcOp, Depth + 1);
2316     if (SrcOp.getValueSizeInBits() != Op.getScalarValueSizeInBits()) {
2317       assert(SrcOp.getValueSizeInBits() > Op.getScalarValueSizeInBits() &&
2318              "Expected DUP implicit truncation");
2319       Known = Known.trunc(Op.getScalarValueSizeInBits());
2320     }
2321     break;
2322   }
2323   case AArch64ISD::CSEL: {
2324     KnownBits Known2;
2325     Known = DAG.computeKnownBits(Op->getOperand(0), Depth + 1);
2326     Known2 = DAG.computeKnownBits(Op->getOperand(1), Depth + 1);
2327     Known = Known.intersectWith(Known2);
2328     break;
2329   }
2330   case AArch64ISD::BICi: {
2331     // Compute the bit cleared value.
2332     uint64_t Mask =
2333         ~(Op->getConstantOperandVal(1) << Op->getConstantOperandVal(2));
2334     Known = DAG.computeKnownBits(Op->getOperand(0), Depth + 1);
2335     Known &= KnownBits::makeConstant(APInt(Known.getBitWidth(), Mask));
2336     break;
2337   }
2338   case AArch64ISD::VLSHR: {
2339     KnownBits Known2;
2340     Known = DAG.computeKnownBits(Op->getOperand(0), Depth + 1);
2341     Known2 = DAG.computeKnownBits(Op->getOperand(1), Depth + 1);
2342     Known = KnownBits::lshr(Known, Known2);
2343     break;
2344   }
2345   case AArch64ISD::VASHR: {
2346     KnownBits Known2;
2347     Known = DAG.computeKnownBits(Op->getOperand(0), Depth + 1);
2348     Known2 = DAG.computeKnownBits(Op->getOperand(1), Depth + 1);
2349     Known = KnownBits::ashr(Known, Known2);
2350     break;
2351   }
2352   case AArch64ISD::VSHL: {
2353     KnownBits Known2;
2354     Known = DAG.computeKnownBits(Op->getOperand(0), Depth + 1);
2355     Known2 = DAG.computeKnownBits(Op->getOperand(1), Depth + 1);
2356     Known = KnownBits::shl(Known, Known2);
2357     break;
2358   }
2359   case AArch64ISD::MOVI: {
2360     Known = KnownBits::makeConstant(
2361         APInt(Known.getBitWidth(), Op->getConstantOperandVal(0)));
2362     break;
2363   }
2364   case AArch64ISD::LOADgot:
2365   case AArch64ISD::ADDlow: {
2366     if (!Subtarget->isTargetILP32())
2367       break;
2368     // In ILP32 mode all valid pointers are in the low 4GB of the address-space.
2369     Known.Zero = APInt::getHighBitsSet(64, 32);
2370     break;
2371   }
2372   case AArch64ISD::ASSERT_ZEXT_BOOL: {
2373     Known = DAG.computeKnownBits(Op->getOperand(0), Depth + 1);
2374     Known.Zero |= APInt(Known.getBitWidth(), 0xFE);
2375     break;
2376   }
2377   case ISD::INTRINSIC_W_CHAIN: {
2378     Intrinsic::ID IntID =
2379         static_cast<Intrinsic::ID>(Op->getConstantOperandVal(1));
2380     switch (IntID) {
2381     default: return;
2382     case Intrinsic::aarch64_ldaxr:
2383     case Intrinsic::aarch64_ldxr: {
2384       unsigned BitWidth = Known.getBitWidth();
2385       EVT VT = cast<MemIntrinsicSDNode>(Op)->getMemoryVT();
2386       unsigned MemBits = VT.getScalarSizeInBits();
2387       Known.Zero |= APInt::getHighBitsSet(BitWidth, BitWidth - MemBits);
2388       return;
2389     }
2390     }
2391     break;
2392   }
2393   case ISD::INTRINSIC_WO_CHAIN:
2394   case ISD::INTRINSIC_VOID: {
2395     unsigned IntNo = Op.getConstantOperandVal(0);
2396     switch (IntNo) {
2397     default:
2398       break;
2399     case Intrinsic::aarch64_neon_uaddlv: {
2400       MVT VT = Op.getOperand(1).getValueType().getSimpleVT();
2401       unsigned BitWidth = Known.getBitWidth();
2402       if (VT == MVT::v8i8 || VT == MVT::v16i8) {
2403         unsigned Bound = (VT == MVT::v8i8) ?  11 : 12;
2404         assert(BitWidth >= Bound && "Unexpected width!");
2405         APInt Mask = APInt::getHighBitsSet(BitWidth, BitWidth - Bound);
2406         Known.Zero |= Mask;
2407       }
2408       break;
2409     }
2410     case Intrinsic::aarch64_neon_umaxv:
2411     case Intrinsic::aarch64_neon_uminv: {
2412       // Figure out the datatype of the vector operand. The UMINV instruction
2413       // will zero extend the result, so we can mark as known zero all the
2414       // bits larger than the element datatype. 32-bit or larget doesn't need
2415       // this as those are legal types and will be handled by isel directly.
2416       MVT VT = Op.getOperand(1).getValueType().getSimpleVT();
2417       unsigned BitWidth = Known.getBitWidth();
2418       if (VT == MVT::v8i8 || VT == MVT::v16i8) {
2419         assert(BitWidth >= 8 && "Unexpected width!");
2420         APInt Mask = APInt::getHighBitsSet(BitWidth, BitWidth - 8);
2421         Known.Zero |= Mask;
2422       } else if (VT == MVT::v4i16 || VT == MVT::v8i16) {
2423         assert(BitWidth >= 16 && "Unexpected width!");
2424         APInt Mask = APInt::getHighBitsSet(BitWidth, BitWidth - 16);
2425         Known.Zero |= Mask;
2426       }
2427       break;
2428     } break;
2429     }
2430   }
2431   }
2432 }
2433 
ComputeNumSignBitsForTargetNode(SDValue Op,const APInt & DemandedElts,const SelectionDAG & DAG,unsigned Depth) const2434 unsigned AArch64TargetLowering::ComputeNumSignBitsForTargetNode(
2435     SDValue Op, const APInt &DemandedElts, const SelectionDAG &DAG,
2436     unsigned Depth) const {
2437   EVT VT = Op.getValueType();
2438   unsigned VTBits = VT.getScalarSizeInBits();
2439   unsigned Opcode = Op.getOpcode();
2440   switch (Opcode) {
2441     case AArch64ISD::CMEQ:
2442     case AArch64ISD::CMGE:
2443     case AArch64ISD::CMGT:
2444     case AArch64ISD::CMHI:
2445     case AArch64ISD::CMHS:
2446     case AArch64ISD::FCMEQ:
2447     case AArch64ISD::FCMGE:
2448     case AArch64ISD::FCMGT:
2449     case AArch64ISD::CMEQz:
2450     case AArch64ISD::CMGEz:
2451     case AArch64ISD::CMGTz:
2452     case AArch64ISD::CMLEz:
2453     case AArch64ISD::CMLTz:
2454     case AArch64ISD::FCMEQz:
2455     case AArch64ISD::FCMGEz:
2456     case AArch64ISD::FCMGTz:
2457     case AArch64ISD::FCMLEz:
2458     case AArch64ISD::FCMLTz:
2459       // Compares return either 0 or all-ones
2460       return VTBits;
2461   }
2462 
2463   return 1;
2464 }
2465 
getScalarShiftAmountTy(const DataLayout & DL,EVT) const2466 MVT AArch64TargetLowering::getScalarShiftAmountTy(const DataLayout &DL,
2467                                                   EVT) const {
2468   return MVT::i64;
2469 }
2470 
allowsMisalignedMemoryAccesses(EVT VT,unsigned AddrSpace,Align Alignment,MachineMemOperand::Flags Flags,unsigned * Fast) const2471 bool AArch64TargetLowering::allowsMisalignedMemoryAccesses(
2472     EVT VT, unsigned AddrSpace, Align Alignment, MachineMemOperand::Flags Flags,
2473     unsigned *Fast) const {
2474   if (Subtarget->requiresStrictAlign())
2475     return false;
2476 
2477   if (Fast) {
2478     // Some CPUs are fine with unaligned stores except for 128-bit ones.
2479     *Fast = !Subtarget->isMisaligned128StoreSlow() || VT.getStoreSize() != 16 ||
2480             // See comments in performSTORECombine() for more details about
2481             // these conditions.
2482 
2483             // Code that uses clang vector extensions can mark that it
2484             // wants unaligned accesses to be treated as fast by
2485             // underspecifying alignment to be 1 or 2.
2486             Alignment <= 2 ||
2487 
2488             // Disregard v2i64. Memcpy lowering produces those and splitting
2489             // them regresses performance on micro-benchmarks and olden/bh.
2490             VT == MVT::v2i64;
2491   }
2492   return true;
2493 }
2494 
2495 // Same as above but handling LLTs instead.
allowsMisalignedMemoryAccesses(LLT Ty,unsigned AddrSpace,Align Alignment,MachineMemOperand::Flags Flags,unsigned * Fast) const2496 bool AArch64TargetLowering::allowsMisalignedMemoryAccesses(
2497     LLT Ty, unsigned AddrSpace, Align Alignment, MachineMemOperand::Flags Flags,
2498     unsigned *Fast) const {
2499   if (Subtarget->requiresStrictAlign())
2500     return false;
2501 
2502   if (Fast) {
2503     // Some CPUs are fine with unaligned stores except for 128-bit ones.
2504     *Fast = !Subtarget->isMisaligned128StoreSlow() ||
2505             Ty.getSizeInBytes() != 16 ||
2506             // See comments in performSTORECombine() for more details about
2507             // these conditions.
2508 
2509             // Code that uses clang vector extensions can mark that it
2510             // wants unaligned accesses to be treated as fast by
2511             // underspecifying alignment to be 1 or 2.
2512             Alignment <= 2 ||
2513 
2514             // Disregard v2i64. Memcpy lowering produces those and splitting
2515             // them regresses performance on micro-benchmarks and olden/bh.
2516             Ty == LLT::fixed_vector(2, 64);
2517   }
2518   return true;
2519 }
2520 
2521 FastISel *
createFastISel(FunctionLoweringInfo & funcInfo,const TargetLibraryInfo * libInfo) const2522 AArch64TargetLowering::createFastISel(FunctionLoweringInfo &funcInfo,
2523                                       const TargetLibraryInfo *libInfo) const {
2524   return AArch64::createFastISel(funcInfo, libInfo);
2525 }
2526 
getTargetNodeName(unsigned Opcode) const2527 const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
2528 #define MAKE_CASE(V)                                                           \
2529   case V:                                                                      \
2530     return #V;
2531   switch ((AArch64ISD::NodeType)Opcode) {
2532   case AArch64ISD::FIRST_NUMBER:
2533     break;
2534     MAKE_CASE(AArch64ISD::ALLOCATE_ZA_BUFFER)
2535     MAKE_CASE(AArch64ISD::INIT_TPIDR2OBJ)
2536     MAKE_CASE(AArch64ISD::COALESCER_BARRIER)
2537     MAKE_CASE(AArch64ISD::VG_SAVE)
2538     MAKE_CASE(AArch64ISD::VG_RESTORE)
2539     MAKE_CASE(AArch64ISD::SMSTART)
2540     MAKE_CASE(AArch64ISD::SMSTOP)
2541     MAKE_CASE(AArch64ISD::RESTORE_ZA)
2542     MAKE_CASE(AArch64ISD::RESTORE_ZT)
2543     MAKE_CASE(AArch64ISD::SAVE_ZT)
2544     MAKE_CASE(AArch64ISD::CALL)
2545     MAKE_CASE(AArch64ISD::ADRP)
2546     MAKE_CASE(AArch64ISD::ADR)
2547     MAKE_CASE(AArch64ISD::ADDlow)
2548     MAKE_CASE(AArch64ISD::AUTH_CALL)
2549     MAKE_CASE(AArch64ISD::AUTH_TC_RETURN)
2550     MAKE_CASE(AArch64ISD::AUTH_CALL_RVMARKER)
2551     MAKE_CASE(AArch64ISD::LOADgot)
2552     MAKE_CASE(AArch64ISD::RET_GLUE)
2553     MAKE_CASE(AArch64ISD::BRCOND)
2554     MAKE_CASE(AArch64ISD::CSEL)
2555     MAKE_CASE(AArch64ISD::CSINV)
2556     MAKE_CASE(AArch64ISD::CSNEG)
2557     MAKE_CASE(AArch64ISD::CSINC)
2558     MAKE_CASE(AArch64ISD::THREAD_POINTER)
2559     MAKE_CASE(AArch64ISD::TLSDESC_CALLSEQ)
2560     MAKE_CASE(AArch64ISD::PROBED_ALLOCA)
2561     MAKE_CASE(AArch64ISD::ABDS_PRED)
2562     MAKE_CASE(AArch64ISD::ABDU_PRED)
2563     MAKE_CASE(AArch64ISD::HADDS_PRED)
2564     MAKE_CASE(AArch64ISD::HADDU_PRED)
2565     MAKE_CASE(AArch64ISD::MUL_PRED)
2566     MAKE_CASE(AArch64ISD::MULHS_PRED)
2567     MAKE_CASE(AArch64ISD::MULHU_PRED)
2568     MAKE_CASE(AArch64ISD::RHADDS_PRED)
2569     MAKE_CASE(AArch64ISD::RHADDU_PRED)
2570     MAKE_CASE(AArch64ISD::SDIV_PRED)
2571     MAKE_CASE(AArch64ISD::SHL_PRED)
2572     MAKE_CASE(AArch64ISD::SMAX_PRED)
2573     MAKE_CASE(AArch64ISD::SMIN_PRED)
2574     MAKE_CASE(AArch64ISD::SRA_PRED)
2575     MAKE_CASE(AArch64ISD::SRL_PRED)
2576     MAKE_CASE(AArch64ISD::UDIV_PRED)
2577     MAKE_CASE(AArch64ISD::UMAX_PRED)
2578     MAKE_CASE(AArch64ISD::UMIN_PRED)
2579     MAKE_CASE(AArch64ISD::SRAD_MERGE_OP1)
2580     MAKE_CASE(AArch64ISD::FNEG_MERGE_PASSTHRU)
2581     MAKE_CASE(AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU)
2582     MAKE_CASE(AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU)
2583     MAKE_CASE(AArch64ISD::FCEIL_MERGE_PASSTHRU)
2584     MAKE_CASE(AArch64ISD::FFLOOR_MERGE_PASSTHRU)
2585     MAKE_CASE(AArch64ISD::FNEARBYINT_MERGE_PASSTHRU)
2586     MAKE_CASE(AArch64ISD::FRINT_MERGE_PASSTHRU)
2587     MAKE_CASE(AArch64ISD::FROUND_MERGE_PASSTHRU)
2588     MAKE_CASE(AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU)
2589     MAKE_CASE(AArch64ISD::FTRUNC_MERGE_PASSTHRU)
2590     MAKE_CASE(AArch64ISD::FP_ROUND_MERGE_PASSTHRU)
2591     MAKE_CASE(AArch64ISD::FP_EXTEND_MERGE_PASSTHRU)
2592     MAKE_CASE(AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU)
2593     MAKE_CASE(AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU)
2594     MAKE_CASE(AArch64ISD::FCVTZU_MERGE_PASSTHRU)
2595     MAKE_CASE(AArch64ISD::FCVTZS_MERGE_PASSTHRU)
2596     MAKE_CASE(AArch64ISD::FSQRT_MERGE_PASSTHRU)
2597     MAKE_CASE(AArch64ISD::FRECPX_MERGE_PASSTHRU)
2598     MAKE_CASE(AArch64ISD::FABS_MERGE_PASSTHRU)
2599     MAKE_CASE(AArch64ISD::ABS_MERGE_PASSTHRU)
2600     MAKE_CASE(AArch64ISD::NEG_MERGE_PASSTHRU)
2601     MAKE_CASE(AArch64ISD::SETCC_MERGE_ZERO)
2602     MAKE_CASE(AArch64ISD::ADC)
2603     MAKE_CASE(AArch64ISD::SBC)
2604     MAKE_CASE(AArch64ISD::ADDS)
2605     MAKE_CASE(AArch64ISD::SUBS)
2606     MAKE_CASE(AArch64ISD::ADCS)
2607     MAKE_CASE(AArch64ISD::SBCS)
2608     MAKE_CASE(AArch64ISD::ANDS)
2609     MAKE_CASE(AArch64ISD::CCMP)
2610     MAKE_CASE(AArch64ISD::CCMN)
2611     MAKE_CASE(AArch64ISD::FCCMP)
2612     MAKE_CASE(AArch64ISD::FCMP)
2613     MAKE_CASE(AArch64ISD::STRICT_FCMP)
2614     MAKE_CASE(AArch64ISD::STRICT_FCMPE)
2615     MAKE_CASE(AArch64ISD::FCVTXN)
2616     MAKE_CASE(AArch64ISD::SME_ZA_LDR)
2617     MAKE_CASE(AArch64ISD::SME_ZA_STR)
2618     MAKE_CASE(AArch64ISD::DUP)
2619     MAKE_CASE(AArch64ISD::DUPLANE8)
2620     MAKE_CASE(AArch64ISD::DUPLANE16)
2621     MAKE_CASE(AArch64ISD::DUPLANE32)
2622     MAKE_CASE(AArch64ISD::DUPLANE64)
2623     MAKE_CASE(AArch64ISD::DUPLANE128)
2624     MAKE_CASE(AArch64ISD::MOVI)
2625     MAKE_CASE(AArch64ISD::MOVIshift)
2626     MAKE_CASE(AArch64ISD::MOVIedit)
2627     MAKE_CASE(AArch64ISD::MOVImsl)
2628     MAKE_CASE(AArch64ISD::FMOV)
2629     MAKE_CASE(AArch64ISD::MVNIshift)
2630     MAKE_CASE(AArch64ISD::MVNImsl)
2631     MAKE_CASE(AArch64ISD::BICi)
2632     MAKE_CASE(AArch64ISD::ORRi)
2633     MAKE_CASE(AArch64ISD::BSP)
2634     MAKE_CASE(AArch64ISD::ZIP1)
2635     MAKE_CASE(AArch64ISD::ZIP2)
2636     MAKE_CASE(AArch64ISD::UZP1)
2637     MAKE_CASE(AArch64ISD::UZP2)
2638     MAKE_CASE(AArch64ISD::TRN1)
2639     MAKE_CASE(AArch64ISD::TRN2)
2640     MAKE_CASE(AArch64ISD::REV16)
2641     MAKE_CASE(AArch64ISD::REV32)
2642     MAKE_CASE(AArch64ISD::REV64)
2643     MAKE_CASE(AArch64ISD::EXT)
2644     MAKE_CASE(AArch64ISD::SPLICE)
2645     MAKE_CASE(AArch64ISD::VSHL)
2646     MAKE_CASE(AArch64ISD::VLSHR)
2647     MAKE_CASE(AArch64ISD::VASHR)
2648     MAKE_CASE(AArch64ISD::VSLI)
2649     MAKE_CASE(AArch64ISD::VSRI)
2650     MAKE_CASE(AArch64ISD::CMEQ)
2651     MAKE_CASE(AArch64ISD::CMGE)
2652     MAKE_CASE(AArch64ISD::CMGT)
2653     MAKE_CASE(AArch64ISD::CMHI)
2654     MAKE_CASE(AArch64ISD::CMHS)
2655     MAKE_CASE(AArch64ISD::FCMEQ)
2656     MAKE_CASE(AArch64ISD::FCMGE)
2657     MAKE_CASE(AArch64ISD::FCMGT)
2658     MAKE_CASE(AArch64ISD::CMEQz)
2659     MAKE_CASE(AArch64ISD::CMGEz)
2660     MAKE_CASE(AArch64ISD::CMGTz)
2661     MAKE_CASE(AArch64ISD::CMLEz)
2662     MAKE_CASE(AArch64ISD::CMLTz)
2663     MAKE_CASE(AArch64ISD::FCMEQz)
2664     MAKE_CASE(AArch64ISD::FCMGEz)
2665     MAKE_CASE(AArch64ISD::FCMGTz)
2666     MAKE_CASE(AArch64ISD::FCMLEz)
2667     MAKE_CASE(AArch64ISD::FCMLTz)
2668     MAKE_CASE(AArch64ISD::SADDV)
2669     MAKE_CASE(AArch64ISD::UADDV)
2670     MAKE_CASE(AArch64ISD::UADDLV)
2671     MAKE_CASE(AArch64ISD::SADDLV)
2672     MAKE_CASE(AArch64ISD::SDOT)
2673     MAKE_CASE(AArch64ISD::UDOT)
2674     MAKE_CASE(AArch64ISD::SMINV)
2675     MAKE_CASE(AArch64ISD::UMINV)
2676     MAKE_CASE(AArch64ISD::SMAXV)
2677     MAKE_CASE(AArch64ISD::UMAXV)
2678     MAKE_CASE(AArch64ISD::SADDV_PRED)
2679     MAKE_CASE(AArch64ISD::UADDV_PRED)
2680     MAKE_CASE(AArch64ISD::SMAXV_PRED)
2681     MAKE_CASE(AArch64ISD::UMAXV_PRED)
2682     MAKE_CASE(AArch64ISD::SMINV_PRED)
2683     MAKE_CASE(AArch64ISD::UMINV_PRED)
2684     MAKE_CASE(AArch64ISD::ORV_PRED)
2685     MAKE_CASE(AArch64ISD::EORV_PRED)
2686     MAKE_CASE(AArch64ISD::ANDV_PRED)
2687     MAKE_CASE(AArch64ISD::CLASTA_N)
2688     MAKE_CASE(AArch64ISD::CLASTB_N)
2689     MAKE_CASE(AArch64ISD::LASTA)
2690     MAKE_CASE(AArch64ISD::LASTB)
2691     MAKE_CASE(AArch64ISD::REINTERPRET_CAST)
2692     MAKE_CASE(AArch64ISD::LS64_BUILD)
2693     MAKE_CASE(AArch64ISD::LS64_EXTRACT)
2694     MAKE_CASE(AArch64ISD::TBL)
2695     MAKE_CASE(AArch64ISD::FADD_PRED)
2696     MAKE_CASE(AArch64ISD::FADDA_PRED)
2697     MAKE_CASE(AArch64ISD::FADDV_PRED)
2698     MAKE_CASE(AArch64ISD::FDIV_PRED)
2699     MAKE_CASE(AArch64ISD::FMA_PRED)
2700     MAKE_CASE(AArch64ISD::FMAX_PRED)
2701     MAKE_CASE(AArch64ISD::FMAXV_PRED)
2702     MAKE_CASE(AArch64ISD::FMAXNM_PRED)
2703     MAKE_CASE(AArch64ISD::FMAXNMV_PRED)
2704     MAKE_CASE(AArch64ISD::FMIN_PRED)
2705     MAKE_CASE(AArch64ISD::FMINV_PRED)
2706     MAKE_CASE(AArch64ISD::FMINNM_PRED)
2707     MAKE_CASE(AArch64ISD::FMINNMV_PRED)
2708     MAKE_CASE(AArch64ISD::FMUL_PRED)
2709     MAKE_CASE(AArch64ISD::FSUB_PRED)
2710     MAKE_CASE(AArch64ISD::RDSVL)
2711     MAKE_CASE(AArch64ISD::BIC)
2712     MAKE_CASE(AArch64ISD::CBZ)
2713     MAKE_CASE(AArch64ISD::CBNZ)
2714     MAKE_CASE(AArch64ISD::TBZ)
2715     MAKE_CASE(AArch64ISD::TBNZ)
2716     MAKE_CASE(AArch64ISD::TC_RETURN)
2717     MAKE_CASE(AArch64ISD::PREFETCH)
2718     MAKE_CASE(AArch64ISD::SITOF)
2719     MAKE_CASE(AArch64ISD::UITOF)
2720     MAKE_CASE(AArch64ISD::NVCAST)
2721     MAKE_CASE(AArch64ISD::MRS)
2722     MAKE_CASE(AArch64ISD::SQSHL_I)
2723     MAKE_CASE(AArch64ISD::UQSHL_I)
2724     MAKE_CASE(AArch64ISD::SRSHR_I)
2725     MAKE_CASE(AArch64ISD::URSHR_I)
2726     MAKE_CASE(AArch64ISD::SQSHLU_I)
2727     MAKE_CASE(AArch64ISD::WrapperLarge)
2728     MAKE_CASE(AArch64ISD::LD2post)
2729     MAKE_CASE(AArch64ISD::LD3post)
2730     MAKE_CASE(AArch64ISD::LD4post)
2731     MAKE_CASE(AArch64ISD::ST2post)
2732     MAKE_CASE(AArch64ISD::ST3post)
2733     MAKE_CASE(AArch64ISD::ST4post)
2734     MAKE_CASE(AArch64ISD::LD1x2post)
2735     MAKE_CASE(AArch64ISD::LD1x3post)
2736     MAKE_CASE(AArch64ISD::LD1x4post)
2737     MAKE_CASE(AArch64ISD::ST1x2post)
2738     MAKE_CASE(AArch64ISD::ST1x3post)
2739     MAKE_CASE(AArch64ISD::ST1x4post)
2740     MAKE_CASE(AArch64ISD::LD1DUPpost)
2741     MAKE_CASE(AArch64ISD::LD2DUPpost)
2742     MAKE_CASE(AArch64ISD::LD3DUPpost)
2743     MAKE_CASE(AArch64ISD::LD4DUPpost)
2744     MAKE_CASE(AArch64ISD::LD1LANEpost)
2745     MAKE_CASE(AArch64ISD::LD2LANEpost)
2746     MAKE_CASE(AArch64ISD::LD3LANEpost)
2747     MAKE_CASE(AArch64ISD::LD4LANEpost)
2748     MAKE_CASE(AArch64ISD::ST2LANEpost)
2749     MAKE_CASE(AArch64ISD::ST3LANEpost)
2750     MAKE_CASE(AArch64ISD::ST4LANEpost)
2751     MAKE_CASE(AArch64ISD::SMULL)
2752     MAKE_CASE(AArch64ISD::UMULL)
2753     MAKE_CASE(AArch64ISD::PMULL)
2754     MAKE_CASE(AArch64ISD::FRECPE)
2755     MAKE_CASE(AArch64ISD::FRECPS)
2756     MAKE_CASE(AArch64ISD::FRSQRTE)
2757     MAKE_CASE(AArch64ISD::FRSQRTS)
2758     MAKE_CASE(AArch64ISD::STG)
2759     MAKE_CASE(AArch64ISD::STZG)
2760     MAKE_CASE(AArch64ISD::ST2G)
2761     MAKE_CASE(AArch64ISD::STZ2G)
2762     MAKE_CASE(AArch64ISD::SUNPKHI)
2763     MAKE_CASE(AArch64ISD::SUNPKLO)
2764     MAKE_CASE(AArch64ISD::UUNPKHI)
2765     MAKE_CASE(AArch64ISD::UUNPKLO)
2766     MAKE_CASE(AArch64ISD::INSR)
2767     MAKE_CASE(AArch64ISD::PTEST)
2768     MAKE_CASE(AArch64ISD::PTEST_ANY)
2769     MAKE_CASE(AArch64ISD::PTRUE)
2770     MAKE_CASE(AArch64ISD::LD1_MERGE_ZERO)
2771     MAKE_CASE(AArch64ISD::LD1S_MERGE_ZERO)
2772     MAKE_CASE(AArch64ISD::LDNF1_MERGE_ZERO)
2773     MAKE_CASE(AArch64ISD::LDNF1S_MERGE_ZERO)
2774     MAKE_CASE(AArch64ISD::LDFF1_MERGE_ZERO)
2775     MAKE_CASE(AArch64ISD::LDFF1S_MERGE_ZERO)
2776     MAKE_CASE(AArch64ISD::LD1RQ_MERGE_ZERO)
2777     MAKE_CASE(AArch64ISD::LD1RO_MERGE_ZERO)
2778     MAKE_CASE(AArch64ISD::SVE_LD2_MERGE_ZERO)
2779     MAKE_CASE(AArch64ISD::SVE_LD3_MERGE_ZERO)
2780     MAKE_CASE(AArch64ISD::SVE_LD4_MERGE_ZERO)
2781     MAKE_CASE(AArch64ISD::GLD1_MERGE_ZERO)
2782     MAKE_CASE(AArch64ISD::GLD1_SCALED_MERGE_ZERO)
2783     MAKE_CASE(AArch64ISD::GLD1_SXTW_MERGE_ZERO)
2784     MAKE_CASE(AArch64ISD::GLD1_UXTW_MERGE_ZERO)
2785     MAKE_CASE(AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO)
2786     MAKE_CASE(AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO)
2787     MAKE_CASE(AArch64ISD::GLD1_IMM_MERGE_ZERO)
2788     MAKE_CASE(AArch64ISD::GLD1Q_MERGE_ZERO)
2789     MAKE_CASE(AArch64ISD::GLD1Q_INDEX_MERGE_ZERO)
2790     MAKE_CASE(AArch64ISD::GLD1S_MERGE_ZERO)
2791     MAKE_CASE(AArch64ISD::GLD1S_SCALED_MERGE_ZERO)
2792     MAKE_CASE(AArch64ISD::GLD1S_SXTW_MERGE_ZERO)
2793     MAKE_CASE(AArch64ISD::GLD1S_UXTW_MERGE_ZERO)
2794     MAKE_CASE(AArch64ISD::GLD1S_SXTW_SCALED_MERGE_ZERO)
2795     MAKE_CASE(AArch64ISD::GLD1S_UXTW_SCALED_MERGE_ZERO)
2796     MAKE_CASE(AArch64ISD::GLD1S_IMM_MERGE_ZERO)
2797     MAKE_CASE(AArch64ISD::GLDFF1_MERGE_ZERO)
2798     MAKE_CASE(AArch64ISD::GLDFF1_SCALED_MERGE_ZERO)
2799     MAKE_CASE(AArch64ISD::GLDFF1_SXTW_MERGE_ZERO)
2800     MAKE_CASE(AArch64ISD::GLDFF1_UXTW_MERGE_ZERO)
2801     MAKE_CASE(AArch64ISD::GLDFF1_SXTW_SCALED_MERGE_ZERO)
2802     MAKE_CASE(AArch64ISD::GLDFF1_UXTW_SCALED_MERGE_ZERO)
2803     MAKE_CASE(AArch64ISD::GLDFF1_IMM_MERGE_ZERO)
2804     MAKE_CASE(AArch64ISD::GLDFF1S_MERGE_ZERO)
2805     MAKE_CASE(AArch64ISD::GLDFF1S_SCALED_MERGE_ZERO)
2806     MAKE_CASE(AArch64ISD::GLDFF1S_SXTW_MERGE_ZERO)
2807     MAKE_CASE(AArch64ISD::GLDFF1S_UXTW_MERGE_ZERO)
2808     MAKE_CASE(AArch64ISD::GLDFF1S_SXTW_SCALED_MERGE_ZERO)
2809     MAKE_CASE(AArch64ISD::GLDFF1S_UXTW_SCALED_MERGE_ZERO)
2810     MAKE_CASE(AArch64ISD::GLDFF1S_IMM_MERGE_ZERO)
2811     MAKE_CASE(AArch64ISD::GLDNT1_MERGE_ZERO)
2812     MAKE_CASE(AArch64ISD::GLDNT1_INDEX_MERGE_ZERO)
2813     MAKE_CASE(AArch64ISD::GLDNT1S_MERGE_ZERO)
2814     MAKE_CASE(AArch64ISD::SST1Q_PRED)
2815     MAKE_CASE(AArch64ISD::SST1Q_INDEX_PRED)
2816     MAKE_CASE(AArch64ISD::ST1_PRED)
2817     MAKE_CASE(AArch64ISD::SST1_PRED)
2818     MAKE_CASE(AArch64ISD::SST1_SCALED_PRED)
2819     MAKE_CASE(AArch64ISD::SST1_SXTW_PRED)
2820     MAKE_CASE(AArch64ISD::SST1_UXTW_PRED)
2821     MAKE_CASE(AArch64ISD::SST1_SXTW_SCALED_PRED)
2822     MAKE_CASE(AArch64ISD::SST1_UXTW_SCALED_PRED)
2823     MAKE_CASE(AArch64ISD::SST1_IMM_PRED)
2824     MAKE_CASE(AArch64ISD::SSTNT1_PRED)
2825     MAKE_CASE(AArch64ISD::SSTNT1_INDEX_PRED)
2826     MAKE_CASE(AArch64ISD::LDP)
2827     MAKE_CASE(AArch64ISD::LDIAPP)
2828     MAKE_CASE(AArch64ISD::LDNP)
2829     MAKE_CASE(AArch64ISD::STP)
2830     MAKE_CASE(AArch64ISD::STILP)
2831     MAKE_CASE(AArch64ISD::STNP)
2832     MAKE_CASE(AArch64ISD::BITREVERSE_MERGE_PASSTHRU)
2833     MAKE_CASE(AArch64ISD::BSWAP_MERGE_PASSTHRU)
2834     MAKE_CASE(AArch64ISD::REVH_MERGE_PASSTHRU)
2835     MAKE_CASE(AArch64ISD::REVW_MERGE_PASSTHRU)
2836     MAKE_CASE(AArch64ISD::REVD_MERGE_PASSTHRU)
2837     MAKE_CASE(AArch64ISD::CTLZ_MERGE_PASSTHRU)
2838     MAKE_CASE(AArch64ISD::CTPOP_MERGE_PASSTHRU)
2839     MAKE_CASE(AArch64ISD::DUP_MERGE_PASSTHRU)
2840     MAKE_CASE(AArch64ISD::INDEX_VECTOR)
2841     MAKE_CASE(AArch64ISD::ADDP)
2842     MAKE_CASE(AArch64ISD::SADDLP)
2843     MAKE_CASE(AArch64ISD::UADDLP)
2844     MAKE_CASE(AArch64ISD::CALL_RVMARKER)
2845     MAKE_CASE(AArch64ISD::ASSERT_ZEXT_BOOL)
2846     MAKE_CASE(AArch64ISD::MOPS_MEMSET)
2847     MAKE_CASE(AArch64ISD::MOPS_MEMSET_TAGGING)
2848     MAKE_CASE(AArch64ISD::MOPS_MEMCOPY)
2849     MAKE_CASE(AArch64ISD::MOPS_MEMMOVE)
2850     MAKE_CASE(AArch64ISD::CALL_BTI)
2851     MAKE_CASE(AArch64ISD::MRRS)
2852     MAKE_CASE(AArch64ISD::MSRR)
2853     MAKE_CASE(AArch64ISD::RSHRNB_I)
2854     MAKE_CASE(AArch64ISD::CTTZ_ELTS)
2855     MAKE_CASE(AArch64ISD::CALL_ARM64EC_TO_X64)
2856     MAKE_CASE(AArch64ISD::URSHR_I_PRED)
2857   }
2858 #undef MAKE_CASE
2859   return nullptr;
2860 }
2861 
2862 MachineBasicBlock *
EmitF128CSEL(MachineInstr & MI,MachineBasicBlock * MBB) const2863 AArch64TargetLowering::EmitF128CSEL(MachineInstr &MI,
2864                                     MachineBasicBlock *MBB) const {
2865   // We materialise the F128CSEL pseudo-instruction as some control flow and a
2866   // phi node:
2867 
2868   // OrigBB:
2869   //     [... previous instrs leading to comparison ...]
2870   //     b.ne TrueBB
2871   //     b EndBB
2872   // TrueBB:
2873   //     ; Fallthrough
2874   // EndBB:
2875   //     Dest = PHI [IfTrue, TrueBB], [IfFalse, OrigBB]
2876 
2877   MachineFunction *MF = MBB->getParent();
2878   const TargetInstrInfo *TII = Subtarget->getInstrInfo();
2879   const BasicBlock *LLVM_BB = MBB->getBasicBlock();
2880   DebugLoc DL = MI.getDebugLoc();
2881   MachineFunction::iterator It = ++MBB->getIterator();
2882 
2883   Register DestReg = MI.getOperand(0).getReg();
2884   Register IfTrueReg = MI.getOperand(1).getReg();
2885   Register IfFalseReg = MI.getOperand(2).getReg();
2886   unsigned CondCode = MI.getOperand(3).getImm();
2887   bool NZCVKilled = MI.getOperand(4).isKill();
2888 
2889   MachineBasicBlock *TrueBB = MF->CreateMachineBasicBlock(LLVM_BB);
2890   MachineBasicBlock *EndBB = MF->CreateMachineBasicBlock(LLVM_BB);
2891   MF->insert(It, TrueBB);
2892   MF->insert(It, EndBB);
2893 
2894   // Transfer rest of current basic-block to EndBB
2895   EndBB->splice(EndBB->begin(), MBB, std::next(MachineBasicBlock::iterator(MI)),
2896                 MBB->end());
2897   EndBB->transferSuccessorsAndUpdatePHIs(MBB);
2898 
2899   BuildMI(MBB, DL, TII->get(AArch64::Bcc)).addImm(CondCode).addMBB(TrueBB);
2900   BuildMI(MBB, DL, TII->get(AArch64::B)).addMBB(EndBB);
2901   MBB->addSuccessor(TrueBB);
2902   MBB->addSuccessor(EndBB);
2903 
2904   // TrueBB falls through to the end.
2905   TrueBB->addSuccessor(EndBB);
2906 
2907   if (!NZCVKilled) {
2908     TrueBB->addLiveIn(AArch64::NZCV);
2909     EndBB->addLiveIn(AArch64::NZCV);
2910   }
2911 
2912   BuildMI(*EndBB, EndBB->begin(), DL, TII->get(AArch64::PHI), DestReg)
2913       .addReg(IfTrueReg)
2914       .addMBB(TrueBB)
2915       .addReg(IfFalseReg)
2916       .addMBB(MBB);
2917 
2918   MI.eraseFromParent();
2919   return EndBB;
2920 }
2921 
EmitLoweredCatchRet(MachineInstr & MI,MachineBasicBlock * BB) const2922 MachineBasicBlock *AArch64TargetLowering::EmitLoweredCatchRet(
2923        MachineInstr &MI, MachineBasicBlock *BB) const {
2924   assert(!isAsynchronousEHPersonality(classifyEHPersonality(
2925              BB->getParent()->getFunction().getPersonalityFn())) &&
2926          "SEH does not use catchret!");
2927   return BB;
2928 }
2929 
2930 MachineBasicBlock *
EmitDynamicProbedAlloc(MachineInstr & MI,MachineBasicBlock * MBB) const2931 AArch64TargetLowering::EmitDynamicProbedAlloc(MachineInstr &MI,
2932                                               MachineBasicBlock *MBB) const {
2933   MachineFunction &MF = *MBB->getParent();
2934   MachineBasicBlock::iterator MBBI = MI.getIterator();
2935   DebugLoc DL = MBB->findDebugLoc(MBBI);
2936   const AArch64InstrInfo &TII =
2937       *MF.getSubtarget<AArch64Subtarget>().getInstrInfo();
2938   Register TargetReg = MI.getOperand(0).getReg();
2939   MachineBasicBlock::iterator NextInst =
2940       TII.probedStackAlloc(MBBI, TargetReg, false);
2941 
2942   MI.eraseFromParent();
2943   return NextInst->getParent();
2944 }
2945 
2946 MachineBasicBlock *
EmitTileLoad(unsigned Opc,unsigned BaseReg,MachineInstr & MI,MachineBasicBlock * BB) const2947 AArch64TargetLowering::EmitTileLoad(unsigned Opc, unsigned BaseReg,
2948                                     MachineInstr &MI,
2949                                     MachineBasicBlock *BB) const {
2950   const TargetInstrInfo *TII = Subtarget->getInstrInfo();
2951   MachineInstrBuilder MIB = BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(Opc));
2952 
2953   MIB.addReg(BaseReg + MI.getOperand(0).getImm(), RegState::Define);
2954   MIB.add(MI.getOperand(1)); // slice index register
2955   MIB.add(MI.getOperand(2)); // slice index offset
2956   MIB.add(MI.getOperand(3)); // pg
2957   MIB.add(MI.getOperand(4)); // base
2958   MIB.add(MI.getOperand(5)); // offset
2959 
2960   MI.eraseFromParent(); // The pseudo is gone now.
2961   return BB;
2962 }
2963 
2964 MachineBasicBlock *
EmitFill(MachineInstr & MI,MachineBasicBlock * BB) const2965 AArch64TargetLowering::EmitFill(MachineInstr &MI, MachineBasicBlock *BB) const {
2966   const TargetInstrInfo *TII = Subtarget->getInstrInfo();
2967   MachineInstrBuilder MIB =
2968       BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::LDR_ZA));
2969 
2970   MIB.addReg(AArch64::ZA, RegState::Define);
2971   MIB.add(MI.getOperand(0)); // Vector select register
2972   MIB.add(MI.getOperand(1)); // Vector select offset
2973   MIB.add(MI.getOperand(2)); // Base
2974   MIB.add(MI.getOperand(1)); // Offset, same as vector select offset
2975 
2976   MI.eraseFromParent(); // The pseudo is gone now.
2977   return BB;
2978 }
2979 
EmitZTInstr(MachineInstr & MI,MachineBasicBlock * BB,unsigned Opcode,bool Op0IsDef) const2980 MachineBasicBlock *AArch64TargetLowering::EmitZTInstr(MachineInstr &MI,
2981                                                       MachineBasicBlock *BB,
2982                                                       unsigned Opcode,
2983                                                       bool Op0IsDef) const {
2984   const TargetInstrInfo *TII = Subtarget->getInstrInfo();
2985   MachineInstrBuilder MIB;
2986 
2987   MIB = BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(Opcode))
2988             .addReg(MI.getOperand(0).getReg(), Op0IsDef ? RegState::Define : 0);
2989   for (unsigned I = 1; I < MI.getNumOperands(); ++I)
2990     MIB.add(MI.getOperand(I));
2991 
2992   MI.eraseFromParent(); // The pseudo is gone now.
2993   return BB;
2994 }
2995 
2996 MachineBasicBlock *
EmitZAInstr(unsigned Opc,unsigned BaseReg,MachineInstr & MI,MachineBasicBlock * BB) const2997 AArch64TargetLowering::EmitZAInstr(unsigned Opc, unsigned BaseReg,
2998                                    MachineInstr &MI,
2999                                    MachineBasicBlock *BB) const {
3000   const TargetInstrInfo *TII = Subtarget->getInstrInfo();
3001   MachineInstrBuilder MIB = BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(Opc));
3002   unsigned StartIdx = 0;
3003 
3004   bool HasTile = BaseReg != AArch64::ZA;
3005   bool HasZPROut = HasTile && MI.getOperand(0).isReg();
3006   if (HasZPROut) {
3007     MIB.add(MI.getOperand(StartIdx)); // Output ZPR
3008     ++StartIdx;
3009   }
3010   if (HasTile) {
3011     MIB.addReg(BaseReg + MI.getOperand(StartIdx).getImm(),
3012                RegState::Define);                           // Output ZA Tile
3013     MIB.addReg(BaseReg + MI.getOperand(StartIdx).getImm()); // Input Za Tile
3014     StartIdx++;
3015   } else {
3016     // Avoids all instructions with mnemonic za.<sz>[Reg, Imm,
3017     if (MI.getOperand(0).isReg() && !MI.getOperand(1).isImm()) {
3018       MIB.add(MI.getOperand(StartIdx)); // Output ZPR
3019       ++StartIdx;
3020     }
3021     MIB.addReg(BaseReg, RegState::Define).addReg(BaseReg);
3022   }
3023   for (unsigned I = StartIdx; I < MI.getNumOperands(); ++I)
3024     MIB.add(MI.getOperand(I));
3025 
3026   MI.eraseFromParent(); // The pseudo is gone now.
3027   return BB;
3028 }
3029 
3030 MachineBasicBlock *
EmitZero(MachineInstr & MI,MachineBasicBlock * BB) const3031 AArch64TargetLowering::EmitZero(MachineInstr &MI, MachineBasicBlock *BB) const {
3032   const TargetInstrInfo *TII = Subtarget->getInstrInfo();
3033   MachineInstrBuilder MIB =
3034       BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::ZERO_M));
3035   MIB.add(MI.getOperand(0)); // Mask
3036 
3037   unsigned Mask = MI.getOperand(0).getImm();
3038   for (unsigned I = 0; I < 8; I++) {
3039     if (Mask & (1 << I))
3040       MIB.addDef(AArch64::ZAD0 + I, RegState::ImplicitDefine);
3041   }
3042 
3043   MI.eraseFromParent(); // The pseudo is gone now.
3044   return BB;
3045 }
3046 
3047 MachineBasicBlock *
EmitInitTPIDR2Object(MachineInstr & MI,MachineBasicBlock * BB) const3048 AArch64TargetLowering::EmitInitTPIDR2Object(MachineInstr &MI,
3049                                             MachineBasicBlock *BB) const {
3050   MachineFunction *MF = BB->getParent();
3051   MachineFrameInfo &MFI = MF->getFrameInfo();
3052   AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
3053   TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
3054   if (TPIDR2.Uses > 0) {
3055     const TargetInstrInfo *TII = Subtarget->getInstrInfo();
3056     // Store the buffer pointer to the TPIDR2 stack object.
3057     BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::STRXui))
3058         .addReg(MI.getOperand(0).getReg())
3059         .addFrameIndex(TPIDR2.FrameIndex)
3060         .addImm(0);
3061     // Set the reserved bytes (10-15) to zero
3062     BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::STRHHui))
3063         .addReg(AArch64::WZR)
3064         .addFrameIndex(TPIDR2.FrameIndex)
3065         .addImm(5);
3066     BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::STRWui))
3067         .addReg(AArch64::WZR)
3068         .addFrameIndex(TPIDR2.FrameIndex)
3069         .addImm(3);
3070   } else
3071     MFI.RemoveStackObject(TPIDR2.FrameIndex);
3072 
3073   BB->remove_instr(&MI);
3074   return BB;
3075 }
3076 
3077 MachineBasicBlock *
EmitAllocateZABuffer(MachineInstr & MI,MachineBasicBlock * BB) const3078 AArch64TargetLowering::EmitAllocateZABuffer(MachineInstr &MI,
3079                                             MachineBasicBlock *BB) const {
3080   MachineFunction *MF = BB->getParent();
3081   MachineFrameInfo &MFI = MF->getFrameInfo();
3082   AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
3083   // TODO This function grows the stack with a subtraction, which doesn't work
3084   // on Windows. Some refactoring to share the functionality in
3085   // LowerWindowsDYNAMIC_STACKALLOC will be required once the Windows ABI
3086   // supports SME
3087   assert(!MF->getSubtarget<AArch64Subtarget>().isTargetWindows() &&
3088          "Lazy ZA save is not yet supported on Windows");
3089 
3090   TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
3091 
3092   if (TPIDR2.Uses > 0) {
3093     const TargetInstrInfo *TII = Subtarget->getInstrInfo();
3094     MachineRegisterInfo &MRI = MF->getRegInfo();
3095 
3096     // The SUBXrs below won't always be emitted in a form that accepts SP
3097     // directly
3098     Register SP = MRI.createVirtualRegister(&AArch64::GPR64RegClass);
3099     BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), SP)
3100         .addReg(AArch64::SP);
3101 
3102     // Allocate a lazy-save buffer object of the size given, normally SVL * SVL
3103     auto Size = MI.getOperand(1).getReg();
3104     auto Dest = MI.getOperand(0).getReg();
3105     BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::MSUBXrrr), Dest)
3106         .addReg(Size)
3107         .addReg(Size)
3108         .addReg(SP);
3109     BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
3110             AArch64::SP)
3111         .addReg(Dest);
3112 
3113     // We have just allocated a variable sized object, tell this to PEI.
3114     MFI.CreateVariableSizedObject(Align(16), nullptr);
3115   }
3116 
3117   BB->remove_instr(&MI);
3118   return BB;
3119 }
3120 
EmitInstrWithCustomInserter(MachineInstr & MI,MachineBasicBlock * BB) const3121 MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
3122     MachineInstr &MI, MachineBasicBlock *BB) const {
3123 
3124   int SMEOrigInstr = AArch64::getSMEPseudoMap(MI.getOpcode());
3125   if (SMEOrigInstr != -1) {
3126     const TargetInstrInfo *TII = Subtarget->getInstrInfo();
3127     uint64_t SMEMatrixType =
3128         TII->get(MI.getOpcode()).TSFlags & AArch64::SMEMatrixTypeMask;
3129     switch (SMEMatrixType) {
3130     case (AArch64::SMEMatrixArray):
3131       return EmitZAInstr(SMEOrigInstr, AArch64::ZA, MI, BB);
3132     case (AArch64::SMEMatrixTileB):
3133       return EmitZAInstr(SMEOrigInstr, AArch64::ZAB0, MI, BB);
3134     case (AArch64::SMEMatrixTileH):
3135       return EmitZAInstr(SMEOrigInstr, AArch64::ZAH0, MI, BB);
3136     case (AArch64::SMEMatrixTileS):
3137       return EmitZAInstr(SMEOrigInstr, AArch64::ZAS0, MI, BB);
3138     case (AArch64::SMEMatrixTileD):
3139       return EmitZAInstr(SMEOrigInstr, AArch64::ZAD0, MI, BB);
3140     case (AArch64::SMEMatrixTileQ):
3141       return EmitZAInstr(SMEOrigInstr, AArch64::ZAQ0, MI, BB);
3142     }
3143   }
3144 
3145   switch (MI.getOpcode()) {
3146   default:
3147 #ifndef NDEBUG
3148     MI.dump();
3149 #endif
3150     llvm_unreachable("Unexpected instruction for custom inserter!");
3151   case AArch64::InitTPIDR2Obj:
3152     return EmitInitTPIDR2Object(MI, BB);
3153   case AArch64::AllocateZABuffer:
3154     return EmitAllocateZABuffer(MI, BB);
3155   case AArch64::F128CSEL:
3156     return EmitF128CSEL(MI, BB);
3157   case TargetOpcode::STATEPOINT:
3158     // STATEPOINT is a pseudo instruction which has no implicit defs/uses
3159     // while bl call instruction (where statepoint will be lowered at the end)
3160     // has implicit def. This def is early-clobber as it will be set at
3161     // the moment of the call and earlier than any use is read.
3162     // Add this implicit dead def here as a workaround.
3163     MI.addOperand(*MI.getMF(),
3164                   MachineOperand::CreateReg(
3165                       AArch64::LR, /*isDef*/ true,
3166                       /*isImp*/ true, /*isKill*/ false, /*isDead*/ true,
3167                       /*isUndef*/ false, /*isEarlyClobber*/ true));
3168     [[fallthrough]];
3169   case TargetOpcode::STACKMAP:
3170   case TargetOpcode::PATCHPOINT:
3171     return emitPatchPoint(MI, BB);
3172 
3173   case TargetOpcode::PATCHABLE_EVENT_CALL:
3174   case TargetOpcode::PATCHABLE_TYPED_EVENT_CALL:
3175     return BB;
3176 
3177   case AArch64::CATCHRET:
3178     return EmitLoweredCatchRet(MI, BB);
3179 
3180   case AArch64::PROBED_STACKALLOC_DYN:
3181     return EmitDynamicProbedAlloc(MI, BB);
3182 
3183   case AArch64::LD1_MXIPXX_H_PSEUDO_B:
3184     return EmitTileLoad(AArch64::LD1_MXIPXX_H_B, AArch64::ZAB0, MI, BB);
3185   case AArch64::LD1_MXIPXX_H_PSEUDO_H:
3186     return EmitTileLoad(AArch64::LD1_MXIPXX_H_H, AArch64::ZAH0, MI, BB);
3187   case AArch64::LD1_MXIPXX_H_PSEUDO_S:
3188     return EmitTileLoad(AArch64::LD1_MXIPXX_H_S, AArch64::ZAS0, MI, BB);
3189   case AArch64::LD1_MXIPXX_H_PSEUDO_D:
3190     return EmitTileLoad(AArch64::LD1_MXIPXX_H_D, AArch64::ZAD0, MI, BB);
3191   case AArch64::LD1_MXIPXX_H_PSEUDO_Q:
3192     return EmitTileLoad(AArch64::LD1_MXIPXX_H_Q, AArch64::ZAQ0, MI, BB);
3193   case AArch64::LD1_MXIPXX_V_PSEUDO_B:
3194     return EmitTileLoad(AArch64::LD1_MXIPXX_V_B, AArch64::ZAB0, MI, BB);
3195   case AArch64::LD1_MXIPXX_V_PSEUDO_H:
3196     return EmitTileLoad(AArch64::LD1_MXIPXX_V_H, AArch64::ZAH0, MI, BB);
3197   case AArch64::LD1_MXIPXX_V_PSEUDO_S:
3198     return EmitTileLoad(AArch64::LD1_MXIPXX_V_S, AArch64::ZAS0, MI, BB);
3199   case AArch64::LD1_MXIPXX_V_PSEUDO_D:
3200     return EmitTileLoad(AArch64::LD1_MXIPXX_V_D, AArch64::ZAD0, MI, BB);
3201   case AArch64::LD1_MXIPXX_V_PSEUDO_Q:
3202     return EmitTileLoad(AArch64::LD1_MXIPXX_V_Q, AArch64::ZAQ0, MI, BB);
3203   case AArch64::LDR_ZA_PSEUDO:
3204     return EmitFill(MI, BB);
3205   case AArch64::LDR_TX_PSEUDO:
3206     return EmitZTInstr(MI, BB, AArch64::LDR_TX, /*Op0IsDef=*/true);
3207   case AArch64::STR_TX_PSEUDO:
3208     return EmitZTInstr(MI, BB, AArch64::STR_TX, /*Op0IsDef=*/false);
3209   case AArch64::ZERO_M_PSEUDO:
3210     return EmitZero(MI, BB);
3211   case AArch64::ZERO_T_PSEUDO:
3212     return EmitZTInstr(MI, BB, AArch64::ZERO_T, /*Op0IsDef=*/true);
3213   }
3214 }
3215 
3216 //===----------------------------------------------------------------------===//
3217 // AArch64 Lowering private implementation.
3218 //===----------------------------------------------------------------------===//
3219 
3220 //===----------------------------------------------------------------------===//
3221 // Lowering Code
3222 //===----------------------------------------------------------------------===//
3223 
3224 // Forward declarations of SVE fixed length lowering helpers
3225 static EVT getContainerForFixedLengthVector(SelectionDAG &DAG, EVT VT);
3226 static SDValue convertToScalableVector(SelectionDAG &DAG, EVT VT, SDValue V);
3227 static SDValue convertFromScalableVector(SelectionDAG &DAG, EVT VT, SDValue V);
3228 static SDValue convertFixedMaskToScalableVector(SDValue Mask,
3229                                                 SelectionDAG &DAG);
3230 static SDValue getPredicateForVector(SelectionDAG &DAG, SDLoc &DL, EVT VT);
3231 static SDValue getPredicateForScalableVector(SelectionDAG &DAG, SDLoc &DL,
3232                                              EVT VT);
3233 
3234 /// isZerosVector - Check whether SDNode N is a zero-filled vector.
isZerosVector(const SDNode * N)3235 static bool isZerosVector(const SDNode *N) {
3236   // Look through a bit convert.
3237   while (N->getOpcode() == ISD::BITCAST)
3238     N = N->getOperand(0).getNode();
3239 
3240   if (ISD::isConstantSplatVectorAllZeros(N))
3241     return true;
3242 
3243   if (N->getOpcode() != AArch64ISD::DUP)
3244     return false;
3245 
3246   auto Opnd0 = N->getOperand(0);
3247   return isNullConstant(Opnd0) || isNullFPConstant(Opnd0);
3248 }
3249 
3250 /// changeIntCCToAArch64CC - Convert a DAG integer condition code to an AArch64
3251 /// CC
changeIntCCToAArch64CC(ISD::CondCode CC)3252 static AArch64CC::CondCode changeIntCCToAArch64CC(ISD::CondCode CC) {
3253   switch (CC) {
3254   default:
3255     llvm_unreachable("Unknown condition code!");
3256   case ISD::SETNE:
3257     return AArch64CC::NE;
3258   case ISD::SETEQ:
3259     return AArch64CC::EQ;
3260   case ISD::SETGT:
3261     return AArch64CC::GT;
3262   case ISD::SETGE:
3263     return AArch64CC::GE;
3264   case ISD::SETLT:
3265     return AArch64CC::LT;
3266   case ISD::SETLE:
3267     return AArch64CC::LE;
3268   case ISD::SETUGT:
3269     return AArch64CC::HI;
3270   case ISD::SETUGE:
3271     return AArch64CC::HS;
3272   case ISD::SETULT:
3273     return AArch64CC::LO;
3274   case ISD::SETULE:
3275     return AArch64CC::LS;
3276   }
3277 }
3278 
3279 /// changeFPCCToAArch64CC - Convert a DAG fp condition code to an AArch64 CC.
changeFPCCToAArch64CC(ISD::CondCode CC,AArch64CC::CondCode & CondCode,AArch64CC::CondCode & CondCode2)3280 static void changeFPCCToAArch64CC(ISD::CondCode CC,
3281                                   AArch64CC::CondCode &CondCode,
3282                                   AArch64CC::CondCode &CondCode2) {
3283   CondCode2 = AArch64CC::AL;
3284   switch (CC) {
3285   default:
3286     llvm_unreachable("Unknown FP condition!");
3287   case ISD::SETEQ:
3288   case ISD::SETOEQ:
3289     CondCode = AArch64CC::EQ;
3290     break;
3291   case ISD::SETGT:
3292   case ISD::SETOGT:
3293     CondCode = AArch64CC::GT;
3294     break;
3295   case ISD::SETGE:
3296   case ISD::SETOGE:
3297     CondCode = AArch64CC::GE;
3298     break;
3299   case ISD::SETOLT:
3300     CondCode = AArch64CC::MI;
3301     break;
3302   case ISD::SETOLE:
3303     CondCode = AArch64CC::LS;
3304     break;
3305   case ISD::SETONE:
3306     CondCode = AArch64CC::MI;
3307     CondCode2 = AArch64CC::GT;
3308     break;
3309   case ISD::SETO:
3310     CondCode = AArch64CC::VC;
3311     break;
3312   case ISD::SETUO:
3313     CondCode = AArch64CC::VS;
3314     break;
3315   case ISD::SETUEQ:
3316     CondCode = AArch64CC::EQ;
3317     CondCode2 = AArch64CC::VS;
3318     break;
3319   case ISD::SETUGT:
3320     CondCode = AArch64CC::HI;
3321     break;
3322   case ISD::SETUGE:
3323     CondCode = AArch64CC::PL;
3324     break;
3325   case ISD::SETLT:
3326   case ISD::SETULT:
3327     CondCode = AArch64CC::LT;
3328     break;
3329   case ISD::SETLE:
3330   case ISD::SETULE:
3331     CondCode = AArch64CC::LE;
3332     break;
3333   case ISD::SETNE:
3334   case ISD::SETUNE:
3335     CondCode = AArch64CC::NE;
3336     break;
3337   }
3338 }
3339 
3340 /// Convert a DAG fp condition code to an AArch64 CC.
3341 /// This differs from changeFPCCToAArch64CC in that it returns cond codes that
3342 /// should be AND'ed instead of OR'ed.
changeFPCCToANDAArch64CC(ISD::CondCode CC,AArch64CC::CondCode & CondCode,AArch64CC::CondCode & CondCode2)3343 static void changeFPCCToANDAArch64CC(ISD::CondCode CC,
3344                                      AArch64CC::CondCode &CondCode,
3345                                      AArch64CC::CondCode &CondCode2) {
3346   CondCode2 = AArch64CC::AL;
3347   switch (CC) {
3348   default:
3349     changeFPCCToAArch64CC(CC, CondCode, CondCode2);
3350     assert(CondCode2 == AArch64CC::AL);
3351     break;
3352   case ISD::SETONE:
3353     // (a one b)
3354     // == ((a olt b) || (a ogt b))
3355     // == ((a ord b) && (a une b))
3356     CondCode = AArch64CC::VC;
3357     CondCode2 = AArch64CC::NE;
3358     break;
3359   case ISD::SETUEQ:
3360     // (a ueq b)
3361     // == ((a uno b) || (a oeq b))
3362     // == ((a ule b) && (a uge b))
3363     CondCode = AArch64CC::PL;
3364     CondCode2 = AArch64CC::LE;
3365     break;
3366   }
3367 }
3368 
3369 /// changeVectorFPCCToAArch64CC - Convert a DAG fp condition code to an AArch64
3370 /// CC usable with the vector instructions. Fewer operations are available
3371 /// without a real NZCV register, so we have to use less efficient combinations
3372 /// to get the same effect.
changeVectorFPCCToAArch64CC(ISD::CondCode CC,AArch64CC::CondCode & CondCode,AArch64CC::CondCode & CondCode2,bool & Invert)3373 static void changeVectorFPCCToAArch64CC(ISD::CondCode CC,
3374                                         AArch64CC::CondCode &CondCode,
3375                                         AArch64CC::CondCode &CondCode2,
3376                                         bool &Invert) {
3377   Invert = false;
3378   switch (CC) {
3379   default:
3380     // Mostly the scalar mappings work fine.
3381     changeFPCCToAArch64CC(CC, CondCode, CondCode2);
3382     break;
3383   case ISD::SETUO:
3384     Invert = true;
3385     [[fallthrough]];
3386   case ISD::SETO:
3387     CondCode = AArch64CC::MI;
3388     CondCode2 = AArch64CC::GE;
3389     break;
3390   case ISD::SETUEQ:
3391   case ISD::SETULT:
3392   case ISD::SETULE:
3393   case ISD::SETUGT:
3394   case ISD::SETUGE:
3395     // All of the compare-mask comparisons are ordered, but we can switch
3396     // between the two by a double inversion. E.g. ULE == !OGT.
3397     Invert = true;
3398     changeFPCCToAArch64CC(getSetCCInverse(CC, /* FP inverse */ MVT::f32),
3399                           CondCode, CondCode2);
3400     break;
3401   }
3402 }
3403 
isLegalArithImmed(uint64_t C)3404 static bool isLegalArithImmed(uint64_t C) {
3405   // Matches AArch64DAGToDAGISel::SelectArithImmed().
3406   bool IsLegal = (C >> 12 == 0) || ((C & 0xFFFULL) == 0 && C >> 24 == 0);
3407   LLVM_DEBUG(dbgs() << "Is imm " << C
3408                     << " legal: " << (IsLegal ? "yes\n" : "no\n"));
3409   return IsLegal;
3410 }
3411 
cannotBeIntMin(SDValue CheckedVal,SelectionDAG & DAG)3412 static bool cannotBeIntMin(SDValue CheckedVal, SelectionDAG &DAG) {
3413   KnownBits KnownSrc = DAG.computeKnownBits(CheckedVal);
3414   return !KnownSrc.getSignedMinValue().isMinSignedValue();
3415 }
3416 
3417 // Can a (CMP op1, (sub 0, op2) be turned into a CMN instruction on
3418 // the grounds that "op1 - (-op2) == op1 + op2" ? Not always, the C and V flags
3419 // can be set differently by this operation. It comes down to whether
3420 // "SInt(~op2)+1 == SInt(~op2+1)" (and the same for UInt). If they are then
3421 // everything is fine. If not then the optimization is wrong. Thus general
3422 // comparisons are only valid if op2 != 0.
3423 //
3424 // So, finally, the only LLVM-native comparisons that don't mention C or V
3425 // are the ones that aren't unsigned comparisons. They're the only ones we can
3426 // safely use CMN for in the absence of information about op2.
isCMN(SDValue Op,ISD::CondCode CC,SelectionDAG & DAG)3427 static bool isCMN(SDValue Op, ISD::CondCode CC, SelectionDAG &DAG) {
3428   return Op.getOpcode() == ISD::SUB && isNullConstant(Op.getOperand(0)) &&
3429          (isIntEqualitySetCC(CC) ||
3430           (isUnsignedIntSetCC(CC) && DAG.isKnownNeverZero(Op.getOperand(1))) ||
3431           (isSignedIntSetCC(CC) && cannotBeIntMin(Op.getOperand(1), DAG)));
3432 }
3433 
emitStrictFPComparison(SDValue LHS,SDValue RHS,const SDLoc & dl,SelectionDAG & DAG,SDValue Chain,bool IsSignaling)3434 static SDValue emitStrictFPComparison(SDValue LHS, SDValue RHS, const SDLoc &dl,
3435                                       SelectionDAG &DAG, SDValue Chain,
3436                                       bool IsSignaling) {
3437   EVT VT = LHS.getValueType();
3438   assert(VT != MVT::f128);
3439 
3440   const bool FullFP16 = DAG.getSubtarget<AArch64Subtarget>().hasFullFP16();
3441 
3442   if ((VT == MVT::f16 && !FullFP16) || VT == MVT::bf16) {
3443     LHS = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {MVT::f32, MVT::Other},
3444                       {Chain, LHS});
3445     RHS = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {MVT::f32, MVT::Other},
3446                       {LHS.getValue(1), RHS});
3447     Chain = RHS.getValue(1);
3448     VT = MVT::f32;
3449   }
3450   unsigned Opcode =
3451       IsSignaling ? AArch64ISD::STRICT_FCMPE : AArch64ISD::STRICT_FCMP;
3452   return DAG.getNode(Opcode, dl, {VT, MVT::Other}, {Chain, LHS, RHS});
3453 }
3454 
emitComparison(SDValue LHS,SDValue RHS,ISD::CondCode CC,const SDLoc & dl,SelectionDAG & DAG)3455 static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC,
3456                               const SDLoc &dl, SelectionDAG &DAG) {
3457   EVT VT = LHS.getValueType();
3458   const bool FullFP16 = DAG.getSubtarget<AArch64Subtarget>().hasFullFP16();
3459 
3460   if (VT.isFloatingPoint()) {
3461     assert(VT != MVT::f128);
3462     if ((VT == MVT::f16 && !FullFP16) || VT == MVT::bf16) {
3463       LHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, LHS);
3464       RHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, RHS);
3465       VT = MVT::f32;
3466     }
3467     return DAG.getNode(AArch64ISD::FCMP, dl, VT, LHS, RHS);
3468   }
3469 
3470   // The CMP instruction is just an alias for SUBS, and representing it as
3471   // SUBS means that it's possible to get CSE with subtract operations.
3472   // A later phase can perform the optimization of setting the destination
3473   // register to WZR/XZR if it ends up being unused.
3474   unsigned Opcode = AArch64ISD::SUBS;
3475 
3476   if (isCMN(RHS, CC, DAG)) {
3477     // Can we combine a (CMP op1, (sub 0, op2) into a CMN instruction ?
3478     Opcode = AArch64ISD::ADDS;
3479     RHS = RHS.getOperand(1);
3480   } else if (LHS.getOpcode() == ISD::SUB && isNullConstant(LHS.getOperand(0)) &&
3481              isIntEqualitySetCC(CC)) {
3482     // As we are looking for EQ/NE compares, the operands can be commuted ; can
3483     // we combine a (CMP (sub 0, op1), op2) into a CMN instruction ?
3484     Opcode = AArch64ISD::ADDS;
3485     LHS = LHS.getOperand(1);
3486   } else if (isNullConstant(RHS) && !isUnsignedIntSetCC(CC)) {
3487     if (LHS.getOpcode() == ISD::AND) {
3488       // Similarly, (CMP (and X, Y), 0) can be implemented with a TST
3489       // (a.k.a. ANDS) except that the flags are only guaranteed to work for one
3490       // of the signed comparisons.
3491       const SDValue ANDSNode = DAG.getNode(AArch64ISD::ANDS, dl,
3492                                            DAG.getVTList(VT, MVT_CC),
3493                                            LHS.getOperand(0),
3494                                            LHS.getOperand(1));
3495       // Replace all users of (and X, Y) with newly generated (ands X, Y)
3496       DAG.ReplaceAllUsesWith(LHS, ANDSNode);
3497       return ANDSNode.getValue(1);
3498     } else if (LHS.getOpcode() == AArch64ISD::ANDS) {
3499       // Use result of ANDS
3500       return LHS.getValue(1);
3501     }
3502   }
3503 
3504   return DAG.getNode(Opcode, dl, DAG.getVTList(VT, MVT_CC), LHS, RHS)
3505       .getValue(1);
3506 }
3507 
3508 /// \defgroup AArch64CCMP CMP;CCMP matching
3509 ///
3510 /// These functions deal with the formation of CMP;CCMP;... sequences.
3511 /// The CCMP/CCMN/FCCMP/FCCMPE instructions allow the conditional execution of
3512 /// a comparison. They set the NZCV flags to a predefined value if their
3513 /// predicate is false. This allows to express arbitrary conjunctions, for
3514 /// example "cmp 0 (and (setCA (cmp A)) (setCB (cmp B)))"
3515 /// expressed as:
3516 ///   cmp A
3517 ///   ccmp B, inv(CB), CA
3518 ///   check for CB flags
3519 ///
3520 /// This naturally lets us implement chains of AND operations with SETCC
3521 /// operands. And we can even implement some other situations by transforming
3522 /// them:
3523 ///   - We can implement (NEG SETCC) i.e. negating a single comparison by
3524 ///     negating the flags used in a CCMP/FCCMP operations.
3525 ///   - We can negate the result of a whole chain of CMP/CCMP/FCCMP operations
3526 ///     by negating the flags we test for afterwards. i.e.
3527 ///     NEG (CMP CCMP CCCMP ...) can be implemented.
3528 ///   - Note that we can only ever negate all previously processed results.
3529 ///     What we can not implement by flipping the flags to test is a negation
3530 ///     of two sub-trees (because the negation affects all sub-trees emitted so
3531 ///     far, so the 2nd sub-tree we emit would also affect the first).
3532 /// With those tools we can implement some OR operations:
3533 ///   - (OR (SETCC A) (SETCC B)) can be implemented via:
3534 ///     NEG (AND (NEG (SETCC A)) (NEG (SETCC B)))
3535 ///   - After transforming OR to NEG/AND combinations we may be able to use NEG
3536 ///     elimination rules from earlier to implement the whole thing as a
3537 ///     CCMP/FCCMP chain.
3538 ///
3539 /// As complete example:
3540 ///     or (or (setCA (cmp A)) (setCB (cmp B)))
3541 ///        (and (setCC (cmp C)) (setCD (cmp D)))"
3542 /// can be reassociated to:
3543 ///     or (and (setCC (cmp C)) setCD (cmp D))
3544 //         (or (setCA (cmp A)) (setCB (cmp B)))
3545 /// can be transformed to:
3546 ///     not (and (not (and (setCC (cmp C)) (setCD (cmp D))))
3547 ///              (and (not (setCA (cmp A)) (not (setCB (cmp B))))))"
3548 /// which can be implemented as:
3549 ///   cmp C
3550 ///   ccmp D, inv(CD), CC
3551 ///   ccmp A, CA, inv(CD)
3552 ///   ccmp B, CB, inv(CA)
3553 ///   check for CB flags
3554 ///
3555 /// A counterexample is "or (and A B) (and C D)" which translates to
3556 /// not (and (not (and (not A) (not B))) (not (and (not C) (not D)))), we
3557 /// can only implement 1 of the inner (not) operations, but not both!
3558 /// @{
3559 
3560 /// Create a conditional comparison; Use CCMP, CCMN or FCCMP as appropriate.
emitConditionalComparison(SDValue LHS,SDValue RHS,ISD::CondCode CC,SDValue CCOp,AArch64CC::CondCode Predicate,AArch64CC::CondCode OutCC,const SDLoc & DL,SelectionDAG & DAG)3561 static SDValue emitConditionalComparison(SDValue LHS, SDValue RHS,
3562                                          ISD::CondCode CC, SDValue CCOp,
3563                                          AArch64CC::CondCode Predicate,
3564                                          AArch64CC::CondCode OutCC,
3565                                          const SDLoc &DL, SelectionDAG &DAG) {
3566   unsigned Opcode = 0;
3567   const bool FullFP16 = DAG.getSubtarget<AArch64Subtarget>().hasFullFP16();
3568 
3569   if (LHS.getValueType().isFloatingPoint()) {
3570     assert(LHS.getValueType() != MVT::f128);
3571     if ((LHS.getValueType() == MVT::f16 && !FullFP16) ||
3572         LHS.getValueType() == MVT::bf16) {
3573       LHS = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, LHS);
3574       RHS = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, RHS);
3575     }
3576     Opcode = AArch64ISD::FCCMP;
3577   } else if (ConstantSDNode *Const = dyn_cast<ConstantSDNode>(RHS)) {
3578     APInt Imm = Const->getAPIntValue();
3579     if (Imm.isNegative() && Imm.sgt(-32)) {
3580       Opcode = AArch64ISD::CCMN;
3581       RHS = DAG.getConstant(Imm.abs(), DL, Const->getValueType(0));
3582     }
3583   } else if (isCMN(RHS, CC, DAG)) {
3584     Opcode = AArch64ISD::CCMN;
3585     RHS = RHS.getOperand(1);
3586   } else if (LHS.getOpcode() == ISD::SUB && isNullConstant(LHS.getOperand(0)) &&
3587              isIntEqualitySetCC(CC)) {
3588     // As we are looking for EQ/NE compares, the operands can be commuted ; can
3589     // we combine a (CCMP (sub 0, op1), op2) into a CCMN instruction ?
3590     Opcode = AArch64ISD::CCMN;
3591     LHS = LHS.getOperand(1);
3592   }
3593   if (Opcode == 0)
3594     Opcode = AArch64ISD::CCMP;
3595 
3596   SDValue Condition = DAG.getConstant(Predicate, DL, MVT_CC);
3597   AArch64CC::CondCode InvOutCC = AArch64CC::getInvertedCondCode(OutCC);
3598   unsigned NZCV = AArch64CC::getNZCVToSatisfyCondCode(InvOutCC);
3599   SDValue NZCVOp = DAG.getConstant(NZCV, DL, MVT::i32);
3600   return DAG.getNode(Opcode, DL, MVT_CC, LHS, RHS, NZCVOp, Condition, CCOp);
3601 }
3602 
3603 /// Returns true if @p Val is a tree of AND/OR/SETCC operations that can be
3604 /// expressed as a conjunction. See \ref AArch64CCMP.
3605 /// \param CanNegate    Set to true if we can negate the whole sub-tree just by
3606 ///                     changing the conditions on the SETCC tests.
3607 ///                     (this means we can call emitConjunctionRec() with
3608 ///                      Negate==true on this sub-tree)
3609 /// \param MustBeFirst  Set to true if this subtree needs to be negated and we
3610 ///                     cannot do the negation naturally. We are required to
3611 ///                     emit the subtree first in this case.
3612 /// \param WillNegate   Is true if are called when the result of this
3613 ///                     subexpression must be negated. This happens when the
3614 ///                     outer expression is an OR. We can use this fact to know
3615 ///                     that we have a double negation (or (or ...) ...) that
3616 ///                     can be implemented for free.
canEmitConjunction(const SDValue Val,bool & CanNegate,bool & MustBeFirst,bool WillNegate,unsigned Depth=0)3617 static bool canEmitConjunction(const SDValue Val, bool &CanNegate,
3618                                bool &MustBeFirst, bool WillNegate,
3619                                unsigned Depth = 0) {
3620   if (!Val.hasOneUse())
3621     return false;
3622   unsigned Opcode = Val->getOpcode();
3623   if (Opcode == ISD::SETCC) {
3624     if (Val->getOperand(0).getValueType() == MVT::f128)
3625       return false;
3626     CanNegate = true;
3627     MustBeFirst = false;
3628     return true;
3629   }
3630   // Protect against exponential runtime and stack overflow.
3631   if (Depth > 6)
3632     return false;
3633   if (Opcode == ISD::AND || Opcode == ISD::OR) {
3634     bool IsOR = Opcode == ISD::OR;
3635     SDValue O0 = Val->getOperand(0);
3636     SDValue O1 = Val->getOperand(1);
3637     bool CanNegateL;
3638     bool MustBeFirstL;
3639     if (!canEmitConjunction(O0, CanNegateL, MustBeFirstL, IsOR, Depth+1))
3640       return false;
3641     bool CanNegateR;
3642     bool MustBeFirstR;
3643     if (!canEmitConjunction(O1, CanNegateR, MustBeFirstR, IsOR, Depth+1))
3644       return false;
3645 
3646     if (MustBeFirstL && MustBeFirstR)
3647       return false;
3648 
3649     if (IsOR) {
3650       // For an OR expression we need to be able to naturally negate at least
3651       // one side or we cannot do the transformation at all.
3652       if (!CanNegateL && !CanNegateR)
3653         return false;
3654       // If we the result of the OR will be negated and we can naturally negate
3655       // the leafs, then this sub-tree as a whole negates naturally.
3656       CanNegate = WillNegate && CanNegateL && CanNegateR;
3657       // If we cannot naturally negate the whole sub-tree, then this must be
3658       // emitted first.
3659       MustBeFirst = !CanNegate;
3660     } else {
3661       assert(Opcode == ISD::AND && "Must be OR or AND");
3662       // We cannot naturally negate an AND operation.
3663       CanNegate = false;
3664       MustBeFirst = MustBeFirstL || MustBeFirstR;
3665     }
3666     return true;
3667   }
3668   return false;
3669 }
3670 
3671 /// Emit conjunction or disjunction tree with the CMP/FCMP followed by a chain
3672 /// of CCMP/CFCMP ops. See @ref AArch64CCMP.
3673 /// Tries to transform the given i1 producing node @p Val to a series compare
3674 /// and conditional compare operations. @returns an NZCV flags producing node
3675 /// and sets @p OutCC to the flags that should be tested or returns SDValue() if
3676 /// transformation was not possible.
3677 /// \p Negate is true if we want this sub-tree being negated just by changing
3678 /// SETCC conditions.
emitConjunctionRec(SelectionDAG & DAG,SDValue Val,AArch64CC::CondCode & OutCC,bool Negate,SDValue CCOp,AArch64CC::CondCode Predicate)3679 static SDValue emitConjunctionRec(SelectionDAG &DAG, SDValue Val,
3680     AArch64CC::CondCode &OutCC, bool Negate, SDValue CCOp,
3681     AArch64CC::CondCode Predicate) {
3682   // We're at a tree leaf, produce a conditional comparison operation.
3683   unsigned Opcode = Val->getOpcode();
3684   if (Opcode == ISD::SETCC) {
3685     SDValue LHS = Val->getOperand(0);
3686     SDValue RHS = Val->getOperand(1);
3687     ISD::CondCode CC = cast<CondCodeSDNode>(Val->getOperand(2))->get();
3688     bool isInteger = LHS.getValueType().isInteger();
3689     if (Negate)
3690       CC = getSetCCInverse(CC, LHS.getValueType());
3691     SDLoc DL(Val);
3692     // Determine OutCC and handle FP special case.
3693     if (isInteger) {
3694       OutCC = changeIntCCToAArch64CC(CC);
3695     } else {
3696       assert(LHS.getValueType().isFloatingPoint());
3697       AArch64CC::CondCode ExtraCC;
3698       changeFPCCToANDAArch64CC(CC, OutCC, ExtraCC);
3699       // Some floating point conditions can't be tested with a single condition
3700       // code. Construct an additional comparison in this case.
3701       if (ExtraCC != AArch64CC::AL) {
3702         SDValue ExtraCmp;
3703         if (!CCOp.getNode())
3704           ExtraCmp = emitComparison(LHS, RHS, CC, DL, DAG);
3705         else
3706           ExtraCmp = emitConditionalComparison(LHS, RHS, CC, CCOp, Predicate,
3707                                                ExtraCC, DL, DAG);
3708         CCOp = ExtraCmp;
3709         Predicate = ExtraCC;
3710       }
3711     }
3712 
3713     // Produce a normal comparison if we are first in the chain
3714     if (!CCOp)
3715       return emitComparison(LHS, RHS, CC, DL, DAG);
3716     // Otherwise produce a ccmp.
3717     return emitConditionalComparison(LHS, RHS, CC, CCOp, Predicate, OutCC, DL,
3718                                      DAG);
3719   }
3720   assert(Val->hasOneUse() && "Valid conjunction/disjunction tree");
3721 
3722   bool IsOR = Opcode == ISD::OR;
3723 
3724   SDValue LHS = Val->getOperand(0);
3725   bool CanNegateL;
3726   bool MustBeFirstL;
3727   bool ValidL = canEmitConjunction(LHS, CanNegateL, MustBeFirstL, IsOR);
3728   assert(ValidL && "Valid conjunction/disjunction tree");
3729   (void)ValidL;
3730 
3731   SDValue RHS = Val->getOperand(1);
3732   bool CanNegateR;
3733   bool MustBeFirstR;
3734   bool ValidR = canEmitConjunction(RHS, CanNegateR, MustBeFirstR, IsOR);
3735   assert(ValidR && "Valid conjunction/disjunction tree");
3736   (void)ValidR;
3737 
3738   // Swap sub-tree that must come first to the right side.
3739   if (MustBeFirstL) {
3740     assert(!MustBeFirstR && "Valid conjunction/disjunction tree");
3741     std::swap(LHS, RHS);
3742     std::swap(CanNegateL, CanNegateR);
3743     std::swap(MustBeFirstL, MustBeFirstR);
3744   }
3745 
3746   bool NegateR;
3747   bool NegateAfterR;
3748   bool NegateL;
3749   bool NegateAfterAll;
3750   if (Opcode == ISD::OR) {
3751     // Swap the sub-tree that we can negate naturally to the left.
3752     if (!CanNegateL) {
3753       assert(CanNegateR && "at least one side must be negatable");
3754       assert(!MustBeFirstR && "invalid conjunction/disjunction tree");
3755       assert(!Negate);
3756       std::swap(LHS, RHS);
3757       NegateR = false;
3758       NegateAfterR = true;
3759     } else {
3760       // Negate the left sub-tree if possible, otherwise negate the result.
3761       NegateR = CanNegateR;
3762       NegateAfterR = !CanNegateR;
3763     }
3764     NegateL = true;
3765     NegateAfterAll = !Negate;
3766   } else {
3767     assert(Opcode == ISD::AND && "Valid conjunction/disjunction tree");
3768     assert(!Negate && "Valid conjunction/disjunction tree");
3769 
3770     NegateL = false;
3771     NegateR = false;
3772     NegateAfterR = false;
3773     NegateAfterAll = false;
3774   }
3775 
3776   // Emit sub-trees.
3777   AArch64CC::CondCode RHSCC;
3778   SDValue CmpR = emitConjunctionRec(DAG, RHS, RHSCC, NegateR, CCOp, Predicate);
3779   if (NegateAfterR)
3780     RHSCC = AArch64CC::getInvertedCondCode(RHSCC);
3781   SDValue CmpL = emitConjunctionRec(DAG, LHS, OutCC, NegateL, CmpR, RHSCC);
3782   if (NegateAfterAll)
3783     OutCC = AArch64CC::getInvertedCondCode(OutCC);
3784   return CmpL;
3785 }
3786 
3787 /// Emit expression as a conjunction (a series of CCMP/CFCMP ops).
3788 /// In some cases this is even possible with OR operations in the expression.
3789 /// See \ref AArch64CCMP.
3790 /// \see emitConjunctionRec().
emitConjunction(SelectionDAG & DAG,SDValue Val,AArch64CC::CondCode & OutCC)3791 static SDValue emitConjunction(SelectionDAG &DAG, SDValue Val,
3792                                AArch64CC::CondCode &OutCC) {
3793   bool DummyCanNegate;
3794   bool DummyMustBeFirst;
3795   if (!canEmitConjunction(Val, DummyCanNegate, DummyMustBeFirst, false))
3796     return SDValue();
3797 
3798   return emitConjunctionRec(DAG, Val, OutCC, false, SDValue(), AArch64CC::AL);
3799 }
3800 
3801 /// @}
3802 
3803 /// Returns how profitable it is to fold a comparison's operand's shift and/or
3804 /// extension operations.
getCmpOperandFoldingProfit(SDValue Op)3805 static unsigned getCmpOperandFoldingProfit(SDValue Op) {
3806   auto isSupportedExtend = [&](SDValue V) {
3807     if (V.getOpcode() == ISD::SIGN_EXTEND_INREG)
3808       return true;
3809 
3810     if (V.getOpcode() == ISD::AND)
3811       if (ConstantSDNode *MaskCst = dyn_cast<ConstantSDNode>(V.getOperand(1))) {
3812         uint64_t Mask = MaskCst->getZExtValue();
3813         return (Mask == 0xFF || Mask == 0xFFFF || Mask == 0xFFFFFFFF);
3814       }
3815 
3816     return false;
3817   };
3818 
3819   if (!Op.hasOneUse())
3820     return 0;
3821 
3822   if (isSupportedExtend(Op))
3823     return 1;
3824 
3825   unsigned Opc = Op.getOpcode();
3826   if (Opc == ISD::SHL || Opc == ISD::SRL || Opc == ISD::SRA)
3827     if (ConstantSDNode *ShiftCst = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
3828       uint64_t Shift = ShiftCst->getZExtValue();
3829       if (isSupportedExtend(Op.getOperand(0)))
3830         return (Shift <= 4) ? 2 : 1;
3831       EVT VT = Op.getValueType();
3832       if ((VT == MVT::i32 && Shift <= 31) || (VT == MVT::i64 && Shift <= 63))
3833         return 1;
3834     }
3835 
3836   return 0;
3837 }
3838 
getAArch64Cmp(SDValue LHS,SDValue RHS,ISD::CondCode CC,SDValue & AArch64cc,SelectionDAG & DAG,const SDLoc & dl)3839 static SDValue getAArch64Cmp(SDValue LHS, SDValue RHS, ISD::CondCode CC,
3840                              SDValue &AArch64cc, SelectionDAG &DAG,
3841                              const SDLoc &dl) {
3842   if (ConstantSDNode *RHSC = dyn_cast<ConstantSDNode>(RHS.getNode())) {
3843     EVT VT = RHS.getValueType();
3844     uint64_t C = RHSC->getZExtValue();
3845     if (!isLegalArithImmed(C)) {
3846       // Constant does not fit, try adjusting it by one?
3847       switch (CC) {
3848       default:
3849         break;
3850       case ISD::SETLT:
3851       case ISD::SETGE:
3852         if ((VT == MVT::i32 && C != 0x80000000 &&
3853              isLegalArithImmed((uint32_t)(C - 1))) ||
3854             (VT == MVT::i64 && C != 0x80000000ULL &&
3855              isLegalArithImmed(C - 1ULL))) {
3856           CC = (CC == ISD::SETLT) ? ISD::SETLE : ISD::SETGT;
3857           C = (VT == MVT::i32) ? (uint32_t)(C - 1) : C - 1;
3858           RHS = DAG.getConstant(C, dl, VT);
3859         }
3860         break;
3861       case ISD::SETULT:
3862       case ISD::SETUGE:
3863         if ((VT == MVT::i32 && C != 0 &&
3864              isLegalArithImmed((uint32_t)(C - 1))) ||
3865             (VT == MVT::i64 && C != 0ULL && isLegalArithImmed(C - 1ULL))) {
3866           CC = (CC == ISD::SETULT) ? ISD::SETULE : ISD::SETUGT;
3867           C = (VT == MVT::i32) ? (uint32_t)(C - 1) : C - 1;
3868           RHS = DAG.getConstant(C, dl, VT);
3869         }
3870         break;
3871       case ISD::SETLE:
3872       case ISD::SETGT:
3873         if ((VT == MVT::i32 && C != INT32_MAX &&
3874              isLegalArithImmed((uint32_t)(C + 1))) ||
3875             (VT == MVT::i64 && C != INT64_MAX &&
3876              isLegalArithImmed(C + 1ULL))) {
3877           CC = (CC == ISD::SETLE) ? ISD::SETLT : ISD::SETGE;
3878           C = (VT == MVT::i32) ? (uint32_t)(C + 1) : C + 1;
3879           RHS = DAG.getConstant(C, dl, VT);
3880         }
3881         break;
3882       case ISD::SETULE:
3883       case ISD::SETUGT:
3884         if ((VT == MVT::i32 && C != UINT32_MAX &&
3885              isLegalArithImmed((uint32_t)(C + 1))) ||
3886             (VT == MVT::i64 && C != UINT64_MAX &&
3887              isLegalArithImmed(C + 1ULL))) {
3888           CC = (CC == ISD::SETULE) ? ISD::SETULT : ISD::SETUGE;
3889           C = (VT == MVT::i32) ? (uint32_t)(C + 1) : C + 1;
3890           RHS = DAG.getConstant(C, dl, VT);
3891         }
3892         break;
3893       }
3894     }
3895   }
3896 
3897   // Comparisons are canonicalized so that the RHS operand is simpler than the
3898   // LHS one, the extreme case being when RHS is an immediate. However, AArch64
3899   // can fold some shift+extend operations on the RHS operand, so swap the
3900   // operands if that can be done.
3901   //
3902   // For example:
3903   //    lsl     w13, w11, #1
3904   //    cmp     w13, w12
3905   // can be turned into:
3906   //    cmp     w12, w11, lsl #1
3907   if (!isa<ConstantSDNode>(RHS) ||
3908       !isLegalArithImmed(RHS->getAsAPIntVal().abs().getZExtValue())) {
3909     bool LHSIsCMN = isCMN(LHS, CC, DAG);
3910     bool RHSIsCMN = isCMN(RHS, CC, DAG);
3911     SDValue TheLHS = LHSIsCMN ? LHS.getOperand(1) : LHS;
3912     SDValue TheRHS = RHSIsCMN ? RHS.getOperand(1) : RHS;
3913 
3914     if (getCmpOperandFoldingProfit(TheLHS) + (LHSIsCMN ? 1 : 0) >
3915         getCmpOperandFoldingProfit(TheRHS) + (RHSIsCMN ? 1 : 0)) {
3916       std::swap(LHS, RHS);
3917       CC = ISD::getSetCCSwappedOperands(CC);
3918     }
3919   }
3920 
3921   SDValue Cmp;
3922   AArch64CC::CondCode AArch64CC;
3923   if (isIntEqualitySetCC(CC) && isa<ConstantSDNode>(RHS)) {
3924     const ConstantSDNode *RHSC = cast<ConstantSDNode>(RHS);
3925 
3926     // The imm operand of ADDS is an unsigned immediate, in the range 0 to 4095.
3927     // For the i8 operand, the largest immediate is 255, so this can be easily
3928     // encoded in the compare instruction. For the i16 operand, however, the
3929     // largest immediate cannot be encoded in the compare.
3930     // Therefore, use a sign extending load and cmn to avoid materializing the
3931     // -1 constant. For example,
3932     // movz w1, #65535
3933     // ldrh w0, [x0, #0]
3934     // cmp w0, w1
3935     // >
3936     // ldrsh w0, [x0, #0]
3937     // cmn w0, #1
3938     // Fundamental, we're relying on the property that (zext LHS) == (zext RHS)
3939     // if and only if (sext LHS) == (sext RHS). The checks are in place to
3940     // ensure both the LHS and RHS are truly zero extended and to make sure the
3941     // transformation is profitable.
3942     if ((RHSC->getZExtValue() >> 16 == 0) && isa<LoadSDNode>(LHS) &&
3943         cast<LoadSDNode>(LHS)->getExtensionType() == ISD::ZEXTLOAD &&
3944         cast<LoadSDNode>(LHS)->getMemoryVT() == MVT::i16 &&
3945         LHS.getNode()->hasNUsesOfValue(1, 0)) {
3946       int16_t ValueofRHS = RHS->getAsZExtVal();
3947       if (ValueofRHS < 0 && isLegalArithImmed(-ValueofRHS)) {
3948         SDValue SExt =
3949             DAG.getNode(ISD::SIGN_EXTEND_INREG, dl, LHS.getValueType(), LHS,
3950                         DAG.getValueType(MVT::i16));
3951         Cmp = emitComparison(SExt, DAG.getConstant(ValueofRHS, dl,
3952                                                    RHS.getValueType()),
3953                              CC, dl, DAG);
3954         AArch64CC = changeIntCCToAArch64CC(CC);
3955       }
3956     }
3957 
3958     if (!Cmp && (RHSC->isZero() || RHSC->isOne())) {
3959       if ((Cmp = emitConjunction(DAG, LHS, AArch64CC))) {
3960         if ((CC == ISD::SETNE) ^ RHSC->isZero())
3961           AArch64CC = AArch64CC::getInvertedCondCode(AArch64CC);
3962       }
3963     }
3964   }
3965 
3966   if (!Cmp) {
3967     Cmp = emitComparison(LHS, RHS, CC, dl, DAG);
3968     AArch64CC = changeIntCCToAArch64CC(CC);
3969   }
3970   AArch64cc = DAG.getConstant(AArch64CC, dl, MVT_CC);
3971   return Cmp;
3972 }
3973 
3974 static std::pair<SDValue, SDValue>
getAArch64XALUOOp(AArch64CC::CondCode & CC,SDValue Op,SelectionDAG & DAG)3975 getAArch64XALUOOp(AArch64CC::CondCode &CC, SDValue Op, SelectionDAG &DAG) {
3976   assert((Op.getValueType() == MVT::i32 || Op.getValueType() == MVT::i64) &&
3977          "Unsupported value type");
3978   SDValue Value, Overflow;
3979   SDLoc DL(Op);
3980   SDValue LHS = Op.getOperand(0);
3981   SDValue RHS = Op.getOperand(1);
3982   unsigned Opc = 0;
3983   switch (Op.getOpcode()) {
3984   default:
3985     llvm_unreachable("Unknown overflow instruction!");
3986   case ISD::SADDO:
3987     Opc = AArch64ISD::ADDS;
3988     CC = AArch64CC::VS;
3989     break;
3990   case ISD::UADDO:
3991     Opc = AArch64ISD::ADDS;
3992     CC = AArch64CC::HS;
3993     break;
3994   case ISD::SSUBO:
3995     Opc = AArch64ISD::SUBS;
3996     CC = AArch64CC::VS;
3997     break;
3998   case ISD::USUBO:
3999     Opc = AArch64ISD::SUBS;
4000     CC = AArch64CC::LO;
4001     break;
4002   // Multiply needs a little bit extra work.
4003   case ISD::SMULO:
4004   case ISD::UMULO: {
4005     CC = AArch64CC::NE;
4006     bool IsSigned = Op.getOpcode() == ISD::SMULO;
4007     if (Op.getValueType() == MVT::i32) {
4008       // Extend to 64-bits, then perform a 64-bit multiply.
4009       unsigned ExtendOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
4010       LHS = DAG.getNode(ExtendOpc, DL, MVT::i64, LHS);
4011       RHS = DAG.getNode(ExtendOpc, DL, MVT::i64, RHS);
4012       SDValue Mul = DAG.getNode(ISD::MUL, DL, MVT::i64, LHS, RHS);
4013       Value = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Mul);
4014 
4015       // Check that the result fits into a 32-bit integer.
4016       SDVTList VTs = DAG.getVTList(MVT::i64, MVT_CC);
4017       if (IsSigned) {
4018         // cmp xreg, wreg, sxtw
4019         SDValue SExtMul = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, Value);
4020         Overflow =
4021             DAG.getNode(AArch64ISD::SUBS, DL, VTs, Mul, SExtMul).getValue(1);
4022       } else {
4023         // tst xreg, #0xffffffff00000000
4024         SDValue UpperBits = DAG.getConstant(0xFFFFFFFF00000000, DL, MVT::i64);
4025         Overflow =
4026             DAG.getNode(AArch64ISD::ANDS, DL, VTs, Mul, UpperBits).getValue(1);
4027       }
4028       break;
4029     }
4030     assert(Op.getValueType() == MVT::i64 && "Expected an i64 value type");
4031     // For the 64 bit multiply
4032     Value = DAG.getNode(ISD::MUL, DL, MVT::i64, LHS, RHS);
4033     if (IsSigned) {
4034       SDValue UpperBits = DAG.getNode(ISD::MULHS, DL, MVT::i64, LHS, RHS);
4035       SDValue LowerBits = DAG.getNode(ISD::SRA, DL, MVT::i64, Value,
4036                                       DAG.getConstant(63, DL, MVT::i64));
4037       // It is important that LowerBits is last, otherwise the arithmetic
4038       // shift will not be folded into the compare (SUBS).
4039       SDVTList VTs = DAG.getVTList(MVT::i64, MVT::i32);
4040       Overflow = DAG.getNode(AArch64ISD::SUBS, DL, VTs, UpperBits, LowerBits)
4041                      .getValue(1);
4042     } else {
4043       SDValue UpperBits = DAG.getNode(ISD::MULHU, DL, MVT::i64, LHS, RHS);
4044       SDVTList VTs = DAG.getVTList(MVT::i64, MVT::i32);
4045       Overflow =
4046           DAG.getNode(AArch64ISD::SUBS, DL, VTs,
4047                       DAG.getConstant(0, DL, MVT::i64),
4048                       UpperBits).getValue(1);
4049     }
4050     break;
4051   }
4052   } // switch (...)
4053 
4054   if (Opc) {
4055     SDVTList VTs = DAG.getVTList(Op->getValueType(0), MVT::i32);
4056 
4057     // Emit the AArch64 operation with overflow check.
4058     Value = DAG.getNode(Opc, DL, VTs, LHS, RHS);
4059     Overflow = Value.getValue(1);
4060   }
4061   return std::make_pair(Value, Overflow);
4062 }
4063 
LowerXOR(SDValue Op,SelectionDAG & DAG) const4064 SDValue AArch64TargetLowering::LowerXOR(SDValue Op, SelectionDAG &DAG) const {
4065   if (useSVEForFixedLengthVectorVT(Op.getValueType(),
4066                                    !Subtarget->isNeonAvailable()))
4067     return LowerToScalableOp(Op, DAG);
4068 
4069   SDValue Sel = Op.getOperand(0);
4070   SDValue Other = Op.getOperand(1);
4071   SDLoc dl(Sel);
4072 
4073   // If the operand is an overflow checking operation, invert the condition
4074   // code and kill the Not operation. I.e., transform:
4075   // (xor (overflow_op_bool, 1))
4076   //   -->
4077   // (csel 1, 0, invert(cc), overflow_op_bool)
4078   // ... which later gets transformed to just a cset instruction with an
4079   // inverted condition code, rather than a cset + eor sequence.
4080   if (isOneConstant(Other) && ISD::isOverflowIntrOpRes(Sel)) {
4081     // Only lower legal XALUO ops.
4082     if (!DAG.getTargetLoweringInfo().isTypeLegal(Sel->getValueType(0)))
4083       return SDValue();
4084 
4085     SDValue TVal = DAG.getConstant(1, dl, MVT::i32);
4086     SDValue FVal = DAG.getConstant(0, dl, MVT::i32);
4087     AArch64CC::CondCode CC;
4088     SDValue Value, Overflow;
4089     std::tie(Value, Overflow) = getAArch64XALUOOp(CC, Sel.getValue(0), DAG);
4090     SDValue CCVal = DAG.getConstant(getInvertedCondCode(CC), dl, MVT::i32);
4091     return DAG.getNode(AArch64ISD::CSEL, dl, Op.getValueType(), TVal, FVal,
4092                        CCVal, Overflow);
4093   }
4094   // If neither operand is a SELECT_CC, give up.
4095   if (Sel.getOpcode() != ISD::SELECT_CC)
4096     std::swap(Sel, Other);
4097   if (Sel.getOpcode() != ISD::SELECT_CC)
4098     return Op;
4099 
4100   // The folding we want to perform is:
4101   // (xor x, (select_cc a, b, cc, 0, -1) )
4102   //   -->
4103   // (csel x, (xor x, -1), cc ...)
4104   //
4105   // The latter will get matched to a CSINV instruction.
4106 
4107   ISD::CondCode CC = cast<CondCodeSDNode>(Sel.getOperand(4))->get();
4108   SDValue LHS = Sel.getOperand(0);
4109   SDValue RHS = Sel.getOperand(1);
4110   SDValue TVal = Sel.getOperand(2);
4111   SDValue FVal = Sel.getOperand(3);
4112 
4113   // FIXME: This could be generalized to non-integer comparisons.
4114   if (LHS.getValueType() != MVT::i32 && LHS.getValueType() != MVT::i64)
4115     return Op;
4116 
4117   ConstantSDNode *CFVal = dyn_cast<ConstantSDNode>(FVal);
4118   ConstantSDNode *CTVal = dyn_cast<ConstantSDNode>(TVal);
4119 
4120   // The values aren't constants, this isn't the pattern we're looking for.
4121   if (!CFVal || !CTVal)
4122     return Op;
4123 
4124   // We can commute the SELECT_CC by inverting the condition.  This
4125   // might be needed to make this fit into a CSINV pattern.
4126   if (CTVal->isAllOnes() && CFVal->isZero()) {
4127     std::swap(TVal, FVal);
4128     std::swap(CTVal, CFVal);
4129     CC = ISD::getSetCCInverse(CC, LHS.getValueType());
4130   }
4131 
4132   // If the constants line up, perform the transform!
4133   if (CTVal->isZero() && CFVal->isAllOnes()) {
4134     SDValue CCVal;
4135     SDValue Cmp = getAArch64Cmp(LHS, RHS, CC, CCVal, DAG, dl);
4136 
4137     FVal = Other;
4138     TVal = DAG.getNode(ISD::XOR, dl, Other.getValueType(), Other,
4139                        DAG.getConstant(-1ULL, dl, Other.getValueType()));
4140 
4141     return DAG.getNode(AArch64ISD::CSEL, dl, Sel.getValueType(), FVal, TVal,
4142                        CCVal, Cmp);
4143   }
4144 
4145   return Op;
4146 }
4147 
4148 // If Invert is false, sets 'C' bit of NZCV to 0 if value is 0, else sets 'C'
4149 // bit to 1. If Invert is true, sets 'C' bit of NZCV to 1 if value is 0, else
4150 // sets 'C' bit to 0.
valueToCarryFlag(SDValue Value,SelectionDAG & DAG,bool Invert)4151 static SDValue valueToCarryFlag(SDValue Value, SelectionDAG &DAG, bool Invert) {
4152   SDLoc DL(Value);
4153   EVT VT = Value.getValueType();
4154   SDValue Op0 = Invert ? DAG.getConstant(0, DL, VT) : Value;
4155   SDValue Op1 = Invert ? Value : DAG.getConstant(1, DL, VT);
4156   SDValue Cmp =
4157       DAG.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, MVT::Glue), Op0, Op1);
4158   return Cmp.getValue(1);
4159 }
4160 
4161 // If Invert is false, value is 1 if 'C' bit of NZCV is 1, else 0.
4162 // If Invert is true, value is 0 if 'C' bit of NZCV is 1, else 1.
carryFlagToValue(SDValue Glue,EVT VT,SelectionDAG & DAG,bool Invert)4163 static SDValue carryFlagToValue(SDValue Glue, EVT VT, SelectionDAG &DAG,
4164                                 bool Invert) {
4165   assert(Glue.getResNo() == 1);
4166   SDLoc DL(Glue);
4167   SDValue Zero = DAG.getConstant(0, DL, VT);
4168   SDValue One = DAG.getConstant(1, DL, VT);
4169   unsigned Cond = Invert ? AArch64CC::LO : AArch64CC::HS;
4170   SDValue CC = DAG.getConstant(Cond, DL, MVT::i32);
4171   return DAG.getNode(AArch64ISD::CSEL, DL, VT, One, Zero, CC, Glue);
4172 }
4173 
4174 // Value is 1 if 'V' bit of NZCV is 1, else 0
overflowFlagToValue(SDValue Glue,EVT VT,SelectionDAG & DAG)4175 static SDValue overflowFlagToValue(SDValue Glue, EVT VT, SelectionDAG &DAG) {
4176   assert(Glue.getResNo() == 1);
4177   SDLoc DL(Glue);
4178   SDValue Zero = DAG.getConstant(0, DL, VT);
4179   SDValue One = DAG.getConstant(1, DL, VT);
4180   SDValue CC = DAG.getConstant(AArch64CC::VS, DL, MVT::i32);
4181   return DAG.getNode(AArch64ISD::CSEL, DL, VT, One, Zero, CC, Glue);
4182 }
4183 
4184 // This lowering is inefficient, but it will get cleaned up by
4185 // `foldOverflowCheck`
lowerADDSUBO_CARRY(SDValue Op,SelectionDAG & DAG,unsigned Opcode,bool IsSigned)4186 static SDValue lowerADDSUBO_CARRY(SDValue Op, SelectionDAG &DAG,
4187                                   unsigned Opcode, bool IsSigned) {
4188   EVT VT0 = Op.getValue(0).getValueType();
4189   EVT VT1 = Op.getValue(1).getValueType();
4190 
4191   if (VT0 != MVT::i32 && VT0 != MVT::i64)
4192     return SDValue();
4193 
4194   bool InvertCarry = Opcode == AArch64ISD::SBCS;
4195   SDValue OpLHS = Op.getOperand(0);
4196   SDValue OpRHS = Op.getOperand(1);
4197   SDValue OpCarryIn = valueToCarryFlag(Op.getOperand(2), DAG, InvertCarry);
4198 
4199   SDLoc DL(Op);
4200   SDVTList VTs = DAG.getVTList(VT0, VT1);
4201 
4202   SDValue Sum = DAG.getNode(Opcode, DL, DAG.getVTList(VT0, MVT::Glue), OpLHS,
4203                             OpRHS, OpCarryIn);
4204 
4205   SDValue OutFlag =
4206       IsSigned ? overflowFlagToValue(Sum.getValue(1), VT1, DAG)
4207                : carryFlagToValue(Sum.getValue(1), VT1, DAG, InvertCarry);
4208 
4209   return DAG.getNode(ISD::MERGE_VALUES, DL, VTs, Sum, OutFlag);
4210 }
4211 
LowerXALUO(SDValue Op,SelectionDAG & DAG)4212 static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) {
4213   // Let legalize expand this if it isn't a legal type yet.
4214   if (!DAG.getTargetLoweringInfo().isTypeLegal(Op.getValueType()))
4215     return SDValue();
4216 
4217   SDLoc dl(Op);
4218   AArch64CC::CondCode CC;
4219   // The actual operation that sets the overflow or carry flag.
4220   SDValue Value, Overflow;
4221   std::tie(Value, Overflow) = getAArch64XALUOOp(CC, Op, DAG);
4222 
4223   // We use 0 and 1 as false and true values.
4224   SDValue TVal = DAG.getConstant(1, dl, MVT::i32);
4225   SDValue FVal = DAG.getConstant(0, dl, MVT::i32);
4226 
4227   // We use an inverted condition, because the conditional select is inverted
4228   // too. This will allow it to be selected to a single instruction:
4229   // CSINC Wd, WZR, WZR, invert(cond).
4230   SDValue CCVal = DAG.getConstant(getInvertedCondCode(CC), dl, MVT::i32);
4231   Overflow = DAG.getNode(AArch64ISD::CSEL, dl, MVT::i32, FVal, TVal,
4232                          CCVal, Overflow);
4233 
4234   SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::i32);
4235   return DAG.getNode(ISD::MERGE_VALUES, dl, VTs, Value, Overflow);
4236 }
4237 
4238 // Prefetch operands are:
4239 // 1: Address to prefetch
4240 // 2: bool isWrite
4241 // 3: int locality (0 = no locality ... 3 = extreme locality)
4242 // 4: bool isDataCache
LowerPREFETCH(SDValue Op,SelectionDAG & DAG)4243 static SDValue LowerPREFETCH(SDValue Op, SelectionDAG &DAG) {
4244   SDLoc DL(Op);
4245   unsigned IsWrite = Op.getConstantOperandVal(2);
4246   unsigned Locality = Op.getConstantOperandVal(3);
4247   unsigned IsData = Op.getConstantOperandVal(4);
4248 
4249   bool IsStream = !Locality;
4250   // When the locality number is set
4251   if (Locality) {
4252     // The front-end should have filtered out the out-of-range values
4253     assert(Locality <= 3 && "Prefetch locality out-of-range");
4254     // The locality degree is the opposite of the cache speed.
4255     // Put the number the other way around.
4256     // The encoding starts at 0 for level 1
4257     Locality = 3 - Locality;
4258   }
4259 
4260   // built the mask value encoding the expected behavior.
4261   unsigned PrfOp = (IsWrite << 4) |     // Load/Store bit
4262                    (!IsData << 3) |     // IsDataCache bit
4263                    (Locality << 1) |    // Cache level bits
4264                    (unsigned)IsStream;  // Stream bit
4265   return DAG.getNode(AArch64ISD::PREFETCH, DL, MVT::Other, Op.getOperand(0),
4266                      DAG.getTargetConstant(PrfOp, DL, MVT::i32),
4267                      Op.getOperand(1));
4268 }
4269 
LowerFP_EXTEND(SDValue Op,SelectionDAG & DAG) const4270 SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op,
4271                                               SelectionDAG &DAG) const {
4272   EVT VT = Op.getValueType();
4273   if (VT.isScalableVector())
4274     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_EXTEND_MERGE_PASSTHRU);
4275 
4276   if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()))
4277     return LowerFixedLengthFPExtendToSVE(Op, DAG);
4278 
4279   assert(Op.getValueType() == MVT::f128 && "Unexpected lowering");
4280   return SDValue();
4281 }
4282 
LowerFP_ROUND(SDValue Op,SelectionDAG & DAG) const4283 SDValue AArch64TargetLowering::LowerFP_ROUND(SDValue Op,
4284                                              SelectionDAG &DAG) const {
4285   EVT VT = Op.getValueType();
4286   if (VT.isScalableVector())
4287     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_ROUND_MERGE_PASSTHRU);
4288 
4289   bool IsStrict = Op->isStrictFPOpcode();
4290   SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
4291   EVT SrcVT = SrcVal.getValueType();
4292   bool Trunc = Op.getConstantOperandVal(IsStrict ? 2 : 1) == 1;
4293 
4294   if (useSVEForFixedLengthVectorVT(SrcVT, !Subtarget->isNeonAvailable()))
4295     return LowerFixedLengthFPRoundToSVE(Op, DAG);
4296 
4297   // Expand cases where the result type is BF16 but we don't have hardware
4298   // instructions to lower it.
4299   if (VT.getScalarType() == MVT::bf16 &&
4300       !((Subtarget->hasNEON() || Subtarget->hasSME()) &&
4301         Subtarget->hasBF16())) {
4302     SDLoc dl(Op);
4303     SDValue Narrow = SrcVal;
4304     SDValue NaN;
4305     EVT I32 = SrcVT.changeElementType(MVT::i32);
4306     EVT F32 = SrcVT.changeElementType(MVT::f32);
4307     if (SrcVT.getScalarType() == MVT::f32) {
4308       bool NeverSNaN = DAG.isKnownNeverSNaN(Narrow);
4309       Narrow = DAG.getNode(ISD::BITCAST, dl, I32, Narrow);
4310       if (!NeverSNaN) {
4311         // Set the quiet bit.
4312         NaN = DAG.getNode(ISD::OR, dl, I32, Narrow,
4313                           DAG.getConstant(0x400000, dl, I32));
4314       }
4315     } else if (SrcVT.getScalarType() == MVT::f64) {
4316       Narrow = DAG.getNode(AArch64ISD::FCVTXN, dl, F32, Narrow);
4317       Narrow = DAG.getNode(ISD::BITCAST, dl, I32, Narrow);
4318     } else {
4319       return SDValue();
4320     }
4321     if (!Trunc) {
4322       SDValue One = DAG.getConstant(1, dl, I32);
4323       SDValue Lsb = DAG.getNode(ISD::SRL, dl, I32, Narrow,
4324                                 DAG.getShiftAmountConstant(16, I32, dl));
4325       Lsb = DAG.getNode(ISD::AND, dl, I32, Lsb, One);
4326       SDValue RoundingBias =
4327           DAG.getNode(ISD::ADD, dl, I32, DAG.getConstant(0x7fff, dl, I32), Lsb);
4328       Narrow = DAG.getNode(ISD::ADD, dl, I32, Narrow, RoundingBias);
4329     }
4330 
4331     // Don't round if we had a NaN, we don't want to turn 0x7fffffff into
4332     // 0x80000000.
4333     if (NaN) {
4334       SDValue IsNaN = DAG.getSetCC(
4335           dl, getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), SrcVT),
4336           SrcVal, SrcVal, ISD::SETUO);
4337       Narrow = DAG.getSelect(dl, I32, IsNaN, NaN, Narrow);
4338     }
4339 
4340     // Now that we have rounded, shift the bits into position.
4341     Narrow = DAG.getNode(ISD::SRL, dl, I32, Narrow,
4342                          DAG.getShiftAmountConstant(16, I32, dl));
4343     if (VT.isVector()) {
4344       EVT I16 = I32.changeVectorElementType(MVT::i16);
4345       Narrow = DAG.getNode(ISD::TRUNCATE, dl, I16, Narrow);
4346       return DAG.getNode(ISD::BITCAST, dl, VT, Narrow);
4347     }
4348     Narrow = DAG.getNode(ISD::BITCAST, dl, F32, Narrow);
4349     SDValue Result = DAG.getTargetExtractSubreg(AArch64::hsub, dl, VT, Narrow);
4350     return IsStrict ? DAG.getMergeValues({Result, Op.getOperand(0)}, dl)
4351                     : Result;
4352   }
4353 
4354   if (SrcVT != MVT::f128) {
4355     // Expand cases where the input is a vector bigger than NEON.
4356     if (useSVEForFixedLengthVectorVT(SrcVT))
4357       return SDValue();
4358 
4359     // It's legal except when f128 is involved
4360     return Op;
4361   }
4362 
4363   return SDValue();
4364 }
4365 
LowerVectorFP_TO_INT(SDValue Op,SelectionDAG & DAG) const4366 SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
4367                                                     SelectionDAG &DAG) const {
4368   // Warning: We maintain cost tables in AArch64TargetTransformInfo.cpp.
4369   // Any additional optimization in this function should be recorded
4370   // in the cost tables.
4371   bool IsStrict = Op->isStrictFPOpcode();
4372   EVT InVT = Op.getOperand(IsStrict ? 1 : 0).getValueType();
4373   EVT VT = Op.getValueType();
4374 
4375   if (VT.isScalableVector()) {
4376     unsigned Opcode = Op.getOpcode() == ISD::FP_TO_UINT
4377                           ? AArch64ISD::FCVTZU_MERGE_PASSTHRU
4378                           : AArch64ISD::FCVTZS_MERGE_PASSTHRU;
4379     return LowerToPredicatedOp(Op, DAG, Opcode);
4380   }
4381 
4382   if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()) ||
4383       useSVEForFixedLengthVectorVT(InVT, !Subtarget->isNeonAvailable()))
4384     return LowerFixedLengthFPToIntToSVE(Op, DAG);
4385 
4386   unsigned NumElts = InVT.getVectorNumElements();
4387 
4388   // f16 conversions are promoted to f32 when full fp16 is not supported.
4389   if ((InVT.getVectorElementType() == MVT::f16 && !Subtarget->hasFullFP16()) ||
4390       InVT.getVectorElementType() == MVT::bf16) {
4391     MVT NewVT = MVT::getVectorVT(MVT::f32, NumElts);
4392     SDLoc dl(Op);
4393     if (IsStrict) {
4394       SDValue Ext = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NewVT, MVT::Other},
4395                                 {Op.getOperand(0), Op.getOperand(1)});
4396       return DAG.getNode(Op.getOpcode(), dl, {VT, MVT::Other},
4397                          {Ext.getValue(1), Ext.getValue(0)});
4398     }
4399     return DAG.getNode(
4400         Op.getOpcode(), dl, Op.getValueType(),
4401         DAG.getNode(ISD::FP_EXTEND, dl, NewVT, Op.getOperand(0)));
4402   }
4403 
4404   uint64_t VTSize = VT.getFixedSizeInBits();
4405   uint64_t InVTSize = InVT.getFixedSizeInBits();
4406   if (VTSize < InVTSize) {
4407     SDLoc dl(Op);
4408     if (IsStrict) {
4409       InVT = InVT.changeVectorElementTypeToInteger();
4410       SDValue Cv = DAG.getNode(Op.getOpcode(), dl, {InVT, MVT::Other},
4411                                {Op.getOperand(0), Op.getOperand(1)});
4412       SDValue Trunc = DAG.getNode(ISD::TRUNCATE, dl, VT, Cv);
4413       return DAG.getMergeValues({Trunc, Cv.getValue(1)}, dl);
4414     }
4415     SDValue Cv =
4416         DAG.getNode(Op.getOpcode(), dl, InVT.changeVectorElementTypeToInteger(),
4417                     Op.getOperand(0));
4418     return DAG.getNode(ISD::TRUNCATE, dl, VT, Cv);
4419   }
4420 
4421   if (VTSize > InVTSize) {
4422     SDLoc dl(Op);
4423     MVT ExtVT =
4424         MVT::getVectorVT(MVT::getFloatingPointVT(VT.getScalarSizeInBits()),
4425                          VT.getVectorNumElements());
4426     if (IsStrict) {
4427       SDValue Ext = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {ExtVT, MVT::Other},
4428                                 {Op.getOperand(0), Op.getOperand(1)});
4429       return DAG.getNode(Op.getOpcode(), dl, {VT, MVT::Other},
4430                          {Ext.getValue(1), Ext.getValue(0)});
4431     }
4432     SDValue Ext = DAG.getNode(ISD::FP_EXTEND, dl, ExtVT, Op.getOperand(0));
4433     return DAG.getNode(Op.getOpcode(), dl, VT, Ext);
4434   }
4435 
4436   // Use a scalar operation for conversions between single-element vectors of
4437   // the same size.
4438   if (NumElts == 1) {
4439     SDLoc dl(Op);
4440     SDValue Extract = DAG.getNode(
4441         ISD::EXTRACT_VECTOR_ELT, dl, InVT.getScalarType(),
4442         Op.getOperand(IsStrict ? 1 : 0), DAG.getConstant(0, dl, MVT::i64));
4443     EVT ScalarVT = VT.getScalarType();
4444     if (IsStrict)
4445       return DAG.getNode(Op.getOpcode(), dl, {ScalarVT, MVT::Other},
4446                          {Op.getOperand(0), Extract});
4447     return DAG.getNode(Op.getOpcode(), dl, ScalarVT, Extract);
4448   }
4449 
4450   // Type changing conversions are illegal.
4451   return Op;
4452 }
4453 
LowerFP_TO_INT(SDValue Op,SelectionDAG & DAG) const4454 SDValue AArch64TargetLowering::LowerFP_TO_INT(SDValue Op,
4455                                               SelectionDAG &DAG) const {
4456   bool IsStrict = Op->isStrictFPOpcode();
4457   SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
4458 
4459   if (SrcVal.getValueType().isVector())
4460     return LowerVectorFP_TO_INT(Op, DAG);
4461 
4462   // f16 conversions are promoted to f32 when full fp16 is not supported.
4463   if ((SrcVal.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) ||
4464       SrcVal.getValueType() == MVT::bf16) {
4465     SDLoc dl(Op);
4466     if (IsStrict) {
4467       SDValue Ext =
4468           DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {MVT::f32, MVT::Other},
4469                       {Op.getOperand(0), SrcVal});
4470       return DAG.getNode(Op.getOpcode(), dl, {Op.getValueType(), MVT::Other},
4471                          {Ext.getValue(1), Ext.getValue(0)});
4472     }
4473     return DAG.getNode(
4474         Op.getOpcode(), dl, Op.getValueType(),
4475         DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, SrcVal));
4476   }
4477 
4478   if (SrcVal.getValueType() != MVT::f128) {
4479     // It's legal except when f128 is involved
4480     return Op;
4481   }
4482 
4483   return SDValue();
4484 }
4485 
4486 SDValue
LowerVectorFP_TO_INT_SAT(SDValue Op,SelectionDAG & DAG) const4487 AArch64TargetLowering::LowerVectorFP_TO_INT_SAT(SDValue Op,
4488                                                 SelectionDAG &DAG) const {
4489   // AArch64 FP-to-int conversions saturate to the destination element size, so
4490   // we can lower common saturating conversions to simple instructions.
4491   SDValue SrcVal = Op.getOperand(0);
4492   EVT SrcVT = SrcVal.getValueType();
4493   EVT DstVT = Op.getValueType();
4494   EVT SatVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
4495 
4496   uint64_t SrcElementWidth = SrcVT.getScalarSizeInBits();
4497   uint64_t DstElementWidth = DstVT.getScalarSizeInBits();
4498   uint64_t SatWidth = SatVT.getScalarSizeInBits();
4499   assert(SatWidth <= DstElementWidth &&
4500          "Saturation width cannot exceed result width");
4501 
4502   // TODO: Consider lowering to SVE operations, as in LowerVectorFP_TO_INT.
4503   // Currently, the `llvm.fpto[su]i.sat.*` intrinsics don't accept scalable
4504   // types, so this is hard to reach.
4505   if (DstVT.isScalableVector())
4506     return SDValue();
4507 
4508   EVT SrcElementVT = SrcVT.getVectorElementType();
4509 
4510   // In the absence of FP16 support, promote f16 to f32 and saturate the result.
4511   if ((SrcElementVT == MVT::f16 &&
4512        (!Subtarget->hasFullFP16() || DstElementWidth > 16)) ||
4513       SrcElementVT == MVT::bf16) {
4514     MVT F32VT = MVT::getVectorVT(MVT::f32, SrcVT.getVectorNumElements());
4515     SrcVal = DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), F32VT, SrcVal);
4516     SrcVT = F32VT;
4517     SrcElementVT = MVT::f32;
4518     SrcElementWidth = 32;
4519   } else if (SrcElementVT != MVT::f64 && SrcElementVT != MVT::f32 &&
4520              SrcElementVT != MVT::f16 && SrcElementVT != MVT::bf16)
4521     return SDValue();
4522 
4523   SDLoc DL(Op);
4524   // Expand to f64 if we are saturating to i64, to help produce keep the lanes
4525   // the same width and produce a fcvtzu.
4526   if (SatWidth == 64 && SrcElementWidth < 64) {
4527     MVT F64VT = MVT::getVectorVT(MVT::f64, SrcVT.getVectorNumElements());
4528     SrcVal = DAG.getNode(ISD::FP_EXTEND, DL, F64VT, SrcVal);
4529     SrcVT = F64VT;
4530     SrcElementVT = MVT::f64;
4531     SrcElementWidth = 64;
4532   }
4533   // Cases that we can emit directly.
4534   if (SrcElementWidth == DstElementWidth && SrcElementWidth == SatWidth)
4535     return DAG.getNode(Op.getOpcode(), DL, DstVT, SrcVal,
4536                        DAG.getValueType(DstVT.getScalarType()));
4537 
4538   // Otherwise we emit a cvt that saturates to a higher BW, and saturate the
4539   // result. This is only valid if the legal cvt is larger than the saturate
4540   // width. For double, as we don't have MIN/MAX, it can be simpler to scalarize
4541   // (at least until sqxtn is selected).
4542   if (SrcElementWidth < SatWidth || SrcElementVT == MVT::f64)
4543     return SDValue();
4544 
4545   EVT IntVT = SrcVT.changeVectorElementTypeToInteger();
4546   SDValue NativeCvt = DAG.getNode(Op.getOpcode(), DL, IntVT, SrcVal,
4547                                   DAG.getValueType(IntVT.getScalarType()));
4548   SDValue Sat;
4549   if (Op.getOpcode() == ISD::FP_TO_SINT_SAT) {
4550     SDValue MinC = DAG.getConstant(
4551         APInt::getSignedMaxValue(SatWidth).sext(SrcElementWidth), DL, IntVT);
4552     SDValue Min = DAG.getNode(ISD::SMIN, DL, IntVT, NativeCvt, MinC);
4553     SDValue MaxC = DAG.getConstant(
4554         APInt::getSignedMinValue(SatWidth).sext(SrcElementWidth), DL, IntVT);
4555     Sat = DAG.getNode(ISD::SMAX, DL, IntVT, Min, MaxC);
4556   } else {
4557     SDValue MinC = DAG.getConstant(
4558         APInt::getAllOnes(SatWidth).zext(SrcElementWidth), DL, IntVT);
4559     Sat = DAG.getNode(ISD::UMIN, DL, IntVT, NativeCvt, MinC);
4560   }
4561 
4562   return DAG.getNode(ISD::TRUNCATE, DL, DstVT, Sat);
4563 }
4564 
LowerFP_TO_INT_SAT(SDValue Op,SelectionDAG & DAG) const4565 SDValue AArch64TargetLowering::LowerFP_TO_INT_SAT(SDValue Op,
4566                                                   SelectionDAG &DAG) const {
4567   // AArch64 FP-to-int conversions saturate to the destination register size, so
4568   // we can lower common saturating conversions to simple instructions.
4569   SDValue SrcVal = Op.getOperand(0);
4570   EVT SrcVT = SrcVal.getValueType();
4571 
4572   if (SrcVT.isVector())
4573     return LowerVectorFP_TO_INT_SAT(Op, DAG);
4574 
4575   EVT DstVT = Op.getValueType();
4576   EVT SatVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
4577   uint64_t SatWidth = SatVT.getScalarSizeInBits();
4578   uint64_t DstWidth = DstVT.getScalarSizeInBits();
4579   assert(SatWidth <= DstWidth && "Saturation width cannot exceed result width");
4580 
4581   // In the absence of FP16 support, promote f16 to f32 and saturate the result.
4582   if ((SrcVT == MVT::f16 && !Subtarget->hasFullFP16()) || SrcVT == MVT::bf16) {
4583     SrcVal = DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), MVT::f32, SrcVal);
4584     SrcVT = MVT::f32;
4585   } else if (SrcVT != MVT::f64 && SrcVT != MVT::f32 && SrcVT != MVT::f16 &&
4586              SrcVT != MVT::bf16)
4587     return SDValue();
4588 
4589   SDLoc DL(Op);
4590   // Cases that we can emit directly.
4591   if ((SrcVT == MVT::f64 || SrcVT == MVT::f32 ||
4592        (SrcVT == MVT::f16 && Subtarget->hasFullFP16())) &&
4593       DstVT == SatVT && (DstVT == MVT::i64 || DstVT == MVT::i32))
4594     return DAG.getNode(Op.getOpcode(), DL, DstVT, SrcVal,
4595                        DAG.getValueType(DstVT));
4596 
4597   // Otherwise we emit a cvt that saturates to a higher BW, and saturate the
4598   // result. This is only valid if the legal cvt is larger than the saturate
4599   // width.
4600   if (DstWidth < SatWidth)
4601     return SDValue();
4602 
4603   SDValue NativeCvt =
4604       DAG.getNode(Op.getOpcode(), DL, DstVT, SrcVal, DAG.getValueType(DstVT));
4605   SDValue Sat;
4606   if (Op.getOpcode() == ISD::FP_TO_SINT_SAT) {
4607     SDValue MinC = DAG.getConstant(
4608         APInt::getSignedMaxValue(SatWidth).sext(DstWidth), DL, DstVT);
4609     SDValue Min = DAG.getNode(ISD::SMIN, DL, DstVT, NativeCvt, MinC);
4610     SDValue MaxC = DAG.getConstant(
4611         APInt::getSignedMinValue(SatWidth).sext(DstWidth), DL, DstVT);
4612     Sat = DAG.getNode(ISD::SMAX, DL, DstVT, Min, MaxC);
4613   } else {
4614     SDValue MinC = DAG.getConstant(
4615         APInt::getAllOnes(SatWidth).zext(DstWidth), DL, DstVT);
4616     Sat = DAG.getNode(ISD::UMIN, DL, DstVT, NativeCvt, MinC);
4617   }
4618 
4619   return DAG.getNode(ISD::TRUNCATE, DL, DstVT, Sat);
4620 }
4621 
LowerVectorXRINT(SDValue Op,SelectionDAG & DAG) const4622 SDValue AArch64TargetLowering::LowerVectorXRINT(SDValue Op,
4623                                                 SelectionDAG &DAG) const {
4624   EVT VT = Op.getValueType();
4625   SDValue Src = Op.getOperand(0);
4626   SDLoc DL(Op);
4627 
4628   assert(VT.isVector() && "Expected vector type");
4629 
4630   EVT CastVT =
4631       VT.changeVectorElementType(Src.getValueType().getVectorElementType());
4632 
4633   // Round the floating-point value into a floating-point register with the
4634   // current rounding mode.
4635   SDValue FOp = DAG.getNode(ISD::FRINT, DL, CastVT, Src);
4636 
4637   // Truncate the rounded floating point to an integer.
4638   return DAG.getNode(ISD::FP_TO_SINT_SAT, DL, VT, FOp,
4639                      DAG.getValueType(VT.getVectorElementType()));
4640 }
4641 
LowerVectorINT_TO_FP(SDValue Op,SelectionDAG & DAG) const4642 SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
4643                                                     SelectionDAG &DAG) const {
4644   // Warning: We maintain cost tables in AArch64TargetTransformInfo.cpp.
4645   // Any additional optimization in this function should be recorded
4646   // in the cost tables.
4647   bool IsStrict = Op->isStrictFPOpcode();
4648   EVT VT = Op.getValueType();
4649   SDLoc dl(Op);
4650   SDValue In = Op.getOperand(IsStrict ? 1 : 0);
4651   EVT InVT = In.getValueType();
4652   unsigned Opc = Op.getOpcode();
4653   bool IsSigned = Opc == ISD::SINT_TO_FP || Opc == ISD::STRICT_SINT_TO_FP;
4654 
4655   if (VT.isScalableVector()) {
4656     if (InVT.getVectorElementType() == MVT::i1) {
4657       // We can't directly extend an SVE predicate; extend it first.
4658       unsigned CastOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
4659       EVT CastVT = getPromotedVTForPredicate(InVT);
4660       In = DAG.getNode(CastOpc, dl, CastVT, In);
4661       return DAG.getNode(Opc, dl, VT, In);
4662     }
4663 
4664     unsigned Opcode = IsSigned ? AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU
4665                                : AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU;
4666     return LowerToPredicatedOp(Op, DAG, Opcode);
4667   }
4668 
4669   if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()) ||
4670       useSVEForFixedLengthVectorVT(InVT, !Subtarget->isNeonAvailable()))
4671     return LowerFixedLengthIntToFPToSVE(Op, DAG);
4672 
4673   // Promote bf16 conversions to f32.
4674   if (VT.getVectorElementType() == MVT::bf16) {
4675     EVT F32 = VT.changeElementType(MVT::f32);
4676     if (IsStrict) {
4677       SDValue Val = DAG.getNode(Op.getOpcode(), dl, {F32, MVT::Other},
4678                                 {Op.getOperand(0), In});
4679       return DAG.getNode(
4680           ISD::STRICT_FP_ROUND, dl, {Op.getValueType(), MVT::Other},
4681           {Val.getValue(1), Val.getValue(0), DAG.getIntPtrConstant(0, dl)});
4682     }
4683     return DAG.getNode(ISD::FP_ROUND, dl, Op.getValueType(),
4684                        DAG.getNode(Op.getOpcode(), dl, F32, In),
4685                        DAG.getIntPtrConstant(0, dl));
4686   }
4687 
4688   uint64_t VTSize = VT.getFixedSizeInBits();
4689   uint64_t InVTSize = InVT.getFixedSizeInBits();
4690   if (VTSize < InVTSize) {
4691     MVT CastVT =
4692         MVT::getVectorVT(MVT::getFloatingPointVT(InVT.getScalarSizeInBits()),
4693                          InVT.getVectorNumElements());
4694     if (IsStrict) {
4695       In = DAG.getNode(Opc, dl, {CastVT, MVT::Other},
4696                        {Op.getOperand(0), In});
4697       return DAG.getNode(
4698           ISD::STRICT_FP_ROUND, dl, {VT, MVT::Other},
4699           {In.getValue(1), In.getValue(0), DAG.getIntPtrConstant(0, dl)});
4700     }
4701     In = DAG.getNode(Opc, dl, CastVT, In);
4702     return DAG.getNode(ISD::FP_ROUND, dl, VT, In,
4703                        DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
4704   }
4705 
4706   if (VTSize > InVTSize) {
4707     unsigned CastOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
4708     EVT CastVT = VT.changeVectorElementTypeToInteger();
4709     In = DAG.getNode(CastOpc, dl, CastVT, In);
4710     if (IsStrict)
4711       return DAG.getNode(Opc, dl, {VT, MVT::Other}, {Op.getOperand(0), In});
4712     return DAG.getNode(Opc, dl, VT, In);
4713   }
4714 
4715   // Use a scalar operation for conversions between single-element vectors of
4716   // the same size.
4717   if (VT.getVectorNumElements() == 1) {
4718     SDValue Extract = DAG.getNode(
4719         ISD::EXTRACT_VECTOR_ELT, dl, InVT.getScalarType(),
4720         In, DAG.getConstant(0, dl, MVT::i64));
4721     EVT ScalarVT = VT.getScalarType();
4722     if (IsStrict)
4723       return DAG.getNode(Op.getOpcode(), dl, {ScalarVT, MVT::Other},
4724                          {Op.getOperand(0), Extract});
4725     return DAG.getNode(Op.getOpcode(), dl, ScalarVT, Extract);
4726   }
4727 
4728   return Op;
4729 }
4730 
LowerINT_TO_FP(SDValue Op,SelectionDAG & DAG) const4731 SDValue AArch64TargetLowering::LowerINT_TO_FP(SDValue Op,
4732                                             SelectionDAG &DAG) const {
4733   if (Op.getValueType().isVector())
4734     return LowerVectorINT_TO_FP(Op, DAG);
4735 
4736   bool IsStrict = Op->isStrictFPOpcode();
4737   SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
4738 
4739   bool IsSigned = Op->getOpcode() == ISD::STRICT_SINT_TO_FP ||
4740                   Op->getOpcode() == ISD::SINT_TO_FP;
4741 
4742   auto IntToFpViaPromotion = [&](EVT PromoteVT) {
4743     SDLoc dl(Op);
4744     if (IsStrict) {
4745       SDValue Val = DAG.getNode(Op.getOpcode(), dl, {PromoteVT, MVT::Other},
4746                                 {Op.getOperand(0), SrcVal});
4747       return DAG.getNode(
4748           ISD::STRICT_FP_ROUND, dl, {Op.getValueType(), MVT::Other},
4749           {Val.getValue(1), Val.getValue(0), DAG.getIntPtrConstant(0, dl)});
4750     }
4751     return DAG.getNode(ISD::FP_ROUND, dl, Op.getValueType(),
4752                        DAG.getNode(Op.getOpcode(), dl, PromoteVT, SrcVal),
4753                        DAG.getIntPtrConstant(0, dl));
4754   };
4755 
4756   if (Op.getValueType() == MVT::bf16) {
4757     unsigned MaxWidth = IsSigned
4758                             ? DAG.ComputeMaxSignificantBits(SrcVal)
4759                             : DAG.computeKnownBits(SrcVal).countMaxActiveBits();
4760     // bf16 conversions are promoted to f32 when converting from i16.
4761     if (MaxWidth <= 24) {
4762       return IntToFpViaPromotion(MVT::f32);
4763     }
4764 
4765     // bf16 conversions are promoted to f64 when converting from i32.
4766     if (MaxWidth <= 53) {
4767       return IntToFpViaPromotion(MVT::f64);
4768     }
4769 
4770     // We need to be careful about i64 -> bf16.
4771     // Consider an i32 22216703.
4772     // This number cannot be represented exactly as an f32 and so a itofp will
4773     // turn it into 22216704.0 fptrunc to bf16 will turn this into 22282240.0
4774     // However, the correct bf16 was supposed to be 22151168.0
4775     // We need to use sticky rounding to get this correct.
4776     if (SrcVal.getValueType() == MVT::i64) {
4777       SDLoc DL(Op);
4778       // This algorithm is equivalent to the following:
4779       // uint64_t SrcHi = SrcVal & ~0xfffull;
4780       // uint64_t SrcLo = SrcVal &  0xfffull;
4781       // uint64_t Highest = SrcVal >> 53;
4782       // bool HasHighest = Highest != 0;
4783       // uint64_t ToRound = HasHighest ? SrcHi : SrcVal;
4784       // double  Rounded = static_cast<double>(ToRound);
4785       // uint64_t RoundedBits = std::bit_cast<uint64_t>(Rounded);
4786       // uint64_t HasLo = SrcLo != 0;
4787       // bool NeedsAdjustment = HasHighest & HasLo;
4788       // uint64_t AdjustedBits = RoundedBits | uint64_t{NeedsAdjustment};
4789       // double Adjusted = std::bit_cast<double>(AdjustedBits);
4790       // return static_cast<__bf16>(Adjusted);
4791       //
4792       // Essentially, what happens is that SrcVal either fits perfectly in a
4793       // double-precision value or it is too big. If it is sufficiently small,
4794       // we should just go u64 -> double -> bf16 in a naive way. Otherwise, we
4795       // ensure that u64 -> double has no rounding error by only using the 52
4796       // MSB of the input. The low order bits will get merged into a sticky bit
4797       // which will avoid issues incurred by double rounding.
4798 
4799       // Signed conversion is more or less like so:
4800       // copysign((__bf16)abs(SrcVal), SrcVal)
4801       SDValue SignBit;
4802       if (IsSigned) {
4803         SignBit = DAG.getNode(ISD::AND, DL, MVT::i64, SrcVal,
4804                               DAG.getConstant(1ull << 63, DL, MVT::i64));
4805         SrcVal = DAG.getNode(ISD::ABS, DL, MVT::i64, SrcVal);
4806       }
4807       SDValue SrcHi = DAG.getNode(ISD::AND, DL, MVT::i64, SrcVal,
4808                                   DAG.getConstant(~0xfffull, DL, MVT::i64));
4809       SDValue SrcLo = DAG.getNode(ISD::AND, DL, MVT::i64, SrcVal,
4810                                   DAG.getConstant(0xfffull, DL, MVT::i64));
4811       SDValue Highest =
4812           DAG.getNode(ISD::SRL, DL, MVT::i64, SrcVal,
4813                       DAG.getShiftAmountConstant(53, MVT::i64, DL));
4814       SDValue Zero64 = DAG.getConstant(0, DL, MVT::i64);
4815       SDValue ToRound =
4816           DAG.getSelectCC(DL, Highest, Zero64, SrcHi, SrcVal, ISD::SETNE);
4817       SDValue Rounded =
4818           IsStrict ? DAG.getNode(Op.getOpcode(), DL, {MVT::f64, MVT::Other},
4819                                  {Op.getOperand(0), ToRound})
4820                    : DAG.getNode(Op.getOpcode(), DL, MVT::f64, ToRound);
4821 
4822       SDValue RoundedBits = DAG.getNode(ISD::BITCAST, DL, MVT::i64, Rounded);
4823       if (SignBit) {
4824         RoundedBits = DAG.getNode(ISD::OR, DL, MVT::i64, RoundedBits, SignBit);
4825       }
4826 
4827       SDValue HasHighest = DAG.getSetCC(
4828           DL,
4829           getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), MVT::i64),
4830           Highest, Zero64, ISD::SETNE);
4831 
4832       SDValue HasLo = DAG.getSetCC(
4833           DL,
4834           getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), MVT::i64),
4835           SrcLo, Zero64, ISD::SETNE);
4836 
4837       SDValue NeedsAdjustment =
4838           DAG.getNode(ISD::AND, DL, HasLo.getValueType(), HasHighest, HasLo);
4839       NeedsAdjustment = DAG.getZExtOrTrunc(NeedsAdjustment, DL, MVT::i64);
4840 
4841       SDValue AdjustedBits =
4842           DAG.getNode(ISD::OR, DL, MVT::i64, RoundedBits, NeedsAdjustment);
4843       SDValue Adjusted = DAG.getNode(ISD::BITCAST, DL, MVT::f64, AdjustedBits);
4844       return IsStrict
4845                  ? DAG.getNode(ISD::STRICT_FP_ROUND, DL,
4846                                {Op.getValueType(), MVT::Other},
4847                                {Rounded.getValue(1), Adjusted,
4848                                 DAG.getIntPtrConstant(0, DL)})
4849                  : DAG.getNode(ISD::FP_ROUND, DL, Op.getValueType(), Adjusted,
4850                                DAG.getIntPtrConstant(0, DL, true));
4851     }
4852   }
4853 
4854   // f16 conversions are promoted to f32 when full fp16 is not supported.
4855   if (Op.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) {
4856     return IntToFpViaPromotion(MVT::f32);
4857   }
4858 
4859   // i128 conversions are libcalls.
4860   if (SrcVal.getValueType() == MVT::i128)
4861     return SDValue();
4862 
4863   // Other conversions are legal, unless it's to the completely software-based
4864   // fp128.
4865   if (Op.getValueType() != MVT::f128)
4866     return Op;
4867   return SDValue();
4868 }
4869 
LowerFSINCOS(SDValue Op,SelectionDAG & DAG) const4870 SDValue AArch64TargetLowering::LowerFSINCOS(SDValue Op,
4871                                             SelectionDAG &DAG) const {
4872   // For iOS, we want to call an alternative entry point: __sincos_stret,
4873   // which returns the values in two S / D registers.
4874   SDLoc dl(Op);
4875   SDValue Arg = Op.getOperand(0);
4876   EVT ArgVT = Arg.getValueType();
4877   Type *ArgTy = ArgVT.getTypeForEVT(*DAG.getContext());
4878 
4879   ArgListTy Args;
4880   ArgListEntry Entry;
4881 
4882   Entry.Node = Arg;
4883   Entry.Ty = ArgTy;
4884   Entry.IsSExt = false;
4885   Entry.IsZExt = false;
4886   Args.push_back(Entry);
4887 
4888   RTLIB::Libcall LC = ArgVT == MVT::f64 ? RTLIB::SINCOS_STRET_F64
4889                                         : RTLIB::SINCOS_STRET_F32;
4890   const char *LibcallName = getLibcallName(LC);
4891   SDValue Callee =
4892       DAG.getExternalSymbol(LibcallName, getPointerTy(DAG.getDataLayout()));
4893 
4894   StructType *RetTy = StructType::get(ArgTy, ArgTy);
4895   TargetLowering::CallLoweringInfo CLI(DAG);
4896   CLI.setDebugLoc(dl)
4897       .setChain(DAG.getEntryNode())
4898       .setLibCallee(CallingConv::Fast, RetTy, Callee, std::move(Args));
4899 
4900   std::pair<SDValue, SDValue> CallResult = LowerCallTo(CLI);
4901   return CallResult.first;
4902 }
4903 
4904 static MVT getSVEContainerType(EVT ContentTy);
4905 
LowerBITCAST(SDValue Op,SelectionDAG & DAG) const4906 SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op,
4907                                             SelectionDAG &DAG) const {
4908   EVT OpVT = Op.getValueType();
4909   EVT ArgVT = Op.getOperand(0).getValueType();
4910 
4911   if (useSVEForFixedLengthVectorVT(OpVT))
4912     return LowerFixedLengthBitcastToSVE(Op, DAG);
4913 
4914   if (OpVT.isScalableVector()) {
4915     // Bitcasting between unpacked vector types of different element counts is
4916     // not a NOP because the live elements are laid out differently.
4917     //                01234567
4918     // e.g. nxv2i32 = XX??XX??
4919     //      nxv4f16 = X?X?X?X?
4920     if (OpVT.getVectorElementCount() != ArgVT.getVectorElementCount())
4921       return SDValue();
4922 
4923     if (isTypeLegal(OpVT) && !isTypeLegal(ArgVT)) {
4924       assert(OpVT.isFloatingPoint() && !ArgVT.isFloatingPoint() &&
4925              "Expected int->fp bitcast!");
4926       SDValue ExtResult =
4927           DAG.getNode(ISD::ANY_EXTEND, SDLoc(Op), getSVEContainerType(ArgVT),
4928                       Op.getOperand(0));
4929       return getSVESafeBitCast(OpVT, ExtResult, DAG);
4930     }
4931     return getSVESafeBitCast(OpVT, Op.getOperand(0), DAG);
4932   }
4933 
4934   if (OpVT != MVT::f16 && OpVT != MVT::bf16)
4935     return SDValue();
4936 
4937   // Bitcasts between f16 and bf16 are legal.
4938   if (ArgVT == MVT::f16 || ArgVT == MVT::bf16)
4939     return Op;
4940 
4941   assert(ArgVT == MVT::i16);
4942   SDLoc DL(Op);
4943 
4944   Op = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op.getOperand(0));
4945   Op = DAG.getNode(ISD::BITCAST, DL, MVT::f32, Op);
4946   return DAG.getTargetExtractSubreg(AArch64::hsub, DL, OpVT, Op);
4947 }
4948 
getExtensionTo64Bits(const EVT & OrigVT)4949 static EVT getExtensionTo64Bits(const EVT &OrigVT) {
4950   if (OrigVT.getSizeInBits() >= 64)
4951     return OrigVT;
4952 
4953   assert(OrigVT.isSimple() && "Expecting a simple value type");
4954 
4955   MVT::SimpleValueType OrigSimpleTy = OrigVT.getSimpleVT().SimpleTy;
4956   switch (OrigSimpleTy) {
4957   default: llvm_unreachable("Unexpected Vector Type");
4958   case MVT::v2i8:
4959   case MVT::v2i16:
4960      return MVT::v2i32;
4961   case MVT::v4i8:
4962     return  MVT::v4i16;
4963   }
4964 }
4965 
addRequiredExtensionForVectorMULL(SDValue N,SelectionDAG & DAG,const EVT & OrigTy,const EVT & ExtTy,unsigned ExtOpcode)4966 static SDValue addRequiredExtensionForVectorMULL(SDValue N, SelectionDAG &DAG,
4967                                                  const EVT &OrigTy,
4968                                                  const EVT &ExtTy,
4969                                                  unsigned ExtOpcode) {
4970   // The vector originally had a size of OrigTy. It was then extended to ExtTy.
4971   // We expect the ExtTy to be 128-bits total. If the OrigTy is less than
4972   // 64-bits we need to insert a new extension so that it will be 64-bits.
4973   assert(ExtTy.is128BitVector() && "Unexpected extension size");
4974   if (OrigTy.getSizeInBits() >= 64)
4975     return N;
4976 
4977   // Must extend size to at least 64 bits to be used as an operand for VMULL.
4978   EVT NewVT = getExtensionTo64Bits(OrigTy);
4979 
4980   return DAG.getNode(ExtOpcode, SDLoc(N), NewVT, N);
4981 }
4982 
4983 // Returns lane if Op extracts from a two-element vector and lane is constant
4984 // (i.e., extractelt(<2 x Ty> %v, ConstantLane)), and std::nullopt otherwise.
4985 static std::optional<uint64_t>
getConstantLaneNumOfExtractHalfOperand(SDValue & Op)4986 getConstantLaneNumOfExtractHalfOperand(SDValue &Op) {
4987   SDNode *OpNode = Op.getNode();
4988   if (OpNode->getOpcode() != ISD::EXTRACT_VECTOR_ELT)
4989     return std::nullopt;
4990 
4991   EVT VT = OpNode->getOperand(0).getValueType();
4992   ConstantSDNode *C = dyn_cast<ConstantSDNode>(OpNode->getOperand(1));
4993   if (!VT.isFixedLengthVector() || VT.getVectorNumElements() != 2 || !C)
4994     return std::nullopt;
4995 
4996   return C->getZExtValue();
4997 }
4998 
isExtendedBUILD_VECTOR(SDValue N,SelectionDAG & DAG,bool isSigned)4999 static bool isExtendedBUILD_VECTOR(SDValue N, SelectionDAG &DAG,
5000                                    bool isSigned) {
5001   EVT VT = N.getValueType();
5002 
5003   if (N.getOpcode() != ISD::BUILD_VECTOR)
5004     return false;
5005 
5006   for (const SDValue &Elt : N->op_values()) {
5007     if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Elt)) {
5008       unsigned EltSize = VT.getScalarSizeInBits();
5009       unsigned HalfSize = EltSize / 2;
5010       if (isSigned) {
5011         if (!isIntN(HalfSize, C->getSExtValue()))
5012           return false;
5013       } else {
5014         if (!isUIntN(HalfSize, C->getZExtValue()))
5015           return false;
5016       }
5017       continue;
5018     }
5019     return false;
5020   }
5021 
5022   return true;
5023 }
5024 
skipExtensionForVectorMULL(SDValue N,SelectionDAG & DAG)5025 static SDValue skipExtensionForVectorMULL(SDValue N, SelectionDAG &DAG) {
5026   EVT VT = N.getValueType();
5027   assert(VT.is128BitVector() && "Unexpected vector MULL size");
5028 
5029   unsigned NumElts = VT.getVectorNumElements();
5030   unsigned OrigEltSize = VT.getScalarSizeInBits();
5031   unsigned EltSize = OrigEltSize / 2;
5032   MVT TruncVT = MVT::getVectorVT(MVT::getIntegerVT(EltSize), NumElts);
5033 
5034   APInt HiBits = APInt::getHighBitsSet(OrigEltSize, EltSize);
5035   if (DAG.MaskedValueIsZero(N, HiBits))
5036     return DAG.getNode(ISD::TRUNCATE, SDLoc(N), TruncVT, N);
5037 
5038   if (ISD::isExtOpcode(N.getOpcode()))
5039     return addRequiredExtensionForVectorMULL(N.getOperand(0), DAG,
5040                                              N.getOperand(0).getValueType(), VT,
5041                                              N.getOpcode());
5042 
5043   assert(N.getOpcode() == ISD::BUILD_VECTOR && "expected BUILD_VECTOR");
5044   SDLoc dl(N);
5045   SmallVector<SDValue, 8> Ops;
5046   for (unsigned i = 0; i != NumElts; ++i) {
5047     const APInt &CInt = N.getConstantOperandAPInt(i);
5048     // Element types smaller than 32 bits are not legal, so use i32 elements.
5049     // The values are implicitly truncated so sext vs. zext doesn't matter.
5050     Ops.push_back(DAG.getConstant(CInt.zextOrTrunc(32), dl, MVT::i32));
5051   }
5052   return DAG.getBuildVector(TruncVT, dl, Ops);
5053 }
5054 
isSignExtended(SDValue N,SelectionDAG & DAG)5055 static bool isSignExtended(SDValue N, SelectionDAG &DAG) {
5056   return N.getOpcode() == ISD::SIGN_EXTEND ||
5057          N.getOpcode() == ISD::ANY_EXTEND ||
5058          isExtendedBUILD_VECTOR(N, DAG, true);
5059 }
5060 
isZeroExtended(SDValue N,SelectionDAG & DAG)5061 static bool isZeroExtended(SDValue N, SelectionDAG &DAG) {
5062   return N.getOpcode() == ISD::ZERO_EXTEND ||
5063          N.getOpcode() == ISD::ANY_EXTEND ||
5064          isExtendedBUILD_VECTOR(N, DAG, false);
5065 }
5066 
isAddSubSExt(SDValue N,SelectionDAG & DAG)5067 static bool isAddSubSExt(SDValue N, SelectionDAG &DAG) {
5068   unsigned Opcode = N.getOpcode();
5069   if (Opcode == ISD::ADD || Opcode == ISD::SUB) {
5070     SDValue N0 = N.getOperand(0);
5071     SDValue N1 = N.getOperand(1);
5072     return N0->hasOneUse() && N1->hasOneUse() &&
5073       isSignExtended(N0, DAG) && isSignExtended(N1, DAG);
5074   }
5075   return false;
5076 }
5077 
isAddSubZExt(SDValue N,SelectionDAG & DAG)5078 static bool isAddSubZExt(SDValue N, SelectionDAG &DAG) {
5079   unsigned Opcode = N.getOpcode();
5080   if (Opcode == ISD::ADD || Opcode == ISD::SUB) {
5081     SDValue N0 = N.getOperand(0);
5082     SDValue N1 = N.getOperand(1);
5083     return N0->hasOneUse() && N1->hasOneUse() &&
5084       isZeroExtended(N0, DAG) && isZeroExtended(N1, DAG);
5085   }
5086   return false;
5087 }
5088 
LowerGET_ROUNDING(SDValue Op,SelectionDAG & DAG) const5089 SDValue AArch64TargetLowering::LowerGET_ROUNDING(SDValue Op,
5090                                                  SelectionDAG &DAG) const {
5091   // The rounding mode is in bits 23:22 of the FPSCR.
5092   // The ARM rounding mode value to FLT_ROUNDS mapping is 0->1, 1->2, 2->3, 3->0
5093   // The formula we use to implement this is (((FPSCR + 1 << 22) >> 22) & 3)
5094   // so that the shift + and get folded into a bitfield extract.
5095   SDLoc dl(Op);
5096 
5097   SDValue Chain = Op.getOperand(0);
5098   SDValue FPCR_64 = DAG.getNode(
5099       ISD::INTRINSIC_W_CHAIN, dl, {MVT::i64, MVT::Other},
5100       {Chain, DAG.getConstant(Intrinsic::aarch64_get_fpcr, dl, MVT::i64)});
5101   Chain = FPCR_64.getValue(1);
5102   SDValue FPCR_32 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, FPCR_64);
5103   SDValue FltRounds = DAG.getNode(ISD::ADD, dl, MVT::i32, FPCR_32,
5104                                   DAG.getConstant(1U << 22, dl, MVT::i32));
5105   SDValue RMODE = DAG.getNode(ISD::SRL, dl, MVT::i32, FltRounds,
5106                               DAG.getConstant(22, dl, MVT::i32));
5107   SDValue AND = DAG.getNode(ISD::AND, dl, MVT::i32, RMODE,
5108                             DAG.getConstant(3, dl, MVT::i32));
5109   return DAG.getMergeValues({AND, Chain}, dl);
5110 }
5111 
LowerSET_ROUNDING(SDValue Op,SelectionDAG & DAG) const5112 SDValue AArch64TargetLowering::LowerSET_ROUNDING(SDValue Op,
5113                                                  SelectionDAG &DAG) const {
5114   SDLoc DL(Op);
5115   SDValue Chain = Op->getOperand(0);
5116   SDValue RMValue = Op->getOperand(1);
5117 
5118   // The rounding mode is in bits 23:22 of the FPCR.
5119   // The llvm.set.rounding argument value to the rounding mode in FPCR mapping
5120   // is 0->3, 1->0, 2->1, 3->2. The formula we use to implement this is
5121   // ((arg - 1) & 3) << 22).
5122   //
5123   // The argument of llvm.set.rounding must be within the segment [0, 3], so
5124   // NearestTiesToAway (4) is not handled here. It is responsibility of the code
5125   // generated llvm.set.rounding to ensure this condition.
5126 
5127   // Calculate new value of FPCR[23:22].
5128   RMValue = DAG.getNode(ISD::SUB, DL, MVT::i32, RMValue,
5129                         DAG.getConstant(1, DL, MVT::i32));
5130   RMValue = DAG.getNode(ISD::AND, DL, MVT::i32, RMValue,
5131                         DAG.getConstant(0x3, DL, MVT::i32));
5132   RMValue =
5133       DAG.getNode(ISD::SHL, DL, MVT::i32, RMValue,
5134                   DAG.getConstant(AArch64::RoundingBitsPos, DL, MVT::i32));
5135   RMValue = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, RMValue);
5136 
5137   // Get current value of FPCR.
5138   SDValue Ops[] = {
5139       Chain, DAG.getTargetConstant(Intrinsic::aarch64_get_fpcr, DL, MVT::i64)};
5140   SDValue FPCR =
5141       DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL, {MVT::i64, MVT::Other}, Ops);
5142   Chain = FPCR.getValue(1);
5143   FPCR = FPCR.getValue(0);
5144 
5145   // Put new rounding mode into FPSCR[23:22].
5146   const int RMMask = ~(AArch64::Rounding::rmMask << AArch64::RoundingBitsPos);
5147   FPCR = DAG.getNode(ISD::AND, DL, MVT::i64, FPCR,
5148                      DAG.getConstant(RMMask, DL, MVT::i64));
5149   FPCR = DAG.getNode(ISD::OR, DL, MVT::i64, FPCR, RMValue);
5150   SDValue Ops2[] = {
5151       Chain, DAG.getTargetConstant(Intrinsic::aarch64_set_fpcr, DL, MVT::i64),
5152       FPCR};
5153   return DAG.getNode(ISD::INTRINSIC_VOID, DL, MVT::Other, Ops2);
5154 }
5155 
LowerGET_FPMODE(SDValue Op,SelectionDAG & DAG) const5156 SDValue AArch64TargetLowering::LowerGET_FPMODE(SDValue Op,
5157                                                SelectionDAG &DAG) const {
5158   SDLoc DL(Op);
5159   SDValue Chain = Op->getOperand(0);
5160 
5161   // Get current value of FPCR.
5162   SDValue Ops[] = {
5163       Chain, DAG.getTargetConstant(Intrinsic::aarch64_get_fpcr, DL, MVT::i64)};
5164   SDValue FPCR =
5165       DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL, {MVT::i64, MVT::Other}, Ops);
5166   Chain = FPCR.getValue(1);
5167   FPCR = FPCR.getValue(0);
5168 
5169   // Truncate FPCR to 32 bits.
5170   SDValue Result = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, FPCR);
5171 
5172   return DAG.getMergeValues({Result, Chain}, DL);
5173 }
5174 
LowerSET_FPMODE(SDValue Op,SelectionDAG & DAG) const5175 SDValue AArch64TargetLowering::LowerSET_FPMODE(SDValue Op,
5176                                                SelectionDAG &DAG) const {
5177   SDLoc DL(Op);
5178   SDValue Chain = Op->getOperand(0);
5179   SDValue Mode = Op->getOperand(1);
5180 
5181   // Extend the specified value to 64 bits.
5182   SDValue FPCR = DAG.getZExtOrTrunc(Mode, DL, MVT::i64);
5183 
5184   // Set new value of FPCR.
5185   SDValue Ops2[] = {
5186       Chain, DAG.getConstant(Intrinsic::aarch64_set_fpcr, DL, MVT::i64), FPCR};
5187   return DAG.getNode(ISD::INTRINSIC_VOID, DL, MVT::Other, Ops2);
5188 }
5189 
LowerRESET_FPMODE(SDValue Op,SelectionDAG & DAG) const5190 SDValue AArch64TargetLowering::LowerRESET_FPMODE(SDValue Op,
5191                                                  SelectionDAG &DAG) const {
5192   SDLoc DL(Op);
5193   SDValue Chain = Op->getOperand(0);
5194 
5195   // Get current value of FPCR.
5196   SDValue Ops[] = {
5197       Chain, DAG.getTargetConstant(Intrinsic::aarch64_get_fpcr, DL, MVT::i64)};
5198   SDValue FPCR =
5199       DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL, {MVT::i64, MVT::Other}, Ops);
5200   Chain = FPCR.getValue(1);
5201   FPCR = FPCR.getValue(0);
5202 
5203   // Clear bits that are not reserved.
5204   SDValue FPSCRMasked = DAG.getNode(
5205       ISD::AND, DL, MVT::i64, FPCR,
5206       DAG.getConstant(AArch64::ReservedFPControlBits, DL, MVT::i64));
5207 
5208   // Set new value of FPCR.
5209   SDValue Ops2[] = {Chain,
5210                     DAG.getConstant(Intrinsic::aarch64_set_fpcr, DL, MVT::i64),
5211                     FPSCRMasked};
5212   return DAG.getNode(ISD::INTRINSIC_VOID, DL, MVT::Other, Ops2);
5213 }
5214 
selectUmullSmull(SDValue & N0,SDValue & N1,SelectionDAG & DAG,SDLoc DL,bool & IsMLA)5215 static unsigned selectUmullSmull(SDValue &N0, SDValue &N1, SelectionDAG &DAG,
5216                                  SDLoc DL, bool &IsMLA) {
5217   bool IsN0SExt = isSignExtended(N0, DAG);
5218   bool IsN1SExt = isSignExtended(N1, DAG);
5219   if (IsN0SExt && IsN1SExt)
5220     return AArch64ISD::SMULL;
5221 
5222   bool IsN0ZExt = isZeroExtended(N0, DAG);
5223   bool IsN1ZExt = isZeroExtended(N1, DAG);
5224 
5225   if (IsN0ZExt && IsN1ZExt)
5226     return AArch64ISD::UMULL;
5227 
5228   // Select SMULL if we can replace zext with sext.
5229   if (((IsN0SExt && IsN1ZExt) || (IsN0ZExt && IsN1SExt)) &&
5230       !isExtendedBUILD_VECTOR(N0, DAG, false) &&
5231       !isExtendedBUILD_VECTOR(N1, DAG, false)) {
5232     SDValue ZextOperand;
5233     if (IsN0ZExt)
5234       ZextOperand = N0.getOperand(0);
5235     else
5236       ZextOperand = N1.getOperand(0);
5237     if (DAG.SignBitIsZero(ZextOperand)) {
5238       SDValue NewSext =
5239           DAG.getSExtOrTrunc(ZextOperand, DL, N0.getValueType());
5240       if (IsN0ZExt)
5241         N0 = NewSext;
5242       else
5243         N1 = NewSext;
5244       return AArch64ISD::SMULL;
5245     }
5246   }
5247 
5248   // Select UMULL if we can replace the other operand with an extend.
5249   if (IsN0ZExt || IsN1ZExt) {
5250     EVT VT = N0.getValueType();
5251     APInt Mask = APInt::getHighBitsSet(VT.getScalarSizeInBits(),
5252                                        VT.getScalarSizeInBits() / 2);
5253     if (DAG.MaskedValueIsZero(IsN0ZExt ? N1 : N0, Mask))
5254       return AArch64ISD::UMULL;
5255   }
5256 
5257   if (!IsN1SExt && !IsN1ZExt)
5258     return 0;
5259 
5260   // Look for (s/zext A + s/zext B) * (s/zext C). We want to turn these
5261   // into (s/zext A * s/zext C) + (s/zext B * s/zext C)
5262   if (IsN1SExt && isAddSubSExt(N0, DAG)) {
5263     IsMLA = true;
5264     return AArch64ISD::SMULL;
5265   }
5266   if (IsN1ZExt && isAddSubZExt(N0, DAG)) {
5267     IsMLA = true;
5268     return AArch64ISD::UMULL;
5269   }
5270   if (IsN0ZExt && isAddSubZExt(N1, DAG)) {
5271     std::swap(N0, N1);
5272     IsMLA = true;
5273     return AArch64ISD::UMULL;
5274   }
5275   return 0;
5276 }
5277 
LowerMUL(SDValue Op,SelectionDAG & DAG) const5278 SDValue AArch64TargetLowering::LowerMUL(SDValue Op, SelectionDAG &DAG) const {
5279   EVT VT = Op.getValueType();
5280 
5281   bool OverrideNEON = !Subtarget->isNeonAvailable();
5282   if (VT.isScalableVector() || useSVEForFixedLengthVectorVT(VT, OverrideNEON))
5283     return LowerToPredicatedOp(Op, DAG, AArch64ISD::MUL_PRED);
5284 
5285   // Multiplications are only custom-lowered for 128-bit and 64-bit vectors so
5286   // that VMULL can be detected.  Otherwise v2i64 multiplications are not legal.
5287   assert((VT.is128BitVector() || VT.is64BitVector()) && VT.isInteger() &&
5288          "unexpected type for custom-lowering ISD::MUL");
5289   SDValue N0 = Op.getOperand(0);
5290   SDValue N1 = Op.getOperand(1);
5291   bool isMLA = false;
5292   EVT OVT = VT;
5293   if (VT.is64BitVector()) {
5294     if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
5295         isNullConstant(N0.getOperand(1)) &&
5296         N1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
5297         isNullConstant(N1.getOperand(1))) {
5298       N0 = N0.getOperand(0);
5299       N1 = N1.getOperand(0);
5300       VT = N0.getValueType();
5301     } else {
5302       if (VT == MVT::v1i64) {
5303         if (Subtarget->hasSVE())
5304           return LowerToPredicatedOp(Op, DAG, AArch64ISD::MUL_PRED);
5305         // Fall through to expand this.  It is not legal.
5306         return SDValue();
5307       } else
5308         // Other vector multiplications are legal.
5309         return Op;
5310     }
5311   }
5312 
5313   SDLoc DL(Op);
5314   unsigned NewOpc = selectUmullSmull(N0, N1, DAG, DL, isMLA);
5315 
5316   if (!NewOpc) {
5317     if (VT.getVectorElementType() == MVT::i64) {
5318       // If SVE is available then i64 vector multiplications can also be made
5319       // legal.
5320       if (Subtarget->hasSVE())
5321         return LowerToPredicatedOp(Op, DAG, AArch64ISD::MUL_PRED);
5322       // Fall through to expand this.  It is not legal.
5323       return SDValue();
5324     } else
5325       // Other vector multiplications are legal.
5326       return Op;
5327   }
5328 
5329   // Legalize to a S/UMULL instruction
5330   SDValue Op0;
5331   SDValue Op1 = skipExtensionForVectorMULL(N1, DAG);
5332   if (!isMLA) {
5333     Op0 = skipExtensionForVectorMULL(N0, DAG);
5334     assert(Op0.getValueType().is64BitVector() &&
5335            Op1.getValueType().is64BitVector() &&
5336            "unexpected types for extended operands to VMULL");
5337     return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, OVT,
5338                        DAG.getNode(NewOpc, DL, VT, Op0, Op1),
5339                        DAG.getConstant(0, DL, MVT::i64));
5340   }
5341   // Optimizing (zext A + zext B) * C, to (S/UMULL A, C) + (S/UMULL B, C) during
5342   // isel lowering to take advantage of no-stall back to back s/umul + s/umla.
5343   // This is true for CPUs with accumulate forwarding such as Cortex-A53/A57
5344   SDValue N00 = skipExtensionForVectorMULL(N0.getOperand(0), DAG);
5345   SDValue N01 = skipExtensionForVectorMULL(N0.getOperand(1), DAG);
5346   EVT Op1VT = Op1.getValueType();
5347   return DAG.getNode(
5348       ISD::EXTRACT_SUBVECTOR, DL, OVT,
5349       DAG.getNode(N0.getOpcode(), DL, VT,
5350                   DAG.getNode(NewOpc, DL, VT,
5351                               DAG.getNode(ISD::BITCAST, DL, Op1VT, N00), Op1),
5352                   DAG.getNode(NewOpc, DL, VT,
5353                               DAG.getNode(ISD::BITCAST, DL, Op1VT, N01), Op1)),
5354       DAG.getConstant(0, DL, MVT::i64));
5355 }
5356 
getPTrue(SelectionDAG & DAG,SDLoc DL,EVT VT,int Pattern)5357 static inline SDValue getPTrue(SelectionDAG &DAG, SDLoc DL, EVT VT,
5358                                int Pattern) {
5359   if (VT == MVT::nxv1i1 && Pattern == AArch64SVEPredPattern::all)
5360     return DAG.getConstant(1, DL, MVT::nxv1i1);
5361   return DAG.getNode(AArch64ISD::PTRUE, DL, VT,
5362                      DAG.getTargetConstant(Pattern, DL, MVT::i32));
5363 }
5364 
optimizeIncrementingWhile(SDValue Op,SelectionDAG & DAG,bool IsSigned,bool IsEqual)5365 static SDValue optimizeIncrementingWhile(SDValue Op, SelectionDAG &DAG,
5366                                          bool IsSigned, bool IsEqual) {
5367   if (!isa<ConstantSDNode>(Op.getOperand(1)) ||
5368       !isa<ConstantSDNode>(Op.getOperand(2)))
5369     return SDValue();
5370 
5371   SDLoc dl(Op);
5372   APInt X = Op.getConstantOperandAPInt(1);
5373   APInt Y = Op.getConstantOperandAPInt(2);
5374   bool Overflow;
5375   APInt NumActiveElems =
5376       IsSigned ? Y.ssub_ov(X, Overflow) : Y.usub_ov(X, Overflow);
5377 
5378   if (Overflow)
5379     return SDValue();
5380 
5381   if (IsEqual) {
5382     APInt One(NumActiveElems.getBitWidth(), 1, IsSigned);
5383     NumActiveElems = IsSigned ? NumActiveElems.sadd_ov(One, Overflow)
5384                               : NumActiveElems.uadd_ov(One, Overflow);
5385     if (Overflow)
5386       return SDValue();
5387   }
5388 
5389   std::optional<unsigned> PredPattern =
5390       getSVEPredPatternFromNumElements(NumActiveElems.getZExtValue());
5391   unsigned MinSVEVectorSize = std::max(
5392       DAG.getSubtarget<AArch64Subtarget>().getMinSVEVectorSizeInBits(), 128u);
5393   unsigned ElementSize = 128 / Op.getValueType().getVectorMinNumElements();
5394   if (PredPattern != std::nullopt &&
5395       NumActiveElems.getZExtValue() <= (MinSVEVectorSize / ElementSize))
5396     return getPTrue(DAG, dl, Op.getValueType(), *PredPattern);
5397 
5398   return SDValue();
5399 }
5400 
5401 // Returns a safe bitcast between two scalable vector predicates, where
5402 // any newly created lanes from a widening bitcast are defined as zero.
getSVEPredicateBitCast(EVT VT,SDValue Op,SelectionDAG & DAG)5403 static SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) {
5404   SDLoc DL(Op);
5405   EVT InVT = Op.getValueType();
5406 
5407   assert(InVT.getVectorElementType() == MVT::i1 &&
5408          VT.getVectorElementType() == MVT::i1 &&
5409          "Expected a predicate-to-predicate bitcast");
5410   assert(VT.isScalableVector() && DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
5411          InVT.isScalableVector() &&
5412          DAG.getTargetLoweringInfo().isTypeLegal(InVT) &&
5413          "Only expect to cast between legal scalable predicate types!");
5414 
5415   // Return the operand if the cast isn't changing type,
5416   // e.g. <n x 16 x i1> -> <n x 16 x i1>
5417   if (InVT == VT)
5418     return Op;
5419 
5420   SDValue Reinterpret = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op);
5421 
5422   // We only have to zero the lanes if new lanes are being defined, e.g. when
5423   // casting from <vscale x 2 x i1> to <vscale x 16 x i1>. If this is not the
5424   // case (e.g. when casting from <vscale x 16 x i1> -> <vscale x 2 x i1>) then
5425   // we can return here.
5426   if (InVT.bitsGT(VT))
5427     return Reinterpret;
5428 
5429   // Check if the other lanes are already known to be zeroed by
5430   // construction.
5431   if (isZeroingInactiveLanes(Op))
5432     return Reinterpret;
5433 
5434   // Zero the newly introduced lanes.
5435   SDValue Mask = DAG.getConstant(1, DL, InVT);
5436   Mask = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Mask);
5437   return DAG.getNode(ISD::AND, DL, VT, Reinterpret, Mask);
5438 }
5439 
getRuntimePStateSM(SelectionDAG & DAG,SDValue Chain,SDLoc DL,EVT VT) const5440 SDValue AArch64TargetLowering::getRuntimePStateSM(SelectionDAG &DAG,
5441                                                   SDValue Chain, SDLoc DL,
5442                                                   EVT VT) const {
5443   SDValue Callee = DAG.getExternalSymbol("__arm_sme_state",
5444                                          getPointerTy(DAG.getDataLayout()));
5445   Type *Int64Ty = Type::getInt64Ty(*DAG.getContext());
5446   Type *RetTy = StructType::get(Int64Ty, Int64Ty);
5447   TargetLowering::CallLoweringInfo CLI(DAG);
5448   ArgListTy Args;
5449   CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
5450       CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2,
5451       RetTy, Callee, std::move(Args));
5452   std::pair<SDValue, SDValue> CallResult = LowerCallTo(CLI);
5453   SDValue Mask = DAG.getConstant(/*PSTATE.SM*/ 1, DL, MVT::i64);
5454   return DAG.getNode(ISD::AND, DL, MVT::i64, CallResult.first.getOperand(0),
5455                      Mask);
5456 }
5457 
5458 // Lower an SME LDR/STR ZA intrinsic
5459 // Case 1: If the vector number (vecnum) is an immediate in range, it gets
5460 // folded into the instruction
5461 //    ldr(%tileslice, %ptr, 11) -> ldr [%tileslice, 11], [%ptr, 11]
5462 // Case 2: If the vecnum is not an immediate, then it is used to modify the base
5463 // and tile slice registers
5464 //    ldr(%tileslice, %ptr, %vecnum)
5465 //    ->
5466 //    %svl = rdsvl
5467 //    %ptr2 = %ptr + %svl * %vecnum
5468 //    %tileslice2 = %tileslice + %vecnum
5469 //    ldr [%tileslice2, 0], [%ptr2, 0]
5470 // Case 3: If the vecnum is an immediate out of range, then the same is done as
5471 // case 2, but the base and slice registers are modified by the greatest
5472 // multiple of 15 lower than the vecnum and the remainder is folded into the
5473 // instruction. This means that successive loads and stores that are offset from
5474 // each other can share the same base and slice register updates.
5475 //    ldr(%tileslice, %ptr, 22)
5476 //    ldr(%tileslice, %ptr, 23)
5477 //    ->
5478 //    %svl = rdsvl
5479 //    %ptr2 = %ptr + %svl * 15
5480 //    %tileslice2 = %tileslice + 15
5481 //    ldr [%tileslice2, 7], [%ptr2, 7]
5482 //    ldr [%tileslice2, 8], [%ptr2, 8]
5483 // Case 4: If the vecnum is an add of an immediate, then the non-immediate
5484 // operand and the immediate can be folded into the instruction, like case 2.
5485 //    ldr(%tileslice, %ptr, %vecnum + 7)
5486 //    ldr(%tileslice, %ptr, %vecnum + 8)
5487 //    ->
5488 //    %svl = rdsvl
5489 //    %ptr2 = %ptr + %svl * %vecnum
5490 //    %tileslice2 = %tileslice + %vecnum
5491 //    ldr [%tileslice2, 7], [%ptr2, 7]
5492 //    ldr [%tileslice2, 8], [%ptr2, 8]
5493 // Case 5: The vecnum being an add of an immediate out of range is also handled,
5494 // in which case the same remainder logic as case 3 is used.
LowerSMELdrStr(SDValue N,SelectionDAG & DAG,bool IsLoad)5495 SDValue LowerSMELdrStr(SDValue N, SelectionDAG &DAG, bool IsLoad) {
5496   SDLoc DL(N);
5497 
5498   SDValue TileSlice = N->getOperand(2);
5499   SDValue Base = N->getOperand(3);
5500   SDValue VecNum = N->getOperand(4);
5501   int32_t ConstAddend = 0;
5502   SDValue VarAddend = VecNum;
5503 
5504   // If the vnum is an add of an immediate, we can fold it into the instruction
5505   if (VecNum.getOpcode() == ISD::ADD &&
5506       isa<ConstantSDNode>(VecNum.getOperand(1))) {
5507     ConstAddend = cast<ConstantSDNode>(VecNum.getOperand(1))->getSExtValue();
5508     VarAddend = VecNum.getOperand(0);
5509   } else if (auto ImmNode = dyn_cast<ConstantSDNode>(VecNum)) {
5510     ConstAddend = ImmNode->getSExtValue();
5511     VarAddend = SDValue();
5512   }
5513 
5514   int32_t ImmAddend = ConstAddend % 16;
5515   if (int32_t C = (ConstAddend - ImmAddend)) {
5516     SDValue CVal = DAG.getTargetConstant(C, DL, MVT::i32);
5517     VarAddend = VarAddend
5518                     ? DAG.getNode(ISD::ADD, DL, MVT::i32, {VarAddend, CVal})
5519                     : CVal;
5520   }
5521 
5522   if (VarAddend) {
5523     // Get the vector length that will be multiplied by vnum
5524     auto SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
5525                            DAG.getConstant(1, DL, MVT::i32));
5526 
5527     // Multiply SVL and vnum then add it to the base
5528     SDValue Mul = DAG.getNode(
5529         ISD::MUL, DL, MVT::i64,
5530         {SVL, DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, VarAddend)});
5531     Base = DAG.getNode(ISD::ADD, DL, MVT::i64, {Base, Mul});
5532     // Just add vnum to the tileslice
5533     TileSlice = DAG.getNode(ISD::ADD, DL, MVT::i32, {TileSlice, VarAddend});
5534   }
5535 
5536   return DAG.getNode(IsLoad ? AArch64ISD::SME_ZA_LDR : AArch64ISD::SME_ZA_STR,
5537                      DL, MVT::Other,
5538                      {/*Chain=*/N.getOperand(0), TileSlice, Base,
5539                       DAG.getTargetConstant(ImmAddend, DL, MVT::i32)});
5540 }
5541 
LowerINTRINSIC_VOID(SDValue Op,SelectionDAG & DAG) const5542 SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op,
5543                                                    SelectionDAG &DAG) const {
5544   unsigned IntNo = Op.getConstantOperandVal(1);
5545   SDLoc DL(Op);
5546   switch (IntNo) {
5547   default:
5548     return SDValue(); // Don't custom lower most intrinsics.
5549   case Intrinsic::aarch64_prefetch: {
5550     SDValue Chain = Op.getOperand(0);
5551     SDValue Addr = Op.getOperand(2);
5552 
5553     unsigned IsWrite = Op.getConstantOperandVal(3);
5554     unsigned Locality = Op.getConstantOperandVal(4);
5555     unsigned IsStream = Op.getConstantOperandVal(5);
5556     unsigned IsData = Op.getConstantOperandVal(6);
5557     unsigned PrfOp = (IsWrite << 4) |    // Load/Store bit
5558                      (!IsData << 3) |    // IsDataCache bit
5559                      (Locality << 1) |   // Cache level bits
5560                      (unsigned)IsStream; // Stream bit
5561 
5562     return DAG.getNode(AArch64ISD::PREFETCH, DL, MVT::Other, Chain,
5563                        DAG.getTargetConstant(PrfOp, DL, MVT::i32), Addr);
5564   }
5565   case Intrinsic::aarch64_sme_str:
5566   case Intrinsic::aarch64_sme_ldr: {
5567     return LowerSMELdrStr(Op, DAG, IntNo == Intrinsic::aarch64_sme_ldr);
5568   }
5569   case Intrinsic::aarch64_sme_za_enable:
5570     return DAG.getNode(
5571         AArch64ISD::SMSTART, DL, MVT::Other,
5572         Op->getOperand(0), // Chain
5573         DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
5574         DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
5575   case Intrinsic::aarch64_sme_za_disable:
5576     return DAG.getNode(
5577         AArch64ISD::SMSTOP, DL, MVT::Other,
5578         Op->getOperand(0), // Chain
5579         DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
5580         DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
5581   }
5582 }
5583 
LowerINTRINSIC_W_CHAIN(SDValue Op,SelectionDAG & DAG) const5584 SDValue AArch64TargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
5585                                                       SelectionDAG &DAG) const {
5586   unsigned IntNo = Op.getConstantOperandVal(1);
5587   SDLoc DL(Op);
5588   switch (IntNo) {
5589   default:
5590     return SDValue(); // Don't custom lower most intrinsics.
5591   case Intrinsic::aarch64_mops_memset_tag: {
5592     auto Node = cast<MemIntrinsicSDNode>(Op.getNode());
5593     SDValue Chain = Node->getChain();
5594     SDValue Dst = Op.getOperand(2);
5595     SDValue Val = Op.getOperand(3);
5596     Val = DAG.getAnyExtOrTrunc(Val, DL, MVT::i64);
5597     SDValue Size = Op.getOperand(4);
5598     auto Alignment = Node->getMemOperand()->getAlign();
5599     bool IsVol = Node->isVolatile();
5600     auto DstPtrInfo = Node->getPointerInfo();
5601 
5602     const auto &SDI =
5603         static_cast<const AArch64SelectionDAGInfo &>(DAG.getSelectionDAGInfo());
5604     SDValue MS =
5605         SDI.EmitMOPS(AArch64ISD::MOPS_MEMSET_TAGGING, DAG, DL, Chain, Dst, Val,
5606                      Size, Alignment, IsVol, DstPtrInfo, MachinePointerInfo{});
5607 
5608     // MOPS_MEMSET_TAGGING has 3 results (DstWb, SizeWb, Chain) whereas the
5609     // intrinsic has 2. So hide SizeWb using MERGE_VALUES. Otherwise
5610     // LowerOperationWrapper will complain that the number of results has
5611     // changed.
5612     return DAG.getMergeValues({MS.getValue(0), MS.getValue(2)}, DL);
5613   }
5614   }
5615 }
5616 
LowerINTRINSIC_WO_CHAIN(SDValue Op,SelectionDAG & DAG) const5617 SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
5618                                                      SelectionDAG &DAG) const {
5619   unsigned IntNo = Op.getConstantOperandVal(0);
5620   SDLoc dl(Op);
5621   switch (IntNo) {
5622   default: return SDValue();    // Don't custom lower most intrinsics.
5623   case Intrinsic::thread_pointer: {
5624     EVT PtrVT = getPointerTy(DAG.getDataLayout());
5625     return DAG.getNode(AArch64ISD::THREAD_POINTER, dl, PtrVT);
5626   }
5627   case Intrinsic::aarch64_neon_abs: {
5628     EVT Ty = Op.getValueType();
5629     if (Ty == MVT::i64) {
5630       SDValue Result = DAG.getNode(ISD::BITCAST, dl, MVT::v1i64,
5631                                    Op.getOperand(1));
5632       Result = DAG.getNode(ISD::ABS, dl, MVT::v1i64, Result);
5633       return DAG.getNode(ISD::BITCAST, dl, MVT::i64, Result);
5634     } else if (Ty.isVector() && Ty.isInteger() && isTypeLegal(Ty)) {
5635       return DAG.getNode(ISD::ABS, dl, Ty, Op.getOperand(1));
5636     } else {
5637       report_fatal_error("Unexpected type for AArch64 NEON intrinic");
5638     }
5639   }
5640   case Intrinsic::aarch64_neon_pmull64: {
5641     SDValue LHS = Op.getOperand(1);
5642     SDValue RHS = Op.getOperand(2);
5643 
5644     std::optional<uint64_t> LHSLane =
5645         getConstantLaneNumOfExtractHalfOperand(LHS);
5646     std::optional<uint64_t> RHSLane =
5647         getConstantLaneNumOfExtractHalfOperand(RHS);
5648 
5649     assert((!LHSLane || *LHSLane < 2) && "Expect lane to be None or 0 or 1");
5650     assert((!RHSLane || *RHSLane < 2) && "Expect lane to be None or 0 or 1");
5651 
5652     // 'aarch64_neon_pmull64' takes i64 parameters; while pmull/pmull2
5653     // instructions execute on SIMD registers. So canonicalize i64 to v1i64,
5654     // which ISel recognizes better. For example, generate a ldr into d*
5655     // registers as opposed to a GPR load followed by a fmov.
5656     auto TryVectorizeOperand = [](SDValue N, std::optional<uint64_t> NLane,
5657                                   std::optional<uint64_t> OtherLane,
5658                                   const SDLoc &dl,
5659                                   SelectionDAG &DAG) -> SDValue {
5660       // If the operand is an higher half itself, rewrite it to
5661       // extract_high_v2i64; this way aarch64_neon_pmull64 could
5662       // re-use the dag-combiner function with aarch64_neon_{pmull,smull,umull}.
5663       if (NLane && *NLane == 1)
5664         return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v1i64,
5665                            N.getOperand(0), DAG.getConstant(1, dl, MVT::i64));
5666 
5667       // Operand N is not a higher half but the other operand is.
5668       if (OtherLane && *OtherLane == 1) {
5669         // If this operand is a lower half, rewrite it to
5670         // extract_high_v2i64(duplane(<2 x Ty>, 0)). This saves a roundtrip to
5671         // align lanes of two operands. A roundtrip sequence (to move from lane
5672         // 1 to lane 0) is like this:
5673         //   mov x8, v0.d[1]
5674         //   fmov d0, x8
5675         if (NLane && *NLane == 0)
5676           return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v1i64,
5677                              DAG.getNode(AArch64ISD::DUPLANE64, dl, MVT::v2i64,
5678                                          N.getOperand(0),
5679                                          DAG.getConstant(0, dl, MVT::i64)),
5680                              DAG.getConstant(1, dl, MVT::i64));
5681 
5682         // Otherwise just dup from main to all lanes.
5683         return DAG.getNode(AArch64ISD::DUP, dl, MVT::v1i64, N);
5684       }
5685 
5686       // Neither operand is an extract of higher half, so codegen may just use
5687       // the non-high version of PMULL instruction. Use v1i64 to represent i64.
5688       assert(N.getValueType() == MVT::i64 &&
5689              "Intrinsic aarch64_neon_pmull64 requires i64 parameters");
5690       return DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v1i64, N);
5691     };
5692 
5693     LHS = TryVectorizeOperand(LHS, LHSLane, RHSLane, dl, DAG);
5694     RHS = TryVectorizeOperand(RHS, RHSLane, LHSLane, dl, DAG);
5695 
5696     return DAG.getNode(AArch64ISD::PMULL, dl, Op.getValueType(), LHS, RHS);
5697   }
5698   case Intrinsic::aarch64_neon_smax:
5699     return DAG.getNode(ISD::SMAX, dl, Op.getValueType(),
5700                        Op.getOperand(1), Op.getOperand(2));
5701   case Intrinsic::aarch64_neon_umax:
5702     return DAG.getNode(ISD::UMAX, dl, Op.getValueType(),
5703                        Op.getOperand(1), Op.getOperand(2));
5704   case Intrinsic::aarch64_neon_smin:
5705     return DAG.getNode(ISD::SMIN, dl, Op.getValueType(),
5706                        Op.getOperand(1), Op.getOperand(2));
5707   case Intrinsic::aarch64_neon_umin:
5708     return DAG.getNode(ISD::UMIN, dl, Op.getValueType(),
5709                        Op.getOperand(1), Op.getOperand(2));
5710   case Intrinsic::aarch64_neon_scalar_sqxtn:
5711   case Intrinsic::aarch64_neon_scalar_sqxtun:
5712   case Intrinsic::aarch64_neon_scalar_uqxtn: {
5713     assert(Op.getValueType() == MVT::i32 || Op.getValueType() == MVT::f32);
5714     if (Op.getValueType() == MVT::i32)
5715       return DAG.getNode(ISD::BITCAST, dl, MVT::i32,
5716                          DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, MVT::f32,
5717                                      Op.getOperand(0),
5718                                      DAG.getNode(ISD::BITCAST, dl, MVT::f64,
5719                                                  Op.getOperand(1))));
5720     return SDValue();
5721   }
5722   case Intrinsic::aarch64_sve_whilelo:
5723     return optimizeIncrementingWhile(Op, DAG, /*IsSigned=*/false,
5724                                      /*IsEqual=*/false);
5725   case Intrinsic::aarch64_sve_whilelt:
5726     return optimizeIncrementingWhile(Op, DAG, /*IsSigned=*/true,
5727                                      /*IsEqual=*/false);
5728   case Intrinsic::aarch64_sve_whilels:
5729     return optimizeIncrementingWhile(Op, DAG, /*IsSigned=*/false,
5730                                      /*IsEqual=*/true);
5731   case Intrinsic::aarch64_sve_whilele:
5732     return optimizeIncrementingWhile(Op, DAG, /*IsSigned=*/true,
5733                                      /*IsEqual=*/true);
5734   case Intrinsic::aarch64_sve_sunpkhi:
5735     return DAG.getNode(AArch64ISD::SUNPKHI, dl, Op.getValueType(),
5736                        Op.getOperand(1));
5737   case Intrinsic::aarch64_sve_sunpklo:
5738     return DAG.getNode(AArch64ISD::SUNPKLO, dl, Op.getValueType(),
5739                        Op.getOperand(1));
5740   case Intrinsic::aarch64_sve_uunpkhi:
5741     return DAG.getNode(AArch64ISD::UUNPKHI, dl, Op.getValueType(),
5742                        Op.getOperand(1));
5743   case Intrinsic::aarch64_sve_uunpklo:
5744     return DAG.getNode(AArch64ISD::UUNPKLO, dl, Op.getValueType(),
5745                        Op.getOperand(1));
5746   case Intrinsic::aarch64_sve_clasta_n:
5747     return DAG.getNode(AArch64ISD::CLASTA_N, dl, Op.getValueType(),
5748                        Op.getOperand(1), Op.getOperand(2), Op.getOperand(3));
5749   case Intrinsic::aarch64_sve_clastb_n:
5750     return DAG.getNode(AArch64ISD::CLASTB_N, dl, Op.getValueType(),
5751                        Op.getOperand(1), Op.getOperand(2), Op.getOperand(3));
5752   case Intrinsic::aarch64_sve_lasta:
5753     return DAG.getNode(AArch64ISD::LASTA, dl, Op.getValueType(),
5754                        Op.getOperand(1), Op.getOperand(2));
5755   case Intrinsic::aarch64_sve_lastb:
5756     return DAG.getNode(AArch64ISD::LASTB, dl, Op.getValueType(),
5757                        Op.getOperand(1), Op.getOperand(2));
5758   case Intrinsic::aarch64_sve_rev:
5759     return DAG.getNode(ISD::VECTOR_REVERSE, dl, Op.getValueType(),
5760                        Op.getOperand(1));
5761   case Intrinsic::aarch64_sve_tbl:
5762     return DAG.getNode(AArch64ISD::TBL, dl, Op.getValueType(),
5763                        Op.getOperand(1), Op.getOperand(2));
5764   case Intrinsic::aarch64_sve_trn1:
5765     return DAG.getNode(AArch64ISD::TRN1, dl, Op.getValueType(),
5766                        Op.getOperand(1), Op.getOperand(2));
5767   case Intrinsic::aarch64_sve_trn2:
5768     return DAG.getNode(AArch64ISD::TRN2, dl, Op.getValueType(),
5769                        Op.getOperand(1), Op.getOperand(2));
5770   case Intrinsic::aarch64_sve_uzp1:
5771     return DAG.getNode(AArch64ISD::UZP1, dl, Op.getValueType(),
5772                        Op.getOperand(1), Op.getOperand(2));
5773   case Intrinsic::aarch64_sve_uzp2:
5774     return DAG.getNode(AArch64ISD::UZP2, dl, Op.getValueType(),
5775                        Op.getOperand(1), Op.getOperand(2));
5776   case Intrinsic::aarch64_sve_zip1:
5777     return DAG.getNode(AArch64ISD::ZIP1, dl, Op.getValueType(),
5778                        Op.getOperand(1), Op.getOperand(2));
5779   case Intrinsic::aarch64_sve_zip2:
5780     return DAG.getNode(AArch64ISD::ZIP2, dl, Op.getValueType(),
5781                        Op.getOperand(1), Op.getOperand(2));
5782   case Intrinsic::aarch64_sve_splice:
5783     return DAG.getNode(AArch64ISD::SPLICE, dl, Op.getValueType(),
5784                        Op.getOperand(1), Op.getOperand(2), Op.getOperand(3));
5785   case Intrinsic::aarch64_sve_ptrue:
5786     return getPTrue(DAG, dl, Op.getValueType(), Op.getConstantOperandVal(1));
5787   case Intrinsic::aarch64_sve_clz:
5788     return DAG.getNode(AArch64ISD::CTLZ_MERGE_PASSTHRU, dl, Op.getValueType(),
5789                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5790   case Intrinsic::aarch64_sme_cntsb:
5791     return DAG.getNode(AArch64ISD::RDSVL, dl, Op.getValueType(),
5792                        DAG.getConstant(1, dl, MVT::i32));
5793   case Intrinsic::aarch64_sme_cntsh: {
5794     SDValue One = DAG.getConstant(1, dl, MVT::i32);
5795     SDValue Bytes = DAG.getNode(AArch64ISD::RDSVL, dl, Op.getValueType(), One);
5796     return DAG.getNode(ISD::SRL, dl, Op.getValueType(), Bytes, One);
5797   }
5798   case Intrinsic::aarch64_sme_cntsw: {
5799     SDValue Bytes = DAG.getNode(AArch64ISD::RDSVL, dl, Op.getValueType(),
5800                                 DAG.getConstant(1, dl, MVT::i32));
5801     return DAG.getNode(ISD::SRL, dl, Op.getValueType(), Bytes,
5802                        DAG.getConstant(2, dl, MVT::i32));
5803   }
5804   case Intrinsic::aarch64_sme_cntsd: {
5805     SDValue Bytes = DAG.getNode(AArch64ISD::RDSVL, dl, Op.getValueType(),
5806                                 DAG.getConstant(1, dl, MVT::i32));
5807     return DAG.getNode(ISD::SRL, dl, Op.getValueType(), Bytes,
5808                        DAG.getConstant(3, dl, MVT::i32));
5809   }
5810   case Intrinsic::aarch64_sve_cnt: {
5811     SDValue Data = Op.getOperand(3);
5812     // CTPOP only supports integer operands.
5813     if (Data.getValueType().isFloatingPoint())
5814       Data = DAG.getNode(ISD::BITCAST, dl, Op.getValueType(), Data);
5815     return DAG.getNode(AArch64ISD::CTPOP_MERGE_PASSTHRU, dl, Op.getValueType(),
5816                        Op.getOperand(2), Data, Op.getOperand(1));
5817   }
5818   case Intrinsic::aarch64_sve_dupq_lane:
5819     return LowerDUPQLane(Op, DAG);
5820   case Intrinsic::aarch64_sve_convert_from_svbool:
5821     if (Op.getValueType() == MVT::aarch64svcount)
5822       return DAG.getNode(ISD::BITCAST, dl, Op.getValueType(), Op.getOperand(1));
5823     return getSVEPredicateBitCast(Op.getValueType(), Op.getOperand(1), DAG);
5824   case Intrinsic::aarch64_sve_convert_to_svbool:
5825     if (Op.getOperand(1).getValueType() == MVT::aarch64svcount)
5826       return DAG.getNode(ISD::BITCAST, dl, MVT::nxv16i1, Op.getOperand(1));
5827     return getSVEPredicateBitCast(MVT::nxv16i1, Op.getOperand(1), DAG);
5828   case Intrinsic::aarch64_sve_fneg:
5829     return DAG.getNode(AArch64ISD::FNEG_MERGE_PASSTHRU, dl, Op.getValueType(),
5830                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5831   case Intrinsic::aarch64_sve_frintp:
5832     return DAG.getNode(AArch64ISD::FCEIL_MERGE_PASSTHRU, dl, Op.getValueType(),
5833                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5834   case Intrinsic::aarch64_sve_frintm:
5835     return DAG.getNode(AArch64ISD::FFLOOR_MERGE_PASSTHRU, dl, Op.getValueType(),
5836                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5837   case Intrinsic::aarch64_sve_frinti:
5838     return DAG.getNode(AArch64ISD::FNEARBYINT_MERGE_PASSTHRU, dl, Op.getValueType(),
5839                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5840   case Intrinsic::aarch64_sve_frintx:
5841     return DAG.getNode(AArch64ISD::FRINT_MERGE_PASSTHRU, dl, Op.getValueType(),
5842                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5843   case Intrinsic::aarch64_sve_frinta:
5844     return DAG.getNode(AArch64ISD::FROUND_MERGE_PASSTHRU, dl, Op.getValueType(),
5845                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5846   case Intrinsic::aarch64_sve_frintn:
5847     return DAG.getNode(AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU, dl, Op.getValueType(),
5848                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5849   case Intrinsic::aarch64_sve_frintz:
5850     return DAG.getNode(AArch64ISD::FTRUNC_MERGE_PASSTHRU, dl, Op.getValueType(),
5851                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5852   case Intrinsic::aarch64_sve_ucvtf:
5853     return DAG.getNode(AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU, dl,
5854                        Op.getValueType(), Op.getOperand(2), Op.getOperand(3),
5855                        Op.getOperand(1));
5856   case Intrinsic::aarch64_sve_scvtf:
5857     return DAG.getNode(AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU, dl,
5858                        Op.getValueType(), Op.getOperand(2), Op.getOperand(3),
5859                        Op.getOperand(1));
5860   case Intrinsic::aarch64_sve_fcvtzu:
5861     return DAG.getNode(AArch64ISD::FCVTZU_MERGE_PASSTHRU, dl,
5862                        Op.getValueType(), Op.getOperand(2), Op.getOperand(3),
5863                        Op.getOperand(1));
5864   case Intrinsic::aarch64_sve_fcvtzs:
5865     return DAG.getNode(AArch64ISD::FCVTZS_MERGE_PASSTHRU, dl,
5866                        Op.getValueType(), Op.getOperand(2), Op.getOperand(3),
5867                        Op.getOperand(1));
5868   case Intrinsic::aarch64_sve_fsqrt:
5869     return DAG.getNode(AArch64ISD::FSQRT_MERGE_PASSTHRU, dl, Op.getValueType(),
5870                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5871   case Intrinsic::aarch64_sve_frecpx:
5872     return DAG.getNode(AArch64ISD::FRECPX_MERGE_PASSTHRU, dl, Op.getValueType(),
5873                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5874   case Intrinsic::aarch64_sve_frecpe_x:
5875     return DAG.getNode(AArch64ISD::FRECPE, dl, Op.getValueType(),
5876                        Op.getOperand(1));
5877   case Intrinsic::aarch64_sve_frecps_x:
5878     return DAG.getNode(AArch64ISD::FRECPS, dl, Op.getValueType(),
5879                        Op.getOperand(1), Op.getOperand(2));
5880   case Intrinsic::aarch64_sve_frsqrte_x:
5881     return DAG.getNode(AArch64ISD::FRSQRTE, dl, Op.getValueType(),
5882                        Op.getOperand(1));
5883   case Intrinsic::aarch64_sve_frsqrts_x:
5884     return DAG.getNode(AArch64ISD::FRSQRTS, dl, Op.getValueType(),
5885                        Op.getOperand(1), Op.getOperand(2));
5886   case Intrinsic::aarch64_sve_fabs:
5887     return DAG.getNode(AArch64ISD::FABS_MERGE_PASSTHRU, dl, Op.getValueType(),
5888                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5889   case Intrinsic::aarch64_sve_abs:
5890     return DAG.getNode(AArch64ISD::ABS_MERGE_PASSTHRU, dl, Op.getValueType(),
5891                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5892   case Intrinsic::aarch64_sve_neg:
5893     return DAG.getNode(AArch64ISD::NEG_MERGE_PASSTHRU, dl, Op.getValueType(),
5894                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5895   case Intrinsic::aarch64_sve_insr: {
5896     SDValue Scalar = Op.getOperand(2);
5897     EVT ScalarTy = Scalar.getValueType();
5898     if ((ScalarTy == MVT::i8) || (ScalarTy == MVT::i16))
5899       Scalar = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32, Scalar);
5900 
5901     return DAG.getNode(AArch64ISD::INSR, dl, Op.getValueType(),
5902                        Op.getOperand(1), Scalar);
5903   }
5904   case Intrinsic::aarch64_sve_rbit:
5905     return DAG.getNode(AArch64ISD::BITREVERSE_MERGE_PASSTHRU, dl,
5906                        Op.getValueType(), Op.getOperand(2), Op.getOperand(3),
5907                        Op.getOperand(1));
5908   case Intrinsic::aarch64_sve_revb:
5909     return DAG.getNode(AArch64ISD::BSWAP_MERGE_PASSTHRU, dl, Op.getValueType(),
5910                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5911   case Intrinsic::aarch64_sve_revh:
5912     return DAG.getNode(AArch64ISD::REVH_MERGE_PASSTHRU, dl, Op.getValueType(),
5913                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5914   case Intrinsic::aarch64_sve_revw:
5915     return DAG.getNode(AArch64ISD::REVW_MERGE_PASSTHRU, dl, Op.getValueType(),
5916                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5917   case Intrinsic::aarch64_sve_revd:
5918     return DAG.getNode(AArch64ISD::REVD_MERGE_PASSTHRU, dl, Op.getValueType(),
5919                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5920   case Intrinsic::aarch64_sve_sxtb:
5921     return DAG.getNode(
5922         AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(),
5923         Op.getOperand(2), Op.getOperand(3),
5924         DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i8)),
5925         Op.getOperand(1));
5926   case Intrinsic::aarch64_sve_sxth:
5927     return DAG.getNode(
5928         AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(),
5929         Op.getOperand(2), Op.getOperand(3),
5930         DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i16)),
5931         Op.getOperand(1));
5932   case Intrinsic::aarch64_sve_sxtw:
5933     return DAG.getNode(
5934         AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(),
5935         Op.getOperand(2), Op.getOperand(3),
5936         DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i32)),
5937         Op.getOperand(1));
5938   case Intrinsic::aarch64_sve_uxtb:
5939     return DAG.getNode(
5940         AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(),
5941         Op.getOperand(2), Op.getOperand(3),
5942         DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i8)),
5943         Op.getOperand(1));
5944   case Intrinsic::aarch64_sve_uxth:
5945     return DAG.getNode(
5946         AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(),
5947         Op.getOperand(2), Op.getOperand(3),
5948         DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i16)),
5949         Op.getOperand(1));
5950   case Intrinsic::aarch64_sve_uxtw:
5951     return DAG.getNode(
5952         AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(),
5953         Op.getOperand(2), Op.getOperand(3),
5954         DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i32)),
5955         Op.getOperand(1));
5956   case Intrinsic::localaddress: {
5957     const auto &MF = DAG.getMachineFunction();
5958     const auto *RegInfo = Subtarget->getRegisterInfo();
5959     unsigned Reg = RegInfo->getLocalAddressRegister(MF);
5960     return DAG.getCopyFromReg(DAG.getEntryNode(), dl, Reg,
5961                               Op.getSimpleValueType());
5962   }
5963 
5964   case Intrinsic::eh_recoverfp: {
5965     // FIXME: This needs to be implemented to correctly handle highly aligned
5966     // stack objects. For now we simply return the incoming FP. Refer D53541
5967     // for more details.
5968     SDValue FnOp = Op.getOperand(1);
5969     SDValue IncomingFPOp = Op.getOperand(2);
5970     GlobalAddressSDNode *GSD = dyn_cast<GlobalAddressSDNode>(FnOp);
5971     auto *Fn = dyn_cast_or_null<Function>(GSD ? GSD->getGlobal() : nullptr);
5972     if (!Fn)
5973       report_fatal_error(
5974           "llvm.eh.recoverfp must take a function as the first argument");
5975     return IncomingFPOp;
5976   }
5977 
5978   case Intrinsic::aarch64_neon_vsri:
5979   case Intrinsic::aarch64_neon_vsli:
5980   case Intrinsic::aarch64_sve_sri:
5981   case Intrinsic::aarch64_sve_sli: {
5982     EVT Ty = Op.getValueType();
5983 
5984     if (!Ty.isVector())
5985       report_fatal_error("Unexpected type for aarch64_neon_vsli");
5986 
5987     assert(Op.getConstantOperandVal(3) <= Ty.getScalarSizeInBits());
5988 
5989     bool IsShiftRight = IntNo == Intrinsic::aarch64_neon_vsri ||
5990                         IntNo == Intrinsic::aarch64_sve_sri;
5991     unsigned Opcode = IsShiftRight ? AArch64ISD::VSRI : AArch64ISD::VSLI;
5992     return DAG.getNode(Opcode, dl, Ty, Op.getOperand(1), Op.getOperand(2),
5993                        Op.getOperand(3));
5994   }
5995 
5996   case Intrinsic::aarch64_neon_srhadd:
5997   case Intrinsic::aarch64_neon_urhadd:
5998   case Intrinsic::aarch64_neon_shadd:
5999   case Intrinsic::aarch64_neon_uhadd: {
6000     bool IsSignedAdd = (IntNo == Intrinsic::aarch64_neon_srhadd ||
6001                         IntNo == Intrinsic::aarch64_neon_shadd);
6002     bool IsRoundingAdd = (IntNo == Intrinsic::aarch64_neon_srhadd ||
6003                           IntNo == Intrinsic::aarch64_neon_urhadd);
6004     unsigned Opcode = IsSignedAdd
6005                           ? (IsRoundingAdd ? ISD::AVGCEILS : ISD::AVGFLOORS)
6006                           : (IsRoundingAdd ? ISD::AVGCEILU : ISD::AVGFLOORU);
6007     return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
6008                        Op.getOperand(2));
6009   }
6010   case Intrinsic::aarch64_neon_saddlp:
6011   case Intrinsic::aarch64_neon_uaddlp: {
6012     unsigned Opcode = IntNo == Intrinsic::aarch64_neon_uaddlp
6013                           ? AArch64ISD::UADDLP
6014                           : AArch64ISD::SADDLP;
6015     return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1));
6016   }
6017   case Intrinsic::aarch64_neon_sdot:
6018   case Intrinsic::aarch64_neon_udot:
6019   case Intrinsic::aarch64_sve_sdot:
6020   case Intrinsic::aarch64_sve_udot: {
6021     unsigned Opcode = (IntNo == Intrinsic::aarch64_neon_udot ||
6022                        IntNo == Intrinsic::aarch64_sve_udot)
6023                           ? AArch64ISD::UDOT
6024                           : AArch64ISD::SDOT;
6025     return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
6026                        Op.getOperand(2), Op.getOperand(3));
6027   }
6028   case Intrinsic::get_active_lane_mask: {
6029     SDValue ID =
6030         DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, dl, MVT::i64);
6031 
6032     EVT VT = Op.getValueType();
6033     if (VT.isScalableVector())
6034       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, ID, Op.getOperand(1),
6035                          Op.getOperand(2));
6036 
6037     // We can use the SVE whilelo instruction to lower this intrinsic by
6038     // creating the appropriate sequence of scalable vector operations and
6039     // then extracting a fixed-width subvector from the scalable vector.
6040 
6041     EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
6042     EVT WhileVT = ContainerVT.changeElementType(MVT::i1);
6043 
6044     SDValue Mask = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, WhileVT, ID,
6045                                Op.getOperand(1), Op.getOperand(2));
6046     SDValue MaskAsInt = DAG.getNode(ISD::SIGN_EXTEND, dl, ContainerVT, Mask);
6047     return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, MaskAsInt,
6048                        DAG.getVectorIdxConstant(0, dl));
6049   }
6050   case Intrinsic::aarch64_neon_uaddlv: {
6051     EVT OpVT = Op.getOperand(1).getValueType();
6052     EVT ResVT = Op.getValueType();
6053     if (ResVT == MVT::i32 && (OpVT == MVT::v8i8 || OpVT == MVT::v16i8 ||
6054                               OpVT == MVT::v8i16 || OpVT == MVT::v4i16)) {
6055       // In order to avoid insert_subvector, used v4i32 than v2i32.
6056       SDValue UADDLV =
6057           DAG.getNode(AArch64ISD::UADDLV, dl, MVT::v4i32, Op.getOperand(1));
6058       SDValue EXTRACT_VEC_ELT =
6059           DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32, UADDLV,
6060                       DAG.getConstant(0, dl, MVT::i64));
6061       return EXTRACT_VEC_ELT;
6062     }
6063     return SDValue();
6064   }
6065   case Intrinsic::experimental_cttz_elts: {
6066     SDValue CttzOp = Op.getOperand(1);
6067     EVT VT = CttzOp.getValueType();
6068     assert(VT.getVectorElementType() == MVT::i1 && "Expected MVT::i1");
6069 
6070     if (VT.isFixedLengthVector()) {
6071       // We can use SVE instructions to lower this intrinsic by first creating
6072       // an SVE predicate register mask from the fixed-width vector.
6073       EVT NewVT = getTypeToTransformTo(*DAG.getContext(), VT);
6074       SDValue Mask = DAG.getNode(ISD::SIGN_EXTEND, dl, NewVT, CttzOp);
6075       CttzOp = convertFixedMaskToScalableVector(Mask, DAG);
6076     }
6077 
6078     SDValue NewCttzElts =
6079         DAG.getNode(AArch64ISD::CTTZ_ELTS, dl, MVT::i64, CttzOp);
6080     return DAG.getZExtOrTrunc(NewCttzElts, dl, Op.getValueType());
6081   }
6082   }
6083 }
6084 
shouldExtendGSIndex(EVT VT,EVT & EltTy) const6085 bool AArch64TargetLowering::shouldExtendGSIndex(EVT VT, EVT &EltTy) const {
6086   if (VT.getVectorElementType() == MVT::i8 ||
6087       VT.getVectorElementType() == MVT::i16) {
6088     EltTy = MVT::i32;
6089     return true;
6090   }
6091   return false;
6092 }
6093 
shouldRemoveExtendFromGSIndex(SDValue Extend,EVT DataVT) const6094 bool AArch64TargetLowering::shouldRemoveExtendFromGSIndex(SDValue Extend,
6095                                                           EVT DataVT) const {
6096   const EVT IndexVT = Extend.getOperand(0).getValueType();
6097   // SVE only supports implicit extension of 32-bit indices.
6098   if (!Subtarget->hasSVE() || IndexVT.getVectorElementType() != MVT::i32)
6099     return false;
6100 
6101   // Indices cannot be smaller than the main data type.
6102   if (IndexVT.getScalarSizeInBits() < DataVT.getScalarSizeInBits())
6103     return false;
6104 
6105   // Scalable vectors with "vscale * 2" or fewer elements sit within a 64-bit
6106   // element container type, which would violate the previous clause.
6107   return DataVT.isFixedLengthVector() || DataVT.getVectorMinNumElements() > 2;
6108 }
6109 
isVectorLoadExtDesirable(SDValue ExtVal) const6110 bool AArch64TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const {
6111   EVT ExtVT = ExtVal.getValueType();
6112   if (!ExtVT.isScalableVector() && !Subtarget->useSVEForFixedLengthVectors())
6113     return false;
6114 
6115   // It may be worth creating extending masked loads if there are multiple
6116   // masked loads using the same predicate. That way we'll end up creating
6117   // extending masked loads that may then get split by the legaliser. This
6118   // results in just one set of predicate unpacks at the start, instead of
6119   // multiple sets of vector unpacks after each load.
6120   if (auto *Ld = dyn_cast<MaskedLoadSDNode>(ExtVal->getOperand(0))) {
6121     if (!isLoadExtLegalOrCustom(ISD::ZEXTLOAD, ExtVT, Ld->getValueType(0))) {
6122       // Disable extending masked loads for fixed-width for now, since the code
6123       // quality doesn't look great.
6124       if (!ExtVT.isScalableVector())
6125         return false;
6126 
6127       unsigned NumExtMaskedLoads = 0;
6128       for (auto *U : Ld->getMask()->uses())
6129         if (isa<MaskedLoadSDNode>(U))
6130           NumExtMaskedLoads++;
6131 
6132       if (NumExtMaskedLoads <= 1)
6133         return false;
6134     }
6135   }
6136 
6137   return true;
6138 }
6139 
getGatherVecOpcode(bool IsScaled,bool IsSigned,bool NeedsExtend)6140 unsigned getGatherVecOpcode(bool IsScaled, bool IsSigned, bool NeedsExtend) {
6141   std::map<std::tuple<bool, bool, bool>, unsigned> AddrModes = {
6142       {std::make_tuple(/*Scaled*/ false, /*Signed*/ false, /*Extend*/ false),
6143        AArch64ISD::GLD1_MERGE_ZERO},
6144       {std::make_tuple(/*Scaled*/ false, /*Signed*/ false, /*Extend*/ true),
6145        AArch64ISD::GLD1_UXTW_MERGE_ZERO},
6146       {std::make_tuple(/*Scaled*/ false, /*Signed*/ true, /*Extend*/ false),
6147        AArch64ISD::GLD1_MERGE_ZERO},
6148       {std::make_tuple(/*Scaled*/ false, /*Signed*/ true, /*Extend*/ true),
6149        AArch64ISD::GLD1_SXTW_MERGE_ZERO},
6150       {std::make_tuple(/*Scaled*/ true, /*Signed*/ false, /*Extend*/ false),
6151        AArch64ISD::GLD1_SCALED_MERGE_ZERO},
6152       {std::make_tuple(/*Scaled*/ true, /*Signed*/ false, /*Extend*/ true),
6153        AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO},
6154       {std::make_tuple(/*Scaled*/ true, /*Signed*/ true, /*Extend*/ false),
6155        AArch64ISD::GLD1_SCALED_MERGE_ZERO},
6156       {std::make_tuple(/*Scaled*/ true, /*Signed*/ true, /*Extend*/ true),
6157        AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO},
6158   };
6159   auto Key = std::make_tuple(IsScaled, IsSigned, NeedsExtend);
6160   return AddrModes.find(Key)->second;
6161 }
6162 
getSignExtendedGatherOpcode(unsigned Opcode)6163 unsigned getSignExtendedGatherOpcode(unsigned Opcode) {
6164   switch (Opcode) {
6165   default:
6166     llvm_unreachable("unimplemented opcode");
6167     return Opcode;
6168   case AArch64ISD::GLD1_MERGE_ZERO:
6169     return AArch64ISD::GLD1S_MERGE_ZERO;
6170   case AArch64ISD::GLD1_IMM_MERGE_ZERO:
6171     return AArch64ISD::GLD1S_IMM_MERGE_ZERO;
6172   case AArch64ISD::GLD1_UXTW_MERGE_ZERO:
6173     return AArch64ISD::GLD1S_UXTW_MERGE_ZERO;
6174   case AArch64ISD::GLD1_SXTW_MERGE_ZERO:
6175     return AArch64ISD::GLD1S_SXTW_MERGE_ZERO;
6176   case AArch64ISD::GLD1_SCALED_MERGE_ZERO:
6177     return AArch64ISD::GLD1S_SCALED_MERGE_ZERO;
6178   case AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO:
6179     return AArch64ISD::GLD1S_UXTW_SCALED_MERGE_ZERO;
6180   case AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO:
6181     return AArch64ISD::GLD1S_SXTW_SCALED_MERGE_ZERO;
6182   }
6183 }
6184 
LowerMGATHER(SDValue Op,SelectionDAG & DAG) const6185 SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
6186                                             SelectionDAG &DAG) const {
6187   MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(Op);
6188 
6189   SDLoc DL(Op);
6190   SDValue Chain = MGT->getChain();
6191   SDValue PassThru = MGT->getPassThru();
6192   SDValue Mask = MGT->getMask();
6193   SDValue BasePtr = MGT->getBasePtr();
6194   SDValue Index = MGT->getIndex();
6195   SDValue Scale = MGT->getScale();
6196   EVT VT = Op.getValueType();
6197   EVT MemVT = MGT->getMemoryVT();
6198   ISD::LoadExtType ExtType = MGT->getExtensionType();
6199   ISD::MemIndexType IndexType = MGT->getIndexType();
6200 
6201   // SVE supports zero (and so undef) passthrough values only, everything else
6202   // must be handled manually by an explicit select on the load's output.
6203   if (!PassThru->isUndef() && !isZerosVector(PassThru.getNode())) {
6204     SDValue Ops[] = {Chain, DAG.getUNDEF(VT), Mask, BasePtr, Index, Scale};
6205     SDValue Load =
6206         DAG.getMaskedGather(MGT->getVTList(), MemVT, DL, Ops,
6207                             MGT->getMemOperand(), IndexType, ExtType);
6208     SDValue Select = DAG.getSelect(DL, VT, Mask, Load, PassThru);
6209     return DAG.getMergeValues({Select, Load.getValue(1)}, DL);
6210   }
6211 
6212   bool IsScaled = MGT->isIndexScaled();
6213   bool IsSigned = MGT->isIndexSigned();
6214 
6215   // SVE supports an index scaled by sizeof(MemVT.elt) only, everything else
6216   // must be calculated before hand.
6217   uint64_t ScaleVal = Scale->getAsZExtVal();
6218   if (IsScaled && ScaleVal != MemVT.getScalarStoreSize()) {
6219     assert(isPowerOf2_64(ScaleVal) && "Expecting power-of-two types");
6220     EVT IndexVT = Index.getValueType();
6221     Index = DAG.getNode(ISD::SHL, DL, IndexVT, Index,
6222                         DAG.getConstant(Log2_32(ScaleVal), DL, IndexVT));
6223     Scale = DAG.getTargetConstant(1, DL, Scale.getValueType());
6224 
6225     SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
6226     return DAG.getMaskedGather(MGT->getVTList(), MemVT, DL, Ops,
6227                                MGT->getMemOperand(), IndexType, ExtType);
6228   }
6229 
6230   // Lower fixed length gather to a scalable equivalent.
6231   if (VT.isFixedLengthVector()) {
6232     assert(Subtarget->useSVEForFixedLengthVectors() &&
6233            "Cannot lower when not using SVE for fixed vectors!");
6234 
6235     // NOTE: Handle floating-point as if integer then bitcast the result.
6236     EVT DataVT = VT.changeVectorElementTypeToInteger();
6237     MemVT = MemVT.changeVectorElementTypeToInteger();
6238 
6239     // Find the smallest integer fixed length vector we can use for the gather.
6240     EVT PromotedVT = VT.changeVectorElementType(MVT::i32);
6241     if (DataVT.getVectorElementType() == MVT::i64 ||
6242         Index.getValueType().getVectorElementType() == MVT::i64 ||
6243         Mask.getValueType().getVectorElementType() == MVT::i64)
6244       PromotedVT = VT.changeVectorElementType(MVT::i64);
6245 
6246     // Promote vector operands except for passthrough, which we know is either
6247     // undef or zero, and thus best constructed directly.
6248     unsigned ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
6249     Index = DAG.getNode(ExtOpcode, DL, PromotedVT, Index);
6250     Mask = DAG.getNode(ISD::SIGN_EXTEND, DL, PromotedVT, Mask);
6251 
6252     // A promoted result type forces the need for an extending load.
6253     if (PromotedVT != DataVT && ExtType == ISD::NON_EXTLOAD)
6254       ExtType = ISD::EXTLOAD;
6255 
6256     EVT ContainerVT = getContainerForFixedLengthVector(DAG, PromotedVT);
6257 
6258     // Convert fixed length vector operands to scalable.
6259     MemVT = ContainerVT.changeVectorElementType(MemVT.getVectorElementType());
6260     Index = convertToScalableVector(DAG, ContainerVT, Index);
6261     Mask = convertFixedMaskToScalableVector(Mask, DAG);
6262     PassThru = PassThru->isUndef() ? DAG.getUNDEF(ContainerVT)
6263                                    : DAG.getConstant(0, DL, ContainerVT);
6264 
6265     // Emit equivalent scalable vector gather.
6266     SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
6267     SDValue Load =
6268         DAG.getMaskedGather(DAG.getVTList(ContainerVT, MVT::Other), MemVT, DL,
6269                             Ops, MGT->getMemOperand(), IndexType, ExtType);
6270 
6271     // Extract fixed length data then convert to the required result type.
6272     SDValue Result = convertFromScalableVector(DAG, PromotedVT, Load);
6273     Result = DAG.getNode(ISD::TRUNCATE, DL, DataVT, Result);
6274     if (VT.isFloatingPoint())
6275       Result = DAG.getNode(ISD::BITCAST, DL, VT, Result);
6276 
6277     return DAG.getMergeValues({Result, Load.getValue(1)}, DL);
6278   }
6279 
6280   // Everything else is legal.
6281   return Op;
6282 }
6283 
LowerMSCATTER(SDValue Op,SelectionDAG & DAG) const6284 SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
6285                                              SelectionDAG &DAG) const {
6286   MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(Op);
6287 
6288   SDLoc DL(Op);
6289   SDValue Chain = MSC->getChain();
6290   SDValue StoreVal = MSC->getValue();
6291   SDValue Mask = MSC->getMask();
6292   SDValue BasePtr = MSC->getBasePtr();
6293   SDValue Index = MSC->getIndex();
6294   SDValue Scale = MSC->getScale();
6295   EVT VT = StoreVal.getValueType();
6296   EVT MemVT = MSC->getMemoryVT();
6297   ISD::MemIndexType IndexType = MSC->getIndexType();
6298   bool Truncating = MSC->isTruncatingStore();
6299 
6300   bool IsScaled = MSC->isIndexScaled();
6301   bool IsSigned = MSC->isIndexSigned();
6302 
6303   // SVE supports an index scaled by sizeof(MemVT.elt) only, everything else
6304   // must be calculated before hand.
6305   uint64_t ScaleVal = Scale->getAsZExtVal();
6306   if (IsScaled && ScaleVal != MemVT.getScalarStoreSize()) {
6307     assert(isPowerOf2_64(ScaleVal) && "Expecting power-of-two types");
6308     EVT IndexVT = Index.getValueType();
6309     Index = DAG.getNode(ISD::SHL, DL, IndexVT, Index,
6310                         DAG.getConstant(Log2_32(ScaleVal), DL, IndexVT));
6311     Scale = DAG.getTargetConstant(1, DL, Scale.getValueType());
6312 
6313     SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
6314     return DAG.getMaskedScatter(MSC->getVTList(), MemVT, DL, Ops,
6315                                 MSC->getMemOperand(), IndexType, Truncating);
6316   }
6317 
6318   // Lower fixed length scatter to a scalable equivalent.
6319   if (VT.isFixedLengthVector()) {
6320     assert(Subtarget->useSVEForFixedLengthVectors() &&
6321            "Cannot lower when not using SVE for fixed vectors!");
6322 
6323     // Once bitcast we treat floating-point scatters as if integer.
6324     if (VT.isFloatingPoint()) {
6325       VT = VT.changeVectorElementTypeToInteger();
6326       MemVT = MemVT.changeVectorElementTypeToInteger();
6327       StoreVal = DAG.getNode(ISD::BITCAST, DL, VT, StoreVal);
6328     }
6329 
6330     // Find the smallest integer fixed length vector we can use for the scatter.
6331     EVT PromotedVT = VT.changeVectorElementType(MVT::i32);
6332     if (VT.getVectorElementType() == MVT::i64 ||
6333         Index.getValueType().getVectorElementType() == MVT::i64 ||
6334         Mask.getValueType().getVectorElementType() == MVT::i64)
6335       PromotedVT = VT.changeVectorElementType(MVT::i64);
6336 
6337     // Promote vector operands.
6338     unsigned ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
6339     Index = DAG.getNode(ExtOpcode, DL, PromotedVT, Index);
6340     Mask = DAG.getNode(ISD::SIGN_EXTEND, DL, PromotedVT, Mask);
6341     StoreVal = DAG.getNode(ISD::ANY_EXTEND, DL, PromotedVT, StoreVal);
6342 
6343     // A promoted value type forces the need for a truncating store.
6344     if (PromotedVT != VT)
6345       Truncating = true;
6346 
6347     EVT ContainerVT = getContainerForFixedLengthVector(DAG, PromotedVT);
6348 
6349     // Convert fixed length vector operands to scalable.
6350     MemVT = ContainerVT.changeVectorElementType(MemVT.getVectorElementType());
6351     Index = convertToScalableVector(DAG, ContainerVT, Index);
6352     Mask = convertFixedMaskToScalableVector(Mask, DAG);
6353     StoreVal = convertToScalableVector(DAG, ContainerVT, StoreVal);
6354 
6355     // Emit equivalent scalable vector scatter.
6356     SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
6357     return DAG.getMaskedScatter(MSC->getVTList(), MemVT, DL, Ops,
6358                                 MSC->getMemOperand(), IndexType, Truncating);
6359   }
6360 
6361   // Everything else is legal.
6362   return Op;
6363 }
6364 
LowerMLOAD(SDValue Op,SelectionDAG & DAG) const6365 SDValue AArch64TargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const {
6366   SDLoc DL(Op);
6367   MaskedLoadSDNode *LoadNode = cast<MaskedLoadSDNode>(Op);
6368   assert(LoadNode && "Expected custom lowering of a masked load node");
6369   EVT VT = Op->getValueType(0);
6370 
6371   if (useSVEForFixedLengthVectorVT(VT, /*OverrideNEON=*/true))
6372     return LowerFixedLengthVectorMLoadToSVE(Op, DAG);
6373 
6374   SDValue PassThru = LoadNode->getPassThru();
6375   SDValue Mask = LoadNode->getMask();
6376 
6377   if (PassThru->isUndef() || isZerosVector(PassThru.getNode()))
6378     return Op;
6379 
6380   SDValue Load = DAG.getMaskedLoad(
6381       VT, DL, LoadNode->getChain(), LoadNode->getBasePtr(),
6382       LoadNode->getOffset(), Mask, DAG.getUNDEF(VT), LoadNode->getMemoryVT(),
6383       LoadNode->getMemOperand(), LoadNode->getAddressingMode(),
6384       LoadNode->getExtensionType());
6385 
6386   SDValue Result = DAG.getSelect(DL, VT, Mask, Load, PassThru);
6387 
6388   return DAG.getMergeValues({Result, Load.getValue(1)}, DL);
6389 }
6390 
6391 // Custom lower trunc store for v4i8 vectors, since it is promoted to v4i16.
LowerTruncateVectorStore(SDLoc DL,StoreSDNode * ST,EVT VT,EVT MemVT,SelectionDAG & DAG)6392 static SDValue LowerTruncateVectorStore(SDLoc DL, StoreSDNode *ST,
6393                                         EVT VT, EVT MemVT,
6394                                         SelectionDAG &DAG) {
6395   assert(VT.isVector() && "VT should be a vector type");
6396   assert(MemVT == MVT::v4i8 && VT == MVT::v4i16);
6397 
6398   SDValue Value = ST->getValue();
6399 
6400   // It first extend the promoted v4i16 to v8i16, truncate to v8i8, and extract
6401   // the word lane which represent the v4i8 subvector.  It optimizes the store
6402   // to:
6403   //
6404   //   xtn  v0.8b, v0.8h
6405   //   str  s0, [x0]
6406 
6407   SDValue Undef = DAG.getUNDEF(MVT::i16);
6408   SDValue UndefVec = DAG.getBuildVector(MVT::v4i16, DL,
6409                                         {Undef, Undef, Undef, Undef});
6410 
6411   SDValue TruncExt = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v8i16,
6412                                  Value, UndefVec);
6413   SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::v8i8, TruncExt);
6414 
6415   Trunc = DAG.getNode(ISD::BITCAST, DL, MVT::v2i32, Trunc);
6416   SDValue ExtractTrunc = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i32,
6417                                      Trunc, DAG.getConstant(0, DL, MVT::i64));
6418 
6419   return DAG.getStore(ST->getChain(), DL, ExtractTrunc,
6420                       ST->getBasePtr(), ST->getMemOperand());
6421 }
6422 
6423 // Custom lowering for any store, vector or scalar and/or default or with
6424 // a truncate operations.  Currently only custom lower truncate operation
6425 // from vector v4i16 to v4i8 or volatile stores of i128.
LowerSTORE(SDValue Op,SelectionDAG & DAG) const6426 SDValue AArch64TargetLowering::LowerSTORE(SDValue Op,
6427                                           SelectionDAG &DAG) const {
6428   SDLoc Dl(Op);
6429   StoreSDNode *StoreNode = cast<StoreSDNode>(Op);
6430   assert (StoreNode && "Can only custom lower store nodes");
6431 
6432   SDValue Value = StoreNode->getValue();
6433 
6434   EVT VT = Value.getValueType();
6435   EVT MemVT = StoreNode->getMemoryVT();
6436 
6437   if (VT.isVector()) {
6438     if (useSVEForFixedLengthVectorVT(
6439             VT,
6440             /*OverrideNEON=*/Subtarget->useSVEForFixedLengthVectors()))
6441       return LowerFixedLengthVectorStoreToSVE(Op, DAG);
6442 
6443     unsigned AS = StoreNode->getAddressSpace();
6444     Align Alignment = StoreNode->getAlign();
6445     if (Alignment < MemVT.getStoreSize() &&
6446         !allowsMisalignedMemoryAccesses(MemVT, AS, Alignment,
6447                                         StoreNode->getMemOperand()->getFlags(),
6448                                         nullptr)) {
6449       return scalarizeVectorStore(StoreNode, DAG);
6450     }
6451 
6452     if (StoreNode->isTruncatingStore() && VT == MVT::v4i16 &&
6453         MemVT == MVT::v4i8) {
6454       return LowerTruncateVectorStore(Dl, StoreNode, VT, MemVT, DAG);
6455     }
6456     // 256 bit non-temporal stores can be lowered to STNP. Do this as part of
6457     // the custom lowering, as there are no un-paired non-temporal stores and
6458     // legalization will break up 256 bit inputs.
6459     ElementCount EC = MemVT.getVectorElementCount();
6460     if (StoreNode->isNonTemporal() && MemVT.getSizeInBits() == 256u &&
6461         EC.isKnownEven() && DAG.getDataLayout().isLittleEndian() &&
6462         (MemVT.getScalarSizeInBits() == 8u ||
6463          MemVT.getScalarSizeInBits() == 16u ||
6464          MemVT.getScalarSizeInBits() == 32u ||
6465          MemVT.getScalarSizeInBits() == 64u)) {
6466       SDValue Lo =
6467           DAG.getNode(ISD::EXTRACT_SUBVECTOR, Dl,
6468                       MemVT.getHalfNumVectorElementsVT(*DAG.getContext()),
6469                       StoreNode->getValue(), DAG.getConstant(0, Dl, MVT::i64));
6470       SDValue Hi =
6471           DAG.getNode(ISD::EXTRACT_SUBVECTOR, Dl,
6472                       MemVT.getHalfNumVectorElementsVT(*DAG.getContext()),
6473                       StoreNode->getValue(),
6474                       DAG.getConstant(EC.getKnownMinValue() / 2, Dl, MVT::i64));
6475       SDValue Result = DAG.getMemIntrinsicNode(
6476           AArch64ISD::STNP, Dl, DAG.getVTList(MVT::Other),
6477           {StoreNode->getChain(), Lo, Hi, StoreNode->getBasePtr()},
6478           StoreNode->getMemoryVT(), StoreNode->getMemOperand());
6479       return Result;
6480     }
6481   } else if (MemVT == MVT::i128 && StoreNode->isVolatile()) {
6482     return LowerStore128(Op, DAG);
6483   } else if (MemVT == MVT::i64x8) {
6484     SDValue Value = StoreNode->getValue();
6485     assert(Value->getValueType(0) == MVT::i64x8);
6486     SDValue Chain = StoreNode->getChain();
6487     SDValue Base = StoreNode->getBasePtr();
6488     EVT PtrVT = Base.getValueType();
6489     for (unsigned i = 0; i < 8; i++) {
6490       SDValue Part = DAG.getNode(AArch64ISD::LS64_EXTRACT, Dl, MVT::i64,
6491                                  Value, DAG.getConstant(i, Dl, MVT::i32));
6492       SDValue Ptr = DAG.getNode(ISD::ADD, Dl, PtrVT, Base,
6493                                 DAG.getConstant(i * 8, Dl, PtrVT));
6494       Chain = DAG.getStore(Chain, Dl, Part, Ptr, StoreNode->getPointerInfo(),
6495                            StoreNode->getOriginalAlign());
6496     }
6497     return Chain;
6498   }
6499 
6500   return SDValue();
6501 }
6502 
6503 /// Lower atomic or volatile 128-bit stores to a single STP instruction.
LowerStore128(SDValue Op,SelectionDAG & DAG) const6504 SDValue AArch64TargetLowering::LowerStore128(SDValue Op,
6505                                              SelectionDAG &DAG) const {
6506   MemSDNode *StoreNode = cast<MemSDNode>(Op);
6507   assert(StoreNode->getMemoryVT() == MVT::i128);
6508   assert(StoreNode->isVolatile() || StoreNode->isAtomic());
6509 
6510   bool IsStoreRelease =
6511       StoreNode->getMergedOrdering() == AtomicOrdering::Release;
6512   if (StoreNode->isAtomic())
6513     assert((Subtarget->hasFeature(AArch64::FeatureLSE2) &&
6514             Subtarget->hasFeature(AArch64::FeatureRCPC3) && IsStoreRelease) ||
6515            StoreNode->getMergedOrdering() == AtomicOrdering::Unordered ||
6516            StoreNode->getMergedOrdering() == AtomicOrdering::Monotonic);
6517 
6518   SDValue Value = (StoreNode->getOpcode() == ISD::STORE ||
6519                    StoreNode->getOpcode() == ISD::ATOMIC_STORE)
6520                       ? StoreNode->getOperand(1)
6521                       : StoreNode->getOperand(2);
6522   SDLoc DL(Op);
6523   auto StoreValue = DAG.SplitScalar(Value, DL, MVT::i64, MVT::i64);
6524   unsigned Opcode = IsStoreRelease ? AArch64ISD::STILP : AArch64ISD::STP;
6525   if (DAG.getDataLayout().isBigEndian())
6526     std::swap(StoreValue.first, StoreValue.second);
6527   SDValue Result = DAG.getMemIntrinsicNode(
6528       Opcode, DL, DAG.getVTList(MVT::Other),
6529       {StoreNode->getChain(), StoreValue.first, StoreValue.second,
6530        StoreNode->getBasePtr()},
6531       StoreNode->getMemoryVT(), StoreNode->getMemOperand());
6532   return Result;
6533 }
6534 
LowerLOAD(SDValue Op,SelectionDAG & DAG) const6535 SDValue AArch64TargetLowering::LowerLOAD(SDValue Op,
6536                                          SelectionDAG &DAG) const {
6537   SDLoc DL(Op);
6538   LoadSDNode *LoadNode = cast<LoadSDNode>(Op);
6539   assert(LoadNode && "Expected custom lowering of a load node");
6540 
6541   if (LoadNode->getMemoryVT() == MVT::i64x8) {
6542     SmallVector<SDValue, 8> Ops;
6543     SDValue Base = LoadNode->getBasePtr();
6544     SDValue Chain = LoadNode->getChain();
6545     EVT PtrVT = Base.getValueType();
6546     for (unsigned i = 0; i < 8; i++) {
6547       SDValue Ptr = DAG.getNode(ISD::ADD, DL, PtrVT, Base,
6548                                 DAG.getConstant(i * 8, DL, PtrVT));
6549       SDValue Part = DAG.getLoad(MVT::i64, DL, Chain, Ptr,
6550                                  LoadNode->getPointerInfo(),
6551                                  LoadNode->getOriginalAlign());
6552       Ops.push_back(Part);
6553       Chain = SDValue(Part.getNode(), 1);
6554     }
6555     SDValue Loaded = DAG.getNode(AArch64ISD::LS64_BUILD, DL, MVT::i64x8, Ops);
6556     return DAG.getMergeValues({Loaded, Chain}, DL);
6557   }
6558 
6559   // Custom lowering for extending v4i8 vector loads.
6560   EVT VT = Op->getValueType(0);
6561   assert((VT == MVT::v4i16 || VT == MVT::v4i32) && "Expected v4i16 or v4i32");
6562 
6563   if (LoadNode->getMemoryVT() != MVT::v4i8)
6564     return SDValue();
6565 
6566   // Avoid generating unaligned loads.
6567   if (Subtarget->requiresStrictAlign() && LoadNode->getAlign() < Align(4))
6568     return SDValue();
6569 
6570   unsigned ExtType;
6571   if (LoadNode->getExtensionType() == ISD::SEXTLOAD)
6572     ExtType = ISD::SIGN_EXTEND;
6573   else if (LoadNode->getExtensionType() == ISD::ZEXTLOAD ||
6574            LoadNode->getExtensionType() == ISD::EXTLOAD)
6575     ExtType = ISD::ZERO_EXTEND;
6576   else
6577     return SDValue();
6578 
6579   SDValue Load = DAG.getLoad(MVT::f32, DL, LoadNode->getChain(),
6580                              LoadNode->getBasePtr(), MachinePointerInfo());
6581   SDValue Chain = Load.getValue(1);
6582   SDValue Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v2f32, Load);
6583   SDValue BC = DAG.getNode(ISD::BITCAST, DL, MVT::v8i8, Vec);
6584   SDValue Ext = DAG.getNode(ExtType, DL, MVT::v8i16, BC);
6585   Ext = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v4i16, Ext,
6586                     DAG.getConstant(0, DL, MVT::i64));
6587   if (VT == MVT::v4i32)
6588     Ext = DAG.getNode(ExtType, DL, MVT::v4i32, Ext);
6589   return DAG.getMergeValues({Ext, Chain}, DL);
6590 }
6591 
6592 // Generate SUBS and CSEL for integer abs.
LowerABS(SDValue Op,SelectionDAG & DAG) const6593 SDValue AArch64TargetLowering::LowerABS(SDValue Op, SelectionDAG &DAG) const {
6594   MVT VT = Op.getSimpleValueType();
6595 
6596   if (VT.isVector())
6597     return LowerToPredicatedOp(Op, DAG, AArch64ISD::ABS_MERGE_PASSTHRU);
6598 
6599   SDLoc DL(Op);
6600   SDValue Neg = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
6601                             Op.getOperand(0));
6602   // Generate SUBS & CSEL.
6603   SDValue Cmp =
6604       DAG.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, MVT::i32),
6605                   Op.getOperand(0), DAG.getConstant(0, DL, VT));
6606   return DAG.getNode(AArch64ISD::CSEL, DL, VT, Op.getOperand(0), Neg,
6607                      DAG.getConstant(AArch64CC::PL, DL, MVT::i32),
6608                      Cmp.getValue(1));
6609 }
6610 
LowerBRCOND(SDValue Op,SelectionDAG & DAG)6611 static SDValue LowerBRCOND(SDValue Op, SelectionDAG &DAG) {
6612   SDValue Chain = Op.getOperand(0);
6613   SDValue Cond = Op.getOperand(1);
6614   SDValue Dest = Op.getOperand(2);
6615 
6616   AArch64CC::CondCode CC;
6617   if (SDValue Cmp = emitConjunction(DAG, Cond, CC)) {
6618     SDLoc dl(Op);
6619     SDValue CCVal = DAG.getConstant(CC, dl, MVT::i32);
6620     return DAG.getNode(AArch64ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
6621                        Cmp);
6622   }
6623 
6624   return SDValue();
6625 }
6626 
6627 // Treat FSHR with constant shifts as legal operation, otherwise it is expanded
6628 // FSHL is converted to FSHR before deciding what to do with it
LowerFunnelShift(SDValue Op,SelectionDAG & DAG)6629 static SDValue LowerFunnelShift(SDValue Op, SelectionDAG &DAG) {
6630   SDValue Shifts = Op.getOperand(2);
6631   // Check if the shift amount is a constant
6632   // If opcode is FSHL, convert it to FSHR
6633   if (auto *ShiftNo = dyn_cast<ConstantSDNode>(Shifts)) {
6634     SDLoc DL(Op);
6635     MVT VT = Op.getSimpleValueType();
6636 
6637     if (Op.getOpcode() == ISD::FSHL) {
6638       unsigned int NewShiftNo =
6639           VT.getFixedSizeInBits() - ShiftNo->getZExtValue();
6640       return DAG.getNode(
6641           ISD::FSHR, DL, VT, Op.getOperand(0), Op.getOperand(1),
6642           DAG.getConstant(NewShiftNo, DL, Shifts.getValueType()));
6643     } else if (Op.getOpcode() == ISD::FSHR) {
6644       return Op;
6645     }
6646   }
6647 
6648   return SDValue();
6649 }
6650 
LowerFLDEXP(SDValue Op,SelectionDAG & DAG)6651 static SDValue LowerFLDEXP(SDValue Op, SelectionDAG &DAG) {
6652   SDValue X = Op.getOperand(0);
6653   EVT XScalarTy = X.getValueType();
6654   SDValue Exp = Op.getOperand(1);
6655 
6656   SDLoc DL(Op);
6657   EVT XVT, ExpVT;
6658   switch (Op.getSimpleValueType().SimpleTy) {
6659   default:
6660     return SDValue();
6661   case MVT::bf16:
6662   case MVT::f16:
6663     X = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, X);
6664     [[fallthrough]];
6665   case MVT::f32:
6666     XVT = MVT::nxv4f32;
6667     ExpVT = MVT::nxv4i32;
6668     break;
6669   case MVT::f64:
6670     XVT = MVT::nxv2f64;
6671     ExpVT = MVT::nxv2i64;
6672     Exp = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, Exp);
6673     break;
6674   }
6675 
6676   SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
6677   SDValue VX =
6678       DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, XVT, DAG.getUNDEF(XVT), X, Zero);
6679   SDValue VExp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ExpVT,
6680                              DAG.getUNDEF(ExpVT), Exp, Zero);
6681   SDValue VPg = getPTrue(DAG, DL, XVT.changeVectorElementType(MVT::i1),
6682                          AArch64SVEPredPattern::all);
6683   SDValue FScale =
6684       DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, XVT,
6685                   DAG.getConstant(Intrinsic::aarch64_sve_fscale, DL, MVT::i64),
6686                   VPg, VX, VExp);
6687   SDValue Final =
6688       DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, X.getValueType(), FScale, Zero);
6689   if (X.getValueType() != XScalarTy)
6690     Final = DAG.getNode(ISD::FP_ROUND, DL, XScalarTy, Final,
6691                         DAG.getIntPtrConstant(1, SDLoc(Op)));
6692   return Final;
6693 }
6694 
LowerADJUST_TRAMPOLINE(SDValue Op,SelectionDAG & DAG) const6695 SDValue AArch64TargetLowering::LowerADJUST_TRAMPOLINE(SDValue Op,
6696                                                       SelectionDAG &DAG) const {
6697   // Note: x18 cannot be used for the Nest parameter on Windows and macOS.
6698   if (Subtarget->isTargetDarwin() || Subtarget->isTargetWindows())
6699     report_fatal_error(
6700         "ADJUST_TRAMPOLINE operation is only supported on Linux.");
6701 
6702   return Op.getOperand(0);
6703 }
6704 
LowerINIT_TRAMPOLINE(SDValue Op,SelectionDAG & DAG) const6705 SDValue AArch64TargetLowering::LowerINIT_TRAMPOLINE(SDValue Op,
6706                                                     SelectionDAG &DAG) const {
6707 
6708   // Note: x18 cannot be used for the Nest parameter on Windows and macOS.
6709   if (Subtarget->isTargetDarwin() || Subtarget->isTargetWindows())
6710     report_fatal_error("INIT_TRAMPOLINE operation is only supported on Linux.");
6711 
6712   SDValue Chain = Op.getOperand(0);
6713   SDValue Trmp = Op.getOperand(1); // trampoline
6714   SDValue FPtr = Op.getOperand(2); // nested function
6715   SDValue Nest = Op.getOperand(3); // 'nest' parameter value
6716   SDLoc dl(Op);
6717 
6718   EVT PtrVT = getPointerTy(DAG.getDataLayout());
6719   Type *IntPtrTy = DAG.getDataLayout().getIntPtrType(*DAG.getContext());
6720 
6721   TargetLowering::ArgListTy Args;
6722   TargetLowering::ArgListEntry Entry;
6723 
6724   Entry.Ty = IntPtrTy;
6725   Entry.Node = Trmp;
6726   Args.push_back(Entry);
6727   Entry.Node = DAG.getConstant(20, dl, MVT::i64);
6728   Args.push_back(Entry);
6729 
6730   Entry.Node = FPtr;
6731   Args.push_back(Entry);
6732   Entry.Node = Nest;
6733   Args.push_back(Entry);
6734 
6735   // Lower to a call to __trampoline_setup(Trmp, TrampSize, FPtr, ctx_reg)
6736   TargetLowering::CallLoweringInfo CLI(DAG);
6737   CLI.setDebugLoc(dl).setChain(Chain).setLibCallee(
6738       CallingConv::C, Type::getVoidTy(*DAG.getContext()),
6739       DAG.getExternalSymbol("__trampoline_setup", PtrVT), std::move(Args));
6740 
6741   std::pair<SDValue, SDValue> CallResult = LowerCallTo(CLI);
6742   return CallResult.second;
6743 }
6744 
LowerOperation(SDValue Op,SelectionDAG & DAG) const6745 SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
6746                                               SelectionDAG &DAG) const {
6747   LLVM_DEBUG(dbgs() << "Custom lowering: ");
6748   LLVM_DEBUG(Op.dump());
6749 
6750   switch (Op.getOpcode()) {
6751   default:
6752     llvm_unreachable("unimplemented operand");
6753     return SDValue();
6754   case ISD::BITCAST:
6755     return LowerBITCAST(Op, DAG);
6756   case ISD::GlobalAddress:
6757     return LowerGlobalAddress(Op, DAG);
6758   case ISD::GlobalTLSAddress:
6759     return LowerGlobalTLSAddress(Op, DAG);
6760   case ISD::PtrAuthGlobalAddress:
6761     return LowerPtrAuthGlobalAddress(Op, DAG);
6762   case ISD::ADJUST_TRAMPOLINE:
6763     return LowerADJUST_TRAMPOLINE(Op, DAG);
6764   case ISD::INIT_TRAMPOLINE:
6765     return LowerINIT_TRAMPOLINE(Op, DAG);
6766   case ISD::SETCC:
6767   case ISD::STRICT_FSETCC:
6768   case ISD::STRICT_FSETCCS:
6769     return LowerSETCC(Op, DAG);
6770   case ISD::SETCCCARRY:
6771     return LowerSETCCCARRY(Op, DAG);
6772   case ISD::BRCOND:
6773     return LowerBRCOND(Op, DAG);
6774   case ISD::BR_CC:
6775     return LowerBR_CC(Op, DAG);
6776   case ISD::SELECT:
6777     return LowerSELECT(Op, DAG);
6778   case ISD::SELECT_CC:
6779     return LowerSELECT_CC(Op, DAG);
6780   case ISD::JumpTable:
6781     return LowerJumpTable(Op, DAG);
6782   case ISD::BR_JT:
6783     return LowerBR_JT(Op, DAG);
6784   case ISD::BRIND:
6785     return LowerBRIND(Op, DAG);
6786   case ISD::ConstantPool:
6787     return LowerConstantPool(Op, DAG);
6788   case ISD::BlockAddress:
6789     return LowerBlockAddress(Op, DAG);
6790   case ISD::VASTART:
6791     return LowerVASTART(Op, DAG);
6792   case ISD::VACOPY:
6793     return LowerVACOPY(Op, DAG);
6794   case ISD::VAARG:
6795     return LowerVAARG(Op, DAG);
6796   case ISD::UADDO_CARRY:
6797     return lowerADDSUBO_CARRY(Op, DAG, AArch64ISD::ADCS, false /*unsigned*/);
6798   case ISD::USUBO_CARRY:
6799     return lowerADDSUBO_CARRY(Op, DAG, AArch64ISD::SBCS, false /*unsigned*/);
6800   case ISD::SADDO_CARRY:
6801     return lowerADDSUBO_CARRY(Op, DAG, AArch64ISD::ADCS, true /*signed*/);
6802   case ISD::SSUBO_CARRY:
6803     return lowerADDSUBO_CARRY(Op, DAG, AArch64ISD::SBCS, true /*signed*/);
6804   case ISD::SADDO:
6805   case ISD::UADDO:
6806   case ISD::SSUBO:
6807   case ISD::USUBO:
6808   case ISD::SMULO:
6809   case ISD::UMULO:
6810     return LowerXALUO(Op, DAG);
6811   case ISD::FADD:
6812     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FADD_PRED);
6813   case ISD::FSUB:
6814     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSUB_PRED);
6815   case ISD::FMUL:
6816     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
6817   case ISD::FMA:
6818     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
6819   case ISD::FDIV:
6820     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FDIV_PRED);
6821   case ISD::FNEG:
6822     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU);
6823   case ISD::FCEIL:
6824     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FCEIL_MERGE_PASSTHRU);
6825   case ISD::FFLOOR:
6826     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FFLOOR_MERGE_PASSTHRU);
6827   case ISD::FNEARBYINT:
6828     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FNEARBYINT_MERGE_PASSTHRU);
6829   case ISD::FRINT:
6830     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FRINT_MERGE_PASSTHRU);
6831   case ISD::FROUND:
6832     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FROUND_MERGE_PASSTHRU);
6833   case ISD::FROUNDEVEN:
6834     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU);
6835   case ISD::FTRUNC:
6836     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FTRUNC_MERGE_PASSTHRU);
6837   case ISD::FSQRT:
6838     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSQRT_MERGE_PASSTHRU);
6839   case ISD::FABS:
6840     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FABS_MERGE_PASSTHRU);
6841   case ISD::FP_ROUND:
6842   case ISD::STRICT_FP_ROUND:
6843     return LowerFP_ROUND(Op, DAG);
6844   case ISD::FP_EXTEND:
6845     return LowerFP_EXTEND(Op, DAG);
6846   case ISD::FRAMEADDR:
6847     return LowerFRAMEADDR(Op, DAG);
6848   case ISD::SPONENTRY:
6849     return LowerSPONENTRY(Op, DAG);
6850   case ISD::RETURNADDR:
6851     return LowerRETURNADDR(Op, DAG);
6852   case ISD::ADDROFRETURNADDR:
6853     return LowerADDROFRETURNADDR(Op, DAG);
6854   case ISD::CONCAT_VECTORS:
6855     return LowerCONCAT_VECTORS(Op, DAG);
6856   case ISD::INSERT_VECTOR_ELT:
6857     return LowerINSERT_VECTOR_ELT(Op, DAG);
6858   case ISD::EXTRACT_VECTOR_ELT:
6859     return LowerEXTRACT_VECTOR_ELT(Op, DAG);
6860   case ISD::BUILD_VECTOR:
6861     return LowerBUILD_VECTOR(Op, DAG);
6862   case ISD::ZERO_EXTEND_VECTOR_INREG:
6863     return LowerZERO_EXTEND_VECTOR_INREG(Op, DAG);
6864   case ISD::VECTOR_SHUFFLE:
6865     return LowerVECTOR_SHUFFLE(Op, DAG);
6866   case ISD::SPLAT_VECTOR:
6867     return LowerSPLAT_VECTOR(Op, DAG);
6868   case ISD::EXTRACT_SUBVECTOR:
6869     return LowerEXTRACT_SUBVECTOR(Op, DAG);
6870   case ISD::INSERT_SUBVECTOR:
6871     return LowerINSERT_SUBVECTOR(Op, DAG);
6872   case ISD::SDIV:
6873   case ISD::UDIV:
6874     return LowerDIV(Op, DAG);
6875   case ISD::SMIN:
6876   case ISD::UMIN:
6877   case ISD::SMAX:
6878   case ISD::UMAX:
6879     return LowerMinMax(Op, DAG);
6880   case ISD::SRA:
6881   case ISD::SRL:
6882   case ISD::SHL:
6883     return LowerVectorSRA_SRL_SHL(Op, DAG);
6884   case ISD::SHL_PARTS:
6885   case ISD::SRL_PARTS:
6886   case ISD::SRA_PARTS:
6887     return LowerShiftParts(Op, DAG);
6888   case ISD::CTPOP:
6889   case ISD::PARITY:
6890     return LowerCTPOP_PARITY(Op, DAG);
6891   case ISD::FCOPYSIGN:
6892     return LowerFCOPYSIGN(Op, DAG);
6893   case ISD::OR:
6894     return LowerVectorOR(Op, DAG);
6895   case ISD::XOR:
6896     return LowerXOR(Op, DAG);
6897   case ISD::PREFETCH:
6898     return LowerPREFETCH(Op, DAG);
6899   case ISD::SINT_TO_FP:
6900   case ISD::UINT_TO_FP:
6901   case ISD::STRICT_SINT_TO_FP:
6902   case ISD::STRICT_UINT_TO_FP:
6903     return LowerINT_TO_FP(Op, DAG);
6904   case ISD::FP_TO_SINT:
6905   case ISD::FP_TO_UINT:
6906   case ISD::STRICT_FP_TO_SINT:
6907   case ISD::STRICT_FP_TO_UINT:
6908     return LowerFP_TO_INT(Op, DAG);
6909   case ISD::FP_TO_SINT_SAT:
6910   case ISD::FP_TO_UINT_SAT:
6911     return LowerFP_TO_INT_SAT(Op, DAG);
6912   case ISD::FSINCOS:
6913     return LowerFSINCOS(Op, DAG);
6914   case ISD::GET_ROUNDING:
6915     return LowerGET_ROUNDING(Op, DAG);
6916   case ISD::SET_ROUNDING:
6917     return LowerSET_ROUNDING(Op, DAG);
6918   case ISD::GET_FPMODE:
6919     return LowerGET_FPMODE(Op, DAG);
6920   case ISD::SET_FPMODE:
6921     return LowerSET_FPMODE(Op, DAG);
6922   case ISD::RESET_FPMODE:
6923     return LowerRESET_FPMODE(Op, DAG);
6924   case ISD::MUL:
6925     return LowerMUL(Op, DAG);
6926   case ISD::MULHS:
6927     return LowerToPredicatedOp(Op, DAG, AArch64ISD::MULHS_PRED);
6928   case ISD::MULHU:
6929     return LowerToPredicatedOp(Op, DAG, AArch64ISD::MULHU_PRED);
6930   case ISD::INTRINSIC_W_CHAIN:
6931     return LowerINTRINSIC_W_CHAIN(Op, DAG);
6932   case ISD::INTRINSIC_WO_CHAIN:
6933     return LowerINTRINSIC_WO_CHAIN(Op, DAG);
6934   case ISD::INTRINSIC_VOID:
6935     return LowerINTRINSIC_VOID(Op, DAG);
6936   case ISD::ATOMIC_STORE:
6937     if (cast<MemSDNode>(Op)->getMemoryVT() == MVT::i128) {
6938       assert(Subtarget->hasLSE2() || Subtarget->hasRCPC3());
6939       return LowerStore128(Op, DAG);
6940     }
6941     return SDValue();
6942   case ISD::STORE:
6943     return LowerSTORE(Op, DAG);
6944   case ISD::MSTORE:
6945     return LowerFixedLengthVectorMStoreToSVE(Op, DAG);
6946   case ISD::MGATHER:
6947     return LowerMGATHER(Op, DAG);
6948   case ISD::MSCATTER:
6949     return LowerMSCATTER(Op, DAG);
6950   case ISD::VECREDUCE_SEQ_FADD:
6951     return LowerVECREDUCE_SEQ_FADD(Op, DAG);
6952   case ISD::VECREDUCE_ADD:
6953   case ISD::VECREDUCE_AND:
6954   case ISD::VECREDUCE_OR:
6955   case ISD::VECREDUCE_XOR:
6956   case ISD::VECREDUCE_SMAX:
6957   case ISD::VECREDUCE_SMIN:
6958   case ISD::VECREDUCE_UMAX:
6959   case ISD::VECREDUCE_UMIN:
6960   case ISD::VECREDUCE_FADD:
6961   case ISD::VECREDUCE_FMAX:
6962   case ISD::VECREDUCE_FMIN:
6963   case ISD::VECREDUCE_FMAXIMUM:
6964   case ISD::VECREDUCE_FMINIMUM:
6965     return LowerVECREDUCE(Op, DAG);
6966   case ISD::ATOMIC_LOAD_AND:
6967     return LowerATOMIC_LOAD_AND(Op, DAG);
6968   case ISD::DYNAMIC_STACKALLOC:
6969     return LowerDYNAMIC_STACKALLOC(Op, DAG);
6970   case ISD::VSCALE:
6971     return LowerVSCALE(Op, DAG);
6972   case ISD::ANY_EXTEND:
6973   case ISD::SIGN_EXTEND:
6974   case ISD::ZERO_EXTEND:
6975     return LowerFixedLengthVectorIntExtendToSVE(Op, DAG);
6976   case ISD::SIGN_EXTEND_INREG: {
6977     // Only custom lower when ExtraVT has a legal byte based element type.
6978     EVT ExtraVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
6979     EVT ExtraEltVT = ExtraVT.getVectorElementType();
6980     if ((ExtraEltVT != MVT::i8) && (ExtraEltVT != MVT::i16) &&
6981         (ExtraEltVT != MVT::i32) && (ExtraEltVT != MVT::i64))
6982       return SDValue();
6983 
6984     return LowerToPredicatedOp(Op, DAG,
6985                                AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU);
6986   }
6987   case ISD::TRUNCATE:
6988     return LowerTRUNCATE(Op, DAG);
6989   case ISD::MLOAD:
6990     return LowerMLOAD(Op, DAG);
6991   case ISD::LOAD:
6992     if (useSVEForFixedLengthVectorVT(Op.getValueType(),
6993                                      !Subtarget->isNeonAvailable()))
6994       return LowerFixedLengthVectorLoadToSVE(Op, DAG);
6995     return LowerLOAD(Op, DAG);
6996   case ISD::ADD:
6997   case ISD::AND:
6998   case ISD::SUB:
6999     return LowerToScalableOp(Op, DAG);
7000   case ISD::FMAXIMUM:
7001     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMAX_PRED);
7002   case ISD::FMAXNUM:
7003     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMAXNM_PRED);
7004   case ISD::FMINIMUM:
7005     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMIN_PRED);
7006   case ISD::FMINNUM:
7007     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMINNM_PRED);
7008   case ISD::VSELECT:
7009     return LowerFixedLengthVectorSelectToSVE(Op, DAG);
7010   case ISD::ABS:
7011     return LowerABS(Op, DAG);
7012   case ISD::ABDS:
7013     return LowerToPredicatedOp(Op, DAG, AArch64ISD::ABDS_PRED);
7014   case ISD::ABDU:
7015     return LowerToPredicatedOp(Op, DAG, AArch64ISD::ABDU_PRED);
7016   case ISD::AVGFLOORS:
7017     return LowerAVG(Op, DAG, AArch64ISD::HADDS_PRED);
7018   case ISD::AVGFLOORU:
7019     return LowerAVG(Op, DAG, AArch64ISD::HADDU_PRED);
7020   case ISD::AVGCEILS:
7021     return LowerAVG(Op, DAG, AArch64ISD::RHADDS_PRED);
7022   case ISD::AVGCEILU:
7023     return LowerAVG(Op, DAG, AArch64ISD::RHADDU_PRED);
7024   case ISD::BITREVERSE:
7025     return LowerBitreverse(Op, DAG);
7026   case ISD::BSWAP:
7027     return LowerToPredicatedOp(Op, DAG, AArch64ISD::BSWAP_MERGE_PASSTHRU);
7028   case ISD::CTLZ:
7029     return LowerToPredicatedOp(Op, DAG, AArch64ISD::CTLZ_MERGE_PASSTHRU);
7030   case ISD::CTTZ:
7031     return LowerCTTZ(Op, DAG);
7032   case ISD::VECTOR_SPLICE:
7033     return LowerVECTOR_SPLICE(Op, DAG);
7034   case ISD::VECTOR_DEINTERLEAVE:
7035     return LowerVECTOR_DEINTERLEAVE(Op, DAG);
7036   case ISD::VECTOR_INTERLEAVE:
7037     return LowerVECTOR_INTERLEAVE(Op, DAG);
7038   case ISD::LRINT:
7039   case ISD::LLRINT:
7040     if (Op.getValueType().isVector())
7041       return LowerVectorXRINT(Op, DAG);
7042     [[fallthrough]];
7043   case ISD::LROUND:
7044   case ISD::LLROUND: {
7045     assert((Op.getOperand(0).getValueType() == MVT::f16 ||
7046             Op.getOperand(0).getValueType() == MVT::bf16) &&
7047            "Expected custom lowering of rounding operations only for f16");
7048     SDLoc DL(Op);
7049     SDValue Ext = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Op.getOperand(0));
7050     return DAG.getNode(Op.getOpcode(), DL, Op.getValueType(), Ext);
7051   }
7052   case ISD::STRICT_LROUND:
7053   case ISD::STRICT_LLROUND:
7054   case ISD::STRICT_LRINT:
7055   case ISD::STRICT_LLRINT: {
7056     assert((Op.getOperand(1).getValueType() == MVT::f16 ||
7057             Op.getOperand(1).getValueType() == MVT::bf16) &&
7058            "Expected custom lowering of rounding operations only for f16");
7059     SDLoc DL(Op);
7060     SDValue Ext = DAG.getNode(ISD::STRICT_FP_EXTEND, DL, {MVT::f32, MVT::Other},
7061                               {Op.getOperand(0), Op.getOperand(1)});
7062     return DAG.getNode(Op.getOpcode(), DL, {Op.getValueType(), MVT::Other},
7063                        {Ext.getValue(1), Ext.getValue(0)});
7064   }
7065   case ISD::WRITE_REGISTER: {
7066     assert(Op.getOperand(2).getValueType() == MVT::i128 &&
7067            "WRITE_REGISTER custom lowering is only for 128-bit sysregs");
7068     SDLoc DL(Op);
7069 
7070     SDValue Chain = Op.getOperand(0);
7071     SDValue SysRegName = Op.getOperand(1);
7072     std::pair<SDValue, SDValue> Pair =
7073         DAG.SplitScalar(Op.getOperand(2), DL, MVT::i64, MVT::i64);
7074 
7075     // chain = MSRR(chain, sysregname, lo, hi)
7076     SDValue Result = DAG.getNode(AArch64ISD::MSRR, DL, MVT::Other, Chain,
7077                                  SysRegName, Pair.first, Pair.second);
7078 
7079     return Result;
7080   }
7081   case ISD::FSHL:
7082   case ISD::FSHR:
7083     return LowerFunnelShift(Op, DAG);
7084   case ISD::FLDEXP:
7085     return LowerFLDEXP(Op, DAG);
7086   case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
7087     return LowerVECTOR_HISTOGRAM(Op, DAG);
7088   }
7089 }
7090 
mergeStoresAfterLegalization(EVT VT) const7091 bool AArch64TargetLowering::mergeStoresAfterLegalization(EVT VT) const {
7092   return !Subtarget->useSVEForFixedLengthVectors();
7093 }
7094 
useSVEForFixedLengthVectorVT(EVT VT,bool OverrideNEON) const7095 bool AArch64TargetLowering::useSVEForFixedLengthVectorVT(
7096     EVT VT, bool OverrideNEON) const {
7097   if (!VT.isFixedLengthVector() || !VT.isSimple())
7098     return false;
7099 
7100   // Don't use SVE for vectors we cannot scalarize if required.
7101   switch (VT.getVectorElementType().getSimpleVT().SimpleTy) {
7102   // Fixed length predicates should be promoted to i8.
7103   // NOTE: This is consistent with how NEON (and thus 64/128bit vectors) work.
7104   case MVT::i1:
7105   default:
7106     return false;
7107   case MVT::i8:
7108   case MVT::i16:
7109   case MVT::i32:
7110   case MVT::i64:
7111   case MVT::f16:
7112   case MVT::f32:
7113   case MVT::f64:
7114     break;
7115   }
7116 
7117   // NEON-sized vectors can be emulated using SVE instructions.
7118   if (OverrideNEON && (VT.is128BitVector() || VT.is64BitVector()))
7119     return Subtarget->isSVEorStreamingSVEAvailable();
7120 
7121   // Ensure NEON MVTs only belong to a single register class.
7122   if (VT.getFixedSizeInBits() <= 128)
7123     return false;
7124 
7125   // Ensure wider than NEON code generation is enabled.
7126   if (!Subtarget->useSVEForFixedLengthVectors())
7127     return false;
7128 
7129   // Don't use SVE for types that don't fit.
7130   if (VT.getFixedSizeInBits() > Subtarget->getMinSVEVectorSizeInBits())
7131     return false;
7132 
7133   // TODO: Perhaps an artificial restriction, but worth having whilst getting
7134   // the base fixed length SVE support in place.
7135   if (!VT.isPow2VectorType())
7136     return false;
7137 
7138   return true;
7139 }
7140 
7141 //===----------------------------------------------------------------------===//
7142 //                      Calling Convention Implementation
7143 //===----------------------------------------------------------------------===//
7144 
getIntrinsicID(const SDNode * N)7145 static unsigned getIntrinsicID(const SDNode *N) {
7146   unsigned Opcode = N->getOpcode();
7147   switch (Opcode) {
7148   default:
7149     return Intrinsic::not_intrinsic;
7150   case ISD::INTRINSIC_WO_CHAIN: {
7151     unsigned IID = N->getConstantOperandVal(0);
7152     if (IID < Intrinsic::num_intrinsics)
7153       return IID;
7154     return Intrinsic::not_intrinsic;
7155   }
7156   }
7157 }
7158 
isReassocProfitable(SelectionDAG & DAG,SDValue N0,SDValue N1) const7159 bool AArch64TargetLowering::isReassocProfitable(SelectionDAG &DAG, SDValue N0,
7160                                                 SDValue N1) const {
7161   if (!N0.hasOneUse())
7162     return false;
7163 
7164   unsigned IID = getIntrinsicID(N1.getNode());
7165   // Avoid reassociating expressions that can be lowered to smlal/umlal.
7166   if (IID == Intrinsic::aarch64_neon_umull ||
7167       N1.getOpcode() == AArch64ISD::UMULL ||
7168       IID == Intrinsic::aarch64_neon_smull ||
7169       N1.getOpcode() == AArch64ISD::SMULL)
7170     return N0.getOpcode() != ISD::ADD;
7171 
7172   return true;
7173 }
7174 
7175 /// Selects the correct CCAssignFn for a given CallingConvention value.
CCAssignFnForCall(CallingConv::ID CC,bool IsVarArg) const7176 CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC,
7177                                                      bool IsVarArg) const {
7178   switch (CC) {
7179   default:
7180     report_fatal_error("Unsupported calling convention.");
7181   case CallingConv::GHC:
7182     return CC_AArch64_GHC;
7183   case CallingConv::PreserveNone:
7184     // The VarArg implementation makes assumptions about register
7185     // argument passing that do not hold for preserve_none, so we
7186     // instead fall back to C argument passing.
7187     // The non-vararg case is handled in the CC function itself.
7188     if (!IsVarArg)
7189       return CC_AArch64_Preserve_None;
7190     [[fallthrough]];
7191   case CallingConv::C:
7192   case CallingConv::Fast:
7193   case CallingConv::PreserveMost:
7194   case CallingConv::PreserveAll:
7195   case CallingConv::CXX_FAST_TLS:
7196   case CallingConv::Swift:
7197   case CallingConv::SwiftTail:
7198   case CallingConv::Tail:
7199   case CallingConv::GRAAL:
7200     if (Subtarget->isTargetWindows()) {
7201       if (IsVarArg) {
7202         if (Subtarget->isWindowsArm64EC())
7203           return CC_AArch64_Arm64EC_VarArg;
7204         return CC_AArch64_Win64_VarArg;
7205       }
7206       return CC_AArch64_Win64PCS;
7207     }
7208     if (!Subtarget->isTargetDarwin())
7209       return CC_AArch64_AAPCS;
7210     if (!IsVarArg)
7211       return CC_AArch64_DarwinPCS;
7212     return Subtarget->isTargetILP32() ? CC_AArch64_DarwinPCS_ILP32_VarArg
7213                                       : CC_AArch64_DarwinPCS_VarArg;
7214   case CallingConv::Win64:
7215     if (IsVarArg) {
7216       if (Subtarget->isWindowsArm64EC())
7217         return CC_AArch64_Arm64EC_VarArg;
7218       return CC_AArch64_Win64_VarArg;
7219     }
7220     return CC_AArch64_Win64PCS;
7221   case CallingConv::CFGuard_Check:
7222     if (Subtarget->isWindowsArm64EC())
7223       return CC_AArch64_Arm64EC_CFGuard_Check;
7224     return CC_AArch64_Win64_CFGuard_Check;
7225   case CallingConv::AArch64_VectorCall:
7226   case CallingConv::AArch64_SVE_VectorCall:
7227   case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0:
7228   case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2:
7229     return CC_AArch64_AAPCS;
7230   case CallingConv::ARM64EC_Thunk_X64:
7231     return CC_AArch64_Arm64EC_Thunk;
7232   case CallingConv::ARM64EC_Thunk_Native:
7233     return CC_AArch64_Arm64EC_Thunk_Native;
7234   }
7235 }
7236 
7237 CCAssignFn *
CCAssignFnForReturn(CallingConv::ID CC) const7238 AArch64TargetLowering::CCAssignFnForReturn(CallingConv::ID CC) const {
7239   switch (CC) {
7240   default:
7241     return RetCC_AArch64_AAPCS;
7242   case CallingConv::ARM64EC_Thunk_X64:
7243     return RetCC_AArch64_Arm64EC_Thunk;
7244   case CallingConv::CFGuard_Check:
7245     if (Subtarget->isWindowsArm64EC())
7246       return RetCC_AArch64_Arm64EC_CFGuard_Check;
7247     return RetCC_AArch64_AAPCS;
7248   }
7249 }
7250 
isPassedInFPR(EVT VT)7251 static bool isPassedInFPR(EVT VT) {
7252   return VT.isFixedLengthVector() ||
7253          (VT.isFloatingPoint() && !VT.isScalableVector());
7254 }
7255 
LowerFormalArguments(SDValue Chain,CallingConv::ID CallConv,bool isVarArg,const SmallVectorImpl<ISD::InputArg> & Ins,const SDLoc & DL,SelectionDAG & DAG,SmallVectorImpl<SDValue> & InVals) const7256 SDValue AArch64TargetLowering::LowerFormalArguments(
7257     SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
7258     const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &DL,
7259     SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
7260   MachineFunction &MF = DAG.getMachineFunction();
7261   const Function &F = MF.getFunction();
7262   MachineFrameInfo &MFI = MF.getFrameInfo();
7263   bool IsWin64 =
7264       Subtarget->isCallingConvWin64(F.getCallingConv(), F.isVarArg());
7265   bool StackViaX4 = CallConv == CallingConv::ARM64EC_Thunk_X64 ||
7266                     (isVarArg && Subtarget->isWindowsArm64EC());
7267   AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
7268 
7269   SmallVector<ISD::OutputArg, 4> Outs;
7270   GetReturnInfo(CallConv, F.getReturnType(), F.getAttributes(), Outs,
7271                 DAG.getTargetLoweringInfo(), MF.getDataLayout());
7272   if (any_of(Outs, [](ISD::OutputArg &Out){ return Out.VT.isScalableVector(); }))
7273     FuncInfo->setIsSVECC(true);
7274 
7275   // Assign locations to all of the incoming arguments.
7276   SmallVector<CCValAssign, 16> ArgLocs;
7277   DenseMap<unsigned, SDValue> CopiedRegs;
7278   CCState CCInfo(CallConv, isVarArg, MF, ArgLocs, *DAG.getContext());
7279 
7280   // At this point, Ins[].VT may already be promoted to i32. To correctly
7281   // handle passing i8 as i8 instead of i32 on stack, we pass in both i32 and
7282   // i8 to CC_AArch64_AAPCS with i32 being ValVT and i8 being LocVT.
7283   // Since AnalyzeFormalArguments uses Ins[].VT for both ValVT and LocVT, here
7284   // we use a special version of AnalyzeFormalArguments to pass in ValVT and
7285   // LocVT.
7286   unsigned NumArgs = Ins.size();
7287   Function::const_arg_iterator CurOrigArg = F.arg_begin();
7288   unsigned CurArgIdx = 0;
7289   for (unsigned i = 0; i != NumArgs; ++i) {
7290     MVT ValVT = Ins[i].VT;
7291     if (Ins[i].isOrigArg()) {
7292       std::advance(CurOrigArg, Ins[i].getOrigArgIndex() - CurArgIdx);
7293       CurArgIdx = Ins[i].getOrigArgIndex();
7294 
7295       // Get type of the original argument.
7296       EVT ActualVT = getValueType(DAG.getDataLayout(), CurOrigArg->getType(),
7297                                   /*AllowUnknown*/ true);
7298       MVT ActualMVT = ActualVT.isSimple() ? ActualVT.getSimpleVT() : MVT::Other;
7299       // If ActualMVT is i1/i8/i16, we should set LocVT to i8/i8/i16.
7300       if (ActualMVT == MVT::i1 || ActualMVT == MVT::i8)
7301         ValVT = MVT::i8;
7302       else if (ActualMVT == MVT::i16)
7303         ValVT = MVT::i16;
7304     }
7305     bool UseVarArgCC = false;
7306     if (IsWin64)
7307       UseVarArgCC = isVarArg;
7308     CCAssignFn *AssignFn = CCAssignFnForCall(CallConv, UseVarArgCC);
7309     bool Res =
7310         AssignFn(i, ValVT, ValVT, CCValAssign::Full, Ins[i].Flags, CCInfo);
7311     assert(!Res && "Call operand has unhandled type");
7312     (void)Res;
7313   }
7314 
7315   SMEAttrs Attrs(MF.getFunction());
7316   bool IsLocallyStreaming =
7317       !Attrs.hasStreamingInterface() && Attrs.hasStreamingBody();
7318   assert(Chain.getOpcode() == ISD::EntryToken && "Unexpected Chain value");
7319   SDValue Glue = Chain.getValue(1);
7320 
7321   SmallVector<SDValue, 16> ArgValues;
7322   unsigned ExtraArgLocs = 0;
7323   for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
7324     CCValAssign &VA = ArgLocs[i - ExtraArgLocs];
7325 
7326     if (Ins[i].Flags.isByVal()) {
7327       // Byval is used for HFAs in the PCS, but the system should work in a
7328       // non-compliant manner for larger structs.
7329       EVT PtrVT = getPointerTy(DAG.getDataLayout());
7330       int Size = Ins[i].Flags.getByValSize();
7331       unsigned NumRegs = (Size + 7) / 8;
7332 
7333       // FIXME: This works on big-endian for composite byvals, which are the common
7334       // case. It should also work for fundamental types too.
7335       unsigned FrameIdx =
7336         MFI.CreateFixedObject(8 * NumRegs, VA.getLocMemOffset(), false);
7337       SDValue FrameIdxN = DAG.getFrameIndex(FrameIdx, PtrVT);
7338       InVals.push_back(FrameIdxN);
7339 
7340       continue;
7341     }
7342 
7343     if (Ins[i].Flags.isSwiftAsync())
7344       MF.getInfo<AArch64FunctionInfo>()->setHasSwiftAsyncContext(true);
7345 
7346     SDValue ArgValue;
7347     if (VA.isRegLoc()) {
7348       // Arguments stored in registers.
7349       EVT RegVT = VA.getLocVT();
7350       const TargetRegisterClass *RC;
7351 
7352       if (RegVT == MVT::i32)
7353         RC = &AArch64::GPR32RegClass;
7354       else if (RegVT == MVT::i64)
7355         RC = &AArch64::GPR64RegClass;
7356       else if (RegVT == MVT::f16 || RegVT == MVT::bf16)
7357         RC = &AArch64::FPR16RegClass;
7358       else if (RegVT == MVT::f32)
7359         RC = &AArch64::FPR32RegClass;
7360       else if (RegVT == MVT::f64 || RegVT.is64BitVector())
7361         RC = &AArch64::FPR64RegClass;
7362       else if (RegVT == MVT::f128 || RegVT.is128BitVector())
7363         RC = &AArch64::FPR128RegClass;
7364       else if (RegVT.isScalableVector() &&
7365                RegVT.getVectorElementType() == MVT::i1) {
7366         FuncInfo->setIsSVECC(true);
7367         RC = &AArch64::PPRRegClass;
7368       } else if (RegVT == MVT::aarch64svcount) {
7369         FuncInfo->setIsSVECC(true);
7370         RC = &AArch64::PPRRegClass;
7371       } else if (RegVT.isScalableVector()) {
7372         FuncInfo->setIsSVECC(true);
7373         RC = &AArch64::ZPRRegClass;
7374       } else
7375         llvm_unreachable("RegVT not supported by FORMAL_ARGUMENTS Lowering");
7376 
7377       // Transform the arguments in physical registers into virtual ones.
7378       Register Reg = MF.addLiveIn(VA.getLocReg(), RC);
7379 
7380       if (IsLocallyStreaming) {
7381         // LocallyStreamingFunctions must insert the SMSTART in the correct
7382         // position, so we use Glue to ensure no instructions can be scheduled
7383         // between the chain of:
7384         //        t0: ch,glue = EntryNode
7385         //      t1:  res,ch,glue = CopyFromReg
7386         //     ...
7387         //   tn: res,ch,glue = CopyFromReg t(n-1), ..
7388         // t(n+1): ch, glue = SMSTART t0:0, ...., tn:2
7389         // ^^^^^^
7390         // This will be the new Chain/Root node.
7391         ArgValue = DAG.getCopyFromReg(Chain, DL, Reg, RegVT, Glue);
7392         Glue = ArgValue.getValue(2);
7393         if (isPassedInFPR(ArgValue.getValueType())) {
7394           ArgValue =
7395               DAG.getNode(AArch64ISD::COALESCER_BARRIER, DL,
7396                           DAG.getVTList(ArgValue.getValueType(), MVT::Glue),
7397                           {ArgValue, Glue});
7398           Glue = ArgValue.getValue(1);
7399         }
7400       } else
7401         ArgValue = DAG.getCopyFromReg(Chain, DL, Reg, RegVT);
7402 
7403       // If this is an 8, 16 or 32-bit value, it is really passed promoted
7404       // to 64 bits.  Insert an assert[sz]ext to capture this, then
7405       // truncate to the right size.
7406       switch (VA.getLocInfo()) {
7407       default:
7408         llvm_unreachable("Unknown loc info!");
7409       case CCValAssign::Full:
7410         break;
7411       case CCValAssign::Indirect:
7412         assert(
7413             (VA.getValVT().isScalableVT() || Subtarget->isWindowsArm64EC()) &&
7414             "Indirect arguments should be scalable on most subtargets");
7415         break;
7416       case CCValAssign::BCvt:
7417         ArgValue = DAG.getNode(ISD::BITCAST, DL, VA.getValVT(), ArgValue);
7418         break;
7419       case CCValAssign::AExt:
7420       case CCValAssign::SExt:
7421       case CCValAssign::ZExt:
7422         break;
7423       case CCValAssign::AExtUpper:
7424         ArgValue = DAG.getNode(ISD::SRL, DL, RegVT, ArgValue,
7425                                DAG.getConstant(32, DL, RegVT));
7426         ArgValue = DAG.getZExtOrTrunc(ArgValue, DL, VA.getValVT());
7427         break;
7428       }
7429     } else { // VA.isRegLoc()
7430       assert(VA.isMemLoc() && "CCValAssign is neither reg nor mem");
7431       unsigned ArgOffset = VA.getLocMemOffset();
7432       unsigned ArgSize = (VA.getLocInfo() == CCValAssign::Indirect
7433                               ? VA.getLocVT().getSizeInBits()
7434                               : VA.getValVT().getSizeInBits()) / 8;
7435 
7436       uint32_t BEAlign = 0;
7437       if (!Subtarget->isLittleEndian() && ArgSize < 8 &&
7438           !Ins[i].Flags.isInConsecutiveRegs())
7439         BEAlign = 8 - ArgSize;
7440 
7441       SDValue FIN;
7442       MachinePointerInfo PtrInfo;
7443       if (StackViaX4) {
7444         // In both the ARM64EC varargs convention and the thunk convention,
7445         // arguments on the stack are accessed relative to x4, not sp. In
7446         // the thunk convention, there's an additional offset of 32 bytes
7447         // to account for the shadow store.
7448         unsigned ObjOffset = ArgOffset + BEAlign;
7449         if (CallConv == CallingConv::ARM64EC_Thunk_X64)
7450           ObjOffset += 32;
7451         Register VReg = MF.addLiveIn(AArch64::X4, &AArch64::GPR64RegClass);
7452         SDValue Val = DAG.getCopyFromReg(Chain, DL, VReg, MVT::i64);
7453         FIN = DAG.getNode(ISD::ADD, DL, MVT::i64, Val,
7454                           DAG.getConstant(ObjOffset, DL, MVT::i64));
7455         PtrInfo = MachinePointerInfo::getUnknownStack(MF);
7456       } else {
7457         int FI = MFI.CreateFixedObject(ArgSize, ArgOffset + BEAlign, true);
7458 
7459         // Create load nodes to retrieve arguments from the stack.
7460         FIN = DAG.getFrameIndex(FI, getPointerTy(DAG.getDataLayout()));
7461         PtrInfo = MachinePointerInfo::getFixedStack(MF, FI);
7462       }
7463 
7464       // For NON_EXTLOAD, generic code in getLoad assert(ValVT == MemVT)
7465       ISD::LoadExtType ExtType = ISD::NON_EXTLOAD;
7466       MVT MemVT = VA.getValVT();
7467 
7468       switch (VA.getLocInfo()) {
7469       default:
7470         break;
7471       case CCValAssign::Trunc:
7472       case CCValAssign::BCvt:
7473         MemVT = VA.getLocVT();
7474         break;
7475       case CCValAssign::Indirect:
7476         assert((VA.getValVT().isScalableVector() ||
7477                 Subtarget->isWindowsArm64EC()) &&
7478                "Indirect arguments should be scalable on most subtargets");
7479         MemVT = VA.getLocVT();
7480         break;
7481       case CCValAssign::SExt:
7482         ExtType = ISD::SEXTLOAD;
7483         break;
7484       case CCValAssign::ZExt:
7485         ExtType = ISD::ZEXTLOAD;
7486         break;
7487       case CCValAssign::AExt:
7488         ExtType = ISD::EXTLOAD;
7489         break;
7490       }
7491 
7492       ArgValue = DAG.getExtLoad(ExtType, DL, VA.getLocVT(), Chain, FIN, PtrInfo,
7493                                 MemVT);
7494     }
7495 
7496     if (VA.getLocInfo() == CCValAssign::Indirect) {
7497       assert((VA.getValVT().isScalableVT() ||
7498               Subtarget->isWindowsArm64EC()) &&
7499              "Indirect arguments should be scalable on most subtargets");
7500 
7501       uint64_t PartSize = VA.getValVT().getStoreSize().getKnownMinValue();
7502       unsigned NumParts = 1;
7503       if (Ins[i].Flags.isInConsecutiveRegs()) {
7504         while (!Ins[i + NumParts - 1].Flags.isInConsecutiveRegsLast())
7505           ++NumParts;
7506       }
7507 
7508       MVT PartLoad = VA.getValVT();
7509       SDValue Ptr = ArgValue;
7510 
7511       // Ensure we generate all loads for each tuple part, whilst updating the
7512       // pointer after each load correctly using vscale.
7513       while (NumParts > 0) {
7514         ArgValue = DAG.getLoad(PartLoad, DL, Chain, Ptr, MachinePointerInfo());
7515         InVals.push_back(ArgValue);
7516         NumParts--;
7517         if (NumParts > 0) {
7518           SDValue BytesIncrement;
7519           if (PartLoad.isScalableVector()) {
7520             BytesIncrement = DAG.getVScale(
7521                 DL, Ptr.getValueType(),
7522                 APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize));
7523           } else {
7524             BytesIncrement = DAG.getConstant(
7525                 APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize), DL,
7526                 Ptr.getValueType());
7527           }
7528           SDNodeFlags Flags;
7529           Flags.setNoUnsignedWrap(true);
7530           Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr,
7531                             BytesIncrement, Flags);
7532           ExtraArgLocs++;
7533           i++;
7534         }
7535       }
7536     } else {
7537       if (Subtarget->isTargetILP32() && Ins[i].Flags.isPointer())
7538         ArgValue = DAG.getNode(ISD::AssertZext, DL, ArgValue.getValueType(),
7539                                ArgValue, DAG.getValueType(MVT::i32));
7540 
7541       // i1 arguments are zero-extended to i8 by the caller. Emit a
7542       // hint to reflect this.
7543       if (Ins[i].isOrigArg()) {
7544         Argument *OrigArg = F.getArg(Ins[i].getOrigArgIndex());
7545         if (OrigArg->getType()->isIntegerTy(1)) {
7546           if (!Ins[i].Flags.isZExt()) {
7547             ArgValue = DAG.getNode(AArch64ISD::ASSERT_ZEXT_BOOL, DL,
7548                                    ArgValue.getValueType(), ArgValue);
7549           }
7550         }
7551       }
7552 
7553       InVals.push_back(ArgValue);
7554     }
7555   }
7556   assert((ArgLocs.size() + ExtraArgLocs) == Ins.size());
7557 
7558   // Insert the SMSTART if this is a locally streaming function and
7559   // make sure it is Glued to the last CopyFromReg value.
7560   if (IsLocallyStreaming) {
7561     SDValue PStateSM;
7562     if (Attrs.hasStreamingCompatibleInterface()) {
7563       PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
7564       Register Reg = MF.getRegInfo().createVirtualRegister(
7565           getRegClassFor(PStateSM.getValueType().getSimpleVT()));
7566       FuncInfo->setPStateSMReg(Reg);
7567       Chain = DAG.getCopyToReg(Chain, DL, Reg, PStateSM);
7568       Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue,
7569                                   AArch64SME::IfCallerIsNonStreaming, PStateSM);
7570     } else
7571       Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue,
7572                                   AArch64SME::Always);
7573 
7574     // Ensure that the SMSTART happens after the CopyWithChain such that its
7575     // chain result is used.
7576     for (unsigned I=0; I<InVals.size(); ++I) {
7577       Register Reg = MF.getRegInfo().createVirtualRegister(
7578           getRegClassFor(InVals[I].getValueType().getSimpleVT()));
7579       Chain = DAG.getCopyToReg(Chain, DL, Reg, InVals[I]);
7580       InVals[I] = DAG.getCopyFromReg(Chain, DL, Reg,
7581                                      InVals[I].getValueType());
7582     }
7583   }
7584 
7585   // varargs
7586   if (isVarArg) {
7587     if (!Subtarget->isTargetDarwin() || IsWin64) {
7588       // The AAPCS variadic function ABI is identical to the non-variadic
7589       // one. As a result there may be more arguments in registers and we should
7590       // save them for future reference.
7591       // Win64 variadic functions also pass arguments in registers, but all float
7592       // arguments are passed in integer registers.
7593       saveVarArgRegisters(CCInfo, DAG, DL, Chain);
7594     }
7595 
7596     // This will point to the next argument passed via stack.
7597     unsigned VarArgsOffset = CCInfo.getStackSize();
7598     // We currently pass all varargs at 8-byte alignment, or 4 for ILP32
7599     VarArgsOffset = alignTo(VarArgsOffset, Subtarget->isTargetILP32() ? 4 : 8);
7600     FuncInfo->setVarArgsStackOffset(VarArgsOffset);
7601     FuncInfo->setVarArgsStackIndex(
7602         MFI.CreateFixedObject(4, VarArgsOffset, true));
7603 
7604     if (MFI.hasMustTailInVarArgFunc()) {
7605       SmallVector<MVT, 2> RegParmTypes;
7606       RegParmTypes.push_back(MVT::i64);
7607       RegParmTypes.push_back(MVT::f128);
7608       // Compute the set of forwarded registers. The rest are scratch.
7609       SmallVectorImpl<ForwardedRegister> &Forwards =
7610                                        FuncInfo->getForwardedMustTailRegParms();
7611       CCInfo.analyzeMustTailForwardedRegisters(Forwards, RegParmTypes,
7612                                                CC_AArch64_AAPCS);
7613 
7614       // Conservatively forward X8, since it might be used for aggregate return.
7615       if (!CCInfo.isAllocated(AArch64::X8)) {
7616         Register X8VReg = MF.addLiveIn(AArch64::X8, &AArch64::GPR64RegClass);
7617         Forwards.push_back(ForwardedRegister(X8VReg, AArch64::X8, MVT::i64));
7618       }
7619     }
7620   }
7621 
7622   // On Windows, InReg pointers must be returned, so record the pointer in a
7623   // virtual register at the start of the function so it can be returned in the
7624   // epilogue.
7625   if (IsWin64 || F.getCallingConv() == CallingConv::ARM64EC_Thunk_X64) {
7626     for (unsigned I = 0, E = Ins.size(); I != E; ++I) {
7627       if ((F.getCallingConv() == CallingConv::ARM64EC_Thunk_X64 ||
7628            Ins[I].Flags.isInReg()) &&
7629           Ins[I].Flags.isSRet()) {
7630         assert(!FuncInfo->getSRetReturnReg());
7631 
7632         MVT PtrTy = getPointerTy(DAG.getDataLayout());
7633         Register Reg =
7634             MF.getRegInfo().createVirtualRegister(getRegClassFor(PtrTy));
7635         FuncInfo->setSRetReturnReg(Reg);
7636 
7637         SDValue Copy = DAG.getCopyToReg(DAG.getEntryNode(), DL, Reg, InVals[I]);
7638         Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Copy, Chain);
7639         break;
7640       }
7641     }
7642   }
7643 
7644   unsigned StackArgSize = CCInfo.getStackSize();
7645   bool TailCallOpt = MF.getTarget().Options.GuaranteedTailCallOpt;
7646   if (DoesCalleeRestoreStack(CallConv, TailCallOpt)) {
7647     // This is a non-standard ABI so by fiat I say we're allowed to make full
7648     // use of the stack area to be popped, which must be aligned to 16 bytes in
7649     // any case:
7650     StackArgSize = alignTo(StackArgSize, 16);
7651 
7652     // If we're expected to restore the stack (e.g. fastcc) then we'll be adding
7653     // a multiple of 16.
7654     FuncInfo->setArgumentStackToRestore(StackArgSize);
7655 
7656     // This realignment carries over to the available bytes below. Our own
7657     // callers will guarantee the space is free by giving an aligned value to
7658     // CALLSEQ_START.
7659   }
7660   // Even if we're not expected to free up the space, it's useful to know how
7661   // much is there while considering tail calls (because we can reuse it).
7662   FuncInfo->setBytesInStackArgArea(StackArgSize);
7663 
7664   if (Subtarget->hasCustomCallingConv())
7665     Subtarget->getRegisterInfo()->UpdateCustomCalleeSavedRegs(MF);
7666 
7667   // Create a 16 Byte TPIDR2 object. The dynamic buffer
7668   // will be expanded and stored in the static object later using a pseudonode.
7669   if (SMEAttrs(MF.getFunction()).hasZAState()) {
7670     TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
7671     TPIDR2.FrameIndex = MFI.CreateStackObject(16, Align(16), false);
7672     SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
7673                               DAG.getConstant(1, DL, MVT::i32));
7674 
7675     SDValue Buffer;
7676     if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
7677       Buffer = DAG.getNode(AArch64ISD::ALLOCATE_ZA_BUFFER, DL,
7678                            DAG.getVTList(MVT::i64, MVT::Other), {Chain, SVL});
7679     } else {
7680       SDValue Size = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL);
7681       Buffer = DAG.getNode(ISD::DYNAMIC_STACKALLOC, DL,
7682                            DAG.getVTList(MVT::i64, MVT::Other),
7683                            {Chain, Size, DAG.getConstant(1, DL, MVT::i64)});
7684       MFI.CreateVariableSizedObject(Align(16), nullptr);
7685     }
7686     Chain = DAG.getNode(
7687         AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
7688         {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
7689   }
7690 
7691   if (CallConv == CallingConv::PreserveNone) {
7692     for (const ISD::InputArg &I : Ins) {
7693       if (I.Flags.isSwiftSelf() || I.Flags.isSwiftError() ||
7694           I.Flags.isSwiftAsync()) {
7695         MachineFunction &MF = DAG.getMachineFunction();
7696         DAG.getContext()->diagnose(DiagnosticInfoUnsupported(
7697             MF.getFunction(),
7698             "Swift attributes can't be used with preserve_none",
7699             DL.getDebugLoc()));
7700         break;
7701       }
7702     }
7703   }
7704 
7705   return Chain;
7706 }
7707 
saveVarArgRegisters(CCState & CCInfo,SelectionDAG & DAG,const SDLoc & DL,SDValue & Chain) const7708 void AArch64TargetLowering::saveVarArgRegisters(CCState &CCInfo,
7709                                                 SelectionDAG &DAG,
7710                                                 const SDLoc &DL,
7711                                                 SDValue &Chain) const {
7712   MachineFunction &MF = DAG.getMachineFunction();
7713   MachineFrameInfo &MFI = MF.getFrameInfo();
7714   AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
7715   auto PtrVT = getPointerTy(DAG.getDataLayout());
7716   Function &F = MF.getFunction();
7717   bool IsWin64 =
7718       Subtarget->isCallingConvWin64(F.getCallingConv(), F.isVarArg());
7719 
7720   SmallVector<SDValue, 8> MemOps;
7721 
7722   auto GPRArgRegs = AArch64::getGPRArgRegs();
7723   unsigned NumGPRArgRegs = GPRArgRegs.size();
7724   if (Subtarget->isWindowsArm64EC()) {
7725     // In the ARM64EC ABI, only x0-x3 are used to pass arguments to varargs
7726     // functions.
7727     NumGPRArgRegs = 4;
7728   }
7729   unsigned FirstVariadicGPR = CCInfo.getFirstUnallocated(GPRArgRegs);
7730 
7731   unsigned GPRSaveSize = 8 * (NumGPRArgRegs - FirstVariadicGPR);
7732   int GPRIdx = 0;
7733   if (GPRSaveSize != 0) {
7734     if (IsWin64) {
7735       GPRIdx = MFI.CreateFixedObject(GPRSaveSize, -(int)GPRSaveSize, false);
7736       if (GPRSaveSize & 15)
7737         // The extra size here, if triggered, will always be 8.
7738         MFI.CreateFixedObject(16 - (GPRSaveSize & 15), -(int)alignTo(GPRSaveSize, 16), false);
7739     } else
7740       GPRIdx = MFI.CreateStackObject(GPRSaveSize, Align(8), false);
7741 
7742     SDValue FIN;
7743     if (Subtarget->isWindowsArm64EC()) {
7744       // With the Arm64EC ABI, we reserve the save area as usual, but we
7745       // compute its address relative to x4.  For a normal AArch64->AArch64
7746       // call, x4 == sp on entry, but calls from an entry thunk can pass in a
7747       // different address.
7748       Register VReg = MF.addLiveIn(AArch64::X4, &AArch64::GPR64RegClass);
7749       SDValue Val = DAG.getCopyFromReg(Chain, DL, VReg, MVT::i64);
7750       FIN = DAG.getNode(ISD::SUB, DL, MVT::i64, Val,
7751                         DAG.getConstant(GPRSaveSize, DL, MVT::i64));
7752     } else {
7753       FIN = DAG.getFrameIndex(GPRIdx, PtrVT);
7754     }
7755 
7756     for (unsigned i = FirstVariadicGPR; i < NumGPRArgRegs; ++i) {
7757       Register VReg = MF.addLiveIn(GPRArgRegs[i], &AArch64::GPR64RegClass);
7758       SDValue Val = DAG.getCopyFromReg(Chain, DL, VReg, MVT::i64);
7759       SDValue Store =
7760           DAG.getStore(Val.getValue(1), DL, Val, FIN,
7761                        IsWin64 ? MachinePointerInfo::getFixedStack(
7762                                      MF, GPRIdx, (i - FirstVariadicGPR) * 8)
7763                                : MachinePointerInfo::getStack(MF, i * 8));
7764       MemOps.push_back(Store);
7765       FIN =
7766           DAG.getNode(ISD::ADD, DL, PtrVT, FIN, DAG.getConstant(8, DL, PtrVT));
7767     }
7768   }
7769   FuncInfo->setVarArgsGPRIndex(GPRIdx);
7770   FuncInfo->setVarArgsGPRSize(GPRSaveSize);
7771 
7772   if (Subtarget->hasFPARMv8() && !IsWin64) {
7773     auto FPRArgRegs = AArch64::getFPRArgRegs();
7774     const unsigned NumFPRArgRegs = FPRArgRegs.size();
7775     unsigned FirstVariadicFPR = CCInfo.getFirstUnallocated(FPRArgRegs);
7776 
7777     unsigned FPRSaveSize = 16 * (NumFPRArgRegs - FirstVariadicFPR);
7778     int FPRIdx = 0;
7779     if (FPRSaveSize != 0) {
7780       FPRIdx = MFI.CreateStackObject(FPRSaveSize, Align(16), false);
7781 
7782       SDValue FIN = DAG.getFrameIndex(FPRIdx, PtrVT);
7783 
7784       for (unsigned i = FirstVariadicFPR; i < NumFPRArgRegs; ++i) {
7785         Register VReg = MF.addLiveIn(FPRArgRegs[i], &AArch64::FPR128RegClass);
7786         SDValue Val = DAG.getCopyFromReg(Chain, DL, VReg, MVT::f128);
7787 
7788         SDValue Store = DAG.getStore(Val.getValue(1), DL, Val, FIN,
7789                                      MachinePointerInfo::getStack(MF, i * 16));
7790         MemOps.push_back(Store);
7791         FIN = DAG.getNode(ISD::ADD, DL, PtrVT, FIN,
7792                           DAG.getConstant(16, DL, PtrVT));
7793       }
7794     }
7795     FuncInfo->setVarArgsFPRIndex(FPRIdx);
7796     FuncInfo->setVarArgsFPRSize(FPRSaveSize);
7797   }
7798 
7799   if (!MemOps.empty()) {
7800     Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, MemOps);
7801   }
7802 }
7803 
7804 /// LowerCallResult - Lower the result values of a call into the
7805 /// appropriate copies out of appropriate physical registers.
LowerCallResult(SDValue Chain,SDValue InGlue,CallingConv::ID CallConv,bool isVarArg,const SmallVectorImpl<CCValAssign> & RVLocs,const SDLoc & DL,SelectionDAG & DAG,SmallVectorImpl<SDValue> & InVals,bool isThisReturn,SDValue ThisVal,bool RequiresSMChange) const7806 SDValue AArch64TargetLowering::LowerCallResult(
7807     SDValue Chain, SDValue InGlue, CallingConv::ID CallConv, bool isVarArg,
7808     const SmallVectorImpl<CCValAssign> &RVLocs, const SDLoc &DL,
7809     SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals, bool isThisReturn,
7810     SDValue ThisVal, bool RequiresSMChange) const {
7811   DenseMap<unsigned, SDValue> CopiedRegs;
7812   // Copy all of the result registers out of their specified physreg.
7813   for (unsigned i = 0; i != RVLocs.size(); ++i) {
7814     CCValAssign VA = RVLocs[i];
7815 
7816     // Pass 'this' value directly from the argument to return value, to avoid
7817     // reg unit interference
7818     if (i == 0 && isThisReturn) {
7819       assert(!VA.needsCustom() && VA.getLocVT() == MVT::i64 &&
7820              "unexpected return calling convention register assignment");
7821       InVals.push_back(ThisVal);
7822       continue;
7823     }
7824 
7825     // Avoid copying a physreg twice since RegAllocFast is incompetent and only
7826     // allows one use of a physreg per block.
7827     SDValue Val = CopiedRegs.lookup(VA.getLocReg());
7828     if (!Val) {
7829       Val =
7830           DAG.getCopyFromReg(Chain, DL, VA.getLocReg(), VA.getLocVT(), InGlue);
7831       Chain = Val.getValue(1);
7832       InGlue = Val.getValue(2);
7833       CopiedRegs[VA.getLocReg()] = Val;
7834     }
7835 
7836     switch (VA.getLocInfo()) {
7837     default:
7838       llvm_unreachable("Unknown loc info!");
7839     case CCValAssign::Full:
7840       break;
7841     case CCValAssign::BCvt:
7842       Val = DAG.getNode(ISD::BITCAST, DL, VA.getValVT(), Val);
7843       break;
7844     case CCValAssign::AExtUpper:
7845       Val = DAG.getNode(ISD::SRL, DL, VA.getLocVT(), Val,
7846                         DAG.getConstant(32, DL, VA.getLocVT()));
7847       [[fallthrough]];
7848     case CCValAssign::AExt:
7849       [[fallthrough]];
7850     case CCValAssign::ZExt:
7851       Val = DAG.getZExtOrTrunc(Val, DL, VA.getValVT());
7852       break;
7853     }
7854 
7855     if (RequiresSMChange && isPassedInFPR(VA.getValVT()))
7856       Val = DAG.getNode(AArch64ISD::COALESCER_BARRIER, DL, Val.getValueType(),
7857                         Val);
7858 
7859     InVals.push_back(Val);
7860   }
7861 
7862   return Chain;
7863 }
7864 
7865 /// Return true if the calling convention is one that we can guarantee TCO for.
canGuaranteeTCO(CallingConv::ID CC,bool GuaranteeTailCalls)7866 static bool canGuaranteeTCO(CallingConv::ID CC, bool GuaranteeTailCalls) {
7867   return (CC == CallingConv::Fast && GuaranteeTailCalls) ||
7868          CC == CallingConv::Tail || CC == CallingConv::SwiftTail;
7869 }
7870 
7871 /// Return true if we might ever do TCO for calls with this calling convention.
mayTailCallThisCC(CallingConv::ID CC)7872 static bool mayTailCallThisCC(CallingConv::ID CC) {
7873   switch (CC) {
7874   case CallingConv::C:
7875   case CallingConv::AArch64_SVE_VectorCall:
7876   case CallingConv::PreserveMost:
7877   case CallingConv::PreserveAll:
7878   case CallingConv::PreserveNone:
7879   case CallingConv::Swift:
7880   case CallingConv::SwiftTail:
7881   case CallingConv::Tail:
7882   case CallingConv::Fast:
7883     return true;
7884   default:
7885     return false;
7886   }
7887 }
7888 
7889 /// Return true if the call convention supports varargs
7890 /// Currently only those that pass varargs like the C
7891 /// calling convention does are eligible
7892 /// Calling conventions listed in this function must also
7893 /// be properly handled in AArch64Subtarget::isCallingConvWin64
callConvSupportsVarArgs(CallingConv::ID CC)7894 static bool callConvSupportsVarArgs(CallingConv::ID CC) {
7895   switch (CC) {
7896   case CallingConv::C:
7897   case CallingConv::PreserveNone:
7898     return true;
7899   default:
7900     return false;
7901   }
7902 }
7903 
analyzeCallOperands(const AArch64TargetLowering & TLI,const AArch64Subtarget * Subtarget,const TargetLowering::CallLoweringInfo & CLI,CCState & CCInfo)7904 static void analyzeCallOperands(const AArch64TargetLowering &TLI,
7905                                 const AArch64Subtarget *Subtarget,
7906                                 const TargetLowering::CallLoweringInfo &CLI,
7907                                 CCState &CCInfo) {
7908   const SelectionDAG &DAG = CLI.DAG;
7909   CallingConv::ID CalleeCC = CLI.CallConv;
7910   bool IsVarArg = CLI.IsVarArg;
7911   const SmallVector<ISD::OutputArg, 32> &Outs = CLI.Outs;
7912   bool IsCalleeWin64 = Subtarget->isCallingConvWin64(CalleeCC, IsVarArg);
7913 
7914   // For Arm64EC thunks, allocate 32 extra bytes at the bottom of the stack
7915   // for the shadow store.
7916   if (CalleeCC == CallingConv::ARM64EC_Thunk_X64)
7917     CCInfo.AllocateStack(32, Align(16));
7918 
7919   unsigned NumArgs = Outs.size();
7920   for (unsigned i = 0; i != NumArgs; ++i) {
7921     MVT ArgVT = Outs[i].VT;
7922     ISD::ArgFlagsTy ArgFlags = Outs[i].Flags;
7923 
7924     bool UseVarArgCC = false;
7925     if (IsVarArg) {
7926       // On Windows, the fixed arguments in a vararg call are passed in GPRs
7927       // too, so use the vararg CC to force them to integer registers.
7928       if (IsCalleeWin64) {
7929         UseVarArgCC = true;
7930       } else {
7931         UseVarArgCC = !Outs[i].IsFixed;
7932       }
7933     }
7934 
7935     if (!UseVarArgCC) {
7936       // Get type of the original argument.
7937       EVT ActualVT =
7938           TLI.getValueType(DAG.getDataLayout(), CLI.Args[Outs[i].OrigArgIndex].Ty,
7939                        /*AllowUnknown*/ true);
7940       MVT ActualMVT = ActualVT.isSimple() ? ActualVT.getSimpleVT() : ArgVT;
7941       // If ActualMVT is i1/i8/i16, we should set LocVT to i8/i8/i16.
7942       if (ActualMVT == MVT::i1 || ActualMVT == MVT::i8)
7943         ArgVT = MVT::i8;
7944       else if (ActualMVT == MVT::i16)
7945         ArgVT = MVT::i16;
7946     }
7947 
7948     CCAssignFn *AssignFn = TLI.CCAssignFnForCall(CalleeCC, UseVarArgCC);
7949     bool Res = AssignFn(i, ArgVT, ArgVT, CCValAssign::Full, ArgFlags, CCInfo);
7950     assert(!Res && "Call operand has unhandled type");
7951     (void)Res;
7952   }
7953 }
7954 
isEligibleForTailCallOptimization(const CallLoweringInfo & CLI) const7955 bool AArch64TargetLowering::isEligibleForTailCallOptimization(
7956     const CallLoweringInfo &CLI) const {
7957   CallingConv::ID CalleeCC = CLI.CallConv;
7958   if (!mayTailCallThisCC(CalleeCC))
7959     return false;
7960 
7961   SDValue Callee = CLI.Callee;
7962   bool IsVarArg = CLI.IsVarArg;
7963   const SmallVector<ISD::OutputArg, 32> &Outs = CLI.Outs;
7964   const SmallVector<SDValue, 32> &OutVals = CLI.OutVals;
7965   const SmallVector<ISD::InputArg, 32> &Ins = CLI.Ins;
7966   const SelectionDAG &DAG = CLI.DAG;
7967   MachineFunction &MF = DAG.getMachineFunction();
7968   const Function &CallerF = MF.getFunction();
7969   CallingConv::ID CallerCC = CallerF.getCallingConv();
7970 
7971   // SME Streaming functions are not eligible for TCO as they may require
7972   // the streaming mode or ZA to be restored after returning from the call.
7973   SMEAttrs CallerAttrs(MF.getFunction());
7974   auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
7975   if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
7976       CallerAttrs.requiresLazySave(CalleeAttrs) ||
7977       CallerAttrs.hasStreamingBody())
7978     return false;
7979 
7980   // Functions using the C or Fast calling convention that have an SVE signature
7981   // preserve more registers and should assume the SVE_VectorCall CC.
7982   // The check for matching callee-saved regs will determine whether it is
7983   // eligible for TCO.
7984   if ((CallerCC == CallingConv::C || CallerCC == CallingConv::Fast) &&
7985       MF.getInfo<AArch64FunctionInfo>()->isSVECC())
7986     CallerCC = CallingConv::AArch64_SVE_VectorCall;
7987 
7988   bool CCMatch = CallerCC == CalleeCC;
7989 
7990   // When using the Windows calling convention on a non-windows OS, we want
7991   // to back up and restore X18 in such functions; we can't do a tail call
7992   // from those functions.
7993   if (CallerCC == CallingConv::Win64 && !Subtarget->isTargetWindows() &&
7994       CalleeCC != CallingConv::Win64)
7995     return false;
7996 
7997   // Byval parameters hand the function a pointer directly into the stack area
7998   // we want to reuse during a tail call. Working around this *is* possible (see
7999   // X86) but less efficient and uglier in LowerCall.
8000   for (Function::const_arg_iterator i = CallerF.arg_begin(),
8001                                     e = CallerF.arg_end();
8002        i != e; ++i) {
8003     if (i->hasByValAttr())
8004       return false;
8005 
8006     // On Windows, "inreg" attributes signify non-aggregate indirect returns.
8007     // In this case, it is necessary to save/restore X0 in the callee. Tail
8008     // call opt interferes with this. So we disable tail call opt when the
8009     // caller has an argument with "inreg" attribute.
8010 
8011     // FIXME: Check whether the callee also has an "inreg" argument.
8012     if (i->hasInRegAttr())
8013       return false;
8014   }
8015 
8016   if (canGuaranteeTCO(CalleeCC, getTargetMachine().Options.GuaranteedTailCallOpt))
8017     return CCMatch;
8018 
8019   // Externally-defined functions with weak linkage should not be
8020   // tail-called on AArch64 when the OS does not support dynamic
8021   // pre-emption of symbols, as the AAELF spec requires normal calls
8022   // to undefined weak functions to be replaced with a NOP or jump to the
8023   // next instruction. The behaviour of branch instructions in this
8024   // situation (as used for tail calls) is implementation-defined, so we
8025   // cannot rely on the linker replacing the tail call with a return.
8026   if (GlobalAddressSDNode *G = dyn_cast<GlobalAddressSDNode>(Callee)) {
8027     const GlobalValue *GV = G->getGlobal();
8028     const Triple &TT = getTargetMachine().getTargetTriple();
8029     if (GV->hasExternalWeakLinkage() &&
8030         (!TT.isOSWindows() || TT.isOSBinFormatELF() || TT.isOSBinFormatMachO()))
8031       return false;
8032   }
8033 
8034   // Now we search for cases where we can use a tail call without changing the
8035   // ABI. Sibcall is used in some places (particularly gcc) to refer to this
8036   // concept.
8037 
8038   // I want anyone implementing a new calling convention to think long and hard
8039   // about this assert.
8040   if (IsVarArg && !callConvSupportsVarArgs(CalleeCC))
8041     report_fatal_error("Unsupported variadic calling convention");
8042 
8043   LLVMContext &C = *DAG.getContext();
8044   // Check that the call results are passed in the same way.
8045   if (!CCState::resultsCompatible(CalleeCC, CallerCC, MF, C, Ins,
8046                                   CCAssignFnForCall(CalleeCC, IsVarArg),
8047                                   CCAssignFnForCall(CallerCC, IsVarArg)))
8048     return false;
8049   // The callee has to preserve all registers the caller needs to preserve.
8050   const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
8051   const uint32_t *CallerPreserved = TRI->getCallPreservedMask(MF, CallerCC);
8052   if (!CCMatch) {
8053     const uint32_t *CalleePreserved = TRI->getCallPreservedMask(MF, CalleeCC);
8054     if (Subtarget->hasCustomCallingConv()) {
8055       TRI->UpdateCustomCallPreservedMask(MF, &CallerPreserved);
8056       TRI->UpdateCustomCallPreservedMask(MF, &CalleePreserved);
8057     }
8058     if (!TRI->regmaskSubsetEqual(CallerPreserved, CalleePreserved))
8059       return false;
8060   }
8061 
8062   // Nothing more to check if the callee is taking no arguments
8063   if (Outs.empty())
8064     return true;
8065 
8066   SmallVector<CCValAssign, 16> ArgLocs;
8067   CCState CCInfo(CalleeCC, IsVarArg, MF, ArgLocs, C);
8068 
8069   analyzeCallOperands(*this, Subtarget, CLI, CCInfo);
8070 
8071   if (IsVarArg && !(CLI.CB && CLI.CB->isMustTailCall())) {
8072     // When we are musttail, additional checks have been done and we can safely ignore this check
8073     // At least two cases here: if caller is fastcc then we can't have any
8074     // memory arguments (we'd be expected to clean up the stack afterwards). If
8075     // caller is C then we could potentially use its argument area.
8076 
8077     // FIXME: for now we take the most conservative of these in both cases:
8078     // disallow all variadic memory operands.
8079     for (const CCValAssign &ArgLoc : ArgLocs)
8080       if (!ArgLoc.isRegLoc())
8081         return false;
8082   }
8083 
8084   const AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
8085 
8086   // If any of the arguments is passed indirectly, it must be SVE, so the
8087   // 'getBytesInStackArgArea' is not sufficient to determine whether we need to
8088   // allocate space on the stack. That is why we determine this explicitly here
8089   // the call cannot be a tailcall.
8090   if (llvm::any_of(ArgLocs, [&](CCValAssign &A) {
8091         assert((A.getLocInfo() != CCValAssign::Indirect ||
8092                 A.getValVT().isScalableVector() ||
8093                 Subtarget->isWindowsArm64EC()) &&
8094                "Expected value to be scalable");
8095         return A.getLocInfo() == CCValAssign::Indirect;
8096       }))
8097     return false;
8098 
8099   // If the stack arguments for this call do not fit into our own save area then
8100   // the call cannot be made tail.
8101   if (CCInfo.getStackSize() > FuncInfo->getBytesInStackArgArea())
8102     return false;
8103 
8104   const MachineRegisterInfo &MRI = MF.getRegInfo();
8105   if (!parametersInCSRMatch(MRI, CallerPreserved, ArgLocs, OutVals))
8106     return false;
8107 
8108   return true;
8109 }
8110 
addTokenForArgument(SDValue Chain,SelectionDAG & DAG,MachineFrameInfo & MFI,int ClobberedFI) const8111 SDValue AArch64TargetLowering::addTokenForArgument(SDValue Chain,
8112                                                    SelectionDAG &DAG,
8113                                                    MachineFrameInfo &MFI,
8114                                                    int ClobberedFI) const {
8115   SmallVector<SDValue, 8> ArgChains;
8116   int64_t FirstByte = MFI.getObjectOffset(ClobberedFI);
8117   int64_t LastByte = FirstByte + MFI.getObjectSize(ClobberedFI) - 1;
8118 
8119   // Include the original chain at the beginning of the list. When this is
8120   // used by target LowerCall hooks, this helps legalize find the
8121   // CALLSEQ_BEGIN node.
8122   ArgChains.push_back(Chain);
8123 
8124   // Add a chain value for each stack argument corresponding
8125   for (SDNode *U : DAG.getEntryNode().getNode()->uses())
8126     if (LoadSDNode *L = dyn_cast<LoadSDNode>(U))
8127       if (FrameIndexSDNode *FI = dyn_cast<FrameIndexSDNode>(L->getBasePtr()))
8128         if (FI->getIndex() < 0) {
8129           int64_t InFirstByte = MFI.getObjectOffset(FI->getIndex());
8130           int64_t InLastByte = InFirstByte;
8131           InLastByte += MFI.getObjectSize(FI->getIndex()) - 1;
8132 
8133           if ((InFirstByte <= FirstByte && FirstByte <= InLastByte) ||
8134               (FirstByte <= InFirstByte && InFirstByte <= LastByte))
8135             ArgChains.push_back(SDValue(L, 1));
8136         }
8137 
8138   // Build a tokenfactor for all the chains.
8139   return DAG.getNode(ISD::TokenFactor, SDLoc(Chain), MVT::Other, ArgChains);
8140 }
8141 
DoesCalleeRestoreStack(CallingConv::ID CallCC,bool TailCallOpt) const8142 bool AArch64TargetLowering::DoesCalleeRestoreStack(CallingConv::ID CallCC,
8143                                                    bool TailCallOpt) const {
8144   return (CallCC == CallingConv::Fast && TailCallOpt) ||
8145          CallCC == CallingConv::Tail || CallCC == CallingConv::SwiftTail;
8146 }
8147 
8148 // Check if the value is zero-extended from i1 to i8
checkZExtBool(SDValue Arg,const SelectionDAG & DAG)8149 static bool checkZExtBool(SDValue Arg, const SelectionDAG &DAG) {
8150   unsigned SizeInBits = Arg.getValueType().getSizeInBits();
8151   if (SizeInBits < 8)
8152     return false;
8153 
8154   APInt RequredZero(SizeInBits, 0xFE);
8155   KnownBits Bits = DAG.computeKnownBits(Arg, 4);
8156   bool ZExtBool = (Bits.Zero & RequredZero) == RequredZero;
8157   return ZExtBool;
8158 }
8159 
AdjustInstrPostInstrSelection(MachineInstr & MI,SDNode * Node) const8160 void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
8161                                                           SDNode *Node) const {
8162   // Live-in physreg copies that are glued to SMSTART are applied as
8163   // implicit-def's in the InstrEmitter. Here we remove them, allowing the
8164   // register allocator to pass call args in callee saved regs, without extra
8165   // copies to avoid these fake clobbers of actually-preserved GPRs.
8166   if (MI.getOpcode() == AArch64::MSRpstatesvcrImm1 ||
8167       MI.getOpcode() == AArch64::MSRpstatePseudo) {
8168     for (unsigned I = MI.getNumOperands() - 1; I > 0; --I)
8169       if (MachineOperand &MO = MI.getOperand(I);
8170           MO.isReg() && MO.isImplicit() && MO.isDef() &&
8171           (AArch64::GPR32RegClass.contains(MO.getReg()) ||
8172            AArch64::GPR64RegClass.contains(MO.getReg())))
8173         MI.removeOperand(I);
8174 
8175     // The SVE vector length can change when entering/leaving streaming mode.
8176     if (MI.getOperand(0).getImm() == AArch64SVCR::SVCRSM ||
8177         MI.getOperand(0).getImm() == AArch64SVCR::SVCRSMZA) {
8178       MI.addOperand(MachineOperand::CreateReg(AArch64::VG, /*IsDef=*/false,
8179                                               /*IsImplicit=*/true));
8180       MI.addOperand(MachineOperand::CreateReg(AArch64::VG, /*IsDef=*/true,
8181                                               /*IsImplicit=*/true));
8182     }
8183   }
8184 
8185   // Add an implicit use of 'VG' for ADDXri/SUBXri, which are instructions that
8186   // have nothing to do with VG, were it not that they are used to materialise a
8187   // frame-address. If they contain a frame-index to a scalable vector, this
8188   // will likely require an ADDVL instruction to materialise the address, thus
8189   // reading VG.
8190   const MachineFunction &MF = *MI.getMF();
8191   if (MF.getInfo<AArch64FunctionInfo>()->hasStreamingModeChanges() &&
8192       (MI.getOpcode() == AArch64::ADDXri ||
8193        MI.getOpcode() == AArch64::SUBXri)) {
8194     const MachineOperand &MO = MI.getOperand(1);
8195     if (MO.isFI() && MF.getFrameInfo().getStackID(MO.getIndex()) ==
8196                          TargetStackID::ScalableVector)
8197       MI.addOperand(MachineOperand::CreateReg(AArch64::VG, /*IsDef=*/false,
8198                                               /*IsImplicit=*/true));
8199   }
8200 }
8201 
changeStreamingMode(SelectionDAG & DAG,SDLoc DL,bool Enable,SDValue Chain,SDValue InGlue,unsigned Condition,SDValue PStateSM) const8202 SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
8203                                                    bool Enable, SDValue Chain,
8204                                                    SDValue InGlue,
8205                                                    unsigned Condition,
8206                                                    SDValue PStateSM) const {
8207   MachineFunction &MF = DAG.getMachineFunction();
8208   AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
8209   FuncInfo->setHasStreamingModeChanges(true);
8210 
8211   const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
8212   SDValue RegMask = DAG.getRegisterMask(TRI->getSMStartStopCallPreservedMask());
8213   SDValue MSROp =
8214       DAG.getTargetConstant((int32_t)AArch64SVCR::SVCRSM, DL, MVT::i32);
8215   SDValue ConditionOp = DAG.getTargetConstant(Condition, DL, MVT::i64);
8216   SmallVector<SDValue> Ops = {Chain, MSROp, ConditionOp};
8217   if (Condition != AArch64SME::Always) {
8218     assert(PStateSM && "PStateSM should be defined");
8219     Ops.push_back(PStateSM);
8220   }
8221   Ops.push_back(RegMask);
8222 
8223   if (InGlue)
8224     Ops.push_back(InGlue);
8225 
8226   unsigned Opcode = Enable ? AArch64ISD::SMSTART : AArch64ISD::SMSTOP;
8227   return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
8228 }
8229 
getSMCondition(const SMEAttrs & CallerAttrs,const SMEAttrs & CalleeAttrs)8230 static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
8231                                const SMEAttrs &CalleeAttrs) {
8232   if (!CallerAttrs.hasStreamingCompatibleInterface() ||
8233       CallerAttrs.hasStreamingBody())
8234     return AArch64SME::Always;
8235   if (CalleeAttrs.hasNonStreamingInterface())
8236     return AArch64SME::IfCallerIsStreaming;
8237   if (CalleeAttrs.hasStreamingInterface())
8238     return AArch64SME::IfCallerIsNonStreaming;
8239 
8240   llvm_unreachable("Unsupported attributes");
8241 }
8242 
8243 /// LowerCall - Lower a call to a callseq_start + CALL + callseq_end chain,
8244 /// and add input and output parameter nodes.
8245 SDValue
LowerCall(CallLoweringInfo & CLI,SmallVectorImpl<SDValue> & InVals) const8246 AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
8247                                  SmallVectorImpl<SDValue> &InVals) const {
8248   SelectionDAG &DAG = CLI.DAG;
8249   SDLoc &DL = CLI.DL;
8250   SmallVector<ISD::OutputArg, 32> &Outs = CLI.Outs;
8251   SmallVector<SDValue, 32> &OutVals = CLI.OutVals;
8252   SmallVector<ISD::InputArg, 32> &Ins = CLI.Ins;
8253   SDValue Chain = CLI.Chain;
8254   SDValue Callee = CLI.Callee;
8255   bool &IsTailCall = CLI.IsTailCall;
8256   CallingConv::ID &CallConv = CLI.CallConv;
8257   bool IsVarArg = CLI.IsVarArg;
8258 
8259   MachineFunction &MF = DAG.getMachineFunction();
8260   MachineFunction::CallSiteInfo CSInfo;
8261   bool IsThisReturn = false;
8262 
8263   AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
8264   bool TailCallOpt = MF.getTarget().Options.GuaranteedTailCallOpt;
8265   bool IsCFICall = CLI.CB && CLI.CB->isIndirectCall() && CLI.CFIType;
8266   bool IsSibCall = false;
8267   bool GuardWithBTI = false;
8268 
8269   if (CLI.CB && CLI.CB->hasFnAttr(Attribute::ReturnsTwice) &&
8270       !Subtarget->noBTIAtReturnTwice()) {
8271     GuardWithBTI = FuncInfo->branchTargetEnforcement();
8272   }
8273 
8274   // Analyze operands of the call, assigning locations to each operand.
8275   SmallVector<CCValAssign, 16> ArgLocs;
8276   CCState CCInfo(CallConv, IsVarArg, MF, ArgLocs, *DAG.getContext());
8277 
8278   if (IsVarArg) {
8279     unsigned NumArgs = Outs.size();
8280 
8281     for (unsigned i = 0; i != NumArgs; ++i) {
8282       if (!Outs[i].IsFixed && Outs[i].VT.isScalableVector())
8283         report_fatal_error("Passing SVE types to variadic functions is "
8284                            "currently not supported");
8285     }
8286   }
8287 
8288   analyzeCallOperands(*this, Subtarget, CLI, CCInfo);
8289 
8290   CCAssignFn *RetCC = CCAssignFnForReturn(CallConv);
8291   // Assign locations to each value returned by this call.
8292   SmallVector<CCValAssign, 16> RVLocs;
8293   CCState RetCCInfo(CallConv, IsVarArg, DAG.getMachineFunction(), RVLocs,
8294                     *DAG.getContext());
8295   RetCCInfo.AnalyzeCallResult(Ins, RetCC);
8296 
8297   // Check callee args/returns for SVE registers and set calling convention
8298   // accordingly.
8299   if (CallConv == CallingConv::C || CallConv == CallingConv::Fast) {
8300     auto HasSVERegLoc = [](CCValAssign &Loc) {
8301       if (!Loc.isRegLoc())
8302         return false;
8303       return AArch64::ZPRRegClass.contains(Loc.getLocReg()) ||
8304              AArch64::PPRRegClass.contains(Loc.getLocReg());
8305     };
8306     if (any_of(RVLocs, HasSVERegLoc) || any_of(ArgLocs, HasSVERegLoc))
8307       CallConv = CallingConv::AArch64_SVE_VectorCall;
8308   }
8309 
8310   if (IsTailCall) {
8311     // Check if it's really possible to do a tail call.
8312     IsTailCall = isEligibleForTailCallOptimization(CLI);
8313 
8314     // A sibling call is one where we're under the usual C ABI and not planning
8315     // to change that but can still do a tail call:
8316     if (!TailCallOpt && IsTailCall && CallConv != CallingConv::Tail &&
8317         CallConv != CallingConv::SwiftTail)
8318       IsSibCall = true;
8319 
8320     if (IsTailCall)
8321       ++NumTailCalls;
8322   }
8323 
8324   if (!IsTailCall && CLI.CB && CLI.CB->isMustTailCall())
8325     report_fatal_error("failed to perform tail call elimination on a call "
8326                        "site marked musttail");
8327 
8328   // Get a count of how many bytes are to be pushed on the stack.
8329   unsigned NumBytes = CCInfo.getStackSize();
8330 
8331   if (IsSibCall) {
8332     // Since we're not changing the ABI to make this a tail call, the memory
8333     // operands are already available in the caller's incoming argument space.
8334     NumBytes = 0;
8335   }
8336 
8337   // FPDiff is the byte offset of the call's argument area from the callee's.
8338   // Stores to callee stack arguments will be placed in FixedStackSlots offset
8339   // by this amount for a tail call. In a sibling call it must be 0 because the
8340   // caller will deallocate the entire stack and the callee still expects its
8341   // arguments to begin at SP+0. Completely unused for non-tail calls.
8342   int FPDiff = 0;
8343 
8344   if (IsTailCall && !IsSibCall) {
8345     unsigned NumReusableBytes = FuncInfo->getBytesInStackArgArea();
8346 
8347     // Since callee will pop argument stack as a tail call, we must keep the
8348     // popped size 16-byte aligned.
8349     NumBytes = alignTo(NumBytes, 16);
8350 
8351     // FPDiff will be negative if this tail call requires more space than we
8352     // would automatically have in our incoming argument space. Positive if we
8353     // can actually shrink the stack.
8354     FPDiff = NumReusableBytes - NumBytes;
8355 
8356     // Update the required reserved area if this is the tail call requiring the
8357     // most argument stack space.
8358     if (FPDiff < 0 && FuncInfo->getTailCallReservedStack() < (unsigned)-FPDiff)
8359       FuncInfo->setTailCallReservedStack(-FPDiff);
8360 
8361     // The stack pointer must be 16-byte aligned at all times it's used for a
8362     // memory operation, which in practice means at *all* times and in
8363     // particular across call boundaries. Therefore our own arguments started at
8364     // a 16-byte aligned SP and the delta applied for the tail call should
8365     // satisfy the same constraint.
8366     assert(FPDiff % 16 == 0 && "unaligned stack on tail call");
8367   }
8368 
8369   // Determine whether we need any streaming mode changes.
8370   SMEAttrs CalleeAttrs, CallerAttrs(MF.getFunction());
8371   if (CLI.CB)
8372     CalleeAttrs = SMEAttrs(*CLI.CB);
8373   else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
8374     CalleeAttrs = SMEAttrs(ES->getSymbol());
8375 
8376   auto DescribeCallsite =
8377       [&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
8378     R << "call from '" << ore::NV("Caller", MF.getName()) << "' to '";
8379     if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
8380       R << ore::NV("Callee", ES->getSymbol());
8381     else if (CLI.CB && CLI.CB->getCalledFunction())
8382       R << ore::NV("Callee", CLI.CB->getCalledFunction()->getName());
8383     else
8384       R << "unknown callee";
8385     R << "'";
8386     return R;
8387   };
8388 
8389   bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
8390   if (RequiresLazySave) {
8391     const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
8392     MachinePointerInfo MPI =
8393         MachinePointerInfo::getStack(MF, TPIDR2.FrameIndex);
8394     SDValue TPIDR2ObjAddr = DAG.getFrameIndex(
8395         TPIDR2.FrameIndex,
8396         DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
8397     SDValue NumZaSaveSlicesAddr =
8398         DAG.getNode(ISD::ADD, DL, TPIDR2ObjAddr.getValueType(), TPIDR2ObjAddr,
8399                     DAG.getConstant(8, DL, TPIDR2ObjAddr.getValueType()));
8400     SDValue NumZaSaveSlices = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
8401                                           DAG.getConstant(1, DL, MVT::i32));
8402     Chain = DAG.getTruncStore(Chain, DL, NumZaSaveSlices, NumZaSaveSlicesAddr,
8403                               MPI, MVT::i16);
8404     Chain = DAG.getNode(
8405         ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
8406         DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
8407         TPIDR2ObjAddr);
8408     OptimizationRemarkEmitter ORE(&MF.getFunction());
8409     ORE.emit([&]() {
8410       auto R = CLI.CB ? OptimizationRemarkAnalysis("sme", "SMELazySaveZA",
8411                                                    CLI.CB)
8412                       : OptimizationRemarkAnalysis("sme", "SMELazySaveZA",
8413                                                    &MF.getFunction());
8414       return DescribeCallsite(R) << " sets up a lazy save for ZA";
8415     });
8416   }
8417 
8418   SDValue PStateSM;
8419   bool RequiresSMChange = CallerAttrs.requiresSMChange(CalleeAttrs);
8420   if (RequiresSMChange) {
8421     if (CallerAttrs.hasStreamingInterfaceOrBody())
8422       PStateSM = DAG.getConstant(1, DL, MVT::i64);
8423     else if (CallerAttrs.hasNonStreamingInterface())
8424       PStateSM = DAG.getConstant(0, DL, MVT::i64);
8425     else
8426       PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
8427     OptimizationRemarkEmitter ORE(&MF.getFunction());
8428     ORE.emit([&]() {
8429       auto R = CLI.CB ? OptimizationRemarkAnalysis("sme", "SMETransition",
8430                                                    CLI.CB)
8431                       : OptimizationRemarkAnalysis("sme", "SMETransition",
8432                                                    &MF.getFunction());
8433       DescribeCallsite(R) << " requires a streaming mode transition";
8434       return R;
8435     });
8436   }
8437 
8438   SDValue ZTFrameIdx;
8439   MachineFrameInfo &MFI = MF.getFrameInfo();
8440   bool ShouldPreserveZT0 = CallerAttrs.requiresPreservingZT0(CalleeAttrs);
8441 
8442   // If the caller has ZT0 state which will not be preserved by the callee,
8443   // spill ZT0 before the call.
8444   if (ShouldPreserveZT0) {
8445     unsigned ZTObj = MFI.CreateSpillStackObject(64, Align(16));
8446     ZTFrameIdx = DAG.getFrameIndex(
8447         ZTObj,
8448         DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
8449 
8450     Chain = DAG.getNode(AArch64ISD::SAVE_ZT, DL, DAG.getVTList(MVT::Other),
8451                         {Chain, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
8452   }
8453 
8454   // If caller shares ZT0 but the callee is not shared ZA, we need to stop
8455   // PSTATE.ZA before the call if there is no lazy-save active.
8456   bool DisableZA = CallerAttrs.requiresDisablingZABeforeCall(CalleeAttrs);
8457   assert((!DisableZA || !RequiresLazySave) &&
8458          "Lazy-save should have PSTATE.SM=1 on entry to the function");
8459 
8460   if (DisableZA)
8461     Chain = DAG.getNode(
8462         AArch64ISD::SMSTOP, DL, MVT::Other, Chain,
8463         DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
8464         DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
8465 
8466   // Adjust the stack pointer for the new arguments...
8467   // These operations are automatically eliminated by the prolog/epilog pass
8468   if (!IsSibCall)
8469     Chain = DAG.getCALLSEQ_START(Chain, IsTailCall ? 0 : NumBytes, 0, DL);
8470 
8471   SDValue StackPtr = DAG.getCopyFromReg(Chain, DL, AArch64::SP,
8472                                         getPointerTy(DAG.getDataLayout()));
8473 
8474   SmallVector<std::pair<unsigned, SDValue>, 8> RegsToPass;
8475   SmallSet<unsigned, 8> RegsUsed;
8476   SmallVector<SDValue, 8> MemOpChains;
8477   auto PtrVT = getPointerTy(DAG.getDataLayout());
8478 
8479   if (IsVarArg && CLI.CB && CLI.CB->isMustTailCall()) {
8480     const auto &Forwards = FuncInfo->getForwardedMustTailRegParms();
8481     for (const auto &F : Forwards) {
8482       SDValue Val = DAG.getCopyFromReg(Chain, DL, F.VReg, F.VT);
8483        RegsToPass.emplace_back(F.PReg, Val);
8484     }
8485   }
8486 
8487   // Walk the register/memloc assignments, inserting copies/loads.
8488   unsigned ExtraArgLocs = 0;
8489   for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
8490     CCValAssign &VA = ArgLocs[i - ExtraArgLocs];
8491     SDValue Arg = OutVals[i];
8492     ISD::ArgFlagsTy Flags = Outs[i].Flags;
8493 
8494     // Promote the value if needed.
8495     switch (VA.getLocInfo()) {
8496     default:
8497       llvm_unreachable("Unknown loc info!");
8498     case CCValAssign::Full:
8499       break;
8500     case CCValAssign::SExt:
8501       Arg = DAG.getNode(ISD::SIGN_EXTEND, DL, VA.getLocVT(), Arg);
8502       break;
8503     case CCValAssign::ZExt:
8504       Arg = DAG.getNode(ISD::ZERO_EXTEND, DL, VA.getLocVT(), Arg);
8505       break;
8506     case CCValAssign::AExt:
8507       if (Outs[i].ArgVT == MVT::i1) {
8508         // AAPCS requires i1 to be zero-extended to 8-bits by the caller.
8509         //
8510         // Check if we actually have to do this, because the value may
8511         // already be zero-extended.
8512         //
8513         // We cannot just emit a (zext i8 (trunc (assert-zext i8)))
8514         // and rely on DAGCombiner to fold this, because the following
8515         // (anyext i32) is combined with (zext i8) in DAG.getNode:
8516         //
8517         //   (ext (zext x)) -> (zext x)
8518         //
8519         // This will give us (zext i32), which we cannot remove, so
8520         // try to check this beforehand.
8521         if (!checkZExtBool(Arg, DAG)) {
8522           Arg = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, Arg);
8523           Arg = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i8, Arg);
8524         }
8525       }
8526       Arg = DAG.getNode(ISD::ANY_EXTEND, DL, VA.getLocVT(), Arg);
8527       break;
8528     case CCValAssign::AExtUpper:
8529       assert(VA.getValVT() == MVT::i32 && "only expect 32 -> 64 upper bits");
8530       Arg = DAG.getNode(ISD::ANY_EXTEND, DL, VA.getLocVT(), Arg);
8531       Arg = DAG.getNode(ISD::SHL, DL, VA.getLocVT(), Arg,
8532                         DAG.getConstant(32, DL, VA.getLocVT()));
8533       break;
8534     case CCValAssign::BCvt:
8535       Arg = DAG.getBitcast(VA.getLocVT(), Arg);
8536       break;
8537     case CCValAssign::Trunc:
8538       Arg = DAG.getZExtOrTrunc(Arg, DL, VA.getLocVT());
8539       break;
8540     case CCValAssign::FPExt:
8541       Arg = DAG.getNode(ISD::FP_EXTEND, DL, VA.getLocVT(), Arg);
8542       break;
8543     case CCValAssign::Indirect:
8544       bool isScalable = VA.getValVT().isScalableVT();
8545       assert((isScalable || Subtarget->isWindowsArm64EC()) &&
8546              "Indirect arguments should be scalable on most subtargets");
8547 
8548       uint64_t StoreSize = VA.getValVT().getStoreSize().getKnownMinValue();
8549       uint64_t PartSize = StoreSize;
8550       unsigned NumParts = 1;
8551       if (Outs[i].Flags.isInConsecutiveRegs()) {
8552         while (!Outs[i + NumParts - 1].Flags.isInConsecutiveRegsLast())
8553           ++NumParts;
8554         StoreSize *= NumParts;
8555       }
8556 
8557       Type *Ty = EVT(VA.getValVT()).getTypeForEVT(*DAG.getContext());
8558       Align Alignment = DAG.getDataLayout().getPrefTypeAlign(Ty);
8559       MachineFrameInfo &MFI = MF.getFrameInfo();
8560       int FI = MFI.CreateStackObject(StoreSize, Alignment, false);
8561       if (isScalable)
8562         MFI.setStackID(FI, TargetStackID::ScalableVector);
8563 
8564       MachinePointerInfo MPI = MachinePointerInfo::getFixedStack(MF, FI);
8565       SDValue Ptr = DAG.getFrameIndex(
8566           FI, DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
8567       SDValue SpillSlot = Ptr;
8568 
8569       // Ensure we generate all stores for each tuple part, whilst updating the
8570       // pointer after each store correctly using vscale.
8571       while (NumParts) {
8572         SDValue Store = DAG.getStore(Chain, DL, OutVals[i], Ptr, MPI);
8573         MemOpChains.push_back(Store);
8574 
8575         NumParts--;
8576         if (NumParts > 0) {
8577           SDValue BytesIncrement;
8578           if (isScalable) {
8579             BytesIncrement = DAG.getVScale(
8580                 DL, Ptr.getValueType(),
8581                 APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize));
8582           } else {
8583             BytesIncrement = DAG.getConstant(
8584                 APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize), DL,
8585                 Ptr.getValueType());
8586           }
8587           SDNodeFlags Flags;
8588           Flags.setNoUnsignedWrap(true);
8589 
8590           MPI = MachinePointerInfo(MPI.getAddrSpace());
8591           Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr,
8592                             BytesIncrement, Flags);
8593           ExtraArgLocs++;
8594           i++;
8595         }
8596       }
8597 
8598       Arg = SpillSlot;
8599       break;
8600     }
8601 
8602     if (VA.isRegLoc()) {
8603       if (i == 0 && Flags.isReturned() && !Flags.isSwiftSelf() &&
8604           Outs[0].VT == MVT::i64) {
8605         assert(VA.getLocVT() == MVT::i64 &&
8606                "unexpected calling convention register assignment");
8607         assert(!Ins.empty() && Ins[0].VT == MVT::i64 &&
8608                "unexpected use of 'returned'");
8609         IsThisReturn = true;
8610       }
8611       if (RegsUsed.count(VA.getLocReg())) {
8612         // If this register has already been used then we're trying to pack
8613         // parts of an [N x i32] into an X-register. The extension type will
8614         // take care of putting the two halves in the right place but we have to
8615         // combine them.
8616         SDValue &Bits =
8617             llvm::find_if(RegsToPass,
8618                           [=](const std::pair<unsigned, SDValue> &Elt) {
8619                             return Elt.first == VA.getLocReg();
8620                           })
8621                 ->second;
8622         Bits = DAG.getNode(ISD::OR, DL, Bits.getValueType(), Bits, Arg);
8623         // Call site info is used for function's parameter entry value
8624         // tracking. For now we track only simple cases when parameter
8625         // is transferred through whole register.
8626         llvm::erase_if(CSInfo.ArgRegPairs,
8627                        [&VA](MachineFunction::ArgRegPair ArgReg) {
8628                          return ArgReg.Reg == VA.getLocReg();
8629                        });
8630       } else {
8631         // Add an extra level of indirection for streaming mode changes by
8632         // using a pseudo copy node that cannot be rematerialised between a
8633         // smstart/smstop and the call by the simple register coalescer.
8634         if (RequiresSMChange && isPassedInFPR(Arg.getValueType()))
8635           Arg = DAG.getNode(AArch64ISD::COALESCER_BARRIER, DL,
8636                             Arg.getValueType(), Arg);
8637         RegsToPass.emplace_back(VA.getLocReg(), Arg);
8638         RegsUsed.insert(VA.getLocReg());
8639         const TargetOptions &Options = DAG.getTarget().Options;
8640         if (Options.EmitCallSiteInfo)
8641           CSInfo.ArgRegPairs.emplace_back(VA.getLocReg(), i);
8642       }
8643     } else {
8644       assert(VA.isMemLoc());
8645 
8646       SDValue DstAddr;
8647       MachinePointerInfo DstInfo;
8648 
8649       // FIXME: This works on big-endian for composite byvals, which are the
8650       // common case. It should also work for fundamental types too.
8651       uint32_t BEAlign = 0;
8652       unsigned OpSize;
8653       if (VA.getLocInfo() == CCValAssign::Indirect ||
8654           VA.getLocInfo() == CCValAssign::Trunc)
8655         OpSize = VA.getLocVT().getFixedSizeInBits();
8656       else
8657         OpSize = Flags.isByVal() ? Flags.getByValSize() * 8
8658                                  : VA.getValVT().getSizeInBits();
8659       OpSize = (OpSize + 7) / 8;
8660       if (!Subtarget->isLittleEndian() && !Flags.isByVal() &&
8661           !Flags.isInConsecutiveRegs()) {
8662         if (OpSize < 8)
8663           BEAlign = 8 - OpSize;
8664       }
8665       unsigned LocMemOffset = VA.getLocMemOffset();
8666       int32_t Offset = LocMemOffset + BEAlign;
8667       SDValue PtrOff = DAG.getIntPtrConstant(Offset, DL);
8668       PtrOff = DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr, PtrOff);
8669 
8670       if (IsTailCall) {
8671         Offset = Offset + FPDiff;
8672         int FI = MF.getFrameInfo().CreateFixedObject(OpSize, Offset, true);
8673 
8674         DstAddr = DAG.getFrameIndex(FI, PtrVT);
8675         DstInfo = MachinePointerInfo::getFixedStack(MF, FI);
8676 
8677         // Make sure any stack arguments overlapping with where we're storing
8678         // are loaded before this eventual operation. Otherwise they'll be
8679         // clobbered.
8680         Chain = addTokenForArgument(Chain, DAG, MF.getFrameInfo(), FI);
8681       } else {
8682         SDValue PtrOff = DAG.getIntPtrConstant(Offset, DL);
8683 
8684         DstAddr = DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr, PtrOff);
8685         DstInfo = MachinePointerInfo::getStack(MF, LocMemOffset);
8686       }
8687 
8688       if (Outs[i].Flags.isByVal()) {
8689         SDValue SizeNode =
8690             DAG.getConstant(Outs[i].Flags.getByValSize(), DL, MVT::i64);
8691         SDValue Cpy = DAG.getMemcpy(
8692             Chain, DL, DstAddr, Arg, SizeNode,
8693             Outs[i].Flags.getNonZeroByValAlign(),
8694             /*isVol = */ false, /*AlwaysInline = */ false,
8695             /*CI=*/nullptr, std::nullopt, DstInfo, MachinePointerInfo());
8696 
8697         MemOpChains.push_back(Cpy);
8698       } else {
8699         // Since we pass i1/i8/i16 as i1/i8/i16 on stack and Arg is already
8700         // promoted to a legal register type i32, we should truncate Arg back to
8701         // i1/i8/i16.
8702         if (VA.getValVT() == MVT::i1 || VA.getValVT() == MVT::i8 ||
8703             VA.getValVT() == MVT::i16)
8704           Arg = DAG.getNode(ISD::TRUNCATE, DL, VA.getValVT(), Arg);
8705 
8706         SDValue Store = DAG.getStore(Chain, DL, Arg, DstAddr, DstInfo);
8707         MemOpChains.push_back(Store);
8708       }
8709     }
8710   }
8711 
8712   if (IsVarArg && Subtarget->isWindowsArm64EC()) {
8713     SDValue ParamPtr = StackPtr;
8714     if (IsTailCall) {
8715       // Create a dummy object at the top of the stack that can be used to get
8716       // the SP after the epilogue
8717       int FI = MF.getFrameInfo().CreateFixedObject(1, FPDiff, true);
8718       ParamPtr = DAG.getFrameIndex(FI, PtrVT);
8719     }
8720 
8721     // For vararg calls, the Arm64EC ABI requires values in x4 and x5
8722     // describing the argument list.  x4 contains the address of the
8723     // first stack parameter. x5 contains the size in bytes of all parameters
8724     // passed on the stack.
8725     RegsToPass.emplace_back(AArch64::X4, ParamPtr);
8726     RegsToPass.emplace_back(AArch64::X5,
8727                             DAG.getConstant(NumBytes, DL, MVT::i64));
8728   }
8729 
8730   if (!MemOpChains.empty())
8731     Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, MemOpChains);
8732 
8733   SDValue InGlue;
8734   if (RequiresSMChange) {
8735     if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
8736       Chain = DAG.getNode(AArch64ISD::VG_SAVE, DL,
8737                           DAG.getVTList(MVT::Other, MVT::Glue), Chain);
8738       InGlue = Chain.getValue(1);
8739     }
8740 
8741     SDValue NewChain = changeStreamingMode(
8742         DAG, DL, CalleeAttrs.hasStreamingInterface(), Chain, InGlue,
8743         getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
8744     Chain = NewChain.getValue(0);
8745     InGlue = NewChain.getValue(1);
8746   }
8747 
8748   // Build a sequence of copy-to-reg nodes chained together with token chain
8749   // and flag operands which copy the outgoing args into the appropriate regs.
8750   for (auto &RegToPass : RegsToPass) {
8751     Chain = DAG.getCopyToReg(Chain, DL, RegToPass.first,
8752                              RegToPass.second, InGlue);
8753     InGlue = Chain.getValue(1);
8754   }
8755 
8756   // If the callee is a GlobalAddress/ExternalSymbol node (quite common, every
8757   // direct call is) turn it into a TargetGlobalAddress/TargetExternalSymbol
8758   // node so that legalize doesn't hack it.
8759   if (auto *G = dyn_cast<GlobalAddressSDNode>(Callee)) {
8760     auto GV = G->getGlobal();
8761     unsigned OpFlags =
8762         Subtarget->classifyGlobalFunctionReference(GV, getTargetMachine());
8763     if (OpFlags & AArch64II::MO_GOT) {
8764       Callee = DAG.getTargetGlobalAddress(GV, DL, PtrVT, 0, OpFlags);
8765       Callee = DAG.getNode(AArch64ISD::LOADgot, DL, PtrVT, Callee);
8766     } else {
8767       const GlobalValue *GV = G->getGlobal();
8768       Callee = DAG.getTargetGlobalAddress(GV, DL, PtrVT, 0, OpFlags);
8769     }
8770   } else if (auto *S = dyn_cast<ExternalSymbolSDNode>(Callee)) {
8771     bool UseGot = (getTargetMachine().getCodeModel() == CodeModel::Large &&
8772                    Subtarget->isTargetMachO()) ||
8773                   MF.getFunction().getParent()->getRtLibUseGOT();
8774     const char *Sym = S->getSymbol();
8775     if (UseGot) {
8776       Callee = DAG.getTargetExternalSymbol(Sym, PtrVT, AArch64II::MO_GOT);
8777       Callee = DAG.getNode(AArch64ISD::LOADgot, DL, PtrVT, Callee);
8778     } else {
8779       Callee = DAG.getTargetExternalSymbol(Sym, PtrVT, 0);
8780     }
8781   }
8782 
8783   // We don't usually want to end the call-sequence here because we would tidy
8784   // the frame up *after* the call, however in the ABI-changing tail-call case
8785   // we've carefully laid out the parameters so that when sp is reset they'll be
8786   // in the correct location.
8787   if (IsTailCall && !IsSibCall) {
8788     Chain = DAG.getCALLSEQ_END(Chain, 0, 0, InGlue, DL);
8789     InGlue = Chain.getValue(1);
8790   }
8791 
8792   unsigned Opc = IsTailCall ? AArch64ISD::TC_RETURN : AArch64ISD::CALL;
8793 
8794   std::vector<SDValue> Ops;
8795   Ops.push_back(Chain);
8796   Ops.push_back(Callee);
8797 
8798   // Calls with operand bundle "clang.arc.attachedcall" are special. They should
8799   // be expanded to the call, directly followed by a special marker sequence and
8800   // a call to an ObjC library function.  Use CALL_RVMARKER to do that.
8801   if (CLI.CB && objcarc::hasAttachedCallOpBundle(CLI.CB)) {
8802     assert(!IsTailCall &&
8803            "tail calls cannot be marked with clang.arc.attachedcall");
8804     Opc = AArch64ISD::CALL_RVMARKER;
8805 
8806     // Add a target global address for the retainRV/claimRV runtime function
8807     // just before the call target.
8808     Function *ARCFn = *objcarc::getAttachedARCFunction(CLI.CB);
8809     auto GA = DAG.getTargetGlobalAddress(ARCFn, DL, PtrVT);
8810     Ops.insert(Ops.begin() + 1, GA);
8811   } else if (CallConv == CallingConv::ARM64EC_Thunk_X64) {
8812     Opc = AArch64ISD::CALL_ARM64EC_TO_X64;
8813   } else if (GuardWithBTI) {
8814     Opc = AArch64ISD::CALL_BTI;
8815   }
8816 
8817   if (IsTailCall) {
8818     // Each tail call may have to adjust the stack by a different amount, so
8819     // this information must travel along with the operation for eventual
8820     // consumption by emitEpilogue.
8821     Ops.push_back(DAG.getTargetConstant(FPDiff, DL, MVT::i32));
8822   }
8823 
8824   if (CLI.PAI) {
8825     const uint64_t Key = CLI.PAI->Key;
8826     assert((Key == AArch64PACKey::IA || Key == AArch64PACKey::IB) &&
8827            "Invalid auth call key");
8828 
8829     // Split the discriminator into address/integer components.
8830     SDValue AddrDisc, IntDisc;
8831     std::tie(IntDisc, AddrDisc) =
8832         extractPtrauthBlendDiscriminators(CLI.PAI->Discriminator, &DAG);
8833 
8834     if (Opc == AArch64ISD::CALL_RVMARKER)
8835       Opc = AArch64ISD::AUTH_CALL_RVMARKER;
8836     else
8837       Opc = IsTailCall ? AArch64ISD::AUTH_TC_RETURN : AArch64ISD::AUTH_CALL;
8838     Ops.push_back(DAG.getTargetConstant(Key, DL, MVT::i32));
8839     Ops.push_back(IntDisc);
8840     Ops.push_back(AddrDisc);
8841   }
8842 
8843   // Add argument registers to the end of the list so that they are known live
8844   // into the call.
8845   for (auto &RegToPass : RegsToPass)
8846     Ops.push_back(DAG.getRegister(RegToPass.first,
8847                                   RegToPass.second.getValueType()));
8848 
8849   // Add a register mask operand representing the call-preserved registers.
8850   const uint32_t *Mask;
8851   const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
8852   if (IsThisReturn) {
8853     // For 'this' returns, use the X0-preserving mask if applicable
8854     Mask = TRI->getThisReturnPreservedMask(MF, CallConv);
8855     if (!Mask) {
8856       IsThisReturn = false;
8857       Mask = TRI->getCallPreservedMask(MF, CallConv);
8858     }
8859   } else
8860     Mask = TRI->getCallPreservedMask(MF, CallConv);
8861 
8862   if (Subtarget->hasCustomCallingConv())
8863     TRI->UpdateCustomCallPreservedMask(MF, &Mask);
8864 
8865   if (TRI->isAnyArgRegReserved(MF))
8866     TRI->emitReservedArgRegCallError(MF);
8867 
8868   assert(Mask && "Missing call preserved mask for calling convention");
8869   Ops.push_back(DAG.getRegisterMask(Mask));
8870 
8871   if (InGlue.getNode())
8872     Ops.push_back(InGlue);
8873 
8874   SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);
8875 
8876   // If we're doing a tall call, use a TC_RETURN here rather than an
8877   // actual call instruction.
8878   if (IsTailCall) {
8879     MF.getFrameInfo().setHasTailCall();
8880     SDValue Ret = DAG.getNode(Opc, DL, NodeTys, Ops);
8881     if (IsCFICall)
8882       Ret.getNode()->setCFIType(CLI.CFIType->getZExtValue());
8883 
8884     DAG.addNoMergeSiteInfo(Ret.getNode(), CLI.NoMerge);
8885     DAG.addCallSiteInfo(Ret.getNode(), std::move(CSInfo));
8886     return Ret;
8887   }
8888 
8889   // Returns a chain and a flag for retval copy to use.
8890   Chain = DAG.getNode(Opc, DL, NodeTys, Ops);
8891   if (IsCFICall)
8892     Chain.getNode()->setCFIType(CLI.CFIType->getZExtValue());
8893 
8894   DAG.addNoMergeSiteInfo(Chain.getNode(), CLI.NoMerge);
8895   InGlue = Chain.getValue(1);
8896   DAG.addCallSiteInfo(Chain.getNode(), std::move(CSInfo));
8897 
8898   uint64_t CalleePopBytes =
8899       DoesCalleeRestoreStack(CallConv, TailCallOpt) ? alignTo(NumBytes, 16) : 0;
8900 
8901   Chain = DAG.getCALLSEQ_END(Chain, NumBytes, CalleePopBytes, InGlue, DL);
8902   InGlue = Chain.getValue(1);
8903 
8904   // Handle result values, copying them out of physregs into vregs that we
8905   // return.
8906   SDValue Result = LowerCallResult(
8907       Chain, InGlue, CallConv, IsVarArg, RVLocs, DL, DAG, InVals, IsThisReturn,
8908       IsThisReturn ? OutVals[0] : SDValue(), RequiresSMChange);
8909 
8910   if (!Ins.empty())
8911     InGlue = Result.getValue(Result->getNumValues() - 1);
8912 
8913   if (RequiresSMChange) {
8914     assert(PStateSM && "Expected a PStateSM to be set");
8915     Result = changeStreamingMode(
8916         DAG, DL, !CalleeAttrs.hasStreamingInterface(), Result, InGlue,
8917         getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
8918 
8919     if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
8920       InGlue = Result.getValue(1);
8921       Result =
8922           DAG.getNode(AArch64ISD::VG_RESTORE, DL,
8923                       DAG.getVTList(MVT::Other, MVT::Glue), {Result, InGlue});
8924     }
8925   }
8926 
8927   if (CallerAttrs.requiresEnablingZAAfterCall(CalleeAttrs))
8928     // Unconditionally resume ZA.
8929     Result = DAG.getNode(
8930         AArch64ISD::SMSTART, DL, MVT::Other, Result,
8931         DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
8932         DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
8933 
8934   if (ShouldPreserveZT0)
8935     Result =
8936         DAG.getNode(AArch64ISD::RESTORE_ZT, DL, DAG.getVTList(MVT::Other),
8937                     {Result, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
8938 
8939   if (RequiresLazySave) {
8940     // Conditionally restore the lazy save using a pseudo node.
8941     TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
8942     SDValue RegMask = DAG.getRegisterMask(
8943         TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
8944     SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
8945         "__arm_tpidr2_restore", getPointerTy(DAG.getDataLayout()));
8946     SDValue TPIDR2_EL0 = DAG.getNode(
8947         ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
8948         DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
8949 
8950     // Copy the address of the TPIDR2 block into X0 before 'calling' the
8951     // RESTORE_ZA pseudo.
8952     SDValue Glue;
8953     SDValue TPIDR2Block = DAG.getFrameIndex(
8954         TPIDR2.FrameIndex,
8955         DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
8956     Result = DAG.getCopyToReg(Result, DL, AArch64::X0, TPIDR2Block, Glue);
8957     Result =
8958         DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
8959                     {Result, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
8960                      RestoreRoutine, RegMask, Result.getValue(1)});
8961 
8962     // Finally reset the TPIDR2_EL0 register to 0.
8963     Result = DAG.getNode(
8964         ISD::INTRINSIC_VOID, DL, MVT::Other, Result,
8965         DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
8966         DAG.getConstant(0, DL, MVT::i64));
8967     TPIDR2.Uses++;
8968   }
8969 
8970   if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0) {
8971     for (unsigned I = 0; I < InVals.size(); ++I) {
8972       // The smstart/smstop is chained as part of the call, but when the
8973       // resulting chain is discarded (which happens when the call is not part
8974       // of a chain, e.g. a call to @llvm.cos()), we need to ensure the
8975       // smstart/smstop is chained to the result value. We can do that by doing
8976       // a vreg -> vreg copy.
8977       Register Reg = MF.getRegInfo().createVirtualRegister(
8978           getRegClassFor(InVals[I].getValueType().getSimpleVT()));
8979       SDValue X = DAG.getCopyToReg(Result, DL, Reg, InVals[I]);
8980       InVals[I] = DAG.getCopyFromReg(X, DL, Reg,
8981                                      InVals[I].getValueType());
8982     }
8983   }
8984 
8985   if (CallConv == CallingConv::PreserveNone) {
8986     for (const ISD::OutputArg &O : Outs) {
8987       if (O.Flags.isSwiftSelf() || O.Flags.isSwiftError() ||
8988           O.Flags.isSwiftAsync()) {
8989         MachineFunction &MF = DAG.getMachineFunction();
8990         DAG.getContext()->diagnose(DiagnosticInfoUnsupported(
8991             MF.getFunction(),
8992             "Swift attributes can't be used with preserve_none",
8993             DL.getDebugLoc()));
8994         break;
8995       }
8996     }
8997   }
8998 
8999   return Result;
9000 }
9001 
CanLowerReturn(CallingConv::ID CallConv,MachineFunction & MF,bool isVarArg,const SmallVectorImpl<ISD::OutputArg> & Outs,LLVMContext & Context) const9002 bool AArch64TargetLowering::CanLowerReturn(
9003     CallingConv::ID CallConv, MachineFunction &MF, bool isVarArg,
9004     const SmallVectorImpl<ISD::OutputArg> &Outs, LLVMContext &Context) const {
9005   CCAssignFn *RetCC = CCAssignFnForReturn(CallConv);
9006   SmallVector<CCValAssign, 16> RVLocs;
9007   CCState CCInfo(CallConv, isVarArg, MF, RVLocs, Context);
9008   return CCInfo.CheckReturn(Outs, RetCC);
9009 }
9010 
9011 SDValue
LowerReturn(SDValue Chain,CallingConv::ID CallConv,bool isVarArg,const SmallVectorImpl<ISD::OutputArg> & Outs,const SmallVectorImpl<SDValue> & OutVals,const SDLoc & DL,SelectionDAG & DAG) const9012 AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
9013                                    bool isVarArg,
9014                                    const SmallVectorImpl<ISD::OutputArg> &Outs,
9015                                    const SmallVectorImpl<SDValue> &OutVals,
9016                                    const SDLoc &DL, SelectionDAG &DAG) const {
9017   auto &MF = DAG.getMachineFunction();
9018   auto *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
9019 
9020   CCAssignFn *RetCC = CCAssignFnForReturn(CallConv);
9021   SmallVector<CCValAssign, 16> RVLocs;
9022   CCState CCInfo(CallConv, isVarArg, MF, RVLocs, *DAG.getContext());
9023   CCInfo.AnalyzeReturn(Outs, RetCC);
9024 
9025   // Copy the result values into the output registers.
9026   SDValue Glue;
9027   SmallVector<std::pair<unsigned, SDValue>, 4> RetVals;
9028   SmallSet<unsigned, 4> RegsUsed;
9029   for (unsigned i = 0, realRVLocIdx = 0; i != RVLocs.size();
9030        ++i, ++realRVLocIdx) {
9031     CCValAssign &VA = RVLocs[i];
9032     assert(VA.isRegLoc() && "Can only return in registers!");
9033     SDValue Arg = OutVals[realRVLocIdx];
9034 
9035     switch (VA.getLocInfo()) {
9036     default:
9037       llvm_unreachable("Unknown loc info!");
9038     case CCValAssign::Full:
9039       if (Outs[i].ArgVT == MVT::i1) {
9040         // AAPCS requires i1 to be zero-extended to i8 by the producer of the
9041         // value. This is strictly redundant on Darwin (which uses "zeroext
9042         // i1"), but will be optimised out before ISel.
9043         Arg = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, Arg);
9044         Arg = DAG.getNode(ISD::ZERO_EXTEND, DL, VA.getLocVT(), Arg);
9045       }
9046       break;
9047     case CCValAssign::BCvt:
9048       Arg = DAG.getNode(ISD::BITCAST, DL, VA.getLocVT(), Arg);
9049       break;
9050     case CCValAssign::AExt:
9051     case CCValAssign::ZExt:
9052       Arg = DAG.getZExtOrTrunc(Arg, DL, VA.getLocVT());
9053       break;
9054     case CCValAssign::AExtUpper:
9055       assert(VA.getValVT() == MVT::i32 && "only expect 32 -> 64 upper bits");
9056       Arg = DAG.getZExtOrTrunc(Arg, DL, VA.getLocVT());
9057       Arg = DAG.getNode(ISD::SHL, DL, VA.getLocVT(), Arg,
9058                         DAG.getConstant(32, DL, VA.getLocVT()));
9059       break;
9060     }
9061 
9062     if (RegsUsed.count(VA.getLocReg())) {
9063       SDValue &Bits =
9064           llvm::find_if(RetVals, [=](const std::pair<unsigned, SDValue> &Elt) {
9065             return Elt.first == VA.getLocReg();
9066           })->second;
9067       Bits = DAG.getNode(ISD::OR, DL, Bits.getValueType(), Bits, Arg);
9068     } else {
9069       RetVals.emplace_back(VA.getLocReg(), Arg);
9070       RegsUsed.insert(VA.getLocReg());
9071     }
9072   }
9073 
9074   const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
9075 
9076   // Emit SMSTOP before returning from a locally streaming function
9077   SMEAttrs FuncAttrs(MF.getFunction());
9078   if (FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface()) {
9079     if (FuncAttrs.hasStreamingCompatibleInterface()) {
9080       Register Reg = FuncInfo->getPStateSMReg();
9081       assert(Reg.isValid() && "PStateSM Register is invalid");
9082       SDValue PStateSM = DAG.getCopyFromReg(Chain, DL, Reg, MVT::i64);
9083       Chain = changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
9084                                   /*Glue*/ SDValue(),
9085                                   AArch64SME::IfCallerIsNonStreaming, PStateSM);
9086     } else
9087       Chain = changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
9088                                   /*Glue*/ SDValue(), AArch64SME::Always);
9089     Glue = Chain.getValue(1);
9090   }
9091 
9092   SmallVector<SDValue, 4> RetOps(1, Chain);
9093   for (auto &RetVal : RetVals) {
9094     if (FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface() &&
9095         isPassedInFPR(RetVal.second.getValueType()))
9096       RetVal.second = DAG.getNode(AArch64ISD::COALESCER_BARRIER, DL,
9097                                   RetVal.second.getValueType(), RetVal.second);
9098     Chain = DAG.getCopyToReg(Chain, DL, RetVal.first, RetVal.second, Glue);
9099     Glue = Chain.getValue(1);
9100     RetOps.push_back(
9101         DAG.getRegister(RetVal.first, RetVal.second.getValueType()));
9102   }
9103 
9104   // Windows AArch64 ABIs require that for returning structs by value we copy
9105   // the sret argument into X0 for the return.
9106   // We saved the argument into a virtual register in the entry block,
9107   // so now we copy the value out and into X0.
9108   if (unsigned SRetReg = FuncInfo->getSRetReturnReg()) {
9109     SDValue Val = DAG.getCopyFromReg(RetOps[0], DL, SRetReg,
9110                                      getPointerTy(MF.getDataLayout()));
9111 
9112     unsigned RetValReg = AArch64::X0;
9113     if (CallConv == CallingConv::ARM64EC_Thunk_X64)
9114       RetValReg = AArch64::X8;
9115     Chain = DAG.getCopyToReg(Chain, DL, RetValReg, Val, Glue);
9116     Glue = Chain.getValue(1);
9117 
9118     RetOps.push_back(
9119       DAG.getRegister(RetValReg, getPointerTy(DAG.getDataLayout())));
9120   }
9121 
9122   const MCPhysReg *I = TRI->getCalleeSavedRegsViaCopy(&MF);
9123   if (I) {
9124     for (; *I; ++I) {
9125       if (AArch64::GPR64RegClass.contains(*I))
9126         RetOps.push_back(DAG.getRegister(*I, MVT::i64));
9127       else if (AArch64::FPR64RegClass.contains(*I))
9128         RetOps.push_back(DAG.getRegister(*I, MVT::getFloatingPointVT(64)));
9129       else
9130         llvm_unreachable("Unexpected register class in CSRsViaCopy!");
9131     }
9132   }
9133 
9134   RetOps[0] = Chain; // Update chain.
9135 
9136   // Add the glue if we have it.
9137   if (Glue.getNode())
9138     RetOps.push_back(Glue);
9139 
9140   if (CallConv == CallingConv::ARM64EC_Thunk_X64) {
9141     // ARM64EC entry thunks use a special return sequence: instead of a regular
9142     // "ret" instruction, they need to explicitly call the emulator.
9143     EVT PtrVT = getPointerTy(DAG.getDataLayout());
9144     SDValue Arm64ECRetDest =
9145         DAG.getExternalSymbol("__os_arm64x_dispatch_ret", PtrVT);
9146     Arm64ECRetDest =
9147         getAddr(cast<ExternalSymbolSDNode>(Arm64ECRetDest), DAG, 0);
9148     Arm64ECRetDest = DAG.getLoad(PtrVT, DL, DAG.getEntryNode(), Arm64ECRetDest,
9149                                  MachinePointerInfo());
9150     RetOps.insert(RetOps.begin() + 1, Arm64ECRetDest);
9151     RetOps.insert(RetOps.begin() + 2, DAG.getTargetConstant(0, DL, MVT::i32));
9152     return DAG.getNode(AArch64ISD::TC_RETURN, DL, MVT::Other, RetOps);
9153   }
9154 
9155   return DAG.getNode(AArch64ISD::RET_GLUE, DL, MVT::Other, RetOps);
9156 }
9157 
9158 //===----------------------------------------------------------------------===//
9159 //  Other Lowering Code
9160 //===----------------------------------------------------------------------===//
9161 
getTargetNode(GlobalAddressSDNode * N,EVT Ty,SelectionDAG & DAG,unsigned Flag) const9162 SDValue AArch64TargetLowering::getTargetNode(GlobalAddressSDNode *N, EVT Ty,
9163                                              SelectionDAG &DAG,
9164                                              unsigned Flag) const {
9165   return DAG.getTargetGlobalAddress(N->getGlobal(), SDLoc(N), Ty,
9166                                     N->getOffset(), Flag);
9167 }
9168 
getTargetNode(JumpTableSDNode * N,EVT Ty,SelectionDAG & DAG,unsigned Flag) const9169 SDValue AArch64TargetLowering::getTargetNode(JumpTableSDNode *N, EVT Ty,
9170                                              SelectionDAG &DAG,
9171                                              unsigned Flag) const {
9172   return DAG.getTargetJumpTable(N->getIndex(), Ty, Flag);
9173 }
9174 
getTargetNode(ConstantPoolSDNode * N,EVT Ty,SelectionDAG & DAG,unsigned Flag) const9175 SDValue AArch64TargetLowering::getTargetNode(ConstantPoolSDNode *N, EVT Ty,
9176                                              SelectionDAG &DAG,
9177                                              unsigned Flag) const {
9178   return DAG.getTargetConstantPool(N->getConstVal(), Ty, N->getAlign(),
9179                                    N->getOffset(), Flag);
9180 }
9181 
getTargetNode(BlockAddressSDNode * N,EVT Ty,SelectionDAG & DAG,unsigned Flag) const9182 SDValue AArch64TargetLowering::getTargetNode(BlockAddressSDNode* N, EVT Ty,
9183                                              SelectionDAG &DAG,
9184                                              unsigned Flag) const {
9185   return DAG.getTargetBlockAddress(N->getBlockAddress(), Ty, 0, Flag);
9186 }
9187 
getTargetNode(ExternalSymbolSDNode * N,EVT Ty,SelectionDAG & DAG,unsigned Flag) const9188 SDValue AArch64TargetLowering::getTargetNode(ExternalSymbolSDNode *N, EVT Ty,
9189                                              SelectionDAG &DAG,
9190                                              unsigned Flag) const {
9191   return DAG.getTargetExternalSymbol(N->getSymbol(), Ty, Flag);
9192 }
9193 
9194 // (loadGOT sym)
9195 template <class NodeTy>
getGOT(NodeTy * N,SelectionDAG & DAG,unsigned Flags) const9196 SDValue AArch64TargetLowering::getGOT(NodeTy *N, SelectionDAG &DAG,
9197                                       unsigned Flags) const {
9198   LLVM_DEBUG(dbgs() << "AArch64TargetLowering::getGOT\n");
9199   SDLoc DL(N);
9200   EVT Ty = getPointerTy(DAG.getDataLayout());
9201   SDValue GotAddr = getTargetNode(N, Ty, DAG, AArch64II::MO_GOT | Flags);
9202   // FIXME: Once remat is capable of dealing with instructions with register
9203   // operands, expand this into two nodes instead of using a wrapper node.
9204   return DAG.getNode(AArch64ISD::LOADgot, DL, Ty, GotAddr);
9205 }
9206 
9207 // (wrapper %highest(sym), %higher(sym), %hi(sym), %lo(sym))
9208 template <class NodeTy>
getAddrLarge(NodeTy * N,SelectionDAG & DAG,unsigned Flags) const9209 SDValue AArch64TargetLowering::getAddrLarge(NodeTy *N, SelectionDAG &DAG,
9210                                             unsigned Flags) const {
9211   LLVM_DEBUG(dbgs() << "AArch64TargetLowering::getAddrLarge\n");
9212   SDLoc DL(N);
9213   EVT Ty = getPointerTy(DAG.getDataLayout());
9214   const unsigned char MO_NC = AArch64II::MO_NC;
9215   return DAG.getNode(
9216       AArch64ISD::WrapperLarge, DL, Ty,
9217       getTargetNode(N, Ty, DAG, AArch64II::MO_G3 | Flags),
9218       getTargetNode(N, Ty, DAG, AArch64II::MO_G2 | MO_NC | Flags),
9219       getTargetNode(N, Ty, DAG, AArch64II::MO_G1 | MO_NC | Flags),
9220       getTargetNode(N, Ty, DAG, AArch64II::MO_G0 | MO_NC | Flags));
9221 }
9222 
9223 // (addlow (adrp %hi(sym)) %lo(sym))
9224 template <class NodeTy>
getAddr(NodeTy * N,SelectionDAG & DAG,unsigned Flags) const9225 SDValue AArch64TargetLowering::getAddr(NodeTy *N, SelectionDAG &DAG,
9226                                        unsigned Flags) const {
9227   LLVM_DEBUG(dbgs() << "AArch64TargetLowering::getAddr\n");
9228   SDLoc DL(N);
9229   EVT Ty = getPointerTy(DAG.getDataLayout());
9230   SDValue Hi = getTargetNode(N, Ty, DAG, AArch64II::MO_PAGE | Flags);
9231   SDValue Lo = getTargetNode(N, Ty, DAG,
9232                              AArch64II::MO_PAGEOFF | AArch64II::MO_NC | Flags);
9233   SDValue ADRP = DAG.getNode(AArch64ISD::ADRP, DL, Ty, Hi);
9234   return DAG.getNode(AArch64ISD::ADDlow, DL, Ty, ADRP, Lo);
9235 }
9236 
9237 // (adr sym)
9238 template <class NodeTy>
getAddrTiny(NodeTy * N,SelectionDAG & DAG,unsigned Flags) const9239 SDValue AArch64TargetLowering::getAddrTiny(NodeTy *N, SelectionDAG &DAG,
9240                                            unsigned Flags) const {
9241   LLVM_DEBUG(dbgs() << "AArch64TargetLowering::getAddrTiny\n");
9242   SDLoc DL(N);
9243   EVT Ty = getPointerTy(DAG.getDataLayout());
9244   SDValue Sym = getTargetNode(N, Ty, DAG, Flags);
9245   return DAG.getNode(AArch64ISD::ADR, DL, Ty, Sym);
9246 }
9247 
LowerGlobalAddress(SDValue Op,SelectionDAG & DAG) const9248 SDValue AArch64TargetLowering::LowerGlobalAddress(SDValue Op,
9249                                                   SelectionDAG &DAG) const {
9250   GlobalAddressSDNode *GN = cast<GlobalAddressSDNode>(Op);
9251   const GlobalValue *GV = GN->getGlobal();
9252   unsigned OpFlags = Subtarget->ClassifyGlobalReference(GV, getTargetMachine());
9253 
9254   if (OpFlags != AArch64II::MO_NO_FLAG)
9255     assert(cast<GlobalAddressSDNode>(Op)->getOffset() == 0 &&
9256            "unexpected offset in global node");
9257 
9258   // This also catches the large code model case for Darwin, and tiny code
9259   // model with got relocations.
9260   if ((OpFlags & AArch64II::MO_GOT) != 0) {
9261     return getGOT(GN, DAG, OpFlags);
9262   }
9263 
9264   SDValue Result;
9265   if (getTargetMachine().getCodeModel() == CodeModel::Large &&
9266       !getTargetMachine().isPositionIndependent()) {
9267     Result = getAddrLarge(GN, DAG, OpFlags);
9268   } else if (getTargetMachine().getCodeModel() == CodeModel::Tiny) {
9269     Result = getAddrTiny(GN, DAG, OpFlags);
9270   } else {
9271     Result = getAddr(GN, DAG, OpFlags);
9272   }
9273   EVT PtrVT = getPointerTy(DAG.getDataLayout());
9274   SDLoc DL(GN);
9275   if (OpFlags & (AArch64II::MO_DLLIMPORT | AArch64II::MO_COFFSTUB))
9276     Result = DAG.getLoad(PtrVT, DL, DAG.getEntryNode(), Result,
9277                          MachinePointerInfo::getGOT(DAG.getMachineFunction()));
9278   return Result;
9279 }
9280 
9281 /// Convert a TLS address reference into the correct sequence of loads
9282 /// and calls to compute the variable's address (for Darwin, currently) and
9283 /// return an SDValue containing the final node.
9284 
9285 /// Darwin only has one TLS scheme which must be capable of dealing with the
9286 /// fully general situation, in the worst case. This means:
9287 ///     + "extern __thread" declaration.
9288 ///     + Defined in a possibly unknown dynamic library.
9289 ///
9290 /// The general system is that each __thread variable has a [3 x i64] descriptor
9291 /// which contains information used by the runtime to calculate the address. The
9292 /// only part of this the compiler needs to know about is the first xword, which
9293 /// contains a function pointer that must be called with the address of the
9294 /// entire descriptor in "x0".
9295 ///
9296 /// Since this descriptor may be in a different unit, in general even the
9297 /// descriptor must be accessed via an indirect load. The "ideal" code sequence
9298 /// is:
9299 ///     adrp x0, _var@TLVPPAGE
9300 ///     ldr x0, [x0, _var@TLVPPAGEOFF]   ; x0 now contains address of descriptor
9301 ///     ldr x1, [x0]                     ; x1 contains 1st entry of descriptor,
9302 ///                                      ; the function pointer
9303 ///     blr x1                           ; Uses descriptor address in x0
9304 ///     ; Address of _var is now in x0.
9305 ///
9306 /// If the address of _var's descriptor *is* known to the linker, then it can
9307 /// change the first "ldr" instruction to an appropriate "add x0, x0, #imm" for
9308 /// a slight efficiency gain.
9309 SDValue
LowerDarwinGlobalTLSAddress(SDValue Op,SelectionDAG & DAG) const9310 AArch64TargetLowering::LowerDarwinGlobalTLSAddress(SDValue Op,
9311                                                    SelectionDAG &DAG) const {
9312   assert(Subtarget->isTargetDarwin() &&
9313          "This function expects a Darwin target");
9314 
9315   SDLoc DL(Op);
9316   MVT PtrVT = getPointerTy(DAG.getDataLayout());
9317   MVT PtrMemVT = getPointerMemTy(DAG.getDataLayout());
9318   const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
9319 
9320   SDValue TLVPAddr =
9321       DAG.getTargetGlobalAddress(GV, DL, PtrVT, 0, AArch64II::MO_TLS);
9322   SDValue DescAddr = DAG.getNode(AArch64ISD::LOADgot, DL, PtrVT, TLVPAddr);
9323 
9324   // The first entry in the descriptor is a function pointer that we must call
9325   // to obtain the address of the variable.
9326   SDValue Chain = DAG.getEntryNode();
9327   SDValue FuncTLVGet = DAG.getLoad(
9328       PtrMemVT, DL, Chain, DescAddr,
9329       MachinePointerInfo::getGOT(DAG.getMachineFunction()),
9330       Align(PtrMemVT.getSizeInBits() / 8),
9331       MachineMemOperand::MOInvariant | MachineMemOperand::MODereferenceable);
9332   Chain = FuncTLVGet.getValue(1);
9333 
9334   // Extend loaded pointer if necessary (i.e. if ILP32) to DAG pointer.
9335   FuncTLVGet = DAG.getZExtOrTrunc(FuncTLVGet, DL, PtrVT);
9336 
9337   MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo();
9338   MFI.setAdjustsStack(true);
9339 
9340   // TLS calls preserve all registers except those that absolutely must be
9341   // trashed: X0 (it takes an argument), LR (it's a call) and NZCV (let's not be
9342   // silly).
9343   const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
9344   const uint32_t *Mask = TRI->getTLSCallPreservedMask();
9345   if (Subtarget->hasCustomCallingConv())
9346     TRI->UpdateCustomCallPreservedMask(DAG.getMachineFunction(), &Mask);
9347 
9348   // Finally, we can make the call. This is just a degenerate version of a
9349   // normal AArch64 call node: x0 takes the address of the descriptor, and
9350   // returns the address of the variable in this thread.
9351   Chain = DAG.getCopyToReg(Chain, DL, AArch64::X0, DescAddr, SDValue());
9352 
9353   unsigned Opcode = AArch64ISD::CALL;
9354   SmallVector<SDValue, 8> Ops;
9355   Ops.push_back(Chain);
9356   Ops.push_back(FuncTLVGet);
9357 
9358   // With ptrauth-calls, the tlv access thunk pointer is authenticated (IA, 0).
9359   if (DAG.getMachineFunction().getFunction().hasFnAttribute("ptrauth-calls")) {
9360     Opcode = AArch64ISD::AUTH_CALL;
9361     Ops.push_back(DAG.getTargetConstant(AArch64PACKey::IA, DL, MVT::i32));
9362     Ops.push_back(DAG.getTargetConstant(0, DL, MVT::i64)); // Integer Disc.
9363     Ops.push_back(DAG.getRegister(AArch64::NoRegister, MVT::i64)); // Addr Disc.
9364   }
9365 
9366   Ops.push_back(DAG.getRegister(AArch64::X0, MVT::i64));
9367   Ops.push_back(DAG.getRegisterMask(Mask));
9368   Ops.push_back(Chain.getValue(1));
9369   Chain = DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
9370   return DAG.getCopyFromReg(Chain, DL, AArch64::X0, PtrVT, Chain.getValue(1));
9371 }
9372 
9373 /// Convert a thread-local variable reference into a sequence of instructions to
9374 /// compute the variable's address for the local exec TLS model of ELF targets.
9375 /// The sequence depends on the maximum TLS area size.
LowerELFTLSLocalExec(const GlobalValue * GV,SDValue ThreadBase,const SDLoc & DL,SelectionDAG & DAG) const9376 SDValue AArch64TargetLowering::LowerELFTLSLocalExec(const GlobalValue *GV,
9377                                                     SDValue ThreadBase,
9378                                                     const SDLoc &DL,
9379                                                     SelectionDAG &DAG) const {
9380   EVT PtrVT = getPointerTy(DAG.getDataLayout());
9381   SDValue TPOff, Addr;
9382 
9383   switch (DAG.getTarget().Options.TLSSize) {
9384   default:
9385     llvm_unreachable("Unexpected TLS size");
9386 
9387   case 12: {
9388     // mrs   x0, TPIDR_EL0
9389     // add   x0, x0, :tprel_lo12:a
9390     SDValue Var = DAG.getTargetGlobalAddress(
9391         GV, DL, PtrVT, 0, AArch64II::MO_TLS | AArch64II::MO_PAGEOFF);
9392     return SDValue(DAG.getMachineNode(AArch64::ADDXri, DL, PtrVT, ThreadBase,
9393                                       Var,
9394                                       DAG.getTargetConstant(0, DL, MVT::i32)),
9395                    0);
9396   }
9397 
9398   case 24: {
9399     // mrs   x0, TPIDR_EL0
9400     // add   x0, x0, :tprel_hi12:a
9401     // add   x0, x0, :tprel_lo12_nc:a
9402     SDValue HiVar = DAG.getTargetGlobalAddress(
9403         GV, DL, PtrVT, 0, AArch64II::MO_TLS | AArch64II::MO_HI12);
9404     SDValue LoVar = DAG.getTargetGlobalAddress(
9405         GV, DL, PtrVT, 0,
9406         AArch64II::MO_TLS | AArch64II::MO_PAGEOFF | AArch64II::MO_NC);
9407     Addr = SDValue(DAG.getMachineNode(AArch64::ADDXri, DL, PtrVT, ThreadBase,
9408                                       HiVar,
9409                                       DAG.getTargetConstant(0, DL, MVT::i32)),
9410                    0);
9411     return SDValue(DAG.getMachineNode(AArch64::ADDXri, DL, PtrVT, Addr,
9412                                       LoVar,
9413                                       DAG.getTargetConstant(0, DL, MVT::i32)),
9414                    0);
9415   }
9416 
9417   case 32: {
9418     // mrs   x1, TPIDR_EL0
9419     // movz  x0, #:tprel_g1:a
9420     // movk  x0, #:tprel_g0_nc:a
9421     // add   x0, x1, x0
9422     SDValue HiVar = DAG.getTargetGlobalAddress(
9423         GV, DL, PtrVT, 0, AArch64II::MO_TLS | AArch64II::MO_G1);
9424     SDValue LoVar = DAG.getTargetGlobalAddress(
9425         GV, DL, PtrVT, 0,
9426         AArch64II::MO_TLS | AArch64II::MO_G0 | AArch64II::MO_NC);
9427     TPOff = SDValue(DAG.getMachineNode(AArch64::MOVZXi, DL, PtrVT, HiVar,
9428                                        DAG.getTargetConstant(16, DL, MVT::i32)),
9429                     0);
9430     TPOff = SDValue(DAG.getMachineNode(AArch64::MOVKXi, DL, PtrVT, TPOff, LoVar,
9431                                        DAG.getTargetConstant(0, DL, MVT::i32)),
9432                     0);
9433     return DAG.getNode(ISD::ADD, DL, PtrVT, ThreadBase, TPOff);
9434   }
9435 
9436   case 48: {
9437     // mrs   x1, TPIDR_EL0
9438     // movz  x0, #:tprel_g2:a
9439     // movk  x0, #:tprel_g1_nc:a
9440     // movk  x0, #:tprel_g0_nc:a
9441     // add   x0, x1, x0
9442     SDValue HiVar = DAG.getTargetGlobalAddress(
9443         GV, DL, PtrVT, 0, AArch64II::MO_TLS | AArch64II::MO_G2);
9444     SDValue MiVar = DAG.getTargetGlobalAddress(
9445         GV, DL, PtrVT, 0,
9446         AArch64II::MO_TLS | AArch64II::MO_G1 | AArch64II::MO_NC);
9447     SDValue LoVar = DAG.getTargetGlobalAddress(
9448         GV, DL, PtrVT, 0,
9449         AArch64II::MO_TLS | AArch64II::MO_G0 | AArch64II::MO_NC);
9450     TPOff = SDValue(DAG.getMachineNode(AArch64::MOVZXi, DL, PtrVT, HiVar,
9451                                        DAG.getTargetConstant(32, DL, MVT::i32)),
9452                     0);
9453     TPOff = SDValue(DAG.getMachineNode(AArch64::MOVKXi, DL, PtrVT, TPOff, MiVar,
9454                                        DAG.getTargetConstant(16, DL, MVT::i32)),
9455                     0);
9456     TPOff = SDValue(DAG.getMachineNode(AArch64::MOVKXi, DL, PtrVT, TPOff, LoVar,
9457                                        DAG.getTargetConstant(0, DL, MVT::i32)),
9458                     0);
9459     return DAG.getNode(ISD::ADD, DL, PtrVT, ThreadBase, TPOff);
9460   }
9461   }
9462 }
9463 
9464 /// When accessing thread-local variables under either the general-dynamic or
9465 /// local-dynamic system, we make a "TLS-descriptor" call. The variable will
9466 /// have a descriptor, accessible via a PC-relative ADRP, and whose first entry
9467 /// is a function pointer to carry out the resolution.
9468 ///
9469 /// The sequence is:
9470 ///    adrp  x0, :tlsdesc:var
9471 ///    ldr   x1, [x0, #:tlsdesc_lo12:var]
9472 ///    add   x0, x0, #:tlsdesc_lo12:var
9473 ///    .tlsdesccall var
9474 ///    blr   x1
9475 ///    (TPIDR_EL0 offset now in x0)
9476 ///
9477 ///  The above sequence must be produced unscheduled, to enable the linker to
9478 ///  optimize/relax this sequence.
9479 ///  Therefore, a pseudo-instruction (TLSDESC_CALLSEQ) is used to represent the
9480 ///  above sequence, and expanded really late in the compilation flow, to ensure
9481 ///  the sequence is produced as per above.
LowerELFTLSDescCallSeq(SDValue SymAddr,const SDLoc & DL,SelectionDAG & DAG) const9482 SDValue AArch64TargetLowering::LowerELFTLSDescCallSeq(SDValue SymAddr,
9483                                                       const SDLoc &DL,
9484                                                       SelectionDAG &DAG) const {
9485   EVT PtrVT = getPointerTy(DAG.getDataLayout());
9486 
9487   SDValue Chain = DAG.getEntryNode();
9488   SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);
9489 
9490   Chain =
9491       DAG.getNode(AArch64ISD::TLSDESC_CALLSEQ, DL, NodeTys, {Chain, SymAddr});
9492   SDValue Glue = Chain.getValue(1);
9493 
9494   return DAG.getCopyFromReg(Chain, DL, AArch64::X0, PtrVT, Glue);
9495 }
9496 
9497 SDValue
LowerELFGlobalTLSAddress(SDValue Op,SelectionDAG & DAG) const9498 AArch64TargetLowering::LowerELFGlobalTLSAddress(SDValue Op,
9499                                                 SelectionDAG &DAG) const {
9500   assert(Subtarget->isTargetELF() && "This function expects an ELF target");
9501 
9502   const GlobalAddressSDNode *GA = cast<GlobalAddressSDNode>(Op);
9503 
9504   TLSModel::Model Model = getTargetMachine().getTLSModel(GA->getGlobal());
9505 
9506   if (!EnableAArch64ELFLocalDynamicTLSGeneration) {
9507     if (Model == TLSModel::LocalDynamic)
9508       Model = TLSModel::GeneralDynamic;
9509   }
9510 
9511   if (getTargetMachine().getCodeModel() == CodeModel::Large &&
9512       Model != TLSModel::LocalExec)
9513     report_fatal_error("ELF TLS only supported in small memory model or "
9514                        "in local exec TLS model");
9515   // Different choices can be made for the maximum size of the TLS area for a
9516   // module. For the small address model, the default TLS size is 16MiB and the
9517   // maximum TLS size is 4GiB.
9518   // FIXME: add tiny and large code model support for TLS access models other
9519   // than local exec. We currently generate the same code as small for tiny,
9520   // which may be larger than needed.
9521 
9522   SDValue TPOff;
9523   EVT PtrVT = getPointerTy(DAG.getDataLayout());
9524   SDLoc DL(Op);
9525   const GlobalValue *GV = GA->getGlobal();
9526 
9527   SDValue ThreadBase = DAG.getNode(AArch64ISD::THREAD_POINTER, DL, PtrVT);
9528 
9529   if (Model == TLSModel::LocalExec) {
9530     return LowerELFTLSLocalExec(GV, ThreadBase, DL, DAG);
9531   } else if (Model == TLSModel::InitialExec) {
9532     TPOff = DAG.getTargetGlobalAddress(GV, DL, PtrVT, 0, AArch64II::MO_TLS);
9533     TPOff = DAG.getNode(AArch64ISD::LOADgot, DL, PtrVT, TPOff);
9534   } else if (Model == TLSModel::LocalDynamic) {
9535     // Local-dynamic accesses proceed in two phases. A general-dynamic TLS
9536     // descriptor call against the special symbol _TLS_MODULE_BASE_ to calculate
9537     // the beginning of the module's TLS region, followed by a DTPREL offset
9538     // calculation.
9539 
9540     // These accesses will need deduplicating if there's more than one.
9541     AArch64FunctionInfo *MFI =
9542         DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
9543     MFI->incNumLocalDynamicTLSAccesses();
9544 
9545     // The call needs a relocation too for linker relaxation. It doesn't make
9546     // sense to call it MO_PAGE or MO_PAGEOFF though so we need another copy of
9547     // the address.
9548     SDValue SymAddr = DAG.getTargetExternalSymbol("_TLS_MODULE_BASE_", PtrVT,
9549                                                   AArch64II::MO_TLS);
9550 
9551     // Now we can calculate the offset from TPIDR_EL0 to this module's
9552     // thread-local area.
9553     TPOff = LowerELFTLSDescCallSeq(SymAddr, DL, DAG);
9554 
9555     // Now use :dtprel_whatever: operations to calculate this variable's offset
9556     // in its thread-storage area.
9557     SDValue HiVar = DAG.getTargetGlobalAddress(
9558         GV, DL, MVT::i64, 0, AArch64II::MO_TLS | AArch64II::MO_HI12);
9559     SDValue LoVar = DAG.getTargetGlobalAddress(
9560         GV, DL, MVT::i64, 0,
9561         AArch64II::MO_TLS | AArch64II::MO_PAGEOFF | AArch64II::MO_NC);
9562 
9563     TPOff = SDValue(DAG.getMachineNode(AArch64::ADDXri, DL, PtrVT, TPOff, HiVar,
9564                                        DAG.getTargetConstant(0, DL, MVT::i32)),
9565                     0);
9566     TPOff = SDValue(DAG.getMachineNode(AArch64::ADDXri, DL, PtrVT, TPOff, LoVar,
9567                                        DAG.getTargetConstant(0, DL, MVT::i32)),
9568                     0);
9569   } else if (Model == TLSModel::GeneralDynamic) {
9570     // The call needs a relocation too for linker relaxation. It doesn't make
9571     // sense to call it MO_PAGE or MO_PAGEOFF though so we need another copy of
9572     // the address.
9573     SDValue SymAddr =
9574         DAG.getTargetGlobalAddress(GV, DL, PtrVT, 0, AArch64II::MO_TLS);
9575 
9576     // Finally we can make a call to calculate the offset from tpidr_el0.
9577     TPOff = LowerELFTLSDescCallSeq(SymAddr, DL, DAG);
9578   } else
9579     llvm_unreachable("Unsupported ELF TLS access model");
9580 
9581   return DAG.getNode(ISD::ADD, DL, PtrVT, ThreadBase, TPOff);
9582 }
9583 
9584 SDValue
LowerWindowsGlobalTLSAddress(SDValue Op,SelectionDAG & DAG) const9585 AArch64TargetLowering::LowerWindowsGlobalTLSAddress(SDValue Op,
9586                                                     SelectionDAG &DAG) const {
9587   assert(Subtarget->isTargetWindows() && "Windows specific TLS lowering");
9588 
9589   SDValue Chain = DAG.getEntryNode();
9590   EVT PtrVT = getPointerTy(DAG.getDataLayout());
9591   SDLoc DL(Op);
9592 
9593   SDValue TEB = DAG.getRegister(AArch64::X18, MVT::i64);
9594 
9595   // Load the ThreadLocalStoragePointer from the TEB
9596   // A pointer to the TLS array is located at offset 0x58 from the TEB.
9597   SDValue TLSArray =
9598       DAG.getNode(ISD::ADD, DL, PtrVT, TEB, DAG.getIntPtrConstant(0x58, DL));
9599   TLSArray = DAG.getLoad(PtrVT, DL, Chain, TLSArray, MachinePointerInfo());
9600   Chain = TLSArray.getValue(1);
9601 
9602   // Load the TLS index from the C runtime;
9603   // This does the same as getAddr(), but without having a GlobalAddressSDNode.
9604   // This also does the same as LOADgot, but using a generic i32 load,
9605   // while LOADgot only loads i64.
9606   SDValue TLSIndexHi =
9607       DAG.getTargetExternalSymbol("_tls_index", PtrVT, AArch64II::MO_PAGE);
9608   SDValue TLSIndexLo = DAG.getTargetExternalSymbol(
9609       "_tls_index", PtrVT, AArch64II::MO_PAGEOFF | AArch64II::MO_NC);
9610   SDValue ADRP = DAG.getNode(AArch64ISD::ADRP, DL, PtrVT, TLSIndexHi);
9611   SDValue TLSIndex =
9612       DAG.getNode(AArch64ISD::ADDlow, DL, PtrVT, ADRP, TLSIndexLo);
9613   TLSIndex = DAG.getLoad(MVT::i32, DL, Chain, TLSIndex, MachinePointerInfo());
9614   Chain = TLSIndex.getValue(1);
9615 
9616   // The pointer to the thread's TLS data area is at the TLS Index scaled by 8
9617   // offset into the TLSArray.
9618   TLSIndex = DAG.getNode(ISD::ZERO_EXTEND, DL, PtrVT, TLSIndex);
9619   SDValue Slot = DAG.getNode(ISD::SHL, DL, PtrVT, TLSIndex,
9620                              DAG.getConstant(3, DL, PtrVT));
9621   SDValue TLS = DAG.getLoad(PtrVT, DL, Chain,
9622                             DAG.getNode(ISD::ADD, DL, PtrVT, TLSArray, Slot),
9623                             MachinePointerInfo());
9624   Chain = TLS.getValue(1);
9625 
9626   const GlobalAddressSDNode *GA = cast<GlobalAddressSDNode>(Op);
9627   const GlobalValue *GV = GA->getGlobal();
9628   SDValue TGAHi = DAG.getTargetGlobalAddress(
9629       GV, DL, PtrVT, 0, AArch64II::MO_TLS | AArch64II::MO_HI12);
9630   SDValue TGALo = DAG.getTargetGlobalAddress(
9631       GV, DL, PtrVT, 0,
9632       AArch64II::MO_TLS | AArch64II::MO_PAGEOFF | AArch64II::MO_NC);
9633 
9634   // Add the offset from the start of the .tls section (section base).
9635   SDValue Addr =
9636       SDValue(DAG.getMachineNode(AArch64::ADDXri, DL, PtrVT, TLS, TGAHi,
9637                                  DAG.getTargetConstant(0, DL, MVT::i32)),
9638               0);
9639   Addr = DAG.getNode(AArch64ISD::ADDlow, DL, PtrVT, Addr, TGALo);
9640   return Addr;
9641 }
9642 
LowerGlobalTLSAddress(SDValue Op,SelectionDAG & DAG) const9643 SDValue AArch64TargetLowering::LowerGlobalTLSAddress(SDValue Op,
9644                                                      SelectionDAG &DAG) const {
9645   const GlobalAddressSDNode *GA = cast<GlobalAddressSDNode>(Op);
9646   if (DAG.getTarget().useEmulatedTLS())
9647     return LowerToTLSEmulatedModel(GA, DAG);
9648 
9649   if (Subtarget->isTargetDarwin())
9650     return LowerDarwinGlobalTLSAddress(Op, DAG);
9651   if (Subtarget->isTargetELF())
9652     return LowerELFGlobalTLSAddress(Op, DAG);
9653   if (Subtarget->isTargetWindows())
9654     return LowerWindowsGlobalTLSAddress(Op, DAG);
9655 
9656   llvm_unreachable("Unexpected platform trying to use TLS");
9657 }
9658 
9659 //===----------------------------------------------------------------------===//
9660 //                      PtrAuthGlobalAddress lowering
9661 //
9662 // We have 3 lowering alternatives to choose from:
9663 // - MOVaddrPAC: similar to MOVaddr, with added PAC.
9664 //   If the GV doesn't need a GOT load (i.e., is locally defined)
9665 //   materialize the pointer using adrp+add+pac. See LowerMOVaddrPAC.
9666 //
9667 // - LOADgotPAC: similar to LOADgot, with added PAC.
9668 //   If the GV needs a GOT load, materialize the pointer using the usual
9669 //   GOT adrp+ldr, +pac. Pointers in GOT are assumed to be not signed, the GOT
9670 //   section is assumed to be read-only (for example, via relro mechanism). See
9671 //   LowerMOVaddrPAC.
9672 //
9673 // - LOADauthptrstatic: similar to LOADgot, but use a
9674 //   special stub slot instead of a GOT slot.
9675 //   Load a signed pointer for symbol 'sym' from a stub slot named
9676 //   'sym$auth_ptr$key$disc' filled by dynamic linker during relocation
9677 //   resolving. This usually lowers to adrp+ldr, but also emits an entry into
9678 //   .data with an @AUTH relocation. See LowerLOADauthptrstatic.
9679 //
9680 // All 3 are pseudos that are expand late to longer sequences: this lets us
9681 // provide integrity guarantees on the to-be-signed intermediate values.
9682 //
9683 // LOADauthptrstatic is undesirable because it requires a large section filled
9684 // with often similarly-signed pointers, making it a good harvesting target.
9685 // Thus, it's only used for ptrauth references to extern_weak to avoid null
9686 // checks.
9687 
LowerPtrAuthGlobalAddressStatically(SDValue TGA,SDLoc DL,EVT VT,AArch64PACKey::ID KeyC,SDValue Discriminator,SDValue AddrDiscriminator,SelectionDAG & DAG) const9688 SDValue AArch64TargetLowering::LowerPtrAuthGlobalAddressStatically(
9689     SDValue TGA, SDLoc DL, EVT VT, AArch64PACKey::ID KeyC,
9690     SDValue Discriminator, SDValue AddrDiscriminator, SelectionDAG &DAG) const {
9691   const auto *TGN = cast<GlobalAddressSDNode>(TGA.getNode());
9692   assert(TGN->getGlobal()->hasExternalWeakLinkage());
9693 
9694   // Offsets and extern_weak don't mix well: ptrauth aside, you'd get the
9695   // offset alone as a pointer if the symbol wasn't available, which would
9696   // probably break null checks in users. Ptrauth complicates things further:
9697   // error out.
9698   if (TGN->getOffset() != 0)
9699     report_fatal_error(
9700         "unsupported non-zero offset in weak ptrauth global reference");
9701 
9702   if (!isNullConstant(AddrDiscriminator))
9703     report_fatal_error("unsupported weak addr-div ptrauth global");
9704 
9705   SDValue Key = DAG.getTargetConstant(KeyC, DL, MVT::i32);
9706   return SDValue(DAG.getMachineNode(AArch64::LOADauthptrstatic, DL, MVT::i64,
9707                                     {TGA, Key, Discriminator}),
9708                  0);
9709 }
9710 
9711 SDValue
LowerPtrAuthGlobalAddress(SDValue Op,SelectionDAG & DAG) const9712 AArch64TargetLowering::LowerPtrAuthGlobalAddress(SDValue Op,
9713                                                  SelectionDAG &DAG) const {
9714   SDValue Ptr = Op.getOperand(0);
9715   uint64_t KeyC = Op.getConstantOperandVal(1);
9716   SDValue AddrDiscriminator = Op.getOperand(2);
9717   uint64_t DiscriminatorC = Op.getConstantOperandVal(3);
9718   EVT VT = Op.getValueType();
9719   SDLoc DL(Op);
9720 
9721   if (KeyC > AArch64PACKey::LAST)
9722     report_fatal_error("key in ptrauth global out of range [0, " +
9723                        Twine((int)AArch64PACKey::LAST) + "]");
9724 
9725   // Blend only works if the integer discriminator is 16-bit wide.
9726   if (!isUInt<16>(DiscriminatorC))
9727     report_fatal_error(
9728         "constant discriminator in ptrauth global out of range [0, 0xffff]");
9729 
9730   // Choosing between 3 lowering alternatives is target-specific.
9731   if (!Subtarget->isTargetELF() && !Subtarget->isTargetMachO())
9732     report_fatal_error("ptrauth global lowering only supported on MachO/ELF");
9733 
9734   int64_t PtrOffsetC = 0;
9735   if (Ptr.getOpcode() == ISD::ADD) {
9736     PtrOffsetC = Ptr.getConstantOperandVal(1);
9737     Ptr = Ptr.getOperand(0);
9738   }
9739   const auto *PtrN = cast<GlobalAddressSDNode>(Ptr.getNode());
9740   const GlobalValue *PtrGV = PtrN->getGlobal();
9741 
9742   // Classify the reference to determine whether it needs a GOT load.
9743   const unsigned OpFlags =
9744       Subtarget->ClassifyGlobalReference(PtrGV, getTargetMachine());
9745   const bool NeedsGOTLoad = ((OpFlags & AArch64II::MO_GOT) != 0);
9746   assert(((OpFlags & (~AArch64II::MO_GOT)) == 0) &&
9747          "unsupported non-GOT op flags on ptrauth global reference");
9748 
9749   // Fold any offset into the GV; our pseudos expect it there.
9750   PtrOffsetC += PtrN->getOffset();
9751   SDValue TPtr = DAG.getTargetGlobalAddress(PtrGV, DL, VT, PtrOffsetC,
9752                                             /*TargetFlags=*/0);
9753   assert(PtrN->getTargetFlags() == 0 &&
9754          "unsupported target flags on ptrauth global");
9755 
9756   SDValue Key = DAG.getTargetConstant(KeyC, DL, MVT::i32);
9757   SDValue Discriminator = DAG.getTargetConstant(DiscriminatorC, DL, MVT::i64);
9758   SDValue TAddrDiscriminator = !isNullConstant(AddrDiscriminator)
9759                                    ? AddrDiscriminator
9760                                    : DAG.getRegister(AArch64::XZR, MVT::i64);
9761 
9762   // No GOT load needed -> MOVaddrPAC
9763   if (!NeedsGOTLoad) {
9764     assert(!PtrGV->hasExternalWeakLinkage() && "extern_weak should use GOT");
9765     return SDValue(
9766         DAG.getMachineNode(AArch64::MOVaddrPAC, DL, MVT::i64,
9767                            {TPtr, Key, TAddrDiscriminator, Discriminator}),
9768         0);
9769   }
9770 
9771   // GOT load -> LOADgotPAC
9772   // Note that we disallow extern_weak refs to avoid null checks later.
9773   if (!PtrGV->hasExternalWeakLinkage())
9774     return SDValue(
9775         DAG.getMachineNode(AArch64::LOADgotPAC, DL, MVT::i64,
9776                            {TPtr, Key, TAddrDiscriminator, Discriminator}),
9777         0);
9778 
9779   // extern_weak ref -> LOADauthptrstatic
9780   return LowerPtrAuthGlobalAddressStatically(
9781       TPtr, DL, VT, (AArch64PACKey::ID)KeyC, Discriminator, AddrDiscriminator,
9782       DAG);
9783 }
9784 
9785 // Looks through \param Val to determine the bit that can be used to
9786 // check the sign of the value. It returns the unextended value and
9787 // the sign bit position.
lookThroughSignExtension(SDValue Val)9788 std::pair<SDValue, uint64_t> lookThroughSignExtension(SDValue Val) {
9789   if (Val.getOpcode() == ISD::SIGN_EXTEND_INREG)
9790     return {Val.getOperand(0),
9791             cast<VTSDNode>(Val.getOperand(1))->getVT().getFixedSizeInBits() -
9792                 1};
9793 
9794   if (Val.getOpcode() == ISD::SIGN_EXTEND)
9795     return {Val.getOperand(0),
9796             Val.getOperand(0)->getValueType(0).getFixedSizeInBits() - 1};
9797 
9798   return {Val, Val.getValueSizeInBits() - 1};
9799 }
9800 
LowerBR_CC(SDValue Op,SelectionDAG & DAG) const9801 SDValue AArch64TargetLowering::LowerBR_CC(SDValue Op, SelectionDAG &DAG) const {
9802   SDValue Chain = Op.getOperand(0);
9803   ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(1))->get();
9804   SDValue LHS = Op.getOperand(2);
9805   SDValue RHS = Op.getOperand(3);
9806   SDValue Dest = Op.getOperand(4);
9807   SDLoc dl(Op);
9808 
9809   MachineFunction &MF = DAG.getMachineFunction();
9810   // Speculation tracking/SLH assumes that optimized TB(N)Z/CB(N)Z instructions
9811   // will not be produced, as they are conditional branch instructions that do
9812   // not set flags.
9813   bool ProduceNonFlagSettingCondBr =
9814       !MF.getFunction().hasFnAttribute(Attribute::SpeculativeLoadHardening);
9815 
9816   // Handle f128 first, since lowering it will result in comparing the return
9817   // value of a libcall against zero, which is just what the rest of LowerBR_CC
9818   // is expecting to deal with.
9819   if (LHS.getValueType() == MVT::f128) {
9820     softenSetCCOperands(DAG, MVT::f128, LHS, RHS, CC, dl, LHS, RHS);
9821 
9822     // If softenSetCCOperands returned a scalar, we need to compare the result
9823     // against zero to select between true and false values.
9824     if (!RHS.getNode()) {
9825       RHS = DAG.getConstant(0, dl, LHS.getValueType());
9826       CC = ISD::SETNE;
9827     }
9828   }
9829 
9830   // Optimize {s|u}{add|sub|mul}.with.overflow feeding into a branch
9831   // instruction.
9832   if (ISD::isOverflowIntrOpRes(LHS) && isOneConstant(RHS) &&
9833       (CC == ISD::SETEQ || CC == ISD::SETNE)) {
9834     // Only lower legal XALUO ops.
9835     if (!DAG.getTargetLoweringInfo().isTypeLegal(LHS->getValueType(0)))
9836       return SDValue();
9837 
9838     // The actual operation with overflow check.
9839     AArch64CC::CondCode OFCC;
9840     SDValue Value, Overflow;
9841     std::tie(Value, Overflow) = getAArch64XALUOOp(OFCC, LHS.getValue(0), DAG);
9842 
9843     if (CC == ISD::SETNE)
9844       OFCC = getInvertedCondCode(OFCC);
9845     SDValue CCVal = DAG.getConstant(OFCC, dl, MVT::i32);
9846 
9847     return DAG.getNode(AArch64ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
9848                        Overflow);
9849   }
9850 
9851   if (LHS.getValueType().isInteger()) {
9852     assert((LHS.getValueType() == RHS.getValueType()) &&
9853            (LHS.getValueType() == MVT::i32 || LHS.getValueType() == MVT::i64));
9854 
9855     // If the RHS of the comparison is zero, we can potentially fold this
9856     // to a specialized branch.
9857     const ConstantSDNode *RHSC = dyn_cast<ConstantSDNode>(RHS);
9858     if (RHSC && RHSC->getZExtValue() == 0 && ProduceNonFlagSettingCondBr) {
9859       if (CC == ISD::SETEQ) {
9860         // See if we can use a TBZ to fold in an AND as well.
9861         // TBZ has a smaller branch displacement than CBZ.  If the offset is
9862         // out of bounds, a late MI-layer pass rewrites branches.
9863         // 403.gcc is an example that hits this case.
9864         if (LHS.getOpcode() == ISD::AND &&
9865             isa<ConstantSDNode>(LHS.getOperand(1)) &&
9866             isPowerOf2_64(LHS.getConstantOperandVal(1))) {
9867           SDValue Test = LHS.getOperand(0);
9868           uint64_t Mask = LHS.getConstantOperandVal(1);
9869           return DAG.getNode(AArch64ISD::TBZ, dl, MVT::Other, Chain, Test,
9870                              DAG.getConstant(Log2_64(Mask), dl, MVT::i64),
9871                              Dest);
9872         }
9873 
9874         return DAG.getNode(AArch64ISD::CBZ, dl, MVT::Other, Chain, LHS, Dest);
9875       } else if (CC == ISD::SETNE) {
9876         // See if we can use a TBZ to fold in an AND as well.
9877         // TBZ has a smaller branch displacement than CBZ.  If the offset is
9878         // out of bounds, a late MI-layer pass rewrites branches.
9879         // 403.gcc is an example that hits this case.
9880         if (LHS.getOpcode() == ISD::AND &&
9881             isa<ConstantSDNode>(LHS.getOperand(1)) &&
9882             isPowerOf2_64(LHS.getConstantOperandVal(1))) {
9883           SDValue Test = LHS.getOperand(0);
9884           uint64_t Mask = LHS.getConstantOperandVal(1);
9885           return DAG.getNode(AArch64ISD::TBNZ, dl, MVT::Other, Chain, Test,
9886                              DAG.getConstant(Log2_64(Mask), dl, MVT::i64),
9887                              Dest);
9888         }
9889 
9890         return DAG.getNode(AArch64ISD::CBNZ, dl, MVT::Other, Chain, LHS, Dest);
9891       } else if (CC == ISD::SETLT && LHS.getOpcode() != ISD::AND) {
9892         // Don't combine AND since emitComparison converts the AND to an ANDS
9893         // (a.k.a. TST) and the test in the test bit and branch instruction
9894         // becomes redundant.  This would also increase register pressure.
9895         uint64_t SignBitPos;
9896         std::tie(LHS, SignBitPos) = lookThroughSignExtension(LHS);
9897         return DAG.getNode(AArch64ISD::TBNZ, dl, MVT::Other, Chain, LHS,
9898                            DAG.getConstant(SignBitPos, dl, MVT::i64), Dest);
9899       }
9900     }
9901     if (RHSC && RHSC->getSExtValue() == -1 && CC == ISD::SETGT &&
9902         LHS.getOpcode() != ISD::AND && ProduceNonFlagSettingCondBr) {
9903       // Don't combine AND since emitComparison converts the AND to an ANDS
9904       // (a.k.a. TST) and the test in the test bit and branch instruction
9905       // becomes redundant.  This would also increase register pressure.
9906       uint64_t SignBitPos;
9907       std::tie(LHS, SignBitPos) = lookThroughSignExtension(LHS);
9908       return DAG.getNode(AArch64ISD::TBZ, dl, MVT::Other, Chain, LHS,
9909                          DAG.getConstant(SignBitPos, dl, MVT::i64), Dest);
9910     }
9911 
9912     SDValue CCVal;
9913     SDValue Cmp = getAArch64Cmp(LHS, RHS, CC, CCVal, DAG, dl);
9914     return DAG.getNode(AArch64ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
9915                        Cmp);
9916   }
9917 
9918   assert(LHS.getValueType() == MVT::f16 || LHS.getValueType() == MVT::bf16 ||
9919          LHS.getValueType() == MVT::f32 || LHS.getValueType() == MVT::f64);
9920 
9921   // Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't totally
9922   // clean.  Some of them require two branches to implement.
9923   SDValue Cmp = emitComparison(LHS, RHS, CC, dl, DAG);
9924   AArch64CC::CondCode CC1, CC2;
9925   changeFPCCToAArch64CC(CC, CC1, CC2);
9926   SDValue CC1Val = DAG.getConstant(CC1, dl, MVT::i32);
9927   SDValue BR1 =
9928       DAG.getNode(AArch64ISD::BRCOND, dl, MVT::Other, Chain, Dest, CC1Val, Cmp);
9929   if (CC2 != AArch64CC::AL) {
9930     SDValue CC2Val = DAG.getConstant(CC2, dl, MVT::i32);
9931     return DAG.getNode(AArch64ISD::BRCOND, dl, MVT::Other, BR1, Dest, CC2Val,
9932                        Cmp);
9933   }
9934 
9935   return BR1;
9936 }
9937 
LowerFCOPYSIGN(SDValue Op,SelectionDAG & DAG) const9938 SDValue AArch64TargetLowering::LowerFCOPYSIGN(SDValue Op,
9939                                               SelectionDAG &DAG) const {
9940   if (!Subtarget->isNeonAvailable() &&
9941       !Subtarget->useSVEForFixedLengthVectors())
9942     return SDValue();
9943 
9944   EVT VT = Op.getValueType();
9945   EVT IntVT = VT.changeTypeToInteger();
9946   SDLoc DL(Op);
9947 
9948   SDValue In1 = Op.getOperand(0);
9949   SDValue In2 = Op.getOperand(1);
9950   EVT SrcVT = In2.getValueType();
9951 
9952   if (!SrcVT.bitsEq(VT))
9953     In2 = DAG.getFPExtendOrRound(In2, DL, VT);
9954 
9955   if (VT.isScalableVector())
9956     IntVT =
9957         getPackedSVEVectorVT(VT.getVectorElementType().changeTypeToInteger());
9958 
9959   if (VT.isFixedLengthVector() &&
9960       useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable())) {
9961     EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
9962 
9963     In1 = convertToScalableVector(DAG, ContainerVT, In1);
9964     In2 = convertToScalableVector(DAG, ContainerVT, In2);
9965 
9966     SDValue Res = DAG.getNode(ISD::FCOPYSIGN, DL, ContainerVT, In1, In2);
9967     return convertFromScalableVector(DAG, VT, Res);
9968   }
9969 
9970   auto BitCast = [this](EVT VT, SDValue Op, SelectionDAG &DAG) {
9971     if (VT.isScalableVector())
9972       return getSVESafeBitCast(VT, Op, DAG);
9973 
9974     return DAG.getBitcast(VT, Op);
9975   };
9976 
9977   SDValue VecVal1, VecVal2;
9978   EVT VecVT;
9979   auto SetVecVal = [&](int Idx = -1) {
9980     if (!VT.isVector()) {
9981       VecVal1 =
9982           DAG.getTargetInsertSubreg(Idx, DL, VecVT, DAG.getUNDEF(VecVT), In1);
9983       VecVal2 =
9984           DAG.getTargetInsertSubreg(Idx, DL, VecVT, DAG.getUNDEF(VecVT), In2);
9985     } else {
9986       VecVal1 = BitCast(VecVT, In1, DAG);
9987       VecVal2 = BitCast(VecVT, In2, DAG);
9988     }
9989   };
9990   if (VT.isVector()) {
9991     VecVT = IntVT;
9992     SetVecVal();
9993   } else if (VT == MVT::f64) {
9994     VecVT = MVT::v2i64;
9995     SetVecVal(AArch64::dsub);
9996   } else if (VT == MVT::f32) {
9997     VecVT = MVT::v4i32;
9998     SetVecVal(AArch64::ssub);
9999   } else if (VT == MVT::f16 || VT == MVT::bf16) {
10000     VecVT = MVT::v8i16;
10001     SetVecVal(AArch64::hsub);
10002   } else {
10003     llvm_unreachable("Invalid type for copysign!");
10004   }
10005 
10006   unsigned BitWidth = In1.getScalarValueSizeInBits();
10007   SDValue SignMaskV = DAG.getConstant(~APInt::getSignMask(BitWidth), DL, VecVT);
10008 
10009   // We want to materialize a mask with every bit but the high bit set, but the
10010   // AdvSIMD immediate moves cannot materialize that in a single instruction for
10011   // 64-bit elements. Instead, materialize all bits set and then negate that.
10012   if (VT == MVT::f64 || VT == MVT::v2f64) {
10013     SignMaskV = DAG.getConstant(APInt::getAllOnes(BitWidth), DL, VecVT);
10014     SignMaskV = DAG.getNode(ISD::BITCAST, DL, MVT::v2f64, SignMaskV);
10015     SignMaskV = DAG.getNode(ISD::FNEG, DL, MVT::v2f64, SignMaskV);
10016     SignMaskV = DAG.getNode(ISD::BITCAST, DL, MVT::v2i64, SignMaskV);
10017   }
10018 
10019   SDValue BSP =
10020       DAG.getNode(AArch64ISD::BSP, DL, VecVT, SignMaskV, VecVal1, VecVal2);
10021   if (VT == MVT::f16 || VT == MVT::bf16)
10022     return DAG.getTargetExtractSubreg(AArch64::hsub, DL, VT, BSP);
10023   if (VT == MVT::f32)
10024     return DAG.getTargetExtractSubreg(AArch64::ssub, DL, VT, BSP);
10025   if (VT == MVT::f64)
10026     return DAG.getTargetExtractSubreg(AArch64::dsub, DL, VT, BSP);
10027 
10028   return BitCast(VT, BSP, DAG);
10029 }
10030 
LowerCTPOP_PARITY(SDValue Op,SelectionDAG & DAG) const10031 SDValue AArch64TargetLowering::LowerCTPOP_PARITY(SDValue Op,
10032                                                  SelectionDAG &DAG) const {
10033   if (DAG.getMachineFunction().getFunction().hasFnAttribute(
10034           Attribute::NoImplicitFloat))
10035     return SDValue();
10036 
10037   EVT VT = Op.getValueType();
10038   if (VT.isScalableVector() ||
10039       useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()))
10040     return LowerToPredicatedOp(Op, DAG, AArch64ISD::CTPOP_MERGE_PASSTHRU);
10041 
10042   if (!Subtarget->isNeonAvailable())
10043     return SDValue();
10044 
10045   bool IsParity = Op.getOpcode() == ISD::PARITY;
10046   SDValue Val = Op.getOperand(0);
10047   SDLoc DL(Op);
10048 
10049   // for i32, general parity function using EORs is more efficient compared to
10050   // using floating point
10051   if (VT == MVT::i32 && IsParity)
10052     return SDValue();
10053 
10054   // If there is no CNT instruction available, GPR popcount can
10055   // be more efficiently lowered to the following sequence that uses
10056   // AdvSIMD registers/instructions as long as the copies to/from
10057   // the AdvSIMD registers are cheap.
10058   //  FMOV    D0, X0        // copy 64-bit int to vector, high bits zero'd
10059   //  CNT     V0.8B, V0.8B  // 8xbyte pop-counts
10060   //  ADDV    B0, V0.8B     // sum 8xbyte pop-counts
10061   //  UMOV    X0, V0.B[0]   // copy byte result back to integer reg
10062   if (VT == MVT::i32 || VT == MVT::i64) {
10063     if (VT == MVT::i32)
10064       Val = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, Val);
10065     Val = DAG.getNode(ISD::BITCAST, DL, MVT::v8i8, Val);
10066 
10067     SDValue CtPop = DAG.getNode(ISD::CTPOP, DL, MVT::v8i8, Val);
10068     SDValue UaddLV = DAG.getNode(AArch64ISD::UADDLV, DL, MVT::v4i32, CtPop);
10069     UaddLV = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i32, UaddLV,
10070                          DAG.getConstant(0, DL, MVT::i64));
10071 
10072     if (IsParity)
10073       UaddLV = DAG.getNode(ISD::AND, DL, MVT::i32, UaddLV,
10074                            DAG.getConstant(1, DL, MVT::i32));
10075 
10076     if (VT == MVT::i64)
10077       UaddLV = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, UaddLV);
10078     return UaddLV;
10079   } else if (VT == MVT::i128) {
10080     Val = DAG.getNode(ISD::BITCAST, DL, MVT::v16i8, Val);
10081 
10082     SDValue CtPop = DAG.getNode(ISD::CTPOP, DL, MVT::v16i8, Val);
10083     SDValue UaddLV = DAG.getNode(AArch64ISD::UADDLV, DL, MVT::v4i32, CtPop);
10084     UaddLV = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i32, UaddLV,
10085                          DAG.getConstant(0, DL, MVT::i64));
10086 
10087     if (IsParity)
10088       UaddLV = DAG.getNode(ISD::AND, DL, MVT::i32, UaddLV,
10089                            DAG.getConstant(1, DL, MVT::i32));
10090 
10091     return DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i128, UaddLV);
10092   }
10093 
10094   assert(!IsParity && "ISD::PARITY of vector types not supported");
10095 
10096   assert((VT == MVT::v1i64 || VT == MVT::v2i64 || VT == MVT::v2i32 ||
10097           VT == MVT::v4i32 || VT == MVT::v4i16 || VT == MVT::v8i16) &&
10098          "Unexpected type for custom ctpop lowering");
10099 
10100   EVT VT8Bit = VT.is64BitVector() ? MVT::v8i8 : MVT::v16i8;
10101   Val = DAG.getBitcast(VT8Bit, Val);
10102   Val = DAG.getNode(ISD::CTPOP, DL, VT8Bit, Val);
10103 
10104   if (Subtarget->hasDotProd() && VT.getScalarSizeInBits() != 16 &&
10105       VT.getVectorNumElements() >= 2) {
10106     EVT DT = VT == MVT::v2i64 ? MVT::v4i32 : VT;
10107     SDValue Zeros = DAG.getConstant(0, DL, DT);
10108     SDValue Ones = DAG.getConstant(1, DL, VT8Bit);
10109 
10110     if (VT == MVT::v2i64) {
10111       Val = DAG.getNode(AArch64ISD::UDOT, DL, DT, Zeros, Ones, Val);
10112       Val = DAG.getNode(AArch64ISD::UADDLP, DL, VT, Val);
10113     } else if (VT == MVT::v2i32) {
10114       Val = DAG.getNode(AArch64ISD::UDOT, DL, DT, Zeros, Ones, Val);
10115     } else if (VT == MVT::v4i32) {
10116       Val = DAG.getNode(AArch64ISD::UDOT, DL, DT, Zeros, Ones, Val);
10117     } else {
10118       llvm_unreachable("Unexpected type for custom ctpop lowering");
10119     }
10120 
10121     return Val;
10122   }
10123 
10124   // Widen v8i8/v16i8 CTPOP result to VT by repeatedly widening pairwise adds.
10125   unsigned EltSize = 8;
10126   unsigned NumElts = VT.is64BitVector() ? 8 : 16;
10127   while (EltSize != VT.getScalarSizeInBits()) {
10128     EltSize *= 2;
10129     NumElts /= 2;
10130     MVT WidenVT = MVT::getVectorVT(MVT::getIntegerVT(EltSize), NumElts);
10131     Val = DAG.getNode(AArch64ISD::UADDLP, DL, WidenVT, Val);
10132   }
10133 
10134   return Val;
10135 }
10136 
LowerCTTZ(SDValue Op,SelectionDAG & DAG) const10137 SDValue AArch64TargetLowering::LowerCTTZ(SDValue Op, SelectionDAG &DAG) const {
10138   EVT VT = Op.getValueType();
10139   assert(VT.isScalableVector() ||
10140          useSVEForFixedLengthVectorVT(
10141              VT, /*OverrideNEON=*/Subtarget->useSVEForFixedLengthVectors()));
10142 
10143   SDLoc DL(Op);
10144   SDValue RBIT = DAG.getNode(ISD::BITREVERSE, DL, VT, Op.getOperand(0));
10145   return DAG.getNode(ISD::CTLZ, DL, VT, RBIT);
10146 }
10147 
LowerMinMax(SDValue Op,SelectionDAG & DAG) const10148 SDValue AArch64TargetLowering::LowerMinMax(SDValue Op,
10149                                            SelectionDAG &DAG) const {
10150 
10151   EVT VT = Op.getValueType();
10152   SDLoc DL(Op);
10153   unsigned Opcode = Op.getOpcode();
10154   ISD::CondCode CC;
10155   switch (Opcode) {
10156   default:
10157     llvm_unreachable("Wrong instruction");
10158   case ISD::SMAX:
10159     CC = ISD::SETGT;
10160     break;
10161   case ISD::SMIN:
10162     CC = ISD::SETLT;
10163     break;
10164   case ISD::UMAX:
10165     CC = ISD::SETUGT;
10166     break;
10167   case ISD::UMIN:
10168     CC = ISD::SETULT;
10169     break;
10170   }
10171 
10172   if (VT.isScalableVector() ||
10173       useSVEForFixedLengthVectorVT(
10174           VT, /*OverrideNEON=*/Subtarget->useSVEForFixedLengthVectors())) {
10175     switch (Opcode) {
10176     default:
10177       llvm_unreachable("Wrong instruction");
10178     case ISD::SMAX:
10179       return LowerToPredicatedOp(Op, DAG, AArch64ISD::SMAX_PRED);
10180     case ISD::SMIN:
10181       return LowerToPredicatedOp(Op, DAG, AArch64ISD::SMIN_PRED);
10182     case ISD::UMAX:
10183       return LowerToPredicatedOp(Op, DAG, AArch64ISD::UMAX_PRED);
10184     case ISD::UMIN:
10185       return LowerToPredicatedOp(Op, DAG, AArch64ISD::UMIN_PRED);
10186     }
10187   }
10188 
10189   SDValue Op0 = Op.getOperand(0);
10190   SDValue Op1 = Op.getOperand(1);
10191   SDValue Cond = DAG.getSetCC(DL, VT, Op0, Op1, CC);
10192   return DAG.getSelect(DL, VT, Cond, Op0, Op1);
10193 }
10194 
LowerBitreverse(SDValue Op,SelectionDAG & DAG) const10195 SDValue AArch64TargetLowering::LowerBitreverse(SDValue Op,
10196                                                SelectionDAG &DAG) const {
10197   EVT VT = Op.getValueType();
10198 
10199   if (VT.isScalableVector() ||
10200       useSVEForFixedLengthVectorVT(
10201           VT, /*OverrideNEON=*/Subtarget->useSVEForFixedLengthVectors()))
10202     return LowerToPredicatedOp(Op, DAG, AArch64ISD::BITREVERSE_MERGE_PASSTHRU);
10203 
10204   SDLoc DL(Op);
10205   SDValue REVB;
10206   MVT VST;
10207 
10208   switch (VT.getSimpleVT().SimpleTy) {
10209   default:
10210     llvm_unreachable("Invalid type for bitreverse!");
10211 
10212   case MVT::v2i32: {
10213     VST = MVT::v8i8;
10214     REVB = DAG.getNode(AArch64ISD::REV32, DL, VST, Op.getOperand(0));
10215 
10216     break;
10217   }
10218 
10219   case MVT::v4i32: {
10220     VST = MVT::v16i8;
10221     REVB = DAG.getNode(AArch64ISD::REV32, DL, VST, Op.getOperand(0));
10222 
10223     break;
10224   }
10225 
10226   case MVT::v1i64: {
10227     VST = MVT::v8i8;
10228     REVB = DAG.getNode(AArch64ISD::REV64, DL, VST, Op.getOperand(0));
10229 
10230     break;
10231   }
10232 
10233   case MVT::v2i64: {
10234     VST = MVT::v16i8;
10235     REVB = DAG.getNode(AArch64ISD::REV64, DL, VST, Op.getOperand(0));
10236 
10237     break;
10238   }
10239   }
10240 
10241   return DAG.getNode(AArch64ISD::NVCAST, DL, VT,
10242                      DAG.getNode(ISD::BITREVERSE, DL, VST, REVB));
10243 }
10244 
10245 // Check whether the continuous comparison sequence.
10246 static bool
isOrXorChain(SDValue N,unsigned & Num,SmallVector<std::pair<SDValue,SDValue>,16> & WorkList)10247 isOrXorChain(SDValue N, unsigned &Num,
10248              SmallVector<std::pair<SDValue, SDValue>, 16> &WorkList) {
10249   if (Num == MaxXors)
10250     return false;
10251 
10252   // Skip the one-use zext
10253   if (N->getOpcode() == ISD::ZERO_EXTEND && N->hasOneUse())
10254     N = N->getOperand(0);
10255 
10256   // The leaf node must be XOR
10257   if (N->getOpcode() == ISD::XOR) {
10258     WorkList.push_back(std::make_pair(N->getOperand(0), N->getOperand(1)));
10259     Num++;
10260     return true;
10261   }
10262 
10263   // All the non-leaf nodes must be OR.
10264   if (N->getOpcode() != ISD::OR || !N->hasOneUse())
10265     return false;
10266 
10267   if (isOrXorChain(N->getOperand(0), Num, WorkList) &&
10268       isOrXorChain(N->getOperand(1), Num, WorkList))
10269     return true;
10270   return false;
10271 }
10272 
10273 // Transform chains of ORs and XORs, which usually outlined by memcmp/bmp.
performOrXorChainCombine(SDNode * N,SelectionDAG & DAG)10274 static SDValue performOrXorChainCombine(SDNode *N, SelectionDAG &DAG) {
10275   SDValue LHS = N->getOperand(0);
10276   SDValue RHS = N->getOperand(1);
10277   SDLoc DL(N);
10278   EVT VT = N->getValueType(0);
10279   SmallVector<std::pair<SDValue, SDValue>, 16> WorkList;
10280 
10281   // Only handle integer compares.
10282   if (N->getOpcode() != ISD::SETCC)
10283     return SDValue();
10284 
10285   ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
10286   // Try to express conjunction "cmp 0 (or (xor A0 A1) (xor B0 B1))" as:
10287   // sub A0, A1; ccmp B0, B1, 0, eq; cmp inv(Cond) flag
10288   unsigned NumXors = 0;
10289   if ((Cond == ISD::SETEQ || Cond == ISD::SETNE) && isNullConstant(RHS) &&
10290       LHS->getOpcode() == ISD::OR && LHS->hasOneUse() &&
10291       isOrXorChain(LHS, NumXors, WorkList)) {
10292     SDValue XOR0, XOR1;
10293     std::tie(XOR0, XOR1) = WorkList[0];
10294     unsigned LogicOp = (Cond == ISD::SETEQ) ? ISD::AND : ISD::OR;
10295     SDValue Cmp = DAG.getSetCC(DL, VT, XOR0, XOR1, Cond);
10296     for (unsigned I = 1; I < WorkList.size(); I++) {
10297       std::tie(XOR0, XOR1) = WorkList[I];
10298       SDValue CmpChain = DAG.getSetCC(DL, VT, XOR0, XOR1, Cond);
10299       Cmp = DAG.getNode(LogicOp, DL, VT, Cmp, CmpChain);
10300     }
10301 
10302     // Exit early by inverting the condition, which help reduce indentations.
10303     return Cmp;
10304   }
10305 
10306   return SDValue();
10307 }
10308 
LowerSETCC(SDValue Op,SelectionDAG & DAG) const10309 SDValue AArch64TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const {
10310 
10311   if (Op.getValueType().isVector())
10312     return LowerVSETCC(Op, DAG);
10313 
10314   bool IsStrict = Op->isStrictFPOpcode();
10315   bool IsSignaling = Op.getOpcode() == ISD::STRICT_FSETCCS;
10316   unsigned OpNo = IsStrict ? 1 : 0;
10317   SDValue Chain;
10318   if (IsStrict)
10319     Chain = Op.getOperand(0);
10320   SDValue LHS = Op.getOperand(OpNo + 0);
10321   SDValue RHS = Op.getOperand(OpNo + 1);
10322   ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(OpNo + 2))->get();
10323   SDLoc dl(Op);
10324 
10325   // We chose ZeroOrOneBooleanContents, so use zero and one.
10326   EVT VT = Op.getValueType();
10327   SDValue TVal = DAG.getConstant(1, dl, VT);
10328   SDValue FVal = DAG.getConstant(0, dl, VT);
10329 
10330   // Handle f128 first, since one possible outcome is a normal integer
10331   // comparison which gets picked up by the next if statement.
10332   if (LHS.getValueType() == MVT::f128) {
10333     softenSetCCOperands(DAG, MVT::f128, LHS, RHS, CC, dl, LHS, RHS, Chain,
10334                         IsSignaling);
10335 
10336     // If softenSetCCOperands returned a scalar, use it.
10337     if (!RHS.getNode()) {
10338       assert(LHS.getValueType() == Op.getValueType() &&
10339              "Unexpected setcc expansion!");
10340       return IsStrict ? DAG.getMergeValues({LHS, Chain}, dl) : LHS;
10341     }
10342   }
10343 
10344   if (LHS.getValueType().isInteger()) {
10345     SDValue CCVal;
10346     SDValue Cmp = getAArch64Cmp(
10347         LHS, RHS, ISD::getSetCCInverse(CC, LHS.getValueType()), CCVal, DAG, dl);
10348 
10349     // Note that we inverted the condition above, so we reverse the order of
10350     // the true and false operands here.  This will allow the setcc to be
10351     // matched to a single CSINC instruction.
10352     SDValue Res = DAG.getNode(AArch64ISD::CSEL, dl, VT, FVal, TVal, CCVal, Cmp);
10353     return IsStrict ? DAG.getMergeValues({Res, Chain}, dl) : Res;
10354   }
10355 
10356   // Now we know we're dealing with FP values.
10357   assert(LHS.getValueType() == MVT::bf16 || LHS.getValueType() == MVT::f16 ||
10358          LHS.getValueType() == MVT::f32 || LHS.getValueType() == MVT::f64);
10359 
10360   // If that fails, we'll need to perform an FCMP + CSEL sequence.  Go ahead
10361   // and do the comparison.
10362   SDValue Cmp;
10363   if (IsStrict)
10364     Cmp = emitStrictFPComparison(LHS, RHS, dl, DAG, Chain, IsSignaling);
10365   else
10366     Cmp = emitComparison(LHS, RHS, CC, dl, DAG);
10367 
10368   AArch64CC::CondCode CC1, CC2;
10369   changeFPCCToAArch64CC(CC, CC1, CC2);
10370   SDValue Res;
10371   if (CC2 == AArch64CC::AL) {
10372     changeFPCCToAArch64CC(ISD::getSetCCInverse(CC, LHS.getValueType()), CC1,
10373                           CC2);
10374     SDValue CC1Val = DAG.getConstant(CC1, dl, MVT::i32);
10375 
10376     // Note that we inverted the condition above, so we reverse the order of
10377     // the true and false operands here.  This will allow the setcc to be
10378     // matched to a single CSINC instruction.
10379     Res = DAG.getNode(AArch64ISD::CSEL, dl, VT, FVal, TVal, CC1Val, Cmp);
10380   } else {
10381     // Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't
10382     // totally clean.  Some of them require two CSELs to implement.  As is in
10383     // this case, we emit the first CSEL and then emit a second using the output
10384     // of the first as the RHS.  We're effectively OR'ing the two CC's together.
10385 
10386     // FIXME: It would be nice if we could match the two CSELs to two CSINCs.
10387     SDValue CC1Val = DAG.getConstant(CC1, dl, MVT::i32);
10388     SDValue CS1 =
10389         DAG.getNode(AArch64ISD::CSEL, dl, VT, TVal, FVal, CC1Val, Cmp);
10390 
10391     SDValue CC2Val = DAG.getConstant(CC2, dl, MVT::i32);
10392     Res = DAG.getNode(AArch64ISD::CSEL, dl, VT, TVal, CS1, CC2Val, Cmp);
10393   }
10394   return IsStrict ? DAG.getMergeValues({Res, Cmp.getValue(1)}, dl) : Res;
10395 }
10396 
LowerSETCCCARRY(SDValue Op,SelectionDAG & DAG) const10397 SDValue AArch64TargetLowering::LowerSETCCCARRY(SDValue Op,
10398                                                SelectionDAG &DAG) const {
10399 
10400   SDValue LHS = Op.getOperand(0);
10401   SDValue RHS = Op.getOperand(1);
10402   EVT VT = LHS.getValueType();
10403   if (VT != MVT::i32 && VT != MVT::i64)
10404     return SDValue();
10405 
10406   SDLoc DL(Op);
10407   SDValue Carry = Op.getOperand(2);
10408   // SBCS uses a carry not a borrow so the carry flag should be inverted first.
10409   SDValue InvCarry = valueToCarryFlag(Carry, DAG, true);
10410   SDValue Cmp = DAG.getNode(AArch64ISD::SBCS, DL, DAG.getVTList(VT, MVT::Glue),
10411                             LHS, RHS, InvCarry);
10412 
10413   EVT OpVT = Op.getValueType();
10414   SDValue TVal = DAG.getConstant(1, DL, OpVT);
10415   SDValue FVal = DAG.getConstant(0, DL, OpVT);
10416 
10417   ISD::CondCode Cond = cast<CondCodeSDNode>(Op.getOperand(3))->get();
10418   ISD::CondCode CondInv = ISD::getSetCCInverse(Cond, VT);
10419   SDValue CCVal =
10420       DAG.getConstant(changeIntCCToAArch64CC(CondInv), DL, MVT::i32);
10421   // Inputs are swapped because the condition is inverted. This will allow
10422   // matching with a single CSINC instruction.
10423   return DAG.getNode(AArch64ISD::CSEL, DL, OpVT, FVal, TVal, CCVal,
10424                      Cmp.getValue(1));
10425 }
10426 
LowerSELECT_CC(ISD::CondCode CC,SDValue LHS,SDValue RHS,SDValue TVal,SDValue FVal,const SDLoc & dl,SelectionDAG & DAG) const10427 SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
10428                                               SDValue RHS, SDValue TVal,
10429                                               SDValue FVal, const SDLoc &dl,
10430                                               SelectionDAG &DAG) const {
10431   // Handle f128 first, because it will result in a comparison of some RTLIB
10432   // call result against zero.
10433   if (LHS.getValueType() == MVT::f128) {
10434     softenSetCCOperands(DAG, MVT::f128, LHS, RHS, CC, dl, LHS, RHS);
10435 
10436     // If softenSetCCOperands returned a scalar, we need to compare the result
10437     // against zero to select between true and false values.
10438     if (!RHS.getNode()) {
10439       RHS = DAG.getConstant(0, dl, LHS.getValueType());
10440       CC = ISD::SETNE;
10441     }
10442   }
10443 
10444   // Also handle f16, for which we need to do a f32 comparison.
10445   if ((LHS.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) ||
10446       LHS.getValueType() == MVT::bf16) {
10447     LHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, LHS);
10448     RHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, RHS);
10449   }
10450 
10451   // Next, handle integers.
10452   if (LHS.getValueType().isInteger()) {
10453     assert((LHS.getValueType() == RHS.getValueType()) &&
10454            (LHS.getValueType() == MVT::i32 || LHS.getValueType() == MVT::i64));
10455 
10456     ConstantSDNode *CFVal = dyn_cast<ConstantSDNode>(FVal);
10457     ConstantSDNode *CTVal = dyn_cast<ConstantSDNode>(TVal);
10458     ConstantSDNode *RHSC = dyn_cast<ConstantSDNode>(RHS);
10459     // Check for sign pattern (SELECT_CC setgt, iN lhs, -1, 1, -1) and transform
10460     // into (OR (ASR lhs, N-1), 1), which requires less instructions for the
10461     // supported types.
10462     if (CC == ISD::SETGT && RHSC && RHSC->isAllOnes() && CTVal && CFVal &&
10463         CTVal->isOne() && CFVal->isAllOnes() &&
10464         LHS.getValueType() == TVal.getValueType()) {
10465       EVT VT = LHS.getValueType();
10466       SDValue Shift =
10467           DAG.getNode(ISD::SRA, dl, VT, LHS,
10468                       DAG.getConstant(VT.getSizeInBits() - 1, dl, VT));
10469       return DAG.getNode(ISD::OR, dl, VT, Shift, DAG.getConstant(1, dl, VT));
10470     }
10471 
10472     // Check for SMAX(lhs, 0) and SMIN(lhs, 0) patterns.
10473     // (SELECT_CC setgt, lhs, 0, lhs, 0) -> (BIC lhs, (SRA lhs, typesize-1))
10474     // (SELECT_CC setlt, lhs, 0, lhs, 0) -> (AND lhs, (SRA lhs, typesize-1))
10475     // Both require less instructions than compare and conditional select.
10476     if ((CC == ISD::SETGT || CC == ISD::SETLT) && LHS == TVal &&
10477         RHSC && RHSC->isZero() && CFVal && CFVal->isZero() &&
10478         LHS.getValueType() == RHS.getValueType()) {
10479       EVT VT = LHS.getValueType();
10480       SDValue Shift =
10481           DAG.getNode(ISD::SRA, dl, VT, LHS,
10482                       DAG.getConstant(VT.getSizeInBits() - 1, dl, VT));
10483 
10484       if (CC == ISD::SETGT)
10485         Shift = DAG.getNOT(dl, Shift, VT);
10486 
10487       return DAG.getNode(ISD::AND, dl, VT, LHS, Shift);
10488     }
10489 
10490     unsigned Opcode = AArch64ISD::CSEL;
10491 
10492     // If both the TVal and the FVal are constants, see if we can swap them in
10493     // order to for a CSINV or CSINC out of them.
10494     if (CTVal && CFVal && CTVal->isAllOnes() && CFVal->isZero()) {
10495       std::swap(TVal, FVal);
10496       std::swap(CTVal, CFVal);
10497       CC = ISD::getSetCCInverse(CC, LHS.getValueType());
10498     } else if (CTVal && CFVal && CTVal->isOne() && CFVal->isZero()) {
10499       std::swap(TVal, FVal);
10500       std::swap(CTVal, CFVal);
10501       CC = ISD::getSetCCInverse(CC, LHS.getValueType());
10502     } else if (TVal.getOpcode() == ISD::XOR) {
10503       // If TVal is a NOT we want to swap TVal and FVal so that we can match
10504       // with a CSINV rather than a CSEL.
10505       if (isAllOnesConstant(TVal.getOperand(1))) {
10506         std::swap(TVal, FVal);
10507         std::swap(CTVal, CFVal);
10508         CC = ISD::getSetCCInverse(CC, LHS.getValueType());
10509       }
10510     } else if (TVal.getOpcode() == ISD::SUB) {
10511       // If TVal is a negation (SUB from 0) we want to swap TVal and FVal so
10512       // that we can match with a CSNEG rather than a CSEL.
10513       if (isNullConstant(TVal.getOperand(0))) {
10514         std::swap(TVal, FVal);
10515         std::swap(CTVal, CFVal);
10516         CC = ISD::getSetCCInverse(CC, LHS.getValueType());
10517       }
10518     } else if (CTVal && CFVal) {
10519       const int64_t TrueVal = CTVal->getSExtValue();
10520       const int64_t FalseVal = CFVal->getSExtValue();
10521       bool Swap = false;
10522 
10523       // If both TVal and FVal are constants, see if FVal is the
10524       // inverse/negation/increment of TVal and generate a CSINV/CSNEG/CSINC
10525       // instead of a CSEL in that case.
10526       if (TrueVal == ~FalseVal) {
10527         Opcode = AArch64ISD::CSINV;
10528       } else if (FalseVal > std::numeric_limits<int64_t>::min() &&
10529                  TrueVal == -FalseVal) {
10530         Opcode = AArch64ISD::CSNEG;
10531       } else if (TVal.getValueType() == MVT::i32) {
10532         // If our operands are only 32-bit wide, make sure we use 32-bit
10533         // arithmetic for the check whether we can use CSINC. This ensures that
10534         // the addition in the check will wrap around properly in case there is
10535         // an overflow (which would not be the case if we do the check with
10536         // 64-bit arithmetic).
10537         const uint32_t TrueVal32 = CTVal->getZExtValue();
10538         const uint32_t FalseVal32 = CFVal->getZExtValue();
10539 
10540         if ((TrueVal32 == FalseVal32 + 1) || (TrueVal32 + 1 == FalseVal32)) {
10541           Opcode = AArch64ISD::CSINC;
10542 
10543           if (TrueVal32 > FalseVal32) {
10544             Swap = true;
10545           }
10546         }
10547       } else {
10548         // 64-bit check whether we can use CSINC.
10549         const uint64_t TrueVal64 = TrueVal;
10550         const uint64_t FalseVal64 = FalseVal;
10551 
10552         if ((TrueVal64 == FalseVal64 + 1) || (TrueVal64 + 1 == FalseVal64)) {
10553           Opcode = AArch64ISD::CSINC;
10554 
10555           if (TrueVal > FalseVal) {
10556             Swap = true;
10557           }
10558         }
10559       }
10560 
10561       // Swap TVal and FVal if necessary.
10562       if (Swap) {
10563         std::swap(TVal, FVal);
10564         std::swap(CTVal, CFVal);
10565         CC = ISD::getSetCCInverse(CC, LHS.getValueType());
10566       }
10567 
10568       if (Opcode != AArch64ISD::CSEL) {
10569         // Drop FVal since we can get its value by simply inverting/negating
10570         // TVal.
10571         FVal = TVal;
10572       }
10573     }
10574 
10575     // Avoid materializing a constant when possible by reusing a known value in
10576     // a register.  However, don't perform this optimization if the known value
10577     // is one, zero or negative one in the case of a CSEL.  We can always
10578     // materialize these values using CSINC, CSEL and CSINV with wzr/xzr as the
10579     // FVal, respectively.
10580     ConstantSDNode *RHSVal = dyn_cast<ConstantSDNode>(RHS);
10581     if (Opcode == AArch64ISD::CSEL && RHSVal && !RHSVal->isOne() &&
10582         !RHSVal->isZero() && !RHSVal->isAllOnes()) {
10583       AArch64CC::CondCode AArch64CC = changeIntCCToAArch64CC(CC);
10584       // Transform "a == C ? C : x" to "a == C ? a : x" and "a != C ? x : C" to
10585       // "a != C ? x : a" to avoid materializing C.
10586       if (CTVal && CTVal == RHSVal && AArch64CC == AArch64CC::EQ)
10587         TVal = LHS;
10588       else if (CFVal && CFVal == RHSVal && AArch64CC == AArch64CC::NE)
10589         FVal = LHS;
10590     } else if (Opcode == AArch64ISD::CSNEG && RHSVal && RHSVal->isOne()) {
10591       assert (CTVal && CFVal && "Expected constant operands for CSNEG.");
10592       // Use a CSINV to transform "a == C ? 1 : -1" to "a == C ? a : -1" to
10593       // avoid materializing C.
10594       AArch64CC::CondCode AArch64CC = changeIntCCToAArch64CC(CC);
10595       if (CTVal == RHSVal && AArch64CC == AArch64CC::EQ) {
10596         Opcode = AArch64ISD::CSINV;
10597         TVal = LHS;
10598         FVal = DAG.getConstant(0, dl, FVal.getValueType());
10599       }
10600     }
10601 
10602     SDValue CCVal;
10603     SDValue Cmp = getAArch64Cmp(LHS, RHS, CC, CCVal, DAG, dl);
10604     EVT VT = TVal.getValueType();
10605     return DAG.getNode(Opcode, dl, VT, TVal, FVal, CCVal, Cmp);
10606   }
10607 
10608   // Now we know we're dealing with FP values.
10609   assert(LHS.getValueType() == MVT::f16 || LHS.getValueType() == MVT::f32 ||
10610          LHS.getValueType() == MVT::f64);
10611   assert(LHS.getValueType() == RHS.getValueType());
10612   EVT VT = TVal.getValueType();
10613   SDValue Cmp = emitComparison(LHS, RHS, CC, dl, DAG);
10614 
10615   // Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't totally
10616   // clean.  Some of them require two CSELs to implement.
10617   AArch64CC::CondCode CC1, CC2;
10618   changeFPCCToAArch64CC(CC, CC1, CC2);
10619 
10620   if (DAG.getTarget().Options.UnsafeFPMath) {
10621     // Transform "a == 0.0 ? 0.0 : x" to "a == 0.0 ? a : x" and
10622     // "a != 0.0 ? x : 0.0" to "a != 0.0 ? x : a" to avoid materializing 0.0.
10623     ConstantFPSDNode *RHSVal = dyn_cast<ConstantFPSDNode>(RHS);
10624     if (RHSVal && RHSVal->isZero()) {
10625       ConstantFPSDNode *CFVal = dyn_cast<ConstantFPSDNode>(FVal);
10626       ConstantFPSDNode *CTVal = dyn_cast<ConstantFPSDNode>(TVal);
10627 
10628       if ((CC == ISD::SETEQ || CC == ISD::SETOEQ || CC == ISD::SETUEQ) &&
10629           CTVal && CTVal->isZero() && TVal.getValueType() == LHS.getValueType())
10630         TVal = LHS;
10631       else if ((CC == ISD::SETNE || CC == ISD::SETONE || CC == ISD::SETUNE) &&
10632                CFVal && CFVal->isZero() &&
10633                FVal.getValueType() == LHS.getValueType())
10634         FVal = LHS;
10635     }
10636   }
10637 
10638   // Emit first, and possibly only, CSEL.
10639   SDValue CC1Val = DAG.getConstant(CC1, dl, MVT::i32);
10640   SDValue CS1 = DAG.getNode(AArch64ISD::CSEL, dl, VT, TVal, FVal, CC1Val, Cmp);
10641 
10642   // If we need a second CSEL, emit it, using the output of the first as the
10643   // RHS.  We're effectively OR'ing the two CC's together.
10644   if (CC2 != AArch64CC::AL) {
10645     SDValue CC2Val = DAG.getConstant(CC2, dl, MVT::i32);
10646     return DAG.getNode(AArch64ISD::CSEL, dl, VT, TVal, CS1, CC2Val, Cmp);
10647   }
10648 
10649   // Otherwise, return the output of the first CSEL.
10650   return CS1;
10651 }
10652 
LowerVECTOR_SPLICE(SDValue Op,SelectionDAG & DAG) const10653 SDValue AArch64TargetLowering::LowerVECTOR_SPLICE(SDValue Op,
10654                                                   SelectionDAG &DAG) const {
10655   EVT Ty = Op.getValueType();
10656   auto Idx = Op.getConstantOperandAPInt(2);
10657   int64_t IdxVal = Idx.getSExtValue();
10658   assert(Ty.isScalableVector() &&
10659          "Only expect scalable vectors for custom lowering of VECTOR_SPLICE");
10660 
10661   // We can use the splice instruction for certain index values where we are
10662   // able to efficiently generate the correct predicate. The index will be
10663   // inverted and used directly as the input to the ptrue instruction, i.e.
10664   // -1 -> vl1, -2 -> vl2, etc. The predicate will then be reversed to get the
10665   // splice predicate. However, we can only do this if we can guarantee that
10666   // there are enough elements in the vector, hence we check the index <= min
10667   // number of elements.
10668   std::optional<unsigned> PredPattern;
10669   if (Ty.isScalableVector() && IdxVal < 0 &&
10670       (PredPattern = getSVEPredPatternFromNumElements(std::abs(IdxVal))) !=
10671           std::nullopt) {
10672     SDLoc DL(Op);
10673 
10674     // Create a predicate where all but the last -IdxVal elements are false.
10675     EVT PredVT = Ty.changeVectorElementType(MVT::i1);
10676     SDValue Pred = getPTrue(DAG, DL, PredVT, *PredPattern);
10677     Pred = DAG.getNode(ISD::VECTOR_REVERSE, DL, PredVT, Pred);
10678 
10679     // Now splice the two inputs together using the predicate.
10680     return DAG.getNode(AArch64ISD::SPLICE, DL, Ty, Pred, Op.getOperand(0),
10681                        Op.getOperand(1));
10682   }
10683 
10684   // We can select to an EXT instruction when indexing the first 256 bytes.
10685   unsigned BlockSize = AArch64::SVEBitsPerBlock / Ty.getVectorMinNumElements();
10686   if (IdxVal >= 0 && (IdxVal * BlockSize / 8) < 256)
10687     return Op;
10688 
10689   return SDValue();
10690 }
10691 
LowerSELECT_CC(SDValue Op,SelectionDAG & DAG) const10692 SDValue AArch64TargetLowering::LowerSELECT_CC(SDValue Op,
10693                                               SelectionDAG &DAG) const {
10694   ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(4))->get();
10695   SDValue LHS = Op.getOperand(0);
10696   SDValue RHS = Op.getOperand(1);
10697   SDValue TVal = Op.getOperand(2);
10698   SDValue FVal = Op.getOperand(3);
10699   SDLoc DL(Op);
10700   return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
10701 }
10702 
LowerSELECT(SDValue Op,SelectionDAG & DAG) const10703 SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
10704                                            SelectionDAG &DAG) const {
10705   SDValue CCVal = Op->getOperand(0);
10706   SDValue TVal = Op->getOperand(1);
10707   SDValue FVal = Op->getOperand(2);
10708   SDLoc DL(Op);
10709 
10710   EVT Ty = Op.getValueType();
10711   if (Ty == MVT::aarch64svcount) {
10712     TVal = DAG.getNode(ISD::BITCAST, DL, MVT::nxv16i1, TVal);
10713     FVal = DAG.getNode(ISD::BITCAST, DL, MVT::nxv16i1, FVal);
10714     SDValue Sel =
10715         DAG.getNode(ISD::SELECT, DL, MVT::nxv16i1, CCVal, TVal, FVal);
10716     return DAG.getNode(ISD::BITCAST, DL, Ty, Sel);
10717   }
10718 
10719   if (Ty.isScalableVector()) {
10720     MVT PredVT = MVT::getVectorVT(MVT::i1, Ty.getVectorElementCount());
10721     SDValue SplatPred = DAG.getNode(ISD::SPLAT_VECTOR, DL, PredVT, CCVal);
10722     return DAG.getNode(ISD::VSELECT, DL, Ty, SplatPred, TVal, FVal);
10723   }
10724 
10725   if (useSVEForFixedLengthVectorVT(Ty, !Subtarget->isNeonAvailable())) {
10726     // FIXME: Ideally this would be the same as above using i1 types, however
10727     // for the moment we can't deal with fixed i1 vector types properly, so
10728     // instead extend the predicate to a result type sized integer vector.
10729     MVT SplatValVT = MVT::getIntegerVT(Ty.getScalarSizeInBits());
10730     MVT PredVT = MVT::getVectorVT(SplatValVT, Ty.getVectorElementCount());
10731     SDValue SplatVal = DAG.getSExtOrTrunc(CCVal, DL, SplatValVT);
10732     SDValue SplatPred = DAG.getNode(ISD::SPLAT_VECTOR, DL, PredVT, SplatVal);
10733     return DAG.getNode(ISD::VSELECT, DL, Ty, SplatPred, TVal, FVal);
10734   }
10735 
10736   // Optimize {s|u}{add|sub|mul}.with.overflow feeding into a select
10737   // instruction.
10738   if (ISD::isOverflowIntrOpRes(CCVal)) {
10739     // Only lower legal XALUO ops.
10740     if (!DAG.getTargetLoweringInfo().isTypeLegal(CCVal->getValueType(0)))
10741       return SDValue();
10742 
10743     AArch64CC::CondCode OFCC;
10744     SDValue Value, Overflow;
10745     std::tie(Value, Overflow) = getAArch64XALUOOp(OFCC, CCVal.getValue(0), DAG);
10746     SDValue CCVal = DAG.getConstant(OFCC, DL, MVT::i32);
10747 
10748     return DAG.getNode(AArch64ISD::CSEL, DL, Op.getValueType(), TVal, FVal,
10749                        CCVal, Overflow);
10750   }
10751 
10752   // Lower it the same way as we would lower a SELECT_CC node.
10753   ISD::CondCode CC;
10754   SDValue LHS, RHS;
10755   if (CCVal.getOpcode() == ISD::SETCC) {
10756     LHS = CCVal.getOperand(0);
10757     RHS = CCVal.getOperand(1);
10758     CC = cast<CondCodeSDNode>(CCVal.getOperand(2))->get();
10759   } else {
10760     LHS = CCVal;
10761     RHS = DAG.getConstant(0, DL, CCVal.getValueType());
10762     CC = ISD::SETNE;
10763   }
10764 
10765   // If we are lowering a f16 and we do not have fullf16, convert to a f32 in
10766   // order to use FCSELSrrr
10767   if ((Ty == MVT::f16 || Ty == MVT::bf16) && !Subtarget->hasFullFP16()) {
10768     TVal = DAG.getTargetInsertSubreg(AArch64::hsub, DL, MVT::f32,
10769                                      DAG.getUNDEF(MVT::f32), TVal);
10770     FVal = DAG.getTargetInsertSubreg(AArch64::hsub, DL, MVT::f32,
10771                                      DAG.getUNDEF(MVT::f32), FVal);
10772   }
10773 
10774   SDValue Res = LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
10775 
10776   if ((Ty == MVT::f16 || Ty == MVT::bf16) && !Subtarget->hasFullFP16()) {
10777     return DAG.getTargetExtractSubreg(AArch64::hsub, DL, Ty, Res);
10778   }
10779 
10780   return Res;
10781 }
10782 
LowerJumpTable(SDValue Op,SelectionDAG & DAG) const10783 SDValue AArch64TargetLowering::LowerJumpTable(SDValue Op,
10784                                               SelectionDAG &DAG) const {
10785   // Jump table entries as PC relative offsets. No additional tweaking
10786   // is necessary here. Just get the address of the jump table.
10787   JumpTableSDNode *JT = cast<JumpTableSDNode>(Op);
10788 
10789   CodeModel::Model CM = getTargetMachine().getCodeModel();
10790   if (CM == CodeModel::Large && !getTargetMachine().isPositionIndependent() &&
10791       !Subtarget->isTargetMachO())
10792     return getAddrLarge(JT, DAG);
10793   if (CM == CodeModel::Tiny)
10794     return getAddrTiny(JT, DAG);
10795   return getAddr(JT, DAG);
10796 }
10797 
LowerBR_JT(SDValue Op,SelectionDAG & DAG) const10798 SDValue AArch64TargetLowering::LowerBR_JT(SDValue Op,
10799                                           SelectionDAG &DAG) const {
10800   // Jump table entries as PC relative offsets. No additional tweaking
10801   // is necessary here. Just get the address of the jump table.
10802   SDLoc DL(Op);
10803   SDValue JT = Op.getOperand(1);
10804   SDValue Entry = Op.getOperand(2);
10805   int JTI = cast<JumpTableSDNode>(JT.getNode())->getIndex();
10806 
10807   auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
10808   AFI->setJumpTableEntryInfo(JTI, 4, nullptr);
10809 
10810   // With aarch64-jump-table-hardening, we only expand the jump table dispatch
10811   // sequence later, to guarantee the integrity of the intermediate values.
10812   if (DAG.getMachineFunction().getFunction().hasFnAttribute(
10813           "aarch64-jump-table-hardening")) {
10814     CodeModel::Model CM = getTargetMachine().getCodeModel();
10815     if (Subtarget->isTargetMachO()) {
10816       if (CM != CodeModel::Small && CM != CodeModel::Large)
10817         report_fatal_error("Unsupported code-model for hardened jump-table");
10818     } else {
10819       // Note that COFF support would likely also need JUMP_TABLE_DEBUG_INFO.
10820       assert(Subtarget->isTargetELF() &&
10821              "jump table hardening only supported on MachO/ELF");
10822       if (CM != CodeModel::Small)
10823         report_fatal_error("Unsupported code-model for hardened jump-table");
10824     }
10825 
10826     SDValue X16Copy = DAG.getCopyToReg(DAG.getEntryNode(), DL, AArch64::X16,
10827                                        Entry, SDValue());
10828     SDNode *B = DAG.getMachineNode(AArch64::BR_JumpTable, DL, MVT::Other,
10829                                    DAG.getTargetJumpTable(JTI, MVT::i32),
10830                                    X16Copy.getValue(0), X16Copy.getValue(1));
10831     return SDValue(B, 0);
10832   }
10833 
10834   SDNode *Dest =
10835       DAG.getMachineNode(AArch64::JumpTableDest32, DL, MVT::i64, MVT::i64, JT,
10836                          Entry, DAG.getTargetJumpTable(JTI, MVT::i32));
10837   SDValue JTInfo = DAG.getJumpTableDebugInfo(JTI, Op.getOperand(0), DL);
10838   return DAG.getNode(ISD::BRIND, DL, MVT::Other, JTInfo, SDValue(Dest, 0));
10839 }
10840 
LowerBRIND(SDValue Op,SelectionDAG & DAG) const10841 SDValue AArch64TargetLowering::LowerBRIND(SDValue Op, SelectionDAG &DAG) const {
10842   SDValue Chain = Op.getOperand(0);
10843   SDValue Dest = Op.getOperand(1);
10844 
10845   // BR_JT is lowered to BRIND, but the later lowering is specific to indirectbr
10846   // Skip over the jump-table BRINDs, where the destination is JumpTableDest32.
10847   if (Dest->isMachineOpcode() &&
10848       Dest->getMachineOpcode() == AArch64::JumpTableDest32)
10849     return SDValue();
10850 
10851   const MachineFunction &MF = DAG.getMachineFunction();
10852   std::optional<uint16_t> BADisc =
10853       Subtarget->getPtrAuthBlockAddressDiscriminatorIfEnabled(MF.getFunction());
10854   if (!BADisc)
10855     return SDValue();
10856 
10857   SDLoc DL(Op);
10858 
10859   SDValue Disc = DAG.getTargetConstant(*BADisc, DL, MVT::i64);
10860   SDValue Key = DAG.getTargetConstant(AArch64PACKey::IA, DL, MVT::i32);
10861   SDValue AddrDisc = DAG.getRegister(AArch64::XZR, MVT::i64);
10862 
10863   SDNode *BrA = DAG.getMachineNode(AArch64::BRA, DL, MVT::Other,
10864                                    {Dest, Key, Disc, AddrDisc, Chain});
10865   return SDValue(BrA, 0);
10866 }
10867 
LowerConstantPool(SDValue Op,SelectionDAG & DAG) const10868 SDValue AArch64TargetLowering::LowerConstantPool(SDValue Op,
10869                                                  SelectionDAG &DAG) const {
10870   ConstantPoolSDNode *CP = cast<ConstantPoolSDNode>(Op);
10871   CodeModel::Model CM = getTargetMachine().getCodeModel();
10872   if (CM == CodeModel::Large) {
10873     // Use the GOT for the large code model on iOS.
10874     if (Subtarget->isTargetMachO()) {
10875       return getGOT(CP, DAG);
10876     }
10877     if (!getTargetMachine().isPositionIndependent())
10878       return getAddrLarge(CP, DAG);
10879   } else if (CM == CodeModel::Tiny) {
10880     return getAddrTiny(CP, DAG);
10881   }
10882   return getAddr(CP, DAG);
10883 }
10884 
LowerBlockAddress(SDValue Op,SelectionDAG & DAG) const10885 SDValue AArch64TargetLowering::LowerBlockAddress(SDValue Op,
10886                                                SelectionDAG &DAG) const {
10887   BlockAddressSDNode *BAN = cast<BlockAddressSDNode>(Op);
10888   const BlockAddress *BA = BAN->getBlockAddress();
10889 
10890   if (std::optional<uint16_t> BADisc =
10891           Subtarget->getPtrAuthBlockAddressDiscriminatorIfEnabled(
10892               *BA->getFunction())) {
10893     SDLoc DL(Op);
10894 
10895     // This isn't cheap, but BRIND is rare.
10896     SDValue TargetBA = DAG.getTargetBlockAddress(BA, BAN->getValueType(0));
10897 
10898     SDValue Disc = DAG.getTargetConstant(*BADisc, DL, MVT::i64);
10899 
10900     SDValue Key = DAG.getTargetConstant(AArch64PACKey::IA, DL, MVT::i32);
10901     SDValue AddrDisc = DAG.getRegister(AArch64::XZR, MVT::i64);
10902 
10903     SDNode *MOV =
10904         DAG.getMachineNode(AArch64::MOVaddrPAC, DL, {MVT::Other, MVT::Glue},
10905                            {TargetBA, Key, AddrDisc, Disc});
10906     return DAG.getCopyFromReg(SDValue(MOV, 0), DL, AArch64::X16, MVT::i64,
10907                               SDValue(MOV, 1));
10908   }
10909 
10910   CodeModel::Model CM = getTargetMachine().getCodeModel();
10911   if (CM == CodeModel::Large && !Subtarget->isTargetMachO()) {
10912     if (!getTargetMachine().isPositionIndependent())
10913       return getAddrLarge(BAN, DAG);
10914   } else if (CM == CodeModel::Tiny) {
10915     return getAddrTiny(BAN, DAG);
10916   }
10917   return getAddr(BAN, DAG);
10918 }
10919 
LowerDarwin_VASTART(SDValue Op,SelectionDAG & DAG) const10920 SDValue AArch64TargetLowering::LowerDarwin_VASTART(SDValue Op,
10921                                                  SelectionDAG &DAG) const {
10922   AArch64FunctionInfo *FuncInfo =
10923       DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
10924 
10925   SDLoc DL(Op);
10926   SDValue FR = DAG.getFrameIndex(FuncInfo->getVarArgsStackIndex(),
10927                                  getPointerTy(DAG.getDataLayout()));
10928   FR = DAG.getZExtOrTrunc(FR, DL, getPointerMemTy(DAG.getDataLayout()));
10929   const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
10930   return DAG.getStore(Op.getOperand(0), DL, FR, Op.getOperand(1),
10931                       MachinePointerInfo(SV));
10932 }
10933 
LowerWin64_VASTART(SDValue Op,SelectionDAG & DAG) const10934 SDValue AArch64TargetLowering::LowerWin64_VASTART(SDValue Op,
10935                                                   SelectionDAG &DAG) const {
10936   MachineFunction &MF = DAG.getMachineFunction();
10937   AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
10938 
10939   SDLoc DL(Op);
10940   SDValue FR;
10941   if (Subtarget->isWindowsArm64EC()) {
10942     // With the Arm64EC ABI, we compute the address of the varargs save area
10943     // relative to x4. For a normal AArch64->AArch64 call, x4 == sp on entry,
10944     // but calls from an entry thunk can pass in a different address.
10945     Register VReg = MF.addLiveIn(AArch64::X4, &AArch64::GPR64RegClass);
10946     SDValue Val = DAG.getCopyFromReg(DAG.getEntryNode(), DL, VReg, MVT::i64);
10947     uint64_t StackOffset;
10948     if (FuncInfo->getVarArgsGPRSize() > 0)
10949       StackOffset = -(uint64_t)FuncInfo->getVarArgsGPRSize();
10950     else
10951       StackOffset = FuncInfo->getVarArgsStackOffset();
10952     FR = DAG.getNode(ISD::ADD, DL, MVT::i64, Val,
10953                      DAG.getConstant(StackOffset, DL, MVT::i64));
10954   } else {
10955     FR = DAG.getFrameIndex(FuncInfo->getVarArgsGPRSize() > 0
10956                                ? FuncInfo->getVarArgsGPRIndex()
10957                                : FuncInfo->getVarArgsStackIndex(),
10958                            getPointerTy(DAG.getDataLayout()));
10959   }
10960   const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
10961   return DAG.getStore(Op.getOperand(0), DL, FR, Op.getOperand(1),
10962                       MachinePointerInfo(SV));
10963 }
10964 
LowerAAPCS_VASTART(SDValue Op,SelectionDAG & DAG) const10965 SDValue AArch64TargetLowering::LowerAAPCS_VASTART(SDValue Op,
10966                                                   SelectionDAG &DAG) const {
10967   // The layout of the va_list struct is specified in the AArch64 Procedure Call
10968   // Standard, section B.3.
10969   MachineFunction &MF = DAG.getMachineFunction();
10970   AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
10971   unsigned PtrSize = Subtarget->isTargetILP32() ? 4 : 8;
10972   auto PtrMemVT = getPointerMemTy(DAG.getDataLayout());
10973   auto PtrVT = getPointerTy(DAG.getDataLayout());
10974   SDLoc DL(Op);
10975 
10976   SDValue Chain = Op.getOperand(0);
10977   SDValue VAList = Op.getOperand(1);
10978   const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
10979   SmallVector<SDValue, 4> MemOps;
10980 
10981   // void *__stack at offset 0
10982   unsigned Offset = 0;
10983   SDValue Stack = DAG.getFrameIndex(FuncInfo->getVarArgsStackIndex(), PtrVT);
10984   Stack = DAG.getZExtOrTrunc(Stack, DL, PtrMemVT);
10985   MemOps.push_back(DAG.getStore(Chain, DL, Stack, VAList,
10986                                 MachinePointerInfo(SV), Align(PtrSize)));
10987 
10988   // void *__gr_top at offset 8 (4 on ILP32)
10989   Offset += PtrSize;
10990   int GPRSize = FuncInfo->getVarArgsGPRSize();
10991   if (GPRSize > 0) {
10992     SDValue GRTop, GRTopAddr;
10993 
10994     GRTopAddr = DAG.getNode(ISD::ADD, DL, PtrVT, VAList,
10995                             DAG.getConstant(Offset, DL, PtrVT));
10996 
10997     GRTop = DAG.getFrameIndex(FuncInfo->getVarArgsGPRIndex(), PtrVT);
10998     GRTop = DAG.getNode(ISD::ADD, DL, PtrVT, GRTop,
10999                         DAG.getConstant(GPRSize, DL, PtrVT));
11000     GRTop = DAG.getZExtOrTrunc(GRTop, DL, PtrMemVT);
11001 
11002     MemOps.push_back(DAG.getStore(Chain, DL, GRTop, GRTopAddr,
11003                                   MachinePointerInfo(SV, Offset),
11004                                   Align(PtrSize)));
11005   }
11006 
11007   // void *__vr_top at offset 16 (8 on ILP32)
11008   Offset += PtrSize;
11009   int FPRSize = FuncInfo->getVarArgsFPRSize();
11010   if (FPRSize > 0) {
11011     SDValue VRTop, VRTopAddr;
11012     VRTopAddr = DAG.getNode(ISD::ADD, DL, PtrVT, VAList,
11013                             DAG.getConstant(Offset, DL, PtrVT));
11014 
11015     VRTop = DAG.getFrameIndex(FuncInfo->getVarArgsFPRIndex(), PtrVT);
11016     VRTop = DAG.getNode(ISD::ADD, DL, PtrVT, VRTop,
11017                         DAG.getConstant(FPRSize, DL, PtrVT));
11018     VRTop = DAG.getZExtOrTrunc(VRTop, DL, PtrMemVT);
11019 
11020     MemOps.push_back(DAG.getStore(Chain, DL, VRTop, VRTopAddr,
11021                                   MachinePointerInfo(SV, Offset),
11022                                   Align(PtrSize)));
11023   }
11024 
11025   // int __gr_offs at offset 24 (12 on ILP32)
11026   Offset += PtrSize;
11027   SDValue GROffsAddr = DAG.getNode(ISD::ADD, DL, PtrVT, VAList,
11028                                    DAG.getConstant(Offset, DL, PtrVT));
11029   MemOps.push_back(
11030       DAG.getStore(Chain, DL, DAG.getConstant(-GPRSize, DL, MVT::i32),
11031                    GROffsAddr, MachinePointerInfo(SV, Offset), Align(4)));
11032 
11033   // int __vr_offs at offset 28 (16 on ILP32)
11034   Offset += 4;
11035   SDValue VROffsAddr = DAG.getNode(ISD::ADD, DL, PtrVT, VAList,
11036                                    DAG.getConstant(Offset, DL, PtrVT));
11037   MemOps.push_back(
11038       DAG.getStore(Chain, DL, DAG.getConstant(-FPRSize, DL, MVT::i32),
11039                    VROffsAddr, MachinePointerInfo(SV, Offset), Align(4)));
11040 
11041   return DAG.getNode(ISD::TokenFactor, DL, MVT::Other, MemOps);
11042 }
11043 
LowerVASTART(SDValue Op,SelectionDAG & DAG) const11044 SDValue AArch64TargetLowering::LowerVASTART(SDValue Op,
11045                                             SelectionDAG &DAG) const {
11046   MachineFunction &MF = DAG.getMachineFunction();
11047   Function &F = MF.getFunction();
11048 
11049   if (Subtarget->isCallingConvWin64(F.getCallingConv(), F.isVarArg()))
11050     return LowerWin64_VASTART(Op, DAG);
11051   else if (Subtarget->isTargetDarwin())
11052     return LowerDarwin_VASTART(Op, DAG);
11053   else
11054     return LowerAAPCS_VASTART(Op, DAG);
11055 }
11056 
LowerVACOPY(SDValue Op,SelectionDAG & DAG) const11057 SDValue AArch64TargetLowering::LowerVACOPY(SDValue Op,
11058                                            SelectionDAG &DAG) const {
11059   // AAPCS has three pointers and two ints (= 32 bytes), Darwin has single
11060   // pointer.
11061   SDLoc DL(Op);
11062   unsigned PtrSize = Subtarget->isTargetILP32() ? 4 : 8;
11063   unsigned VaListSize =
11064       (Subtarget->isTargetDarwin() || Subtarget->isTargetWindows())
11065           ? PtrSize
11066           : Subtarget->isTargetILP32() ? 20 : 32;
11067   const Value *DestSV = cast<SrcValueSDNode>(Op.getOperand(3))->getValue();
11068   const Value *SrcSV = cast<SrcValueSDNode>(Op.getOperand(4))->getValue();
11069 
11070   return DAG.getMemcpy(Op.getOperand(0), DL, Op.getOperand(1), Op.getOperand(2),
11071                        DAG.getConstant(VaListSize, DL, MVT::i32),
11072                        Align(PtrSize), false, false, /*CI=*/nullptr,
11073                        std::nullopt, MachinePointerInfo(DestSV),
11074                        MachinePointerInfo(SrcSV));
11075 }
11076 
LowerVAARG(SDValue Op,SelectionDAG & DAG) const11077 SDValue AArch64TargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
11078   assert(Subtarget->isTargetDarwin() &&
11079          "automatic va_arg instruction only works on Darwin");
11080 
11081   const Value *V = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
11082   EVT VT = Op.getValueType();
11083   SDLoc DL(Op);
11084   SDValue Chain = Op.getOperand(0);
11085   SDValue Addr = Op.getOperand(1);
11086   MaybeAlign Align(Op.getConstantOperandVal(3));
11087   unsigned MinSlotSize = Subtarget->isTargetILP32() ? 4 : 8;
11088   auto PtrVT = getPointerTy(DAG.getDataLayout());
11089   auto PtrMemVT = getPointerMemTy(DAG.getDataLayout());
11090   SDValue VAList =
11091       DAG.getLoad(PtrMemVT, DL, Chain, Addr, MachinePointerInfo(V));
11092   Chain = VAList.getValue(1);
11093   VAList = DAG.getZExtOrTrunc(VAList, DL, PtrVT);
11094 
11095   if (VT.isScalableVector())
11096     report_fatal_error("Passing SVE types to variadic functions is "
11097                        "currently not supported");
11098 
11099   if (Align && *Align > MinSlotSize) {
11100     VAList = DAG.getNode(ISD::ADD, DL, PtrVT, VAList,
11101                          DAG.getConstant(Align->value() - 1, DL, PtrVT));
11102     VAList = DAG.getNode(ISD::AND, DL, PtrVT, VAList,
11103                          DAG.getConstant(-(int64_t)Align->value(), DL, PtrVT));
11104   }
11105 
11106   Type *ArgTy = VT.getTypeForEVT(*DAG.getContext());
11107   unsigned ArgSize = DAG.getDataLayout().getTypeAllocSize(ArgTy);
11108 
11109   // Scalar integer and FP values smaller than 64 bits are implicitly extended
11110   // up to 64 bits.  At the very least, we have to increase the striding of the
11111   // vaargs list to match this, and for FP values we need to introduce
11112   // FP_ROUND nodes as well.
11113   if (VT.isInteger() && !VT.isVector())
11114     ArgSize = std::max(ArgSize, MinSlotSize);
11115   bool NeedFPTrunc = false;
11116   if (VT.isFloatingPoint() && !VT.isVector() && VT != MVT::f64) {
11117     ArgSize = 8;
11118     NeedFPTrunc = true;
11119   }
11120 
11121   // Increment the pointer, VAList, to the next vaarg
11122   SDValue VANext = DAG.getNode(ISD::ADD, DL, PtrVT, VAList,
11123                                DAG.getConstant(ArgSize, DL, PtrVT));
11124   VANext = DAG.getZExtOrTrunc(VANext, DL, PtrMemVT);
11125 
11126   // Store the incremented VAList to the legalized pointer
11127   SDValue APStore =
11128       DAG.getStore(Chain, DL, VANext, Addr, MachinePointerInfo(V));
11129 
11130   // Load the actual argument out of the pointer VAList
11131   if (NeedFPTrunc) {
11132     // Load the value as an f64.
11133     SDValue WideFP =
11134         DAG.getLoad(MVT::f64, DL, APStore, VAList, MachinePointerInfo());
11135     // Round the value down to an f32.
11136     SDValue NarrowFP =
11137         DAG.getNode(ISD::FP_ROUND, DL, VT, WideFP.getValue(0),
11138                     DAG.getIntPtrConstant(1, DL, /*isTarget=*/true));
11139     SDValue Ops[] = { NarrowFP, WideFP.getValue(1) };
11140     // Merge the rounded value with the chain output of the load.
11141     return DAG.getMergeValues(Ops, DL);
11142   }
11143 
11144   return DAG.getLoad(VT, DL, APStore, VAList, MachinePointerInfo());
11145 }
11146 
LowerFRAMEADDR(SDValue Op,SelectionDAG & DAG) const11147 SDValue AArch64TargetLowering::LowerFRAMEADDR(SDValue Op,
11148                                               SelectionDAG &DAG) const {
11149   MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo();
11150   MFI.setFrameAddressIsTaken(true);
11151 
11152   EVT VT = Op.getValueType();
11153   SDLoc DL(Op);
11154   unsigned Depth = Op.getConstantOperandVal(0);
11155   SDValue FrameAddr =
11156       DAG.getCopyFromReg(DAG.getEntryNode(), DL, AArch64::FP, MVT::i64);
11157   while (Depth--)
11158     FrameAddr = DAG.getLoad(VT, DL, DAG.getEntryNode(), FrameAddr,
11159                             MachinePointerInfo());
11160 
11161   if (Subtarget->isTargetILP32())
11162     FrameAddr = DAG.getNode(ISD::AssertZext, DL, MVT::i64, FrameAddr,
11163                             DAG.getValueType(VT));
11164 
11165   return FrameAddr;
11166 }
11167 
LowerSPONENTRY(SDValue Op,SelectionDAG & DAG) const11168 SDValue AArch64TargetLowering::LowerSPONENTRY(SDValue Op,
11169                                               SelectionDAG &DAG) const {
11170   MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo();
11171 
11172   EVT VT = getPointerTy(DAG.getDataLayout());
11173   SDLoc DL(Op);
11174   int FI = MFI.CreateFixedObject(4, 0, false);
11175   return DAG.getFrameIndex(FI, VT);
11176 }
11177 
11178 #define GET_REGISTER_MATCHER
11179 #include "AArch64GenAsmMatcher.inc"
11180 
11181 // FIXME? Maybe this could be a TableGen attribute on some registers and
11182 // this table could be generated automatically from RegInfo.
11183 Register AArch64TargetLowering::
getRegisterByName(const char * RegName,LLT VT,const MachineFunction & MF) const11184 getRegisterByName(const char* RegName, LLT VT, const MachineFunction &MF) const {
11185   Register Reg = MatchRegisterName(RegName);
11186   if (AArch64::X1 <= Reg && Reg <= AArch64::X28) {
11187     const AArch64RegisterInfo *MRI = Subtarget->getRegisterInfo();
11188     unsigned DwarfRegNum = MRI->getDwarfRegNum(Reg, false);
11189     if (!Subtarget->isXRegisterReserved(DwarfRegNum) &&
11190         !MRI->isReservedReg(MF, Reg))
11191       Reg = 0;
11192   }
11193   if (Reg)
11194     return Reg;
11195   report_fatal_error(Twine("Invalid register name \""
11196                               + StringRef(RegName)  + "\"."));
11197 }
11198 
LowerADDROFRETURNADDR(SDValue Op,SelectionDAG & DAG) const11199 SDValue AArch64TargetLowering::LowerADDROFRETURNADDR(SDValue Op,
11200                                                      SelectionDAG &DAG) const {
11201   DAG.getMachineFunction().getFrameInfo().setFrameAddressIsTaken(true);
11202 
11203   EVT VT = Op.getValueType();
11204   SDLoc DL(Op);
11205 
11206   SDValue FrameAddr =
11207       DAG.getCopyFromReg(DAG.getEntryNode(), DL, AArch64::FP, VT);
11208   SDValue Offset = DAG.getConstant(8, DL, getPointerTy(DAG.getDataLayout()));
11209 
11210   return DAG.getNode(ISD::ADD, DL, VT, FrameAddr, Offset);
11211 }
11212 
LowerRETURNADDR(SDValue Op,SelectionDAG & DAG) const11213 SDValue AArch64TargetLowering::LowerRETURNADDR(SDValue Op,
11214                                                SelectionDAG &DAG) const {
11215   MachineFunction &MF = DAG.getMachineFunction();
11216   MachineFrameInfo &MFI = MF.getFrameInfo();
11217   MFI.setReturnAddressIsTaken(true);
11218 
11219   EVT VT = Op.getValueType();
11220   SDLoc DL(Op);
11221   unsigned Depth = Op.getConstantOperandVal(0);
11222   SDValue ReturnAddress;
11223   if (Depth) {
11224     SDValue FrameAddr = LowerFRAMEADDR(Op, DAG);
11225     SDValue Offset = DAG.getConstant(8, DL, getPointerTy(DAG.getDataLayout()));
11226     ReturnAddress = DAG.getLoad(
11227         VT, DL, DAG.getEntryNode(),
11228         DAG.getNode(ISD::ADD, DL, VT, FrameAddr, Offset), MachinePointerInfo());
11229   } else {
11230     // Return LR, which contains the return address. Mark it an implicit
11231     // live-in.
11232     Register Reg = MF.addLiveIn(AArch64::LR, &AArch64::GPR64RegClass);
11233     ReturnAddress = DAG.getCopyFromReg(DAG.getEntryNode(), DL, Reg, VT);
11234   }
11235 
11236   // The XPACLRI instruction assembles to a hint-space instruction before
11237   // Armv8.3-A therefore this instruction can be safely used for any pre
11238   // Armv8.3-A architectures. On Armv8.3-A and onwards XPACI is available so use
11239   // that instead.
11240   SDNode *St;
11241   if (Subtarget->hasPAuth()) {
11242     St = DAG.getMachineNode(AArch64::XPACI, DL, VT, ReturnAddress);
11243   } else {
11244     // XPACLRI operates on LR therefore we must move the operand accordingly.
11245     SDValue Chain =
11246         DAG.getCopyToReg(DAG.getEntryNode(), DL, AArch64::LR, ReturnAddress);
11247     St = DAG.getMachineNode(AArch64::XPACLRI, DL, VT, Chain);
11248   }
11249   return SDValue(St, 0);
11250 }
11251 
11252 /// LowerShiftParts - Lower SHL_PARTS/SRA_PARTS/SRL_PARTS, which returns two
11253 /// i32 values and take a 2 x i32 value to shift plus a shift amount.
LowerShiftParts(SDValue Op,SelectionDAG & DAG) const11254 SDValue AArch64TargetLowering::LowerShiftParts(SDValue Op,
11255                                                SelectionDAG &DAG) const {
11256   SDValue Lo, Hi;
11257   expandShiftParts(Op.getNode(), Lo, Hi, DAG);
11258   return DAG.getMergeValues({Lo, Hi}, SDLoc(Op));
11259 }
11260 
isOffsetFoldingLegal(const GlobalAddressSDNode * GA) const11261 bool AArch64TargetLowering::isOffsetFoldingLegal(
11262     const GlobalAddressSDNode *GA) const {
11263   // Offsets are folded in the DAG combine rather than here so that we can
11264   // intelligently choose an offset based on the uses.
11265   return false;
11266 }
11267 
isFPImmLegal(const APFloat & Imm,EVT VT,bool OptForSize) const11268 bool AArch64TargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT,
11269                                          bool OptForSize) const {
11270   bool IsLegal = false;
11271   // We can materialize #0.0 as fmov $Rd, XZR for 64-bit, 32-bit cases, and
11272   // 16-bit case when target has full fp16 support.
11273   // We encode bf16 bit patterns as if they were fp16. This results in very
11274   // strange looking assembly but should populate the register with appropriate
11275   // values. Let's say we wanted to encode 0xR3FC0 which is 1.5 in BF16. We will
11276   // end up encoding this as the imm8 0x7f. This imm8 will be expanded to the
11277   // FP16 1.9375 which shares the same bit pattern as BF16 1.5.
11278   // FIXME: We should be able to handle f128 as well with a clever lowering.
11279   const APInt ImmInt = Imm.bitcastToAPInt();
11280   if (VT == MVT::f64)
11281     IsLegal = AArch64_AM::getFP64Imm(ImmInt) != -1 || Imm.isPosZero();
11282   else if (VT == MVT::f32)
11283     IsLegal = AArch64_AM::getFP32Imm(ImmInt) != -1 || Imm.isPosZero();
11284   else if (VT == MVT::f16 || VT == MVT::bf16)
11285     IsLegal =
11286         (Subtarget->hasFullFP16() && AArch64_AM::getFP16Imm(ImmInt) != -1) ||
11287         Imm.isPosZero();
11288 
11289   // If we can not materialize in immediate field for fmov, check if the
11290   // value can be encoded as the immediate operand of a logical instruction.
11291   // The immediate value will be created with either MOVZ, MOVN, or ORR.
11292   // TODO: fmov h0, w0 is also legal, however we don't have an isel pattern to
11293   //       generate that fmov.
11294   if (!IsLegal && (VT == MVT::f64 || VT == MVT::f32)) {
11295     // The cost is actually exactly the same for mov+fmov vs. adrp+ldr;
11296     // however the mov+fmov sequence is always better because of the reduced
11297     // cache pressure. The timings are still the same if you consider
11298     // movw+movk+fmov vs. adrp+ldr (it's one instruction longer, but the
11299     // movw+movk is fused). So we limit up to 2 instrdduction at most.
11300     SmallVector<AArch64_IMM::ImmInsnModel, 4> Insn;
11301     AArch64_IMM::expandMOVImm(ImmInt.getZExtValue(), VT.getSizeInBits(), Insn);
11302     unsigned Limit = (OptForSize ? 1 : (Subtarget->hasFuseLiterals() ? 5 : 2));
11303     IsLegal = Insn.size() <= Limit;
11304   }
11305 
11306   LLVM_DEBUG(dbgs() << (IsLegal ? "Legal " : "Illegal ") << VT
11307                     << " imm value: "; Imm.dump(););
11308   return IsLegal;
11309 }
11310 
11311 //===----------------------------------------------------------------------===//
11312 //                          AArch64 Optimization Hooks
11313 //===----------------------------------------------------------------------===//
11314 
getEstimate(const AArch64Subtarget * ST,unsigned Opcode,SDValue Operand,SelectionDAG & DAG,int & ExtraSteps)11315 static SDValue getEstimate(const AArch64Subtarget *ST, unsigned Opcode,
11316                            SDValue Operand, SelectionDAG &DAG,
11317                            int &ExtraSteps) {
11318   EVT VT = Operand.getValueType();
11319   if ((ST->hasNEON() &&
11320        (VT == MVT::f64 || VT == MVT::v1f64 || VT == MVT::v2f64 ||
11321         VT == MVT::f32 || VT == MVT::v1f32 || VT == MVT::v2f32 ||
11322         VT == MVT::v4f32)) ||
11323       (ST->hasSVE() &&
11324        (VT == MVT::nxv8f16 || VT == MVT::nxv4f32 || VT == MVT::nxv2f64))) {
11325     if (ExtraSteps == TargetLoweringBase::ReciprocalEstimate::Unspecified) {
11326       // For the reciprocal estimates, convergence is quadratic, so the number
11327       // of digits is doubled after each iteration.  In ARMv8, the accuracy of
11328       // the initial estimate is 2^-8.  Thus the number of extra steps to refine
11329       // the result for float (23 mantissa bits) is 2 and for double (52
11330       // mantissa bits) is 3.
11331       constexpr unsigned AccurateBits = 8;
11332       unsigned DesiredBits =
11333           APFloat::semanticsPrecision(DAG.EVTToAPFloatSemantics(VT));
11334       ExtraSteps = DesiredBits <= AccurateBits
11335                        ? 0
11336                        : Log2_64_Ceil(DesiredBits) - Log2_64_Ceil(AccurateBits);
11337     }
11338 
11339     return DAG.getNode(Opcode, SDLoc(Operand), VT, Operand);
11340   }
11341 
11342   return SDValue();
11343 }
11344 
11345 SDValue
getSqrtInputTest(SDValue Op,SelectionDAG & DAG,const DenormalMode & Mode) const11346 AArch64TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
11347                                         const DenormalMode &Mode) const {
11348   SDLoc DL(Op);
11349   EVT VT = Op.getValueType();
11350   EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
11351   SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
11352   return DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
11353 }
11354 
11355 SDValue
getSqrtResultForDenormInput(SDValue Op,SelectionDAG & DAG) const11356 AArch64TargetLowering::getSqrtResultForDenormInput(SDValue Op,
11357                                                    SelectionDAG &DAG) const {
11358   return Op;
11359 }
11360 
getSqrtEstimate(SDValue Operand,SelectionDAG & DAG,int Enabled,int & ExtraSteps,bool & UseOneConst,bool Reciprocal) const11361 SDValue AArch64TargetLowering::getSqrtEstimate(SDValue Operand,
11362                                                SelectionDAG &DAG, int Enabled,
11363                                                int &ExtraSteps,
11364                                                bool &UseOneConst,
11365                                                bool Reciprocal) const {
11366   if (Enabled == ReciprocalEstimate::Enabled ||
11367       (Enabled == ReciprocalEstimate::Unspecified && Subtarget->useRSqrt()))
11368     if (SDValue Estimate = getEstimate(Subtarget, AArch64ISD::FRSQRTE, Operand,
11369                                        DAG, ExtraSteps)) {
11370       SDLoc DL(Operand);
11371       EVT VT = Operand.getValueType();
11372 
11373       SDNodeFlags Flags;
11374       Flags.setAllowReassociation(true);
11375 
11376       // Newton reciprocal square root iteration: E * 0.5 * (3 - X * E^2)
11377       // AArch64 reciprocal square root iteration instruction: 0.5 * (3 - M * N)
11378       for (int i = ExtraSteps; i > 0; --i) {
11379         SDValue Step = DAG.getNode(ISD::FMUL, DL, VT, Estimate, Estimate,
11380                                    Flags);
11381         Step = DAG.getNode(AArch64ISD::FRSQRTS, DL, VT, Operand, Step, Flags);
11382         Estimate = DAG.getNode(ISD::FMUL, DL, VT, Estimate, Step, Flags);
11383       }
11384       if (!Reciprocal)
11385         Estimate = DAG.getNode(ISD::FMUL, DL, VT, Operand, Estimate, Flags);
11386 
11387       ExtraSteps = 0;
11388       return Estimate;
11389     }
11390 
11391   return SDValue();
11392 }
11393 
getRecipEstimate(SDValue Operand,SelectionDAG & DAG,int Enabled,int & ExtraSteps) const11394 SDValue AArch64TargetLowering::getRecipEstimate(SDValue Operand,
11395                                                 SelectionDAG &DAG, int Enabled,
11396                                                 int &ExtraSteps) const {
11397   if (Enabled == ReciprocalEstimate::Enabled)
11398     if (SDValue Estimate = getEstimate(Subtarget, AArch64ISD::FRECPE, Operand,
11399                                        DAG, ExtraSteps)) {
11400       SDLoc DL(Operand);
11401       EVT VT = Operand.getValueType();
11402 
11403       SDNodeFlags Flags;
11404       Flags.setAllowReassociation(true);
11405 
11406       // Newton reciprocal iteration: E * (2 - X * E)
11407       // AArch64 reciprocal iteration instruction: (2 - M * N)
11408       for (int i = ExtraSteps; i > 0; --i) {
11409         SDValue Step = DAG.getNode(AArch64ISD::FRECPS, DL, VT, Operand,
11410                                    Estimate, Flags);
11411         Estimate = DAG.getNode(ISD::FMUL, DL, VT, Estimate, Step, Flags);
11412       }
11413 
11414       ExtraSteps = 0;
11415       return Estimate;
11416     }
11417 
11418   return SDValue();
11419 }
11420 
11421 //===----------------------------------------------------------------------===//
11422 //                          AArch64 Inline Assembly Support
11423 //===----------------------------------------------------------------------===//
11424 
11425 // Table of Constraints
11426 // TODO: This is the current set of constraints supported by ARM for the
11427 // compiler, not all of them may make sense.
11428 //
11429 // r - A general register
11430 // w - An FP/SIMD register of some size in the range v0-v31
11431 // x - An FP/SIMD register of some size in the range v0-v15
11432 // I - Constant that can be used with an ADD instruction
11433 // J - Constant that can be used with a SUB instruction
11434 // K - Constant that can be used with a 32-bit logical instruction
11435 // L - Constant that can be used with a 64-bit logical instruction
11436 // M - Constant that can be used as a 32-bit MOV immediate
11437 // N - Constant that can be used as a 64-bit MOV immediate
11438 // Q - A memory reference with base register and no offset
11439 // S - A symbolic address
11440 // Y - Floating point constant zero
11441 // Z - Integer constant zero
11442 //
11443 //   Note that general register operands will be output using their 64-bit x
11444 // register name, whatever the size of the variable, unless the asm operand
11445 // is prefixed by the %w modifier. Floating-point and SIMD register operands
11446 // will be output with the v prefix unless prefixed by the %b, %h, %s, %d or
11447 // %q modifier.
LowerXConstraint(EVT ConstraintVT) const11448 const char *AArch64TargetLowering::LowerXConstraint(EVT ConstraintVT) const {
11449   // At this point, we have to lower this constraint to something else, so we
11450   // lower it to an "r" or "w". However, by doing this we will force the result
11451   // to be in register, while the X constraint is much more permissive.
11452   //
11453   // Although we are correct (we are free to emit anything, without
11454   // constraints), we might break use cases that would expect us to be more
11455   // efficient and emit something else.
11456   if (!Subtarget->hasFPARMv8())
11457     return "r";
11458 
11459   if (ConstraintVT.isFloatingPoint())
11460     return "w";
11461 
11462   if (ConstraintVT.isVector() &&
11463      (ConstraintVT.getSizeInBits() == 64 ||
11464       ConstraintVT.getSizeInBits() == 128))
11465     return "w";
11466 
11467   return "r";
11468 }
11469 
11470 enum class PredicateConstraint { Uph, Upl, Upa };
11471 
11472 static std::optional<PredicateConstraint>
parsePredicateConstraint(StringRef Constraint)11473 parsePredicateConstraint(StringRef Constraint) {
11474   return StringSwitch<std::optional<PredicateConstraint>>(Constraint)
11475       .Case("Uph", PredicateConstraint::Uph)
11476       .Case("Upl", PredicateConstraint::Upl)
11477       .Case("Upa", PredicateConstraint::Upa)
11478       .Default(std::nullopt);
11479 }
11480 
11481 static const TargetRegisterClass *
getPredicateRegisterClass(PredicateConstraint Constraint,EVT VT)11482 getPredicateRegisterClass(PredicateConstraint Constraint, EVT VT) {
11483   if (VT != MVT::aarch64svcount &&
11484       (!VT.isScalableVector() || VT.getVectorElementType() != MVT::i1))
11485     return nullptr;
11486 
11487   switch (Constraint) {
11488   case PredicateConstraint::Uph:
11489     return VT == MVT::aarch64svcount ? &AArch64::PNR_p8to15RegClass
11490                                      : &AArch64::PPR_p8to15RegClass;
11491   case PredicateConstraint::Upl:
11492     return VT == MVT::aarch64svcount ? &AArch64::PNR_3bRegClass
11493                                      : &AArch64::PPR_3bRegClass;
11494   case PredicateConstraint::Upa:
11495     return VT == MVT::aarch64svcount ? &AArch64::PNRRegClass
11496                                      : &AArch64::PPRRegClass;
11497   }
11498 
11499   llvm_unreachable("Missing PredicateConstraint!");
11500 }
11501 
11502 enum class ReducedGprConstraint { Uci, Ucj };
11503 
11504 static std::optional<ReducedGprConstraint>
parseReducedGprConstraint(StringRef Constraint)11505 parseReducedGprConstraint(StringRef Constraint) {
11506   return StringSwitch<std::optional<ReducedGprConstraint>>(Constraint)
11507       .Case("Uci", ReducedGprConstraint::Uci)
11508       .Case("Ucj", ReducedGprConstraint::Ucj)
11509       .Default(std::nullopt);
11510 }
11511 
11512 static const TargetRegisterClass *
getReducedGprRegisterClass(ReducedGprConstraint Constraint,EVT VT)11513 getReducedGprRegisterClass(ReducedGprConstraint Constraint, EVT VT) {
11514   if (!VT.isScalarInteger() || VT.getFixedSizeInBits() > 64)
11515     return nullptr;
11516 
11517   switch (Constraint) {
11518   case ReducedGprConstraint::Uci:
11519     return &AArch64::MatrixIndexGPR32_8_11RegClass;
11520   case ReducedGprConstraint::Ucj:
11521     return &AArch64::MatrixIndexGPR32_12_15RegClass;
11522   }
11523 
11524   llvm_unreachable("Missing ReducedGprConstraint!");
11525 }
11526 
11527 // The set of cc code supported is from
11528 // https://gcc.gnu.org/onlinedocs/gcc/Extended-Asm.html#Flag-Output-Operands
parseConstraintCode(llvm::StringRef Constraint)11529 static AArch64CC::CondCode parseConstraintCode(llvm::StringRef Constraint) {
11530   AArch64CC::CondCode Cond = StringSwitch<AArch64CC::CondCode>(Constraint)
11531                                  .Case("{@cchi}", AArch64CC::HI)
11532                                  .Case("{@cccs}", AArch64CC::HS)
11533                                  .Case("{@cclo}", AArch64CC::LO)
11534                                  .Case("{@ccls}", AArch64CC::LS)
11535                                  .Case("{@cccc}", AArch64CC::LO)
11536                                  .Case("{@cceq}", AArch64CC::EQ)
11537                                  .Case("{@ccgt}", AArch64CC::GT)
11538                                  .Case("{@ccge}", AArch64CC::GE)
11539                                  .Case("{@cclt}", AArch64CC::LT)
11540                                  .Case("{@ccle}", AArch64CC::LE)
11541                                  .Case("{@cchs}", AArch64CC::HS)
11542                                  .Case("{@ccne}", AArch64CC::NE)
11543                                  .Case("{@ccvc}", AArch64CC::VC)
11544                                  .Case("{@ccpl}", AArch64CC::PL)
11545                                  .Case("{@ccvs}", AArch64CC::VS)
11546                                  .Case("{@ccmi}", AArch64CC::MI)
11547                                  .Default(AArch64CC::Invalid);
11548   return Cond;
11549 }
11550 
11551 /// Helper function to create 'CSET', which is equivalent to 'CSINC <Wd>, WZR,
11552 /// WZR, invert(<cond>)'.
getSETCC(AArch64CC::CondCode CC,SDValue NZCV,const SDLoc & DL,SelectionDAG & DAG)11553 static SDValue getSETCC(AArch64CC::CondCode CC, SDValue NZCV, const SDLoc &DL,
11554                         SelectionDAG &DAG) {
11555   return DAG.getNode(
11556       AArch64ISD::CSINC, DL, MVT::i32, DAG.getConstant(0, DL, MVT::i32),
11557       DAG.getConstant(0, DL, MVT::i32),
11558       DAG.getConstant(getInvertedCondCode(CC), DL, MVT::i32), NZCV);
11559 }
11560 
11561 // Lower @cc flag output via getSETCC.
LowerAsmOutputForConstraint(SDValue & Chain,SDValue & Glue,const SDLoc & DL,const AsmOperandInfo & OpInfo,SelectionDAG & DAG) const11562 SDValue AArch64TargetLowering::LowerAsmOutputForConstraint(
11563     SDValue &Chain, SDValue &Glue, const SDLoc &DL,
11564     const AsmOperandInfo &OpInfo, SelectionDAG &DAG) const {
11565   AArch64CC::CondCode Cond = parseConstraintCode(OpInfo.ConstraintCode);
11566   if (Cond == AArch64CC::Invalid)
11567     return SDValue();
11568   // The output variable should be a scalar integer.
11569   if (OpInfo.ConstraintVT.isVector() || !OpInfo.ConstraintVT.isInteger() ||
11570       OpInfo.ConstraintVT.getSizeInBits() < 8)
11571     report_fatal_error("Flag output operand is of invalid type");
11572 
11573   // Get NZCV register. Only update chain when copyfrom is glued.
11574   if (Glue.getNode()) {
11575     Glue = DAG.getCopyFromReg(Chain, DL, AArch64::NZCV, MVT::i32, Glue);
11576     Chain = Glue.getValue(1);
11577   } else
11578     Glue = DAG.getCopyFromReg(Chain, DL, AArch64::NZCV, MVT::i32);
11579   // Extract CC code.
11580   SDValue CC = getSETCC(Cond, Glue, DL, DAG);
11581 
11582   SDValue Result;
11583 
11584   // Truncate or ZERO_EXTEND based on value types.
11585   if (OpInfo.ConstraintVT.getSizeInBits() <= 32)
11586     Result = DAG.getNode(ISD::TRUNCATE, DL, OpInfo.ConstraintVT, CC);
11587   else
11588     Result = DAG.getNode(ISD::ZERO_EXTEND, DL, OpInfo.ConstraintVT, CC);
11589 
11590   return Result;
11591 }
11592 
11593 /// getConstraintType - Given a constraint letter, return the type of
11594 /// constraint it is for this target.
11595 AArch64TargetLowering::ConstraintType
getConstraintType(StringRef Constraint) const11596 AArch64TargetLowering::getConstraintType(StringRef Constraint) const {
11597   if (Constraint.size() == 1) {
11598     switch (Constraint[0]) {
11599     default:
11600       break;
11601     case 'x':
11602     case 'w':
11603     case 'y':
11604       return C_RegisterClass;
11605     // An address with a single base register. Due to the way we
11606     // currently handle addresses it is the same as 'r'.
11607     case 'Q':
11608       return C_Memory;
11609     case 'I':
11610     case 'J':
11611     case 'K':
11612     case 'L':
11613     case 'M':
11614     case 'N':
11615     case 'Y':
11616     case 'Z':
11617       return C_Immediate;
11618     case 'z':
11619     case 'S': // A symbol or label reference with a constant offset
11620       return C_Other;
11621     }
11622   } else if (parsePredicateConstraint(Constraint))
11623     return C_RegisterClass;
11624   else if (parseReducedGprConstraint(Constraint))
11625     return C_RegisterClass;
11626   else if (parseConstraintCode(Constraint) != AArch64CC::Invalid)
11627     return C_Other;
11628   return TargetLowering::getConstraintType(Constraint);
11629 }
11630 
11631 /// Examine constraint type and operand type and determine a weight value.
11632 /// This object must already have been set up with the operand type
11633 /// and the current alternative constraint selected.
11634 TargetLowering::ConstraintWeight
getSingleConstraintMatchWeight(AsmOperandInfo & info,const char * constraint) const11635 AArch64TargetLowering::getSingleConstraintMatchWeight(
11636     AsmOperandInfo &info, const char *constraint) const {
11637   ConstraintWeight weight = CW_Invalid;
11638   Value *CallOperandVal = info.CallOperandVal;
11639   // If we don't have a value, we can't do a match,
11640   // but allow it at the lowest weight.
11641   if (!CallOperandVal)
11642     return CW_Default;
11643   Type *type = CallOperandVal->getType();
11644   // Look at the constraint type.
11645   switch (*constraint) {
11646   default:
11647     weight = TargetLowering::getSingleConstraintMatchWeight(info, constraint);
11648     break;
11649   case 'x':
11650   case 'w':
11651   case 'y':
11652     if (type->isFloatingPointTy() || type->isVectorTy())
11653       weight = CW_Register;
11654     break;
11655   case 'z':
11656     weight = CW_Constant;
11657     break;
11658   case 'U':
11659     if (parsePredicateConstraint(constraint) ||
11660         parseReducedGprConstraint(constraint))
11661       weight = CW_Register;
11662     break;
11663   }
11664   return weight;
11665 }
11666 
11667 std::pair<unsigned, const TargetRegisterClass *>
getRegForInlineAsmConstraint(const TargetRegisterInfo * TRI,StringRef Constraint,MVT VT) const11668 AArch64TargetLowering::getRegForInlineAsmConstraint(
11669     const TargetRegisterInfo *TRI, StringRef Constraint, MVT VT) const {
11670   if (Constraint.size() == 1) {
11671     switch (Constraint[0]) {
11672     case 'r':
11673       if (VT.isScalableVector())
11674         return std::make_pair(0U, nullptr);
11675       if (Subtarget->hasLS64() && VT.getSizeInBits() == 512)
11676         return std::make_pair(0U, &AArch64::GPR64x8ClassRegClass);
11677       if (VT.getFixedSizeInBits() == 64)
11678         return std::make_pair(0U, &AArch64::GPR64commonRegClass);
11679       return std::make_pair(0U, &AArch64::GPR32commonRegClass);
11680     case 'w': {
11681       if (!Subtarget->hasFPARMv8())
11682         break;
11683       if (VT.isScalableVector()) {
11684         if (VT.getVectorElementType() != MVT::i1)
11685           return std::make_pair(0U, &AArch64::ZPRRegClass);
11686         return std::make_pair(0U, nullptr);
11687       }
11688       if (VT == MVT::Other)
11689         break;
11690       uint64_t VTSize = VT.getFixedSizeInBits();
11691       if (VTSize == 16)
11692         return std::make_pair(0U, &AArch64::FPR16RegClass);
11693       if (VTSize == 32)
11694         return std::make_pair(0U, &AArch64::FPR32RegClass);
11695       if (VTSize == 64)
11696         return std::make_pair(0U, &AArch64::FPR64RegClass);
11697       if (VTSize == 128)
11698         return std::make_pair(0U, &AArch64::FPR128RegClass);
11699       break;
11700     }
11701     // The instructions that this constraint is designed for can
11702     // only take 128-bit registers so just use that regclass.
11703     case 'x':
11704       if (!Subtarget->hasFPARMv8())
11705         break;
11706       if (VT.isScalableVector())
11707         return std::make_pair(0U, &AArch64::ZPR_4bRegClass);
11708       if (VT.getSizeInBits() == 128)
11709         return std::make_pair(0U, &AArch64::FPR128_loRegClass);
11710       break;
11711     case 'y':
11712       if (!Subtarget->hasFPARMv8())
11713         break;
11714       if (VT.isScalableVector())
11715         return std::make_pair(0U, &AArch64::ZPR_3bRegClass);
11716       break;
11717     }
11718   } else {
11719     if (const auto PC = parsePredicateConstraint(Constraint))
11720       if (const auto *RegClass = getPredicateRegisterClass(*PC, VT))
11721         return std::make_pair(0U, RegClass);
11722 
11723     if (const auto RGC = parseReducedGprConstraint(Constraint))
11724       if (const auto *RegClass = getReducedGprRegisterClass(*RGC, VT))
11725         return std::make_pair(0U, RegClass);
11726   }
11727   if (StringRef("{cc}").equals_insensitive(Constraint) ||
11728       parseConstraintCode(Constraint) != AArch64CC::Invalid)
11729     return std::make_pair(unsigned(AArch64::NZCV), &AArch64::CCRRegClass);
11730 
11731   if (Constraint == "{za}") {
11732     return std::make_pair(unsigned(AArch64::ZA), &AArch64::MPRRegClass);
11733   }
11734 
11735   if (Constraint == "{zt0}") {
11736     return std::make_pair(unsigned(AArch64::ZT0), &AArch64::ZTRRegClass);
11737   }
11738 
11739   // Use the default implementation in TargetLowering to convert the register
11740   // constraint into a member of a register class.
11741   std::pair<unsigned, const TargetRegisterClass *> Res;
11742   Res = TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT);
11743 
11744   // Not found as a standard register?
11745   if (!Res.second) {
11746     unsigned Size = Constraint.size();
11747     if ((Size == 4 || Size == 5) && Constraint[0] == '{' &&
11748         tolower(Constraint[1]) == 'v' && Constraint[Size - 1] == '}') {
11749       int RegNo;
11750       bool Failed = Constraint.slice(2, Size - 1).getAsInteger(10, RegNo);
11751       if (!Failed && RegNo >= 0 && RegNo <= 31) {
11752         // v0 - v31 are aliases of q0 - q31 or d0 - d31 depending on size.
11753         // By default we'll emit v0-v31 for this unless there's a modifier where
11754         // we'll emit the correct register as well.
11755         if (VT != MVT::Other && VT.getSizeInBits() == 64) {
11756           Res.first = AArch64::FPR64RegClass.getRegister(RegNo);
11757           Res.second = &AArch64::FPR64RegClass;
11758         } else {
11759           Res.first = AArch64::FPR128RegClass.getRegister(RegNo);
11760           Res.second = &AArch64::FPR128RegClass;
11761         }
11762       }
11763     }
11764   }
11765 
11766   if (Res.second && !Subtarget->hasFPARMv8() &&
11767       !AArch64::GPR32allRegClass.hasSubClassEq(Res.second) &&
11768       !AArch64::GPR64allRegClass.hasSubClassEq(Res.second))
11769     return std::make_pair(0U, nullptr);
11770 
11771   return Res;
11772 }
11773 
getAsmOperandValueType(const DataLayout & DL,llvm::Type * Ty,bool AllowUnknown) const11774 EVT AArch64TargetLowering::getAsmOperandValueType(const DataLayout &DL,
11775                                                   llvm::Type *Ty,
11776                                                   bool AllowUnknown) const {
11777   if (Subtarget->hasLS64() && Ty->isIntegerTy(512))
11778     return EVT(MVT::i64x8);
11779 
11780   return TargetLowering::getAsmOperandValueType(DL, Ty, AllowUnknown);
11781 }
11782 
11783 /// LowerAsmOperandForConstraint - Lower the specified operand into the Ops
11784 /// vector.  If it is invalid, don't add anything to Ops.
LowerAsmOperandForConstraint(SDValue Op,StringRef Constraint,std::vector<SDValue> & Ops,SelectionDAG & DAG) const11785 void AArch64TargetLowering::LowerAsmOperandForConstraint(
11786     SDValue Op, StringRef Constraint, std::vector<SDValue> &Ops,
11787     SelectionDAG &DAG) const {
11788   SDValue Result;
11789 
11790   // Currently only support length 1 constraints.
11791   if (Constraint.size() != 1)
11792     return;
11793 
11794   char ConstraintLetter = Constraint[0];
11795   switch (ConstraintLetter) {
11796   default:
11797     break;
11798 
11799   // This set of constraints deal with valid constants for various instructions.
11800   // Validate and return a target constant for them if we can.
11801   case 'z': {
11802     // 'z' maps to xzr or wzr so it needs an input of 0.
11803     if (!isNullConstant(Op))
11804       return;
11805 
11806     if (Op.getValueType() == MVT::i64)
11807       Result = DAG.getRegister(AArch64::XZR, MVT::i64);
11808     else
11809       Result = DAG.getRegister(AArch64::WZR, MVT::i32);
11810     break;
11811   }
11812   case 'S':
11813     // Use the generic code path for "s". In GCC's aarch64 port, "S" is
11814     // supported for PIC while "s" isn't, making "s" less useful. We implement
11815     // "S" but not "s".
11816     TargetLowering::LowerAsmOperandForConstraint(Op, "s", Ops, DAG);
11817     break;
11818 
11819   case 'I':
11820   case 'J':
11821   case 'K':
11822   case 'L':
11823   case 'M':
11824   case 'N':
11825     ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op);
11826     if (!C)
11827       return;
11828 
11829     // Grab the value and do some validation.
11830     uint64_t CVal = C->getZExtValue();
11831     switch (ConstraintLetter) {
11832     // The I constraint applies only to simple ADD or SUB immediate operands:
11833     // i.e. 0 to 4095 with optional shift by 12
11834     // The J constraint applies only to ADD or SUB immediates that would be
11835     // valid when negated, i.e. if [an add pattern] were to be output as a SUB
11836     // instruction [or vice versa], in other words -1 to -4095 with optional
11837     // left shift by 12.
11838     case 'I':
11839       if (isUInt<12>(CVal) || isShiftedUInt<12, 12>(CVal))
11840         break;
11841       return;
11842     case 'J': {
11843       uint64_t NVal = -C->getSExtValue();
11844       if (isUInt<12>(NVal) || isShiftedUInt<12, 12>(NVal)) {
11845         CVal = C->getSExtValue();
11846         break;
11847       }
11848       return;
11849     }
11850     // The K and L constraints apply *only* to logical immediates, including
11851     // what used to be the MOVI alias for ORR (though the MOVI alias has now
11852     // been removed and MOV should be used). So these constraints have to
11853     // distinguish between bit patterns that are valid 32-bit or 64-bit
11854     // "bitmask immediates": for example 0xaaaaaaaa is a valid bimm32 (K), but
11855     // not a valid bimm64 (L) where 0xaaaaaaaaaaaaaaaa would be valid, and vice
11856     // versa.
11857     case 'K':
11858       if (AArch64_AM::isLogicalImmediate(CVal, 32))
11859         break;
11860       return;
11861     case 'L':
11862       if (AArch64_AM::isLogicalImmediate(CVal, 64))
11863         break;
11864       return;
11865     // The M and N constraints are a superset of K and L respectively, for use
11866     // with the MOV (immediate) alias. As well as the logical immediates they
11867     // also match 32 or 64-bit immediates that can be loaded either using a
11868     // *single* MOVZ or MOVN , such as 32-bit 0x12340000, 0x00001234, 0xffffedca
11869     // (M) or 64-bit 0x1234000000000000 (N) etc.
11870     // As a note some of this code is liberally stolen from the asm parser.
11871     case 'M': {
11872       if (!isUInt<32>(CVal))
11873         return;
11874       if (AArch64_AM::isLogicalImmediate(CVal, 32))
11875         break;
11876       if ((CVal & 0xFFFF) == CVal)
11877         break;
11878       if ((CVal & 0xFFFF0000ULL) == CVal)
11879         break;
11880       uint64_t NCVal = ~(uint32_t)CVal;
11881       if ((NCVal & 0xFFFFULL) == NCVal)
11882         break;
11883       if ((NCVal & 0xFFFF0000ULL) == NCVal)
11884         break;
11885       return;
11886     }
11887     case 'N': {
11888       if (AArch64_AM::isLogicalImmediate(CVal, 64))
11889         break;
11890       if ((CVal & 0xFFFFULL) == CVal)
11891         break;
11892       if ((CVal & 0xFFFF0000ULL) == CVal)
11893         break;
11894       if ((CVal & 0xFFFF00000000ULL) == CVal)
11895         break;
11896       if ((CVal & 0xFFFF000000000000ULL) == CVal)
11897         break;
11898       uint64_t NCVal = ~CVal;
11899       if ((NCVal & 0xFFFFULL) == NCVal)
11900         break;
11901       if ((NCVal & 0xFFFF0000ULL) == NCVal)
11902         break;
11903       if ((NCVal & 0xFFFF00000000ULL) == NCVal)
11904         break;
11905       if ((NCVal & 0xFFFF000000000000ULL) == NCVal)
11906         break;
11907       return;
11908     }
11909     default:
11910       return;
11911     }
11912 
11913     // All assembler immediates are 64-bit integers.
11914     Result = DAG.getTargetConstant(CVal, SDLoc(Op), MVT::i64);
11915     break;
11916   }
11917 
11918   if (Result.getNode()) {
11919     Ops.push_back(Result);
11920     return;
11921   }
11922 
11923   return TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
11924 }
11925 
11926 //===----------------------------------------------------------------------===//
11927 //                     AArch64 Advanced SIMD Support
11928 //===----------------------------------------------------------------------===//
11929 
11930 /// WidenVector - Given a value in the V64 register class, produce the
11931 /// equivalent value in the V128 register class.
WidenVector(SDValue V64Reg,SelectionDAG & DAG)11932 static SDValue WidenVector(SDValue V64Reg, SelectionDAG &DAG) {
11933   EVT VT = V64Reg.getValueType();
11934   unsigned NarrowSize = VT.getVectorNumElements();
11935   MVT EltTy = VT.getVectorElementType().getSimpleVT();
11936   MVT WideTy = MVT::getVectorVT(EltTy, 2 * NarrowSize);
11937   SDLoc DL(V64Reg);
11938 
11939   return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, WideTy, DAG.getUNDEF(WideTy),
11940                      V64Reg, DAG.getConstant(0, DL, MVT::i64));
11941 }
11942 
11943 /// getExtFactor - Determine the adjustment factor for the position when
11944 /// generating an "extract from vector registers" instruction.
getExtFactor(SDValue & V)11945 static unsigned getExtFactor(SDValue &V) {
11946   EVT EltType = V.getValueType().getVectorElementType();
11947   return EltType.getSizeInBits() / 8;
11948 }
11949 
11950 // Check if a vector is built from one vector via extracted elements of
11951 // another together with an AND mask, ensuring that all elements fit
11952 // within range. This can be reconstructed using AND and NEON's TBL1.
ReconstructShuffleWithRuntimeMask(SDValue Op,SelectionDAG & DAG)11953 SDValue ReconstructShuffleWithRuntimeMask(SDValue Op, SelectionDAG &DAG) {
11954   assert(Op.getOpcode() == ISD::BUILD_VECTOR && "Unknown opcode!");
11955   SDLoc dl(Op);
11956   EVT VT = Op.getValueType();
11957   assert(!VT.isScalableVector() &&
11958          "Scalable vectors cannot be used with ISD::BUILD_VECTOR");
11959 
11960   // Can only recreate a shuffle with 16xi8 or 8xi8 elements, as they map
11961   // directly to TBL1.
11962   if (VT != MVT::v16i8 && VT != MVT::v8i8)
11963     return SDValue();
11964 
11965   unsigned NumElts = VT.getVectorNumElements();
11966   assert((NumElts == 8 || NumElts == 16) &&
11967          "Need to have exactly 8 or 16 elements in vector.");
11968 
11969   SDValue SourceVec;
11970   SDValue MaskSourceVec;
11971   SmallVector<SDValue, 16> AndMaskConstants;
11972 
11973   for (unsigned i = 0; i < NumElts; ++i) {
11974     SDValue V = Op.getOperand(i);
11975     if (V.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
11976       return SDValue();
11977 
11978     SDValue OperandSourceVec = V.getOperand(0);
11979     if (!SourceVec)
11980       SourceVec = OperandSourceVec;
11981     else if (SourceVec != OperandSourceVec)
11982       return SDValue();
11983 
11984     // This only looks at shuffles with elements that are
11985     // a) truncated by a constant AND mask extracted from a mask vector, or
11986     // b) extracted directly from a mask vector.
11987     SDValue MaskSource = V.getOperand(1);
11988     if (MaskSource.getOpcode() == ISD::AND) {
11989       if (!isa<ConstantSDNode>(MaskSource.getOperand(1)))
11990         return SDValue();
11991 
11992       AndMaskConstants.push_back(MaskSource.getOperand(1));
11993       MaskSource = MaskSource->getOperand(0);
11994     } else if (!AndMaskConstants.empty()) {
11995       // Either all or no operands should have an AND mask.
11996       return SDValue();
11997     }
11998 
11999     // An ANY_EXTEND may be inserted between the AND and the source vector
12000     // extraction. We don't care about that, so we can just skip it.
12001     if (MaskSource.getOpcode() == ISD::ANY_EXTEND)
12002       MaskSource = MaskSource.getOperand(0);
12003 
12004     if (MaskSource.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
12005       return SDValue();
12006 
12007     SDValue MaskIdx = MaskSource.getOperand(1);
12008     if (!isa<ConstantSDNode>(MaskIdx) ||
12009         !cast<ConstantSDNode>(MaskIdx)->getConstantIntValue()->equalsInt(i))
12010       return SDValue();
12011 
12012     // We only apply this if all elements come from the same vector with the
12013     // same vector type.
12014     if (!MaskSourceVec) {
12015       MaskSourceVec = MaskSource->getOperand(0);
12016       if (MaskSourceVec.getValueType() != VT)
12017         return SDValue();
12018     } else if (MaskSourceVec != MaskSource->getOperand(0)) {
12019       return SDValue();
12020     }
12021   }
12022 
12023   // We need a v16i8 for TBL, so we extend the source with a placeholder vector
12024   // for v8i8 to get a v16i8. As the pattern we are replacing is extract +
12025   // insert, we know that the index in the mask must be smaller than the number
12026   // of elements in the source, or we would have an out-of-bounds access.
12027   if (NumElts == 8)
12028     SourceVec = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v16i8, SourceVec,
12029                             DAG.getUNDEF(VT));
12030 
12031   // Preconditions met, so we can use a vector (AND +) TBL to build this vector.
12032   if (!AndMaskConstants.empty())
12033     MaskSourceVec = DAG.getNode(ISD::AND, dl, VT, MaskSourceVec,
12034                                 DAG.getBuildVector(VT, dl, AndMaskConstants));
12035 
12036   return DAG.getNode(
12037       ISD::INTRINSIC_WO_CHAIN, dl, VT,
12038       DAG.getConstant(Intrinsic::aarch64_neon_tbl1, dl, MVT::i32), SourceVec,
12039       MaskSourceVec);
12040 }
12041 
12042 // Gather data to see if the operation can be modelled as a
12043 // shuffle in combination with VEXTs.
ReconstructShuffle(SDValue Op,SelectionDAG & DAG) const12044 SDValue AArch64TargetLowering::ReconstructShuffle(SDValue Op,
12045                                                   SelectionDAG &DAG) const {
12046   assert(Op.getOpcode() == ISD::BUILD_VECTOR && "Unknown opcode!");
12047   LLVM_DEBUG(dbgs() << "AArch64TargetLowering::ReconstructShuffle\n");
12048   SDLoc dl(Op);
12049   EVT VT = Op.getValueType();
12050   assert(!VT.isScalableVector() &&
12051          "Scalable vectors cannot be used with ISD::BUILD_VECTOR");
12052   unsigned NumElts = VT.getVectorNumElements();
12053 
12054   struct ShuffleSourceInfo {
12055     SDValue Vec;
12056     unsigned MinElt;
12057     unsigned MaxElt;
12058 
12059     // We may insert some combination of BITCASTs and VEXT nodes to force Vec to
12060     // be compatible with the shuffle we intend to construct. As a result
12061     // ShuffleVec will be some sliding window into the original Vec.
12062     SDValue ShuffleVec;
12063 
12064     // Code should guarantee that element i in Vec starts at element "WindowBase
12065     // + i * WindowScale in ShuffleVec".
12066     int WindowBase;
12067     int WindowScale;
12068 
12069     ShuffleSourceInfo(SDValue Vec)
12070       : Vec(Vec), MinElt(std::numeric_limits<unsigned>::max()), MaxElt(0),
12071           ShuffleVec(Vec), WindowBase(0), WindowScale(1) {}
12072 
12073     bool operator ==(SDValue OtherVec) { return Vec == OtherVec; }
12074   };
12075 
12076   // First gather all vectors used as an immediate source for this BUILD_VECTOR
12077   // node.
12078   SmallVector<ShuffleSourceInfo, 2> Sources;
12079   for (unsigned i = 0; i < NumElts; ++i) {
12080     SDValue V = Op.getOperand(i);
12081     if (V.isUndef())
12082       continue;
12083     else if (V.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
12084              !isa<ConstantSDNode>(V.getOperand(1)) ||
12085              V.getOperand(0).getValueType().isScalableVector()) {
12086       LLVM_DEBUG(
12087           dbgs() << "Reshuffle failed: "
12088                     "a shuffle can only come from building a vector from "
12089                     "various elements of other fixed-width vectors, provided "
12090                     "their indices are constant\n");
12091       return SDValue();
12092     }
12093 
12094     // Add this element source to the list if it's not already there.
12095     SDValue SourceVec = V.getOperand(0);
12096     auto Source = find(Sources, SourceVec);
12097     if (Source == Sources.end())
12098       Source = Sources.insert(Sources.end(), ShuffleSourceInfo(SourceVec));
12099 
12100     // Update the minimum and maximum lane number seen.
12101     unsigned EltNo = V.getConstantOperandVal(1);
12102     Source->MinElt = std::min(Source->MinElt, EltNo);
12103     Source->MaxElt = std::max(Source->MaxElt, EltNo);
12104   }
12105 
12106   // If we have 3 or 4 sources, try to generate a TBL, which will at least be
12107   // better than moving to/from gpr registers for larger vectors.
12108   if ((Sources.size() == 3 || Sources.size() == 4) && NumElts > 4) {
12109     // Construct a mask for the tbl. We may need to adjust the index for types
12110     // larger than i8.
12111     SmallVector<unsigned, 16> Mask;
12112     unsigned OutputFactor = VT.getScalarSizeInBits() / 8;
12113     for (unsigned I = 0; I < NumElts; ++I) {
12114       SDValue V = Op.getOperand(I);
12115       if (V.isUndef()) {
12116         for (unsigned OF = 0; OF < OutputFactor; OF++)
12117           Mask.push_back(-1);
12118         continue;
12119       }
12120       // Set the Mask lanes adjusted for the size of the input and output
12121       // lanes. The Mask is always i8, so it will set OutputFactor lanes per
12122       // output element, adjusted in their positions per input and output types.
12123       unsigned Lane = V.getConstantOperandVal(1);
12124       for (unsigned S = 0; S < Sources.size(); S++) {
12125         if (V.getOperand(0) == Sources[S].Vec) {
12126           unsigned InputSize = Sources[S].Vec.getScalarValueSizeInBits();
12127           unsigned InputBase = 16 * S + Lane * InputSize / 8;
12128           for (unsigned OF = 0; OF < OutputFactor; OF++)
12129             Mask.push_back(InputBase + OF);
12130           break;
12131         }
12132       }
12133     }
12134 
12135     // Construct the tbl3/tbl4 out of an intrinsic, the sources converted to
12136     // v16i8, and the TBLMask
12137     SmallVector<SDValue, 16> TBLOperands;
12138     TBLOperands.push_back(DAG.getConstant(Sources.size() == 3
12139                                               ? Intrinsic::aarch64_neon_tbl3
12140                                               : Intrinsic::aarch64_neon_tbl4,
12141                                           dl, MVT::i32));
12142     for (unsigned i = 0; i < Sources.size(); i++) {
12143       SDValue Src = Sources[i].Vec;
12144       EVT SrcVT = Src.getValueType();
12145       Src = DAG.getBitcast(SrcVT.is64BitVector() ? MVT::v8i8 : MVT::v16i8, Src);
12146       assert((SrcVT.is64BitVector() || SrcVT.is128BitVector()) &&
12147              "Expected a legally typed vector");
12148       if (SrcVT.is64BitVector())
12149         Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v16i8, Src,
12150                           DAG.getUNDEF(MVT::v8i8));
12151       TBLOperands.push_back(Src);
12152     }
12153 
12154     SmallVector<SDValue, 16> TBLMask;
12155     for (unsigned i = 0; i < Mask.size(); i++)
12156       TBLMask.push_back(DAG.getConstant(Mask[i], dl, MVT::i32));
12157     assert((Mask.size() == 8 || Mask.size() == 16) &&
12158            "Expected a v8i8 or v16i8 Mask");
12159     TBLOperands.push_back(
12160         DAG.getBuildVector(Mask.size() == 8 ? MVT::v8i8 : MVT::v16i8, dl, TBLMask));
12161 
12162     SDValue Shuffle =
12163         DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl,
12164                     Mask.size() == 8 ? MVT::v8i8 : MVT::v16i8, TBLOperands);
12165     return DAG.getBitcast(VT, Shuffle);
12166   }
12167 
12168   if (Sources.size() > 2) {
12169     LLVM_DEBUG(dbgs() << "Reshuffle failed: currently only do something "
12170                       << "sensible when at most two source vectors are "
12171                       << "involved\n");
12172     return SDValue();
12173   }
12174 
12175   // Find out the smallest element size among result and two sources, and use
12176   // it as element size to build the shuffle_vector.
12177   EVT SmallestEltTy = VT.getVectorElementType();
12178   for (auto &Source : Sources) {
12179     EVT SrcEltTy = Source.Vec.getValueType().getVectorElementType();
12180     if (SrcEltTy.bitsLT(SmallestEltTy)) {
12181       SmallestEltTy = SrcEltTy;
12182     }
12183   }
12184   unsigned ResMultiplier =
12185       VT.getScalarSizeInBits() / SmallestEltTy.getFixedSizeInBits();
12186   uint64_t VTSize = VT.getFixedSizeInBits();
12187   NumElts = VTSize / SmallestEltTy.getFixedSizeInBits();
12188   EVT ShuffleVT = EVT::getVectorVT(*DAG.getContext(), SmallestEltTy, NumElts);
12189 
12190   // If the source vector is too wide or too narrow, we may nevertheless be able
12191   // to construct a compatible shuffle either by concatenating it with UNDEF or
12192   // extracting a suitable range of elements.
12193   for (auto &Src : Sources) {
12194     EVT SrcVT = Src.ShuffleVec.getValueType();
12195 
12196     TypeSize SrcVTSize = SrcVT.getSizeInBits();
12197     if (SrcVTSize == TypeSize::getFixed(VTSize))
12198       continue;
12199 
12200     // This stage of the search produces a source with the same element type as
12201     // the original, but with a total width matching the BUILD_VECTOR output.
12202     EVT EltVT = SrcVT.getVectorElementType();
12203     unsigned NumSrcElts = VTSize / EltVT.getFixedSizeInBits();
12204     EVT DestVT = EVT::getVectorVT(*DAG.getContext(), EltVT, NumSrcElts);
12205 
12206     if (SrcVTSize.getFixedValue() < VTSize) {
12207       assert(2 * SrcVTSize == VTSize);
12208       // We can pad out the smaller vector for free, so if it's part of a
12209       // shuffle...
12210       Src.ShuffleVec =
12211           DAG.getNode(ISD::CONCAT_VECTORS, dl, DestVT, Src.ShuffleVec,
12212                       DAG.getUNDEF(Src.ShuffleVec.getValueType()));
12213       continue;
12214     }
12215 
12216     if (SrcVTSize.getFixedValue() != 2 * VTSize) {
12217       LLVM_DEBUG(
12218           dbgs() << "Reshuffle failed: result vector too small to extract\n");
12219       return SDValue();
12220     }
12221 
12222     if (Src.MaxElt - Src.MinElt >= NumSrcElts) {
12223       LLVM_DEBUG(
12224           dbgs() << "Reshuffle failed: span too large for a VEXT to cope\n");
12225       return SDValue();
12226     }
12227 
12228     if (Src.MinElt >= NumSrcElts) {
12229       // The extraction can just take the second half
12230       Src.ShuffleVec =
12231           DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, DestVT, Src.ShuffleVec,
12232                       DAG.getConstant(NumSrcElts, dl, MVT::i64));
12233       Src.WindowBase = -NumSrcElts;
12234     } else if (Src.MaxElt < NumSrcElts) {
12235       // The extraction can just take the first half
12236       Src.ShuffleVec =
12237           DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, DestVT, Src.ShuffleVec,
12238                       DAG.getConstant(0, dl, MVT::i64));
12239     } else {
12240       // An actual VEXT is needed
12241       SDValue VEXTSrc1 =
12242           DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, DestVT, Src.ShuffleVec,
12243                       DAG.getConstant(0, dl, MVT::i64));
12244       SDValue VEXTSrc2 =
12245           DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, DestVT, Src.ShuffleVec,
12246                       DAG.getConstant(NumSrcElts, dl, MVT::i64));
12247       unsigned Imm = Src.MinElt * getExtFactor(VEXTSrc1);
12248 
12249       if (!SrcVT.is64BitVector()) {
12250         LLVM_DEBUG(
12251           dbgs() << "Reshuffle failed: don't know how to lower AArch64ISD::EXT "
12252                     "for SVE vectors.");
12253         return SDValue();
12254       }
12255 
12256       Src.ShuffleVec = DAG.getNode(AArch64ISD::EXT, dl, DestVT, VEXTSrc1,
12257                                    VEXTSrc2,
12258                                    DAG.getConstant(Imm, dl, MVT::i32));
12259       Src.WindowBase = -Src.MinElt;
12260     }
12261   }
12262 
12263   // Another possible incompatibility occurs from the vector element types. We
12264   // can fix this by bitcasting the source vectors to the same type we intend
12265   // for the shuffle.
12266   for (auto &Src : Sources) {
12267     EVT SrcEltTy = Src.ShuffleVec.getValueType().getVectorElementType();
12268     if (SrcEltTy == SmallestEltTy)
12269       continue;
12270     assert(ShuffleVT.getVectorElementType() == SmallestEltTy);
12271     if (DAG.getDataLayout().isBigEndian()) {
12272       Src.ShuffleVec =
12273           DAG.getNode(AArch64ISD::NVCAST, dl, ShuffleVT, Src.ShuffleVec);
12274     } else {
12275       Src.ShuffleVec = DAG.getNode(ISD::BITCAST, dl, ShuffleVT, Src.ShuffleVec);
12276     }
12277     Src.WindowScale =
12278         SrcEltTy.getFixedSizeInBits() / SmallestEltTy.getFixedSizeInBits();
12279     Src.WindowBase *= Src.WindowScale;
12280   }
12281 
12282   // Final check before we try to actually produce a shuffle.
12283   LLVM_DEBUG(for (auto Src
12284                   : Sources)
12285                  assert(Src.ShuffleVec.getValueType() == ShuffleVT););
12286 
12287   // The stars all align, our next step is to produce the mask for the shuffle.
12288   SmallVector<int, 8> Mask(ShuffleVT.getVectorNumElements(), -1);
12289   int BitsPerShuffleLane = ShuffleVT.getScalarSizeInBits();
12290   for (unsigned i = 0; i < VT.getVectorNumElements(); ++i) {
12291     SDValue Entry = Op.getOperand(i);
12292     if (Entry.isUndef())
12293       continue;
12294 
12295     auto Src = find(Sources, Entry.getOperand(0));
12296     int EltNo = cast<ConstantSDNode>(Entry.getOperand(1))->getSExtValue();
12297 
12298     // EXTRACT_VECTOR_ELT performs an implicit any_ext; BUILD_VECTOR an implicit
12299     // trunc. So only std::min(SrcBits, DestBits) actually get defined in this
12300     // segment.
12301     EVT OrigEltTy = Entry.getOperand(0).getValueType().getVectorElementType();
12302     int BitsDefined = std::min(OrigEltTy.getScalarSizeInBits(),
12303                                VT.getScalarSizeInBits());
12304     int LanesDefined = BitsDefined / BitsPerShuffleLane;
12305 
12306     // This source is expected to fill ResMultiplier lanes of the final shuffle,
12307     // starting at the appropriate offset.
12308     int *LaneMask = &Mask[i * ResMultiplier];
12309 
12310     int ExtractBase = EltNo * Src->WindowScale + Src->WindowBase;
12311     ExtractBase += NumElts * (Src - Sources.begin());
12312     for (int j = 0; j < LanesDefined; ++j)
12313       LaneMask[j] = ExtractBase + j;
12314   }
12315 
12316   // Final check before we try to produce nonsense...
12317   if (!isShuffleMaskLegal(Mask, ShuffleVT)) {
12318     LLVM_DEBUG(dbgs() << "Reshuffle failed: illegal shuffle mask\n");
12319     return SDValue();
12320   }
12321 
12322   SDValue ShuffleOps[] = { DAG.getUNDEF(ShuffleVT), DAG.getUNDEF(ShuffleVT) };
12323   for (unsigned i = 0; i < Sources.size(); ++i)
12324     ShuffleOps[i] = Sources[i].ShuffleVec;
12325 
12326   SDValue Shuffle = DAG.getVectorShuffle(ShuffleVT, dl, ShuffleOps[0],
12327                                          ShuffleOps[1], Mask);
12328   SDValue V;
12329   if (DAG.getDataLayout().isBigEndian()) {
12330     V = DAG.getNode(AArch64ISD::NVCAST, dl, VT, Shuffle);
12331   } else {
12332     V = DAG.getNode(ISD::BITCAST, dl, VT, Shuffle);
12333   }
12334 
12335   LLVM_DEBUG(dbgs() << "Reshuffle, creating node: "; Shuffle.dump();
12336              dbgs() << "Reshuffle, creating node: "; V.dump(););
12337 
12338   return V;
12339 }
12340 
12341 // check if an EXT instruction can handle the shuffle mask when the
12342 // vector sources of the shuffle are the same.
isSingletonEXTMask(ArrayRef<int> M,EVT VT,unsigned & Imm)12343 static bool isSingletonEXTMask(ArrayRef<int> M, EVT VT, unsigned &Imm) {
12344   unsigned NumElts = VT.getVectorNumElements();
12345 
12346   // Assume that the first shuffle index is not UNDEF.  Fail if it is.
12347   if (M[0] < 0)
12348     return false;
12349 
12350   Imm = M[0];
12351 
12352   // If this is a VEXT shuffle, the immediate value is the index of the first
12353   // element.  The other shuffle indices must be the successive elements after
12354   // the first one.
12355   unsigned ExpectedElt = Imm;
12356   for (unsigned i = 1; i < NumElts; ++i) {
12357     // Increment the expected index.  If it wraps around, just follow it
12358     // back to index zero and keep going.
12359     ++ExpectedElt;
12360     if (ExpectedElt == NumElts)
12361       ExpectedElt = 0;
12362 
12363     if (M[i] < 0)
12364       continue; // ignore UNDEF indices
12365     if (ExpectedElt != static_cast<unsigned>(M[i]))
12366       return false;
12367   }
12368 
12369   return true;
12370 }
12371 
12372 // Detect patterns of a0,a1,a2,a3,b0,b1,b2,b3,c0,c1,c2,c3,d0,d1,d2,d3 from
12373 // v4i32s. This is really a truncate, which we can construct out of (legal)
12374 // concats and truncate nodes.
ReconstructTruncateFromBuildVector(SDValue V,SelectionDAG & DAG)12375 static SDValue ReconstructTruncateFromBuildVector(SDValue V, SelectionDAG &DAG) {
12376   if (V.getValueType() != MVT::v16i8)
12377     return SDValue();
12378   assert(V.getNumOperands() == 16 && "Expected 16 operands on the BUILDVECTOR");
12379 
12380   for (unsigned X = 0; X < 4; X++) {
12381     // Check the first item in each group is an extract from lane 0 of a v4i32
12382     // or v4i16.
12383     SDValue BaseExt = V.getOperand(X * 4);
12384     if (BaseExt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
12385         (BaseExt.getOperand(0).getValueType() != MVT::v4i16 &&
12386          BaseExt.getOperand(0).getValueType() != MVT::v4i32) ||
12387         !isa<ConstantSDNode>(BaseExt.getOperand(1)) ||
12388         BaseExt.getConstantOperandVal(1) != 0)
12389       return SDValue();
12390     SDValue Base = BaseExt.getOperand(0);
12391     // And check the other items are extracts from the same vector.
12392     for (unsigned Y = 1; Y < 4; Y++) {
12393       SDValue Ext = V.getOperand(X * 4 + Y);
12394       if (Ext.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
12395           Ext.getOperand(0) != Base ||
12396           !isa<ConstantSDNode>(Ext.getOperand(1)) ||
12397           Ext.getConstantOperandVal(1) != Y)
12398         return SDValue();
12399     }
12400   }
12401 
12402   // Turn the buildvector into a series of truncates and concates, which will
12403   // become uzip1's. Any v4i32s we found get truncated to v4i16, which are
12404   // concat together to produce 2 v8i16. These are both truncated and concat
12405   // together.
12406   SDLoc DL(V);
12407   SDValue Trunc[4] = {
12408       V.getOperand(0).getOperand(0), V.getOperand(4).getOperand(0),
12409       V.getOperand(8).getOperand(0), V.getOperand(12).getOperand(0)};
12410   for (SDValue &V : Trunc)
12411     if (V.getValueType() == MVT::v4i32)
12412       V = DAG.getNode(ISD::TRUNCATE, DL, MVT::v4i16, V);
12413   SDValue Concat0 =
12414       DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v8i16, Trunc[0], Trunc[1]);
12415   SDValue Concat1 =
12416       DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v8i16, Trunc[2], Trunc[3]);
12417   SDValue Trunc0 = DAG.getNode(ISD::TRUNCATE, DL, MVT::v8i8, Concat0);
12418   SDValue Trunc1 = DAG.getNode(ISD::TRUNCATE, DL, MVT::v8i8, Concat1);
12419   return DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v16i8, Trunc0, Trunc1);
12420 }
12421 
12422 /// Check if a vector shuffle corresponds to a DUP instructions with a larger
12423 /// element width than the vector lane type. If that is the case the function
12424 /// returns true and writes the value of the DUP instruction lane operand into
12425 /// DupLaneOp
isWideDUPMask(ArrayRef<int> M,EVT VT,unsigned BlockSize,unsigned & DupLaneOp)12426 static bool isWideDUPMask(ArrayRef<int> M, EVT VT, unsigned BlockSize,
12427                           unsigned &DupLaneOp) {
12428   assert((BlockSize == 16 || BlockSize == 32 || BlockSize == 64) &&
12429          "Only possible block sizes for wide DUP are: 16, 32, 64");
12430 
12431   if (BlockSize <= VT.getScalarSizeInBits())
12432     return false;
12433   if (BlockSize % VT.getScalarSizeInBits() != 0)
12434     return false;
12435   if (VT.getSizeInBits() % BlockSize != 0)
12436     return false;
12437 
12438   size_t SingleVecNumElements = VT.getVectorNumElements();
12439   size_t NumEltsPerBlock = BlockSize / VT.getScalarSizeInBits();
12440   size_t NumBlocks = VT.getSizeInBits() / BlockSize;
12441 
12442   // We are looking for masks like
12443   // [0, 1, 0, 1] or [2, 3, 2, 3] or [4, 5, 6, 7, 4, 5, 6, 7] where any element
12444   // might be replaced by 'undefined'. BlockIndices will eventually contain
12445   // lane indices of the duplicated block (i.e. [0, 1], [2, 3] and [4, 5, 6, 7]
12446   // for the above examples)
12447   SmallVector<int, 8> BlockElts(NumEltsPerBlock, -1);
12448   for (size_t BlockIndex = 0; BlockIndex < NumBlocks; BlockIndex++)
12449     for (size_t I = 0; I < NumEltsPerBlock; I++) {
12450       int Elt = M[BlockIndex * NumEltsPerBlock + I];
12451       if (Elt < 0)
12452         continue;
12453       // For now we don't support shuffles that use the second operand
12454       if ((unsigned)Elt >= SingleVecNumElements)
12455         return false;
12456       if (BlockElts[I] < 0)
12457         BlockElts[I] = Elt;
12458       else if (BlockElts[I] != Elt)
12459         return false;
12460     }
12461 
12462   // We found a candidate block (possibly with some undefs). It must be a
12463   // sequence of consecutive integers starting with a value divisible by
12464   // NumEltsPerBlock with some values possibly replaced by undef-s.
12465 
12466   // Find first non-undef element
12467   auto FirstRealEltIter = find_if(BlockElts, [](int Elt) { return Elt >= 0; });
12468   assert(FirstRealEltIter != BlockElts.end() &&
12469          "Shuffle with all-undefs must have been caught by previous cases, "
12470          "e.g. isSplat()");
12471   if (FirstRealEltIter == BlockElts.end()) {
12472     DupLaneOp = 0;
12473     return true;
12474   }
12475 
12476   // Index of FirstRealElt in BlockElts
12477   size_t FirstRealIndex = FirstRealEltIter - BlockElts.begin();
12478 
12479   if ((unsigned)*FirstRealEltIter < FirstRealIndex)
12480     return false;
12481   // BlockElts[0] must have the following value if it isn't undef:
12482   size_t Elt0 = *FirstRealEltIter - FirstRealIndex;
12483 
12484   // Check the first element
12485   if (Elt0 % NumEltsPerBlock != 0)
12486     return false;
12487   // Check that the sequence indeed consists of consecutive integers (modulo
12488   // undefs)
12489   for (size_t I = 0; I < NumEltsPerBlock; I++)
12490     if (BlockElts[I] >= 0 && (unsigned)BlockElts[I] != Elt0 + I)
12491       return false;
12492 
12493   DupLaneOp = Elt0 / NumEltsPerBlock;
12494   return true;
12495 }
12496 
12497 // check if an EXT instruction can handle the shuffle mask when the
12498 // vector sources of the shuffle are different.
isEXTMask(ArrayRef<int> M,EVT VT,bool & ReverseEXT,unsigned & Imm)12499 static bool isEXTMask(ArrayRef<int> M, EVT VT, bool &ReverseEXT,
12500                       unsigned &Imm) {
12501   // Look for the first non-undef element.
12502   const int *FirstRealElt = find_if(M, [](int Elt) { return Elt >= 0; });
12503 
12504   // Benefit form APInt to handle overflow when calculating expected element.
12505   unsigned NumElts = VT.getVectorNumElements();
12506   unsigned MaskBits = APInt(32, NumElts * 2).logBase2();
12507   APInt ExpectedElt = APInt(MaskBits, *FirstRealElt + 1);
12508   // The following shuffle indices must be the successive elements after the
12509   // first real element.
12510   bool FoundWrongElt = std::any_of(FirstRealElt + 1, M.end(), [&](int Elt) {
12511     return Elt != ExpectedElt++ && Elt != -1;
12512   });
12513   if (FoundWrongElt)
12514     return false;
12515 
12516   // The index of an EXT is the first element if it is not UNDEF.
12517   // Watch out for the beginning UNDEFs. The EXT index should be the expected
12518   // value of the first element.  E.g.
12519   // <-1, -1, 3, ...> is treated as <1, 2, 3, ...>.
12520   // <-1, -1, 0, 1, ...> is treated as <2*NumElts-2, 2*NumElts-1, 0, 1, ...>.
12521   // ExpectedElt is the last mask index plus 1.
12522   Imm = ExpectedElt.getZExtValue();
12523 
12524   // There are two difference cases requiring to reverse input vectors.
12525   // For example, for vector <4 x i32> we have the following cases,
12526   // Case 1: shufflevector(<4 x i32>,<4 x i32>,<-1, -1, -1, 0>)
12527   // Case 2: shufflevector(<4 x i32>,<4 x i32>,<-1, -1, 7, 0>)
12528   // For both cases, we finally use mask <5, 6, 7, 0>, which requires
12529   // to reverse two input vectors.
12530   if (Imm < NumElts)
12531     ReverseEXT = true;
12532   else
12533     Imm -= NumElts;
12534 
12535   return true;
12536 }
12537 
12538 /// isZIP_v_undef_Mask - Special case of isZIPMask for canonical form of
12539 /// "vector_shuffle v, v", i.e., "vector_shuffle v, undef".
12540 /// Mask is e.g., <0, 0, 1, 1> instead of <0, 4, 1, 5>.
isZIP_v_undef_Mask(ArrayRef<int> M,EVT VT,unsigned & WhichResult)12541 static bool isZIP_v_undef_Mask(ArrayRef<int> M, EVT VT, unsigned &WhichResult) {
12542   unsigned NumElts = VT.getVectorNumElements();
12543   if (NumElts % 2 != 0)
12544     return false;
12545   WhichResult = (M[0] == 0 ? 0 : 1);
12546   unsigned Idx = WhichResult * NumElts / 2;
12547   for (unsigned i = 0; i != NumElts; i += 2) {
12548     if ((M[i] >= 0 && (unsigned)M[i] != Idx) ||
12549         (M[i + 1] >= 0 && (unsigned)M[i + 1] != Idx))
12550       return false;
12551     Idx += 1;
12552   }
12553 
12554   return true;
12555 }
12556 
12557 /// isUZP_v_undef_Mask - Special case of isUZPMask for canonical form of
12558 /// "vector_shuffle v, v", i.e., "vector_shuffle v, undef".
12559 /// Mask is e.g., <0, 2, 0, 2> instead of <0, 2, 4, 6>,
isUZP_v_undef_Mask(ArrayRef<int> M,EVT VT,unsigned & WhichResult)12560 static bool isUZP_v_undef_Mask(ArrayRef<int> M, EVT VT, unsigned &WhichResult) {
12561   unsigned Half = VT.getVectorNumElements() / 2;
12562   WhichResult = (M[0] == 0 ? 0 : 1);
12563   for (unsigned j = 0; j != 2; ++j) {
12564     unsigned Idx = WhichResult;
12565     for (unsigned i = 0; i != Half; ++i) {
12566       int MIdx = M[i + j * Half];
12567       if (MIdx >= 0 && (unsigned)MIdx != Idx)
12568         return false;
12569       Idx += 2;
12570     }
12571   }
12572 
12573   return true;
12574 }
12575 
12576 /// isTRN_v_undef_Mask - Special case of isTRNMask for canonical form of
12577 /// "vector_shuffle v, v", i.e., "vector_shuffle v, undef".
12578 /// Mask is e.g., <0, 0, 2, 2> instead of <0, 4, 2, 6>.
isTRN_v_undef_Mask(ArrayRef<int> M,EVT VT,unsigned & WhichResult)12579 static bool isTRN_v_undef_Mask(ArrayRef<int> M, EVT VT, unsigned &WhichResult) {
12580   unsigned NumElts = VT.getVectorNumElements();
12581   if (NumElts % 2 != 0)
12582     return false;
12583   WhichResult = (M[0] == 0 ? 0 : 1);
12584   for (unsigned i = 0; i < NumElts; i += 2) {
12585     if ((M[i] >= 0 && (unsigned)M[i] != i + WhichResult) ||
12586         (M[i + 1] >= 0 && (unsigned)M[i + 1] != i + WhichResult))
12587       return false;
12588   }
12589   return true;
12590 }
12591 
isINSMask(ArrayRef<int> M,int NumInputElements,bool & DstIsLeft,int & Anomaly)12592 static bool isINSMask(ArrayRef<int> M, int NumInputElements,
12593                       bool &DstIsLeft, int &Anomaly) {
12594   if (M.size() != static_cast<size_t>(NumInputElements))
12595     return false;
12596 
12597   int NumLHSMatch = 0, NumRHSMatch = 0;
12598   int LastLHSMismatch = -1, LastRHSMismatch = -1;
12599 
12600   for (int i = 0; i < NumInputElements; ++i) {
12601     if (M[i] == -1) {
12602       ++NumLHSMatch;
12603       ++NumRHSMatch;
12604       continue;
12605     }
12606 
12607     if (M[i] == i)
12608       ++NumLHSMatch;
12609     else
12610       LastLHSMismatch = i;
12611 
12612     if (M[i] == i + NumInputElements)
12613       ++NumRHSMatch;
12614     else
12615       LastRHSMismatch = i;
12616   }
12617 
12618   if (NumLHSMatch == NumInputElements - 1) {
12619     DstIsLeft = true;
12620     Anomaly = LastLHSMismatch;
12621     return true;
12622   } else if (NumRHSMatch == NumInputElements - 1) {
12623     DstIsLeft = false;
12624     Anomaly = LastRHSMismatch;
12625     return true;
12626   }
12627 
12628   return false;
12629 }
12630 
isConcatMask(ArrayRef<int> Mask,EVT VT,bool SplitLHS)12631 static bool isConcatMask(ArrayRef<int> Mask, EVT VT, bool SplitLHS) {
12632   if (VT.getSizeInBits() != 128)
12633     return false;
12634 
12635   unsigned NumElts = VT.getVectorNumElements();
12636 
12637   for (int I = 0, E = NumElts / 2; I != E; I++) {
12638     if (Mask[I] != I)
12639       return false;
12640   }
12641 
12642   int Offset = NumElts / 2;
12643   for (int I = NumElts / 2, E = NumElts; I != E; I++) {
12644     if (Mask[I] != I + SplitLHS * Offset)
12645       return false;
12646   }
12647 
12648   return true;
12649 }
12650 
tryFormConcatFromShuffle(SDValue Op,SelectionDAG & DAG)12651 static SDValue tryFormConcatFromShuffle(SDValue Op, SelectionDAG &DAG) {
12652   SDLoc DL(Op);
12653   EVT VT = Op.getValueType();
12654   SDValue V0 = Op.getOperand(0);
12655   SDValue V1 = Op.getOperand(1);
12656   ArrayRef<int> Mask = cast<ShuffleVectorSDNode>(Op)->getMask();
12657 
12658   if (VT.getVectorElementType() != V0.getValueType().getVectorElementType() ||
12659       VT.getVectorElementType() != V1.getValueType().getVectorElementType())
12660     return SDValue();
12661 
12662   bool SplitV0 = V0.getValueSizeInBits() == 128;
12663 
12664   if (!isConcatMask(Mask, VT, SplitV0))
12665     return SDValue();
12666 
12667   EVT CastVT = VT.getHalfNumVectorElementsVT(*DAG.getContext());
12668   if (SplitV0) {
12669     V0 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, CastVT, V0,
12670                      DAG.getConstant(0, DL, MVT::i64));
12671   }
12672   if (V1.getValueSizeInBits() == 128) {
12673     V1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, CastVT, V1,
12674                      DAG.getConstant(0, DL, MVT::i64));
12675   }
12676   return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, V0, V1);
12677 }
12678 
12679 /// GeneratePerfectShuffle - Given an entry in the perfect-shuffle table, emit
12680 /// the specified operations to build the shuffle. ID is the perfect-shuffle
12681 //ID, V1 and V2 are the original shuffle inputs. PFEntry is the Perfect shuffle
12682 //table entry and LHS/RHS are the immediate inputs for this stage of the
12683 //shuffle.
GeneratePerfectShuffle(unsigned ID,SDValue V1,SDValue V2,unsigned PFEntry,SDValue LHS,SDValue RHS,SelectionDAG & DAG,const SDLoc & dl)12684 static SDValue GeneratePerfectShuffle(unsigned ID, SDValue V1,
12685                                       SDValue V2, unsigned PFEntry, SDValue LHS,
12686                                       SDValue RHS, SelectionDAG &DAG,
12687                                       const SDLoc &dl) {
12688   unsigned OpNum = (PFEntry >> 26) & 0x0F;
12689   unsigned LHSID = (PFEntry >> 13) & ((1 << 13) - 1);
12690   unsigned RHSID = (PFEntry >> 0) & ((1 << 13) - 1);
12691 
12692   enum {
12693     OP_COPY = 0, // Copy, used for things like <u,u,u,3> to say it is <0,1,2,3>
12694     OP_VREV,
12695     OP_VDUP0,
12696     OP_VDUP1,
12697     OP_VDUP2,
12698     OP_VDUP3,
12699     OP_VEXT1,
12700     OP_VEXT2,
12701     OP_VEXT3,
12702     OP_VUZPL,  // VUZP, left result
12703     OP_VUZPR,  // VUZP, right result
12704     OP_VZIPL,  // VZIP, left result
12705     OP_VZIPR,  // VZIP, right result
12706     OP_VTRNL,  // VTRN, left result
12707     OP_VTRNR,  // VTRN, right result
12708     OP_MOVLANE // Move lane. RHSID is the lane to move into
12709   };
12710 
12711   if (OpNum == OP_COPY) {
12712     if (LHSID == (1 * 9 + 2) * 9 + 3)
12713       return LHS;
12714     assert(LHSID == ((4 * 9 + 5) * 9 + 6) * 9 + 7 && "Illegal OP_COPY!");
12715     return RHS;
12716   }
12717 
12718   if (OpNum == OP_MOVLANE) {
12719     // Decompose a PerfectShuffle ID to get the Mask for lane Elt
12720     auto getPFIDLane = [](unsigned ID, int Elt) -> int {
12721       assert(Elt < 4 && "Expected Perfect Lanes to be less than 4");
12722       Elt = 3 - Elt;
12723       while (Elt > 0) {
12724         ID /= 9;
12725         Elt--;
12726       }
12727       return (ID % 9 == 8) ? -1 : ID % 9;
12728     };
12729 
12730     // For OP_MOVLANE shuffles, the RHSID represents the lane to move into. We
12731     // get the lane to move from the PFID, which is always from the
12732     // original vectors (V1 or V2).
12733     SDValue OpLHS = GeneratePerfectShuffle(
12734         LHSID, V1, V2, PerfectShuffleTable[LHSID], LHS, RHS, DAG, dl);
12735     EVT VT = OpLHS.getValueType();
12736     assert(RHSID < 8 && "Expected a lane index for RHSID!");
12737     unsigned ExtLane = 0;
12738     SDValue Input;
12739 
12740     // OP_MOVLANE are either D movs (if bit 0x4 is set) or S movs. D movs
12741     // convert into a higher type.
12742     if (RHSID & 0x4) {
12743       int MaskElt = getPFIDLane(ID, (RHSID & 0x01) << 1) >> 1;
12744       if (MaskElt == -1)
12745         MaskElt = (getPFIDLane(ID, ((RHSID & 0x01) << 1) + 1) - 1) >> 1;
12746       assert(MaskElt >= 0 && "Didn't expect an undef movlane index!");
12747       ExtLane = MaskElt < 2 ? MaskElt : (MaskElt - 2);
12748       Input = MaskElt < 2 ? V1 : V2;
12749       if (VT.getScalarSizeInBits() == 16) {
12750         Input = DAG.getBitcast(MVT::v2f32, Input);
12751         OpLHS = DAG.getBitcast(MVT::v2f32, OpLHS);
12752       } else {
12753         assert(VT.getScalarSizeInBits() == 32 &&
12754                "Expected 16 or 32 bit shuffle elemements");
12755         Input = DAG.getBitcast(MVT::v2f64, Input);
12756         OpLHS = DAG.getBitcast(MVT::v2f64, OpLHS);
12757       }
12758     } else {
12759       int MaskElt = getPFIDLane(ID, RHSID);
12760       assert(MaskElt >= 0 && "Didn't expect an undef movlane index!");
12761       ExtLane = MaskElt < 4 ? MaskElt : (MaskElt - 4);
12762       Input = MaskElt < 4 ? V1 : V2;
12763       // Be careful about creating illegal types. Use f16 instead of i16.
12764       if (VT == MVT::v4i16) {
12765         Input = DAG.getBitcast(MVT::v4f16, Input);
12766         OpLHS = DAG.getBitcast(MVT::v4f16, OpLHS);
12767       }
12768     }
12769     SDValue Ext = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
12770                               Input.getValueType().getVectorElementType(),
12771                               Input, DAG.getVectorIdxConstant(ExtLane, dl));
12772     SDValue Ins =
12773         DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, Input.getValueType(), OpLHS,
12774                     Ext, DAG.getVectorIdxConstant(RHSID & 0x3, dl));
12775     return DAG.getBitcast(VT, Ins);
12776   }
12777 
12778   SDValue OpLHS, OpRHS;
12779   OpLHS = GeneratePerfectShuffle(LHSID, V1, V2, PerfectShuffleTable[LHSID], LHS,
12780                                  RHS, DAG, dl);
12781   OpRHS = GeneratePerfectShuffle(RHSID, V1, V2, PerfectShuffleTable[RHSID], LHS,
12782                                  RHS, DAG, dl);
12783   EVT VT = OpLHS.getValueType();
12784 
12785   switch (OpNum) {
12786   default:
12787     llvm_unreachable("Unknown shuffle opcode!");
12788   case OP_VREV:
12789     // VREV divides the vector in half and swaps within the half.
12790     if (VT.getVectorElementType() == MVT::i32 ||
12791         VT.getVectorElementType() == MVT::f32)
12792       return DAG.getNode(AArch64ISD::REV64, dl, VT, OpLHS);
12793     // vrev <4 x i16> -> REV32
12794     if (VT.getVectorElementType() == MVT::i16 ||
12795         VT.getVectorElementType() == MVT::f16 ||
12796         VT.getVectorElementType() == MVT::bf16)
12797       return DAG.getNode(AArch64ISD::REV32, dl, VT, OpLHS);
12798     // vrev <4 x i8> -> REV16
12799     assert(VT.getVectorElementType() == MVT::i8);
12800     return DAG.getNode(AArch64ISD::REV16, dl, VT, OpLHS);
12801   case OP_VDUP0:
12802   case OP_VDUP1:
12803   case OP_VDUP2:
12804   case OP_VDUP3: {
12805     EVT EltTy = VT.getVectorElementType();
12806     unsigned Opcode;
12807     if (EltTy == MVT::i8)
12808       Opcode = AArch64ISD::DUPLANE8;
12809     else if (EltTy == MVT::i16 || EltTy == MVT::f16 || EltTy == MVT::bf16)
12810       Opcode = AArch64ISD::DUPLANE16;
12811     else if (EltTy == MVT::i32 || EltTy == MVT::f32)
12812       Opcode = AArch64ISD::DUPLANE32;
12813     else if (EltTy == MVT::i64 || EltTy == MVT::f64)
12814       Opcode = AArch64ISD::DUPLANE64;
12815     else
12816       llvm_unreachable("Invalid vector element type?");
12817 
12818     if (VT.getSizeInBits() == 64)
12819       OpLHS = WidenVector(OpLHS, DAG);
12820     SDValue Lane = DAG.getConstant(OpNum - OP_VDUP0, dl, MVT::i64);
12821     return DAG.getNode(Opcode, dl, VT, OpLHS, Lane);
12822   }
12823   case OP_VEXT1:
12824   case OP_VEXT2:
12825   case OP_VEXT3: {
12826     unsigned Imm = (OpNum - OP_VEXT1 + 1) * getExtFactor(OpLHS);
12827     return DAG.getNode(AArch64ISD::EXT, dl, VT, OpLHS, OpRHS,
12828                        DAG.getConstant(Imm, dl, MVT::i32));
12829   }
12830   case OP_VUZPL:
12831     return DAG.getNode(AArch64ISD::UZP1, dl, VT, OpLHS, OpRHS);
12832   case OP_VUZPR:
12833     return DAG.getNode(AArch64ISD::UZP2, dl, VT, OpLHS, OpRHS);
12834   case OP_VZIPL:
12835     return DAG.getNode(AArch64ISD::ZIP1, dl, VT, OpLHS, OpRHS);
12836   case OP_VZIPR:
12837     return DAG.getNode(AArch64ISD::ZIP2, dl, VT, OpLHS, OpRHS);
12838   case OP_VTRNL:
12839     return DAG.getNode(AArch64ISD::TRN1, dl, VT, OpLHS, OpRHS);
12840   case OP_VTRNR:
12841     return DAG.getNode(AArch64ISD::TRN2, dl, VT, OpLHS, OpRHS);
12842   }
12843 }
12844 
GenerateTBL(SDValue Op,ArrayRef<int> ShuffleMask,SelectionDAG & DAG)12845 static SDValue GenerateTBL(SDValue Op, ArrayRef<int> ShuffleMask,
12846                            SelectionDAG &DAG) {
12847   // Check to see if we can use the TBL instruction.
12848   SDValue V1 = Op.getOperand(0);
12849   SDValue V2 = Op.getOperand(1);
12850   SDLoc DL(Op);
12851 
12852   EVT EltVT = Op.getValueType().getVectorElementType();
12853   unsigned BytesPerElt = EltVT.getSizeInBits() / 8;
12854 
12855   bool Swap = false;
12856   if (V1.isUndef() || isZerosVector(V1.getNode())) {
12857     std::swap(V1, V2);
12858     Swap = true;
12859   }
12860 
12861   // If the V2 source is undef or zero then we can use a tbl1, as tbl1 will fill
12862   // out of range values with 0s. We do need to make sure that any out-of-range
12863   // values are really out-of-range for a v16i8 vector.
12864   bool IsUndefOrZero = V2.isUndef() || isZerosVector(V2.getNode());
12865   MVT IndexVT = MVT::v8i8;
12866   unsigned IndexLen = 8;
12867   if (Op.getValueSizeInBits() == 128) {
12868     IndexVT = MVT::v16i8;
12869     IndexLen = 16;
12870   }
12871 
12872   SmallVector<SDValue, 8> TBLMask;
12873   for (int Val : ShuffleMask) {
12874     for (unsigned Byte = 0; Byte < BytesPerElt; ++Byte) {
12875       unsigned Offset = Byte + Val * BytesPerElt;
12876       if (Swap)
12877         Offset = Offset < IndexLen ? Offset + IndexLen : Offset - IndexLen;
12878       if (IsUndefOrZero && Offset >= IndexLen)
12879         Offset = 255;
12880       TBLMask.push_back(DAG.getConstant(Offset, DL, MVT::i32));
12881     }
12882   }
12883 
12884   SDValue V1Cst = DAG.getNode(ISD::BITCAST, DL, IndexVT, V1);
12885   SDValue V2Cst = DAG.getNode(ISD::BITCAST, DL, IndexVT, V2);
12886 
12887   SDValue Shuffle;
12888   if (IsUndefOrZero) {
12889     if (IndexLen == 8)
12890       V1Cst = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v16i8, V1Cst, V1Cst);
12891     Shuffle = DAG.getNode(
12892         ISD::INTRINSIC_WO_CHAIN, DL, IndexVT,
12893         DAG.getConstant(Intrinsic::aarch64_neon_tbl1, DL, MVT::i32), V1Cst,
12894         DAG.getBuildVector(IndexVT, DL, ArrayRef(TBLMask.data(), IndexLen)));
12895   } else {
12896     if (IndexLen == 8) {
12897       V1Cst = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v16i8, V1Cst, V2Cst);
12898       Shuffle = DAG.getNode(
12899           ISD::INTRINSIC_WO_CHAIN, DL, IndexVT,
12900           DAG.getConstant(Intrinsic::aarch64_neon_tbl1, DL, MVT::i32), V1Cst,
12901           DAG.getBuildVector(IndexVT, DL, ArrayRef(TBLMask.data(), IndexLen)));
12902     } else {
12903       // FIXME: We cannot, for the moment, emit a TBL2 instruction because we
12904       // cannot currently represent the register constraints on the input
12905       // table registers.
12906       //  Shuffle = DAG.getNode(AArch64ISD::TBL2, DL, IndexVT, V1Cst, V2Cst,
12907       //                   DAG.getBuildVector(IndexVT, DL, &TBLMask[0],
12908       //                   IndexLen));
12909       Shuffle = DAG.getNode(
12910           ISD::INTRINSIC_WO_CHAIN, DL, IndexVT,
12911           DAG.getConstant(Intrinsic::aarch64_neon_tbl2, DL, MVT::i32), V1Cst,
12912           V2Cst,
12913           DAG.getBuildVector(IndexVT, DL, ArrayRef(TBLMask.data(), IndexLen)));
12914     }
12915   }
12916   return DAG.getNode(ISD::BITCAST, DL, Op.getValueType(), Shuffle);
12917 }
12918 
getDUPLANEOp(EVT EltType)12919 static unsigned getDUPLANEOp(EVT EltType) {
12920   if (EltType == MVT::i8)
12921     return AArch64ISD::DUPLANE8;
12922   if (EltType == MVT::i16 || EltType == MVT::f16 || EltType == MVT::bf16)
12923     return AArch64ISD::DUPLANE16;
12924   if (EltType == MVT::i32 || EltType == MVT::f32)
12925     return AArch64ISD::DUPLANE32;
12926   if (EltType == MVT::i64 || EltType == MVT::f64)
12927     return AArch64ISD::DUPLANE64;
12928 
12929   llvm_unreachable("Invalid vector element type?");
12930 }
12931 
constructDup(SDValue V,int Lane,SDLoc dl,EVT VT,unsigned Opcode,SelectionDAG & DAG)12932 static SDValue constructDup(SDValue V, int Lane, SDLoc dl, EVT VT,
12933                             unsigned Opcode, SelectionDAG &DAG) {
12934   // Try to eliminate a bitcasted extract subvector before a DUPLANE.
12935   auto getScaledOffsetDup = [](SDValue BitCast, int &LaneC, MVT &CastVT) {
12936     // Match: dup (bitcast (extract_subv X, C)), LaneC
12937     if (BitCast.getOpcode() != ISD::BITCAST ||
12938         BitCast.getOperand(0).getOpcode() != ISD::EXTRACT_SUBVECTOR)
12939       return false;
12940 
12941     // The extract index must align in the destination type. That may not
12942     // happen if the bitcast is from narrow to wide type.
12943     SDValue Extract = BitCast.getOperand(0);
12944     unsigned ExtIdx = Extract.getConstantOperandVal(1);
12945     unsigned SrcEltBitWidth = Extract.getScalarValueSizeInBits();
12946     unsigned ExtIdxInBits = ExtIdx * SrcEltBitWidth;
12947     unsigned CastedEltBitWidth = BitCast.getScalarValueSizeInBits();
12948     if (ExtIdxInBits % CastedEltBitWidth != 0)
12949       return false;
12950 
12951     // Can't handle cases where vector size is not 128-bit
12952     if (!Extract.getOperand(0).getValueType().is128BitVector())
12953       return false;
12954 
12955     // Update the lane value by offsetting with the scaled extract index.
12956     LaneC += ExtIdxInBits / CastedEltBitWidth;
12957 
12958     // Determine the casted vector type of the wide vector input.
12959     // dup (bitcast (extract_subv X, C)), LaneC --> dup (bitcast X), LaneC'
12960     // Examples:
12961     // dup (bitcast (extract_subv v2f64 X, 1) to v2f32), 1 --> dup v4f32 X, 3
12962     // dup (bitcast (extract_subv v16i8 X, 8) to v4i16), 1 --> dup v8i16 X, 5
12963     unsigned SrcVecNumElts =
12964         Extract.getOperand(0).getValueSizeInBits() / CastedEltBitWidth;
12965     CastVT = MVT::getVectorVT(BitCast.getSimpleValueType().getScalarType(),
12966                               SrcVecNumElts);
12967     return true;
12968   };
12969   MVT CastVT;
12970   if (getScaledOffsetDup(V, Lane, CastVT)) {
12971     V = DAG.getBitcast(CastVT, V.getOperand(0).getOperand(0));
12972   } else if (V.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
12973              V.getOperand(0).getValueType().is128BitVector()) {
12974     // The lane is incremented by the index of the extract.
12975     // Example: dup v2f32 (extract v4f32 X, 2), 1 --> dup v4f32 X, 3
12976     Lane += V.getConstantOperandVal(1);
12977     V = V.getOperand(0);
12978   } else if (V.getOpcode() == ISD::CONCAT_VECTORS) {
12979     // The lane is decremented if we are splatting from the 2nd operand.
12980     // Example: dup v4i32 (concat v2i32 X, v2i32 Y), 3 --> dup v4i32 Y, 1
12981     unsigned Idx = Lane >= (int)VT.getVectorNumElements() / 2;
12982     Lane -= Idx * VT.getVectorNumElements() / 2;
12983     V = WidenVector(V.getOperand(Idx), DAG);
12984   } else if (VT.getSizeInBits() == 64) {
12985     // Widen the operand to 128-bit register with undef.
12986     V = WidenVector(V, DAG);
12987   }
12988   return DAG.getNode(Opcode, dl, VT, V, DAG.getConstant(Lane, dl, MVT::i64));
12989 }
12990 
12991 // Return true if we can get a new shuffle mask by checking the parameter mask
12992 // array to test whether every two adjacent mask values are continuous and
12993 // starting from an even number.
isWideTypeMask(ArrayRef<int> M,EVT VT,SmallVectorImpl<int> & NewMask)12994 static bool isWideTypeMask(ArrayRef<int> M, EVT VT,
12995                            SmallVectorImpl<int> &NewMask) {
12996   unsigned NumElts = VT.getVectorNumElements();
12997   if (NumElts % 2 != 0)
12998     return false;
12999 
13000   NewMask.clear();
13001   for (unsigned i = 0; i < NumElts; i += 2) {
13002     int M0 = M[i];
13003     int M1 = M[i + 1];
13004 
13005     // If both elements are undef, new mask is undef too.
13006     if (M0 == -1 && M1 == -1) {
13007       NewMask.push_back(-1);
13008       continue;
13009     }
13010 
13011     if (M0 == -1 && M1 != -1 && (M1 % 2) == 1) {
13012       NewMask.push_back(M1 / 2);
13013       continue;
13014     }
13015 
13016     if (M0 != -1 && (M0 % 2) == 0 && ((M0 + 1) == M1 || M1 == -1)) {
13017       NewMask.push_back(M0 / 2);
13018       continue;
13019     }
13020 
13021     NewMask.clear();
13022     return false;
13023   }
13024 
13025   assert(NewMask.size() == NumElts / 2 && "Incorrect size for mask!");
13026   return true;
13027 }
13028 
13029 // Try to widen element type to get a new mask value for a better permutation
13030 // sequence, so that we can use NEON shuffle instructions, such as zip1/2,
13031 // UZP1/2, TRN1/2, REV, INS, etc.
13032 // For example:
13033 //  shufflevector <4 x i32> %a, <4 x i32> %b,
13034 //                <4 x i32> <i32 6, i32 7, i32 2, i32 3>
13035 // is equivalent to:
13036 //  shufflevector <2 x i64> %a, <2 x i64> %b, <2 x i32> <i32 3, i32 1>
13037 // Finally, we can get:
13038 //  mov     v0.d[0], v1.d[1]
tryWidenMaskForShuffle(SDValue Op,SelectionDAG & DAG)13039 static SDValue tryWidenMaskForShuffle(SDValue Op, SelectionDAG &DAG) {
13040   SDLoc DL(Op);
13041   EVT VT = Op.getValueType();
13042   EVT ScalarVT = VT.getVectorElementType();
13043   unsigned ElementSize = ScalarVT.getFixedSizeInBits();
13044   SDValue V0 = Op.getOperand(0);
13045   SDValue V1 = Op.getOperand(1);
13046   ArrayRef<int> Mask = cast<ShuffleVectorSDNode>(Op)->getMask();
13047 
13048   // If combining adjacent elements, like two i16's -> i32, two i32's -> i64 ...
13049   // We need to make sure the wider element type is legal. Thus, ElementSize
13050   // should be not larger than 32 bits, and i1 type should also be excluded.
13051   if (ElementSize > 32 || ElementSize == 1)
13052     return SDValue();
13053 
13054   SmallVector<int, 8> NewMask;
13055   if (isWideTypeMask(Mask, VT, NewMask)) {
13056     MVT NewEltVT = VT.isFloatingPoint()
13057                        ? MVT::getFloatingPointVT(ElementSize * 2)
13058                        : MVT::getIntegerVT(ElementSize * 2);
13059     MVT NewVT = MVT::getVectorVT(NewEltVT, VT.getVectorNumElements() / 2);
13060     if (DAG.getTargetLoweringInfo().isTypeLegal(NewVT)) {
13061       V0 = DAG.getBitcast(NewVT, V0);
13062       V1 = DAG.getBitcast(NewVT, V1);
13063       return DAG.getBitcast(VT,
13064                             DAG.getVectorShuffle(NewVT, DL, V0, V1, NewMask));
13065     }
13066   }
13067 
13068   return SDValue();
13069 }
13070 
13071 // Try to fold shuffle (tbl2, tbl2) into a single tbl4.
tryToConvertShuffleOfTbl2ToTbl4(SDValue Op,ArrayRef<int> ShuffleMask,SelectionDAG & DAG)13072 static SDValue tryToConvertShuffleOfTbl2ToTbl4(SDValue Op,
13073                                                ArrayRef<int> ShuffleMask,
13074                                                SelectionDAG &DAG) {
13075   SDValue Tbl1 = Op->getOperand(0);
13076   SDValue Tbl2 = Op->getOperand(1);
13077   SDLoc dl(Op);
13078   SDValue Tbl2ID =
13079       DAG.getTargetConstant(Intrinsic::aarch64_neon_tbl2, dl, MVT::i64);
13080 
13081   EVT VT = Op.getValueType();
13082   if (Tbl1->getOpcode() != ISD::INTRINSIC_WO_CHAIN ||
13083       Tbl1->getOperand(0) != Tbl2ID ||
13084       Tbl2->getOpcode() != ISD::INTRINSIC_WO_CHAIN ||
13085       Tbl2->getOperand(0) != Tbl2ID)
13086     return SDValue();
13087 
13088   if (Tbl1->getValueType(0) != MVT::v16i8 ||
13089       Tbl2->getValueType(0) != MVT::v16i8)
13090     return SDValue();
13091 
13092   SDValue Mask1 = Tbl1->getOperand(3);
13093   SDValue Mask2 = Tbl2->getOperand(3);
13094   SmallVector<SDValue, 16> TBLMaskParts(16, SDValue());
13095   for (unsigned I = 0; I < 16; I++) {
13096     if (ShuffleMask[I] < 16)
13097       TBLMaskParts[I] = Mask1->getOperand(ShuffleMask[I]);
13098     else {
13099       auto *C =
13100           dyn_cast<ConstantSDNode>(Mask2->getOperand(ShuffleMask[I] - 16));
13101       if (!C)
13102         return SDValue();
13103       TBLMaskParts[I] = DAG.getConstant(C->getSExtValue() + 32, dl, MVT::i32);
13104     }
13105   }
13106 
13107   SDValue TBLMask = DAG.getBuildVector(VT, dl, TBLMaskParts);
13108   SDValue ID =
13109       DAG.getTargetConstant(Intrinsic::aarch64_neon_tbl4, dl, MVT::i64);
13110 
13111   return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, MVT::v16i8,
13112                      {ID, Tbl1->getOperand(1), Tbl1->getOperand(2),
13113                       Tbl2->getOperand(1), Tbl2->getOperand(2), TBLMask});
13114 }
13115 
13116 // Baseline legalization for ZERO_EXTEND_VECTOR_INREG will blend-in zeros,
13117 // but we don't have an appropriate instruction,
13118 // so custom-lower it as ZIP1-with-zeros.
13119 SDValue
LowerZERO_EXTEND_VECTOR_INREG(SDValue Op,SelectionDAG & DAG) const13120 AArch64TargetLowering::LowerZERO_EXTEND_VECTOR_INREG(SDValue Op,
13121                                                      SelectionDAG &DAG) const {
13122   SDLoc dl(Op);
13123   EVT VT = Op.getValueType();
13124   SDValue SrcOp = Op.getOperand(0);
13125   EVT SrcVT = SrcOp.getValueType();
13126   assert(VT.getScalarSizeInBits() % SrcVT.getScalarSizeInBits() == 0 &&
13127          "Unexpected extension factor.");
13128   unsigned Scale = VT.getScalarSizeInBits() / SrcVT.getScalarSizeInBits();
13129   // FIXME: support multi-step zipping?
13130   if (Scale != 2)
13131     return SDValue();
13132   SDValue Zeros = DAG.getConstant(0, dl, SrcVT);
13133   return DAG.getBitcast(VT,
13134                         DAG.getNode(AArch64ISD::ZIP1, dl, SrcVT, SrcOp, Zeros));
13135 }
13136 
LowerVECTOR_SHUFFLE(SDValue Op,SelectionDAG & DAG) const13137 SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
13138                                                    SelectionDAG &DAG) const {
13139   SDLoc dl(Op);
13140   EVT VT = Op.getValueType();
13141 
13142   ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op.getNode());
13143 
13144   if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()))
13145     return LowerFixedLengthVECTOR_SHUFFLEToSVE(Op, DAG);
13146 
13147   // Convert shuffles that are directly supported on NEON to target-specific
13148   // DAG nodes, instead of keeping them as shuffles and matching them again
13149   // during code selection.  This is more efficient and avoids the possibility
13150   // of inconsistencies between legalization and selection.
13151   ArrayRef<int> ShuffleMask = SVN->getMask();
13152 
13153   SDValue V1 = Op.getOperand(0);
13154   SDValue V2 = Op.getOperand(1);
13155 
13156   assert(V1.getValueType() == VT && "Unexpected VECTOR_SHUFFLE type!");
13157   assert(ShuffleMask.size() == VT.getVectorNumElements() &&
13158          "Unexpected VECTOR_SHUFFLE mask size!");
13159 
13160   if (SDValue Res = tryToConvertShuffleOfTbl2ToTbl4(Op, ShuffleMask, DAG))
13161     return Res;
13162 
13163   if (SVN->isSplat()) {
13164     int Lane = SVN->getSplatIndex();
13165     // If this is undef splat, generate it via "just" vdup, if possible.
13166     if (Lane == -1)
13167       Lane = 0;
13168 
13169     if (Lane == 0 && V1.getOpcode() == ISD::SCALAR_TO_VECTOR)
13170       return DAG.getNode(AArch64ISD::DUP, dl, V1.getValueType(),
13171                          V1.getOperand(0));
13172     // Test if V1 is a BUILD_VECTOR and the lane being referenced is a non-
13173     // constant. If so, we can just reference the lane's definition directly.
13174     if (V1.getOpcode() == ISD::BUILD_VECTOR &&
13175         !isa<ConstantSDNode>(V1.getOperand(Lane)))
13176       return DAG.getNode(AArch64ISD::DUP, dl, VT, V1.getOperand(Lane));
13177 
13178     // Otherwise, duplicate from the lane of the input vector.
13179     unsigned Opcode = getDUPLANEOp(V1.getValueType().getVectorElementType());
13180     return constructDup(V1, Lane, dl, VT, Opcode, DAG);
13181   }
13182 
13183   // Check if the mask matches a DUP for a wider element
13184   for (unsigned LaneSize : {64U, 32U, 16U}) {
13185     unsigned Lane = 0;
13186     if (isWideDUPMask(ShuffleMask, VT, LaneSize, Lane)) {
13187       unsigned Opcode = LaneSize == 64 ? AArch64ISD::DUPLANE64
13188                                        : LaneSize == 32 ? AArch64ISD::DUPLANE32
13189                                                         : AArch64ISD::DUPLANE16;
13190       // Cast V1 to an integer vector with required lane size
13191       MVT NewEltTy = MVT::getIntegerVT(LaneSize);
13192       unsigned NewEltCount = VT.getSizeInBits() / LaneSize;
13193       MVT NewVecTy = MVT::getVectorVT(NewEltTy, NewEltCount);
13194       V1 = DAG.getBitcast(NewVecTy, V1);
13195       // Constuct the DUP instruction
13196       V1 = constructDup(V1, Lane, dl, NewVecTy, Opcode, DAG);
13197       // Cast back to the original type
13198       return DAG.getBitcast(VT, V1);
13199     }
13200   }
13201 
13202   unsigned NumElts = VT.getVectorNumElements();
13203   unsigned EltSize = VT.getScalarSizeInBits();
13204   if (isREVMask(ShuffleMask, EltSize, NumElts, 64))
13205     return DAG.getNode(AArch64ISD::REV64, dl, V1.getValueType(), V1, V2);
13206   if (isREVMask(ShuffleMask, EltSize, NumElts, 32))
13207     return DAG.getNode(AArch64ISD::REV32, dl, V1.getValueType(), V1, V2);
13208   if (isREVMask(ShuffleMask, EltSize, NumElts, 16))
13209     return DAG.getNode(AArch64ISD::REV16, dl, V1.getValueType(), V1, V2);
13210 
13211   if (((NumElts == 8 && EltSize == 16) || (NumElts == 16 && EltSize == 8)) &&
13212       ShuffleVectorInst::isReverseMask(ShuffleMask, ShuffleMask.size())) {
13213     SDValue Rev = DAG.getNode(AArch64ISD::REV64, dl, VT, V1);
13214     return DAG.getNode(AArch64ISD::EXT, dl, VT, Rev, Rev,
13215                        DAG.getConstant(8, dl, MVT::i32));
13216   }
13217 
13218   bool ReverseEXT = false;
13219   unsigned Imm;
13220   if (isEXTMask(ShuffleMask, VT, ReverseEXT, Imm)) {
13221     if (ReverseEXT)
13222       std::swap(V1, V2);
13223     Imm *= getExtFactor(V1);
13224     return DAG.getNode(AArch64ISD::EXT, dl, V1.getValueType(), V1, V2,
13225                        DAG.getConstant(Imm, dl, MVT::i32));
13226   } else if (V2->isUndef() && isSingletonEXTMask(ShuffleMask, VT, Imm)) {
13227     Imm *= getExtFactor(V1);
13228     return DAG.getNode(AArch64ISD::EXT, dl, V1.getValueType(), V1, V1,
13229                        DAG.getConstant(Imm, dl, MVT::i32));
13230   }
13231 
13232   unsigned WhichResult;
13233   if (isZIPMask(ShuffleMask, NumElts, WhichResult)) {
13234     unsigned Opc = (WhichResult == 0) ? AArch64ISD::ZIP1 : AArch64ISD::ZIP2;
13235     return DAG.getNode(Opc, dl, V1.getValueType(), V1, V2);
13236   }
13237   if (isUZPMask(ShuffleMask, NumElts, WhichResult)) {
13238     unsigned Opc = (WhichResult == 0) ? AArch64ISD::UZP1 : AArch64ISD::UZP2;
13239     return DAG.getNode(Opc, dl, V1.getValueType(), V1, V2);
13240   }
13241   if (isTRNMask(ShuffleMask, NumElts, WhichResult)) {
13242     unsigned Opc = (WhichResult == 0) ? AArch64ISD::TRN1 : AArch64ISD::TRN2;
13243     return DAG.getNode(Opc, dl, V1.getValueType(), V1, V2);
13244   }
13245 
13246   if (isZIP_v_undef_Mask(ShuffleMask, VT, WhichResult)) {
13247     unsigned Opc = (WhichResult == 0) ? AArch64ISD::ZIP1 : AArch64ISD::ZIP2;
13248     return DAG.getNode(Opc, dl, V1.getValueType(), V1, V1);
13249   }
13250   if (isUZP_v_undef_Mask(ShuffleMask, VT, WhichResult)) {
13251     unsigned Opc = (WhichResult == 0) ? AArch64ISD::UZP1 : AArch64ISD::UZP2;
13252     return DAG.getNode(Opc, dl, V1.getValueType(), V1, V1);
13253   }
13254   if (isTRN_v_undef_Mask(ShuffleMask, VT, WhichResult)) {
13255     unsigned Opc = (WhichResult == 0) ? AArch64ISD::TRN1 : AArch64ISD::TRN2;
13256     return DAG.getNode(Opc, dl, V1.getValueType(), V1, V1);
13257   }
13258 
13259   if (SDValue Concat = tryFormConcatFromShuffle(Op, DAG))
13260     return Concat;
13261 
13262   bool DstIsLeft;
13263   int Anomaly;
13264   int NumInputElements = V1.getValueType().getVectorNumElements();
13265   if (isINSMask(ShuffleMask, NumInputElements, DstIsLeft, Anomaly)) {
13266     SDValue DstVec = DstIsLeft ? V1 : V2;
13267     SDValue DstLaneV = DAG.getConstant(Anomaly, dl, MVT::i64);
13268 
13269     SDValue SrcVec = V1;
13270     int SrcLane = ShuffleMask[Anomaly];
13271     if (SrcLane >= NumInputElements) {
13272       SrcVec = V2;
13273       SrcLane -= NumElts;
13274     }
13275     SDValue SrcLaneV = DAG.getConstant(SrcLane, dl, MVT::i64);
13276 
13277     EVT ScalarVT = VT.getVectorElementType();
13278 
13279     if (ScalarVT.getFixedSizeInBits() < 32 && ScalarVT.isInteger())
13280       ScalarVT = MVT::i32;
13281 
13282     return DAG.getNode(
13283         ISD::INSERT_VECTOR_ELT, dl, VT, DstVec,
13284         DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, ScalarVT, SrcVec, SrcLaneV),
13285         DstLaneV);
13286   }
13287 
13288   if (SDValue NewSD = tryWidenMaskForShuffle(Op, DAG))
13289     return NewSD;
13290 
13291   // If the shuffle is not directly supported and it has 4 elements, use
13292   // the PerfectShuffle-generated table to synthesize it from other shuffles.
13293   if (NumElts == 4) {
13294     unsigned PFIndexes[4];
13295     for (unsigned i = 0; i != 4; ++i) {
13296       if (ShuffleMask[i] < 0)
13297         PFIndexes[i] = 8;
13298       else
13299         PFIndexes[i] = ShuffleMask[i];
13300     }
13301 
13302     // Compute the index in the perfect shuffle table.
13303     unsigned PFTableIndex = PFIndexes[0] * 9 * 9 * 9 + PFIndexes[1] * 9 * 9 +
13304                             PFIndexes[2] * 9 + PFIndexes[3];
13305     unsigned PFEntry = PerfectShuffleTable[PFTableIndex];
13306     return GeneratePerfectShuffle(PFTableIndex, V1, V2, PFEntry, V1, V2, DAG,
13307                                   dl);
13308   }
13309 
13310   return GenerateTBL(Op, ShuffleMask, DAG);
13311 }
13312 
LowerSPLAT_VECTOR(SDValue Op,SelectionDAG & DAG) const13313 SDValue AArch64TargetLowering::LowerSPLAT_VECTOR(SDValue Op,
13314                                                  SelectionDAG &DAG) const {
13315   EVT VT = Op.getValueType();
13316 
13317   if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()))
13318     return LowerToScalableOp(Op, DAG);
13319 
13320   assert(VT.isScalableVector() && VT.getVectorElementType() == MVT::i1 &&
13321          "Unexpected vector type!");
13322 
13323   // We can handle the constant cases during isel.
13324   if (isa<ConstantSDNode>(Op.getOperand(0)))
13325     return Op;
13326 
13327   // There isn't a natural way to handle the general i1 case, so we use some
13328   // trickery with whilelo.
13329   SDLoc DL(Op);
13330   SDValue SplatVal = DAG.getAnyExtOrTrunc(Op.getOperand(0), DL, MVT::i64);
13331   SplatVal = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64, SplatVal,
13332                          DAG.getValueType(MVT::i1));
13333   SDValue ID =
13334       DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64);
13335   SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
13336   if (VT == MVT::nxv1i1)
13337     return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::nxv1i1,
13338                        DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv2i1, ID,
13339                                    Zero, SplatVal),
13340                        Zero);
13341   return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, ID, Zero, SplatVal);
13342 }
13343 
LowerDUPQLane(SDValue Op,SelectionDAG & DAG) const13344 SDValue AArch64TargetLowering::LowerDUPQLane(SDValue Op,
13345                                              SelectionDAG &DAG) const {
13346   SDLoc DL(Op);
13347 
13348   EVT VT = Op.getValueType();
13349   if (!isTypeLegal(VT) || !VT.isScalableVector())
13350     return SDValue();
13351 
13352   // Current lowering only supports the SVE-ACLE types.
13353   if (VT.getSizeInBits().getKnownMinValue() != AArch64::SVEBitsPerBlock)
13354     return SDValue();
13355 
13356   // The DUPQ operation is indepedent of element type so normalise to i64s.
13357   SDValue Idx128 = Op.getOperand(2);
13358 
13359   // DUPQ can be used when idx is in range.
13360   auto *CIdx = dyn_cast<ConstantSDNode>(Idx128);
13361   if (CIdx && (CIdx->getZExtValue() <= 3)) {
13362     SDValue CI = DAG.getTargetConstant(CIdx->getZExtValue(), DL, MVT::i64);
13363     return DAG.getNode(AArch64ISD::DUPLANE128, DL, VT, Op.getOperand(1), CI);
13364   }
13365 
13366   SDValue V = DAG.getNode(ISD::BITCAST, DL, MVT::nxv2i64, Op.getOperand(1));
13367 
13368   // The ACLE says this must produce the same result as:
13369   //   svtbl(data, svadd_x(svptrue_b64(),
13370   //                       svand_x(svptrue_b64(), svindex_u64(0, 1), 1),
13371   //                       index * 2))
13372   SDValue One = DAG.getConstant(1, DL, MVT::i64);
13373   SDValue SplatOne = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv2i64, One);
13374 
13375   // create the vector 0,1,0,1,...
13376   SDValue SV = DAG.getStepVector(DL, MVT::nxv2i64);
13377   SV = DAG.getNode(ISD::AND, DL, MVT::nxv2i64, SV, SplatOne);
13378 
13379   // create the vector idx64,idx64+1,idx64,idx64+1,...
13380   SDValue Idx64 = DAG.getNode(ISD::ADD, DL, MVT::i64, Idx128, Idx128);
13381   SDValue SplatIdx64 = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv2i64, Idx64);
13382   SDValue ShuffleMask = DAG.getNode(ISD::ADD, DL, MVT::nxv2i64, SV, SplatIdx64);
13383 
13384   // create the vector Val[idx64],Val[idx64+1],Val[idx64],Val[idx64+1],...
13385   SDValue TBL = DAG.getNode(AArch64ISD::TBL, DL, MVT::nxv2i64, V, ShuffleMask);
13386   return DAG.getNode(ISD::BITCAST, DL, VT, TBL);
13387 }
13388 
13389 
resolveBuildVector(BuildVectorSDNode * BVN,APInt & CnstBits,APInt & UndefBits)13390 static bool resolveBuildVector(BuildVectorSDNode *BVN, APInt &CnstBits,
13391                                APInt &UndefBits) {
13392   EVT VT = BVN->getValueType(0);
13393   APInt SplatBits, SplatUndef;
13394   unsigned SplatBitSize;
13395   bool HasAnyUndefs;
13396   if (BVN->isConstantSplat(SplatBits, SplatUndef, SplatBitSize, HasAnyUndefs)) {
13397     unsigned NumSplats = VT.getSizeInBits() / SplatBitSize;
13398 
13399     for (unsigned i = 0; i < NumSplats; ++i) {
13400       CnstBits <<= SplatBitSize;
13401       UndefBits <<= SplatBitSize;
13402       CnstBits |= SplatBits.zextOrTrunc(VT.getSizeInBits());
13403       UndefBits |= (SplatBits ^ SplatUndef).zextOrTrunc(VT.getSizeInBits());
13404     }
13405 
13406     return true;
13407   }
13408 
13409   return false;
13410 }
13411 
13412 // Try 64-bit splatted SIMD immediate.
tryAdvSIMDModImm64(unsigned NewOp,SDValue Op,SelectionDAG & DAG,const APInt & Bits)13413 static SDValue tryAdvSIMDModImm64(unsigned NewOp, SDValue Op, SelectionDAG &DAG,
13414                                  const APInt &Bits) {
13415   if (Bits.getHiBits(64) == Bits.getLoBits(64)) {
13416     uint64_t Value = Bits.zextOrTrunc(64).getZExtValue();
13417     EVT VT = Op.getValueType();
13418     MVT MovTy = (VT.getSizeInBits() == 128) ? MVT::v2i64 : MVT::f64;
13419 
13420     if (AArch64_AM::isAdvSIMDModImmType10(Value)) {
13421       Value = AArch64_AM::encodeAdvSIMDModImmType10(Value);
13422 
13423       SDLoc dl(Op);
13424       SDValue Mov = DAG.getNode(NewOp, dl, MovTy,
13425                                 DAG.getConstant(Value, dl, MVT::i32));
13426       return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov);
13427     }
13428   }
13429 
13430   return SDValue();
13431 }
13432 
13433 // Try 32-bit splatted SIMD immediate.
tryAdvSIMDModImm32(unsigned NewOp,SDValue Op,SelectionDAG & DAG,const APInt & Bits,const SDValue * LHS=nullptr)13434 static SDValue tryAdvSIMDModImm32(unsigned NewOp, SDValue Op, SelectionDAG &DAG,
13435                                   const APInt &Bits,
13436                                   const SDValue *LHS = nullptr) {
13437   EVT VT = Op.getValueType();
13438   if (VT.isFixedLengthVector() &&
13439       !DAG.getSubtarget<AArch64Subtarget>().isNeonAvailable())
13440     return SDValue();
13441 
13442   if (Bits.getHiBits(64) == Bits.getLoBits(64)) {
13443     uint64_t Value = Bits.zextOrTrunc(64).getZExtValue();
13444     MVT MovTy = (VT.getSizeInBits() == 128) ? MVT::v4i32 : MVT::v2i32;
13445     bool isAdvSIMDModImm = false;
13446     uint64_t Shift;
13447 
13448     if ((isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType1(Value))) {
13449       Value = AArch64_AM::encodeAdvSIMDModImmType1(Value);
13450       Shift = 0;
13451     }
13452     else if ((isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType2(Value))) {
13453       Value = AArch64_AM::encodeAdvSIMDModImmType2(Value);
13454       Shift = 8;
13455     }
13456     else if ((isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType3(Value))) {
13457       Value = AArch64_AM::encodeAdvSIMDModImmType3(Value);
13458       Shift = 16;
13459     }
13460     else if ((isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType4(Value))) {
13461       Value = AArch64_AM::encodeAdvSIMDModImmType4(Value);
13462       Shift = 24;
13463     }
13464 
13465     if (isAdvSIMDModImm) {
13466       SDLoc dl(Op);
13467       SDValue Mov;
13468 
13469       if (LHS)
13470         Mov = DAG.getNode(NewOp, dl, MovTy,
13471                           DAG.getNode(AArch64ISD::NVCAST, dl, MovTy, *LHS),
13472                           DAG.getConstant(Value, dl, MVT::i32),
13473                           DAG.getConstant(Shift, dl, MVT::i32));
13474       else
13475         Mov = DAG.getNode(NewOp, dl, MovTy,
13476                           DAG.getConstant(Value, dl, MVT::i32),
13477                           DAG.getConstant(Shift, dl, MVT::i32));
13478 
13479       return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov);
13480     }
13481   }
13482 
13483   return SDValue();
13484 }
13485 
13486 // Try 16-bit splatted SIMD immediate.
tryAdvSIMDModImm16(unsigned NewOp,SDValue Op,SelectionDAG & DAG,const APInt & Bits,const SDValue * LHS=nullptr)13487 static SDValue tryAdvSIMDModImm16(unsigned NewOp, SDValue Op, SelectionDAG &DAG,
13488                                   const APInt &Bits,
13489                                   const SDValue *LHS = nullptr) {
13490   EVT VT = Op.getValueType();
13491   if (VT.isFixedLengthVector() &&
13492       !DAG.getSubtarget<AArch64Subtarget>().isNeonAvailable())
13493     return SDValue();
13494 
13495   if (Bits.getHiBits(64) == Bits.getLoBits(64)) {
13496     uint64_t Value = Bits.zextOrTrunc(64).getZExtValue();
13497     MVT MovTy = (VT.getSizeInBits() == 128) ? MVT::v8i16 : MVT::v4i16;
13498     bool isAdvSIMDModImm = false;
13499     uint64_t Shift;
13500 
13501     if ((isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType5(Value))) {
13502       Value = AArch64_AM::encodeAdvSIMDModImmType5(Value);
13503       Shift = 0;
13504     }
13505     else if ((isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType6(Value))) {
13506       Value = AArch64_AM::encodeAdvSIMDModImmType6(Value);
13507       Shift = 8;
13508     }
13509 
13510     if (isAdvSIMDModImm) {
13511       SDLoc dl(Op);
13512       SDValue Mov;
13513 
13514       if (LHS)
13515         Mov = DAG.getNode(NewOp, dl, MovTy,
13516                           DAG.getNode(AArch64ISD::NVCAST, dl, MovTy, *LHS),
13517                           DAG.getConstant(Value, dl, MVT::i32),
13518                           DAG.getConstant(Shift, dl, MVT::i32));
13519       else
13520         Mov = DAG.getNode(NewOp, dl, MovTy,
13521                           DAG.getConstant(Value, dl, MVT::i32),
13522                           DAG.getConstant(Shift, dl, MVT::i32));
13523 
13524       return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov);
13525     }
13526   }
13527 
13528   return SDValue();
13529 }
13530 
13531 // Try 32-bit splatted SIMD immediate with shifted ones.
tryAdvSIMDModImm321s(unsigned NewOp,SDValue Op,SelectionDAG & DAG,const APInt & Bits)13532 static SDValue tryAdvSIMDModImm321s(unsigned NewOp, SDValue Op,
13533                                     SelectionDAG &DAG, const APInt &Bits) {
13534   if (Bits.getHiBits(64) == Bits.getLoBits(64)) {
13535     uint64_t Value = Bits.zextOrTrunc(64).getZExtValue();
13536     EVT VT = Op.getValueType();
13537     MVT MovTy = (VT.getSizeInBits() == 128) ? MVT::v4i32 : MVT::v2i32;
13538     bool isAdvSIMDModImm = false;
13539     uint64_t Shift;
13540 
13541     if ((isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType7(Value))) {
13542       Value = AArch64_AM::encodeAdvSIMDModImmType7(Value);
13543       Shift = 264;
13544     }
13545     else if ((isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType8(Value))) {
13546       Value = AArch64_AM::encodeAdvSIMDModImmType8(Value);
13547       Shift = 272;
13548     }
13549 
13550     if (isAdvSIMDModImm) {
13551       SDLoc dl(Op);
13552       SDValue Mov = DAG.getNode(NewOp, dl, MovTy,
13553                                 DAG.getConstant(Value, dl, MVT::i32),
13554                                 DAG.getConstant(Shift, dl, MVT::i32));
13555       return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov);
13556     }
13557   }
13558 
13559   return SDValue();
13560 }
13561 
13562 // Try 8-bit splatted SIMD immediate.
tryAdvSIMDModImm8(unsigned NewOp,SDValue Op,SelectionDAG & DAG,const APInt & Bits)13563 static SDValue tryAdvSIMDModImm8(unsigned NewOp, SDValue Op, SelectionDAG &DAG,
13564                                  const APInt &Bits) {
13565   if (Bits.getHiBits(64) == Bits.getLoBits(64)) {
13566     uint64_t Value = Bits.zextOrTrunc(64).getZExtValue();
13567     EVT VT = Op.getValueType();
13568     MVT MovTy = (VT.getSizeInBits() == 128) ? MVT::v16i8 : MVT::v8i8;
13569 
13570     if (AArch64_AM::isAdvSIMDModImmType9(Value)) {
13571       Value = AArch64_AM::encodeAdvSIMDModImmType9(Value);
13572 
13573       SDLoc dl(Op);
13574       SDValue Mov = DAG.getNode(NewOp, dl, MovTy,
13575                                 DAG.getConstant(Value, dl, MVT::i32));
13576       return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov);
13577     }
13578   }
13579 
13580   return SDValue();
13581 }
13582 
13583 // Try FP splatted SIMD immediate.
tryAdvSIMDModImmFP(unsigned NewOp,SDValue Op,SelectionDAG & DAG,const APInt & Bits)13584 static SDValue tryAdvSIMDModImmFP(unsigned NewOp, SDValue Op, SelectionDAG &DAG,
13585                                   const APInt &Bits) {
13586   if (Bits.getHiBits(64) == Bits.getLoBits(64)) {
13587     uint64_t Value = Bits.zextOrTrunc(64).getZExtValue();
13588     EVT VT = Op.getValueType();
13589     bool isWide = (VT.getSizeInBits() == 128);
13590     MVT MovTy;
13591     bool isAdvSIMDModImm = false;
13592 
13593     if ((isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType11(Value))) {
13594       Value = AArch64_AM::encodeAdvSIMDModImmType11(Value);
13595       MovTy = isWide ? MVT::v4f32 : MVT::v2f32;
13596     }
13597     else if (isWide &&
13598              (isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType12(Value))) {
13599       Value = AArch64_AM::encodeAdvSIMDModImmType12(Value);
13600       MovTy = MVT::v2f64;
13601     }
13602 
13603     if (isAdvSIMDModImm) {
13604       SDLoc dl(Op);
13605       SDValue Mov = DAG.getNode(NewOp, dl, MovTy,
13606                                 DAG.getConstant(Value, dl, MVT::i32));
13607       return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov);
13608     }
13609   }
13610 
13611   return SDValue();
13612 }
13613 
13614 // Specialized code to quickly find if PotentialBVec is a BuildVector that
13615 // consists of only the same constant int value, returned in reference arg
13616 // ConstVal
isAllConstantBuildVector(const SDValue & PotentialBVec,uint64_t & ConstVal)13617 static bool isAllConstantBuildVector(const SDValue &PotentialBVec,
13618                                      uint64_t &ConstVal) {
13619   BuildVectorSDNode *Bvec = dyn_cast<BuildVectorSDNode>(PotentialBVec);
13620   if (!Bvec)
13621     return false;
13622   ConstantSDNode *FirstElt = dyn_cast<ConstantSDNode>(Bvec->getOperand(0));
13623   if (!FirstElt)
13624     return false;
13625   EVT VT = Bvec->getValueType(0);
13626   unsigned NumElts = VT.getVectorNumElements();
13627   for (unsigned i = 1; i < NumElts; ++i)
13628     if (dyn_cast<ConstantSDNode>(Bvec->getOperand(i)) != FirstElt)
13629       return false;
13630   ConstVal = FirstElt->getZExtValue();
13631   return true;
13632 }
13633 
isAllInactivePredicate(SDValue N)13634 static bool isAllInactivePredicate(SDValue N) {
13635   // Look through cast.
13636   while (N.getOpcode() == AArch64ISD::REINTERPRET_CAST)
13637     N = N.getOperand(0);
13638 
13639   return ISD::isConstantSplatVectorAllZeros(N.getNode());
13640 }
13641 
isAllActivePredicate(SelectionDAG & DAG,SDValue N)13642 static bool isAllActivePredicate(SelectionDAG &DAG, SDValue N) {
13643   unsigned NumElts = N.getValueType().getVectorMinNumElements();
13644 
13645   // Look through cast.
13646   while (N.getOpcode() == AArch64ISD::REINTERPRET_CAST) {
13647     N = N.getOperand(0);
13648     // When reinterpreting from a type with fewer elements the "new" elements
13649     // are not active, so bail if they're likely to be used.
13650     if (N.getValueType().getVectorMinNumElements() < NumElts)
13651       return false;
13652   }
13653 
13654   if (ISD::isConstantSplatVectorAllOnes(N.getNode()))
13655     return true;
13656 
13657   // "ptrue p.<ty>, all" can be considered all active when <ty> is the same size
13658   // or smaller than the implicit element type represented by N.
13659   // NOTE: A larger element count implies a smaller element type.
13660   if (N.getOpcode() == AArch64ISD::PTRUE &&
13661       N.getConstantOperandVal(0) == AArch64SVEPredPattern::all)
13662     return N.getValueType().getVectorMinNumElements() >= NumElts;
13663 
13664   // If we're compiling for a specific vector-length, we can check if the
13665   // pattern's VL equals that of the scalable vector at runtime.
13666   if (N.getOpcode() == AArch64ISD::PTRUE) {
13667     const auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
13668     unsigned MinSVESize = Subtarget.getMinSVEVectorSizeInBits();
13669     unsigned MaxSVESize = Subtarget.getMaxSVEVectorSizeInBits();
13670     if (MaxSVESize && MinSVESize == MaxSVESize) {
13671       unsigned VScale = MaxSVESize / AArch64::SVEBitsPerBlock;
13672       unsigned PatNumElts =
13673           getNumElementsFromSVEPredPattern(N.getConstantOperandVal(0));
13674       return PatNumElts == (NumElts * VScale);
13675     }
13676   }
13677 
13678   return false;
13679 }
13680 
13681 // Attempt to form a vector S[LR]I from (or (and X, BvecC1), (lsl Y, C2)),
13682 // to (SLI X, Y, C2), where X and Y have matching vector types, BvecC1 is a
13683 // BUILD_VECTORs with constant element C1, C2 is a constant, and:
13684 //   - for the SLI case: C1 == ~(Ones(ElemSizeInBits) << C2)
13685 //   - for the SRI case: C1 == ~(Ones(ElemSizeInBits) >> C2)
13686 // The (or (lsl Y, C2), (and X, BvecC1)) case is also handled.
tryLowerToSLI(SDNode * N,SelectionDAG & DAG)13687 static SDValue tryLowerToSLI(SDNode *N, SelectionDAG &DAG) {
13688   EVT VT = N->getValueType(0);
13689 
13690   if (!VT.isVector())
13691     return SDValue();
13692 
13693   SDLoc DL(N);
13694 
13695   SDValue And;
13696   SDValue Shift;
13697 
13698   SDValue FirstOp = N->getOperand(0);
13699   unsigned FirstOpc = FirstOp.getOpcode();
13700   SDValue SecondOp = N->getOperand(1);
13701   unsigned SecondOpc = SecondOp.getOpcode();
13702 
13703   // Is one of the operands an AND or a BICi? The AND may have been optimised to
13704   // a BICi in order to use an immediate instead of a register.
13705   // Is the other operand an shl or lshr? This will have been turned into:
13706   // AArch64ISD::VSHL vector, #shift or AArch64ISD::VLSHR vector, #shift
13707   // or (AArch64ISD::SHL_PRED || AArch64ISD::SRL_PRED) mask, vector, #shiftVec.
13708   if ((FirstOpc == ISD::AND || FirstOpc == AArch64ISD::BICi) &&
13709       (SecondOpc == AArch64ISD::VSHL || SecondOpc == AArch64ISD::VLSHR ||
13710        SecondOpc == AArch64ISD::SHL_PRED ||
13711        SecondOpc == AArch64ISD::SRL_PRED)) {
13712     And = FirstOp;
13713     Shift = SecondOp;
13714 
13715   } else if ((SecondOpc == ISD::AND || SecondOpc == AArch64ISD::BICi) &&
13716              (FirstOpc == AArch64ISD::VSHL || FirstOpc == AArch64ISD::VLSHR ||
13717               FirstOpc == AArch64ISD::SHL_PRED ||
13718               FirstOpc == AArch64ISD::SRL_PRED)) {
13719     And = SecondOp;
13720     Shift = FirstOp;
13721   } else
13722     return SDValue();
13723 
13724   bool IsAnd = And.getOpcode() == ISD::AND;
13725   bool IsShiftRight = Shift.getOpcode() == AArch64ISD::VLSHR ||
13726                       Shift.getOpcode() == AArch64ISD::SRL_PRED;
13727   bool ShiftHasPredOp = Shift.getOpcode() == AArch64ISD::SHL_PRED ||
13728                         Shift.getOpcode() == AArch64ISD::SRL_PRED;
13729 
13730   // Is the shift amount constant and are all lanes active?
13731   uint64_t C2;
13732   if (ShiftHasPredOp) {
13733     if (!isAllActivePredicate(DAG, Shift.getOperand(0)))
13734       return SDValue();
13735     APInt C;
13736     if (!ISD::isConstantSplatVector(Shift.getOperand(2).getNode(), C))
13737       return SDValue();
13738     C2 = C.getZExtValue();
13739   } else if (ConstantSDNode *C2node =
13740                  dyn_cast<ConstantSDNode>(Shift.getOperand(1)))
13741     C2 = C2node->getZExtValue();
13742   else
13743     return SDValue();
13744 
13745   APInt C1AsAPInt;
13746   unsigned ElemSizeInBits = VT.getScalarSizeInBits();
13747   if (IsAnd) {
13748     // Is the and mask vector all constant?
13749     if (!ISD::isConstantSplatVector(And.getOperand(1).getNode(), C1AsAPInt))
13750       return SDValue();
13751   } else {
13752     // Reconstruct the corresponding AND immediate from the two BICi immediates.
13753     ConstantSDNode *C1nodeImm = dyn_cast<ConstantSDNode>(And.getOperand(1));
13754     ConstantSDNode *C1nodeShift = dyn_cast<ConstantSDNode>(And.getOperand(2));
13755     assert(C1nodeImm && C1nodeShift);
13756     C1AsAPInt = ~(C1nodeImm->getAPIntValue() << C1nodeShift->getAPIntValue());
13757     C1AsAPInt = C1AsAPInt.zextOrTrunc(ElemSizeInBits);
13758   }
13759 
13760   // Is C1 == ~(Ones(ElemSizeInBits) << C2) or
13761   // C1 == ~(Ones(ElemSizeInBits) >> C2), taking into account
13762   // how much one can shift elements of a particular size?
13763   if (C2 > ElemSizeInBits)
13764     return SDValue();
13765 
13766   APInt RequiredC1 = IsShiftRight ? APInt::getHighBitsSet(ElemSizeInBits, C2)
13767                                   : APInt::getLowBitsSet(ElemSizeInBits, C2);
13768   if (C1AsAPInt != RequiredC1)
13769     return SDValue();
13770 
13771   SDValue X = And.getOperand(0);
13772   SDValue Y = ShiftHasPredOp ? Shift.getOperand(1) : Shift.getOperand(0);
13773   SDValue Imm = ShiftHasPredOp ? DAG.getTargetConstant(C2, DL, MVT::i32)
13774                                : Shift.getOperand(1);
13775 
13776   unsigned Inst = IsShiftRight ? AArch64ISD::VSRI : AArch64ISD::VSLI;
13777   SDValue ResultSLI = DAG.getNode(Inst, DL, VT, X, Y, Imm);
13778 
13779   LLVM_DEBUG(dbgs() << "aarch64-lower: transformed: \n");
13780   LLVM_DEBUG(N->dump(&DAG));
13781   LLVM_DEBUG(dbgs() << "into: \n");
13782   LLVM_DEBUG(ResultSLI->dump(&DAG));
13783 
13784   ++NumShiftInserts;
13785   return ResultSLI;
13786 }
13787 
LowerVectorOR(SDValue Op,SelectionDAG & DAG) const13788 SDValue AArch64TargetLowering::LowerVectorOR(SDValue Op,
13789                                              SelectionDAG &DAG) const {
13790   if (useSVEForFixedLengthVectorVT(Op.getValueType(),
13791                                    !Subtarget->isNeonAvailable()))
13792     return LowerToScalableOp(Op, DAG);
13793 
13794   // Attempt to form a vector S[LR]I from (or (and X, C1), (lsl Y, C2))
13795   if (SDValue Res = tryLowerToSLI(Op.getNode(), DAG))
13796     return Res;
13797 
13798   EVT VT = Op.getValueType();
13799   if (VT.isScalableVector())
13800     return Op;
13801 
13802   SDValue LHS = Op.getOperand(0);
13803   BuildVectorSDNode *BVN =
13804       dyn_cast<BuildVectorSDNode>(Op.getOperand(1).getNode());
13805   if (!BVN) {
13806     // OR commutes, so try swapping the operands.
13807     LHS = Op.getOperand(1);
13808     BVN = dyn_cast<BuildVectorSDNode>(Op.getOperand(0).getNode());
13809   }
13810   if (!BVN)
13811     return Op;
13812 
13813   APInt DefBits(VT.getSizeInBits(), 0);
13814   APInt UndefBits(VT.getSizeInBits(), 0);
13815   if (resolveBuildVector(BVN, DefBits, UndefBits)) {
13816     SDValue NewOp;
13817 
13818     if ((NewOp = tryAdvSIMDModImm32(AArch64ISD::ORRi, Op, DAG,
13819                                     DefBits, &LHS)) ||
13820         (NewOp = tryAdvSIMDModImm16(AArch64ISD::ORRi, Op, DAG,
13821                                     DefBits, &LHS)))
13822       return NewOp;
13823 
13824     if ((NewOp = tryAdvSIMDModImm32(AArch64ISD::ORRi, Op, DAG,
13825                                     UndefBits, &LHS)) ||
13826         (NewOp = tryAdvSIMDModImm16(AArch64ISD::ORRi, Op, DAG,
13827                                     UndefBits, &LHS)))
13828       return NewOp;
13829   }
13830 
13831   // We can always fall back to a non-immediate OR.
13832   return Op;
13833 }
13834 
13835 // Normalize the operands of BUILD_VECTOR. The value of constant operands will
13836 // be truncated to fit element width.
NormalizeBuildVector(SDValue Op,SelectionDAG & DAG)13837 static SDValue NormalizeBuildVector(SDValue Op,
13838                                     SelectionDAG &DAG) {
13839   assert(Op.getOpcode() == ISD::BUILD_VECTOR && "Unknown opcode!");
13840   SDLoc dl(Op);
13841   EVT VT = Op.getValueType();
13842   EVT EltTy= VT.getVectorElementType();
13843 
13844   if (EltTy.isFloatingPoint() || EltTy.getSizeInBits() > 16)
13845     return Op;
13846 
13847   SmallVector<SDValue, 16> Ops;
13848   for (SDValue Lane : Op->ops()) {
13849     // For integer vectors, type legalization would have promoted the
13850     // operands already. Otherwise, if Op is a floating-point splat
13851     // (with operands cast to integers), then the only possibilities
13852     // are constants and UNDEFs.
13853     if (auto *CstLane = dyn_cast<ConstantSDNode>(Lane)) {
13854       APInt LowBits(EltTy.getSizeInBits(),
13855                     CstLane->getZExtValue());
13856       Lane = DAG.getConstant(LowBits.getZExtValue(), dl, MVT::i32);
13857     } else if (Lane.getNode()->isUndef()) {
13858       Lane = DAG.getUNDEF(MVT::i32);
13859     } else {
13860       assert(Lane.getValueType() == MVT::i32 &&
13861              "Unexpected BUILD_VECTOR operand type");
13862     }
13863     Ops.push_back(Lane);
13864   }
13865   return DAG.getBuildVector(VT, dl, Ops);
13866 }
13867 
ConstantBuildVector(SDValue Op,SelectionDAG & DAG,const AArch64Subtarget * ST)13868 static SDValue ConstantBuildVector(SDValue Op, SelectionDAG &DAG,
13869                                    const AArch64Subtarget *ST) {
13870   EVT VT = Op.getValueType();
13871   assert((VT.getSizeInBits() == 64 || VT.getSizeInBits() == 128) &&
13872          "Expected a legal NEON vector");
13873 
13874   APInt DefBits(VT.getSizeInBits(), 0);
13875   APInt UndefBits(VT.getSizeInBits(), 0);
13876   BuildVectorSDNode *BVN = cast<BuildVectorSDNode>(Op.getNode());
13877   if (resolveBuildVector(BVN, DefBits, UndefBits)) {
13878     auto TryMOVIWithBits = [&](APInt DefBits) {
13879       SDValue NewOp;
13880       if ((NewOp =
13881                tryAdvSIMDModImm64(AArch64ISD::MOVIedit, Op, DAG, DefBits)) ||
13882           (NewOp =
13883                tryAdvSIMDModImm32(AArch64ISD::MOVIshift, Op, DAG, DefBits)) ||
13884           (NewOp =
13885                tryAdvSIMDModImm321s(AArch64ISD::MOVImsl, Op, DAG, DefBits)) ||
13886           (NewOp =
13887                tryAdvSIMDModImm16(AArch64ISD::MOVIshift, Op, DAG, DefBits)) ||
13888           (NewOp = tryAdvSIMDModImm8(AArch64ISD::MOVI, Op, DAG, DefBits)) ||
13889           (NewOp = tryAdvSIMDModImmFP(AArch64ISD::FMOV, Op, DAG, DefBits)))
13890         return NewOp;
13891 
13892       APInt NotDefBits = ~DefBits;
13893       if ((NewOp = tryAdvSIMDModImm32(AArch64ISD::MVNIshift, Op, DAG,
13894                                       NotDefBits)) ||
13895           (NewOp = tryAdvSIMDModImm321s(AArch64ISD::MVNImsl, Op, DAG,
13896                                         NotDefBits)) ||
13897           (NewOp =
13898                tryAdvSIMDModImm16(AArch64ISD::MVNIshift, Op, DAG, NotDefBits)))
13899         return NewOp;
13900       return SDValue();
13901     };
13902     if (SDValue R = TryMOVIWithBits(DefBits))
13903       return R;
13904     if (SDValue R = TryMOVIWithBits(UndefBits))
13905       return R;
13906 
13907     // See if a fneg of the constant can be materialized with a MOVI, etc
13908     auto TryWithFNeg = [&](APInt DefBits, MVT FVT) {
13909       // FNegate each sub-element of the constant
13910       assert(VT.getSizeInBits() % FVT.getScalarSizeInBits() == 0);
13911       APInt Neg = APInt::getHighBitsSet(FVT.getSizeInBits(), 1)
13912                       .zext(VT.getSizeInBits());
13913       APInt NegBits(VT.getSizeInBits(), 0);
13914       unsigned NumElts = VT.getSizeInBits() / FVT.getScalarSizeInBits();
13915       for (unsigned i = 0; i < NumElts; i++)
13916         NegBits |= Neg << (FVT.getScalarSizeInBits() * i);
13917       NegBits = DefBits ^ NegBits;
13918 
13919       // Try to create the new constants with MOVI, and if so generate a fneg
13920       // for it.
13921       if (SDValue NewOp = TryMOVIWithBits(NegBits)) {
13922         SDLoc DL(Op);
13923         MVT VFVT = NumElts == 1 ? FVT : MVT::getVectorVT(FVT, NumElts);
13924         return DAG.getNode(
13925             AArch64ISD::NVCAST, DL, VT,
13926             DAG.getNode(ISD::FNEG, DL, VFVT,
13927                         DAG.getNode(AArch64ISD::NVCAST, DL, VFVT, NewOp)));
13928       }
13929       return SDValue();
13930     };
13931     SDValue R;
13932     if ((R = TryWithFNeg(DefBits, MVT::f32)) ||
13933         (R = TryWithFNeg(DefBits, MVT::f64)) ||
13934         (ST->hasFullFP16() && (R = TryWithFNeg(DefBits, MVT::f16))))
13935       return R;
13936   }
13937 
13938   return SDValue();
13939 }
13940 
LowerBUILD_VECTOR(SDValue Op,SelectionDAG & DAG) const13941 SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op,
13942                                                  SelectionDAG &DAG) const {
13943   EVT VT = Op.getValueType();
13944 
13945   if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable())) {
13946     if (auto SeqInfo = cast<BuildVectorSDNode>(Op)->isConstantSequence()) {
13947       SDLoc DL(Op);
13948       EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
13949       SDValue Start = DAG.getConstant(SeqInfo->first, DL, ContainerVT);
13950       SDValue Steps = DAG.getStepVector(DL, ContainerVT, SeqInfo->second);
13951       SDValue Seq = DAG.getNode(ISD::ADD, DL, ContainerVT, Start, Steps);
13952       return convertFromScalableVector(DAG, Op.getValueType(), Seq);
13953     }
13954 
13955     // Revert to common legalisation for all other variants.
13956     return SDValue();
13957   }
13958 
13959   // Try to build a simple constant vector.
13960   Op = NormalizeBuildVector(Op, DAG);
13961   // Thought this might return a non-BUILD_VECTOR (e.g. CONCAT_VECTORS), if so,
13962   // abort.
13963   if (Op.getOpcode() != ISD::BUILD_VECTOR)
13964     return SDValue();
13965 
13966   // Certain vector constants, used to express things like logical NOT and
13967   // arithmetic NEG, are passed through unmodified.  This allows special
13968   // patterns for these operations to match, which will lower these constants
13969   // to whatever is proven necessary.
13970   BuildVectorSDNode *BVN = cast<BuildVectorSDNode>(Op.getNode());
13971   if (BVN->isConstant()) {
13972     if (ConstantSDNode *Const = BVN->getConstantSplatNode()) {
13973       unsigned BitSize = VT.getVectorElementType().getSizeInBits();
13974       APInt Val(BitSize,
13975                 Const->getAPIntValue().zextOrTrunc(BitSize).getZExtValue());
13976       if (Val.isZero() || (VT.isInteger() && Val.isAllOnes()))
13977         return Op;
13978     }
13979     if (ConstantFPSDNode *Const = BVN->getConstantFPSplatNode())
13980       if (Const->isZero() && !Const->isNegative())
13981         return Op;
13982   }
13983 
13984   if (SDValue V = ConstantBuildVector(Op, DAG, Subtarget))
13985     return V;
13986 
13987   // Scan through the operands to find some interesting properties we can
13988   // exploit:
13989   //   1) If only one value is used, we can use a DUP, or
13990   //   2) if only the low element is not undef, we can just insert that, or
13991   //   3) if only one constant value is used (w/ some non-constant lanes),
13992   //      we can splat the constant value into the whole vector then fill
13993   //      in the non-constant lanes.
13994   //   4) FIXME: If different constant values are used, but we can intelligently
13995   //             select the values we'll be overwriting for the non-constant
13996   //             lanes such that we can directly materialize the vector
13997   //             some other way (MOVI, e.g.), we can be sneaky.
13998   //   5) if all operands are EXTRACT_VECTOR_ELT, check for VUZP.
13999   SDLoc dl(Op);
14000   unsigned NumElts = VT.getVectorNumElements();
14001   bool isOnlyLowElement = true;
14002   bool usesOnlyOneValue = true;
14003   bool usesOnlyOneConstantValue = true;
14004   bool isConstant = true;
14005   bool AllLanesExtractElt = true;
14006   unsigned NumConstantLanes = 0;
14007   unsigned NumDifferentLanes = 0;
14008   unsigned NumUndefLanes = 0;
14009   SDValue Value;
14010   SDValue ConstantValue;
14011   SmallMapVector<SDValue, unsigned, 16> DifferentValueMap;
14012   unsigned ConsecutiveValCount = 0;
14013   SDValue PrevVal;
14014   for (unsigned i = 0; i < NumElts; ++i) {
14015     SDValue V = Op.getOperand(i);
14016     if (V.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
14017       AllLanesExtractElt = false;
14018     if (V.isUndef()) {
14019       ++NumUndefLanes;
14020       continue;
14021     }
14022     if (i > 0)
14023       isOnlyLowElement = false;
14024     if (!isIntOrFPConstant(V))
14025       isConstant = false;
14026 
14027     if (isIntOrFPConstant(V)) {
14028       ++NumConstantLanes;
14029       if (!ConstantValue.getNode())
14030         ConstantValue = V;
14031       else if (ConstantValue != V)
14032         usesOnlyOneConstantValue = false;
14033     }
14034 
14035     if (!Value.getNode())
14036       Value = V;
14037     else if (V != Value) {
14038       usesOnlyOneValue = false;
14039       ++NumDifferentLanes;
14040     }
14041 
14042     if (PrevVal != V) {
14043       ConsecutiveValCount = 0;
14044       PrevVal = V;
14045     }
14046 
14047     // Keep different values and its last consecutive count. For example,
14048     //
14049     //  t22: v16i8 = build_vector t23, t23, t23, t23, t23, t23, t23, t23,
14050     //                            t24, t24, t24, t24, t24, t24, t24, t24
14051     //  t23 = consecutive count 8
14052     //  t24 = consecutive count 8
14053     // ------------------------------------------------------------------
14054     //  t22: v16i8 = build_vector t24, t24, t23, t23, t23, t23, t23, t24,
14055     //                            t24, t24, t24, t24, t24, t24, t24, t24
14056     //  t23 = consecutive count 5
14057     //  t24 = consecutive count 9
14058     DifferentValueMap[V] = ++ConsecutiveValCount;
14059   }
14060 
14061   if (!Value.getNode()) {
14062     LLVM_DEBUG(
14063         dbgs() << "LowerBUILD_VECTOR: value undefined, creating undef node\n");
14064     return DAG.getUNDEF(VT);
14065   }
14066 
14067   // Convert BUILD_VECTOR where all elements but the lowest are undef into
14068   // SCALAR_TO_VECTOR, except for when we have a single-element constant vector
14069   // as SimplifyDemandedBits will just turn that back into BUILD_VECTOR.
14070   if (isOnlyLowElement && !(NumElts == 1 && isIntOrFPConstant(Value))) {
14071     LLVM_DEBUG(dbgs() << "LowerBUILD_VECTOR: only low element used, creating 1 "
14072                          "SCALAR_TO_VECTOR node\n");
14073     return DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT, Value);
14074   }
14075 
14076   if (AllLanesExtractElt) {
14077     SDNode *Vector = nullptr;
14078     bool Even = false;
14079     bool Odd = false;
14080     // Check whether the extract elements match the Even pattern <0,2,4,...> or
14081     // the Odd pattern <1,3,5,...>.
14082     for (unsigned i = 0; i < NumElts; ++i) {
14083       SDValue V = Op.getOperand(i);
14084       const SDNode *N = V.getNode();
14085       if (!isa<ConstantSDNode>(N->getOperand(1))) {
14086         Even = false;
14087         Odd = false;
14088         break;
14089       }
14090       SDValue N0 = N->getOperand(0);
14091 
14092       // All elements are extracted from the same vector.
14093       if (!Vector) {
14094         Vector = N0.getNode();
14095         // Check that the type of EXTRACT_VECTOR_ELT matches the type of
14096         // BUILD_VECTOR.
14097         if (VT.getVectorElementType() !=
14098             N0.getValueType().getVectorElementType())
14099           break;
14100       } else if (Vector != N0.getNode()) {
14101         Odd = false;
14102         Even = false;
14103         break;
14104       }
14105 
14106       // Extracted values are either at Even indices <0,2,4,...> or at Odd
14107       // indices <1,3,5,...>.
14108       uint64_t Val = N->getConstantOperandVal(1);
14109       if (Val == 2 * i) {
14110         Even = true;
14111         continue;
14112       }
14113       if (Val - 1 == 2 * i) {
14114         Odd = true;
14115         continue;
14116       }
14117 
14118       // Something does not match: abort.
14119       Odd = false;
14120       Even = false;
14121       break;
14122     }
14123     if (Even || Odd) {
14124       SDValue LHS =
14125           DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, SDValue(Vector, 0),
14126                       DAG.getConstant(0, dl, MVT::i64));
14127       SDValue RHS =
14128           DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, SDValue(Vector, 0),
14129                       DAG.getConstant(NumElts, dl, MVT::i64));
14130 
14131       if (Even && !Odd)
14132         return DAG.getNode(AArch64ISD::UZP1, dl, VT, LHS, RHS);
14133       if (Odd && !Even)
14134         return DAG.getNode(AArch64ISD::UZP2, dl, VT, LHS, RHS);
14135     }
14136   }
14137 
14138   // Use DUP for non-constant splats. For f32 constant splats, reduce to
14139   // i32 and try again.
14140   if (usesOnlyOneValue) {
14141     if (!isConstant) {
14142       if (Value.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
14143           Value.getValueType() != VT) {
14144         LLVM_DEBUG(
14145             dbgs() << "LowerBUILD_VECTOR: use DUP for non-constant splats\n");
14146         return DAG.getNode(AArch64ISD::DUP, dl, VT, Value);
14147       }
14148 
14149       // This is actually a DUPLANExx operation, which keeps everything vectory.
14150 
14151       SDValue Lane = Value.getOperand(1);
14152       Value = Value.getOperand(0);
14153       if (Value.getValueSizeInBits() == 64) {
14154         LLVM_DEBUG(
14155             dbgs() << "LowerBUILD_VECTOR: DUPLANE works on 128-bit vectors, "
14156                       "widening it\n");
14157         Value = WidenVector(Value, DAG);
14158       }
14159 
14160       unsigned Opcode = getDUPLANEOp(VT.getVectorElementType());
14161       return DAG.getNode(Opcode, dl, VT, Value, Lane);
14162     }
14163 
14164     if (VT.getVectorElementType().isFloatingPoint()) {
14165       SmallVector<SDValue, 8> Ops;
14166       EVT EltTy = VT.getVectorElementType();
14167       assert ((EltTy == MVT::f16 || EltTy == MVT::bf16 || EltTy == MVT::f32 ||
14168                EltTy == MVT::f64) && "Unsupported floating-point vector type");
14169       LLVM_DEBUG(
14170           dbgs() << "LowerBUILD_VECTOR: float constant splats, creating int "
14171                     "BITCASTS, and try again\n");
14172       MVT NewType = MVT::getIntegerVT(EltTy.getSizeInBits());
14173       for (unsigned i = 0; i < NumElts; ++i)
14174         Ops.push_back(DAG.getNode(ISD::BITCAST, dl, NewType, Op.getOperand(i)));
14175       EVT VecVT = EVT::getVectorVT(*DAG.getContext(), NewType, NumElts);
14176       SDValue Val = DAG.getBuildVector(VecVT, dl, Ops);
14177       LLVM_DEBUG(dbgs() << "LowerBUILD_VECTOR: trying to lower new vector: ";
14178                  Val.dump(););
14179       Val = LowerBUILD_VECTOR(Val, DAG);
14180       if (Val.getNode())
14181         return DAG.getNode(ISD::BITCAST, dl, VT, Val);
14182     }
14183   }
14184 
14185   // If we need to insert a small number of different non-constant elements and
14186   // the vector width is sufficiently large, prefer using DUP with the common
14187   // value and INSERT_VECTOR_ELT for the different lanes. If DUP is preferred,
14188   // skip the constant lane handling below.
14189   bool PreferDUPAndInsert =
14190       !isConstant && NumDifferentLanes >= 1 &&
14191       NumDifferentLanes < ((NumElts - NumUndefLanes) / 2) &&
14192       NumDifferentLanes >= NumConstantLanes;
14193 
14194   // If there was only one constant value used and for more than one lane,
14195   // start by splatting that value, then replace the non-constant lanes. This
14196   // is better than the default, which will perform a separate initialization
14197   // for each lane.
14198   if (!PreferDUPAndInsert && NumConstantLanes > 0 && usesOnlyOneConstantValue) {
14199     // Firstly, try to materialize the splat constant.
14200     SDValue Val = DAG.getSplatBuildVector(VT, dl, ConstantValue);
14201     unsigned BitSize = VT.getScalarSizeInBits();
14202     APInt ConstantValueAPInt(1, 0);
14203     if (auto *C = dyn_cast<ConstantSDNode>(ConstantValue))
14204       ConstantValueAPInt = C->getAPIntValue().zextOrTrunc(BitSize);
14205     if (!isNullConstant(ConstantValue) && !isNullFPConstant(ConstantValue) &&
14206         !ConstantValueAPInt.isAllOnes()) {
14207       Val = ConstantBuildVector(Val, DAG, Subtarget);
14208       if (!Val)
14209         // Otherwise, materialize the constant and splat it.
14210         Val = DAG.getNode(AArch64ISD::DUP, dl, VT, ConstantValue);
14211     }
14212 
14213     // Now insert the non-constant lanes.
14214     for (unsigned i = 0; i < NumElts; ++i) {
14215       SDValue V = Op.getOperand(i);
14216       SDValue LaneIdx = DAG.getConstant(i, dl, MVT::i64);
14217       if (!isIntOrFPConstant(V))
14218         // Note that type legalization likely mucked about with the VT of the
14219         // source operand, so we may have to convert it here before inserting.
14220         Val = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, Val, V, LaneIdx);
14221     }
14222     return Val;
14223   }
14224 
14225   // This will generate a load from the constant pool.
14226   if (isConstant) {
14227     LLVM_DEBUG(
14228         dbgs() << "LowerBUILD_VECTOR: all elements are constant, use default "
14229                   "expansion\n");
14230     return SDValue();
14231   }
14232 
14233   // Detect patterns of a0,a1,a2,a3,b0,b1,b2,b3,c0,c1,c2,c3,d0,d1,d2,d3 from
14234   // v4i32s. This is really a truncate, which we can construct out of (legal)
14235   // concats and truncate nodes.
14236   if (SDValue M = ReconstructTruncateFromBuildVector(Op, DAG))
14237     return M;
14238 
14239   // Empirical tests suggest this is rarely worth it for vectors of length <= 2.
14240   if (NumElts >= 4) {
14241     if (SDValue Shuffle = ReconstructShuffle(Op, DAG))
14242       return Shuffle;
14243 
14244     if (SDValue Shuffle = ReconstructShuffleWithRuntimeMask(Op, DAG))
14245       return Shuffle;
14246   }
14247 
14248   if (PreferDUPAndInsert) {
14249     // First, build a constant vector with the common element.
14250     SmallVector<SDValue, 8> Ops(NumElts, Value);
14251     SDValue NewVector = LowerBUILD_VECTOR(DAG.getBuildVector(VT, dl, Ops), DAG);
14252     // Next, insert the elements that do not match the common value.
14253     for (unsigned I = 0; I < NumElts; ++I)
14254       if (Op.getOperand(I) != Value)
14255         NewVector =
14256             DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, NewVector,
14257                         Op.getOperand(I), DAG.getConstant(I, dl, MVT::i64));
14258 
14259     return NewVector;
14260   }
14261 
14262   // If vector consists of two different values, try to generate two DUPs and
14263   // (CONCAT_VECTORS or VECTOR_SHUFFLE).
14264   if (DifferentValueMap.size() == 2 && NumUndefLanes == 0) {
14265     SmallVector<SDValue, 2> Vals;
14266     // Check the consecutive count of the value is the half number of vector
14267     // elements. In this case, we can use CONCAT_VECTORS. For example,
14268     //
14269     // canUseVECTOR_CONCAT = true;
14270     //  t22: v16i8 = build_vector t23, t23, t23, t23, t23, t23, t23, t23,
14271     //                            t24, t24, t24, t24, t24, t24, t24, t24
14272     //
14273     // canUseVECTOR_CONCAT = false;
14274     //  t22: v16i8 = build_vector t23, t23, t23, t23, t23, t24, t24, t24,
14275     //                            t24, t24, t24, t24, t24, t24, t24, t24
14276     bool canUseVECTOR_CONCAT = true;
14277     for (auto Pair : DifferentValueMap) {
14278       // Check different values have same length which is NumElts / 2.
14279       if (Pair.second != NumElts / 2)
14280         canUseVECTOR_CONCAT = false;
14281       Vals.push_back(Pair.first);
14282     }
14283 
14284     // If canUseVECTOR_CONCAT is true, we can generate two DUPs and
14285     // CONCAT_VECTORs. For example,
14286     //
14287     //  t22: v16i8 = BUILD_VECTOR t23, t23, t23, t23, t23, t23, t23, t23,
14288     //                            t24, t24, t24, t24, t24, t24, t24, t24
14289     // ==>
14290     //    t26: v8i8 = AArch64ISD::DUP t23
14291     //    t28: v8i8 = AArch64ISD::DUP t24
14292     //  t29: v16i8 = concat_vectors t26, t28
14293     if (canUseVECTOR_CONCAT) {
14294       EVT SubVT = VT.getHalfNumVectorElementsVT(*DAG.getContext());
14295       if (isTypeLegal(SubVT) && SubVT.isVector() &&
14296           SubVT.getVectorNumElements() >= 2) {
14297         SmallVector<SDValue, 8> Ops1(NumElts / 2, Vals[0]);
14298         SmallVector<SDValue, 8> Ops2(NumElts / 2, Vals[1]);
14299         SDValue DUP1 =
14300             LowerBUILD_VECTOR(DAG.getBuildVector(SubVT, dl, Ops1), DAG);
14301         SDValue DUP2 =
14302             LowerBUILD_VECTOR(DAG.getBuildVector(SubVT, dl, Ops2), DAG);
14303         SDValue CONCAT_VECTORS =
14304             DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, DUP1, DUP2);
14305         return CONCAT_VECTORS;
14306       }
14307     }
14308 
14309     // Let's try to generate VECTOR_SHUFFLE. For example,
14310     //
14311     //  t24: v8i8 = BUILD_VECTOR t25, t25, t25, t25, t26, t26, t26, t26
14312     //  ==>
14313     //    t27: v8i8 = BUILD_VECTOR t26, t26, t26, t26, t26, t26, t26, t26
14314     //    t28: v8i8 = BUILD_VECTOR t25, t25, t25, t25, t25, t25, t25, t25
14315     //  t29: v8i8 = vector_shuffle<0,1,2,3,12,13,14,15> t27, t28
14316     if (NumElts >= 8) {
14317       SmallVector<int, 16> MaskVec;
14318       // Build mask for VECTOR_SHUFLLE.
14319       SDValue FirstLaneVal = Op.getOperand(0);
14320       for (unsigned i = 0; i < NumElts; ++i) {
14321         SDValue Val = Op.getOperand(i);
14322         if (FirstLaneVal == Val)
14323           MaskVec.push_back(i);
14324         else
14325           MaskVec.push_back(i + NumElts);
14326       }
14327 
14328       SmallVector<SDValue, 8> Ops1(NumElts, Vals[0]);
14329       SmallVector<SDValue, 8> Ops2(NumElts, Vals[1]);
14330       SDValue VEC1 = DAG.getBuildVector(VT, dl, Ops1);
14331       SDValue VEC2 = DAG.getBuildVector(VT, dl, Ops2);
14332       SDValue VECTOR_SHUFFLE =
14333           DAG.getVectorShuffle(VT, dl, VEC1, VEC2, MaskVec);
14334       return VECTOR_SHUFFLE;
14335     }
14336   }
14337 
14338   // If all else fails, just use a sequence of INSERT_VECTOR_ELT when we
14339   // know the default expansion would otherwise fall back on something even
14340   // worse. For a vector with one or two non-undef values, that's
14341   // scalar_to_vector for the elements followed by a shuffle (provided the
14342   // shuffle is valid for the target) and materialization element by element
14343   // on the stack followed by a load for everything else.
14344   if (!isConstant && !usesOnlyOneValue) {
14345     LLVM_DEBUG(
14346         dbgs() << "LowerBUILD_VECTOR: alternatives failed, creating sequence "
14347                   "of INSERT_VECTOR_ELT\n");
14348 
14349     SDValue Vec = DAG.getUNDEF(VT);
14350     SDValue Op0 = Op.getOperand(0);
14351     unsigned i = 0;
14352 
14353     // Use SCALAR_TO_VECTOR for lane zero to
14354     // a) Avoid a RMW dependency on the full vector register, and
14355     // b) Allow the register coalescer to fold away the copy if the
14356     //    value is already in an S or D register, and we're forced to emit an
14357     //    INSERT_SUBREG that we can't fold anywhere.
14358     //
14359     // We also allow types like i8 and i16 which are illegal scalar but legal
14360     // vector element types. After type-legalization the inserted value is
14361     // extended (i32) and it is safe to cast them to the vector type by ignoring
14362     // the upper bits of the lowest lane (e.g. v8i8, v4i16).
14363     if (!Op0.isUndef()) {
14364       LLVM_DEBUG(dbgs() << "Creating node for op0, it is not undefined:\n");
14365       Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT, Op0);
14366       ++i;
14367     }
14368     LLVM_DEBUG(if (i < NumElts) dbgs()
14369                    << "Creating nodes for the other vector elements:\n";);
14370     for (; i < NumElts; ++i) {
14371       SDValue V = Op.getOperand(i);
14372       if (V.isUndef())
14373         continue;
14374       SDValue LaneIdx = DAG.getConstant(i, dl, MVT::i64);
14375       Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, Vec, V, LaneIdx);
14376     }
14377     return Vec;
14378   }
14379 
14380   LLVM_DEBUG(
14381       dbgs() << "LowerBUILD_VECTOR: use default expansion, failed to find "
14382                 "better alternative\n");
14383   return SDValue();
14384 }
14385 
LowerCONCAT_VECTORS(SDValue Op,SelectionDAG & DAG) const14386 SDValue AArch64TargetLowering::LowerCONCAT_VECTORS(SDValue Op,
14387                                                    SelectionDAG &DAG) const {
14388   if (useSVEForFixedLengthVectorVT(Op.getValueType(),
14389                                    !Subtarget->isNeonAvailable()))
14390     return LowerFixedLengthConcatVectorsToSVE(Op, DAG);
14391 
14392   assert(Op.getValueType().isScalableVector() &&
14393          isTypeLegal(Op.getValueType()) &&
14394          "Expected legal scalable vector type!");
14395 
14396   if (isTypeLegal(Op.getOperand(0).getValueType())) {
14397     unsigned NumOperands = Op->getNumOperands();
14398     assert(NumOperands > 1 && isPowerOf2_32(NumOperands) &&
14399            "Unexpected number of operands in CONCAT_VECTORS");
14400 
14401     if (NumOperands == 2)
14402       return Op;
14403 
14404     // Concat each pair of subvectors and pack into the lower half of the array.
14405     SmallVector<SDValue> ConcatOps(Op->op_begin(), Op->op_end());
14406     while (ConcatOps.size() > 1) {
14407       for (unsigned I = 0, E = ConcatOps.size(); I != E; I += 2) {
14408         SDValue V1 = ConcatOps[I];
14409         SDValue V2 = ConcatOps[I + 1];
14410         EVT SubVT = V1.getValueType();
14411         EVT PairVT = SubVT.getDoubleNumVectorElementsVT(*DAG.getContext());
14412         ConcatOps[I / 2] =
14413             DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(Op), PairVT, V1, V2);
14414       }
14415       ConcatOps.resize(ConcatOps.size() / 2);
14416     }
14417     return ConcatOps[0];
14418   }
14419 
14420   return SDValue();
14421 }
14422 
LowerINSERT_VECTOR_ELT(SDValue Op,SelectionDAG & DAG) const14423 SDValue AArch64TargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
14424                                                       SelectionDAG &DAG) const {
14425   assert(Op.getOpcode() == ISD::INSERT_VECTOR_ELT && "Unknown opcode!");
14426 
14427   if (useSVEForFixedLengthVectorVT(Op.getValueType(),
14428                                    !Subtarget->isNeonAvailable()))
14429     return LowerFixedLengthInsertVectorElt(Op, DAG);
14430 
14431   EVT VT = Op.getOperand(0).getValueType();
14432 
14433   if (VT.getScalarType() == MVT::i1) {
14434     EVT VectorVT = getPromotedVTForPredicate(VT);
14435     SDLoc DL(Op);
14436     SDValue ExtendedVector =
14437         DAG.getAnyExtOrTrunc(Op.getOperand(0), DL, VectorVT);
14438     SDValue ExtendedValue =
14439         DAG.getAnyExtOrTrunc(Op.getOperand(1), DL,
14440                              VectorVT.getScalarType().getSizeInBits() < 32
14441                                  ? MVT::i32
14442                                  : VectorVT.getScalarType());
14443     ExtendedVector =
14444         DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VectorVT, ExtendedVector,
14445                     ExtendedValue, Op.getOperand(2));
14446     return DAG.getAnyExtOrTrunc(ExtendedVector, DL, VT);
14447   }
14448 
14449   // Check for non-constant or out of range lane.
14450   ConstantSDNode *CI = dyn_cast<ConstantSDNode>(Op.getOperand(2));
14451   if (!CI || CI->getZExtValue() >= VT.getVectorNumElements())
14452     return SDValue();
14453 
14454   return Op;
14455 }
14456 
14457 SDValue
LowerEXTRACT_VECTOR_ELT(SDValue Op,SelectionDAG & DAG) const14458 AArch64TargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
14459                                                SelectionDAG &DAG) const {
14460   assert(Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT && "Unknown opcode!");
14461   EVT VT = Op.getOperand(0).getValueType();
14462 
14463   if (VT.getScalarType() == MVT::i1) {
14464     // We can't directly extract from an SVE predicate; extend it first.
14465     // (This isn't the only possible lowering, but it's straightforward.)
14466     EVT VectorVT = getPromotedVTForPredicate(VT);
14467     SDLoc DL(Op);
14468     SDValue Extend =
14469         DAG.getNode(ISD::ANY_EXTEND, DL, VectorVT, Op.getOperand(0));
14470     MVT ExtractTy = VectorVT == MVT::nxv2i64 ? MVT::i64 : MVT::i32;
14471     SDValue Extract = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ExtractTy,
14472                                   Extend, Op.getOperand(1));
14473     return DAG.getAnyExtOrTrunc(Extract, DL, Op.getValueType());
14474   }
14475 
14476   if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()))
14477     return LowerFixedLengthExtractVectorElt(Op, DAG);
14478 
14479   // Check for non-constant or out of range lane.
14480   ConstantSDNode *CI = dyn_cast<ConstantSDNode>(Op.getOperand(1));
14481   if (!CI || CI->getZExtValue() >= VT.getVectorNumElements())
14482     return SDValue();
14483 
14484   // Insertion/extraction are legal for V128 types.
14485   if (VT == MVT::v16i8 || VT == MVT::v8i16 || VT == MVT::v4i32 ||
14486       VT == MVT::v2i64 || VT == MVT::v4f32 || VT == MVT::v2f64 ||
14487       VT == MVT::v8f16 || VT == MVT::v8bf16)
14488     return Op;
14489 
14490   if (VT != MVT::v8i8 && VT != MVT::v4i16 && VT != MVT::v2i32 &&
14491       VT != MVT::v1i64 && VT != MVT::v2f32 && VT != MVT::v4f16 &&
14492       VT != MVT::v4bf16)
14493     return SDValue();
14494 
14495   // For V64 types, we perform extraction by expanding the value
14496   // to a V128 type and perform the extraction on that.
14497   SDLoc DL(Op);
14498   SDValue WideVec = WidenVector(Op.getOperand(0), DAG);
14499   EVT WideTy = WideVec.getValueType();
14500 
14501   EVT ExtrTy = WideTy.getVectorElementType();
14502   if (ExtrTy == MVT::i16 || ExtrTy == MVT::i8)
14503     ExtrTy = MVT::i32;
14504 
14505   // For extractions, we just return the result directly.
14506   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ExtrTy, WideVec,
14507                      Op.getOperand(1));
14508 }
14509 
LowerEXTRACT_SUBVECTOR(SDValue Op,SelectionDAG & DAG) const14510 SDValue AArch64TargetLowering::LowerEXTRACT_SUBVECTOR(SDValue Op,
14511                                                       SelectionDAG &DAG) const {
14512   EVT VT = Op.getValueType();
14513   assert(VT.isFixedLengthVector() &&
14514          "Only cases that extract a fixed length vector are supported!");
14515   EVT InVT = Op.getOperand(0).getValueType();
14516 
14517   // If we don't have legal types yet, do nothing
14518   if (!isTypeLegal(InVT))
14519     return SDValue();
14520 
14521   if (InVT.is128BitVector()) {
14522     assert(VT.is64BitVector() && "Extracting unexpected vector type!");
14523     unsigned Idx = Op.getConstantOperandVal(1);
14524 
14525     // This will get lowered to an appropriate EXTRACT_SUBREG in ISel.
14526     if (Idx == 0)
14527       return Op;
14528 
14529     // If this is extracting the upper 64-bits of a 128-bit vector, we match
14530     // that directly.
14531     if (Idx * InVT.getScalarSizeInBits() == 64 && Subtarget->isNeonAvailable())
14532       return Op;
14533   }
14534 
14535   if (InVT.isScalableVector() ||
14536       useSVEForFixedLengthVectorVT(InVT, !Subtarget->isNeonAvailable())) {
14537     SDLoc DL(Op);
14538     SDValue Vec = Op.getOperand(0);
14539     SDValue Idx = Op.getOperand(1);
14540 
14541     EVT PackedVT = getPackedSVEVectorVT(InVT.getVectorElementType());
14542     if (PackedVT != InVT) {
14543       // Pack input into the bottom part of an SVE register and try again.
14544       SDValue Container = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, PackedVT,
14545                                       DAG.getUNDEF(PackedVT), Vec,
14546                                       DAG.getVectorIdxConstant(0, DL));
14547       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Container, Idx);
14548     }
14549 
14550     // This will get matched by custom code during ISelDAGToDAG.
14551     if (isNullConstant(Idx))
14552       return Op;
14553 
14554     assert(InVT.isScalableVector() && "Unexpected vector type!");
14555     // Move requested subvector to the start of the vector and try again.
14556     SDValue Splice = DAG.getNode(ISD::VECTOR_SPLICE, DL, InVT, Vec, Vec, Idx);
14557     return convertFromScalableVector(DAG, VT, Splice);
14558   }
14559 
14560   return SDValue();
14561 }
14562 
LowerINSERT_SUBVECTOR(SDValue Op,SelectionDAG & DAG) const14563 SDValue AArch64TargetLowering::LowerINSERT_SUBVECTOR(SDValue Op,
14564                                                      SelectionDAG &DAG) const {
14565   assert(Op.getValueType().isScalableVector() &&
14566          "Only expect to lower inserts into scalable vectors!");
14567 
14568   EVT InVT = Op.getOperand(1).getValueType();
14569   unsigned Idx = Op.getConstantOperandVal(2);
14570 
14571   SDValue Vec0 = Op.getOperand(0);
14572   SDValue Vec1 = Op.getOperand(1);
14573   SDLoc DL(Op);
14574   EVT VT = Op.getValueType();
14575 
14576   if (InVT.isScalableVector()) {
14577     if (!isTypeLegal(VT))
14578       return SDValue();
14579 
14580     // Break down insert_subvector into simpler parts.
14581     if (VT.getVectorElementType() == MVT::i1) {
14582       unsigned NumElts = VT.getVectorMinNumElements();
14583       EVT HalfVT = VT.getHalfNumVectorElementsVT(*DAG.getContext());
14584 
14585       SDValue Lo, Hi;
14586       Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, Vec0,
14587                        DAG.getVectorIdxConstant(0, DL));
14588       Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, Vec0,
14589                        DAG.getVectorIdxConstant(NumElts / 2, DL));
14590       if (Idx < (NumElts / 2))
14591         Lo = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, HalfVT, Lo, Vec1,
14592                          DAG.getVectorIdxConstant(Idx, DL));
14593       else
14594         Hi = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, HalfVT, Hi, Vec1,
14595                          DAG.getVectorIdxConstant(Idx - (NumElts / 2), DL));
14596 
14597       return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi);
14598     }
14599 
14600     // Ensure the subvector is half the size of the main vector.
14601     if (VT.getVectorElementCount() != (InVT.getVectorElementCount() * 2))
14602       return SDValue();
14603 
14604     // Here narrow and wide refers to the vector element types. After "casting"
14605     // both vectors must have the same bit length and so because the subvector
14606     // has fewer elements, those elements need to be bigger.
14607     EVT NarrowVT = getPackedSVEVectorVT(VT.getVectorElementCount());
14608     EVT WideVT = getPackedSVEVectorVT(InVT.getVectorElementCount());
14609 
14610     // NOP cast operands to the largest legal vector of the same element count.
14611     if (VT.isFloatingPoint()) {
14612       Vec0 = getSVESafeBitCast(NarrowVT, Vec0, DAG);
14613       Vec1 = getSVESafeBitCast(WideVT, Vec1, DAG);
14614     } else {
14615       // Legal integer vectors are already their largest so Vec0 is fine as is.
14616       Vec1 = DAG.getNode(ISD::ANY_EXTEND, DL, WideVT, Vec1);
14617     }
14618 
14619     // To replace the top/bottom half of vector V with vector SubV we widen the
14620     // preserved half of V, concatenate this to SubV (the order depending on the
14621     // half being replaced) and then narrow the result.
14622     SDValue Narrow;
14623     if (Idx == 0) {
14624       SDValue HiVec0 = DAG.getNode(AArch64ISD::UUNPKHI, DL, WideVT, Vec0);
14625       Narrow = DAG.getNode(AArch64ISD::UZP1, DL, NarrowVT, Vec1, HiVec0);
14626     } else {
14627       assert(Idx == InVT.getVectorMinNumElements() &&
14628              "Invalid subvector index!");
14629       SDValue LoVec0 = DAG.getNode(AArch64ISD::UUNPKLO, DL, WideVT, Vec0);
14630       Narrow = DAG.getNode(AArch64ISD::UZP1, DL, NarrowVT, LoVec0, Vec1);
14631     }
14632 
14633     return getSVESafeBitCast(VT, Narrow, DAG);
14634   }
14635 
14636   if (Idx == 0 && isPackedVectorType(VT, DAG)) {
14637     // This will be matched by custom code during ISelDAGToDAG.
14638     if (Vec0.isUndef())
14639       return Op;
14640 
14641     std::optional<unsigned> PredPattern =
14642         getSVEPredPatternFromNumElements(InVT.getVectorNumElements());
14643     auto PredTy = VT.changeVectorElementType(MVT::i1);
14644     SDValue PTrue = getPTrue(DAG, DL, PredTy, *PredPattern);
14645     SDValue ScalableVec1 = convertToScalableVector(DAG, VT, Vec1);
14646     return DAG.getNode(ISD::VSELECT, DL, VT, PTrue, ScalableVec1, Vec0);
14647   }
14648 
14649   return SDValue();
14650 }
14651 
isPow2Splat(SDValue Op,uint64_t & SplatVal,bool & Negated)14652 static bool isPow2Splat(SDValue Op, uint64_t &SplatVal, bool &Negated) {
14653   if (Op.getOpcode() != AArch64ISD::DUP &&
14654       Op.getOpcode() != ISD::SPLAT_VECTOR &&
14655       Op.getOpcode() != ISD::BUILD_VECTOR)
14656     return false;
14657 
14658   if (Op.getOpcode() == ISD::BUILD_VECTOR &&
14659       !isAllConstantBuildVector(Op, SplatVal))
14660     return false;
14661 
14662   if (Op.getOpcode() != ISD::BUILD_VECTOR &&
14663       !isa<ConstantSDNode>(Op->getOperand(0)))
14664     return false;
14665 
14666   SplatVal = Op->getConstantOperandVal(0);
14667   if (Op.getValueType().getVectorElementType() != MVT::i64)
14668     SplatVal = (int32_t)SplatVal;
14669 
14670   Negated = false;
14671   if (isPowerOf2_64(SplatVal))
14672     return true;
14673 
14674   Negated = true;
14675   if (isPowerOf2_64(-SplatVal)) {
14676     SplatVal = -SplatVal;
14677     return true;
14678   }
14679 
14680   return false;
14681 }
14682 
LowerDIV(SDValue Op,SelectionDAG & DAG) const14683 SDValue AArch64TargetLowering::LowerDIV(SDValue Op, SelectionDAG &DAG) const {
14684   EVT VT = Op.getValueType();
14685   SDLoc dl(Op);
14686 
14687   if (useSVEForFixedLengthVectorVT(VT, /*OverrideNEON=*/true))
14688     return LowerFixedLengthVectorIntDivideToSVE(Op, DAG);
14689 
14690   assert(VT.isScalableVector() && "Expected a scalable vector.");
14691 
14692   bool Signed = Op.getOpcode() == ISD::SDIV;
14693   unsigned PredOpcode = Signed ? AArch64ISD::SDIV_PRED : AArch64ISD::UDIV_PRED;
14694 
14695   bool Negated;
14696   uint64_t SplatVal;
14697   if (Signed && isPow2Splat(Op.getOperand(1), SplatVal, Negated)) {
14698     SDValue Pg = getPredicateForScalableVector(DAG, dl, VT);
14699     SDValue Res =
14700         DAG.getNode(AArch64ISD::SRAD_MERGE_OP1, dl, VT, Pg, Op->getOperand(0),
14701                     DAG.getTargetConstant(Log2_64(SplatVal), dl, MVT::i32));
14702     if (Negated)
14703       Res = DAG.getNode(ISD::SUB, dl, VT, DAG.getConstant(0, dl, VT), Res);
14704 
14705     return Res;
14706   }
14707 
14708   if (VT == MVT::nxv4i32 || VT == MVT::nxv2i64)
14709     return LowerToPredicatedOp(Op, DAG, PredOpcode);
14710 
14711   // SVE doesn't have i8 and i16 DIV operations; widen them to 32-bit
14712   // operations, and truncate the result.
14713   EVT WidenedVT;
14714   if (VT == MVT::nxv16i8)
14715     WidenedVT = MVT::nxv8i16;
14716   else if (VT == MVT::nxv8i16)
14717     WidenedVT = MVT::nxv4i32;
14718   else
14719     llvm_unreachable("Unexpected Custom DIV operation");
14720 
14721   unsigned UnpkLo = Signed ? AArch64ISD::SUNPKLO : AArch64ISD::UUNPKLO;
14722   unsigned UnpkHi = Signed ? AArch64ISD::SUNPKHI : AArch64ISD::UUNPKHI;
14723   SDValue Op0Lo = DAG.getNode(UnpkLo, dl, WidenedVT, Op.getOperand(0));
14724   SDValue Op1Lo = DAG.getNode(UnpkLo, dl, WidenedVT, Op.getOperand(1));
14725   SDValue Op0Hi = DAG.getNode(UnpkHi, dl, WidenedVT, Op.getOperand(0));
14726   SDValue Op1Hi = DAG.getNode(UnpkHi, dl, WidenedVT, Op.getOperand(1));
14727   SDValue ResultLo = DAG.getNode(Op.getOpcode(), dl, WidenedVT, Op0Lo, Op1Lo);
14728   SDValue ResultHi = DAG.getNode(Op.getOpcode(), dl, WidenedVT, Op0Hi, Op1Hi);
14729   return DAG.getNode(AArch64ISD::UZP1, dl, VT, ResultLo, ResultHi);
14730 }
14731 
shouldExpandBuildVectorWithShuffles(EVT VT,unsigned DefinedValues) const14732 bool AArch64TargetLowering::shouldExpandBuildVectorWithShuffles(
14733     EVT VT, unsigned DefinedValues) const {
14734   if (!Subtarget->isNeonAvailable())
14735     return false;
14736   return TargetLowering::shouldExpandBuildVectorWithShuffles(VT, DefinedValues);
14737 }
14738 
isShuffleMaskLegal(ArrayRef<int> M,EVT VT) const14739 bool AArch64TargetLowering::isShuffleMaskLegal(ArrayRef<int> M, EVT VT) const {
14740   // Currently no fixed length shuffles that require SVE are legal.
14741   if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()))
14742     return false;
14743 
14744   if (VT.getVectorNumElements() == 4 &&
14745       (VT.is128BitVector() || VT.is64BitVector())) {
14746     unsigned Cost = getPerfectShuffleCost(M);
14747     if (Cost <= 1)
14748       return true;
14749   }
14750 
14751   bool DummyBool;
14752   int DummyInt;
14753   unsigned DummyUnsigned;
14754 
14755   unsigned EltSize = VT.getScalarSizeInBits();
14756   unsigned NumElts = VT.getVectorNumElements();
14757   return (ShuffleVectorSDNode::isSplatMask(&M[0], VT) ||
14758           isREVMask(M, EltSize, NumElts, 64) ||
14759           isREVMask(M, EltSize, NumElts, 32) ||
14760           isREVMask(M, EltSize, NumElts, 16) ||
14761           isEXTMask(M, VT, DummyBool, DummyUnsigned) ||
14762           isTRNMask(M, NumElts, DummyUnsigned) ||
14763           isUZPMask(M, NumElts, DummyUnsigned) ||
14764           isZIPMask(M, NumElts, DummyUnsigned) ||
14765           isTRN_v_undef_Mask(M, VT, DummyUnsigned) ||
14766           isUZP_v_undef_Mask(M, VT, DummyUnsigned) ||
14767           isZIP_v_undef_Mask(M, VT, DummyUnsigned) ||
14768           isINSMask(M, NumElts, DummyBool, DummyInt) ||
14769           isConcatMask(M, VT, VT.getSizeInBits() == 128));
14770 }
14771 
isVectorClearMaskLegal(ArrayRef<int> M,EVT VT) const14772 bool AArch64TargetLowering::isVectorClearMaskLegal(ArrayRef<int> M,
14773                                                    EVT VT) const {
14774   // Just delegate to the generic legality, clear masks aren't special.
14775   return isShuffleMaskLegal(M, VT);
14776 }
14777 
14778 /// getVShiftImm - Check if this is a valid build_vector for the immediate
14779 /// operand of a vector shift operation, where all the elements of the
14780 /// build_vector must have the same constant integer value.
getVShiftImm(SDValue Op,unsigned ElementBits,int64_t & Cnt)14781 static bool getVShiftImm(SDValue Op, unsigned ElementBits, int64_t &Cnt) {
14782   // Ignore bit_converts.
14783   while (Op.getOpcode() == ISD::BITCAST)
14784     Op = Op.getOperand(0);
14785   BuildVectorSDNode *BVN = dyn_cast<BuildVectorSDNode>(Op.getNode());
14786   APInt SplatBits, SplatUndef;
14787   unsigned SplatBitSize;
14788   bool HasAnyUndefs;
14789   if (!BVN || !BVN->isConstantSplat(SplatBits, SplatUndef, SplatBitSize,
14790                                     HasAnyUndefs, ElementBits) ||
14791       SplatBitSize > ElementBits)
14792     return false;
14793   Cnt = SplatBits.getSExtValue();
14794   return true;
14795 }
14796 
14797 /// isVShiftLImm - Check if this is a valid build_vector for the immediate
14798 /// operand of a vector shift left operation.  That value must be in the range:
14799 ///   0 <= Value < ElementBits for a left shift; or
14800 ///   0 <= Value <= ElementBits for a long left shift.
isVShiftLImm(SDValue Op,EVT VT,bool isLong,int64_t & Cnt)14801 static bool isVShiftLImm(SDValue Op, EVT VT, bool isLong, int64_t &Cnt) {
14802   assert(VT.isVector() && "vector shift count is not a vector type");
14803   int64_t ElementBits = VT.getScalarSizeInBits();
14804   if (!getVShiftImm(Op, ElementBits, Cnt))
14805     return false;
14806   return (Cnt >= 0 && (isLong ? Cnt - 1 : Cnt) < ElementBits);
14807 }
14808 
14809 /// isVShiftRImm - Check if this is a valid build_vector for the immediate
14810 /// operand of a vector shift right operation. The value must be in the range:
14811 ///   1 <= Value <= ElementBits for a right shift; or
isVShiftRImm(SDValue Op,EVT VT,bool isNarrow,int64_t & Cnt)14812 static bool isVShiftRImm(SDValue Op, EVT VT, bool isNarrow, int64_t &Cnt) {
14813   assert(VT.isVector() && "vector shift count is not a vector type");
14814   int64_t ElementBits = VT.getScalarSizeInBits();
14815   if (!getVShiftImm(Op, ElementBits, Cnt))
14816     return false;
14817   return (Cnt >= 1 && Cnt <= (isNarrow ? ElementBits / 2 : ElementBits));
14818 }
14819 
LowerTRUNCATE(SDValue Op,SelectionDAG & DAG) const14820 SDValue AArch64TargetLowering::LowerTRUNCATE(SDValue Op,
14821                                              SelectionDAG &DAG) const {
14822   EVT VT = Op.getValueType();
14823 
14824   if (VT.getScalarType() == MVT::i1) {
14825     // Lower i1 truncate to `(x & 1) != 0`.
14826     SDLoc dl(Op);
14827     EVT OpVT = Op.getOperand(0).getValueType();
14828     SDValue Zero = DAG.getConstant(0, dl, OpVT);
14829     SDValue One = DAG.getConstant(1, dl, OpVT);
14830     SDValue And = DAG.getNode(ISD::AND, dl, OpVT, Op.getOperand(0), One);
14831     return DAG.getSetCC(dl, VT, And, Zero, ISD::SETNE);
14832   }
14833 
14834   if (!VT.isVector() || VT.isScalableVector())
14835     return SDValue();
14836 
14837   if (useSVEForFixedLengthVectorVT(Op.getOperand(0).getValueType(),
14838                                    !Subtarget->isNeonAvailable()))
14839     return LowerFixedLengthVectorTruncateToSVE(Op, DAG);
14840 
14841   return SDValue();
14842 }
14843 
14844 // Check if we can we lower this SRL to a rounding shift instruction. ResVT is
14845 // possibly a truncated type, it tells how many bits of the value are to be
14846 // used.
canLowerSRLToRoundingShiftForVT(SDValue Shift,EVT ResVT,SelectionDAG & DAG,unsigned & ShiftValue,SDValue & RShOperand)14847 static bool canLowerSRLToRoundingShiftForVT(SDValue Shift, EVT ResVT,
14848                                             SelectionDAG &DAG,
14849                                             unsigned &ShiftValue,
14850                                             SDValue &RShOperand) {
14851   if (Shift->getOpcode() != ISD::SRL)
14852     return false;
14853 
14854   EVT VT = Shift.getValueType();
14855   assert(VT.isScalableVT());
14856 
14857   auto ShiftOp1 =
14858       dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Shift->getOperand(1)));
14859   if (!ShiftOp1)
14860     return false;
14861 
14862   ShiftValue = ShiftOp1->getZExtValue();
14863   if (ShiftValue < 1 || ShiftValue > ResVT.getScalarSizeInBits())
14864     return false;
14865 
14866   SDValue Add = Shift->getOperand(0);
14867   if (Add->getOpcode() != ISD::ADD || !Add->hasOneUse())
14868     return false;
14869 
14870   assert(ResVT.getScalarSizeInBits() <= VT.getScalarSizeInBits() &&
14871          "ResVT must be truncated or same type as the shift.");
14872   // Check if an overflow can lead to incorrect results.
14873   uint64_t ExtraBits = VT.getScalarSizeInBits() - ResVT.getScalarSizeInBits();
14874   if (ShiftValue > ExtraBits && !Add->getFlags().hasNoUnsignedWrap())
14875     return false;
14876 
14877   auto AddOp1 =
14878       dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Add->getOperand(1)));
14879   if (!AddOp1)
14880     return false;
14881   uint64_t AddValue = AddOp1->getZExtValue();
14882   if (AddValue != 1ULL << (ShiftValue - 1))
14883     return false;
14884 
14885   RShOperand = Add->getOperand(0);
14886   return true;
14887 }
14888 
LowerVectorSRA_SRL_SHL(SDValue Op,SelectionDAG & DAG) const14889 SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
14890                                                       SelectionDAG &DAG) const {
14891   EVT VT = Op.getValueType();
14892   SDLoc DL(Op);
14893   int64_t Cnt;
14894 
14895   if (!Op.getOperand(1).getValueType().isVector())
14896     return Op;
14897   unsigned EltSize = VT.getScalarSizeInBits();
14898 
14899   switch (Op.getOpcode()) {
14900   case ISD::SHL:
14901     if (VT.isScalableVector() ||
14902         useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()))
14903       return LowerToPredicatedOp(Op, DAG, AArch64ISD::SHL_PRED);
14904 
14905     if (isVShiftLImm(Op.getOperand(1), VT, false, Cnt) && Cnt < EltSize)
14906       return DAG.getNode(AArch64ISD::VSHL, DL, VT, Op.getOperand(0),
14907                          DAG.getConstant(Cnt, DL, MVT::i32));
14908     return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
14909                        DAG.getConstant(Intrinsic::aarch64_neon_ushl, DL,
14910                                        MVT::i32),
14911                        Op.getOperand(0), Op.getOperand(1));
14912   case ISD::SRA:
14913   case ISD::SRL:
14914     if (VT.isScalableVector() &&
14915         (Subtarget->hasSVE2() ||
14916          (Subtarget->hasSME() && Subtarget->isStreaming()))) {
14917       SDValue RShOperand;
14918       unsigned ShiftValue;
14919       if (canLowerSRLToRoundingShiftForVT(Op, VT, DAG, ShiftValue, RShOperand))
14920         return DAG.getNode(AArch64ISD::URSHR_I_PRED, DL, VT,
14921                            getPredicateForVector(DAG, DL, VT), RShOperand,
14922                            DAG.getTargetConstant(ShiftValue, DL, MVT::i32));
14923     }
14924 
14925     if (VT.isScalableVector() ||
14926         useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable())) {
14927       unsigned Opc = Op.getOpcode() == ISD::SRA ? AArch64ISD::SRA_PRED
14928                                                 : AArch64ISD::SRL_PRED;
14929       return LowerToPredicatedOp(Op, DAG, Opc);
14930     }
14931 
14932     // Right shift immediate
14933     if (isVShiftRImm(Op.getOperand(1), VT, false, Cnt) && Cnt < EltSize) {
14934       unsigned Opc =
14935           (Op.getOpcode() == ISD::SRA) ? AArch64ISD::VASHR : AArch64ISD::VLSHR;
14936       return DAG.getNode(Opc, DL, VT, Op.getOperand(0),
14937                          DAG.getConstant(Cnt, DL, MVT::i32), Op->getFlags());
14938     }
14939 
14940     // Right shift register.  Note, there is not a shift right register
14941     // instruction, but the shift left register instruction takes a signed
14942     // value, where negative numbers specify a right shift.
14943     unsigned Opc = (Op.getOpcode() == ISD::SRA) ? Intrinsic::aarch64_neon_sshl
14944                                                 : Intrinsic::aarch64_neon_ushl;
14945     // negate the shift amount
14946     SDValue NegShift = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
14947                                    Op.getOperand(1));
14948     SDValue NegShiftLeft =
14949         DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
14950                     DAG.getConstant(Opc, DL, MVT::i32), Op.getOperand(0),
14951                     NegShift);
14952     return NegShiftLeft;
14953   }
14954 
14955   llvm_unreachable("unexpected shift opcode");
14956 }
14957 
EmitVectorComparison(SDValue LHS,SDValue RHS,AArch64CC::CondCode CC,bool NoNans,EVT VT,const SDLoc & dl,SelectionDAG & DAG)14958 static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
14959                                     AArch64CC::CondCode CC, bool NoNans, EVT VT,
14960                                     const SDLoc &dl, SelectionDAG &DAG) {
14961   EVT SrcVT = LHS.getValueType();
14962   assert(VT.getSizeInBits() == SrcVT.getSizeInBits() &&
14963          "function only supposed to emit natural comparisons");
14964 
14965   APInt SplatValue;
14966   APInt SplatUndef;
14967   unsigned SplatBitSize = 0;
14968   bool HasAnyUndefs;
14969 
14970   BuildVectorSDNode *BVN = dyn_cast<BuildVectorSDNode>(RHS.getNode());
14971   bool IsCnst = BVN && BVN->isConstantSplat(SplatValue, SplatUndef,
14972                                             SplatBitSize, HasAnyUndefs);
14973 
14974   bool IsZero = IsCnst && SplatValue == 0;
14975   bool IsOne =
14976       IsCnst && SrcVT.getScalarSizeInBits() == SplatBitSize && SplatValue == 1;
14977   bool IsMinusOne = IsCnst && SplatValue.isAllOnes();
14978 
14979   if (SrcVT.getVectorElementType().isFloatingPoint()) {
14980     switch (CC) {
14981     default:
14982       return SDValue();
14983     case AArch64CC::NE: {
14984       SDValue Fcmeq;
14985       if (IsZero)
14986         Fcmeq = DAG.getNode(AArch64ISD::FCMEQz, dl, VT, LHS);
14987       else
14988         Fcmeq = DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
14989       return DAG.getNOT(dl, Fcmeq, VT);
14990     }
14991     case AArch64CC::EQ:
14992       if (IsZero)
14993         return DAG.getNode(AArch64ISD::FCMEQz, dl, VT, LHS);
14994       return DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
14995     case AArch64CC::GE:
14996       if (IsZero)
14997         return DAG.getNode(AArch64ISD::FCMGEz, dl, VT, LHS);
14998       return DAG.getNode(AArch64ISD::FCMGE, dl, VT, LHS, RHS);
14999     case AArch64CC::GT:
15000       if (IsZero)
15001         return DAG.getNode(AArch64ISD::FCMGTz, dl, VT, LHS);
15002       return DAG.getNode(AArch64ISD::FCMGT, dl, VT, LHS, RHS);
15003     case AArch64CC::LE:
15004       if (!NoNans)
15005         return SDValue();
15006       // If we ignore NaNs then we can use to the LS implementation.
15007       [[fallthrough]];
15008     case AArch64CC::LS:
15009       if (IsZero)
15010         return DAG.getNode(AArch64ISD::FCMLEz, dl, VT, LHS);
15011       return DAG.getNode(AArch64ISD::FCMGE, dl, VT, RHS, LHS);
15012     case AArch64CC::LT:
15013       if (!NoNans)
15014         return SDValue();
15015       // If we ignore NaNs then we can use to the MI implementation.
15016       [[fallthrough]];
15017     case AArch64CC::MI:
15018       if (IsZero)
15019         return DAG.getNode(AArch64ISD::FCMLTz, dl, VT, LHS);
15020       return DAG.getNode(AArch64ISD::FCMGT, dl, VT, RHS, LHS);
15021     }
15022   }
15023 
15024   switch (CC) {
15025   default:
15026     return SDValue();
15027   case AArch64CC::NE: {
15028     SDValue Cmeq;
15029     if (IsZero)
15030       Cmeq = DAG.getNode(AArch64ISD::CMEQz, dl, VT, LHS);
15031     else
15032       Cmeq = DAG.getNode(AArch64ISD::CMEQ, dl, VT, LHS, RHS);
15033     return DAG.getNOT(dl, Cmeq, VT);
15034   }
15035   case AArch64CC::EQ:
15036     if (IsZero)
15037       return DAG.getNode(AArch64ISD::CMEQz, dl, VT, LHS);
15038     return DAG.getNode(AArch64ISD::CMEQ, dl, VT, LHS, RHS);
15039   case AArch64CC::GE:
15040     if (IsZero)
15041       return DAG.getNode(AArch64ISD::CMGEz, dl, VT, LHS);
15042     return DAG.getNode(AArch64ISD::CMGE, dl, VT, LHS, RHS);
15043   case AArch64CC::GT:
15044     if (IsZero)
15045       return DAG.getNode(AArch64ISD::CMGTz, dl, VT, LHS);
15046     if (IsMinusOne)
15047       return DAG.getNode(AArch64ISD::CMGEz, dl, VT, LHS, RHS);
15048     return DAG.getNode(AArch64ISD::CMGT, dl, VT, LHS, RHS);
15049   case AArch64CC::LE:
15050     if (IsZero)
15051       return DAG.getNode(AArch64ISD::CMLEz, dl, VT, LHS);
15052     return DAG.getNode(AArch64ISD::CMGE, dl, VT, RHS, LHS);
15053   case AArch64CC::LS:
15054     return DAG.getNode(AArch64ISD::CMHS, dl, VT, RHS, LHS);
15055   case AArch64CC::LO:
15056     return DAG.getNode(AArch64ISD::CMHI, dl, VT, RHS, LHS);
15057   case AArch64CC::LT:
15058     if (IsZero)
15059       return DAG.getNode(AArch64ISD::CMLTz, dl, VT, LHS);
15060     if (IsOne)
15061       return DAG.getNode(AArch64ISD::CMLEz, dl, VT, LHS);
15062     return DAG.getNode(AArch64ISD::CMGT, dl, VT, RHS, LHS);
15063   case AArch64CC::HI:
15064     return DAG.getNode(AArch64ISD::CMHI, dl, VT, LHS, RHS);
15065   case AArch64CC::HS:
15066     return DAG.getNode(AArch64ISD::CMHS, dl, VT, LHS, RHS);
15067   }
15068 }
15069 
LowerVSETCC(SDValue Op,SelectionDAG & DAG) const15070 SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
15071                                            SelectionDAG &DAG) const {
15072   if (Op.getValueType().isScalableVector())
15073     return LowerToPredicatedOp(Op, DAG, AArch64ISD::SETCC_MERGE_ZERO);
15074 
15075   if (useSVEForFixedLengthVectorVT(Op.getOperand(0).getValueType(),
15076                                    !Subtarget->isNeonAvailable()))
15077     return LowerFixedLengthVectorSetccToSVE(Op, DAG);
15078 
15079   ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
15080   SDValue LHS = Op.getOperand(0);
15081   SDValue RHS = Op.getOperand(1);
15082   EVT CmpVT = LHS.getValueType().changeVectorElementTypeToInteger();
15083   SDLoc dl(Op);
15084 
15085   if (LHS.getValueType().getVectorElementType().isInteger()) {
15086     assert(LHS.getValueType() == RHS.getValueType());
15087     AArch64CC::CondCode AArch64CC = changeIntCCToAArch64CC(CC);
15088     SDValue Cmp =
15089         EmitVectorComparison(LHS, RHS, AArch64CC, false, CmpVT, dl, DAG);
15090     return DAG.getSExtOrTrunc(Cmp, dl, Op.getValueType());
15091   }
15092 
15093   // Lower isnan(x) | isnan(never-nan) to x != x.
15094   // Lower !isnan(x) & !isnan(never-nan) to x == x.
15095   if (CC == ISD::SETUO || CC == ISD::SETO) {
15096     bool OneNaN = false;
15097     if (LHS == RHS) {
15098       OneNaN = true;
15099     } else if (DAG.isKnownNeverNaN(RHS)) {
15100       OneNaN = true;
15101       RHS = LHS;
15102     } else if (DAG.isKnownNeverNaN(LHS)) {
15103       OneNaN = true;
15104       LHS = RHS;
15105     }
15106     if (OneNaN) {
15107       CC = CC == ISD::SETUO ? ISD::SETUNE : ISD::SETOEQ;
15108     }
15109   }
15110 
15111   const bool FullFP16 = DAG.getSubtarget<AArch64Subtarget>().hasFullFP16();
15112 
15113   // Make v4f16 (only) fcmp operations utilise vector instructions
15114   // v8f16 support will be a litle more complicated
15115   if ((!FullFP16 && LHS.getValueType().getVectorElementType() == MVT::f16) ||
15116       LHS.getValueType().getVectorElementType() == MVT::bf16) {
15117     if (LHS.getValueType().getVectorNumElements() == 4) {
15118       LHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::v4f32, LHS);
15119       RHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::v4f32, RHS);
15120       SDValue NewSetcc = DAG.getSetCC(dl, MVT::v4i16, LHS, RHS, CC);
15121       DAG.ReplaceAllUsesWith(Op, NewSetcc);
15122       CmpVT = MVT::v4i32;
15123     } else
15124       return SDValue();
15125   }
15126 
15127   assert((!FullFP16 && LHS.getValueType().getVectorElementType() != MVT::f16) ||
15128          LHS.getValueType().getVectorElementType() != MVT::bf16 ||
15129          LHS.getValueType().getVectorElementType() != MVT::f128);
15130 
15131   // Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't totally
15132   // clean.  Some of them require two branches to implement.
15133   AArch64CC::CondCode CC1, CC2;
15134   bool ShouldInvert;
15135   changeVectorFPCCToAArch64CC(CC, CC1, CC2, ShouldInvert);
15136 
15137   bool NoNaNs = getTargetMachine().Options.NoNaNsFPMath || Op->getFlags().hasNoNaNs();
15138   SDValue Cmp =
15139       EmitVectorComparison(LHS, RHS, CC1, NoNaNs, CmpVT, dl, DAG);
15140   if (!Cmp.getNode())
15141     return SDValue();
15142 
15143   if (CC2 != AArch64CC::AL) {
15144     SDValue Cmp2 =
15145         EmitVectorComparison(LHS, RHS, CC2, NoNaNs, CmpVT, dl, DAG);
15146     if (!Cmp2.getNode())
15147       return SDValue();
15148 
15149     Cmp = DAG.getNode(ISD::OR, dl, CmpVT, Cmp, Cmp2);
15150   }
15151 
15152   Cmp = DAG.getSExtOrTrunc(Cmp, dl, Op.getValueType());
15153 
15154   if (ShouldInvert)
15155     Cmp = DAG.getNOT(dl, Cmp, Cmp.getValueType());
15156 
15157   return Cmp;
15158 }
15159 
getReductionSDNode(unsigned Op,SDLoc DL,SDValue ScalarOp,SelectionDAG & DAG)15160 static SDValue getReductionSDNode(unsigned Op, SDLoc DL, SDValue ScalarOp,
15161                                   SelectionDAG &DAG) {
15162   SDValue VecOp = ScalarOp.getOperand(0);
15163   auto Rdx = DAG.getNode(Op, DL, VecOp.getSimpleValueType(), VecOp);
15164   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarOp.getValueType(), Rdx,
15165                      DAG.getConstant(0, DL, MVT::i64));
15166 }
15167 
getVectorBitwiseReduce(unsigned Opcode,SDValue Vec,EVT VT,SDLoc DL,SelectionDAG & DAG)15168 static SDValue getVectorBitwiseReduce(unsigned Opcode, SDValue Vec, EVT VT,
15169                                       SDLoc DL, SelectionDAG &DAG) {
15170   unsigned ScalarOpcode;
15171   switch (Opcode) {
15172   case ISD::VECREDUCE_AND:
15173     ScalarOpcode = ISD::AND;
15174     break;
15175   case ISD::VECREDUCE_OR:
15176     ScalarOpcode = ISD::OR;
15177     break;
15178   case ISD::VECREDUCE_XOR:
15179     ScalarOpcode = ISD::XOR;
15180     break;
15181   default:
15182     llvm_unreachable("Expected bitwise vector reduction");
15183     return SDValue();
15184   }
15185 
15186   EVT VecVT = Vec.getValueType();
15187   assert(VecVT.isFixedLengthVector() && VecVT.isPow2VectorType() &&
15188          "Expected power-of-2 length vector");
15189 
15190   EVT ElemVT = VecVT.getVectorElementType();
15191 
15192   SDValue Result;
15193   unsigned NumElems = VecVT.getVectorNumElements();
15194 
15195   // Special case for boolean reductions
15196   if (ElemVT == MVT::i1) {
15197     // Split large vectors into smaller ones
15198     if (NumElems > 16) {
15199       SDValue Lo, Hi;
15200       std::tie(Lo, Hi) = DAG.SplitVector(Vec, DL);
15201       EVT HalfVT = Lo.getValueType();
15202       SDValue HalfVec = DAG.getNode(ScalarOpcode, DL, HalfVT, Lo, Hi);
15203       return getVectorBitwiseReduce(Opcode, HalfVec, VT, DL, DAG);
15204     }
15205 
15206     // Vectors that are less than 64 bits get widened to neatly fit a 64 bit
15207     // register, so e.g. <4 x i1> gets lowered to <4 x i16>. Sign extending to
15208     // this element size leads to the best codegen, since e.g. setcc results
15209     // might need to be truncated otherwise.
15210     EVT ExtendedVT = MVT::getIntegerVT(std::max(64u / NumElems, 8u));
15211 
15212     // any_ext doesn't work with umin/umax, so only use it for uadd.
15213     unsigned ExtendOp =
15214         ScalarOpcode == ISD::XOR ? ISD::ANY_EXTEND : ISD::SIGN_EXTEND;
15215     SDValue Extended = DAG.getNode(
15216         ExtendOp, DL, VecVT.changeVectorElementType(ExtendedVT), Vec);
15217     switch (ScalarOpcode) {
15218     case ISD::AND:
15219       Result = DAG.getNode(ISD::VECREDUCE_UMIN, DL, ExtendedVT, Extended);
15220       break;
15221     case ISD::OR:
15222       Result = DAG.getNode(ISD::VECREDUCE_UMAX, DL, ExtendedVT, Extended);
15223       break;
15224     case ISD::XOR:
15225       Result = DAG.getNode(ISD::VECREDUCE_ADD, DL, ExtendedVT, Extended);
15226       break;
15227     default:
15228       llvm_unreachable("Unexpected Opcode");
15229     }
15230 
15231     Result = DAG.getAnyExtOrTrunc(Result, DL, MVT::i1);
15232   } else {
15233     // Iteratively split the vector in half and combine using the bitwise
15234     // operation until it fits in a 64 bit register.
15235     while (VecVT.getSizeInBits() > 64) {
15236       SDValue Lo, Hi;
15237       std::tie(Lo, Hi) = DAG.SplitVector(Vec, DL);
15238       VecVT = Lo.getValueType();
15239       NumElems = VecVT.getVectorNumElements();
15240       Vec = DAG.getNode(ScalarOpcode, DL, VecVT, Lo, Hi);
15241     }
15242 
15243     EVT ScalarVT = EVT::getIntegerVT(*DAG.getContext(), VecVT.getSizeInBits());
15244 
15245     // Do the remaining work on a scalar since it allows the code generator to
15246     // combine the shift and bitwise operation into one instruction and since
15247     // integer instructions can have higher throughput than vector instructions.
15248     SDValue Scalar = DAG.getBitcast(ScalarVT, Vec);
15249 
15250     // Iteratively combine the lower and upper halves of the scalar using the
15251     // bitwise operation, halving the relevant region of the scalar in each
15252     // iteration, until the relevant region is just one element of the original
15253     // vector.
15254     for (unsigned Shift = NumElems / 2; Shift > 0; Shift /= 2) {
15255       SDValue ShiftAmount =
15256           DAG.getConstant(Shift * ElemVT.getSizeInBits(), DL, MVT::i64);
15257       SDValue Shifted =
15258           DAG.getNode(ISD::SRL, DL, ScalarVT, Scalar, ShiftAmount);
15259       Scalar = DAG.getNode(ScalarOpcode, DL, ScalarVT, Scalar, Shifted);
15260     }
15261 
15262     Result = DAG.getAnyExtOrTrunc(Scalar, DL, ElemVT);
15263   }
15264 
15265   return DAG.getAnyExtOrTrunc(Result, DL, VT);
15266 }
15267 
LowerVECREDUCE(SDValue Op,SelectionDAG & DAG) const15268 SDValue AArch64TargetLowering::LowerVECREDUCE(SDValue Op,
15269                                               SelectionDAG &DAG) const {
15270   SDValue Src = Op.getOperand(0);
15271 
15272   // Try to lower fixed length reductions to SVE.
15273   EVT SrcVT = Src.getValueType();
15274   bool OverrideNEON = !Subtarget->isNeonAvailable() ||
15275                       Op.getOpcode() == ISD::VECREDUCE_AND ||
15276                       Op.getOpcode() == ISD::VECREDUCE_OR ||
15277                       Op.getOpcode() == ISD::VECREDUCE_XOR ||
15278                       Op.getOpcode() == ISD::VECREDUCE_FADD ||
15279                       (Op.getOpcode() != ISD::VECREDUCE_ADD &&
15280                        SrcVT.getVectorElementType() == MVT::i64);
15281   if (SrcVT.isScalableVector() ||
15282       useSVEForFixedLengthVectorVT(
15283           SrcVT, OverrideNEON && Subtarget->useSVEForFixedLengthVectors())) {
15284 
15285     if (SrcVT.getVectorElementType() == MVT::i1)
15286       return LowerPredReductionToSVE(Op, DAG);
15287 
15288     switch (Op.getOpcode()) {
15289     case ISD::VECREDUCE_ADD:
15290       return LowerReductionToSVE(AArch64ISD::UADDV_PRED, Op, DAG);
15291     case ISD::VECREDUCE_AND:
15292       return LowerReductionToSVE(AArch64ISD::ANDV_PRED, Op, DAG);
15293     case ISD::VECREDUCE_OR:
15294       return LowerReductionToSVE(AArch64ISD::ORV_PRED, Op, DAG);
15295     case ISD::VECREDUCE_SMAX:
15296       return LowerReductionToSVE(AArch64ISD::SMAXV_PRED, Op, DAG);
15297     case ISD::VECREDUCE_SMIN:
15298       return LowerReductionToSVE(AArch64ISD::SMINV_PRED, Op, DAG);
15299     case ISD::VECREDUCE_UMAX:
15300       return LowerReductionToSVE(AArch64ISD::UMAXV_PRED, Op, DAG);
15301     case ISD::VECREDUCE_UMIN:
15302       return LowerReductionToSVE(AArch64ISD::UMINV_PRED, Op, DAG);
15303     case ISD::VECREDUCE_XOR:
15304       return LowerReductionToSVE(AArch64ISD::EORV_PRED, Op, DAG);
15305     case ISD::VECREDUCE_FADD:
15306       return LowerReductionToSVE(AArch64ISD::FADDV_PRED, Op, DAG);
15307     case ISD::VECREDUCE_FMAX:
15308       return LowerReductionToSVE(AArch64ISD::FMAXNMV_PRED, Op, DAG);
15309     case ISD::VECREDUCE_FMIN:
15310       return LowerReductionToSVE(AArch64ISD::FMINNMV_PRED, Op, DAG);
15311     case ISD::VECREDUCE_FMAXIMUM:
15312       return LowerReductionToSVE(AArch64ISD::FMAXV_PRED, Op, DAG);
15313     case ISD::VECREDUCE_FMINIMUM:
15314       return LowerReductionToSVE(AArch64ISD::FMINV_PRED, Op, DAG);
15315     default:
15316       llvm_unreachable("Unhandled fixed length reduction");
15317     }
15318   }
15319 
15320   // Lower NEON reductions.
15321   SDLoc dl(Op);
15322   switch (Op.getOpcode()) {
15323   case ISD::VECREDUCE_AND:
15324   case ISD::VECREDUCE_OR:
15325   case ISD::VECREDUCE_XOR:
15326     return getVectorBitwiseReduce(Op.getOpcode(), Op.getOperand(0),
15327                                   Op.getValueType(), dl, DAG);
15328   case ISD::VECREDUCE_ADD:
15329     return getReductionSDNode(AArch64ISD::UADDV, dl, Op, DAG);
15330   case ISD::VECREDUCE_SMAX:
15331     return getReductionSDNode(AArch64ISD::SMAXV, dl, Op, DAG);
15332   case ISD::VECREDUCE_SMIN:
15333     return getReductionSDNode(AArch64ISD::SMINV, dl, Op, DAG);
15334   case ISD::VECREDUCE_UMAX:
15335     return getReductionSDNode(AArch64ISD::UMAXV, dl, Op, DAG);
15336   case ISD::VECREDUCE_UMIN:
15337     return getReductionSDNode(AArch64ISD::UMINV, dl, Op, DAG);
15338   default:
15339     llvm_unreachable("Unhandled reduction");
15340   }
15341 }
15342 
LowerATOMIC_LOAD_AND(SDValue Op,SelectionDAG & DAG) const15343 SDValue AArch64TargetLowering::LowerATOMIC_LOAD_AND(SDValue Op,
15344                                                     SelectionDAG &DAG) const {
15345   auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
15346   // No point replacing if we don't have the relevant instruction/libcall anyway
15347   if (!Subtarget.hasLSE() && !Subtarget.outlineAtomics())
15348     return SDValue();
15349 
15350   // LSE has an atomic load-clear instruction, but not a load-and.
15351   SDLoc dl(Op);
15352   MVT VT = Op.getSimpleValueType();
15353   assert(VT != MVT::i128 && "Handled elsewhere, code replicated.");
15354   SDValue RHS = Op.getOperand(2);
15355   AtomicSDNode *AN = cast<AtomicSDNode>(Op.getNode());
15356   RHS = DAG.getNode(ISD::XOR, dl, VT, DAG.getConstant(-1ULL, dl, VT), RHS);
15357   return DAG.getAtomic(ISD::ATOMIC_LOAD_CLR, dl, AN->getMemoryVT(),
15358                        Op.getOperand(0), Op.getOperand(1), RHS,
15359                        AN->getMemOperand());
15360 }
15361 
15362 SDValue
LowerWindowsDYNAMIC_STACKALLOC(SDValue Op,SelectionDAG & DAG) const15363 AArch64TargetLowering::LowerWindowsDYNAMIC_STACKALLOC(SDValue Op,
15364                                                       SelectionDAG &DAG) const {
15365 
15366   SDLoc dl(Op);
15367   // Get the inputs.
15368   SDNode *Node = Op.getNode();
15369   SDValue Chain = Op.getOperand(0);
15370   SDValue Size = Op.getOperand(1);
15371   MaybeAlign Align =
15372       cast<ConstantSDNode>(Op.getOperand(2))->getMaybeAlignValue();
15373   EVT VT = Node->getValueType(0);
15374 
15375   if (DAG.getMachineFunction().getFunction().hasFnAttribute(
15376           "no-stack-arg-probe")) {
15377     SDValue SP = DAG.getCopyFromReg(Chain, dl, AArch64::SP, MVT::i64);
15378     Chain = SP.getValue(1);
15379     SP = DAG.getNode(ISD::SUB, dl, MVT::i64, SP, Size);
15380     if (Align)
15381       SP = DAG.getNode(ISD::AND, dl, VT, SP.getValue(0),
15382                        DAG.getConstant(-(uint64_t)Align->value(), dl, VT));
15383     Chain = DAG.getCopyToReg(Chain, dl, AArch64::SP, SP);
15384     SDValue Ops[2] = {SP, Chain};
15385     return DAG.getMergeValues(Ops, dl);
15386   }
15387 
15388   Chain = DAG.getCALLSEQ_START(Chain, 0, 0, dl);
15389 
15390   EVT PtrVT = getPointerTy(DAG.getDataLayout());
15391   SDValue Callee = DAG.getTargetExternalSymbol(Subtarget->getChkStkName(),
15392                                                PtrVT, 0);
15393 
15394   const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
15395   const uint32_t *Mask = TRI->getWindowsStackProbePreservedMask();
15396   if (Subtarget->hasCustomCallingConv())
15397     TRI->UpdateCustomCallPreservedMask(DAG.getMachineFunction(), &Mask);
15398 
15399   Size = DAG.getNode(ISD::SRL, dl, MVT::i64, Size,
15400                      DAG.getConstant(4, dl, MVT::i64));
15401   Chain = DAG.getCopyToReg(Chain, dl, AArch64::X15, Size, SDValue());
15402   Chain =
15403       DAG.getNode(AArch64ISD::CALL, dl, DAG.getVTList(MVT::Other, MVT::Glue),
15404                   Chain, Callee, DAG.getRegister(AArch64::X15, MVT::i64),
15405                   DAG.getRegisterMask(Mask), Chain.getValue(1));
15406   // To match the actual intent better, we should read the output from X15 here
15407   // again (instead of potentially spilling it to the stack), but rereading Size
15408   // from X15 here doesn't work at -O0, since it thinks that X15 is undefined
15409   // here.
15410 
15411   Size = DAG.getNode(ISD::SHL, dl, MVT::i64, Size,
15412                      DAG.getConstant(4, dl, MVT::i64));
15413 
15414   SDValue SP = DAG.getCopyFromReg(Chain, dl, AArch64::SP, MVT::i64);
15415   Chain = SP.getValue(1);
15416   SP = DAG.getNode(ISD::SUB, dl, MVT::i64, SP, Size);
15417   if (Align)
15418     SP = DAG.getNode(ISD::AND, dl, VT, SP.getValue(0),
15419                      DAG.getConstant(-(uint64_t)Align->value(), dl, VT));
15420   Chain = DAG.getCopyToReg(Chain, dl, AArch64::SP, SP);
15421 
15422   Chain = DAG.getCALLSEQ_END(Chain, 0, 0, SDValue(), dl);
15423 
15424   SDValue Ops[2] = {SP, Chain};
15425   return DAG.getMergeValues(Ops, dl);
15426 }
15427 
15428 SDValue
LowerInlineDYNAMIC_STACKALLOC(SDValue Op,SelectionDAG & DAG) const15429 AArch64TargetLowering::LowerInlineDYNAMIC_STACKALLOC(SDValue Op,
15430                                                      SelectionDAG &DAG) const {
15431   // Get the inputs.
15432   SDNode *Node = Op.getNode();
15433   SDValue Chain = Op.getOperand(0);
15434   SDValue Size = Op.getOperand(1);
15435 
15436   MaybeAlign Align =
15437       cast<ConstantSDNode>(Op.getOperand(2))->getMaybeAlignValue();
15438   SDLoc dl(Op);
15439   EVT VT = Node->getValueType(0);
15440 
15441   // Construct the new SP value in a GPR.
15442   SDValue SP = DAG.getCopyFromReg(Chain, dl, AArch64::SP, MVT::i64);
15443   Chain = SP.getValue(1);
15444   SP = DAG.getNode(ISD::SUB, dl, MVT::i64, SP, Size);
15445   if (Align)
15446     SP = DAG.getNode(ISD::AND, dl, VT, SP.getValue(0),
15447                      DAG.getConstant(-(uint64_t)Align->value(), dl, VT));
15448 
15449   // Set the real SP to the new value with a probing loop.
15450   Chain = DAG.getNode(AArch64ISD::PROBED_ALLOCA, dl, MVT::Other, Chain, SP);
15451   SDValue Ops[2] = {SP, Chain};
15452   return DAG.getMergeValues(Ops, dl);
15453 }
15454 
15455 SDValue
LowerDYNAMIC_STACKALLOC(SDValue Op,SelectionDAG & DAG) const15456 AArch64TargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op,
15457                                                SelectionDAG &DAG) const {
15458   MachineFunction &MF = DAG.getMachineFunction();
15459 
15460   if (Subtarget->isTargetWindows())
15461     return LowerWindowsDYNAMIC_STACKALLOC(Op, DAG);
15462   else if (hasInlineStackProbe(MF))
15463     return LowerInlineDYNAMIC_STACKALLOC(Op, DAG);
15464   else
15465     return SDValue();
15466 }
15467 
LowerAVG(SDValue Op,SelectionDAG & DAG,unsigned NewOp) const15468 SDValue AArch64TargetLowering::LowerAVG(SDValue Op, SelectionDAG &DAG,
15469                                         unsigned NewOp) const {
15470   if (Subtarget->hasSVE2())
15471     return LowerToPredicatedOp(Op, DAG, NewOp);
15472 
15473   // Default to expand.
15474   return SDValue();
15475 }
15476 
LowerVSCALE(SDValue Op,SelectionDAG & DAG) const15477 SDValue AArch64TargetLowering::LowerVSCALE(SDValue Op,
15478                                            SelectionDAG &DAG) const {
15479   EVT VT = Op.getValueType();
15480   assert(VT != MVT::i64 && "Expected illegal VSCALE node");
15481 
15482   SDLoc DL(Op);
15483   APInt MulImm = Op.getConstantOperandAPInt(0);
15484   return DAG.getZExtOrTrunc(DAG.getVScale(DL, MVT::i64, MulImm.sext(64)), DL,
15485                             VT);
15486 }
15487 
15488 /// Set the IntrinsicInfo for the `aarch64_sve_st<N>` intrinsics.
15489 template <unsigned NumVecs>
15490 static bool
setInfoSVEStN(const AArch64TargetLowering & TLI,const DataLayout & DL,AArch64TargetLowering::IntrinsicInfo & Info,const CallInst & CI)15491 setInfoSVEStN(const AArch64TargetLowering &TLI, const DataLayout &DL,
15492               AArch64TargetLowering::IntrinsicInfo &Info, const CallInst &CI) {
15493   Info.opc = ISD::INTRINSIC_VOID;
15494   // Retrieve EC from first vector argument.
15495   const EVT VT = TLI.getMemValueType(DL, CI.getArgOperand(0)->getType());
15496   ElementCount EC = VT.getVectorElementCount();
15497 #ifndef NDEBUG
15498   // Check the assumption that all input vectors are the same type.
15499   for (unsigned I = 0; I < NumVecs; ++I)
15500     assert(VT == TLI.getMemValueType(DL, CI.getArgOperand(I)->getType()) &&
15501            "Invalid type.");
15502 #endif
15503   // memVT is `NumVecs * VT`.
15504   Info.memVT = EVT::getVectorVT(CI.getType()->getContext(), VT.getScalarType(),
15505                                 EC * NumVecs);
15506   Info.ptrVal = CI.getArgOperand(CI.arg_size() - 1);
15507   Info.offset = 0;
15508   Info.align.reset();
15509   Info.flags = MachineMemOperand::MOStore;
15510   return true;
15511 }
15512 
15513 /// getTgtMemIntrinsic - Represent NEON load and store intrinsics as
15514 /// MemIntrinsicNodes.  The associated MachineMemOperands record the alignment
15515 /// specified in the intrinsic calls.
getTgtMemIntrinsic(IntrinsicInfo & Info,const CallInst & I,MachineFunction & MF,unsigned Intrinsic) const15516 bool AArch64TargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
15517                                                const CallInst &I,
15518                                                MachineFunction &MF,
15519                                                unsigned Intrinsic) const {
15520   auto &DL = I.getDataLayout();
15521   switch (Intrinsic) {
15522   case Intrinsic::aarch64_sve_st2:
15523     return setInfoSVEStN<2>(*this, DL, Info, I);
15524   case Intrinsic::aarch64_sve_st3:
15525     return setInfoSVEStN<3>(*this, DL, Info, I);
15526   case Intrinsic::aarch64_sve_st4:
15527     return setInfoSVEStN<4>(*this, DL, Info, I);
15528   case Intrinsic::aarch64_neon_ld2:
15529   case Intrinsic::aarch64_neon_ld3:
15530   case Intrinsic::aarch64_neon_ld4:
15531   case Intrinsic::aarch64_neon_ld1x2:
15532   case Intrinsic::aarch64_neon_ld1x3:
15533   case Intrinsic::aarch64_neon_ld1x4: {
15534     Info.opc = ISD::INTRINSIC_W_CHAIN;
15535     uint64_t NumElts = DL.getTypeSizeInBits(I.getType()) / 64;
15536     Info.memVT = EVT::getVectorVT(I.getType()->getContext(), MVT::i64, NumElts);
15537     Info.ptrVal = I.getArgOperand(I.arg_size() - 1);
15538     Info.offset = 0;
15539     Info.align.reset();
15540     // volatile loads with NEON intrinsics not supported
15541     Info.flags = MachineMemOperand::MOLoad;
15542     return true;
15543   }
15544   case Intrinsic::aarch64_neon_ld2lane:
15545   case Intrinsic::aarch64_neon_ld3lane:
15546   case Intrinsic::aarch64_neon_ld4lane:
15547   case Intrinsic::aarch64_neon_ld2r:
15548   case Intrinsic::aarch64_neon_ld3r:
15549   case Intrinsic::aarch64_neon_ld4r: {
15550     Info.opc = ISD::INTRINSIC_W_CHAIN;
15551     // ldx return struct with the same vec type
15552     Type *RetTy = I.getType();
15553     auto *StructTy = cast<StructType>(RetTy);
15554     unsigned NumElts = StructTy->getNumElements();
15555     Type *VecTy = StructTy->getElementType(0);
15556     MVT EleVT = MVT::getVT(VecTy).getVectorElementType();
15557     Info.memVT = EVT::getVectorVT(I.getType()->getContext(), EleVT, NumElts);
15558     Info.ptrVal = I.getArgOperand(I.arg_size() - 1);
15559     Info.offset = 0;
15560     Info.align.reset();
15561     // volatile loads with NEON intrinsics not supported
15562     Info.flags = MachineMemOperand::MOLoad;
15563     return true;
15564   }
15565   case Intrinsic::aarch64_neon_st2:
15566   case Intrinsic::aarch64_neon_st3:
15567   case Intrinsic::aarch64_neon_st4:
15568   case Intrinsic::aarch64_neon_st1x2:
15569   case Intrinsic::aarch64_neon_st1x3:
15570   case Intrinsic::aarch64_neon_st1x4: {
15571     Info.opc = ISD::INTRINSIC_VOID;
15572     unsigned NumElts = 0;
15573     for (const Value *Arg : I.args()) {
15574       Type *ArgTy = Arg->getType();
15575       if (!ArgTy->isVectorTy())
15576         break;
15577       NumElts += DL.getTypeSizeInBits(ArgTy) / 64;
15578     }
15579     Info.memVT = EVT::getVectorVT(I.getType()->getContext(), MVT::i64, NumElts);
15580     Info.ptrVal = I.getArgOperand(I.arg_size() - 1);
15581     Info.offset = 0;
15582     Info.align.reset();
15583     // volatile stores with NEON intrinsics not supported
15584     Info.flags = MachineMemOperand::MOStore;
15585     return true;
15586   }
15587   case Intrinsic::aarch64_neon_st2lane:
15588   case Intrinsic::aarch64_neon_st3lane:
15589   case Intrinsic::aarch64_neon_st4lane: {
15590     Info.opc = ISD::INTRINSIC_VOID;
15591     unsigned NumElts = 0;
15592     // all the vector type is same
15593     Type *VecTy = I.getArgOperand(0)->getType();
15594     MVT EleVT = MVT::getVT(VecTy).getVectorElementType();
15595 
15596     for (const Value *Arg : I.args()) {
15597       Type *ArgTy = Arg->getType();
15598       if (!ArgTy->isVectorTy())
15599         break;
15600       NumElts += 1;
15601     }
15602 
15603     Info.memVT = EVT::getVectorVT(I.getType()->getContext(), EleVT, NumElts);
15604     Info.ptrVal = I.getArgOperand(I.arg_size() - 1);
15605     Info.offset = 0;
15606     Info.align.reset();
15607     // volatile stores with NEON intrinsics not supported
15608     Info.flags = MachineMemOperand::MOStore;
15609     return true;
15610   }
15611   case Intrinsic::aarch64_ldaxr:
15612   case Intrinsic::aarch64_ldxr: {
15613     Type *ValTy = I.getParamElementType(0);
15614     Info.opc = ISD::INTRINSIC_W_CHAIN;
15615     Info.memVT = MVT::getVT(ValTy);
15616     Info.ptrVal = I.getArgOperand(0);
15617     Info.offset = 0;
15618     Info.align = DL.getABITypeAlign(ValTy);
15619     Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOVolatile;
15620     return true;
15621   }
15622   case Intrinsic::aarch64_stlxr:
15623   case Intrinsic::aarch64_stxr: {
15624     Type *ValTy = I.getParamElementType(1);
15625     Info.opc = ISD::INTRINSIC_W_CHAIN;
15626     Info.memVT = MVT::getVT(ValTy);
15627     Info.ptrVal = I.getArgOperand(1);
15628     Info.offset = 0;
15629     Info.align = DL.getABITypeAlign(ValTy);
15630     Info.flags = MachineMemOperand::MOStore | MachineMemOperand::MOVolatile;
15631     return true;
15632   }
15633   case Intrinsic::aarch64_ldaxp:
15634   case Intrinsic::aarch64_ldxp:
15635     Info.opc = ISD::INTRINSIC_W_CHAIN;
15636     Info.memVT = MVT::i128;
15637     Info.ptrVal = I.getArgOperand(0);
15638     Info.offset = 0;
15639     Info.align = Align(16);
15640     Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOVolatile;
15641     return true;
15642   case Intrinsic::aarch64_stlxp:
15643   case Intrinsic::aarch64_stxp:
15644     Info.opc = ISD::INTRINSIC_W_CHAIN;
15645     Info.memVT = MVT::i128;
15646     Info.ptrVal = I.getArgOperand(2);
15647     Info.offset = 0;
15648     Info.align = Align(16);
15649     Info.flags = MachineMemOperand::MOStore | MachineMemOperand::MOVolatile;
15650     return true;
15651   case Intrinsic::aarch64_sve_ldnt1: {
15652     Type *ElTy = cast<VectorType>(I.getType())->getElementType();
15653     Info.opc = ISD::INTRINSIC_W_CHAIN;
15654     Info.memVT = MVT::getVT(I.getType());
15655     Info.ptrVal = I.getArgOperand(1);
15656     Info.offset = 0;
15657     Info.align = DL.getABITypeAlign(ElTy);
15658     Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MONonTemporal;
15659     return true;
15660   }
15661   case Intrinsic::aarch64_sve_stnt1: {
15662     Type *ElTy =
15663         cast<VectorType>(I.getArgOperand(0)->getType())->getElementType();
15664     Info.opc = ISD::INTRINSIC_W_CHAIN;
15665     Info.memVT = MVT::getVT(I.getOperand(0)->getType());
15666     Info.ptrVal = I.getArgOperand(2);
15667     Info.offset = 0;
15668     Info.align = DL.getABITypeAlign(ElTy);
15669     Info.flags = MachineMemOperand::MOStore | MachineMemOperand::MONonTemporal;
15670     return true;
15671   }
15672   case Intrinsic::aarch64_mops_memset_tag: {
15673     Value *Dst = I.getArgOperand(0);
15674     Value *Val = I.getArgOperand(1);
15675     Info.opc = ISD::INTRINSIC_W_CHAIN;
15676     Info.memVT = MVT::getVT(Val->getType());
15677     Info.ptrVal = Dst;
15678     Info.offset = 0;
15679     Info.align = I.getParamAlign(0).valueOrOne();
15680     Info.flags = MachineMemOperand::MOStore;
15681     // The size of the memory being operated on is unknown at this point
15682     Info.size = MemoryLocation::UnknownSize;
15683     return true;
15684   }
15685   default:
15686     break;
15687   }
15688 
15689   return false;
15690 }
15691 
shouldReduceLoadWidth(SDNode * Load,ISD::LoadExtType ExtTy,EVT NewVT) const15692 bool AArch64TargetLowering::shouldReduceLoadWidth(SDNode *Load,
15693                                                   ISD::LoadExtType ExtTy,
15694                                                   EVT NewVT) const {
15695   // TODO: This may be worth removing. Check regression tests for diffs.
15696   if (!TargetLoweringBase::shouldReduceLoadWidth(Load, ExtTy, NewVT))
15697     return false;
15698 
15699   // If we're reducing the load width in order to avoid having to use an extra
15700   // instruction to do extension then it's probably a good idea.
15701   if (ExtTy != ISD::NON_EXTLOAD)
15702     return true;
15703   // Don't reduce load width if it would prevent us from combining a shift into
15704   // the offset.
15705   MemSDNode *Mem = dyn_cast<MemSDNode>(Load);
15706   assert(Mem);
15707   const SDValue &Base = Mem->getBasePtr();
15708   if (Base.getOpcode() == ISD::ADD &&
15709       Base.getOperand(1).getOpcode() == ISD::SHL &&
15710       Base.getOperand(1).hasOneUse() &&
15711       Base.getOperand(1).getOperand(1).getOpcode() == ISD::Constant) {
15712     // It's unknown whether a scalable vector has a power-of-2 bitwidth.
15713     if (Mem->getMemoryVT().isScalableVector())
15714       return false;
15715     // The shift can be combined if it matches the size of the value being
15716     // loaded (and so reducing the width would make it not match).
15717     uint64_t ShiftAmount = Base.getOperand(1).getConstantOperandVal(1);
15718     uint64_t LoadBytes = Mem->getMemoryVT().getSizeInBits()/8;
15719     if (ShiftAmount == Log2_32(LoadBytes))
15720       return false;
15721   }
15722   // We have no reason to disallow reducing the load width, so allow it.
15723   return true;
15724 }
15725 
15726 // Treat a sext_inreg(extract(..)) as free if it has multiple uses.
shouldRemoveRedundantExtend(SDValue Extend) const15727 bool AArch64TargetLowering::shouldRemoveRedundantExtend(SDValue Extend) const {
15728   EVT VT = Extend.getValueType();
15729   if ((VT == MVT::i64 || VT == MVT::i32) && Extend->use_size()) {
15730     SDValue Extract = Extend.getOperand(0);
15731     if (Extract.getOpcode() == ISD::ANY_EXTEND && Extract.hasOneUse())
15732       Extract = Extract.getOperand(0);
15733     if (Extract.getOpcode() == ISD::EXTRACT_VECTOR_ELT && Extract.hasOneUse()) {
15734       EVT VecVT = Extract.getOperand(0).getValueType();
15735       if (VecVT.getScalarType() == MVT::i8 || VecVT.getScalarType() == MVT::i16)
15736         return false;
15737     }
15738   }
15739   return true;
15740 }
15741 
15742 // Truncations from 64-bit GPR to 32-bit GPR is free.
isTruncateFree(Type * Ty1,Type * Ty2) const15743 bool AArch64TargetLowering::isTruncateFree(Type *Ty1, Type *Ty2) const {
15744   if (!Ty1->isIntegerTy() || !Ty2->isIntegerTy())
15745     return false;
15746   uint64_t NumBits1 = Ty1->getPrimitiveSizeInBits().getFixedValue();
15747   uint64_t NumBits2 = Ty2->getPrimitiveSizeInBits().getFixedValue();
15748   return NumBits1 > NumBits2;
15749 }
isTruncateFree(EVT VT1,EVT VT2) const15750 bool AArch64TargetLowering::isTruncateFree(EVT VT1, EVT VT2) const {
15751   if (VT1.isVector() || VT2.isVector() || !VT1.isInteger() || !VT2.isInteger())
15752     return false;
15753   uint64_t NumBits1 = VT1.getFixedSizeInBits();
15754   uint64_t NumBits2 = VT2.getFixedSizeInBits();
15755   return NumBits1 > NumBits2;
15756 }
15757 
15758 /// Check if it is profitable to hoist instruction in then/else to if.
15759 /// Not profitable if I and it's user can form a FMA instruction
15760 /// because we prefer FMSUB/FMADD.
isProfitableToHoist(Instruction * I) const15761 bool AArch64TargetLowering::isProfitableToHoist(Instruction *I) const {
15762   if (I->getOpcode() != Instruction::FMul)
15763     return true;
15764 
15765   if (!I->hasOneUse())
15766     return true;
15767 
15768   Instruction *User = I->user_back();
15769 
15770   if (!(User->getOpcode() == Instruction::FSub ||
15771         User->getOpcode() == Instruction::FAdd))
15772     return true;
15773 
15774   const TargetOptions &Options = getTargetMachine().Options;
15775   const Function *F = I->getFunction();
15776   const DataLayout &DL = F->getDataLayout();
15777   Type *Ty = User->getOperand(0)->getType();
15778 
15779   return !(isFMAFasterThanFMulAndFAdd(*F, Ty) &&
15780            isOperationLegalOrCustom(ISD::FMA, getValueType(DL, Ty)) &&
15781            (Options.AllowFPOpFusion == FPOpFusion::Fast ||
15782             Options.UnsafeFPMath));
15783 }
15784 
15785 // All 32-bit GPR operations implicitly zero the high-half of the corresponding
15786 // 64-bit GPR.
isZExtFree(Type * Ty1,Type * Ty2) const15787 bool AArch64TargetLowering::isZExtFree(Type *Ty1, Type *Ty2) const {
15788   if (!Ty1->isIntegerTy() || !Ty2->isIntegerTy())
15789     return false;
15790   unsigned NumBits1 = Ty1->getPrimitiveSizeInBits();
15791   unsigned NumBits2 = Ty2->getPrimitiveSizeInBits();
15792   return NumBits1 == 32 && NumBits2 == 64;
15793 }
isZExtFree(EVT VT1,EVT VT2) const15794 bool AArch64TargetLowering::isZExtFree(EVT VT1, EVT VT2) const {
15795   if (VT1.isVector() || VT2.isVector() || !VT1.isInteger() || !VT2.isInteger())
15796     return false;
15797   unsigned NumBits1 = VT1.getSizeInBits();
15798   unsigned NumBits2 = VT2.getSizeInBits();
15799   return NumBits1 == 32 && NumBits2 == 64;
15800 }
15801 
isZExtFree(SDValue Val,EVT VT2) const15802 bool AArch64TargetLowering::isZExtFree(SDValue Val, EVT VT2) const {
15803   EVT VT1 = Val.getValueType();
15804   if (isZExtFree(VT1, VT2)) {
15805     return true;
15806   }
15807 
15808   if (Val.getOpcode() != ISD::LOAD)
15809     return false;
15810 
15811   // 8-, 16-, and 32-bit integer loads all implicitly zero-extend.
15812   return (VT1.isSimple() && !VT1.isVector() && VT1.isInteger() &&
15813           VT2.isSimple() && !VT2.isVector() && VT2.isInteger() &&
15814           VT1.getSizeInBits() <= 32);
15815 }
15816 
isExtFreeImpl(const Instruction * Ext) const15817 bool AArch64TargetLowering::isExtFreeImpl(const Instruction *Ext) const {
15818   if (isa<FPExtInst>(Ext))
15819     return false;
15820 
15821   // Vector types are not free.
15822   if (Ext->getType()->isVectorTy())
15823     return false;
15824 
15825   for (const Use &U : Ext->uses()) {
15826     // The extension is free if we can fold it with a left shift in an
15827     // addressing mode or an arithmetic operation: add, sub, and cmp.
15828 
15829     // Is there a shift?
15830     const Instruction *Instr = cast<Instruction>(U.getUser());
15831 
15832     // Is this a constant shift?
15833     switch (Instr->getOpcode()) {
15834     case Instruction::Shl:
15835       if (!isa<ConstantInt>(Instr->getOperand(1)))
15836         return false;
15837       break;
15838     case Instruction::GetElementPtr: {
15839       gep_type_iterator GTI = gep_type_begin(Instr);
15840       auto &DL = Ext->getDataLayout();
15841       std::advance(GTI, U.getOperandNo()-1);
15842       Type *IdxTy = GTI.getIndexedType();
15843       // This extension will end up with a shift because of the scaling factor.
15844       // 8-bit sized types have a scaling factor of 1, thus a shift amount of 0.
15845       // Get the shift amount based on the scaling factor:
15846       // log2(sizeof(IdxTy)) - log2(8).
15847       if (IdxTy->isScalableTy())
15848         return false;
15849       uint64_t ShiftAmt =
15850           llvm::countr_zero(DL.getTypeStoreSizeInBits(IdxTy).getFixedValue()) -
15851           3;
15852       // Is the constant foldable in the shift of the addressing mode?
15853       // I.e., shift amount is between 1 and 4 inclusive.
15854       if (ShiftAmt == 0 || ShiftAmt > 4)
15855         return false;
15856       break;
15857     }
15858     case Instruction::Trunc:
15859       // Check if this is a noop.
15860       // trunc(sext ty1 to ty2) to ty1.
15861       if (Instr->getType() == Ext->getOperand(0)->getType())
15862         continue;
15863       [[fallthrough]];
15864     default:
15865       return false;
15866     }
15867 
15868     // At this point we can use the bfm family, so this extension is free
15869     // for that use.
15870   }
15871   return true;
15872 }
15873 
isSplatShuffle(Value * V)15874 static bool isSplatShuffle(Value *V) {
15875   if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V))
15876     return all_equal(Shuf->getShuffleMask());
15877   return false;
15878 }
15879 
15880 /// Check if both Op1 and Op2 are shufflevector extracts of either the lower
15881 /// or upper half of the vector elements.
areExtractShuffleVectors(Value * Op1,Value * Op2,bool AllowSplat=false)15882 static bool areExtractShuffleVectors(Value *Op1, Value *Op2,
15883                                      bool AllowSplat = false) {
15884   auto areTypesHalfed = [](Value *FullV, Value *HalfV) {
15885     auto *FullTy = FullV->getType();
15886     auto *HalfTy = HalfV->getType();
15887     return FullTy->getPrimitiveSizeInBits().getFixedValue() ==
15888            2 * HalfTy->getPrimitiveSizeInBits().getFixedValue();
15889   };
15890 
15891   auto extractHalf = [](Value *FullV, Value *HalfV) {
15892     auto *FullVT = cast<FixedVectorType>(FullV->getType());
15893     auto *HalfVT = cast<FixedVectorType>(HalfV->getType());
15894     return FullVT->getNumElements() == 2 * HalfVT->getNumElements();
15895   };
15896 
15897   ArrayRef<int> M1, M2;
15898   Value *S1Op1 = nullptr, *S2Op1 = nullptr;
15899   if (!match(Op1, m_Shuffle(m_Value(S1Op1), m_Undef(), m_Mask(M1))) ||
15900       !match(Op2, m_Shuffle(m_Value(S2Op1), m_Undef(), m_Mask(M2))))
15901     return false;
15902 
15903   // If we allow splats, set S1Op1/S2Op1 to nullptr for the relavant arg so that
15904   // it is not checked as an extract below.
15905   if (AllowSplat && isSplatShuffle(Op1))
15906     S1Op1 = nullptr;
15907   if (AllowSplat && isSplatShuffle(Op2))
15908     S2Op1 = nullptr;
15909 
15910   // Check that the operands are half as wide as the result and we extract
15911   // half of the elements of the input vectors.
15912   if ((S1Op1 && (!areTypesHalfed(S1Op1, Op1) || !extractHalf(S1Op1, Op1))) ||
15913       (S2Op1 && (!areTypesHalfed(S2Op1, Op2) || !extractHalf(S2Op1, Op2))))
15914     return false;
15915 
15916   // Check the mask extracts either the lower or upper half of vector
15917   // elements.
15918   int M1Start = 0;
15919   int M2Start = 0;
15920   int NumElements = cast<FixedVectorType>(Op1->getType())->getNumElements() * 2;
15921   if ((S1Op1 &&
15922        !ShuffleVectorInst::isExtractSubvectorMask(M1, NumElements, M1Start)) ||
15923       (S2Op1 &&
15924        !ShuffleVectorInst::isExtractSubvectorMask(M2, NumElements, M2Start)))
15925     return false;
15926 
15927   if ((M1Start != 0 && M1Start != (NumElements / 2)) ||
15928       (M2Start != 0 && M2Start != (NumElements / 2)))
15929     return false;
15930   if (S1Op1 && S2Op1 && M1Start != M2Start)
15931     return false;
15932 
15933   return true;
15934 }
15935 
15936 /// Check if Ext1 and Ext2 are extends of the same type, doubling the bitwidth
15937 /// of the vector elements.
areExtractExts(Value * Ext1,Value * Ext2)15938 static bool areExtractExts(Value *Ext1, Value *Ext2) {
15939   auto areExtDoubled = [](Instruction *Ext) {
15940     return Ext->getType()->getScalarSizeInBits() ==
15941            2 * Ext->getOperand(0)->getType()->getScalarSizeInBits();
15942   };
15943 
15944   if (!match(Ext1, m_ZExtOrSExt(m_Value())) ||
15945       !match(Ext2, m_ZExtOrSExt(m_Value())) ||
15946       !areExtDoubled(cast<Instruction>(Ext1)) ||
15947       !areExtDoubled(cast<Instruction>(Ext2)))
15948     return false;
15949 
15950   return true;
15951 }
15952 
15953 /// Check if Op could be used with vmull_high_p64 intrinsic.
isOperandOfVmullHighP64(Value * Op)15954 static bool isOperandOfVmullHighP64(Value *Op) {
15955   Value *VectorOperand = nullptr;
15956   ConstantInt *ElementIndex = nullptr;
15957   return match(Op, m_ExtractElt(m_Value(VectorOperand),
15958                                 m_ConstantInt(ElementIndex))) &&
15959          ElementIndex->getValue() == 1 &&
15960          isa<FixedVectorType>(VectorOperand->getType()) &&
15961          cast<FixedVectorType>(VectorOperand->getType())->getNumElements() == 2;
15962 }
15963 
15964 /// Check if Op1 and Op2 could be used with vmull_high_p64 intrinsic.
areOperandsOfVmullHighP64(Value * Op1,Value * Op2)15965 static bool areOperandsOfVmullHighP64(Value *Op1, Value *Op2) {
15966   return isOperandOfVmullHighP64(Op1) && isOperandOfVmullHighP64(Op2);
15967 }
15968 
shouldSinkVectorOfPtrs(Value * Ptrs,SmallVectorImpl<Use * > & Ops)15969 static bool shouldSinkVectorOfPtrs(Value *Ptrs, SmallVectorImpl<Use *> &Ops) {
15970   // Restrict ourselves to the form CodeGenPrepare typically constructs.
15971   auto *GEP = dyn_cast<GetElementPtrInst>(Ptrs);
15972   if (!GEP || GEP->getNumOperands() != 2)
15973     return false;
15974 
15975   Value *Base = GEP->getOperand(0);
15976   Value *Offsets = GEP->getOperand(1);
15977 
15978   // We only care about scalar_base+vector_offsets.
15979   if (Base->getType()->isVectorTy() || !Offsets->getType()->isVectorTy())
15980     return false;
15981 
15982   // Sink extends that would allow us to use 32-bit offset vectors.
15983   if (isa<SExtInst>(Offsets) || isa<ZExtInst>(Offsets)) {
15984     auto *OffsetsInst = cast<Instruction>(Offsets);
15985     if (OffsetsInst->getType()->getScalarSizeInBits() > 32 &&
15986         OffsetsInst->getOperand(0)->getType()->getScalarSizeInBits() <= 32)
15987       Ops.push_back(&GEP->getOperandUse(1));
15988   }
15989 
15990   // Sink the GEP.
15991   return true;
15992 }
15993 
15994 /// We want to sink following cases:
15995 /// (add|sub|gep) A, ((mul|shl) vscale, imm); (add|sub|gep) A, vscale;
15996 /// (add|sub|gep) A, ((mul|shl) zext(vscale), imm);
shouldSinkVScale(Value * Op,SmallVectorImpl<Use * > & Ops)15997 static bool shouldSinkVScale(Value *Op, SmallVectorImpl<Use *> &Ops) {
15998   if (match(Op, m_VScale()))
15999     return true;
16000   if (match(Op, m_Shl(m_VScale(), m_ConstantInt())) ||
16001       match(Op, m_Mul(m_VScale(), m_ConstantInt()))) {
16002     Ops.push_back(&cast<Instruction>(Op)->getOperandUse(0));
16003     return true;
16004   }
16005   if (match(Op, m_Shl(m_ZExt(m_VScale()), m_ConstantInt())) ||
16006       match(Op, m_Mul(m_ZExt(m_VScale()), m_ConstantInt()))) {
16007     Value *ZExtOp = cast<Instruction>(Op)->getOperand(0);
16008     Ops.push_back(&cast<Instruction>(ZExtOp)->getOperandUse(0));
16009     Ops.push_back(&cast<Instruction>(Op)->getOperandUse(0));
16010     return true;
16011   }
16012   return false;
16013 }
16014 
16015 /// Check if sinking \p I's operands to I's basic block is profitable, because
16016 /// the operands can be folded into a target instruction, e.g.
16017 /// shufflevectors extracts and/or sext/zext can be folded into (u,s)subl(2).
shouldSinkOperands(Instruction * I,SmallVectorImpl<Use * > & Ops) const16018 bool AArch64TargetLowering::shouldSinkOperands(
16019     Instruction *I, SmallVectorImpl<Use *> &Ops) const {
16020   if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
16021     switch (II->getIntrinsicID()) {
16022     case Intrinsic::aarch64_neon_smull:
16023     case Intrinsic::aarch64_neon_umull:
16024       if (areExtractShuffleVectors(II->getOperand(0), II->getOperand(1),
16025                                    /*AllowSplat=*/true)) {
16026         Ops.push_back(&II->getOperandUse(0));
16027         Ops.push_back(&II->getOperandUse(1));
16028         return true;
16029       }
16030       [[fallthrough]];
16031 
16032     case Intrinsic::fma:
16033       if (isa<VectorType>(I->getType()) &&
16034           cast<VectorType>(I->getType())->getElementType()->isHalfTy() &&
16035           !Subtarget->hasFullFP16())
16036         return false;
16037       [[fallthrough]];
16038     case Intrinsic::aarch64_neon_sqdmull:
16039     case Intrinsic::aarch64_neon_sqdmulh:
16040     case Intrinsic::aarch64_neon_sqrdmulh:
16041       // Sink splats for index lane variants
16042       if (isSplatShuffle(II->getOperand(0)))
16043         Ops.push_back(&II->getOperandUse(0));
16044       if (isSplatShuffle(II->getOperand(1)))
16045         Ops.push_back(&II->getOperandUse(1));
16046       return !Ops.empty();
16047     case Intrinsic::aarch64_neon_fmlal:
16048     case Intrinsic::aarch64_neon_fmlal2:
16049     case Intrinsic::aarch64_neon_fmlsl:
16050     case Intrinsic::aarch64_neon_fmlsl2:
16051       // Sink splats for index lane variants
16052       if (isSplatShuffle(II->getOperand(1)))
16053         Ops.push_back(&II->getOperandUse(1));
16054       if (isSplatShuffle(II->getOperand(2)))
16055         Ops.push_back(&II->getOperandUse(2));
16056       return !Ops.empty();
16057     case Intrinsic::aarch64_sve_ptest_first:
16058     case Intrinsic::aarch64_sve_ptest_last:
16059       if (auto *IIOp = dyn_cast<IntrinsicInst>(II->getOperand(0)))
16060         if (IIOp->getIntrinsicID() == Intrinsic::aarch64_sve_ptrue)
16061           Ops.push_back(&II->getOperandUse(0));
16062       return !Ops.empty();
16063     case Intrinsic::aarch64_sme_write_horiz:
16064     case Intrinsic::aarch64_sme_write_vert:
16065     case Intrinsic::aarch64_sme_writeq_horiz:
16066     case Intrinsic::aarch64_sme_writeq_vert: {
16067       auto *Idx = dyn_cast<Instruction>(II->getOperand(1));
16068       if (!Idx || Idx->getOpcode() != Instruction::Add)
16069         return false;
16070       Ops.push_back(&II->getOperandUse(1));
16071       return true;
16072     }
16073     case Intrinsic::aarch64_sme_read_horiz:
16074     case Intrinsic::aarch64_sme_read_vert:
16075     case Intrinsic::aarch64_sme_readq_horiz:
16076     case Intrinsic::aarch64_sme_readq_vert:
16077     case Intrinsic::aarch64_sme_ld1b_vert:
16078     case Intrinsic::aarch64_sme_ld1h_vert:
16079     case Intrinsic::aarch64_sme_ld1w_vert:
16080     case Intrinsic::aarch64_sme_ld1d_vert:
16081     case Intrinsic::aarch64_sme_ld1q_vert:
16082     case Intrinsic::aarch64_sme_st1b_vert:
16083     case Intrinsic::aarch64_sme_st1h_vert:
16084     case Intrinsic::aarch64_sme_st1w_vert:
16085     case Intrinsic::aarch64_sme_st1d_vert:
16086     case Intrinsic::aarch64_sme_st1q_vert:
16087     case Intrinsic::aarch64_sme_ld1b_horiz:
16088     case Intrinsic::aarch64_sme_ld1h_horiz:
16089     case Intrinsic::aarch64_sme_ld1w_horiz:
16090     case Intrinsic::aarch64_sme_ld1d_horiz:
16091     case Intrinsic::aarch64_sme_ld1q_horiz:
16092     case Intrinsic::aarch64_sme_st1b_horiz:
16093     case Intrinsic::aarch64_sme_st1h_horiz:
16094     case Intrinsic::aarch64_sme_st1w_horiz:
16095     case Intrinsic::aarch64_sme_st1d_horiz:
16096     case Intrinsic::aarch64_sme_st1q_horiz: {
16097       auto *Idx = dyn_cast<Instruction>(II->getOperand(3));
16098       if (!Idx || Idx->getOpcode() != Instruction::Add)
16099         return false;
16100       Ops.push_back(&II->getOperandUse(3));
16101       return true;
16102     }
16103     case Intrinsic::aarch64_neon_pmull:
16104       if (!areExtractShuffleVectors(II->getOperand(0), II->getOperand(1)))
16105         return false;
16106       Ops.push_back(&II->getOperandUse(0));
16107       Ops.push_back(&II->getOperandUse(1));
16108       return true;
16109     case Intrinsic::aarch64_neon_pmull64:
16110       if (!areOperandsOfVmullHighP64(II->getArgOperand(0),
16111                                      II->getArgOperand(1)))
16112         return false;
16113       Ops.push_back(&II->getArgOperandUse(0));
16114       Ops.push_back(&II->getArgOperandUse(1));
16115       return true;
16116     case Intrinsic::masked_gather:
16117       if (!shouldSinkVectorOfPtrs(II->getArgOperand(0), Ops))
16118         return false;
16119       Ops.push_back(&II->getArgOperandUse(0));
16120       return true;
16121     case Intrinsic::masked_scatter:
16122       if (!shouldSinkVectorOfPtrs(II->getArgOperand(1), Ops))
16123         return false;
16124       Ops.push_back(&II->getArgOperandUse(1));
16125       return true;
16126     default:
16127       return false;
16128     }
16129   }
16130 
16131   // Sink vscales closer to uses for better isel
16132   switch (I->getOpcode()) {
16133   case Instruction::GetElementPtr:
16134   case Instruction::Add:
16135   case Instruction::Sub:
16136     for (unsigned Op = 0; Op < I->getNumOperands(); ++Op) {
16137       if (shouldSinkVScale(I->getOperand(Op), Ops)) {
16138         Ops.push_back(&I->getOperandUse(Op));
16139         return true;
16140       }
16141     }
16142     break;
16143   default:
16144     break;
16145   }
16146 
16147   if (!I->getType()->isVectorTy())
16148     return false;
16149 
16150   switch (I->getOpcode()) {
16151   case Instruction::Sub:
16152   case Instruction::Add: {
16153     if (!areExtractExts(I->getOperand(0), I->getOperand(1)))
16154       return false;
16155 
16156     // If the exts' operands extract either the lower or upper elements, we
16157     // can sink them too.
16158     auto Ext1 = cast<Instruction>(I->getOperand(0));
16159     auto Ext2 = cast<Instruction>(I->getOperand(1));
16160     if (areExtractShuffleVectors(Ext1->getOperand(0), Ext2->getOperand(0))) {
16161       Ops.push_back(&Ext1->getOperandUse(0));
16162       Ops.push_back(&Ext2->getOperandUse(0));
16163     }
16164 
16165     Ops.push_back(&I->getOperandUse(0));
16166     Ops.push_back(&I->getOperandUse(1));
16167 
16168     return true;
16169   }
16170   case Instruction::Or: {
16171     // Pattern: Or(And(MaskValue, A), And(Not(MaskValue), B)) ->
16172     // bitselect(MaskValue, A, B) where Not(MaskValue) = Xor(MaskValue, -1)
16173     if (Subtarget->hasNEON()) {
16174       Instruction *OtherAnd, *IA, *IB;
16175       Value *MaskValue;
16176       // MainAnd refers to And instruction that has 'Not' as one of its operands
16177       if (match(I, m_c_Or(m_OneUse(m_Instruction(OtherAnd)),
16178                           m_OneUse(m_c_And(m_OneUse(m_Not(m_Value(MaskValue))),
16179                                            m_Instruction(IA)))))) {
16180         if (match(OtherAnd,
16181                   m_c_And(m_Specific(MaskValue), m_Instruction(IB)))) {
16182           Instruction *MainAnd = I->getOperand(0) == OtherAnd
16183                                      ? cast<Instruction>(I->getOperand(1))
16184                                      : cast<Instruction>(I->getOperand(0));
16185 
16186           // Both Ands should be in same basic block as Or
16187           if (I->getParent() != MainAnd->getParent() ||
16188               I->getParent() != OtherAnd->getParent())
16189             return false;
16190 
16191           // Non-mask operands of both Ands should also be in same basic block
16192           if (I->getParent() != IA->getParent() ||
16193               I->getParent() != IB->getParent())
16194             return false;
16195 
16196           Ops.push_back(&MainAnd->getOperandUse(MainAnd->getOperand(0) == IA ? 1 : 0));
16197           Ops.push_back(&I->getOperandUse(0));
16198           Ops.push_back(&I->getOperandUse(1));
16199 
16200           return true;
16201         }
16202       }
16203     }
16204 
16205     return false;
16206   }
16207   case Instruction::Mul: {
16208     int NumZExts = 0, NumSExts = 0;
16209     for (auto &Op : I->operands()) {
16210       // Make sure we are not already sinking this operand
16211       if (any_of(Ops, [&](Use *U) { return U->get() == Op; }))
16212         continue;
16213 
16214       if (match(&Op, m_SExt(m_Value()))) {
16215         NumSExts++;
16216         continue;
16217       } else if (match(&Op, m_ZExt(m_Value()))) {
16218         NumZExts++;
16219         continue;
16220       }
16221 
16222       ShuffleVectorInst *Shuffle = dyn_cast<ShuffleVectorInst>(Op);
16223 
16224       // If the Shuffle is a splat and the operand is a zext/sext, sinking the
16225       // operand and the s/zext can help create indexed s/umull. This is
16226       // especially useful to prevent i64 mul being scalarized.
16227       if (Shuffle && isSplatShuffle(Shuffle) &&
16228           match(Shuffle->getOperand(0), m_ZExtOrSExt(m_Value()))) {
16229         Ops.push_back(&Shuffle->getOperandUse(0));
16230         Ops.push_back(&Op);
16231         if (match(Shuffle->getOperand(0), m_SExt(m_Value())))
16232           NumSExts++;
16233         else
16234           NumZExts++;
16235         continue;
16236       }
16237 
16238       if (!Shuffle)
16239         continue;
16240 
16241       Value *ShuffleOperand = Shuffle->getOperand(0);
16242       InsertElementInst *Insert = dyn_cast<InsertElementInst>(ShuffleOperand);
16243       if (!Insert)
16244         continue;
16245 
16246       Instruction *OperandInstr = dyn_cast<Instruction>(Insert->getOperand(1));
16247       if (!OperandInstr)
16248         continue;
16249 
16250       ConstantInt *ElementConstant =
16251           dyn_cast<ConstantInt>(Insert->getOperand(2));
16252       // Check that the insertelement is inserting into element 0
16253       if (!ElementConstant || !ElementConstant->isZero())
16254         continue;
16255 
16256       unsigned Opcode = OperandInstr->getOpcode();
16257       if (Opcode == Instruction::SExt)
16258         NumSExts++;
16259       else if (Opcode == Instruction::ZExt)
16260         NumZExts++;
16261       else {
16262         // If we find that the top bits are known 0, then we can sink and allow
16263         // the backend to generate a umull.
16264         unsigned Bitwidth = I->getType()->getScalarSizeInBits();
16265         APInt UpperMask = APInt::getHighBitsSet(Bitwidth, Bitwidth / 2);
16266         const DataLayout &DL = I->getDataLayout();
16267         if (!MaskedValueIsZero(OperandInstr, UpperMask, DL))
16268           continue;
16269         NumZExts++;
16270       }
16271 
16272       Ops.push_back(&Shuffle->getOperandUse(0));
16273       Ops.push_back(&Op);
16274     }
16275 
16276     // Is it profitable to sink if we found two of the same type of extends.
16277     return !Ops.empty() && (NumSExts == 2 || NumZExts == 2);
16278   }
16279   default:
16280     return false;
16281   }
16282   return false;
16283 }
16284 
createTblShuffleMask(unsigned SrcWidth,unsigned DstWidth,unsigned NumElts,bool IsLittleEndian,SmallVectorImpl<int> & Mask)16285 static bool createTblShuffleMask(unsigned SrcWidth, unsigned DstWidth,
16286                                  unsigned NumElts, bool IsLittleEndian,
16287                                  SmallVectorImpl<int> &Mask) {
16288   if (DstWidth % 8 != 0 || DstWidth <= 16 || DstWidth >= 64)
16289     return false;
16290 
16291   assert(DstWidth % SrcWidth == 0 &&
16292          "TBL lowering is not supported for a conversion instruction with this "
16293          "source and destination element type.");
16294 
16295   unsigned Factor = DstWidth / SrcWidth;
16296   unsigned MaskLen = NumElts * Factor;
16297 
16298   Mask.clear();
16299   Mask.resize(MaskLen, NumElts);
16300 
16301   unsigned SrcIndex = 0;
16302   for (unsigned I = IsLittleEndian ? 0 : Factor - 1; I < MaskLen; I += Factor)
16303     Mask[I] = SrcIndex++;
16304 
16305   return true;
16306 }
16307 
createTblShuffleForZExt(IRBuilderBase & Builder,Value * Op,FixedVectorType * ZExtTy,FixedVectorType * DstTy,bool IsLittleEndian)16308 static Value *createTblShuffleForZExt(IRBuilderBase &Builder, Value *Op,
16309                                       FixedVectorType *ZExtTy,
16310                                       FixedVectorType *DstTy,
16311                                       bool IsLittleEndian) {
16312   auto *SrcTy = cast<FixedVectorType>(Op->getType());
16313   unsigned NumElts = SrcTy->getNumElements();
16314   auto SrcWidth = cast<IntegerType>(SrcTy->getElementType())->getBitWidth();
16315   auto DstWidth = cast<IntegerType>(DstTy->getElementType())->getBitWidth();
16316 
16317   SmallVector<int> Mask;
16318   if (!createTblShuffleMask(SrcWidth, DstWidth, NumElts, IsLittleEndian, Mask))
16319     return nullptr;
16320 
16321   auto *FirstEltZero = Builder.CreateInsertElement(
16322       PoisonValue::get(SrcTy), Builder.getInt8(0), uint64_t(0));
16323   Value *Result = Builder.CreateShuffleVector(Op, FirstEltZero, Mask);
16324   Result = Builder.CreateBitCast(Result, DstTy);
16325   if (DstTy != ZExtTy)
16326     Result = Builder.CreateZExt(Result, ZExtTy);
16327   return Result;
16328 }
16329 
createTblShuffleForSExt(IRBuilderBase & Builder,Value * Op,FixedVectorType * DstTy,bool IsLittleEndian)16330 static Value *createTblShuffleForSExt(IRBuilderBase &Builder, Value *Op,
16331                                       FixedVectorType *DstTy,
16332                                       bool IsLittleEndian) {
16333   auto *SrcTy = cast<FixedVectorType>(Op->getType());
16334   auto SrcWidth = cast<IntegerType>(SrcTy->getElementType())->getBitWidth();
16335   auto DstWidth = cast<IntegerType>(DstTy->getElementType())->getBitWidth();
16336 
16337   SmallVector<int> Mask;
16338   if (!createTblShuffleMask(SrcWidth, DstWidth, SrcTy->getNumElements(),
16339                             !IsLittleEndian, Mask))
16340     return nullptr;
16341 
16342   auto *FirstEltZero = Builder.CreateInsertElement(
16343       PoisonValue::get(SrcTy), Builder.getInt8(0), uint64_t(0));
16344 
16345   return Builder.CreateShuffleVector(Op, FirstEltZero, Mask);
16346 }
16347 
createTblForTrunc(TruncInst * TI,bool IsLittleEndian)16348 static void createTblForTrunc(TruncInst *TI, bool IsLittleEndian) {
16349   IRBuilder<> Builder(TI);
16350   SmallVector<Value *> Parts;
16351   int NumElements = cast<FixedVectorType>(TI->getType())->getNumElements();
16352   auto *SrcTy = cast<FixedVectorType>(TI->getOperand(0)->getType());
16353   auto *DstTy = cast<FixedVectorType>(TI->getType());
16354   assert(SrcTy->getElementType()->isIntegerTy() &&
16355          "Non-integer type source vector element is not supported");
16356   assert(DstTy->getElementType()->isIntegerTy(8) &&
16357          "Unsupported destination vector element type");
16358   unsigned SrcElemTySz =
16359       cast<IntegerType>(SrcTy->getElementType())->getBitWidth();
16360   unsigned DstElemTySz =
16361       cast<IntegerType>(DstTy->getElementType())->getBitWidth();
16362   assert((SrcElemTySz % DstElemTySz == 0) &&
16363          "Cannot lower truncate to tbl instructions for a source element size "
16364          "that is not divisible by the destination element size");
16365   unsigned TruncFactor = SrcElemTySz / DstElemTySz;
16366   assert((SrcElemTySz == 16 || SrcElemTySz == 32 || SrcElemTySz == 64) &&
16367          "Unsupported source vector element type size");
16368   Type *VecTy = FixedVectorType::get(Builder.getInt8Ty(), 16);
16369 
16370   // Create a mask to choose every nth byte from the source vector table of
16371   // bytes to create the truncated destination vector, where 'n' is the truncate
16372   // ratio. For example, for a truncate from Yxi64 to Yxi8, choose
16373   // 0,8,16,..Y*8th bytes for the little-endian format
16374   SmallVector<Constant *, 16> MaskConst;
16375   for (int Itr = 0; Itr < 16; Itr++) {
16376     if (Itr < NumElements)
16377       MaskConst.push_back(Builder.getInt8(
16378           IsLittleEndian ? Itr * TruncFactor
16379                          : Itr * TruncFactor + (TruncFactor - 1)));
16380     else
16381       MaskConst.push_back(Builder.getInt8(255));
16382   }
16383 
16384   int MaxTblSz = 128 * 4;
16385   int MaxSrcSz = SrcElemTySz * NumElements;
16386   int ElemsPerTbl =
16387       (MaxTblSz > MaxSrcSz) ? NumElements : (MaxTblSz / SrcElemTySz);
16388   assert(ElemsPerTbl <= 16 &&
16389          "Maximum elements selected using TBL instruction cannot exceed 16!");
16390 
16391   int ShuffleCount = 128 / SrcElemTySz;
16392   SmallVector<int> ShuffleLanes;
16393   for (int i = 0; i < ShuffleCount; ++i)
16394     ShuffleLanes.push_back(i);
16395 
16396   // Create TBL's table of bytes in 1,2,3 or 4 FP/SIMD registers using shuffles
16397   // over the source vector. If TBL's maximum 4 FP/SIMD registers are saturated,
16398   // call TBL & save the result in a vector of TBL results for combining later.
16399   SmallVector<Value *> Results;
16400   while (ShuffleLanes.back() < NumElements) {
16401     Parts.push_back(Builder.CreateBitCast(
16402         Builder.CreateShuffleVector(TI->getOperand(0), ShuffleLanes), VecTy));
16403 
16404     if (Parts.size() == 4) {
16405       auto *F = Intrinsic::getDeclaration(TI->getModule(),
16406                                           Intrinsic::aarch64_neon_tbl4, VecTy);
16407       Parts.push_back(ConstantVector::get(MaskConst));
16408       Results.push_back(Builder.CreateCall(F, Parts));
16409       Parts.clear();
16410     }
16411 
16412     for (int i = 0; i < ShuffleCount; ++i)
16413       ShuffleLanes[i] += ShuffleCount;
16414   }
16415 
16416   assert((Parts.empty() || Results.empty()) &&
16417          "Lowering trunc for vectors requiring different TBL instructions is "
16418          "not supported!");
16419   // Call TBL for the residual table bytes present in 1,2, or 3 FP/SIMD
16420   // registers
16421   if (!Parts.empty()) {
16422     Intrinsic::ID TblID;
16423     switch (Parts.size()) {
16424     case 1:
16425       TblID = Intrinsic::aarch64_neon_tbl1;
16426       break;
16427     case 2:
16428       TblID = Intrinsic::aarch64_neon_tbl2;
16429       break;
16430     case 3:
16431       TblID = Intrinsic::aarch64_neon_tbl3;
16432       break;
16433     }
16434 
16435     auto *F = Intrinsic::getDeclaration(TI->getModule(), TblID, VecTy);
16436     Parts.push_back(ConstantVector::get(MaskConst));
16437     Results.push_back(Builder.CreateCall(F, Parts));
16438   }
16439 
16440   // Extract the destination vector from TBL result(s) after combining them
16441   // where applicable. Currently, at most two TBLs are supported.
16442   assert(Results.size() <= 2 && "Trunc lowering does not support generation of "
16443                                 "more than 2 tbl instructions!");
16444   Value *FinalResult = Results[0];
16445   if (Results.size() == 1) {
16446     if (ElemsPerTbl < 16) {
16447       SmallVector<int> FinalMask(ElemsPerTbl);
16448       std::iota(FinalMask.begin(), FinalMask.end(), 0);
16449       FinalResult = Builder.CreateShuffleVector(Results[0], FinalMask);
16450     }
16451   } else {
16452     SmallVector<int> FinalMask(ElemsPerTbl * Results.size());
16453     if (ElemsPerTbl < 16) {
16454       std::iota(FinalMask.begin(), FinalMask.begin() + ElemsPerTbl, 0);
16455       std::iota(FinalMask.begin() + ElemsPerTbl, FinalMask.end(), 16);
16456     } else {
16457       std::iota(FinalMask.begin(), FinalMask.end(), 0);
16458     }
16459     FinalResult =
16460         Builder.CreateShuffleVector(Results[0], Results[1], FinalMask);
16461   }
16462 
16463   TI->replaceAllUsesWith(FinalResult);
16464   TI->eraseFromParent();
16465 }
16466 
optimizeExtendOrTruncateConversion(Instruction * I,Loop * L,const TargetTransformInfo & TTI) const16467 bool AArch64TargetLowering::optimizeExtendOrTruncateConversion(
16468     Instruction *I, Loop *L, const TargetTransformInfo &TTI) const {
16469   // shuffle_vector instructions are serialized when targeting SVE,
16470   // see LowerSPLAT_VECTOR. This peephole is not beneficial.
16471   if (!EnableExtToTBL || Subtarget->useSVEForFixedLengthVectors())
16472     return false;
16473 
16474   // Try to optimize conversions using tbl. This requires materializing constant
16475   // index vectors, which can increase code size and add loads. Skip the
16476   // transform unless the conversion is in a loop block guaranteed to execute
16477   // and we are not optimizing for size.
16478   Function *F = I->getParent()->getParent();
16479   if (!L || L->getHeader() != I->getParent() || F->hasMinSize() ||
16480       F->hasOptSize())
16481     return false;
16482 
16483   auto *SrcTy = dyn_cast<FixedVectorType>(I->getOperand(0)->getType());
16484   auto *DstTy = dyn_cast<FixedVectorType>(I->getType());
16485   if (!SrcTy || !DstTy)
16486     return false;
16487 
16488   // Convert 'zext <Y x i8> %x to <Y x i8X>' to a shuffle that can be
16489   // lowered to tbl instructions to insert the original i8 elements
16490   // into i8x lanes. This is enabled for cases where it is beneficial.
16491   auto *ZExt = dyn_cast<ZExtInst>(I);
16492   if (ZExt && SrcTy->getElementType()->isIntegerTy(8)) {
16493     auto DstWidth = DstTy->getElementType()->getScalarSizeInBits();
16494     if (DstWidth % 8 != 0)
16495       return false;
16496 
16497     auto *TruncDstType =
16498         cast<FixedVectorType>(VectorType::getTruncatedElementVectorType(DstTy));
16499     // If the ZExt can be lowered to a single ZExt to the next power-of-2 and
16500     // the remaining ZExt folded into the user, don't use tbl lowering.
16501     auto SrcWidth = SrcTy->getElementType()->getScalarSizeInBits();
16502     if (TTI.getCastInstrCost(I->getOpcode(), DstTy, TruncDstType,
16503                              TargetTransformInfo::getCastContextHint(I),
16504                              TTI::TCK_SizeAndLatency, I) == TTI::TCC_Free) {
16505       if (SrcWidth * 2 >= TruncDstType->getElementType()->getScalarSizeInBits())
16506         return false;
16507 
16508       DstTy = TruncDstType;
16509     }
16510     IRBuilder<> Builder(ZExt);
16511     Value *Result = createTblShuffleForZExt(
16512         Builder, ZExt->getOperand(0), cast<FixedVectorType>(ZExt->getType()),
16513         DstTy, Subtarget->isLittleEndian());
16514     if (!Result)
16515       return false;
16516     ZExt->replaceAllUsesWith(Result);
16517     ZExt->eraseFromParent();
16518     return true;
16519   }
16520 
16521   auto *UIToFP = dyn_cast<UIToFPInst>(I);
16522   if (UIToFP && SrcTy->getElementType()->isIntegerTy(8) &&
16523       DstTy->getElementType()->isFloatTy()) {
16524     IRBuilder<> Builder(I);
16525     Value *ZExt = createTblShuffleForZExt(
16526         Builder, I->getOperand(0), FixedVectorType::getInteger(DstTy),
16527         FixedVectorType::getInteger(DstTy), Subtarget->isLittleEndian());
16528     assert(ZExt && "Cannot fail for the i8 to float conversion");
16529     auto *UI = Builder.CreateUIToFP(ZExt, DstTy);
16530     I->replaceAllUsesWith(UI);
16531     I->eraseFromParent();
16532     return true;
16533   }
16534 
16535   auto *SIToFP = dyn_cast<SIToFPInst>(I);
16536   if (SIToFP && SrcTy->getElementType()->isIntegerTy(8) &&
16537       DstTy->getElementType()->isFloatTy()) {
16538     IRBuilder<> Builder(I);
16539     auto *Shuffle = createTblShuffleForSExt(Builder, I->getOperand(0),
16540                                             FixedVectorType::getInteger(DstTy),
16541                                             Subtarget->isLittleEndian());
16542     assert(Shuffle && "Cannot fail for the i8 to float conversion");
16543     auto *Cast = Builder.CreateBitCast(Shuffle, VectorType::getInteger(DstTy));
16544     auto *AShr = Builder.CreateAShr(Cast, 24, "", true);
16545     auto *SI = Builder.CreateSIToFP(AShr, DstTy);
16546     I->replaceAllUsesWith(SI);
16547     I->eraseFromParent();
16548     return true;
16549   }
16550 
16551   // Convert 'fptoui <(8|16) x float> to <(8|16) x i8>' to a wide fptoui
16552   // followed by a truncate lowered to using tbl.4.
16553   auto *FPToUI = dyn_cast<FPToUIInst>(I);
16554   if (FPToUI &&
16555       (SrcTy->getNumElements() == 8 || SrcTy->getNumElements() == 16) &&
16556       SrcTy->getElementType()->isFloatTy() &&
16557       DstTy->getElementType()->isIntegerTy(8)) {
16558     IRBuilder<> Builder(I);
16559     auto *WideConv = Builder.CreateFPToUI(FPToUI->getOperand(0),
16560                                           VectorType::getInteger(SrcTy));
16561     auto *TruncI = Builder.CreateTrunc(WideConv, DstTy);
16562     I->replaceAllUsesWith(TruncI);
16563     I->eraseFromParent();
16564     createTblForTrunc(cast<TruncInst>(TruncI), Subtarget->isLittleEndian());
16565     return true;
16566   }
16567 
16568   // Convert 'trunc <(8|16) x (i32|i64)> %x to <(8|16) x i8>' to an appropriate
16569   // tbl instruction selecting the lowest/highest (little/big endian) 8 bits
16570   // per lane of the input that is represented using 1,2,3 or 4 128-bit table
16571   // registers
16572   auto *TI = dyn_cast<TruncInst>(I);
16573   if (TI && DstTy->getElementType()->isIntegerTy(8) &&
16574       ((SrcTy->getElementType()->isIntegerTy(32) ||
16575         SrcTy->getElementType()->isIntegerTy(64)) &&
16576        (SrcTy->getNumElements() == 16 || SrcTy->getNumElements() == 8))) {
16577     createTblForTrunc(TI, Subtarget->isLittleEndian());
16578     return true;
16579   }
16580 
16581   return false;
16582 }
16583 
hasPairedLoad(EVT LoadedType,Align & RequiredAligment) const16584 bool AArch64TargetLowering::hasPairedLoad(EVT LoadedType,
16585                                           Align &RequiredAligment) const {
16586   if (!LoadedType.isSimple() ||
16587       (!LoadedType.isInteger() && !LoadedType.isFloatingPoint()))
16588     return false;
16589   // Cyclone supports unaligned accesses.
16590   RequiredAligment = Align(1);
16591   unsigned NumBits = LoadedType.getSizeInBits();
16592   return NumBits == 32 || NumBits == 64;
16593 }
16594 
16595 /// A helper function for determining the number of interleaved accesses we
16596 /// will generate when lowering accesses of the given type.
getNumInterleavedAccesses(VectorType * VecTy,const DataLayout & DL,bool UseScalable) const16597 unsigned AArch64TargetLowering::getNumInterleavedAccesses(
16598     VectorType *VecTy, const DataLayout &DL, bool UseScalable) const {
16599   unsigned VecSize = 128;
16600   unsigned ElSize = DL.getTypeSizeInBits(VecTy->getElementType());
16601   unsigned MinElts = VecTy->getElementCount().getKnownMinValue();
16602   if (UseScalable && isa<FixedVectorType>(VecTy))
16603     VecSize = std::max(Subtarget->getMinSVEVectorSizeInBits(), 128u);
16604   return std::max<unsigned>(1, (MinElts * ElSize + 127) / VecSize);
16605 }
16606 
16607 MachineMemOperand::Flags
getTargetMMOFlags(const Instruction & I) const16608 AArch64TargetLowering::getTargetMMOFlags(const Instruction &I) const {
16609   if (Subtarget->getProcFamily() == AArch64Subtarget::Falkor &&
16610       I.hasMetadata(FALKOR_STRIDED_ACCESS_MD))
16611     return MOStridedAccess;
16612   return MachineMemOperand::MONone;
16613 }
16614 
isLegalInterleavedAccessType(VectorType * VecTy,const DataLayout & DL,bool & UseScalable) const16615 bool AArch64TargetLowering::isLegalInterleavedAccessType(
16616     VectorType *VecTy, const DataLayout &DL, bool &UseScalable) const {
16617   unsigned ElSize = DL.getTypeSizeInBits(VecTy->getElementType());
16618   auto EC = VecTy->getElementCount();
16619   unsigned MinElts = EC.getKnownMinValue();
16620 
16621   UseScalable = false;
16622 
16623   if (isa<FixedVectorType>(VecTy) && !Subtarget->isNeonAvailable() &&
16624       (!Subtarget->useSVEForFixedLengthVectors() ||
16625        !getSVEPredPatternFromNumElements(MinElts)))
16626     return false;
16627 
16628   if (isa<ScalableVectorType>(VecTy) &&
16629       !Subtarget->isSVEorStreamingSVEAvailable())
16630     return false;
16631 
16632   // Ensure the number of vector elements is greater than 1.
16633   if (MinElts < 2)
16634     return false;
16635 
16636   // Ensure the element type is legal.
16637   if (ElSize != 8 && ElSize != 16 && ElSize != 32 && ElSize != 64)
16638     return false;
16639 
16640   if (EC.isScalable()) {
16641     UseScalable = true;
16642     return isPowerOf2_32(MinElts) && (MinElts * ElSize) % 128 == 0;
16643   }
16644 
16645   unsigned VecSize = DL.getTypeSizeInBits(VecTy);
16646   if (Subtarget->useSVEForFixedLengthVectors()) {
16647     unsigned MinSVEVectorSize =
16648         std::max(Subtarget->getMinSVEVectorSizeInBits(), 128u);
16649     if (VecSize % MinSVEVectorSize == 0 ||
16650         (VecSize < MinSVEVectorSize && isPowerOf2_32(MinElts) &&
16651          (!Subtarget->isNeonAvailable() || VecSize > 128))) {
16652       UseScalable = true;
16653       return true;
16654     }
16655   }
16656 
16657   // Ensure the total vector size is 64 or a multiple of 128. Types larger than
16658   // 128 will be split into multiple interleaved accesses.
16659   return Subtarget->isNeonAvailable() && (VecSize == 64 || VecSize % 128 == 0);
16660 }
16661 
getSVEContainerIRType(FixedVectorType * VTy)16662 static ScalableVectorType *getSVEContainerIRType(FixedVectorType *VTy) {
16663   if (VTy->getElementType() == Type::getDoubleTy(VTy->getContext()))
16664     return ScalableVectorType::get(VTy->getElementType(), 2);
16665 
16666   if (VTy->getElementType() == Type::getFloatTy(VTy->getContext()))
16667     return ScalableVectorType::get(VTy->getElementType(), 4);
16668 
16669   if (VTy->getElementType() == Type::getBFloatTy(VTy->getContext()))
16670     return ScalableVectorType::get(VTy->getElementType(), 8);
16671 
16672   if (VTy->getElementType() == Type::getHalfTy(VTy->getContext()))
16673     return ScalableVectorType::get(VTy->getElementType(), 8);
16674 
16675   if (VTy->getElementType() == Type::getInt64Ty(VTy->getContext()))
16676     return ScalableVectorType::get(VTy->getElementType(), 2);
16677 
16678   if (VTy->getElementType() == Type::getInt32Ty(VTy->getContext()))
16679     return ScalableVectorType::get(VTy->getElementType(), 4);
16680 
16681   if (VTy->getElementType() == Type::getInt16Ty(VTy->getContext()))
16682     return ScalableVectorType::get(VTy->getElementType(), 8);
16683 
16684   if (VTy->getElementType() == Type::getInt8Ty(VTy->getContext()))
16685     return ScalableVectorType::get(VTy->getElementType(), 16);
16686 
16687   llvm_unreachable("Cannot handle input vector type");
16688 }
16689 
getStructuredLoadFunction(Module * M,unsigned Factor,bool Scalable,Type * LDVTy,Type * PtrTy)16690 static Function *getStructuredLoadFunction(Module *M, unsigned Factor,
16691                                            bool Scalable, Type *LDVTy,
16692                                            Type *PtrTy) {
16693   assert(Factor >= 2 && Factor <= 4 && "Invalid interleave factor");
16694   static const Intrinsic::ID SVELoads[3] = {Intrinsic::aarch64_sve_ld2_sret,
16695                                             Intrinsic::aarch64_sve_ld3_sret,
16696                                             Intrinsic::aarch64_sve_ld4_sret};
16697   static const Intrinsic::ID NEONLoads[3] = {Intrinsic::aarch64_neon_ld2,
16698                                              Intrinsic::aarch64_neon_ld3,
16699                                              Intrinsic::aarch64_neon_ld4};
16700   if (Scalable)
16701     return Intrinsic::getDeclaration(M, SVELoads[Factor - 2], {LDVTy});
16702 
16703   return Intrinsic::getDeclaration(M, NEONLoads[Factor - 2], {LDVTy, PtrTy});
16704 }
16705 
getStructuredStoreFunction(Module * M,unsigned Factor,bool Scalable,Type * STVTy,Type * PtrTy)16706 static Function *getStructuredStoreFunction(Module *M, unsigned Factor,
16707                                             bool Scalable, Type *STVTy,
16708                                             Type *PtrTy) {
16709   assert(Factor >= 2 && Factor <= 4 && "Invalid interleave factor");
16710   static const Intrinsic::ID SVEStores[3] = {Intrinsic::aarch64_sve_st2,
16711                                              Intrinsic::aarch64_sve_st3,
16712                                              Intrinsic::aarch64_sve_st4};
16713   static const Intrinsic::ID NEONStores[3] = {Intrinsic::aarch64_neon_st2,
16714                                               Intrinsic::aarch64_neon_st3,
16715                                               Intrinsic::aarch64_neon_st4};
16716   if (Scalable)
16717     return Intrinsic::getDeclaration(M, SVEStores[Factor - 2], {STVTy});
16718 
16719   return Intrinsic::getDeclaration(M, NEONStores[Factor - 2], {STVTy, PtrTy});
16720 }
16721 
16722 /// Lower an interleaved load into a ldN intrinsic.
16723 ///
16724 /// E.g. Lower an interleaved load (Factor = 2):
16725 ///        %wide.vec = load <8 x i32>, <8 x i32>* %ptr
16726 ///        %v0 = shuffle %wide.vec, undef, <0, 2, 4, 6>  ; Extract even elements
16727 ///        %v1 = shuffle %wide.vec, undef, <1, 3, 5, 7>  ; Extract odd elements
16728 ///
16729 ///      Into:
16730 ///        %ld2 = { <4 x i32>, <4 x i32> } call llvm.aarch64.neon.ld2(%ptr)
16731 ///        %vec0 = extractelement { <4 x i32>, <4 x i32> } %ld2, i32 0
16732 ///        %vec1 = extractelement { <4 x i32>, <4 x i32> } %ld2, i32 1
lowerInterleavedLoad(LoadInst * LI,ArrayRef<ShuffleVectorInst * > Shuffles,ArrayRef<unsigned> Indices,unsigned Factor) const16733 bool AArch64TargetLowering::lowerInterleavedLoad(
16734     LoadInst *LI, ArrayRef<ShuffleVectorInst *> Shuffles,
16735     ArrayRef<unsigned> Indices, unsigned Factor) const {
16736   assert(Factor >= 2 && Factor <= getMaxSupportedInterleaveFactor() &&
16737          "Invalid interleave factor");
16738   assert(!Shuffles.empty() && "Empty shufflevector input");
16739   assert(Shuffles.size() == Indices.size() &&
16740          "Unmatched number of shufflevectors and indices");
16741 
16742   const DataLayout &DL = LI->getDataLayout();
16743 
16744   VectorType *VTy = Shuffles[0]->getType();
16745 
16746   // Skip if we do not have NEON and skip illegal vector types. We can
16747   // "legalize" wide vector types into multiple interleaved accesses as long as
16748   // the vector types are divisible by 128.
16749   bool UseScalable;
16750   if (!isLegalInterleavedAccessType(VTy, DL, UseScalable))
16751     return false;
16752 
16753   unsigned NumLoads = getNumInterleavedAccesses(VTy, DL, UseScalable);
16754 
16755   auto *FVTy = cast<FixedVectorType>(VTy);
16756 
16757   // A pointer vector can not be the return type of the ldN intrinsics. Need to
16758   // load integer vectors first and then convert to pointer vectors.
16759   Type *EltTy = FVTy->getElementType();
16760   if (EltTy->isPointerTy())
16761     FVTy =
16762         FixedVectorType::get(DL.getIntPtrType(EltTy), FVTy->getNumElements());
16763 
16764   // If we're going to generate more than one load, reset the sub-vector type
16765   // to something legal.
16766   FVTy = FixedVectorType::get(FVTy->getElementType(),
16767                               FVTy->getNumElements() / NumLoads);
16768 
16769   auto *LDVTy =
16770       UseScalable ? cast<VectorType>(getSVEContainerIRType(FVTy)) : FVTy;
16771 
16772   IRBuilder<> Builder(LI);
16773 
16774   // The base address of the load.
16775   Value *BaseAddr = LI->getPointerOperand();
16776 
16777   Type *PtrTy = LI->getPointerOperandType();
16778   Type *PredTy = VectorType::get(Type::getInt1Ty(LDVTy->getContext()),
16779                                  LDVTy->getElementCount());
16780 
16781   Function *LdNFunc = getStructuredLoadFunction(LI->getModule(), Factor,
16782                                                 UseScalable, LDVTy, PtrTy);
16783 
16784   // Holds sub-vectors extracted from the load intrinsic return values. The
16785   // sub-vectors are associated with the shufflevector instructions they will
16786   // replace.
16787   DenseMap<ShuffleVectorInst *, SmallVector<Value *, 4>> SubVecs;
16788 
16789   Value *PTrue = nullptr;
16790   if (UseScalable) {
16791     std::optional<unsigned> PgPattern =
16792         getSVEPredPatternFromNumElements(FVTy->getNumElements());
16793     if (Subtarget->getMinSVEVectorSizeInBits() ==
16794             Subtarget->getMaxSVEVectorSizeInBits() &&
16795         Subtarget->getMinSVEVectorSizeInBits() == DL.getTypeSizeInBits(FVTy))
16796       PgPattern = AArch64SVEPredPattern::all;
16797 
16798     auto *PTruePat =
16799         ConstantInt::get(Type::getInt32Ty(LDVTy->getContext()), *PgPattern);
16800     PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue, {PredTy},
16801                                     {PTruePat});
16802   }
16803 
16804   for (unsigned LoadCount = 0; LoadCount < NumLoads; ++LoadCount) {
16805 
16806     // If we're generating more than one load, compute the base address of
16807     // subsequent loads as an offset from the previous.
16808     if (LoadCount > 0)
16809       BaseAddr = Builder.CreateConstGEP1_32(LDVTy->getElementType(), BaseAddr,
16810                                             FVTy->getNumElements() * Factor);
16811 
16812     CallInst *LdN;
16813     if (UseScalable)
16814       LdN = Builder.CreateCall(LdNFunc, {PTrue, BaseAddr}, "ldN");
16815     else
16816       LdN = Builder.CreateCall(LdNFunc, BaseAddr, "ldN");
16817 
16818     // Extract and store the sub-vectors returned by the load intrinsic.
16819     for (unsigned i = 0; i < Shuffles.size(); i++) {
16820       ShuffleVectorInst *SVI = Shuffles[i];
16821       unsigned Index = Indices[i];
16822 
16823       Value *SubVec = Builder.CreateExtractValue(LdN, Index);
16824 
16825       if (UseScalable)
16826         SubVec = Builder.CreateExtractVector(
16827             FVTy, SubVec,
16828             ConstantInt::get(Type::getInt64Ty(VTy->getContext()), 0));
16829 
16830       // Convert the integer vector to pointer vector if the element is pointer.
16831       if (EltTy->isPointerTy())
16832         SubVec = Builder.CreateIntToPtr(
16833             SubVec, FixedVectorType::get(SVI->getType()->getElementType(),
16834                                          FVTy->getNumElements()));
16835 
16836       SubVecs[SVI].push_back(SubVec);
16837     }
16838   }
16839 
16840   // Replace uses of the shufflevector instructions with the sub-vectors
16841   // returned by the load intrinsic. If a shufflevector instruction is
16842   // associated with more than one sub-vector, those sub-vectors will be
16843   // concatenated into a single wide vector.
16844   for (ShuffleVectorInst *SVI : Shuffles) {
16845     auto &SubVec = SubVecs[SVI];
16846     auto *WideVec =
16847         SubVec.size() > 1 ? concatenateVectors(Builder, SubVec) : SubVec[0];
16848     SVI->replaceAllUsesWith(WideVec);
16849   }
16850 
16851   return true;
16852 }
16853 
16854 template <typename Iter>
hasNearbyPairedStore(Iter It,Iter End,Value * Ptr,const DataLayout & DL)16855 bool hasNearbyPairedStore(Iter It, Iter End, Value *Ptr, const DataLayout &DL) {
16856   int MaxLookupDist = 20;
16857   unsigned IdxWidth = DL.getIndexSizeInBits(0);
16858   APInt OffsetA(IdxWidth, 0), OffsetB(IdxWidth, 0);
16859   const Value *PtrA1 =
16860       Ptr->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetA);
16861 
16862   while (++It != End) {
16863     if (It->isDebugOrPseudoInst())
16864       continue;
16865     if (MaxLookupDist-- == 0)
16866       break;
16867     if (const auto *SI = dyn_cast<StoreInst>(&*It)) {
16868       const Value *PtrB1 =
16869           SI->getPointerOperand()->stripAndAccumulateInBoundsConstantOffsets(
16870               DL, OffsetB);
16871       if (PtrA1 == PtrB1 &&
16872           (OffsetA.sextOrTrunc(IdxWidth) - OffsetB.sextOrTrunc(IdxWidth))
16873                   .abs() == 16)
16874         return true;
16875     }
16876   }
16877 
16878   return false;
16879 }
16880 
16881 /// Lower an interleaved store into a stN intrinsic.
16882 ///
16883 /// E.g. Lower an interleaved store (Factor = 3):
16884 ///        %i.vec = shuffle <8 x i32> %v0, <8 x i32> %v1,
16885 ///                 <0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11>
16886 ///        store <12 x i32> %i.vec, <12 x i32>* %ptr
16887 ///
16888 ///      Into:
16889 ///        %sub.v0 = shuffle <8 x i32> %v0, <8 x i32> v1, <0, 1, 2, 3>
16890 ///        %sub.v1 = shuffle <8 x i32> %v0, <8 x i32> v1, <4, 5, 6, 7>
16891 ///        %sub.v2 = shuffle <8 x i32> %v0, <8 x i32> v1, <8, 9, 10, 11>
16892 ///        call void llvm.aarch64.neon.st3(%sub.v0, %sub.v1, %sub.v2, %ptr)
16893 ///
16894 /// Note that the new shufflevectors will be removed and we'll only generate one
16895 /// st3 instruction in CodeGen.
16896 ///
16897 /// Example for a more general valid mask (Factor 3). Lower:
16898 ///        %i.vec = shuffle <32 x i32> %v0, <32 x i32> %v1,
16899 ///                 <4, 32, 16, 5, 33, 17, 6, 34, 18, 7, 35, 19>
16900 ///        store <12 x i32> %i.vec, <12 x i32>* %ptr
16901 ///
16902 ///      Into:
16903 ///        %sub.v0 = shuffle <32 x i32> %v0, <32 x i32> v1, <4, 5, 6, 7>
16904 ///        %sub.v1 = shuffle <32 x i32> %v0, <32 x i32> v1, <32, 33, 34, 35>
16905 ///        %sub.v2 = shuffle <32 x i32> %v0, <32 x i32> v1, <16, 17, 18, 19>
16906 ///        call void llvm.aarch64.neon.st3(%sub.v0, %sub.v1, %sub.v2, %ptr)
lowerInterleavedStore(StoreInst * SI,ShuffleVectorInst * SVI,unsigned Factor) const16907 bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI,
16908                                                   ShuffleVectorInst *SVI,
16909                                                   unsigned Factor) const {
16910 
16911   assert(Factor >= 2 && Factor <= getMaxSupportedInterleaveFactor() &&
16912          "Invalid interleave factor");
16913 
16914   auto *VecTy = cast<FixedVectorType>(SVI->getType());
16915   assert(VecTy->getNumElements() % Factor == 0 && "Invalid interleaved store");
16916 
16917   unsigned LaneLen = VecTy->getNumElements() / Factor;
16918   Type *EltTy = VecTy->getElementType();
16919   auto *SubVecTy = FixedVectorType::get(EltTy, LaneLen);
16920 
16921   const DataLayout &DL = SI->getDataLayout();
16922   bool UseScalable;
16923 
16924   // Skip if we do not have NEON and skip illegal vector types. We can
16925   // "legalize" wide vector types into multiple interleaved accesses as long as
16926   // the vector types are divisible by 128.
16927   if (!isLegalInterleavedAccessType(SubVecTy, DL, UseScalable))
16928     return false;
16929 
16930   unsigned NumStores = getNumInterleavedAccesses(SubVecTy, DL, UseScalable);
16931 
16932   Value *Op0 = SVI->getOperand(0);
16933   Value *Op1 = SVI->getOperand(1);
16934   IRBuilder<> Builder(SI);
16935 
16936   // StN intrinsics don't support pointer vectors as arguments. Convert pointer
16937   // vectors to integer vectors.
16938   if (EltTy->isPointerTy()) {
16939     Type *IntTy = DL.getIntPtrType(EltTy);
16940     unsigned NumOpElts =
16941         cast<FixedVectorType>(Op0->getType())->getNumElements();
16942 
16943     // Convert to the corresponding integer vector.
16944     auto *IntVecTy = FixedVectorType::get(IntTy, NumOpElts);
16945     Op0 = Builder.CreatePtrToInt(Op0, IntVecTy);
16946     Op1 = Builder.CreatePtrToInt(Op1, IntVecTy);
16947 
16948     SubVecTy = FixedVectorType::get(IntTy, LaneLen);
16949   }
16950 
16951   // If we're going to generate more than one store, reset the lane length
16952   // and sub-vector type to something legal.
16953   LaneLen /= NumStores;
16954   SubVecTy = FixedVectorType::get(SubVecTy->getElementType(), LaneLen);
16955 
16956   auto *STVTy = UseScalable ? cast<VectorType>(getSVEContainerIRType(SubVecTy))
16957                             : SubVecTy;
16958 
16959   // The base address of the store.
16960   Value *BaseAddr = SI->getPointerOperand();
16961 
16962   auto Mask = SVI->getShuffleMask();
16963 
16964   // Sanity check if all the indices are NOT in range.
16965   // If mask is `poison`, `Mask` may be a vector of -1s.
16966   // If all of them are `poison`, OOB read will happen later.
16967   if (llvm::all_of(Mask, [](int Idx) { return Idx == PoisonMaskElem; })) {
16968     return false;
16969   }
16970   // A 64bit st2 which does not start at element 0 will involved adding extra
16971   // ext elements making the st2 unprofitable, and if there is a nearby store
16972   // that points to BaseAddr+16 or BaseAddr-16 then it can be better left as a
16973   // zip;ldp pair which has higher throughput.
16974   if (Factor == 2 && SubVecTy->getPrimitiveSizeInBits() == 64 &&
16975       (Mask[0] != 0 ||
16976        hasNearbyPairedStore(SI->getIterator(), SI->getParent()->end(), BaseAddr,
16977                             DL) ||
16978        hasNearbyPairedStore(SI->getReverseIterator(), SI->getParent()->rend(),
16979                             BaseAddr, DL)))
16980     return false;
16981 
16982   Type *PtrTy = SI->getPointerOperandType();
16983   Type *PredTy = VectorType::get(Type::getInt1Ty(STVTy->getContext()),
16984                                  STVTy->getElementCount());
16985 
16986   Function *StNFunc = getStructuredStoreFunction(SI->getModule(), Factor,
16987                                                  UseScalable, STVTy, PtrTy);
16988 
16989   Value *PTrue = nullptr;
16990   if (UseScalable) {
16991     std::optional<unsigned> PgPattern =
16992         getSVEPredPatternFromNumElements(SubVecTy->getNumElements());
16993     if (Subtarget->getMinSVEVectorSizeInBits() ==
16994             Subtarget->getMaxSVEVectorSizeInBits() &&
16995         Subtarget->getMinSVEVectorSizeInBits() ==
16996             DL.getTypeSizeInBits(SubVecTy))
16997       PgPattern = AArch64SVEPredPattern::all;
16998 
16999     auto *PTruePat =
17000         ConstantInt::get(Type::getInt32Ty(STVTy->getContext()), *PgPattern);
17001     PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue, {PredTy},
17002                                     {PTruePat});
17003   }
17004 
17005   for (unsigned StoreCount = 0; StoreCount < NumStores; ++StoreCount) {
17006 
17007     SmallVector<Value *, 5> Ops;
17008 
17009     // Split the shufflevector operands into sub vectors for the new stN call.
17010     for (unsigned i = 0; i < Factor; i++) {
17011       Value *Shuffle;
17012       unsigned IdxI = StoreCount * LaneLen * Factor + i;
17013       if (Mask[IdxI] >= 0) {
17014         Shuffle = Builder.CreateShuffleVector(
17015             Op0, Op1, createSequentialMask(Mask[IdxI], LaneLen, 0));
17016       } else {
17017         unsigned StartMask = 0;
17018         for (unsigned j = 1; j < LaneLen; j++) {
17019           unsigned IdxJ = StoreCount * LaneLen * Factor + j * Factor + i;
17020           if (Mask[IdxJ] >= 0) {
17021             StartMask = Mask[IdxJ] - j;
17022             break;
17023           }
17024         }
17025         // Note: Filling undef gaps with random elements is ok, since
17026         // those elements were being written anyway (with undefs).
17027         // In the case of all undefs we're defaulting to using elems from 0
17028         // Note: StartMask cannot be negative, it's checked in
17029         // isReInterleaveMask
17030         Shuffle = Builder.CreateShuffleVector(
17031             Op0, Op1, createSequentialMask(StartMask, LaneLen, 0));
17032       }
17033 
17034       if (UseScalable)
17035         Shuffle = Builder.CreateInsertVector(
17036             STVTy, UndefValue::get(STVTy), Shuffle,
17037             ConstantInt::get(Type::getInt64Ty(STVTy->getContext()), 0));
17038 
17039       Ops.push_back(Shuffle);
17040     }
17041 
17042     if (UseScalable)
17043       Ops.push_back(PTrue);
17044 
17045     // If we generating more than one store, we compute the base address of
17046     // subsequent stores as an offset from the previous.
17047     if (StoreCount > 0)
17048       BaseAddr = Builder.CreateConstGEP1_32(SubVecTy->getElementType(),
17049                                             BaseAddr, LaneLen * Factor);
17050 
17051     Ops.push_back(BaseAddr);
17052     Builder.CreateCall(StNFunc, Ops);
17053   }
17054   return true;
17055 }
17056 
lowerDeinterleaveIntrinsicToLoad(IntrinsicInst * DI,LoadInst * LI) const17057 bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad(
17058     IntrinsicInst *DI, LoadInst *LI) const {
17059   // Only deinterleave2 supported at present.
17060   if (DI->getIntrinsicID() != Intrinsic::vector_deinterleave2)
17061     return false;
17062 
17063   // Only a factor of 2 supported at present.
17064   const unsigned Factor = 2;
17065 
17066   VectorType *VTy = cast<VectorType>(DI->getType()->getContainedType(0));
17067   const DataLayout &DL = DI->getDataLayout();
17068   bool UseScalable;
17069   if (!isLegalInterleavedAccessType(VTy, DL, UseScalable))
17070     return false;
17071 
17072   // TODO: Add support for using SVE instructions with fixed types later, using
17073   // the code from lowerInterleavedLoad to obtain the correct container type.
17074   if (UseScalable && !VTy->isScalableTy())
17075     return false;
17076 
17077   unsigned NumLoads = getNumInterleavedAccesses(VTy, DL, UseScalable);
17078 
17079   VectorType *LdTy =
17080       VectorType::get(VTy->getElementType(),
17081                       VTy->getElementCount().divideCoefficientBy(NumLoads));
17082 
17083   Type *PtrTy = LI->getPointerOperandType();
17084   Function *LdNFunc = getStructuredLoadFunction(DI->getModule(), Factor,
17085                                                 UseScalable, LdTy, PtrTy);
17086 
17087   IRBuilder<> Builder(LI);
17088 
17089   Value *Pred = nullptr;
17090   if (UseScalable)
17091     Pred =
17092         Builder.CreateVectorSplat(LdTy->getElementCount(), Builder.getTrue());
17093 
17094   Value *BaseAddr = LI->getPointerOperand();
17095   Value *Result;
17096   if (NumLoads > 1) {
17097     Value *Left = PoisonValue::get(VTy);
17098     Value *Right = PoisonValue::get(VTy);
17099 
17100     for (unsigned I = 0; I < NumLoads; ++I) {
17101       Value *Offset = Builder.getInt64(I * Factor);
17102 
17103       Value *Address = Builder.CreateGEP(LdTy, BaseAddr, {Offset});
17104       Value *LdN = nullptr;
17105       if (UseScalable)
17106         LdN = Builder.CreateCall(LdNFunc, {Pred, Address}, "ldN");
17107       else
17108         LdN = Builder.CreateCall(LdNFunc, Address, "ldN");
17109 
17110       Value *Idx =
17111           Builder.getInt64(I * LdTy->getElementCount().getKnownMinValue());
17112       Left = Builder.CreateInsertVector(
17113           VTy, Left, Builder.CreateExtractValue(LdN, 0), Idx);
17114       Right = Builder.CreateInsertVector(
17115           VTy, Right, Builder.CreateExtractValue(LdN, 1), Idx);
17116     }
17117 
17118     Result = PoisonValue::get(DI->getType());
17119     Result = Builder.CreateInsertValue(Result, Left, 0);
17120     Result = Builder.CreateInsertValue(Result, Right, 1);
17121   } else {
17122     if (UseScalable)
17123       Result = Builder.CreateCall(LdNFunc, {Pred, BaseAddr}, "ldN");
17124     else
17125       Result = Builder.CreateCall(LdNFunc, BaseAddr, "ldN");
17126   }
17127 
17128   DI->replaceAllUsesWith(Result);
17129   return true;
17130 }
17131 
lowerInterleaveIntrinsicToStore(IntrinsicInst * II,StoreInst * SI) const17132 bool AArch64TargetLowering::lowerInterleaveIntrinsicToStore(
17133     IntrinsicInst *II, StoreInst *SI) const {
17134   // Only interleave2 supported at present.
17135   if (II->getIntrinsicID() != Intrinsic::vector_interleave2)
17136     return false;
17137 
17138   // Only a factor of 2 supported at present.
17139   const unsigned Factor = 2;
17140 
17141   VectorType *VTy = cast<VectorType>(II->getOperand(0)->getType());
17142   const DataLayout &DL = II->getDataLayout();
17143   bool UseScalable;
17144   if (!isLegalInterleavedAccessType(VTy, DL, UseScalable))
17145     return false;
17146 
17147   // TODO: Add support for using SVE instructions with fixed types later, using
17148   // the code from lowerInterleavedStore to obtain the correct container type.
17149   if (UseScalable && !VTy->isScalableTy())
17150     return false;
17151 
17152   unsigned NumStores = getNumInterleavedAccesses(VTy, DL, UseScalable);
17153 
17154   VectorType *StTy =
17155       VectorType::get(VTy->getElementType(),
17156                       VTy->getElementCount().divideCoefficientBy(NumStores));
17157 
17158   Type *PtrTy = SI->getPointerOperandType();
17159   Function *StNFunc = getStructuredStoreFunction(SI->getModule(), Factor,
17160                                                  UseScalable, StTy, PtrTy);
17161 
17162   IRBuilder<> Builder(SI);
17163 
17164   Value *BaseAddr = SI->getPointerOperand();
17165   Value *Pred = nullptr;
17166 
17167   if (UseScalable)
17168     Pred =
17169         Builder.CreateVectorSplat(StTy->getElementCount(), Builder.getTrue());
17170 
17171   Value *L = II->getOperand(0);
17172   Value *R = II->getOperand(1);
17173 
17174   for (unsigned I = 0; I < NumStores; ++I) {
17175     Value *Address = BaseAddr;
17176     if (NumStores > 1) {
17177       Value *Offset = Builder.getInt64(I * Factor);
17178       Address = Builder.CreateGEP(StTy, BaseAddr, {Offset});
17179 
17180       Value *Idx =
17181           Builder.getInt64(I * StTy->getElementCount().getKnownMinValue());
17182       L = Builder.CreateExtractVector(StTy, II->getOperand(0), Idx);
17183       R = Builder.CreateExtractVector(StTy, II->getOperand(1), Idx);
17184     }
17185 
17186     if (UseScalable)
17187       Builder.CreateCall(StNFunc, {L, R, Pred, Address});
17188     else
17189       Builder.CreateCall(StNFunc, {L, R, Address});
17190   }
17191 
17192   return true;
17193 }
17194 
getOptimalMemOpType(const MemOp & Op,const AttributeList & FuncAttributes) const17195 EVT AArch64TargetLowering::getOptimalMemOpType(
17196     const MemOp &Op, const AttributeList &FuncAttributes) const {
17197   bool CanImplicitFloat = !FuncAttributes.hasFnAttr(Attribute::NoImplicitFloat);
17198   bool CanUseNEON = Subtarget->hasNEON() && CanImplicitFloat;
17199   bool CanUseFP = Subtarget->hasFPARMv8() && CanImplicitFloat;
17200   // Only use AdvSIMD to implement memset of 32-byte and above. It would have
17201   // taken one instruction to materialize the v2i64 zero and one store (with
17202   // restrictive addressing mode). Just do i64 stores.
17203   bool IsSmallMemset = Op.isMemset() && Op.size() < 32;
17204   auto AlignmentIsAcceptable = [&](EVT VT, Align AlignCheck) {
17205     if (Op.isAligned(AlignCheck))
17206       return true;
17207     unsigned Fast;
17208     return allowsMisalignedMemoryAccesses(VT, 0, Align(1),
17209                                           MachineMemOperand::MONone, &Fast) &&
17210            Fast;
17211   };
17212 
17213   if (CanUseNEON && Op.isMemset() && !IsSmallMemset &&
17214       AlignmentIsAcceptable(MVT::v16i8, Align(16)))
17215     return MVT::v16i8;
17216   if (CanUseFP && !IsSmallMemset && AlignmentIsAcceptable(MVT::f128, Align(16)))
17217     return MVT::f128;
17218   if (Op.size() >= 8 && AlignmentIsAcceptable(MVT::i64, Align(8)))
17219     return MVT::i64;
17220   if (Op.size() >= 4 && AlignmentIsAcceptable(MVT::i32, Align(4)))
17221     return MVT::i32;
17222   return MVT::Other;
17223 }
17224 
getOptimalMemOpLLT(const MemOp & Op,const AttributeList & FuncAttributes) const17225 LLT AArch64TargetLowering::getOptimalMemOpLLT(
17226     const MemOp &Op, const AttributeList &FuncAttributes) const {
17227   bool CanImplicitFloat = !FuncAttributes.hasFnAttr(Attribute::NoImplicitFloat);
17228   bool CanUseNEON = Subtarget->hasNEON() && CanImplicitFloat;
17229   bool CanUseFP = Subtarget->hasFPARMv8() && CanImplicitFloat;
17230   // Only use AdvSIMD to implement memset of 32-byte and above. It would have
17231   // taken one instruction to materialize the v2i64 zero and one store (with
17232   // restrictive addressing mode). Just do i64 stores.
17233   bool IsSmallMemset = Op.isMemset() && Op.size() < 32;
17234   auto AlignmentIsAcceptable = [&](EVT VT, Align AlignCheck) {
17235     if (Op.isAligned(AlignCheck))
17236       return true;
17237     unsigned Fast;
17238     return allowsMisalignedMemoryAccesses(VT, 0, Align(1),
17239                                           MachineMemOperand::MONone, &Fast) &&
17240            Fast;
17241   };
17242 
17243   if (CanUseNEON && Op.isMemset() && !IsSmallMemset &&
17244       AlignmentIsAcceptable(MVT::v2i64, Align(16)))
17245     return LLT::fixed_vector(2, 64);
17246   if (CanUseFP && !IsSmallMemset && AlignmentIsAcceptable(MVT::f128, Align(16)))
17247     return LLT::scalar(128);
17248   if (Op.size() >= 8 && AlignmentIsAcceptable(MVT::i64, Align(8)))
17249     return LLT::scalar(64);
17250   if (Op.size() >= 4 && AlignmentIsAcceptable(MVT::i32, Align(4)))
17251     return LLT::scalar(32);
17252   return LLT();
17253 }
17254 
17255 // 12-bit optionally shifted immediates are legal for adds.
isLegalAddImmediate(int64_t Immed) const17256 bool AArch64TargetLowering::isLegalAddImmediate(int64_t Immed) const {
17257   if (Immed == std::numeric_limits<int64_t>::min()) {
17258     LLVM_DEBUG(dbgs() << "Illegal add imm " << Immed
17259                       << ": avoid UB for INT64_MIN\n");
17260     return false;
17261   }
17262   // Same encoding for add/sub, just flip the sign.
17263   Immed = std::abs(Immed);
17264   bool IsLegal = ((Immed >> 12) == 0 ||
17265                   ((Immed & 0xfff) == 0 && Immed >> 24 == 0));
17266   LLVM_DEBUG(dbgs() << "Is " << Immed
17267                     << " legal add imm: " << (IsLegal ? "yes" : "no") << "\n");
17268   return IsLegal;
17269 }
17270 
isLegalAddScalableImmediate(int64_t Imm) const17271 bool AArch64TargetLowering::isLegalAddScalableImmediate(int64_t Imm) const {
17272   // We will only emit addvl/inc* instructions for SVE2
17273   if (!Subtarget->hasSVE2())
17274     return false;
17275 
17276   // addvl's immediates are in terms of the number of bytes in a register.
17277   // Since there are 16 in the base supported size (128bits), we need to
17278   // divide the immediate by that much to give us a useful immediate to
17279   // multiply by vscale. We can't have a remainder as a result of this.
17280   if (Imm % 16 == 0)
17281     return isInt<6>(Imm / 16);
17282 
17283   // Inc[b|h|w|d] instructions take a pattern and a positive immediate
17284   // multiplier. For now, assume a pattern of 'all'. Incb would be a subset
17285   // of addvl as a result, so only take h|w|d into account.
17286   // Dec[h|w|d] will cover subtractions.
17287   // Immediates are in the range [1,16], so we can't do a 2's complement check.
17288   // FIXME: Can we make use of other patterns to cover other immediates?
17289 
17290   // inch|dech
17291   if (Imm % 8 == 0)
17292     return std::abs(Imm / 8) <= 16;
17293   // incw|decw
17294   if (Imm % 4 == 0)
17295     return std::abs(Imm / 4) <= 16;
17296   // incd|decd
17297   if (Imm % 2 == 0)
17298     return std::abs(Imm / 2) <= 16;
17299 
17300   return false;
17301 }
17302 
17303 // Return false to prevent folding
17304 // (mul (add x, c1), c2) -> (add (mul x, c2), c2*c1) in DAGCombine,
17305 // if the folding leads to worse code.
isMulAddWithConstProfitable(SDValue AddNode,SDValue ConstNode) const17306 bool AArch64TargetLowering::isMulAddWithConstProfitable(
17307     SDValue AddNode, SDValue ConstNode) const {
17308   // Let the DAGCombiner decide for vector types and large types.
17309   const EVT VT = AddNode.getValueType();
17310   if (VT.isVector() || VT.getScalarSizeInBits() > 64)
17311     return true;
17312 
17313   // It is worse if c1 is legal add immediate, while c1*c2 is not
17314   // and has to be composed by at least two instructions.
17315   const ConstantSDNode *C1Node = cast<ConstantSDNode>(AddNode.getOperand(1));
17316   const ConstantSDNode *C2Node = cast<ConstantSDNode>(ConstNode);
17317   const int64_t C1 = C1Node->getSExtValue();
17318   const APInt C1C2 = C1Node->getAPIntValue() * C2Node->getAPIntValue();
17319   if (!isLegalAddImmediate(C1) || isLegalAddImmediate(C1C2.getSExtValue()))
17320     return true;
17321   SmallVector<AArch64_IMM::ImmInsnModel, 4> Insn;
17322   // Adapt to the width of a register.
17323   unsigned BitSize = VT.getSizeInBits() <= 32 ? 32 : 64;
17324   AArch64_IMM::expandMOVImm(C1C2.getZExtValue(), BitSize, Insn);
17325   if (Insn.size() > 1)
17326     return false;
17327 
17328   // Default to true and let the DAGCombiner decide.
17329   return true;
17330 }
17331 
17332 // Integer comparisons are implemented with ADDS/SUBS, so the range of valid
17333 // immediates is the same as for an add or a sub.
isLegalICmpImmediate(int64_t Immed) const17334 bool AArch64TargetLowering::isLegalICmpImmediate(int64_t Immed) const {
17335   return isLegalAddImmediate(Immed);
17336 }
17337 
17338 /// isLegalAddressingMode - Return true if the addressing mode represented
17339 /// by AM is legal for this target, for a load/store of the specified type.
isLegalAddressingMode(const DataLayout & DL,const AddrMode & AMode,Type * Ty,unsigned AS,Instruction * I) const17340 bool AArch64TargetLowering::isLegalAddressingMode(const DataLayout &DL,
17341                                                   const AddrMode &AMode, Type *Ty,
17342                                                   unsigned AS, Instruction *I) const {
17343   // AArch64 has five basic addressing modes:
17344   //  reg
17345   //  reg + 9-bit signed offset
17346   //  reg + SIZE_IN_BYTES * 12-bit unsigned offset
17347   //  reg1 + reg2
17348   //  reg + SIZE_IN_BYTES * reg
17349 
17350   // No global is ever allowed as a base.
17351   if (AMode.BaseGV)
17352     return false;
17353 
17354   // No reg+reg+imm addressing.
17355   if (AMode.HasBaseReg && AMode.BaseOffs && AMode.Scale)
17356     return false;
17357 
17358   // Canonicalise `1*ScaledReg + imm` into `BaseReg + imm` and
17359   // `2*ScaledReg` into `BaseReg + ScaledReg`
17360   AddrMode AM = AMode;
17361   if (AM.Scale && !AM.HasBaseReg) {
17362     if (AM.Scale == 1) {
17363       AM.HasBaseReg = true;
17364       AM.Scale = 0;
17365     } else if (AM.Scale == 2) {
17366       AM.HasBaseReg = true;
17367       AM.Scale = 1;
17368     } else {
17369       return false;
17370     }
17371   }
17372 
17373   // A base register is required in all addressing modes.
17374   if (!AM.HasBaseReg)
17375     return false;
17376 
17377   if (Ty->isScalableTy()) {
17378     if (isa<ScalableVectorType>(Ty)) {
17379       // See if we have a foldable vscale-based offset, for vector types which
17380       // are either legal or smaller than the minimum; more work will be
17381       // required if we need to consider addressing for types which need
17382       // legalization by splitting.
17383       uint64_t VecNumBytes = DL.getTypeSizeInBits(Ty).getKnownMinValue() / 8;
17384       if (AM.HasBaseReg && !AM.BaseOffs && AM.ScalableOffset && !AM.Scale &&
17385           (AM.ScalableOffset % VecNumBytes == 0) && VecNumBytes <= 16 &&
17386           isPowerOf2_64(VecNumBytes))
17387         return isInt<4>(AM.ScalableOffset / (int64_t)VecNumBytes);
17388 
17389       uint64_t VecElemNumBytes =
17390           DL.getTypeSizeInBits(cast<VectorType>(Ty)->getElementType()) / 8;
17391       return AM.HasBaseReg && !AM.BaseOffs && !AM.ScalableOffset &&
17392              (AM.Scale == 0 || (uint64_t)AM.Scale == VecElemNumBytes);
17393     }
17394 
17395     return AM.HasBaseReg && !AM.BaseOffs && !AM.ScalableOffset && !AM.Scale;
17396   }
17397 
17398   // No scalable offsets allowed for non-scalable types.
17399   if (AM.ScalableOffset)
17400     return false;
17401 
17402   // check reg + imm case:
17403   // i.e., reg + 0, reg + imm9, reg + SIZE_IN_BYTES * uimm12
17404   uint64_t NumBytes = 0;
17405   if (Ty->isSized()) {
17406     uint64_t NumBits = DL.getTypeSizeInBits(Ty);
17407     NumBytes = NumBits / 8;
17408     if (!isPowerOf2_64(NumBits))
17409       NumBytes = 0;
17410   }
17411 
17412   return Subtarget->getInstrInfo()->isLegalAddressingMode(NumBytes, AM.BaseOffs,
17413                                                           AM.Scale);
17414 }
17415 
17416 // Check whether the 2 offsets belong to the same imm24 range, and their high
17417 // 12bits are same, then their high part can be decoded with the offset of add.
17418 int64_t
getPreferredLargeGEPBaseOffset(int64_t MinOffset,int64_t MaxOffset) const17419 AArch64TargetLowering::getPreferredLargeGEPBaseOffset(int64_t MinOffset,
17420                                                       int64_t MaxOffset) const {
17421   int64_t HighPart = MinOffset & ~0xfffULL;
17422   if (MinOffset >> 12 == MaxOffset >> 12 && isLegalAddImmediate(HighPart)) {
17423     // Rebase the value to an integer multiple of imm12.
17424     return HighPart;
17425   }
17426 
17427   return 0;
17428 }
17429 
shouldConsiderGEPOffsetSplit() const17430 bool AArch64TargetLowering::shouldConsiderGEPOffsetSplit() const {
17431   // Consider splitting large offset of struct or array.
17432   return true;
17433 }
17434 
isFMAFasterThanFMulAndFAdd(const MachineFunction & MF,EVT VT) const17435 bool AArch64TargetLowering::isFMAFasterThanFMulAndFAdd(
17436     const MachineFunction &MF, EVT VT) const {
17437   VT = VT.getScalarType();
17438 
17439   if (!VT.isSimple())
17440     return false;
17441 
17442   switch (VT.getSimpleVT().SimpleTy) {
17443   case MVT::f16:
17444     return Subtarget->hasFullFP16();
17445   case MVT::f32:
17446   case MVT::f64:
17447     return true;
17448   default:
17449     break;
17450   }
17451 
17452   return false;
17453 }
17454 
isFMAFasterThanFMulAndFAdd(const Function & F,Type * Ty) const17455 bool AArch64TargetLowering::isFMAFasterThanFMulAndFAdd(const Function &F,
17456                                                        Type *Ty) const {
17457   switch (Ty->getScalarType()->getTypeID()) {
17458   case Type::FloatTyID:
17459   case Type::DoubleTyID:
17460     return true;
17461   default:
17462     return false;
17463   }
17464 }
17465 
generateFMAsInMachineCombiner(EVT VT,CodeGenOptLevel OptLevel) const17466 bool AArch64TargetLowering::generateFMAsInMachineCombiner(
17467     EVT VT, CodeGenOptLevel OptLevel) const {
17468   return (OptLevel >= CodeGenOptLevel::Aggressive) && !VT.isScalableVector() &&
17469          !useSVEForFixedLengthVectorVT(VT);
17470 }
17471 
17472 const MCPhysReg *
getScratchRegisters(CallingConv::ID) const17473 AArch64TargetLowering::getScratchRegisters(CallingConv::ID) const {
17474   // LR is a callee-save register, but we must treat it as clobbered by any call
17475   // site. Hence we include LR in the scratch registers, which are in turn added
17476   // as implicit-defs for stackmaps and patchpoints.
17477   static const MCPhysReg ScratchRegs[] = {
17478     AArch64::X16, AArch64::X17, AArch64::LR, 0
17479   };
17480   return ScratchRegs;
17481 }
17482 
getRoundingControlRegisters() const17483 ArrayRef<MCPhysReg> AArch64TargetLowering::getRoundingControlRegisters() const {
17484   static const MCPhysReg RCRegs[] = {AArch64::FPCR};
17485   return RCRegs;
17486 }
17487 
17488 bool
isDesirableToCommuteWithShift(const SDNode * N,CombineLevel Level) const17489 AArch64TargetLowering::isDesirableToCommuteWithShift(const SDNode *N,
17490                                                      CombineLevel Level) const {
17491   assert((N->getOpcode() == ISD::SHL || N->getOpcode() == ISD::SRA ||
17492           N->getOpcode() == ISD::SRL) &&
17493          "Expected shift op");
17494 
17495   SDValue ShiftLHS = N->getOperand(0);
17496   EVT VT = N->getValueType(0);
17497 
17498   // If ShiftLHS is unsigned bit extraction: ((x >> C) & mask), then do not
17499   // combine it with shift 'N' to let it be lowered to UBFX except:
17500   // ((x >> C) & mask) << C.
17501   if (ShiftLHS.getOpcode() == ISD::AND && (VT == MVT::i32 || VT == MVT::i64) &&
17502       isa<ConstantSDNode>(ShiftLHS.getOperand(1))) {
17503     uint64_t TruncMask = ShiftLHS.getConstantOperandVal(1);
17504     if (isMask_64(TruncMask)) {
17505       SDValue AndLHS = ShiftLHS.getOperand(0);
17506       if (AndLHS.getOpcode() == ISD::SRL) {
17507         if (auto *SRLC = dyn_cast<ConstantSDNode>(AndLHS.getOperand(1))) {
17508           if (N->getOpcode() == ISD::SHL)
17509             if (auto *SHLC = dyn_cast<ConstantSDNode>(N->getOperand(1)))
17510               return SRLC->getZExtValue() == SHLC->getZExtValue();
17511           return false;
17512         }
17513       }
17514     }
17515   }
17516   return true;
17517 }
17518 
isDesirableToCommuteXorWithShift(const SDNode * N) const17519 bool AArch64TargetLowering::isDesirableToCommuteXorWithShift(
17520     const SDNode *N) const {
17521   assert(N->getOpcode() == ISD::XOR &&
17522          (N->getOperand(0).getOpcode() == ISD::SHL ||
17523           N->getOperand(0).getOpcode() == ISD::SRL) &&
17524          "Expected XOR(SHIFT) pattern");
17525 
17526   // Only commute if the entire NOT mask is a hidden shifted mask.
17527   auto *XorC = dyn_cast<ConstantSDNode>(N->getOperand(1));
17528   auto *ShiftC = dyn_cast<ConstantSDNode>(N->getOperand(0).getOperand(1));
17529   if (XorC && ShiftC) {
17530     unsigned MaskIdx, MaskLen;
17531     if (XorC->getAPIntValue().isShiftedMask(MaskIdx, MaskLen)) {
17532       unsigned ShiftAmt = ShiftC->getZExtValue();
17533       unsigned BitWidth = N->getValueType(0).getScalarSizeInBits();
17534       if (N->getOperand(0).getOpcode() == ISD::SHL)
17535         return MaskIdx == ShiftAmt && MaskLen == (BitWidth - ShiftAmt);
17536       return MaskIdx == 0 && MaskLen == (BitWidth - ShiftAmt);
17537     }
17538   }
17539 
17540   return false;
17541 }
17542 
shouldFoldConstantShiftPairToMask(const SDNode * N,CombineLevel Level) const17543 bool AArch64TargetLowering::shouldFoldConstantShiftPairToMask(
17544     const SDNode *N, CombineLevel Level) const {
17545   assert(((N->getOpcode() == ISD::SHL &&
17546            N->getOperand(0).getOpcode() == ISD::SRL) ||
17547           (N->getOpcode() == ISD::SRL &&
17548            N->getOperand(0).getOpcode() == ISD::SHL)) &&
17549          "Expected shift-shift mask");
17550   // Don't allow multiuse shift folding with the same shift amount.
17551   if (!N->getOperand(0)->hasOneUse())
17552     return false;
17553 
17554   // Only fold srl(shl(x,c1),c2) iff C1 >= C2 to prevent loss of UBFX patterns.
17555   EVT VT = N->getValueType(0);
17556   if (N->getOpcode() == ISD::SRL && (VT == MVT::i32 || VT == MVT::i64)) {
17557     auto *C1 = dyn_cast<ConstantSDNode>(N->getOperand(0).getOperand(1));
17558     auto *C2 = dyn_cast<ConstantSDNode>(N->getOperand(1));
17559     return (!C1 || !C2 || C1->getZExtValue() >= C2->getZExtValue());
17560   }
17561 
17562   return true;
17563 }
17564 
shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,EVT VT) const17565 bool AArch64TargetLowering::shouldFoldSelectWithIdentityConstant(
17566     unsigned BinOpcode, EVT VT) const {
17567   return VT.isScalableVector() && isTypeLegal(VT);
17568 }
17569 
shouldConvertConstantLoadToIntImm(const APInt & Imm,Type * Ty) const17570 bool AArch64TargetLowering::shouldConvertConstantLoadToIntImm(const APInt &Imm,
17571                                                               Type *Ty) const {
17572   assert(Ty->isIntegerTy());
17573 
17574   unsigned BitSize = Ty->getPrimitiveSizeInBits();
17575   if (BitSize == 0)
17576     return false;
17577 
17578   int64_t Val = Imm.getSExtValue();
17579   if (Val == 0 || AArch64_AM::isLogicalImmediate(Val, BitSize))
17580     return true;
17581 
17582   if ((int64_t)Val < 0)
17583     Val = ~Val;
17584   if (BitSize == 32)
17585     Val &= (1LL << 32) - 1;
17586 
17587   unsigned Shift = llvm::Log2_64((uint64_t)Val) / 16;
17588   // MOVZ is free so return true for one or fewer MOVK.
17589   return Shift < 3;
17590 }
17591 
isExtractSubvectorCheap(EVT ResVT,EVT SrcVT,unsigned Index) const17592 bool AArch64TargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT,
17593                                                     unsigned Index) const {
17594   if (!isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, ResVT))
17595     return false;
17596 
17597   return (Index == 0 || Index == ResVT.getVectorMinNumElements());
17598 }
17599 
17600 /// Turn vector tests of the signbit in the form of:
17601 ///   xor (sra X, elt_size(X)-1), -1
17602 /// into:
17603 ///   cmge X, X, #0
foldVectorXorShiftIntoCmp(SDNode * N,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)17604 static SDValue foldVectorXorShiftIntoCmp(SDNode *N, SelectionDAG &DAG,
17605                                          const AArch64Subtarget *Subtarget) {
17606   EVT VT = N->getValueType(0);
17607   if (!Subtarget->hasNEON() || !VT.isVector())
17608     return SDValue();
17609 
17610   // There must be a shift right algebraic before the xor, and the xor must be a
17611   // 'not' operation.
17612   SDValue Shift = N->getOperand(0);
17613   SDValue Ones = N->getOperand(1);
17614   if (Shift.getOpcode() != AArch64ISD::VASHR || !Shift.hasOneUse() ||
17615       !ISD::isBuildVectorAllOnes(Ones.getNode()))
17616     return SDValue();
17617 
17618   // The shift should be smearing the sign bit across each vector element.
17619   auto *ShiftAmt = dyn_cast<ConstantSDNode>(Shift.getOperand(1));
17620   EVT ShiftEltTy = Shift.getValueType().getVectorElementType();
17621   if (!ShiftAmt || ShiftAmt->getZExtValue() != ShiftEltTy.getSizeInBits() - 1)
17622     return SDValue();
17623 
17624   return DAG.getNode(AArch64ISD::CMGEz, SDLoc(N), VT, Shift.getOperand(0));
17625 }
17626 
17627 // Given a vecreduce_add node, detect the below pattern and convert it to the
17628 // node sequence with UABDL, [S|U]ADB and UADDLP.
17629 //
17630 // i32 vecreduce_add(
17631 //  v16i32 abs(
17632 //    v16i32 sub(
17633 //     v16i32 [sign|zero]_extend(v16i8 a), v16i32 [sign|zero]_extend(v16i8 b))))
17634 // =================>
17635 // i32 vecreduce_add(
17636 //   v4i32 UADDLP(
17637 //     v8i16 add(
17638 //       v8i16 zext(
17639 //         v8i8 [S|U]ABD low8:v16i8 a, low8:v16i8 b
17640 //       v8i16 zext(
17641 //         v8i8 [S|U]ABD high8:v16i8 a, high8:v16i8 b
performVecReduceAddCombineWithUADDLP(SDNode * N,SelectionDAG & DAG)17642 static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
17643                                                     SelectionDAG &DAG) {
17644   // Assumed i32 vecreduce_add
17645   if (N->getValueType(0) != MVT::i32)
17646     return SDValue();
17647 
17648   SDValue VecReduceOp0 = N->getOperand(0);
17649   unsigned Opcode = VecReduceOp0.getOpcode();
17650   // Assumed v16i32 abs
17651   if (Opcode != ISD::ABS || VecReduceOp0->getValueType(0) != MVT::v16i32)
17652     return SDValue();
17653 
17654   SDValue ABS = VecReduceOp0;
17655   // Assumed v16i32 sub
17656   if (ABS->getOperand(0)->getOpcode() != ISD::SUB ||
17657       ABS->getOperand(0)->getValueType(0) != MVT::v16i32)
17658     return SDValue();
17659 
17660   SDValue SUB = ABS->getOperand(0);
17661   unsigned Opcode0 = SUB->getOperand(0).getOpcode();
17662   unsigned Opcode1 = SUB->getOperand(1).getOpcode();
17663   // Assumed v16i32 type
17664   if (SUB->getOperand(0)->getValueType(0) != MVT::v16i32 ||
17665       SUB->getOperand(1)->getValueType(0) != MVT::v16i32)
17666     return SDValue();
17667 
17668   // Assumed zext or sext
17669   bool IsZExt = false;
17670   if (Opcode0 == ISD::ZERO_EXTEND && Opcode1 == ISD::ZERO_EXTEND) {
17671     IsZExt = true;
17672   } else if (Opcode0 == ISD::SIGN_EXTEND && Opcode1 == ISD::SIGN_EXTEND) {
17673     IsZExt = false;
17674   } else
17675     return SDValue();
17676 
17677   SDValue EXT0 = SUB->getOperand(0);
17678   SDValue EXT1 = SUB->getOperand(1);
17679   // Assumed zext's operand has v16i8 type
17680   if (EXT0->getOperand(0)->getValueType(0) != MVT::v16i8 ||
17681       EXT1->getOperand(0)->getValueType(0) != MVT::v16i8)
17682     return SDValue();
17683 
17684   // Pattern is dectected. Let's convert it to sequence of nodes.
17685   SDLoc DL(N);
17686 
17687   // First, create the node pattern of UABD/SABD.
17688   SDValue UABDHigh8Op0 =
17689       DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT0->getOperand(0),
17690                   DAG.getConstant(8, DL, MVT::i64));
17691   SDValue UABDHigh8Op1 =
17692       DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT1->getOperand(0),
17693                   DAG.getConstant(8, DL, MVT::i64));
17694   SDValue UABDHigh8 = DAG.getNode(IsZExt ? ISD::ABDU : ISD::ABDS, DL, MVT::v8i8,
17695                                   UABDHigh8Op0, UABDHigh8Op1);
17696   SDValue UABDL = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, UABDHigh8);
17697 
17698   // Second, create the node pattern of UABAL.
17699   SDValue UABDLo8Op0 =
17700       DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT0->getOperand(0),
17701                   DAG.getConstant(0, DL, MVT::i64));
17702   SDValue UABDLo8Op1 =
17703       DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT1->getOperand(0),
17704                   DAG.getConstant(0, DL, MVT::i64));
17705   SDValue UABDLo8 = DAG.getNode(IsZExt ? ISD::ABDU : ISD::ABDS, DL, MVT::v8i8,
17706                                 UABDLo8Op0, UABDLo8Op1);
17707   SDValue ZExtUABD = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, UABDLo8);
17708   SDValue UABAL = DAG.getNode(ISD::ADD, DL, MVT::v8i16, UABDL, ZExtUABD);
17709 
17710   // Third, create the node of UADDLP.
17711   SDValue UADDLP = DAG.getNode(AArch64ISD::UADDLP, DL, MVT::v4i32, UABAL);
17712 
17713   // Fourth, create the node of VECREDUCE_ADD.
17714   return DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, UADDLP);
17715 }
17716 
17717 // Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce
17718 //   vecreduce.add(ext(A)) to vecreduce.add(DOT(zero, A, one))
17719 //   vecreduce.add(mul(ext(A), ext(B))) to vecreduce.add(DOT(zero, A, B))
17720 // If we have vectors larger than v16i8 we extract v16i8 vectors,
17721 // Follow the same steps above to get DOT instructions concatenate them
17722 // and generate vecreduce.add(concat_vector(DOT, DOT2, ..)).
performVecReduceAddCombine(SDNode * N,SelectionDAG & DAG,const AArch64Subtarget * ST)17723 static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
17724                                           const AArch64Subtarget *ST) {
17725   if (!ST->isNeonAvailable())
17726     return SDValue();
17727 
17728   if (!ST->hasDotProd())
17729     return performVecReduceAddCombineWithUADDLP(N, DAG);
17730 
17731   SDValue Op0 = N->getOperand(0);
17732   if (N->getValueType(0) != MVT::i32 || Op0.getValueType().isScalableVT() ||
17733       Op0.getValueType().getVectorElementType() != MVT::i32)
17734     return SDValue();
17735 
17736   unsigned ExtOpcode = Op0.getOpcode();
17737   SDValue A = Op0;
17738   SDValue B;
17739   if (ExtOpcode == ISD::MUL) {
17740     A = Op0.getOperand(0);
17741     B = Op0.getOperand(1);
17742     if (A.getOpcode() != B.getOpcode() ||
17743         A.getOperand(0).getValueType() != B.getOperand(0).getValueType())
17744       return SDValue();
17745     ExtOpcode = A.getOpcode();
17746   }
17747   if (ExtOpcode != ISD::ZERO_EXTEND && ExtOpcode != ISD::SIGN_EXTEND)
17748     return SDValue();
17749 
17750   EVT Op0VT = A.getOperand(0).getValueType();
17751   bool IsValidElementCount = Op0VT.getVectorNumElements() % 8 == 0;
17752   bool IsValidSize = Op0VT.getScalarSizeInBits() == 8;
17753   if (!IsValidElementCount || !IsValidSize)
17754     return SDValue();
17755 
17756   SDLoc DL(Op0);
17757   // For non-mla reductions B can be set to 1. For MLA we take the operand of
17758   // the extend B.
17759   if (!B)
17760     B = DAG.getConstant(1, DL, Op0VT);
17761   else
17762     B = B.getOperand(0);
17763 
17764   unsigned IsMultipleOf16 = Op0VT.getVectorNumElements() % 16 == 0;
17765   unsigned NumOfVecReduce;
17766   EVT TargetType;
17767   if (IsMultipleOf16) {
17768     NumOfVecReduce = Op0VT.getVectorNumElements() / 16;
17769     TargetType = MVT::v4i32;
17770   } else {
17771     NumOfVecReduce = Op0VT.getVectorNumElements() / 8;
17772     TargetType = MVT::v2i32;
17773   }
17774   auto DotOpcode =
17775       (ExtOpcode == ISD::ZERO_EXTEND) ? AArch64ISD::UDOT : AArch64ISD::SDOT;
17776   // Handle the case where we need to generate only one Dot operation.
17777   if (NumOfVecReduce == 1) {
17778     SDValue Zeros = DAG.getConstant(0, DL, TargetType);
17779     SDValue Dot = DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros,
17780                               A.getOperand(0), B);
17781     return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot);
17782   }
17783   // Generate Dot instructions that are multiple of 16.
17784   unsigned VecReduce16Num = Op0VT.getVectorNumElements() / 16;
17785   SmallVector<SDValue, 4> SDotVec16;
17786   unsigned I = 0;
17787   for (; I < VecReduce16Num; I += 1) {
17788     SDValue Zeros = DAG.getConstant(0, DL, MVT::v4i32);
17789     SDValue Op0 =
17790         DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v16i8, A.getOperand(0),
17791                     DAG.getConstant(I * 16, DL, MVT::i64));
17792     SDValue Op1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v16i8, B,
17793                               DAG.getConstant(I * 16, DL, MVT::i64));
17794     SDValue Dot =
17795         DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros, Op0, Op1);
17796     SDotVec16.push_back(Dot);
17797   }
17798   // Concatenate dot operations.
17799   EVT SDot16EVT =
17800       EVT::getVectorVT(*DAG.getContext(), MVT::i32, 4 * VecReduce16Num);
17801   SDValue ConcatSDot16 =
17802       DAG.getNode(ISD::CONCAT_VECTORS, DL, SDot16EVT, SDotVec16);
17803   SDValue VecReduceAdd16 =
17804       DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), ConcatSDot16);
17805   unsigned VecReduce8Num = (Op0VT.getVectorNumElements() % 16) / 8;
17806   if (VecReduce8Num == 0)
17807     return VecReduceAdd16;
17808 
17809   // Generate the remainder Dot operation that is multiple of 8.
17810   SmallVector<SDValue, 4> SDotVec8;
17811   SDValue Zeros = DAG.getConstant(0, DL, MVT::v2i32);
17812   SDValue Vec8Op0 =
17813       DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, A.getOperand(0),
17814                   DAG.getConstant(I * 16, DL, MVT::i64));
17815   SDValue Vec8Op1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, B,
17816                                 DAG.getConstant(I * 16, DL, MVT::i64));
17817   SDValue Dot =
17818       DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros, Vec8Op0, Vec8Op1);
17819   SDValue VecReudceAdd8 =
17820       DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot);
17821   return DAG.getNode(ISD::ADD, DL, N->getValueType(0), VecReduceAdd16,
17822                      VecReudceAdd8);
17823 }
17824 
17825 // Given an (integer) vecreduce, we know the order of the inputs does not
17826 // matter. We can convert UADDV(add(zext(extract_lo(x)), zext(extract_hi(x))))
17827 // into UADDV(UADDLP(x)). This can also happen through an extra add, where we
17828 // transform UADDV(add(y, add(zext(extract_lo(x)), zext(extract_hi(x))))).
performUADDVAddCombine(SDValue A,SelectionDAG & DAG)17829 static SDValue performUADDVAddCombine(SDValue A, SelectionDAG &DAG) {
17830   auto DetectAddExtract = [&](SDValue A) {
17831     // Look for add(zext(extract_lo(x)), zext(extract_hi(x))), returning
17832     // UADDLP(x) if found.
17833     assert(A.getOpcode() == ISD::ADD);
17834     EVT VT = A.getValueType();
17835     SDValue Op0 = A.getOperand(0);
17836     SDValue Op1 = A.getOperand(1);
17837     if (Op0.getOpcode() != Op0.getOpcode() ||
17838         (Op0.getOpcode() != ISD::ZERO_EXTEND &&
17839          Op0.getOpcode() != ISD::SIGN_EXTEND))
17840       return SDValue();
17841     SDValue Ext0 = Op0.getOperand(0);
17842     SDValue Ext1 = Op1.getOperand(0);
17843     if (Ext0.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
17844         Ext1.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
17845         Ext0.getOperand(0) != Ext1.getOperand(0))
17846       return SDValue();
17847     // Check that the type is twice the add types, and the extract are from
17848     // upper/lower parts of the same source.
17849     if (Ext0.getOperand(0).getValueType().getVectorNumElements() !=
17850         VT.getVectorNumElements() * 2)
17851       return SDValue();
17852     if ((Ext0.getConstantOperandVal(1) != 0 ||
17853          Ext1.getConstantOperandVal(1) != VT.getVectorNumElements()) &&
17854         (Ext1.getConstantOperandVal(1) != 0 ||
17855          Ext0.getConstantOperandVal(1) != VT.getVectorNumElements()))
17856       return SDValue();
17857     unsigned Opcode = Op0.getOpcode() == ISD::ZERO_EXTEND ? AArch64ISD::UADDLP
17858                                                           : AArch64ISD::SADDLP;
17859     return DAG.getNode(Opcode, SDLoc(A), VT, Ext0.getOperand(0));
17860   };
17861 
17862   if (SDValue R = DetectAddExtract(A))
17863     return R;
17864 
17865   if (A.getOperand(0).getOpcode() == ISD::ADD && A.getOperand(0).hasOneUse())
17866     if (SDValue R = performUADDVAddCombine(A.getOperand(0), DAG))
17867       return DAG.getNode(ISD::ADD, SDLoc(A), A.getValueType(), R,
17868                          A.getOperand(1));
17869   if (A.getOperand(1).getOpcode() == ISD::ADD && A.getOperand(1).hasOneUse())
17870     if (SDValue R = performUADDVAddCombine(A.getOperand(1), DAG))
17871       return DAG.getNode(ISD::ADD, SDLoc(A), A.getValueType(), R,
17872                          A.getOperand(0));
17873   return SDValue();
17874 }
17875 
17876 // We can convert a UADDV(add(zext(64-bit source), zext(64-bit source))) into
17877 // UADDLV(concat), where the concat represents the 64-bit zext sources.
performUADDVZextCombine(SDValue A,SelectionDAG & DAG)17878 static SDValue performUADDVZextCombine(SDValue A, SelectionDAG &DAG) {
17879   // Look for add(zext(64-bit source), zext(64-bit source)), returning
17880   // UADDLV(concat(zext, zext)) if found.
17881   assert(A.getOpcode() == ISD::ADD);
17882   EVT VT = A.getValueType();
17883   if (VT != MVT::v8i16 && VT != MVT::v4i32 && VT != MVT::v2i64)
17884     return SDValue();
17885   SDValue Op0 = A.getOperand(0);
17886   SDValue Op1 = A.getOperand(1);
17887   if (Op0.getOpcode() != ISD::ZERO_EXTEND || Op0.getOpcode() != Op1.getOpcode())
17888     return SDValue();
17889   SDValue Ext0 = Op0.getOperand(0);
17890   SDValue Ext1 = Op1.getOperand(0);
17891   EVT ExtVT0 = Ext0.getValueType();
17892   EVT ExtVT1 = Ext1.getValueType();
17893   // Check zext VTs are the same and 64-bit length.
17894   if (ExtVT0 != ExtVT1 ||
17895       VT.getScalarSizeInBits() != (2 * ExtVT0.getScalarSizeInBits()))
17896     return SDValue();
17897   // Get VT for concat of zext sources.
17898   EVT PairVT = ExtVT0.getDoubleNumVectorElementsVT(*DAG.getContext());
17899   SDValue Concat =
17900       DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(A), PairVT, Ext0, Ext1);
17901 
17902   switch (VT.getSimpleVT().SimpleTy) {
17903   case MVT::v2i64:
17904   case MVT::v4i32:
17905     return DAG.getNode(AArch64ISD::UADDLV, SDLoc(A), VT, Concat);
17906   case MVT::v8i16: {
17907     SDValue Uaddlv =
17908         DAG.getNode(AArch64ISD::UADDLV, SDLoc(A), MVT::v4i32, Concat);
17909     return DAG.getNode(AArch64ISD::NVCAST, SDLoc(A), MVT::v8i16, Uaddlv);
17910   }
17911   default:
17912     llvm_unreachable("Unhandled vector type");
17913   }
17914 }
17915 
performUADDVCombine(SDNode * N,SelectionDAG & DAG)17916 static SDValue performUADDVCombine(SDNode *N, SelectionDAG &DAG) {
17917   SDValue A = N->getOperand(0);
17918   if (A.getOpcode() == ISD::ADD) {
17919     if (SDValue R = performUADDVAddCombine(A, DAG))
17920       return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), R);
17921     else if (SDValue R = performUADDVZextCombine(A, DAG))
17922       return R;
17923   }
17924   return SDValue();
17925 }
17926 
performXorCombine(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const AArch64Subtarget * Subtarget)17927 static SDValue performXorCombine(SDNode *N, SelectionDAG &DAG,
17928                                  TargetLowering::DAGCombinerInfo &DCI,
17929                                  const AArch64Subtarget *Subtarget) {
17930   if (DCI.isBeforeLegalizeOps())
17931     return SDValue();
17932 
17933   return foldVectorXorShiftIntoCmp(N, DAG, Subtarget);
17934 }
17935 
17936 SDValue
BuildSDIVPow2(SDNode * N,const APInt & Divisor,SelectionDAG & DAG,SmallVectorImpl<SDNode * > & Created) const17937 AArch64TargetLowering::BuildSDIVPow2(SDNode *N, const APInt &Divisor,
17938                                      SelectionDAG &DAG,
17939                                      SmallVectorImpl<SDNode *> &Created) const {
17940   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
17941   if (isIntDivCheap(N->getValueType(0), Attr))
17942     return SDValue(N, 0); // Lower SDIV as SDIV
17943 
17944   EVT VT = N->getValueType(0);
17945 
17946   // For scalable and fixed types, mark them as cheap so we can handle it much
17947   // later. This allows us to handle larger than legal types.
17948   if (VT.isScalableVector() ||
17949       (VT.isFixedLengthVector() && Subtarget->useSVEForFixedLengthVectors()))
17950     return SDValue(N, 0);
17951 
17952   // fold (sdiv X, pow2)
17953   if ((VT != MVT::i32 && VT != MVT::i64) ||
17954       !(Divisor.isPowerOf2() || Divisor.isNegatedPowerOf2()))
17955     return SDValue();
17956 
17957   // If the divisor is 2 or -2, the default expansion is better. It will add
17958   // (N->getValueType(0) >> (BitWidth - 1)) to it before shifting right.
17959   if (Divisor == 2 ||
17960       Divisor == APInt(Divisor.getBitWidth(), -2, /*isSigned*/ true))
17961     return SDValue();
17962 
17963   return TargetLowering::buildSDIVPow2WithCMov(N, Divisor, DAG, Created);
17964 }
17965 
17966 SDValue
BuildSREMPow2(SDNode * N,const APInt & Divisor,SelectionDAG & DAG,SmallVectorImpl<SDNode * > & Created) const17967 AArch64TargetLowering::BuildSREMPow2(SDNode *N, const APInt &Divisor,
17968                                      SelectionDAG &DAG,
17969                                      SmallVectorImpl<SDNode *> &Created) const {
17970   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
17971   if (isIntDivCheap(N->getValueType(0), Attr))
17972     return SDValue(N, 0); // Lower SREM as SREM
17973 
17974   EVT VT = N->getValueType(0);
17975 
17976   // For scalable and fixed types, mark them as cheap so we can handle it much
17977   // later. This allows us to handle larger than legal types.
17978   if (VT.isScalableVector() || Subtarget->useSVEForFixedLengthVectors())
17979     return SDValue(N, 0);
17980 
17981   // fold (srem X, pow2)
17982   if ((VT != MVT::i32 && VT != MVT::i64) ||
17983       !(Divisor.isPowerOf2() || Divisor.isNegatedPowerOf2()))
17984     return SDValue();
17985 
17986   unsigned Lg2 = Divisor.countr_zero();
17987   if (Lg2 == 0)
17988     return SDValue();
17989 
17990   SDLoc DL(N);
17991   SDValue N0 = N->getOperand(0);
17992   SDValue Pow2MinusOne = DAG.getConstant((1ULL << Lg2) - 1, DL, VT);
17993   SDValue Zero = DAG.getConstant(0, DL, VT);
17994   SDValue CCVal, CSNeg;
17995   if (Lg2 == 1) {
17996     SDValue Cmp = getAArch64Cmp(N0, Zero, ISD::SETGE, CCVal, DAG, DL);
17997     SDValue And = DAG.getNode(ISD::AND, DL, VT, N0, Pow2MinusOne);
17998     CSNeg = DAG.getNode(AArch64ISD::CSNEG, DL, VT, And, And, CCVal, Cmp);
17999 
18000     Created.push_back(Cmp.getNode());
18001     Created.push_back(And.getNode());
18002   } else {
18003     SDValue CCVal = DAG.getConstant(AArch64CC::MI, DL, MVT_CC);
18004     SDVTList VTs = DAG.getVTList(VT, MVT::i32);
18005 
18006     SDValue Negs = DAG.getNode(AArch64ISD::SUBS, DL, VTs, Zero, N0);
18007     SDValue AndPos = DAG.getNode(ISD::AND, DL, VT, N0, Pow2MinusOne);
18008     SDValue AndNeg = DAG.getNode(ISD::AND, DL, VT, Negs, Pow2MinusOne);
18009     CSNeg = DAG.getNode(AArch64ISD::CSNEG, DL, VT, AndPos, AndNeg, CCVal,
18010                         Negs.getValue(1));
18011 
18012     Created.push_back(Negs.getNode());
18013     Created.push_back(AndPos.getNode());
18014     Created.push_back(AndNeg.getNode());
18015   }
18016 
18017   return CSNeg;
18018 }
18019 
IsSVECntIntrinsic(SDValue S)18020 static std::optional<unsigned> IsSVECntIntrinsic(SDValue S) {
18021   switch(getIntrinsicID(S.getNode())) {
18022   default:
18023     break;
18024   case Intrinsic::aarch64_sve_cntb:
18025     return 8;
18026   case Intrinsic::aarch64_sve_cnth:
18027     return 16;
18028   case Intrinsic::aarch64_sve_cntw:
18029     return 32;
18030   case Intrinsic::aarch64_sve_cntd:
18031     return 64;
18032   }
18033   return {};
18034 }
18035 
18036 /// Calculates what the pre-extend type is, based on the extension
18037 /// operation node provided by \p Extend.
18038 ///
18039 /// In the case that \p Extend is a SIGN_EXTEND or a ZERO_EXTEND, the
18040 /// pre-extend type is pulled directly from the operand, while other extend
18041 /// operations need a bit more inspection to get this information.
18042 ///
18043 /// \param Extend The SDNode from the DAG that represents the extend operation
18044 ///
18045 /// \returns The type representing the \p Extend source type, or \p MVT::Other
18046 /// if no valid type can be determined
calculatePreExtendType(SDValue Extend)18047 static EVT calculatePreExtendType(SDValue Extend) {
18048   switch (Extend.getOpcode()) {
18049   case ISD::SIGN_EXTEND:
18050   case ISD::ZERO_EXTEND:
18051     return Extend.getOperand(0).getValueType();
18052   case ISD::AssertSext:
18053   case ISD::AssertZext:
18054   case ISD::SIGN_EXTEND_INREG: {
18055     VTSDNode *TypeNode = dyn_cast<VTSDNode>(Extend.getOperand(1));
18056     if (!TypeNode)
18057       return MVT::Other;
18058     return TypeNode->getVT();
18059   }
18060   case ISD::AND: {
18061     ConstantSDNode *Constant =
18062         dyn_cast<ConstantSDNode>(Extend.getOperand(1).getNode());
18063     if (!Constant)
18064       return MVT::Other;
18065 
18066     uint32_t Mask = Constant->getZExtValue();
18067 
18068     if (Mask == UCHAR_MAX)
18069       return MVT::i8;
18070     else if (Mask == USHRT_MAX)
18071       return MVT::i16;
18072     else if (Mask == UINT_MAX)
18073       return MVT::i32;
18074 
18075     return MVT::Other;
18076   }
18077   default:
18078     return MVT::Other;
18079   }
18080 }
18081 
18082 /// Combines a buildvector(sext/zext) or shuffle(sext/zext, undef) node pattern
18083 /// into sext/zext(buildvector) or sext/zext(shuffle) making use of the vector
18084 /// SExt/ZExt rather than the scalar SExt/ZExt
performBuildShuffleExtendCombine(SDValue BV,SelectionDAG & DAG)18085 static SDValue performBuildShuffleExtendCombine(SDValue BV, SelectionDAG &DAG) {
18086   EVT VT = BV.getValueType();
18087   if (BV.getOpcode() != ISD::BUILD_VECTOR &&
18088       BV.getOpcode() != ISD::VECTOR_SHUFFLE)
18089     return SDValue();
18090 
18091   // Use the first item in the buildvector/shuffle to get the size of the
18092   // extend, and make sure it looks valid.
18093   SDValue Extend = BV->getOperand(0);
18094   unsigned ExtendOpcode = Extend.getOpcode();
18095   bool IsSExt = ExtendOpcode == ISD::SIGN_EXTEND ||
18096                 ExtendOpcode == ISD::SIGN_EXTEND_INREG ||
18097                 ExtendOpcode == ISD::AssertSext;
18098   if (!IsSExt && ExtendOpcode != ISD::ZERO_EXTEND &&
18099       ExtendOpcode != ISD::AssertZext && ExtendOpcode != ISD::AND)
18100     return SDValue();
18101   // Shuffle inputs are vector, limit to SIGN_EXTEND and ZERO_EXTEND to ensure
18102   // calculatePreExtendType will work without issue.
18103   if (BV.getOpcode() == ISD::VECTOR_SHUFFLE &&
18104       ExtendOpcode != ISD::SIGN_EXTEND && ExtendOpcode != ISD::ZERO_EXTEND)
18105     return SDValue();
18106 
18107   // Restrict valid pre-extend data type
18108   EVT PreExtendType = calculatePreExtendType(Extend);
18109   if (PreExtendType == MVT::Other ||
18110       PreExtendType.getScalarSizeInBits() != VT.getScalarSizeInBits() / 2)
18111     return SDValue();
18112 
18113   // Make sure all other operands are equally extended
18114   for (SDValue Op : drop_begin(BV->ops())) {
18115     if (Op.isUndef())
18116       continue;
18117     unsigned Opc = Op.getOpcode();
18118     bool OpcIsSExt = Opc == ISD::SIGN_EXTEND || Opc == ISD::SIGN_EXTEND_INREG ||
18119                      Opc == ISD::AssertSext;
18120     if (OpcIsSExt != IsSExt || calculatePreExtendType(Op) != PreExtendType)
18121       return SDValue();
18122   }
18123 
18124   SDValue NBV;
18125   SDLoc DL(BV);
18126   if (BV.getOpcode() == ISD::BUILD_VECTOR) {
18127     EVT PreExtendVT = VT.changeVectorElementType(PreExtendType);
18128     EVT PreExtendLegalType =
18129         PreExtendType.getScalarSizeInBits() < 32 ? MVT::i32 : PreExtendType;
18130     SmallVector<SDValue, 8> NewOps;
18131     for (SDValue Op : BV->ops())
18132       NewOps.push_back(Op.isUndef() ? DAG.getUNDEF(PreExtendLegalType)
18133                                     : DAG.getAnyExtOrTrunc(Op.getOperand(0), DL,
18134                                                            PreExtendLegalType));
18135     NBV = DAG.getNode(ISD::BUILD_VECTOR, DL, PreExtendVT, NewOps);
18136   } else { // BV.getOpcode() == ISD::VECTOR_SHUFFLE
18137     EVT PreExtendVT = VT.changeVectorElementType(PreExtendType.getScalarType());
18138     NBV = DAG.getVectorShuffle(PreExtendVT, DL, BV.getOperand(0).getOperand(0),
18139                                BV.getOperand(1).isUndef()
18140                                    ? DAG.getUNDEF(PreExtendVT)
18141                                    : BV.getOperand(1).getOperand(0),
18142                                cast<ShuffleVectorSDNode>(BV)->getMask());
18143   }
18144   return DAG.getNode(IsSExt ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL, VT, NBV);
18145 }
18146 
18147 /// Combines a mul(dup(sext/zext)) node pattern into mul(sext/zext(dup))
18148 /// making use of the vector SExt/ZExt rather than the scalar SExt/ZExt
performMulVectorExtendCombine(SDNode * Mul,SelectionDAG & DAG)18149 static SDValue performMulVectorExtendCombine(SDNode *Mul, SelectionDAG &DAG) {
18150   // If the value type isn't a vector, none of the operands are going to be dups
18151   EVT VT = Mul->getValueType(0);
18152   if (VT != MVT::v8i16 && VT != MVT::v4i32 && VT != MVT::v2i64)
18153     return SDValue();
18154 
18155   SDValue Op0 = performBuildShuffleExtendCombine(Mul->getOperand(0), DAG);
18156   SDValue Op1 = performBuildShuffleExtendCombine(Mul->getOperand(1), DAG);
18157 
18158   // Neither operands have been changed, don't make any further changes
18159   if (!Op0 && !Op1)
18160     return SDValue();
18161 
18162   SDLoc DL(Mul);
18163   return DAG.getNode(Mul->getOpcode(), DL, VT, Op0 ? Op0 : Mul->getOperand(0),
18164                      Op1 ? Op1 : Mul->getOperand(1));
18165 }
18166 
18167 // Combine v4i32 Mul(And(Srl(X, 15), 0x10001), 0xffff) -> v8i16 CMLTz
18168 // Same for other types with equivalent constants.
performMulVectorCmpZeroCombine(SDNode * N,SelectionDAG & DAG)18169 static SDValue performMulVectorCmpZeroCombine(SDNode *N, SelectionDAG &DAG) {
18170   EVT VT = N->getValueType(0);
18171   if (VT != MVT::v2i64 && VT != MVT::v1i64 && VT != MVT::v2i32 &&
18172       VT != MVT::v4i32 && VT != MVT::v4i16 && VT != MVT::v8i16)
18173     return SDValue();
18174   if (N->getOperand(0).getOpcode() != ISD::AND ||
18175       N->getOperand(0).getOperand(0).getOpcode() != ISD::SRL)
18176     return SDValue();
18177 
18178   SDValue And = N->getOperand(0);
18179   SDValue Srl = And.getOperand(0);
18180 
18181   APInt V1, V2, V3;
18182   if (!ISD::isConstantSplatVector(N->getOperand(1).getNode(), V1) ||
18183       !ISD::isConstantSplatVector(And.getOperand(1).getNode(), V2) ||
18184       !ISD::isConstantSplatVector(Srl.getOperand(1).getNode(), V3))
18185     return SDValue();
18186 
18187   unsigned HalfSize = VT.getScalarSizeInBits() / 2;
18188   if (!V1.isMask(HalfSize) || V2 != (1ULL | 1ULL << HalfSize) ||
18189       V3 != (HalfSize - 1))
18190     return SDValue();
18191 
18192   EVT HalfVT = EVT::getVectorVT(*DAG.getContext(),
18193                                 EVT::getIntegerVT(*DAG.getContext(), HalfSize),
18194                                 VT.getVectorElementCount() * 2);
18195 
18196   SDLoc DL(N);
18197   SDValue In = DAG.getNode(AArch64ISD::NVCAST, DL, HalfVT, Srl.getOperand(0));
18198   SDValue CM = DAG.getNode(AArch64ISD::CMLTz, DL, HalfVT, In);
18199   return DAG.getNode(AArch64ISD::NVCAST, DL, VT, CM);
18200 }
18201 
18202 // Transform vector add(zext i8 to i32, zext i8 to i32)
18203 //  into sext(add(zext(i8 to i16), zext(i8 to i16)) to i32)
18204 // This allows extra uses of saddl/uaddl at the lower vector widths, and less
18205 // extends.
performVectorExtCombine(SDNode * N,SelectionDAG & DAG)18206 static SDValue performVectorExtCombine(SDNode *N, SelectionDAG &DAG) {
18207   EVT VT = N->getValueType(0);
18208   if (!VT.isFixedLengthVector() || VT.getSizeInBits() <= 128 ||
18209       (N->getOperand(0).getOpcode() != ISD::ZERO_EXTEND &&
18210        N->getOperand(0).getOpcode() != ISD::SIGN_EXTEND) ||
18211       (N->getOperand(1).getOpcode() != ISD::ZERO_EXTEND &&
18212        N->getOperand(1).getOpcode() != ISD::SIGN_EXTEND) ||
18213       N->getOperand(0).getOperand(0).getValueType() !=
18214           N->getOperand(1).getOperand(0).getValueType())
18215     return SDValue();
18216 
18217   if (N->getOpcode() == ISD::MUL &&
18218       N->getOperand(0).getOpcode() != N->getOperand(1).getOpcode())
18219     return SDValue();
18220 
18221   SDValue N0 = N->getOperand(0).getOperand(0);
18222   SDValue N1 = N->getOperand(1).getOperand(0);
18223   EVT InVT = N0.getValueType();
18224 
18225   EVT S1 = InVT.getScalarType();
18226   EVT S2 = VT.getScalarType();
18227   if ((S2 == MVT::i32 && S1 == MVT::i8) ||
18228       (S2 == MVT::i64 && (S1 == MVT::i8 || S1 == MVT::i16))) {
18229     SDLoc DL(N);
18230     EVT HalfVT = EVT::getVectorVT(*DAG.getContext(),
18231                                   S2.getHalfSizedIntegerVT(*DAG.getContext()),
18232                                   VT.getVectorElementCount());
18233     SDValue NewN0 = DAG.getNode(N->getOperand(0).getOpcode(), DL, HalfVT, N0);
18234     SDValue NewN1 = DAG.getNode(N->getOperand(1).getOpcode(), DL, HalfVT, N1);
18235     SDValue NewOp = DAG.getNode(N->getOpcode(), DL, HalfVT, NewN0, NewN1);
18236     return DAG.getNode(N->getOpcode() == ISD::MUL ? N->getOperand(0).getOpcode()
18237                                                   : (unsigned)ISD::SIGN_EXTEND,
18238                        DL, VT, NewOp);
18239   }
18240   return SDValue();
18241 }
18242 
performMulCombine(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const AArch64Subtarget * Subtarget)18243 static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
18244                                  TargetLowering::DAGCombinerInfo &DCI,
18245                                  const AArch64Subtarget *Subtarget) {
18246 
18247   if (SDValue Ext = performMulVectorExtendCombine(N, DAG))
18248     return Ext;
18249   if (SDValue Ext = performMulVectorCmpZeroCombine(N, DAG))
18250     return Ext;
18251   if (SDValue Ext = performVectorExtCombine(N, DAG))
18252     return Ext;
18253 
18254   if (DCI.isBeforeLegalizeOps())
18255     return SDValue();
18256 
18257   // Canonicalize X*(Y+1) -> X*Y+X and (X+1)*Y -> X*Y+Y,
18258   // and in MachineCombiner pass, add+mul will be combined into madd.
18259   // Similarly, X*(1-Y) -> X - X*Y and (1-Y)*X -> X - Y*X.
18260   SDLoc DL(N);
18261   EVT VT = N->getValueType(0);
18262   SDValue N0 = N->getOperand(0);
18263   SDValue N1 = N->getOperand(1);
18264   SDValue MulOper;
18265   unsigned AddSubOpc;
18266 
18267   auto IsAddSubWith1 = [&](SDValue V) -> bool {
18268     AddSubOpc = V->getOpcode();
18269     if ((AddSubOpc == ISD::ADD || AddSubOpc == ISD::SUB) && V->hasOneUse()) {
18270       SDValue Opnd = V->getOperand(1);
18271       MulOper = V->getOperand(0);
18272       if (AddSubOpc == ISD::SUB)
18273         std::swap(Opnd, MulOper);
18274       if (auto C = dyn_cast<ConstantSDNode>(Opnd))
18275         return C->isOne();
18276     }
18277     return false;
18278   };
18279 
18280   if (IsAddSubWith1(N0)) {
18281     SDValue MulVal = DAG.getNode(ISD::MUL, DL, VT, N1, MulOper);
18282     return DAG.getNode(AddSubOpc, DL, VT, N1, MulVal);
18283   }
18284 
18285   if (IsAddSubWith1(N1)) {
18286     SDValue MulVal = DAG.getNode(ISD::MUL, DL, VT, N0, MulOper);
18287     return DAG.getNode(AddSubOpc, DL, VT, N0, MulVal);
18288   }
18289 
18290   // The below optimizations require a constant RHS.
18291   if (!isa<ConstantSDNode>(N1))
18292     return SDValue();
18293 
18294   ConstantSDNode *C = cast<ConstantSDNode>(N1);
18295   const APInt &ConstValue = C->getAPIntValue();
18296 
18297   // Allow the scaling to be folded into the `cnt` instruction by preventing
18298   // the scaling to be obscured here. This makes it easier to pattern match.
18299   if (IsSVECntIntrinsic(N0) ||
18300      (N0->getOpcode() == ISD::TRUNCATE &&
18301       (IsSVECntIntrinsic(N0->getOperand(0)))))
18302        if (ConstValue.sge(1) && ConstValue.sle(16))
18303          return SDValue();
18304 
18305   // Multiplication of a power of two plus/minus one can be done more
18306   // cheaply as shift+add/sub. For now, this is true unilaterally. If
18307   // future CPUs have a cheaper MADD instruction, this may need to be
18308   // gated on a subtarget feature. For Cyclone, 32-bit MADD is 4 cycles and
18309   // 64-bit is 5 cycles, so this is always a win.
18310   // More aggressively, some multiplications N0 * C can be lowered to
18311   // shift+add+shift if the constant C = A * B where A = 2^N + 1 and B = 2^M,
18312   // e.g. 6=3*2=(2+1)*2, 45=(1+4)*(1+8)
18313   // TODO: lower more cases.
18314 
18315   // TrailingZeroes is used to test if the mul can be lowered to
18316   // shift+add+shift.
18317   unsigned TrailingZeroes = ConstValue.countr_zero();
18318   if (TrailingZeroes) {
18319     // Conservatively do not lower to shift+add+shift if the mul might be
18320     // folded into smul or umul.
18321     if (N0->hasOneUse() && (isSignExtended(N0, DAG) ||
18322                             isZeroExtended(N0, DAG)))
18323       return SDValue();
18324     // Conservatively do not lower to shift+add+shift if the mul might be
18325     // folded into madd or msub.
18326     if (N->hasOneUse() && (N->use_begin()->getOpcode() == ISD::ADD ||
18327                            N->use_begin()->getOpcode() == ISD::SUB))
18328       return SDValue();
18329   }
18330   // Use ShiftedConstValue instead of ConstValue to support both shift+add/sub
18331   // and shift+add+shift.
18332   APInt ShiftedConstValue = ConstValue.ashr(TrailingZeroes);
18333   unsigned ShiftAmt;
18334 
18335   auto Shl = [&](SDValue N0, unsigned N1) {
18336     if (!N0.getNode())
18337       return SDValue();
18338     // If shift causes overflow, ignore this combine.
18339     if (N1 >= N0.getValueSizeInBits())
18340       return SDValue();
18341     SDValue RHS = DAG.getConstant(N1, DL, MVT::i64);
18342     return DAG.getNode(ISD::SHL, DL, VT, N0, RHS);
18343   };
18344   auto Add = [&](SDValue N0, SDValue N1) {
18345     if (!N0.getNode() || !N1.getNode())
18346       return SDValue();
18347     return DAG.getNode(ISD::ADD, DL, VT, N0, N1);
18348   };
18349   auto Sub = [&](SDValue N0, SDValue N1) {
18350     if (!N0.getNode() || !N1.getNode())
18351       return SDValue();
18352     return DAG.getNode(ISD::SUB, DL, VT, N0, N1);
18353   };
18354   auto Negate = [&](SDValue N) {
18355     if (!N0.getNode())
18356       return SDValue();
18357     SDValue Zero = DAG.getConstant(0, DL, VT);
18358     return DAG.getNode(ISD::SUB, DL, VT, Zero, N);
18359   };
18360 
18361   // Can the const C be decomposed into (1+2^M1)*(1+2^N1), eg:
18362   // C = 45 is equal to (1+4)*(1+8), we don't decompose it into (1+2)*(16-1) as
18363   // the (2^N - 1) can't be execused via a single instruction.
18364   auto isPowPlusPlusConst = [](APInt C, APInt &M, APInt &N) {
18365     unsigned BitWidth = C.getBitWidth();
18366     for (unsigned i = 1; i < BitWidth / 2; i++) {
18367       APInt Rem;
18368       APInt X(BitWidth, (1 << i) + 1);
18369       APInt::sdivrem(C, X, N, Rem);
18370       APInt NVMinus1 = N - 1;
18371       if (Rem == 0 && NVMinus1.isPowerOf2()) {
18372         M = X;
18373         return true;
18374       }
18375     }
18376     return false;
18377   };
18378 
18379   // Can the const C be decomposed into (2^M + 1) * 2^N + 1), eg:
18380   // C = 11 is equal to (1+4)*2+1, we don't decompose it into (1+2)*4-1 as
18381   // the (2^N - 1) can't be execused via a single instruction.
18382   auto isPowPlusPlusOneConst = [](APInt C, APInt &M, APInt &N) {
18383     APInt CVMinus1 = C - 1;
18384     if (CVMinus1.isNegative())
18385       return false;
18386     unsigned TrailingZeroes = CVMinus1.countr_zero();
18387     APInt SCVMinus1 = CVMinus1.ashr(TrailingZeroes) - 1;
18388     if (SCVMinus1.isPowerOf2()) {
18389       unsigned BitWidth = SCVMinus1.getBitWidth();
18390       M = APInt(BitWidth, SCVMinus1.logBase2());
18391       N = APInt(BitWidth, TrailingZeroes);
18392       return true;
18393     }
18394     return false;
18395   };
18396 
18397   // Can the const C be decomposed into (1 - (1 - 2^M) * 2^N), eg:
18398   // C = 29 is equal to 1 - (1 - 2^3) * 2^2.
18399   auto isPowMinusMinusOneConst = [](APInt C, APInt &M, APInt &N) {
18400     APInt CVMinus1 = C - 1;
18401     if (CVMinus1.isNegative())
18402       return false;
18403     unsigned TrailingZeroes = CVMinus1.countr_zero();
18404     APInt CVPlus1 = CVMinus1.ashr(TrailingZeroes) + 1;
18405     if (CVPlus1.isPowerOf2()) {
18406       unsigned BitWidth = CVPlus1.getBitWidth();
18407       M = APInt(BitWidth, CVPlus1.logBase2());
18408       N = APInt(BitWidth, TrailingZeroes);
18409       return true;
18410     }
18411     return false;
18412   };
18413 
18414   if (ConstValue.isNonNegative()) {
18415     // (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M)
18416     // (mul x, 2^N - 1) => (sub (shl x, N), x)
18417     // (mul x, (2^(N-M) - 1) * 2^M) => (sub (shl x, N), (shl x, M))
18418     // (mul x, (2^M + 1) * (2^N + 1))
18419     //     => MV = (add (shl x, M), x); (add (shl MV, N), MV)
18420     // (mul x, (2^M + 1) * 2^N + 1))
18421     //     =>  MV = add (shl x, M), x); add (shl MV, N), x)
18422     // (mul x, 1 - (1 - 2^M) * 2^N))
18423     //     =>  MV = sub (x - (shl x, M)); sub (x - (shl MV, N))
18424     APInt SCVMinus1 = ShiftedConstValue - 1;
18425     APInt SCVPlus1 = ShiftedConstValue + 1;
18426     APInt CVPlus1 = ConstValue + 1;
18427     APInt CVM, CVN;
18428     if (SCVMinus1.isPowerOf2()) {
18429       ShiftAmt = SCVMinus1.logBase2();
18430       return Shl(Add(Shl(N0, ShiftAmt), N0), TrailingZeroes);
18431     } else if (CVPlus1.isPowerOf2()) {
18432       ShiftAmt = CVPlus1.logBase2();
18433       return Sub(Shl(N0, ShiftAmt), N0);
18434     } else if (SCVPlus1.isPowerOf2()) {
18435       ShiftAmt = SCVPlus1.logBase2() + TrailingZeroes;
18436       return Sub(Shl(N0, ShiftAmt), Shl(N0, TrailingZeroes));
18437     }
18438     if (Subtarget->hasALULSLFast() &&
18439         isPowPlusPlusConst(ConstValue, CVM, CVN)) {
18440       APInt CVMMinus1 = CVM - 1;
18441       APInt CVNMinus1 = CVN - 1;
18442       unsigned ShiftM1 = CVMMinus1.logBase2();
18443       unsigned ShiftN1 = CVNMinus1.logBase2();
18444       // ALULSLFast implicate that Shifts <= 4 places are fast
18445       if (ShiftM1 <= 4 && ShiftN1 <= 4) {
18446         SDValue MVal = Add(Shl(N0, ShiftM1), N0);
18447         return Add(Shl(MVal, ShiftN1), MVal);
18448       }
18449     }
18450     if (Subtarget->hasALULSLFast() &&
18451         isPowPlusPlusOneConst(ConstValue, CVM, CVN)) {
18452       unsigned ShiftM = CVM.getZExtValue();
18453       unsigned ShiftN = CVN.getZExtValue();
18454       // ALULSLFast implicate that Shifts <= 4 places are fast
18455       if (ShiftM <= 4 && ShiftN <= 4) {
18456         SDValue MVal = Add(Shl(N0, CVM.getZExtValue()), N0);
18457         return Add(Shl(MVal, CVN.getZExtValue()), N0);
18458       }
18459     }
18460 
18461     if (Subtarget->hasALULSLFast() &&
18462         isPowMinusMinusOneConst(ConstValue, CVM, CVN)) {
18463       unsigned ShiftM = CVM.getZExtValue();
18464       unsigned ShiftN = CVN.getZExtValue();
18465       // ALULSLFast implicate that Shifts <= 4 places are fast
18466       if (ShiftM <= 4 && ShiftN <= 4) {
18467         SDValue MVal = Sub(N0, Shl(N0, CVM.getZExtValue()));
18468         return Sub(N0, Shl(MVal, CVN.getZExtValue()));
18469       }
18470     }
18471   } else {
18472     // (mul x, -(2^N - 1)) => (sub x, (shl x, N))
18473     // (mul x, -(2^N + 1)) => - (add (shl x, N), x)
18474     // (mul x, -(2^(N-M) - 1) * 2^M) => (sub (shl x, M), (shl x, N))
18475     APInt SCVPlus1 = -ShiftedConstValue + 1;
18476     APInt CVNegPlus1 = -ConstValue + 1;
18477     APInt CVNegMinus1 = -ConstValue - 1;
18478     if (CVNegPlus1.isPowerOf2()) {
18479       ShiftAmt = CVNegPlus1.logBase2();
18480       return Sub(N0, Shl(N0, ShiftAmt));
18481     } else if (CVNegMinus1.isPowerOf2()) {
18482       ShiftAmt = CVNegMinus1.logBase2();
18483       return Negate(Add(Shl(N0, ShiftAmt), N0));
18484     } else if (SCVPlus1.isPowerOf2()) {
18485       ShiftAmt = SCVPlus1.logBase2() + TrailingZeroes;
18486       return Sub(Shl(N0, TrailingZeroes), Shl(N0, ShiftAmt));
18487     }
18488   }
18489 
18490   return SDValue();
18491 }
18492 
performVectorCompareAndMaskUnaryOpCombine(SDNode * N,SelectionDAG & DAG)18493 static SDValue performVectorCompareAndMaskUnaryOpCombine(SDNode *N,
18494                                                          SelectionDAG &DAG) {
18495   // Take advantage of vector comparisons producing 0 or -1 in each lane to
18496   // optimize away operation when it's from a constant.
18497   //
18498   // The general transformation is:
18499   //    UNARYOP(AND(VECTOR_CMP(x,y), constant)) -->
18500   //       AND(VECTOR_CMP(x,y), constant2)
18501   //    constant2 = UNARYOP(constant)
18502 
18503   // Early exit if this isn't a vector operation, the operand of the
18504   // unary operation isn't a bitwise AND, or if the sizes of the operations
18505   // aren't the same.
18506   EVT VT = N->getValueType(0);
18507   if (!VT.isVector() || N->getOperand(0)->getOpcode() != ISD::AND ||
18508       N->getOperand(0)->getOperand(0)->getOpcode() != ISD::SETCC ||
18509       VT.getSizeInBits() != N->getOperand(0)->getValueType(0).getSizeInBits())
18510     return SDValue();
18511 
18512   // Now check that the other operand of the AND is a constant. We could
18513   // make the transformation for non-constant splats as well, but it's unclear
18514   // that would be a benefit as it would not eliminate any operations, just
18515   // perform one more step in scalar code before moving to the vector unit.
18516   if (BuildVectorSDNode *BV =
18517           dyn_cast<BuildVectorSDNode>(N->getOperand(0)->getOperand(1))) {
18518     // Bail out if the vector isn't a constant.
18519     if (!BV->isConstant())
18520       return SDValue();
18521 
18522     // Everything checks out. Build up the new and improved node.
18523     SDLoc DL(N);
18524     EVT IntVT = BV->getValueType(0);
18525     // Create a new constant of the appropriate type for the transformed
18526     // DAG.
18527     SDValue SourceConst = DAG.getNode(N->getOpcode(), DL, VT, SDValue(BV, 0));
18528     // The AND node needs bitcasts to/from an integer vector type around it.
18529     SDValue MaskConst = DAG.getNode(ISD::BITCAST, DL, IntVT, SourceConst);
18530     SDValue NewAnd = DAG.getNode(ISD::AND, DL, IntVT,
18531                                  N->getOperand(0)->getOperand(0), MaskConst);
18532     SDValue Res = DAG.getNode(ISD::BITCAST, DL, VT, NewAnd);
18533     return Res;
18534   }
18535 
18536   return SDValue();
18537 }
18538 
performIntToFpCombine(SDNode * N,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)18539 static SDValue performIntToFpCombine(SDNode *N, SelectionDAG &DAG,
18540                                      const AArch64Subtarget *Subtarget) {
18541   // First try to optimize away the conversion when it's conditionally from
18542   // a constant. Vectors only.
18543   if (SDValue Res = performVectorCompareAndMaskUnaryOpCombine(N, DAG))
18544     return Res;
18545 
18546   EVT VT = N->getValueType(0);
18547   if (VT != MVT::f32 && VT != MVT::f64)
18548     return SDValue();
18549 
18550   // Only optimize when the source and destination types have the same width.
18551   if (VT.getSizeInBits() != N->getOperand(0).getValueSizeInBits())
18552     return SDValue();
18553 
18554   // If the result of an integer load is only used by an integer-to-float
18555   // conversion, use a fp load instead and a AdvSIMD scalar {S|U}CVTF instead.
18556   // This eliminates an "integer-to-vector-move" UOP and improves throughput.
18557   SDValue N0 = N->getOperand(0);
18558   if (Subtarget->isNeonAvailable() && ISD::isNormalLoad(N0.getNode()) &&
18559       N0.hasOneUse() &&
18560       // Do not change the width of a volatile load.
18561       !cast<LoadSDNode>(N0)->isVolatile()) {
18562     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
18563     SDValue Load = DAG.getLoad(VT, SDLoc(N), LN0->getChain(), LN0->getBasePtr(),
18564                                LN0->getPointerInfo(), LN0->getAlign(),
18565                                LN0->getMemOperand()->getFlags());
18566 
18567     // Make sure successors of the original load stay after it by updating them
18568     // to use the new Chain.
18569     DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), Load.getValue(1));
18570 
18571     unsigned Opcode =
18572         (N->getOpcode() == ISD::SINT_TO_FP) ? AArch64ISD::SITOF : AArch64ISD::UITOF;
18573     return DAG.getNode(Opcode, SDLoc(N), VT, Load);
18574   }
18575 
18576   return SDValue();
18577 }
18578 
18579 /// Fold a floating-point multiply by power of two into floating-point to
18580 /// fixed-point conversion.
performFpToIntCombine(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const AArch64Subtarget * Subtarget)18581 static SDValue performFpToIntCombine(SDNode *N, SelectionDAG &DAG,
18582                                      TargetLowering::DAGCombinerInfo &DCI,
18583                                      const AArch64Subtarget *Subtarget) {
18584   if (!Subtarget->isNeonAvailable())
18585     return SDValue();
18586 
18587   if (!N->getValueType(0).isSimple())
18588     return SDValue();
18589 
18590   SDValue Op = N->getOperand(0);
18591   if (!Op.getValueType().isSimple() || Op.getOpcode() != ISD::FMUL)
18592     return SDValue();
18593 
18594   if (!Op.getValueType().is64BitVector() && !Op.getValueType().is128BitVector())
18595     return SDValue();
18596 
18597   SDValue ConstVec = Op->getOperand(1);
18598   if (!isa<BuildVectorSDNode>(ConstVec))
18599     return SDValue();
18600 
18601   MVT FloatTy = Op.getSimpleValueType().getVectorElementType();
18602   uint32_t FloatBits = FloatTy.getSizeInBits();
18603   if (FloatBits != 32 && FloatBits != 64 &&
18604       (FloatBits != 16 || !Subtarget->hasFullFP16()))
18605     return SDValue();
18606 
18607   MVT IntTy = N->getSimpleValueType(0).getVectorElementType();
18608   uint32_t IntBits = IntTy.getSizeInBits();
18609   if (IntBits != 16 && IntBits != 32 && IntBits != 64)
18610     return SDValue();
18611 
18612   // Avoid conversions where iN is larger than the float (e.g., float -> i64).
18613   if (IntBits > FloatBits)
18614     return SDValue();
18615 
18616   BitVector UndefElements;
18617   BuildVectorSDNode *BV = cast<BuildVectorSDNode>(ConstVec);
18618   int32_t Bits = IntBits == 64 ? 64 : 32;
18619   int32_t C = BV->getConstantFPSplatPow2ToLog2Int(&UndefElements, Bits + 1);
18620   if (C == -1 || C == 0 || C > Bits)
18621     return SDValue();
18622 
18623   EVT ResTy = Op.getValueType().changeVectorElementTypeToInteger();
18624   if (!DAG.getTargetLoweringInfo().isTypeLegal(ResTy))
18625     return SDValue();
18626 
18627   if (N->getOpcode() == ISD::FP_TO_SINT_SAT ||
18628       N->getOpcode() == ISD::FP_TO_UINT_SAT) {
18629     EVT SatVT = cast<VTSDNode>(N->getOperand(1))->getVT();
18630     if (SatVT.getScalarSizeInBits() != IntBits || IntBits != FloatBits)
18631       return SDValue();
18632   }
18633 
18634   SDLoc DL(N);
18635   bool IsSigned = (N->getOpcode() == ISD::FP_TO_SINT ||
18636                    N->getOpcode() == ISD::FP_TO_SINT_SAT);
18637   unsigned IntrinsicOpcode = IsSigned ? Intrinsic::aarch64_neon_vcvtfp2fxs
18638                                       : Intrinsic::aarch64_neon_vcvtfp2fxu;
18639   SDValue FixConv =
18640       DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ResTy,
18641                   DAG.getConstant(IntrinsicOpcode, DL, MVT::i32),
18642                   Op->getOperand(0), DAG.getConstant(C, DL, MVT::i32));
18643   // We can handle smaller integers by generating an extra trunc.
18644   if (IntBits < FloatBits)
18645     FixConv = DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), FixConv);
18646 
18647   return FixConv;
18648 }
18649 
tryCombineToBSL(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,const AArch64TargetLowering & TLI)18650 static SDValue tryCombineToBSL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
18651                                const AArch64TargetLowering &TLI) {
18652   EVT VT = N->getValueType(0);
18653   SelectionDAG &DAG = DCI.DAG;
18654   SDLoc DL(N);
18655   const auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
18656 
18657   if (!VT.isVector())
18658     return SDValue();
18659 
18660   if (VT.isScalableVector() && !Subtarget.hasSVE2())
18661     return SDValue();
18662 
18663   if (VT.isFixedLengthVector() &&
18664       (!Subtarget.isNeonAvailable() || TLI.useSVEForFixedLengthVectorVT(VT)))
18665     return SDValue();
18666 
18667   SDValue N0 = N->getOperand(0);
18668   if (N0.getOpcode() != ISD::AND)
18669     return SDValue();
18670 
18671   SDValue N1 = N->getOperand(1);
18672   if (N1.getOpcode() != ISD::AND)
18673     return SDValue();
18674 
18675   // InstCombine does (not (neg a)) => (add a -1).
18676   // Try: (or (and (neg a) b) (and (add a -1) c)) => (bsl (neg a) b c)
18677   // Loop over all combinations of AND operands.
18678   for (int i = 1; i >= 0; --i) {
18679     for (int j = 1; j >= 0; --j) {
18680       SDValue O0 = N0->getOperand(i);
18681       SDValue O1 = N1->getOperand(j);
18682       SDValue Sub, Add, SubSibling, AddSibling;
18683 
18684       // Find a SUB and an ADD operand, one from each AND.
18685       if (O0.getOpcode() == ISD::SUB && O1.getOpcode() == ISD::ADD) {
18686         Sub = O0;
18687         Add = O1;
18688         SubSibling = N0->getOperand(1 - i);
18689         AddSibling = N1->getOperand(1 - j);
18690       } else if (O0.getOpcode() == ISD::ADD && O1.getOpcode() == ISD::SUB) {
18691         Add = O0;
18692         Sub = O1;
18693         AddSibling = N0->getOperand(1 - i);
18694         SubSibling = N1->getOperand(1 - j);
18695       } else
18696         continue;
18697 
18698       if (!ISD::isConstantSplatVectorAllZeros(Sub.getOperand(0).getNode()))
18699         continue;
18700 
18701       // Constant ones is always righthand operand of the Add.
18702       if (!ISD::isConstantSplatVectorAllOnes(Add.getOperand(1).getNode()))
18703         continue;
18704 
18705       if (Sub.getOperand(1) != Add.getOperand(0))
18706         continue;
18707 
18708       return DAG.getNode(AArch64ISD::BSP, DL, VT, Sub, SubSibling, AddSibling);
18709     }
18710   }
18711 
18712   // (or (and a b) (and (not a) c)) => (bsl a b c)
18713   // We only have to look for constant vectors here since the general, variable
18714   // case can be handled in TableGen.
18715   unsigned Bits = VT.getScalarSizeInBits();
18716   uint64_t BitMask = Bits == 64 ? -1ULL : ((1ULL << Bits) - 1);
18717   for (int i = 1; i >= 0; --i)
18718     for (int j = 1; j >= 0; --j) {
18719       APInt Val1, Val2;
18720 
18721       if (ISD::isConstantSplatVector(N0->getOperand(i).getNode(), Val1) &&
18722           ISD::isConstantSplatVector(N1->getOperand(j).getNode(), Val2) &&
18723           (BitMask & ~Val1.getZExtValue()) == Val2.getZExtValue()) {
18724         return DAG.getNode(AArch64ISD::BSP, DL, VT, N0->getOperand(i),
18725                            N0->getOperand(1 - i), N1->getOperand(1 - j));
18726       }
18727       BuildVectorSDNode *BVN0 = dyn_cast<BuildVectorSDNode>(N0->getOperand(i));
18728       BuildVectorSDNode *BVN1 = dyn_cast<BuildVectorSDNode>(N1->getOperand(j));
18729       if (!BVN0 || !BVN1)
18730         continue;
18731 
18732       bool FoundMatch = true;
18733       for (unsigned k = 0; k < VT.getVectorNumElements(); ++k) {
18734         ConstantSDNode *CN0 = dyn_cast<ConstantSDNode>(BVN0->getOperand(k));
18735         ConstantSDNode *CN1 = dyn_cast<ConstantSDNode>(BVN1->getOperand(k));
18736         if (!CN0 || !CN1 ||
18737             CN0->getZExtValue() != (BitMask & ~CN1->getZExtValue())) {
18738           FoundMatch = false;
18739           break;
18740         }
18741       }
18742       if (FoundMatch)
18743         return DAG.getNode(AArch64ISD::BSP, DL, VT, N0->getOperand(i),
18744                            N0->getOperand(1 - i), N1->getOperand(1 - j));
18745     }
18746 
18747   return SDValue();
18748 }
18749 
18750 // Given a tree of and/or(csel(0, 1, cc0), csel(0, 1, cc1)), we may be able to
18751 // convert to csel(ccmp(.., cc0)), depending on cc1:
18752 
18753 // (AND (CSET cc0 cmp0) (CSET cc1 (CMP x1 y1)))
18754 // =>
18755 // (CSET cc1 (CCMP x1 y1 !cc1 cc0 cmp0))
18756 //
18757 // (OR (CSET cc0 cmp0) (CSET cc1 (CMP x1 y1)))
18758 // =>
18759 // (CSET cc1 (CCMP x1 y1 cc1 !cc0 cmp0))
performANDORCSELCombine(SDNode * N,SelectionDAG & DAG)18760 static SDValue performANDORCSELCombine(SDNode *N, SelectionDAG &DAG) {
18761   EVT VT = N->getValueType(0);
18762   SDValue CSel0 = N->getOperand(0);
18763   SDValue CSel1 = N->getOperand(1);
18764 
18765   if (CSel0.getOpcode() != AArch64ISD::CSEL ||
18766       CSel1.getOpcode() != AArch64ISD::CSEL)
18767     return SDValue();
18768 
18769   if (!CSel0->hasOneUse() || !CSel1->hasOneUse())
18770     return SDValue();
18771 
18772   if (!isNullConstant(CSel0.getOperand(0)) ||
18773       !isOneConstant(CSel0.getOperand(1)) ||
18774       !isNullConstant(CSel1.getOperand(0)) ||
18775       !isOneConstant(CSel1.getOperand(1)))
18776     return SDValue();
18777 
18778   SDValue Cmp0 = CSel0.getOperand(3);
18779   SDValue Cmp1 = CSel1.getOperand(3);
18780   AArch64CC::CondCode CC0 = (AArch64CC::CondCode)CSel0.getConstantOperandVal(2);
18781   AArch64CC::CondCode CC1 = (AArch64CC::CondCode)CSel1.getConstantOperandVal(2);
18782   if (!Cmp0->hasOneUse() || !Cmp1->hasOneUse())
18783     return SDValue();
18784   if (Cmp1.getOpcode() != AArch64ISD::SUBS &&
18785       Cmp0.getOpcode() == AArch64ISD::SUBS) {
18786     std::swap(Cmp0, Cmp1);
18787     std::swap(CC0, CC1);
18788   }
18789 
18790   if (Cmp1.getOpcode() != AArch64ISD::SUBS)
18791     return SDValue();
18792 
18793   SDLoc DL(N);
18794   SDValue CCmp, Condition;
18795   unsigned NZCV;
18796 
18797   if (N->getOpcode() == ISD::AND) {
18798     AArch64CC::CondCode InvCC0 = AArch64CC::getInvertedCondCode(CC0);
18799     Condition = DAG.getConstant(InvCC0, DL, MVT_CC);
18800     NZCV = AArch64CC::getNZCVToSatisfyCondCode(CC1);
18801   } else {
18802     AArch64CC::CondCode InvCC1 = AArch64CC::getInvertedCondCode(CC1);
18803     Condition = DAG.getConstant(CC0, DL, MVT_CC);
18804     NZCV = AArch64CC::getNZCVToSatisfyCondCode(InvCC1);
18805   }
18806 
18807   SDValue NZCVOp = DAG.getConstant(NZCV, DL, MVT::i32);
18808 
18809   auto *Op1 = dyn_cast<ConstantSDNode>(Cmp1.getOperand(1));
18810   if (Op1 && Op1->getAPIntValue().isNegative() &&
18811       Op1->getAPIntValue().sgt(-32)) {
18812     // CCMP accept the constant int the range [0, 31]
18813     // if the Op1 is a constant in the range [-31, -1], we
18814     // can select to CCMN to avoid the extra mov
18815     SDValue AbsOp1 =
18816         DAG.getConstant(Op1->getAPIntValue().abs(), DL, Op1->getValueType(0));
18817     CCmp = DAG.getNode(AArch64ISD::CCMN, DL, MVT_CC, Cmp1.getOperand(0), AbsOp1,
18818                        NZCVOp, Condition, Cmp0);
18819   } else {
18820     CCmp = DAG.getNode(AArch64ISD::CCMP, DL, MVT_CC, Cmp1.getOperand(0),
18821                        Cmp1.getOperand(1), NZCVOp, Condition, Cmp0);
18822   }
18823   return DAG.getNode(AArch64ISD::CSEL, DL, VT, CSel0.getOperand(0),
18824                      CSel0.getOperand(1), DAG.getConstant(CC1, DL, MVT::i32),
18825                      CCmp);
18826 }
18827 
performORCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,const AArch64Subtarget * Subtarget,const AArch64TargetLowering & TLI)18828 static SDValue performORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
18829                                 const AArch64Subtarget *Subtarget,
18830                                 const AArch64TargetLowering &TLI) {
18831   SelectionDAG &DAG = DCI.DAG;
18832   EVT VT = N->getValueType(0);
18833 
18834   if (SDValue R = performANDORCSELCombine(N, DAG))
18835     return R;
18836 
18837   if (!DAG.getTargetLoweringInfo().isTypeLegal(VT))
18838     return SDValue();
18839 
18840   if (SDValue Res = tryCombineToBSL(N, DCI, TLI))
18841     return Res;
18842 
18843   return SDValue();
18844 }
18845 
isConstantSplatVectorMaskForType(SDNode * N,EVT MemVT)18846 static bool isConstantSplatVectorMaskForType(SDNode *N, EVT MemVT) {
18847   if (!MemVT.getVectorElementType().isSimple())
18848     return false;
18849 
18850   uint64_t MaskForTy = 0ull;
18851   switch (MemVT.getVectorElementType().getSimpleVT().SimpleTy) {
18852   case MVT::i8:
18853     MaskForTy = 0xffull;
18854     break;
18855   case MVT::i16:
18856     MaskForTy = 0xffffull;
18857     break;
18858   case MVT::i32:
18859     MaskForTy = 0xffffffffull;
18860     break;
18861   default:
18862     return false;
18863     break;
18864   }
18865 
18866   if (N->getOpcode() == AArch64ISD::DUP || N->getOpcode() == ISD::SPLAT_VECTOR)
18867     if (auto *Op0 = dyn_cast<ConstantSDNode>(N->getOperand(0)))
18868       return Op0->getAPIntValue().getLimitedValue() == MaskForTy;
18869 
18870   return false;
18871 }
18872 
performReinterpretCastCombine(SDNode * N)18873 static SDValue performReinterpretCastCombine(SDNode *N) {
18874   SDValue LeafOp = SDValue(N, 0);
18875   SDValue Op = N->getOperand(0);
18876   while (Op.getOpcode() == AArch64ISD::REINTERPRET_CAST &&
18877          LeafOp.getValueType() != Op.getValueType())
18878     Op = Op->getOperand(0);
18879   if (LeafOp.getValueType() == Op.getValueType())
18880     return Op;
18881   return SDValue();
18882 }
18883 
performSVEAndCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)18884 static SDValue performSVEAndCombine(SDNode *N,
18885                                     TargetLowering::DAGCombinerInfo &DCI) {
18886   SelectionDAG &DAG = DCI.DAG;
18887   SDValue Src = N->getOperand(0);
18888   unsigned Opc = Src->getOpcode();
18889 
18890   // Zero/any extend of an unsigned unpack
18891   if (Opc == AArch64ISD::UUNPKHI || Opc == AArch64ISD::UUNPKLO) {
18892     SDValue UnpkOp = Src->getOperand(0);
18893     SDValue Dup = N->getOperand(1);
18894 
18895     if (Dup.getOpcode() != ISD::SPLAT_VECTOR)
18896       return SDValue();
18897 
18898     SDLoc DL(N);
18899     ConstantSDNode *C = dyn_cast<ConstantSDNode>(Dup->getOperand(0));
18900     if (!C)
18901       return SDValue();
18902 
18903     uint64_t ExtVal = C->getZExtValue();
18904 
18905     auto MaskAndTypeMatch = [ExtVal](EVT VT) -> bool {
18906       return ((ExtVal == 0xFF && VT == MVT::i8) ||
18907               (ExtVal == 0xFFFF && VT == MVT::i16) ||
18908               (ExtVal == 0xFFFFFFFF && VT == MVT::i32));
18909     };
18910 
18911     // If the mask is fully covered by the unpack, we don't need to push
18912     // a new AND onto the operand
18913     EVT EltTy = UnpkOp->getValueType(0).getVectorElementType();
18914     if (MaskAndTypeMatch(EltTy))
18915       return Src;
18916 
18917     // If this is 'and (uunpklo/hi (extload MemTy -> ExtTy)), mask', then check
18918     // to see if the mask is all-ones of size MemTy.
18919     auto MaskedLoadOp = dyn_cast<MaskedLoadSDNode>(UnpkOp);
18920     if (MaskedLoadOp && (MaskedLoadOp->getExtensionType() == ISD::ZEXTLOAD ||
18921                          MaskedLoadOp->getExtensionType() == ISD::EXTLOAD)) {
18922       EVT EltTy = MaskedLoadOp->getMemoryVT().getVectorElementType();
18923       if (MaskAndTypeMatch(EltTy))
18924         return Src;
18925     }
18926 
18927     // Truncate to prevent a DUP with an over wide constant
18928     APInt Mask = C->getAPIntValue().trunc(EltTy.getSizeInBits());
18929 
18930     // Otherwise, make sure we propagate the AND to the operand
18931     // of the unpack
18932     Dup = DAG.getNode(ISD::SPLAT_VECTOR, DL, UnpkOp->getValueType(0),
18933                       DAG.getConstant(Mask.zextOrTrunc(32), DL, MVT::i32));
18934 
18935     SDValue And = DAG.getNode(ISD::AND, DL,
18936                               UnpkOp->getValueType(0), UnpkOp, Dup);
18937 
18938     return DAG.getNode(Opc, DL, N->getValueType(0), And);
18939   }
18940 
18941   if (DCI.isBeforeLegalizeOps())
18942     return SDValue();
18943 
18944   // If both sides of AND operations are i1 splat_vectors then
18945   // we can produce just i1 splat_vector as the result.
18946   if (isAllActivePredicate(DAG, N->getOperand(0)))
18947     return N->getOperand(1);
18948   if (isAllActivePredicate(DAG, N->getOperand(1)))
18949     return N->getOperand(0);
18950 
18951   if (!EnableCombineMGatherIntrinsics)
18952     return SDValue();
18953 
18954   SDValue Mask = N->getOperand(1);
18955 
18956   if (!Src.hasOneUse())
18957     return SDValue();
18958 
18959   EVT MemVT;
18960 
18961   // SVE load instructions perform an implicit zero-extend, which makes them
18962   // perfect candidates for combining.
18963   switch (Opc) {
18964   case AArch64ISD::LD1_MERGE_ZERO:
18965   case AArch64ISD::LDNF1_MERGE_ZERO:
18966   case AArch64ISD::LDFF1_MERGE_ZERO:
18967     MemVT = cast<VTSDNode>(Src->getOperand(3))->getVT();
18968     break;
18969   case AArch64ISD::GLD1_MERGE_ZERO:
18970   case AArch64ISD::GLD1_SCALED_MERGE_ZERO:
18971   case AArch64ISD::GLD1_SXTW_MERGE_ZERO:
18972   case AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO:
18973   case AArch64ISD::GLD1_UXTW_MERGE_ZERO:
18974   case AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO:
18975   case AArch64ISD::GLD1_IMM_MERGE_ZERO:
18976   case AArch64ISD::GLDFF1_MERGE_ZERO:
18977   case AArch64ISD::GLDFF1_SCALED_MERGE_ZERO:
18978   case AArch64ISD::GLDFF1_SXTW_MERGE_ZERO:
18979   case AArch64ISD::GLDFF1_SXTW_SCALED_MERGE_ZERO:
18980   case AArch64ISD::GLDFF1_UXTW_MERGE_ZERO:
18981   case AArch64ISD::GLDFF1_UXTW_SCALED_MERGE_ZERO:
18982   case AArch64ISD::GLDFF1_IMM_MERGE_ZERO:
18983   case AArch64ISD::GLDNT1_MERGE_ZERO:
18984     MemVT = cast<VTSDNode>(Src->getOperand(4))->getVT();
18985     break;
18986   default:
18987     return SDValue();
18988   }
18989 
18990   if (isConstantSplatVectorMaskForType(Mask.getNode(), MemVT))
18991     return Src;
18992 
18993   return SDValue();
18994 }
18995 
18996 // Transform and(fcmp(a, b), fcmp(c, d)) into fccmp(fcmp(a, b), c, d)
performANDSETCCCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)18997 static SDValue performANDSETCCCombine(SDNode *N,
18998                                       TargetLowering::DAGCombinerInfo &DCI) {
18999 
19000   // This function performs an optimization on a specific pattern involving
19001   // an AND operation and SETCC (Set Condition Code) node.
19002 
19003   SDValue SetCC = N->getOperand(0);
19004   EVT VT = N->getValueType(0);
19005   SelectionDAG &DAG = DCI.DAG;
19006 
19007   // Checks if the current node (N) is used by any SELECT instruction and
19008   // returns an empty SDValue to avoid applying the optimization to prevent
19009   // incorrect results
19010   for (auto U : N->uses())
19011     if (U->getOpcode() == ISD::SELECT)
19012       return SDValue();
19013 
19014   // Check if the operand is a SETCC node with floating-point comparison
19015   if (SetCC.getOpcode() == ISD::SETCC &&
19016       SetCC.getOperand(0).getValueType() == MVT::f32) {
19017 
19018     SDValue Cmp;
19019     AArch64CC::CondCode CC;
19020 
19021     // Check if the DAG is after legalization and if we can emit the conjunction
19022     if (!DCI.isBeforeLegalize() &&
19023         (Cmp = emitConjunction(DAG, SDValue(N, 0), CC))) {
19024 
19025       AArch64CC::CondCode InvertedCC = AArch64CC::getInvertedCondCode(CC);
19026 
19027       SDLoc DL(N);
19028       return DAG.getNode(AArch64ISD::CSINC, DL, VT, DAG.getConstant(0, DL, VT),
19029                          DAG.getConstant(0, DL, VT),
19030                          DAG.getConstant(InvertedCC, DL, MVT::i32), Cmp);
19031     }
19032   }
19033   return SDValue();
19034 }
19035 
performANDCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)19036 static SDValue performANDCombine(SDNode *N,
19037                                  TargetLowering::DAGCombinerInfo &DCI) {
19038   SelectionDAG &DAG = DCI.DAG;
19039   SDValue LHS = N->getOperand(0);
19040   SDValue RHS = N->getOperand(1);
19041   EVT VT = N->getValueType(0);
19042 
19043   if (SDValue R = performANDORCSELCombine(N, DAG))
19044     return R;
19045 
19046   if (SDValue R = performANDSETCCCombine(N,DCI))
19047     return R;
19048 
19049   if (!DAG.getTargetLoweringInfo().isTypeLegal(VT))
19050     return SDValue();
19051 
19052   if (VT.isScalableVector())
19053     return performSVEAndCombine(N, DCI);
19054 
19055   // The combining code below works only for NEON vectors. In particular, it
19056   // does not work for SVE when dealing with vectors wider than 128 bits.
19057   if (!VT.is64BitVector() && !VT.is128BitVector())
19058     return SDValue();
19059 
19060   BuildVectorSDNode *BVN = dyn_cast<BuildVectorSDNode>(RHS.getNode());
19061   if (!BVN)
19062     return SDValue();
19063 
19064   // AND does not accept an immediate, so check if we can use a BIC immediate
19065   // instruction instead. We do this here instead of using a (and x, (mvni imm))
19066   // pattern in isel, because some immediates may be lowered to the preferred
19067   // (and x, (movi imm)) form, even though an mvni representation also exists.
19068   APInt DefBits(VT.getSizeInBits(), 0);
19069   APInt UndefBits(VT.getSizeInBits(), 0);
19070   if (resolveBuildVector(BVN, DefBits, UndefBits)) {
19071     SDValue NewOp;
19072 
19073     // Any bits known to already be 0 need not be cleared again, which can help
19074     // reduce the size of the immediate to one supported by the instruction.
19075     KnownBits Known = DAG.computeKnownBits(LHS);
19076     APInt ZeroSplat(VT.getSizeInBits(), 0);
19077     for (unsigned I = 0; I < VT.getSizeInBits() / Known.Zero.getBitWidth(); I++)
19078       ZeroSplat |= Known.Zero.zext(VT.getSizeInBits())
19079                    << (Known.Zero.getBitWidth() * I);
19080 
19081     DefBits = ~(DefBits | ZeroSplat);
19082     if ((NewOp = tryAdvSIMDModImm32(AArch64ISD::BICi, SDValue(N, 0), DAG,
19083                                     DefBits, &LHS)) ||
19084         (NewOp = tryAdvSIMDModImm16(AArch64ISD::BICi, SDValue(N, 0), DAG,
19085                                     DefBits, &LHS)))
19086       return NewOp;
19087 
19088     UndefBits = ~(UndefBits | ZeroSplat);
19089     if ((NewOp = tryAdvSIMDModImm32(AArch64ISD::BICi, SDValue(N, 0), DAG,
19090                                     UndefBits, &LHS)) ||
19091         (NewOp = tryAdvSIMDModImm16(AArch64ISD::BICi, SDValue(N, 0), DAG,
19092                                     UndefBits, &LHS)))
19093       return NewOp;
19094   }
19095 
19096   return SDValue();
19097 }
19098 
performFADDCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)19099 static SDValue performFADDCombine(SDNode *N,
19100                                   TargetLowering::DAGCombinerInfo &DCI) {
19101   SelectionDAG &DAG = DCI.DAG;
19102   SDValue LHS = N->getOperand(0);
19103   SDValue RHS = N->getOperand(1);
19104   EVT VT = N->getValueType(0);
19105   SDLoc DL(N);
19106 
19107   if (!N->getFlags().hasAllowReassociation())
19108     return SDValue();
19109 
19110   // Combine fadd(a, vcmla(b, c, d)) -> vcmla(fadd(a, b), b, c)
19111   auto ReassocComplex = [&](SDValue A, SDValue B) {
19112     if (A.getOpcode() != ISD::INTRINSIC_WO_CHAIN)
19113       return SDValue();
19114     unsigned Opc = A.getConstantOperandVal(0);
19115     if (Opc != Intrinsic::aarch64_neon_vcmla_rot0 &&
19116         Opc != Intrinsic::aarch64_neon_vcmla_rot90 &&
19117         Opc != Intrinsic::aarch64_neon_vcmla_rot180 &&
19118         Opc != Intrinsic::aarch64_neon_vcmla_rot270)
19119       return SDValue();
19120     SDValue VCMLA = DAG.getNode(
19121         ISD::INTRINSIC_WO_CHAIN, DL, VT, A.getOperand(0),
19122         DAG.getNode(ISD::FADD, DL, VT, A.getOperand(1), B, N->getFlags()),
19123         A.getOperand(2), A.getOperand(3));
19124     VCMLA->setFlags(A->getFlags());
19125     return VCMLA;
19126   };
19127   if (SDValue R = ReassocComplex(LHS, RHS))
19128     return R;
19129   if (SDValue R = ReassocComplex(RHS, LHS))
19130     return R;
19131 
19132   return SDValue();
19133 }
19134 
hasPairwiseAdd(unsigned Opcode,EVT VT,bool FullFP16)19135 static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
19136   switch (Opcode) {
19137   case ISD::STRICT_FADD:
19138   case ISD::FADD:
19139     return (FullFP16 && VT == MVT::f16) || VT == MVT::f32 || VT == MVT::f64;
19140   case ISD::ADD:
19141     return VT == MVT::i64;
19142   default:
19143     return false;
19144   }
19145 }
19146 
19147 static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
19148                         AArch64CC::CondCode Cond);
19149 
isPredicateCCSettingOp(SDValue N)19150 static bool isPredicateCCSettingOp(SDValue N) {
19151   if ((N.getOpcode() == ISD::SETCC) ||
19152       (N.getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
19153        (N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilege ||
19154         N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilegt ||
19155         N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehi ||
19156         N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehs ||
19157         N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilele ||
19158         N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelo ||
19159         N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilels ||
19160         N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelt ||
19161         // get_active_lane_mask is lowered to a whilelo instruction.
19162         N.getConstantOperandVal(0) == Intrinsic::get_active_lane_mask)))
19163     return true;
19164 
19165   return false;
19166 }
19167 
19168 // Materialize : i1 = extract_vector_elt t37, Constant:i64<0>
19169 // ... into: "ptrue p, all" + PTEST
19170 static SDValue
performFirstTrueTestVectorCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,const AArch64Subtarget * Subtarget)19171 performFirstTrueTestVectorCombine(SDNode *N,
19172                                   TargetLowering::DAGCombinerInfo &DCI,
19173                                   const AArch64Subtarget *Subtarget) {
19174   assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT);
19175   // Make sure PTEST can be legalised with illegal types.
19176   if (!Subtarget->hasSVE() || DCI.isBeforeLegalize())
19177     return SDValue();
19178 
19179   SDValue N0 = N->getOperand(0);
19180   EVT VT = N0.getValueType();
19181 
19182   if (!VT.isScalableVector() || VT.getVectorElementType() != MVT::i1 ||
19183       !isNullConstant(N->getOperand(1)))
19184     return SDValue();
19185 
19186   // Restricted the DAG combine to only cases where we're extracting from a
19187   // flag-setting operation.
19188   if (!isPredicateCCSettingOp(N0))
19189     return SDValue();
19190 
19191   // Extracts of lane 0 for SVE can be expressed as PTEST(Op, FIRST) ? 1 : 0
19192   SelectionDAG &DAG = DCI.DAG;
19193   SDValue Pg = getPTrue(DAG, SDLoc(N), VT, AArch64SVEPredPattern::all);
19194   return getPTest(DAG, N->getValueType(0), Pg, N0, AArch64CC::FIRST_ACTIVE);
19195 }
19196 
19197 // Materialize : Idx = (add (mul vscale, NumEls), -1)
19198 //               i1 = extract_vector_elt t37, Constant:i64<Idx>
19199 //     ... into: "ptrue p, all" + PTEST
19200 static SDValue
performLastTrueTestVectorCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,const AArch64Subtarget * Subtarget)19201 performLastTrueTestVectorCombine(SDNode *N,
19202                                  TargetLowering::DAGCombinerInfo &DCI,
19203                                  const AArch64Subtarget *Subtarget) {
19204   assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT);
19205   // Make sure PTEST is legal types.
19206   if (!Subtarget->hasSVE() || DCI.isBeforeLegalize())
19207     return SDValue();
19208 
19209   SDValue N0 = N->getOperand(0);
19210   EVT OpVT = N0.getValueType();
19211 
19212   if (!OpVT.isScalableVector() || OpVT.getVectorElementType() != MVT::i1)
19213     return SDValue();
19214 
19215   // Idx == (add (mul vscale, NumEls), -1)
19216   SDValue Idx = N->getOperand(1);
19217   if (Idx.getOpcode() != ISD::ADD || !isAllOnesConstant(Idx.getOperand(1)))
19218     return SDValue();
19219 
19220   SDValue VS = Idx.getOperand(0);
19221   if (VS.getOpcode() != ISD::VSCALE)
19222     return SDValue();
19223 
19224   unsigned NumEls = OpVT.getVectorElementCount().getKnownMinValue();
19225   if (VS.getConstantOperandVal(0) != NumEls)
19226     return SDValue();
19227 
19228   // Extracts of lane EC-1 for SVE can be expressed as PTEST(Op, LAST) ? 1 : 0
19229   SelectionDAG &DAG = DCI.DAG;
19230   SDValue Pg = getPTrue(DAG, SDLoc(N), OpVT, AArch64SVEPredPattern::all);
19231   return getPTest(DAG, N->getValueType(0), Pg, N0, AArch64CC::LAST_ACTIVE);
19232 }
19233 
19234 static SDValue
performExtractVectorEltCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,const AArch64Subtarget * Subtarget)19235 performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
19236                                const AArch64Subtarget *Subtarget) {
19237   assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT);
19238   if (SDValue Res = performFirstTrueTestVectorCombine(N, DCI, Subtarget))
19239     return Res;
19240   if (SDValue Res = performLastTrueTestVectorCombine(N, DCI, Subtarget))
19241     return Res;
19242 
19243   SelectionDAG &DAG = DCI.DAG;
19244   SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
19245 
19246   EVT VT = N->getValueType(0);
19247   const bool FullFP16 = DAG.getSubtarget<AArch64Subtarget>().hasFullFP16();
19248   bool IsStrict = N0->isStrictFPOpcode();
19249 
19250   // extract(dup x) -> x
19251   if (N0.getOpcode() == AArch64ISD::DUP)
19252     return VT.isInteger() ? DAG.getZExtOrTrunc(N0.getOperand(0), SDLoc(N), VT)
19253                           : N0.getOperand(0);
19254 
19255   // Rewrite for pairwise fadd pattern
19256   //   (f32 (extract_vector_elt
19257   //           (fadd (vXf32 Other)
19258   //                 (vector_shuffle (vXf32 Other) undef <1,X,...> )) 0))
19259   // ->
19260   //   (f32 (fadd (extract_vector_elt (vXf32 Other) 0)
19261   //              (extract_vector_elt (vXf32 Other) 1))
19262   // For strict_fadd we need to make sure the old strict_fadd can be deleted, so
19263   // we can only do this when it's used only by the extract_vector_elt.
19264   if (isNullConstant(N1) && hasPairwiseAdd(N0->getOpcode(), VT, FullFP16) &&
19265       (!IsStrict || N0.hasOneUse())) {
19266     SDLoc DL(N0);
19267     SDValue N00 = N0->getOperand(IsStrict ? 1 : 0);
19268     SDValue N01 = N0->getOperand(IsStrict ? 2 : 1);
19269 
19270     ShuffleVectorSDNode *Shuffle = dyn_cast<ShuffleVectorSDNode>(N01);
19271     SDValue Other = N00;
19272 
19273     // And handle the commutative case.
19274     if (!Shuffle) {
19275       Shuffle = dyn_cast<ShuffleVectorSDNode>(N00);
19276       Other = N01;
19277     }
19278 
19279     if (Shuffle && Shuffle->getMaskElt(0) == 1 &&
19280         Other == Shuffle->getOperand(0)) {
19281       SDValue Extract1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Other,
19282                                      DAG.getConstant(0, DL, MVT::i64));
19283       SDValue Extract2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Other,
19284                                      DAG.getConstant(1, DL, MVT::i64));
19285       if (!IsStrict)
19286         return DAG.getNode(N0->getOpcode(), DL, VT, Extract1, Extract2);
19287 
19288       // For strict_fadd we need uses of the final extract_vector to be replaced
19289       // with the strict_fadd, but we also need uses of the chain output of the
19290       // original strict_fadd to use the chain output of the new strict_fadd as
19291       // otherwise it may not be deleted.
19292       SDValue Ret = DAG.getNode(N0->getOpcode(), DL,
19293                                 {VT, MVT::Other},
19294                                 {N0->getOperand(0), Extract1, Extract2});
19295       DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Ret);
19296       DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Ret.getValue(1));
19297       return SDValue(N, 0);
19298     }
19299   }
19300 
19301   return SDValue();
19302 }
19303 
performConcatVectorsCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)19304 static SDValue performConcatVectorsCombine(SDNode *N,
19305                                            TargetLowering::DAGCombinerInfo &DCI,
19306                                            SelectionDAG &DAG) {
19307   SDLoc dl(N);
19308   EVT VT = N->getValueType(0);
19309   SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
19310   unsigned N0Opc = N0->getOpcode(), N1Opc = N1->getOpcode();
19311 
19312   if (VT.isScalableVector())
19313     return SDValue();
19314 
19315   // Optimize concat_vectors of truncated vectors, where the intermediate
19316   // type is illegal, to avoid said illegality,  e.g.,
19317   //   (v4i16 (concat_vectors (v2i16 (truncate (v2i64))),
19318   //                          (v2i16 (truncate (v2i64)))))
19319   // ->
19320   //   (v4i16 (truncate (vector_shuffle (v4i32 (bitcast (v2i64))),
19321   //                                    (v4i32 (bitcast (v2i64))),
19322   //                                    <0, 2, 4, 6>)))
19323   // This isn't really target-specific, but ISD::TRUNCATE legality isn't keyed
19324   // on both input and result type, so we might generate worse code.
19325   // On AArch64 we know it's fine for v2i64->v4i16 and v4i32->v8i8.
19326   if (N->getNumOperands() == 2 && N0Opc == ISD::TRUNCATE &&
19327       N1Opc == ISD::TRUNCATE) {
19328     SDValue N00 = N0->getOperand(0);
19329     SDValue N10 = N1->getOperand(0);
19330     EVT N00VT = N00.getValueType();
19331 
19332     if (N00VT == N10.getValueType() &&
19333         (N00VT == MVT::v2i64 || N00VT == MVT::v4i32) &&
19334         N00VT.getScalarSizeInBits() == 4 * VT.getScalarSizeInBits()) {
19335       MVT MidVT = (N00VT == MVT::v2i64 ? MVT::v4i32 : MVT::v8i16);
19336       SmallVector<int, 8> Mask(MidVT.getVectorNumElements());
19337       for (size_t i = 0; i < Mask.size(); ++i)
19338         Mask[i] = i * 2;
19339       return DAG.getNode(ISD::TRUNCATE, dl, VT,
19340                          DAG.getVectorShuffle(
19341                              MidVT, dl,
19342                              DAG.getNode(ISD::BITCAST, dl, MidVT, N00),
19343                              DAG.getNode(ISD::BITCAST, dl, MidVT, N10), Mask));
19344     }
19345   }
19346 
19347   if (N->getOperand(0).getValueType() == MVT::v4i8 ||
19348       N->getOperand(0).getValueType() == MVT::v2i16 ||
19349       N->getOperand(0).getValueType() == MVT::v2i8) {
19350     EVT SrcVT = N->getOperand(0).getValueType();
19351     // If we have a concat of v4i8 loads, convert them to a buildvector of f32
19352     // loads to prevent having to go through the v4i8 load legalization that
19353     // needs to extend each element into a larger type.
19354     if (N->getNumOperands() % 2 == 0 &&
19355         all_of(N->op_values(), [SrcVT](SDValue V) {
19356           if (V.getValueType() != SrcVT)
19357             return false;
19358           if (V.isUndef())
19359             return true;
19360           LoadSDNode *LD = dyn_cast<LoadSDNode>(V);
19361           return LD && V.hasOneUse() && LD->isSimple() && !LD->isIndexed() &&
19362                  LD->getExtensionType() == ISD::NON_EXTLOAD;
19363         })) {
19364       EVT FVT = SrcVT == MVT::v2i8 ? MVT::f16 : MVT::f32;
19365       EVT NVT = EVT::getVectorVT(*DAG.getContext(), FVT, N->getNumOperands());
19366       SmallVector<SDValue> Ops;
19367 
19368       for (unsigned i = 0; i < N->getNumOperands(); i++) {
19369         SDValue V = N->getOperand(i);
19370         if (V.isUndef())
19371           Ops.push_back(DAG.getUNDEF(FVT));
19372         else {
19373           LoadSDNode *LD = cast<LoadSDNode>(V);
19374           SDValue NewLoad = DAG.getLoad(FVT, dl, LD->getChain(),
19375                                         LD->getBasePtr(), LD->getMemOperand());
19376           DAG.ReplaceAllUsesOfValueWith(SDValue(LD, 1), NewLoad.getValue(1));
19377           Ops.push_back(NewLoad);
19378         }
19379       }
19380       return DAG.getBitcast(N->getValueType(0),
19381                             DAG.getBuildVector(NVT, dl, Ops));
19382     }
19383   }
19384 
19385   // Canonicalise concat_vectors to replace concatenations of truncated nots
19386   // with nots of concatenated truncates. This in some cases allows for multiple
19387   // redundant negations to be eliminated.
19388   //  (concat_vectors (v4i16 (truncate (not (v4i32)))),
19389   //                  (v4i16 (truncate (not (v4i32)))))
19390   // ->
19391   //  (not (concat_vectors (v4i16 (truncate (v4i32))),
19392   //                       (v4i16 (truncate (v4i32)))))
19393   if (N->getNumOperands() == 2 && N0Opc == ISD::TRUNCATE &&
19394       N1Opc == ISD::TRUNCATE && N->isOnlyUserOf(N0.getNode()) &&
19395       N->isOnlyUserOf(N1.getNode())) {
19396     auto isBitwiseVectorNegate = [](SDValue V) {
19397       return V->getOpcode() == ISD::XOR &&
19398              ISD::isConstantSplatVectorAllOnes(V.getOperand(1).getNode());
19399     };
19400     SDValue N00 = N0->getOperand(0);
19401     SDValue N10 = N1->getOperand(0);
19402     if (isBitwiseVectorNegate(N00) && N0->isOnlyUserOf(N00.getNode()) &&
19403         isBitwiseVectorNegate(N10) && N1->isOnlyUserOf(N10.getNode())) {
19404       return DAG.getNOT(
19405           dl,
19406           DAG.getNode(ISD::CONCAT_VECTORS, dl, VT,
19407                       DAG.getNode(ISD::TRUNCATE, dl, N0.getValueType(),
19408                                   N00->getOperand(0)),
19409                       DAG.getNode(ISD::TRUNCATE, dl, N1.getValueType(),
19410                                   N10->getOperand(0))),
19411           VT);
19412     }
19413   }
19414 
19415   // Wait till after everything is legalized to try this. That way we have
19416   // legal vector types and such.
19417   if (DCI.isBeforeLegalizeOps())
19418     return SDValue();
19419 
19420   // Optimise concat_vectors of two identical binops with a 128-bit destination
19421   // size, combine into an binop of two contacts of the source vectors. eg:
19422   // concat(uhadd(a,b), uhadd(c, d)) -> uhadd(concat(a, c), concat(b, d))
19423   if (N->getNumOperands() == 2 && N0Opc == N1Opc && VT.is128BitVector() &&
19424       DAG.getTargetLoweringInfo().isBinOp(N0Opc) && N0->hasOneUse() &&
19425       N1->hasOneUse()) {
19426     SDValue N00 = N0->getOperand(0);
19427     SDValue N01 = N0->getOperand(1);
19428     SDValue N10 = N1->getOperand(0);
19429     SDValue N11 = N1->getOperand(1);
19430 
19431     if (!N00.isUndef() && !N01.isUndef() && !N10.isUndef() && !N11.isUndef()) {
19432       SDValue Concat0 = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, N00, N10);
19433       SDValue Concat1 = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, N01, N11);
19434       return DAG.getNode(N0Opc, dl, VT, Concat0, Concat1);
19435     }
19436   }
19437 
19438   auto IsRSHRN = [](SDValue Shr) {
19439     if (Shr.getOpcode() != AArch64ISD::VLSHR)
19440       return false;
19441     SDValue Op = Shr.getOperand(0);
19442     EVT VT = Op.getValueType();
19443     unsigned ShtAmt = Shr.getConstantOperandVal(1);
19444     if (ShtAmt > VT.getScalarSizeInBits() / 2 || Op.getOpcode() != ISD::ADD)
19445       return false;
19446 
19447     APInt Imm;
19448     if (Op.getOperand(1).getOpcode() == AArch64ISD::MOVIshift)
19449       Imm = APInt(VT.getScalarSizeInBits(),
19450                   Op.getOperand(1).getConstantOperandVal(0)
19451                       << Op.getOperand(1).getConstantOperandVal(1));
19452     else if (Op.getOperand(1).getOpcode() == AArch64ISD::DUP &&
19453              isa<ConstantSDNode>(Op.getOperand(1).getOperand(0)))
19454       Imm = APInt(VT.getScalarSizeInBits(),
19455                   Op.getOperand(1).getConstantOperandVal(0));
19456     else
19457       return false;
19458 
19459     if (Imm != 1ULL << (ShtAmt - 1))
19460       return false;
19461     return true;
19462   };
19463 
19464   // concat(rshrn(x), rshrn(y)) -> rshrn(concat(x, y))
19465   if (N->getNumOperands() == 2 && IsRSHRN(N0) &&
19466       ((IsRSHRN(N1) &&
19467         N0.getConstantOperandVal(1) == N1.getConstantOperandVal(1)) ||
19468        N1.isUndef())) {
19469     SDValue X = N0.getOperand(0).getOperand(0);
19470     SDValue Y = N1.isUndef() ? DAG.getUNDEF(X.getValueType())
19471                              : N1.getOperand(0).getOperand(0);
19472     EVT BVT =
19473         X.getValueType().getDoubleNumVectorElementsVT(*DCI.DAG.getContext());
19474     SDValue CC = DAG.getNode(ISD::CONCAT_VECTORS, dl, BVT, X, Y);
19475     SDValue Add = DAG.getNode(
19476         ISD::ADD, dl, BVT, CC,
19477         DAG.getConstant(1ULL << (N0.getConstantOperandVal(1) - 1), dl, BVT));
19478     SDValue Shr =
19479         DAG.getNode(AArch64ISD::VLSHR, dl, BVT, Add, N0.getOperand(1));
19480     return Shr;
19481   }
19482 
19483   // concat(zip1(a, b), zip2(a, b)) is zip1(a, b)
19484   if (N->getNumOperands() == 2 && N0Opc == AArch64ISD::ZIP1 &&
19485       N1Opc == AArch64ISD::ZIP2 && N0.getOperand(0) == N1.getOperand(0) &&
19486       N0.getOperand(1) == N1.getOperand(1)) {
19487     SDValue E0 = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, N0.getOperand(0),
19488                              DAG.getUNDEF(N0.getValueType()));
19489     SDValue E1 = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, N0.getOperand(1),
19490                              DAG.getUNDEF(N0.getValueType()));
19491     return DAG.getNode(AArch64ISD::ZIP1, dl, VT, E0, E1);
19492   }
19493 
19494   // If we see a (concat_vectors (v1x64 A), (v1x64 A)) it's really a vector
19495   // splat. The indexed instructions are going to be expecting a DUPLANE64, so
19496   // canonicalise to that.
19497   if (N->getNumOperands() == 2 && N0 == N1 && VT.getVectorNumElements() == 2) {
19498     assert(VT.getScalarSizeInBits() == 64);
19499     return DAG.getNode(AArch64ISD::DUPLANE64, dl, VT, WidenVector(N0, DAG),
19500                        DAG.getConstant(0, dl, MVT::i64));
19501   }
19502 
19503   // Canonicalise concat_vectors so that the right-hand vector has as few
19504   // bit-casts as possible before its real operation. The primary matching
19505   // destination for these operations will be the narrowing "2" instructions,
19506   // which depend on the operation being performed on this right-hand vector.
19507   // For example,
19508   //    (concat_vectors LHS,  (v1i64 (bitconvert (v4i16 RHS))))
19509   // becomes
19510   //    (bitconvert (concat_vectors (v4i16 (bitconvert LHS)), RHS))
19511 
19512   if (N->getNumOperands() != 2 || N1Opc != ISD::BITCAST)
19513     return SDValue();
19514   SDValue RHS = N1->getOperand(0);
19515   MVT RHSTy = RHS.getValueType().getSimpleVT();
19516   // If the RHS is not a vector, this is not the pattern we're looking for.
19517   if (!RHSTy.isVector())
19518     return SDValue();
19519 
19520   LLVM_DEBUG(
19521       dbgs() << "aarch64-lower: concat_vectors bitcast simplification\n");
19522 
19523   MVT ConcatTy = MVT::getVectorVT(RHSTy.getVectorElementType(),
19524                                   RHSTy.getVectorNumElements() * 2);
19525   return DAG.getNode(ISD::BITCAST, dl, VT,
19526                      DAG.getNode(ISD::CONCAT_VECTORS, dl, ConcatTy,
19527                                  DAG.getNode(ISD::BITCAST, dl, RHSTy, N0),
19528                                  RHS));
19529 }
19530 
19531 static SDValue
performExtractSubvectorCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)19532 performExtractSubvectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
19533                                SelectionDAG &DAG) {
19534   if (DCI.isBeforeLegalizeOps())
19535     return SDValue();
19536 
19537   EVT VT = N->getValueType(0);
19538   if (!VT.isScalableVector() || VT.getVectorElementType() != MVT::i1)
19539     return SDValue();
19540 
19541   SDValue V = N->getOperand(0);
19542 
19543   // NOTE: This combine exists in DAGCombiner, but that version's legality check
19544   // blocks this combine because the non-const case requires custom lowering.
19545   //
19546   // ty1 extract_vector(ty2 splat(const))) -> ty1 splat(const)
19547   if (V.getOpcode() == ISD::SPLAT_VECTOR)
19548     if (isa<ConstantSDNode>(V.getOperand(0)))
19549       return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, V.getOperand(0));
19550 
19551   return SDValue();
19552 }
19553 
19554 static SDValue
performInsertSubvectorCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)19555 performInsertSubvectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
19556                               SelectionDAG &DAG) {
19557   SDLoc DL(N);
19558   SDValue Vec = N->getOperand(0);
19559   SDValue SubVec = N->getOperand(1);
19560   uint64_t IdxVal = N->getConstantOperandVal(2);
19561   EVT VecVT = Vec.getValueType();
19562   EVT SubVT = SubVec.getValueType();
19563 
19564   // Only do this for legal fixed vector types.
19565   if (!VecVT.isFixedLengthVector() ||
19566       !DAG.getTargetLoweringInfo().isTypeLegal(VecVT) ||
19567       !DAG.getTargetLoweringInfo().isTypeLegal(SubVT))
19568     return SDValue();
19569 
19570   // Ignore widening patterns.
19571   if (IdxVal == 0 && Vec.isUndef())
19572     return SDValue();
19573 
19574   // Subvector must be half the width and an "aligned" insertion.
19575   unsigned NumSubElts = SubVT.getVectorNumElements();
19576   if ((SubVT.getSizeInBits() * 2) != VecVT.getSizeInBits() ||
19577       (IdxVal != 0 && IdxVal != NumSubElts))
19578     return SDValue();
19579 
19580   // Fold insert_subvector -> concat_vectors
19581   // insert_subvector(Vec,Sub,lo) -> concat_vectors(Sub,extract(Vec,hi))
19582   // insert_subvector(Vec,Sub,hi) -> concat_vectors(extract(Vec,lo),Sub)
19583   SDValue Lo, Hi;
19584   if (IdxVal == 0) {
19585     Lo = SubVec;
19586     Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT, Vec,
19587                      DAG.getVectorIdxConstant(NumSubElts, DL));
19588   } else {
19589     Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT, Vec,
19590                      DAG.getVectorIdxConstant(0, DL));
19591     Hi = SubVec;
19592   }
19593   return DAG.getNode(ISD::CONCAT_VECTORS, DL, VecVT, Lo, Hi);
19594 }
19595 
tryCombineFixedPointConvert(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)19596 static SDValue tryCombineFixedPointConvert(SDNode *N,
19597                                            TargetLowering::DAGCombinerInfo &DCI,
19598                                            SelectionDAG &DAG) {
19599   // Wait until after everything is legalized to try this. That way we have
19600   // legal vector types and such.
19601   if (DCI.isBeforeLegalizeOps())
19602     return SDValue();
19603   // Transform a scalar conversion of a value from a lane extract into a
19604   // lane extract of a vector conversion. E.g., from foo1 to foo2:
19605   // double foo1(int64x2_t a) { return vcvtd_n_f64_s64(a[1], 9); }
19606   // double foo2(int64x2_t a) { return vcvtq_n_f64_s64(a, 9)[1]; }
19607   //
19608   // The second form interacts better with instruction selection and the
19609   // register allocator to avoid cross-class register copies that aren't
19610   // coalescable due to a lane reference.
19611 
19612   // Check the operand and see if it originates from a lane extract.
19613   SDValue Op1 = N->getOperand(1);
19614   if (Op1.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
19615     return SDValue();
19616 
19617   // Yep, no additional predication needed. Perform the transform.
19618   SDValue IID = N->getOperand(0);
19619   SDValue Shift = N->getOperand(2);
19620   SDValue Vec = Op1.getOperand(0);
19621   SDValue Lane = Op1.getOperand(1);
19622   EVT ResTy = N->getValueType(0);
19623   EVT VecResTy;
19624   SDLoc DL(N);
19625 
19626   // The vector width should be 128 bits by the time we get here, even
19627   // if it started as 64 bits (the extract_vector handling will have
19628   // done so). Bail if it is not.
19629   if (Vec.getValueSizeInBits() != 128)
19630     return SDValue();
19631 
19632   if (Vec.getValueType() == MVT::v4i32)
19633     VecResTy = MVT::v4f32;
19634   else if (Vec.getValueType() == MVT::v2i64)
19635     VecResTy = MVT::v2f64;
19636   else
19637     return SDValue();
19638 
19639   SDValue Convert =
19640       DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VecResTy, IID, Vec, Shift);
19641   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResTy, Convert, Lane);
19642 }
19643 
19644 // AArch64 high-vector "long" operations are formed by performing the non-high
19645 // version on an extract_subvector of each operand which gets the high half:
19646 //
19647 //  (longop2 LHS, RHS) == (longop (extract_high LHS), (extract_high RHS))
19648 //
19649 // However, there are cases which don't have an extract_high explicitly, but
19650 // have another operation that can be made compatible with one for free. For
19651 // example:
19652 //
19653 //  (dupv64 scalar) --> (extract_high (dup128 scalar))
19654 //
19655 // This routine does the actual conversion of such DUPs, once outer routines
19656 // have determined that everything else is in order.
19657 // It also supports immediate DUP-like nodes (MOVI/MVNi), which we can fold
19658 // similarly here.
tryExtendDUPToExtractHigh(SDValue N,SelectionDAG & DAG)19659 static SDValue tryExtendDUPToExtractHigh(SDValue N, SelectionDAG &DAG) {
19660   MVT VT = N.getSimpleValueType();
19661   if (N.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
19662       N.getConstantOperandVal(1) == 0)
19663     N = N.getOperand(0);
19664 
19665   switch (N.getOpcode()) {
19666   case AArch64ISD::DUP:
19667   case AArch64ISD::DUPLANE8:
19668   case AArch64ISD::DUPLANE16:
19669   case AArch64ISD::DUPLANE32:
19670   case AArch64ISD::DUPLANE64:
19671   case AArch64ISD::MOVI:
19672   case AArch64ISD::MOVIshift:
19673   case AArch64ISD::MOVIedit:
19674   case AArch64ISD::MOVImsl:
19675   case AArch64ISD::MVNIshift:
19676   case AArch64ISD::MVNImsl:
19677     break;
19678   default:
19679     // FMOV could be supported, but isn't very useful, as it would only occur
19680     // if you passed a bitcast' floating point immediate to an eligible long
19681     // integer op (addl, smull, ...).
19682     return SDValue();
19683   }
19684 
19685   if (!VT.is64BitVector())
19686     return SDValue();
19687 
19688   SDLoc DL(N);
19689   unsigned NumElems = VT.getVectorNumElements();
19690   if (N.getValueType().is64BitVector()) {
19691     MVT ElementTy = VT.getVectorElementType();
19692     MVT NewVT = MVT::getVectorVT(ElementTy, NumElems * 2);
19693     N = DAG.getNode(N->getOpcode(), DL, NewVT, N->ops());
19694   }
19695 
19696   return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, N,
19697                      DAG.getConstant(NumElems, DL, MVT::i64));
19698 }
19699 
isEssentiallyExtractHighSubvector(SDValue N)19700 static bool isEssentiallyExtractHighSubvector(SDValue N) {
19701   if (N.getOpcode() == ISD::BITCAST)
19702     N = N.getOperand(0);
19703   if (N.getOpcode() != ISD::EXTRACT_SUBVECTOR)
19704     return false;
19705   if (N.getOperand(0).getValueType().isScalableVector())
19706     return false;
19707   return N.getConstantOperandAPInt(1) ==
19708          N.getOperand(0).getValueType().getVectorNumElements() / 2;
19709 }
19710 
19711 /// Helper structure to keep track of ISD::SET_CC operands.
19712 struct GenericSetCCInfo {
19713   const SDValue *Opnd0;
19714   const SDValue *Opnd1;
19715   ISD::CondCode CC;
19716 };
19717 
19718 /// Helper structure to keep track of a SET_CC lowered into AArch64 code.
19719 struct AArch64SetCCInfo {
19720   const SDValue *Cmp;
19721   AArch64CC::CondCode CC;
19722 };
19723 
19724 /// Helper structure to keep track of SetCC information.
19725 union SetCCInfo {
19726   GenericSetCCInfo Generic;
19727   AArch64SetCCInfo AArch64;
19728 };
19729 
19730 /// Helper structure to be able to read SetCC information.  If set to
19731 /// true, IsAArch64 field, Info is a AArch64SetCCInfo, otherwise Info is a
19732 /// GenericSetCCInfo.
19733 struct SetCCInfoAndKind {
19734   SetCCInfo Info;
19735   bool IsAArch64;
19736 };
19737 
19738 /// Check whether or not \p Op is a SET_CC operation, either a generic or
19739 /// an
19740 /// AArch64 lowered one.
19741 /// \p SetCCInfo is filled accordingly.
19742 /// \post SetCCInfo is meanginfull only when this function returns true.
19743 /// \return True when Op is a kind of SET_CC operation.
isSetCC(SDValue Op,SetCCInfoAndKind & SetCCInfo)19744 static bool isSetCC(SDValue Op, SetCCInfoAndKind &SetCCInfo) {
19745   // If this is a setcc, this is straight forward.
19746   if (Op.getOpcode() == ISD::SETCC) {
19747     SetCCInfo.Info.Generic.Opnd0 = &Op.getOperand(0);
19748     SetCCInfo.Info.Generic.Opnd1 = &Op.getOperand(1);
19749     SetCCInfo.Info.Generic.CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
19750     SetCCInfo.IsAArch64 = false;
19751     return true;
19752   }
19753   // Otherwise, check if this is a matching csel instruction.
19754   // In other words:
19755   // - csel 1, 0, cc
19756   // - csel 0, 1, !cc
19757   if (Op.getOpcode() != AArch64ISD::CSEL)
19758     return false;
19759   // Set the information about the operands.
19760   // TODO: we want the operands of the Cmp not the csel
19761   SetCCInfo.Info.AArch64.Cmp = &Op.getOperand(3);
19762   SetCCInfo.IsAArch64 = true;
19763   SetCCInfo.Info.AArch64.CC =
19764       static_cast<AArch64CC::CondCode>(Op.getConstantOperandVal(2));
19765 
19766   // Check that the operands matches the constraints:
19767   // (1) Both operands must be constants.
19768   // (2) One must be 1 and the other must be 0.
19769   ConstantSDNode *TValue = dyn_cast<ConstantSDNode>(Op.getOperand(0));
19770   ConstantSDNode *FValue = dyn_cast<ConstantSDNode>(Op.getOperand(1));
19771 
19772   // Check (1).
19773   if (!TValue || !FValue)
19774     return false;
19775 
19776   // Check (2).
19777   if (!TValue->isOne()) {
19778     // Update the comparison when we are interested in !cc.
19779     std::swap(TValue, FValue);
19780     SetCCInfo.Info.AArch64.CC =
19781         AArch64CC::getInvertedCondCode(SetCCInfo.Info.AArch64.CC);
19782   }
19783   return TValue->isOne() && FValue->isZero();
19784 }
19785 
19786 // Returns true if Op is setcc or zext of setcc.
isSetCCOrZExtSetCC(const SDValue & Op,SetCCInfoAndKind & Info)19787 static bool isSetCCOrZExtSetCC(const SDValue& Op, SetCCInfoAndKind &Info) {
19788   if (isSetCC(Op, Info))
19789     return true;
19790   return ((Op.getOpcode() == ISD::ZERO_EXTEND) &&
19791     isSetCC(Op->getOperand(0), Info));
19792 }
19793 
19794 // The folding we want to perform is:
19795 // (add x, [zext] (setcc cc ...) )
19796 //   -->
19797 // (csel x, (add x, 1), !cc ...)
19798 //
19799 // The latter will get matched to a CSINC instruction.
performSetccAddFolding(SDNode * Op,SelectionDAG & DAG)19800 static SDValue performSetccAddFolding(SDNode *Op, SelectionDAG &DAG) {
19801   assert(Op && Op->getOpcode() == ISD::ADD && "Unexpected operation!");
19802   SDValue LHS = Op->getOperand(0);
19803   SDValue RHS = Op->getOperand(1);
19804   SetCCInfoAndKind InfoAndKind;
19805 
19806   // If both operands are a SET_CC, then we don't want to perform this
19807   // folding and create another csel as this results in more instructions
19808   // (and higher register usage).
19809   if (isSetCCOrZExtSetCC(LHS, InfoAndKind) &&
19810       isSetCCOrZExtSetCC(RHS, InfoAndKind))
19811     return SDValue();
19812 
19813   // If neither operand is a SET_CC, give up.
19814   if (!isSetCCOrZExtSetCC(LHS, InfoAndKind)) {
19815     std::swap(LHS, RHS);
19816     if (!isSetCCOrZExtSetCC(LHS, InfoAndKind))
19817       return SDValue();
19818   }
19819 
19820   // FIXME: This could be generatized to work for FP comparisons.
19821   EVT CmpVT = InfoAndKind.IsAArch64
19822                   ? InfoAndKind.Info.AArch64.Cmp->getOperand(0).getValueType()
19823                   : InfoAndKind.Info.Generic.Opnd0->getValueType();
19824   if (CmpVT != MVT::i32 && CmpVT != MVT::i64)
19825     return SDValue();
19826 
19827   SDValue CCVal;
19828   SDValue Cmp;
19829   SDLoc dl(Op);
19830   if (InfoAndKind.IsAArch64) {
19831     CCVal = DAG.getConstant(
19832         AArch64CC::getInvertedCondCode(InfoAndKind.Info.AArch64.CC), dl,
19833         MVT::i32);
19834     Cmp = *InfoAndKind.Info.AArch64.Cmp;
19835   } else
19836     Cmp = getAArch64Cmp(
19837         *InfoAndKind.Info.Generic.Opnd0, *InfoAndKind.Info.Generic.Opnd1,
19838         ISD::getSetCCInverse(InfoAndKind.Info.Generic.CC, CmpVT), CCVal, DAG,
19839         dl);
19840 
19841   EVT VT = Op->getValueType(0);
19842   LHS = DAG.getNode(ISD::ADD, dl, VT, RHS, DAG.getConstant(1, dl, VT));
19843   return DAG.getNode(AArch64ISD::CSEL, dl, VT, RHS, LHS, CCVal, Cmp);
19844 }
19845 
19846 // ADD(UADDV a, UADDV b) -->  UADDV(ADD a, b)
performAddUADDVCombine(SDNode * N,SelectionDAG & DAG)19847 static SDValue performAddUADDVCombine(SDNode *N, SelectionDAG &DAG) {
19848   EVT VT = N->getValueType(0);
19849   // Only scalar integer and vector types.
19850   if (N->getOpcode() != ISD::ADD || !VT.isScalarInteger())
19851     return SDValue();
19852 
19853   SDValue LHS = N->getOperand(0);
19854   SDValue RHS = N->getOperand(1);
19855   if (LHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
19856       RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT || LHS.getValueType() != VT)
19857     return SDValue();
19858 
19859   auto *LHSN1 = dyn_cast<ConstantSDNode>(LHS->getOperand(1));
19860   auto *RHSN1 = dyn_cast<ConstantSDNode>(RHS->getOperand(1));
19861   if (!LHSN1 || LHSN1 != RHSN1 || !RHSN1->isZero())
19862     return SDValue();
19863 
19864   SDValue Op1 = LHS->getOperand(0);
19865   SDValue Op2 = RHS->getOperand(0);
19866   EVT OpVT1 = Op1.getValueType();
19867   EVT OpVT2 = Op2.getValueType();
19868   if (Op1.getOpcode() != AArch64ISD::UADDV || OpVT1 != OpVT2 ||
19869       Op2.getOpcode() != AArch64ISD::UADDV ||
19870       OpVT1.getVectorElementType() != VT)
19871     return SDValue();
19872 
19873   SDValue Val1 = Op1.getOperand(0);
19874   SDValue Val2 = Op2.getOperand(0);
19875   EVT ValVT = Val1->getValueType(0);
19876   SDLoc DL(N);
19877   SDValue AddVal = DAG.getNode(ISD::ADD, DL, ValVT, Val1, Val2);
19878   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT,
19879                      DAG.getNode(AArch64ISD::UADDV, DL, ValVT, AddVal),
19880                      DAG.getConstant(0, DL, MVT::i64));
19881 }
19882 
19883 /// Perform the scalar expression combine in the form of:
19884 ///   CSEL(c, 1, cc) + b => CSINC(b+c, b, cc)
19885 ///   CSNEG(c, -1, cc) + b => CSINC(b+c, b, cc)
performAddCSelIntoCSinc(SDNode * N,SelectionDAG & DAG)19886 static SDValue performAddCSelIntoCSinc(SDNode *N, SelectionDAG &DAG) {
19887   EVT VT = N->getValueType(0);
19888   if (!VT.isScalarInteger() || N->getOpcode() != ISD::ADD)
19889     return SDValue();
19890 
19891   SDValue LHS = N->getOperand(0);
19892   SDValue RHS = N->getOperand(1);
19893 
19894   // Handle commutivity.
19895   if (LHS.getOpcode() != AArch64ISD::CSEL &&
19896       LHS.getOpcode() != AArch64ISD::CSNEG) {
19897     std::swap(LHS, RHS);
19898     if (LHS.getOpcode() != AArch64ISD::CSEL &&
19899         LHS.getOpcode() != AArch64ISD::CSNEG) {
19900       return SDValue();
19901     }
19902   }
19903 
19904   if (!LHS.hasOneUse())
19905     return SDValue();
19906 
19907   AArch64CC::CondCode AArch64CC =
19908       static_cast<AArch64CC::CondCode>(LHS.getConstantOperandVal(2));
19909 
19910   // The CSEL should include a const one operand, and the CSNEG should include
19911   // One or NegOne operand.
19912   ConstantSDNode *CTVal = dyn_cast<ConstantSDNode>(LHS.getOperand(0));
19913   ConstantSDNode *CFVal = dyn_cast<ConstantSDNode>(LHS.getOperand(1));
19914   if (!CTVal || !CFVal)
19915     return SDValue();
19916 
19917   if (!(LHS.getOpcode() == AArch64ISD::CSEL &&
19918         (CTVal->isOne() || CFVal->isOne())) &&
19919       !(LHS.getOpcode() == AArch64ISD::CSNEG &&
19920         (CTVal->isOne() || CFVal->isAllOnes())))
19921     return SDValue();
19922 
19923   // Switch CSEL(1, c, cc) to CSEL(c, 1, !cc)
19924   if (LHS.getOpcode() == AArch64ISD::CSEL && CTVal->isOne() &&
19925       !CFVal->isOne()) {
19926     std::swap(CTVal, CFVal);
19927     AArch64CC = AArch64CC::getInvertedCondCode(AArch64CC);
19928   }
19929 
19930   SDLoc DL(N);
19931   // Switch CSNEG(1, c, cc) to CSNEG(-c, -1, !cc)
19932   if (LHS.getOpcode() == AArch64ISD::CSNEG && CTVal->isOne() &&
19933       !CFVal->isAllOnes()) {
19934     APInt C = -1 * CFVal->getAPIntValue();
19935     CTVal = cast<ConstantSDNode>(DAG.getConstant(C, DL, VT));
19936     CFVal = cast<ConstantSDNode>(DAG.getAllOnesConstant(DL, VT));
19937     AArch64CC = AArch64CC::getInvertedCondCode(AArch64CC);
19938   }
19939 
19940   // It might be neutral for larger constants, as the immediate need to be
19941   // materialized in a register.
19942   APInt ADDC = CTVal->getAPIntValue();
19943   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
19944   if (!TLI.isLegalAddImmediate(ADDC.getSExtValue()))
19945     return SDValue();
19946 
19947   assert(((LHS.getOpcode() == AArch64ISD::CSEL && CFVal->isOne()) ||
19948           (LHS.getOpcode() == AArch64ISD::CSNEG && CFVal->isAllOnes())) &&
19949          "Unexpected constant value");
19950 
19951   SDValue NewNode = DAG.getNode(ISD::ADD, DL, VT, RHS, SDValue(CTVal, 0));
19952   SDValue CCVal = DAG.getConstant(AArch64CC, DL, MVT::i32);
19953   SDValue Cmp = LHS.getOperand(3);
19954 
19955   return DAG.getNode(AArch64ISD::CSINC, DL, VT, NewNode, RHS, CCVal, Cmp);
19956 }
19957 
19958 // ADD(UDOT(zero, x, y), A) -->  UDOT(A, x, y)
performAddDotCombine(SDNode * N,SelectionDAG & DAG)19959 static SDValue performAddDotCombine(SDNode *N, SelectionDAG &DAG) {
19960   EVT VT = N->getValueType(0);
19961   if (N->getOpcode() != ISD::ADD)
19962     return SDValue();
19963 
19964   SDValue Dot = N->getOperand(0);
19965   SDValue A = N->getOperand(1);
19966   // Handle commutivity
19967   auto isZeroDot = [](SDValue Dot) {
19968     return (Dot.getOpcode() == AArch64ISD::UDOT ||
19969             Dot.getOpcode() == AArch64ISD::SDOT) &&
19970            isZerosVector(Dot.getOperand(0).getNode());
19971   };
19972   if (!isZeroDot(Dot))
19973     std::swap(Dot, A);
19974   if (!isZeroDot(Dot))
19975     return SDValue();
19976 
19977   return DAG.getNode(Dot.getOpcode(), SDLoc(N), VT, A, Dot.getOperand(1),
19978                      Dot.getOperand(2));
19979 }
19980 
isNegatedInteger(SDValue Op)19981 static bool isNegatedInteger(SDValue Op) {
19982   return Op.getOpcode() == ISD::SUB && isNullConstant(Op.getOperand(0));
19983 }
19984 
getNegatedInteger(SDValue Op,SelectionDAG & DAG)19985 static SDValue getNegatedInteger(SDValue Op, SelectionDAG &DAG) {
19986   SDLoc DL(Op);
19987   EVT VT = Op.getValueType();
19988   SDValue Zero = DAG.getConstant(0, DL, VT);
19989   return DAG.getNode(ISD::SUB, DL, VT, Zero, Op);
19990 }
19991 
19992 // Try to fold
19993 //
19994 // (neg (csel X, Y)) -> (csel (neg X), (neg Y))
19995 //
19996 // The folding helps csel to be matched with csneg without generating
19997 // redundant neg instruction, which includes negation of the csel expansion
19998 // of abs node lowered by lowerABS.
performNegCSelCombine(SDNode * N,SelectionDAG & DAG)19999 static SDValue performNegCSelCombine(SDNode *N, SelectionDAG &DAG) {
20000   if (!isNegatedInteger(SDValue(N, 0)))
20001     return SDValue();
20002 
20003   SDValue CSel = N->getOperand(1);
20004   if (CSel.getOpcode() != AArch64ISD::CSEL || !CSel->hasOneUse())
20005     return SDValue();
20006 
20007   SDValue N0 = CSel.getOperand(0);
20008   SDValue N1 = CSel.getOperand(1);
20009 
20010   // If both of them is not negations, it's not worth the folding as it
20011   // introduces two additional negations while reducing one negation.
20012   if (!isNegatedInteger(N0) && !isNegatedInteger(N1))
20013     return SDValue();
20014 
20015   SDValue N0N = getNegatedInteger(N0, DAG);
20016   SDValue N1N = getNegatedInteger(N1, DAG);
20017 
20018   SDLoc DL(N);
20019   EVT VT = CSel.getValueType();
20020   return DAG.getNode(AArch64ISD::CSEL, DL, VT, N0N, N1N, CSel.getOperand(2),
20021                      CSel.getOperand(3));
20022 }
20023 
20024 // The basic add/sub long vector instructions have variants with "2" on the end
20025 // which act on the high-half of their inputs. They are normally matched by
20026 // patterns like:
20027 //
20028 // (add (zeroext (extract_high LHS)),
20029 //      (zeroext (extract_high RHS)))
20030 // -> uaddl2 vD, vN, vM
20031 //
20032 // However, if one of the extracts is something like a duplicate, this
20033 // instruction can still be used profitably. This function puts the DAG into a
20034 // more appropriate form for those patterns to trigger.
performAddSubLongCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)20035 static SDValue performAddSubLongCombine(SDNode *N,
20036                                         TargetLowering::DAGCombinerInfo &DCI) {
20037   SelectionDAG &DAG = DCI.DAG;
20038   if (DCI.isBeforeLegalizeOps())
20039     return SDValue();
20040 
20041   MVT VT = N->getSimpleValueType(0);
20042   if (!VT.is128BitVector()) {
20043     if (N->getOpcode() == ISD::ADD)
20044       return performSetccAddFolding(N, DAG);
20045     return SDValue();
20046   }
20047 
20048   // Make sure both branches are extended in the same way.
20049   SDValue LHS = N->getOperand(0);
20050   SDValue RHS = N->getOperand(1);
20051   if ((LHS.getOpcode() != ISD::ZERO_EXTEND &&
20052        LHS.getOpcode() != ISD::SIGN_EXTEND) ||
20053       LHS.getOpcode() != RHS.getOpcode())
20054     return SDValue();
20055 
20056   unsigned ExtType = LHS.getOpcode();
20057 
20058   // It's not worth doing if at least one of the inputs isn't already an
20059   // extract, but we don't know which it'll be so we have to try both.
20060   if (isEssentiallyExtractHighSubvector(LHS.getOperand(0))) {
20061     RHS = tryExtendDUPToExtractHigh(RHS.getOperand(0), DAG);
20062     if (!RHS.getNode())
20063       return SDValue();
20064 
20065     RHS = DAG.getNode(ExtType, SDLoc(N), VT, RHS);
20066   } else if (isEssentiallyExtractHighSubvector(RHS.getOperand(0))) {
20067     LHS = tryExtendDUPToExtractHigh(LHS.getOperand(0), DAG);
20068     if (!LHS.getNode())
20069       return SDValue();
20070 
20071     LHS = DAG.getNode(ExtType, SDLoc(N), VT, LHS);
20072   }
20073 
20074   return DAG.getNode(N->getOpcode(), SDLoc(N), VT, LHS, RHS);
20075 }
20076 
isCMP(SDValue Op)20077 static bool isCMP(SDValue Op) {
20078   return Op.getOpcode() == AArch64ISD::SUBS &&
20079          !Op.getNode()->hasAnyUseOfValue(0);
20080 }
20081 
20082 // (CSEL 1 0 CC Cond) => CC
20083 // (CSEL 0 1 CC Cond) => !CC
getCSETCondCode(SDValue Op)20084 static std::optional<AArch64CC::CondCode> getCSETCondCode(SDValue Op) {
20085   if (Op.getOpcode() != AArch64ISD::CSEL)
20086     return std::nullopt;
20087   auto CC = static_cast<AArch64CC::CondCode>(Op.getConstantOperandVal(2));
20088   if (CC == AArch64CC::AL || CC == AArch64CC::NV)
20089     return std::nullopt;
20090   SDValue OpLHS = Op.getOperand(0);
20091   SDValue OpRHS = Op.getOperand(1);
20092   if (isOneConstant(OpLHS) && isNullConstant(OpRHS))
20093     return CC;
20094   if (isNullConstant(OpLHS) && isOneConstant(OpRHS))
20095     return getInvertedCondCode(CC);
20096 
20097   return std::nullopt;
20098 }
20099 
20100 // (ADC{S} l r (CMP (CSET HS carry) 1)) => (ADC{S} l r carry)
20101 // (SBC{S} l r (CMP 0 (CSET LO carry))) => (SBC{S} l r carry)
foldOverflowCheck(SDNode * Op,SelectionDAG & DAG,bool IsAdd)20102 static SDValue foldOverflowCheck(SDNode *Op, SelectionDAG &DAG, bool IsAdd) {
20103   SDValue CmpOp = Op->getOperand(2);
20104   if (!isCMP(CmpOp))
20105     return SDValue();
20106 
20107   if (IsAdd) {
20108     if (!isOneConstant(CmpOp.getOperand(1)))
20109       return SDValue();
20110   } else {
20111     if (!isNullConstant(CmpOp.getOperand(0)))
20112       return SDValue();
20113   }
20114 
20115   SDValue CsetOp = CmpOp->getOperand(IsAdd ? 0 : 1);
20116   auto CC = getCSETCondCode(CsetOp);
20117   if (CC != (IsAdd ? AArch64CC::HS : AArch64CC::LO))
20118     return SDValue();
20119 
20120   return DAG.getNode(Op->getOpcode(), SDLoc(Op), Op->getVTList(),
20121                      Op->getOperand(0), Op->getOperand(1),
20122                      CsetOp.getOperand(3));
20123 }
20124 
20125 // (ADC x 0 cond) => (CINC x HS cond)
foldADCToCINC(SDNode * N,SelectionDAG & DAG)20126 static SDValue foldADCToCINC(SDNode *N, SelectionDAG &DAG) {
20127   SDValue LHS = N->getOperand(0);
20128   SDValue RHS = N->getOperand(1);
20129   SDValue Cond = N->getOperand(2);
20130 
20131   if (!isNullConstant(RHS))
20132     return SDValue();
20133 
20134   EVT VT = N->getValueType(0);
20135   SDLoc DL(N);
20136 
20137   // (CINC x cc cond) <=> (CSINC x x !cc cond)
20138   SDValue CC = DAG.getConstant(AArch64CC::LO, DL, MVT::i32);
20139   return DAG.getNode(AArch64ISD::CSINC, DL, VT, LHS, LHS, CC, Cond);
20140 }
20141 
performBuildVectorCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)20142 static SDValue performBuildVectorCombine(SDNode *N,
20143                                          TargetLowering::DAGCombinerInfo &DCI,
20144                                          SelectionDAG &DAG) {
20145   SDLoc DL(N);
20146   EVT VT = N->getValueType(0);
20147 
20148   if (DAG.getSubtarget<AArch64Subtarget>().isNeonAvailable() &&
20149       (VT == MVT::v4f16 || VT == MVT::v4bf16)) {
20150     SDValue Elt0 = N->getOperand(0), Elt1 = N->getOperand(1),
20151             Elt2 = N->getOperand(2), Elt3 = N->getOperand(3);
20152     if (Elt0->getOpcode() == ISD::FP_ROUND &&
20153         Elt1->getOpcode() == ISD::FP_ROUND &&
20154         isa<ConstantSDNode>(Elt0->getOperand(1)) &&
20155         isa<ConstantSDNode>(Elt1->getOperand(1)) &&
20156         Elt0->getConstantOperandVal(1) == Elt1->getConstantOperandVal(1) &&
20157         Elt0->getOperand(0)->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
20158         Elt1->getOperand(0)->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
20159         // Constant index.
20160         isa<ConstantSDNode>(Elt0->getOperand(0)->getOperand(1)) &&
20161         isa<ConstantSDNode>(Elt1->getOperand(0)->getOperand(1)) &&
20162         Elt0->getOperand(0)->getOperand(0) ==
20163             Elt1->getOperand(0)->getOperand(0) &&
20164         Elt0->getOperand(0)->getConstantOperandVal(1) == 0 &&
20165         Elt1->getOperand(0)->getConstantOperandVal(1) == 1) {
20166       SDValue LowLanesSrcVec = Elt0->getOperand(0)->getOperand(0);
20167       if (LowLanesSrcVec.getValueType() == MVT::v2f64) {
20168         SDValue HighLanes;
20169         if (Elt2->getOpcode() == ISD::UNDEF &&
20170             Elt3->getOpcode() == ISD::UNDEF) {
20171           HighLanes = DAG.getUNDEF(MVT::v2f32);
20172         } else if (Elt2->getOpcode() == ISD::FP_ROUND &&
20173                    Elt3->getOpcode() == ISD::FP_ROUND &&
20174                    isa<ConstantSDNode>(Elt2->getOperand(1)) &&
20175                    isa<ConstantSDNode>(Elt3->getOperand(1)) &&
20176                    Elt2->getConstantOperandVal(1) ==
20177                        Elt3->getConstantOperandVal(1) &&
20178                    Elt2->getOperand(0)->getOpcode() ==
20179                        ISD::EXTRACT_VECTOR_ELT &&
20180                    Elt3->getOperand(0)->getOpcode() ==
20181                        ISD::EXTRACT_VECTOR_ELT &&
20182                    // Constant index.
20183                    isa<ConstantSDNode>(Elt2->getOperand(0)->getOperand(1)) &&
20184                    isa<ConstantSDNode>(Elt3->getOperand(0)->getOperand(1)) &&
20185                    Elt2->getOperand(0)->getOperand(0) ==
20186                        Elt3->getOperand(0)->getOperand(0) &&
20187                    Elt2->getOperand(0)->getConstantOperandVal(1) == 0 &&
20188                    Elt3->getOperand(0)->getConstantOperandVal(1) == 1) {
20189           SDValue HighLanesSrcVec = Elt2->getOperand(0)->getOperand(0);
20190           HighLanes =
20191               DAG.getNode(AArch64ISD::FCVTXN, DL, MVT::v2f32, HighLanesSrcVec);
20192         }
20193         if (HighLanes) {
20194           SDValue DoubleToSingleSticky =
20195               DAG.getNode(AArch64ISD::FCVTXN, DL, MVT::v2f32, LowLanesSrcVec);
20196           SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v4f32,
20197                                        DoubleToSingleSticky, HighLanes);
20198           return DAG.getNode(ISD::FP_ROUND, DL, VT, Concat,
20199                              Elt0->getOperand(1));
20200         }
20201       }
20202     }
20203   }
20204 
20205   if (VT == MVT::v2f64) {
20206     SDValue Elt0 = N->getOperand(0), Elt1 = N->getOperand(1);
20207     if (Elt0->getOpcode() == ISD::FP_EXTEND &&
20208         Elt1->getOpcode() == ISD::FP_EXTEND &&
20209         Elt0->getOperand(0)->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
20210         Elt1->getOperand(0)->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
20211         Elt0->getOperand(0)->getOperand(0) ==
20212             Elt1->getOperand(0)->getOperand(0) &&
20213         // Constant index.
20214         isa<ConstantSDNode>(Elt0->getOperand(0)->getOperand(1)) &&
20215         isa<ConstantSDNode>(Elt1->getOperand(0)->getOperand(1)) &&
20216         Elt0->getOperand(0)->getConstantOperandVal(1) + 1 ==
20217             Elt1->getOperand(0)->getConstantOperandVal(1) &&
20218         // EXTRACT_SUBVECTOR requires that Idx be a constant multiple of
20219         // ResultType's known minimum vector length.
20220         Elt0->getOperand(0)->getConstantOperandVal(1) %
20221                 VT.getVectorMinNumElements() ==
20222             0) {
20223       SDValue SrcVec = Elt0->getOperand(0)->getOperand(0);
20224       if (SrcVec.getValueType() == MVT::v4f16 ||
20225           SrcVec.getValueType() == MVT::v4bf16) {
20226         SDValue HalfToSingle =
20227             DAG.getNode(ISD::FP_EXTEND, DL, MVT::v4f32, SrcVec);
20228         SDValue SubvectorIdx = Elt0->getOperand(0)->getOperand(1);
20229         SDValue Extract = DAG.getNode(
20230             ISD::EXTRACT_SUBVECTOR, DL, VT.changeVectorElementType(MVT::f32),
20231             HalfToSingle, SubvectorIdx);
20232         return DAG.getNode(ISD::FP_EXTEND, DL, VT, Extract);
20233       }
20234     }
20235   }
20236 
20237   // A build vector of two extracted elements is equivalent to an
20238   // extract subvector where the inner vector is any-extended to the
20239   // extract_vector_elt VT.
20240   //    (build_vector (extract_elt_iXX_to_i32 vec Idx+0)
20241   //                  (extract_elt_iXX_to_i32 vec Idx+1))
20242   // => (extract_subvector (anyext_iXX_to_i32 vec) Idx)
20243 
20244   // For now, only consider the v2i32 case, which arises as a result of
20245   // legalization.
20246   if (VT != MVT::v2i32)
20247     return SDValue();
20248 
20249   SDValue Elt0 = N->getOperand(0), Elt1 = N->getOperand(1);
20250   // Reminder, EXTRACT_VECTOR_ELT has the effect of any-extending to its VT.
20251   if (Elt0->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
20252       Elt1->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
20253       // Constant index.
20254       isa<ConstantSDNode>(Elt0->getOperand(1)) &&
20255       isa<ConstantSDNode>(Elt1->getOperand(1)) &&
20256       // Both EXTRACT_VECTOR_ELT from same vector...
20257       Elt0->getOperand(0) == Elt1->getOperand(0) &&
20258       // ... and contiguous. First element's index +1 == second element's index.
20259       Elt0->getConstantOperandVal(1) + 1 == Elt1->getConstantOperandVal(1) &&
20260       // EXTRACT_SUBVECTOR requires that Idx be a constant multiple of
20261       // ResultType's known minimum vector length.
20262       Elt0->getConstantOperandVal(1) % VT.getVectorMinNumElements() == 0) {
20263     SDValue VecToExtend = Elt0->getOperand(0);
20264     EVT ExtVT = VecToExtend.getValueType().changeVectorElementType(MVT::i32);
20265     if (!DAG.getTargetLoweringInfo().isTypeLegal(ExtVT))
20266       return SDValue();
20267 
20268     SDValue SubvectorIdx = DAG.getVectorIdxConstant(Elt0->getConstantOperandVal(1), DL);
20269 
20270     SDValue Ext = DAG.getNode(ISD::ANY_EXTEND, DL, ExtVT, VecToExtend);
20271     return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v2i32, Ext,
20272                        SubvectorIdx);
20273   }
20274 
20275   return SDValue();
20276 }
20277 
performTruncateCombine(SDNode * N,SelectionDAG & DAG)20278 static SDValue performTruncateCombine(SDNode *N,
20279                                       SelectionDAG &DAG) {
20280   EVT VT = N->getValueType(0);
20281   SDValue N0 = N->getOperand(0);
20282   if (VT.isFixedLengthVector() && VT.is64BitVector() && N0.hasOneUse() &&
20283       N0.getOpcode() == AArch64ISD::DUP) {
20284     SDValue Op = N0.getOperand(0);
20285     if (VT.getScalarType() == MVT::i32 &&
20286         N0.getOperand(0).getValueType().getScalarType() == MVT::i64)
20287       Op = DAG.getNode(ISD::TRUNCATE, SDLoc(N), MVT::i32, Op);
20288     return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, Op);
20289   }
20290 
20291   return SDValue();
20292 }
20293 
20294 // Check an node is an extend or shift operand
isExtendOrShiftOperand(SDValue N)20295 static bool isExtendOrShiftOperand(SDValue N) {
20296   unsigned Opcode = N.getOpcode();
20297   if (ISD::isExtOpcode(Opcode) || Opcode == ISD::SIGN_EXTEND_INREG) {
20298     EVT SrcVT;
20299     if (Opcode == ISD::SIGN_EXTEND_INREG)
20300       SrcVT = cast<VTSDNode>(N.getOperand(1))->getVT();
20301     else
20302       SrcVT = N.getOperand(0).getValueType();
20303 
20304     return SrcVT == MVT::i32 || SrcVT == MVT::i16 || SrcVT == MVT::i8;
20305   } else if (Opcode == ISD::AND) {
20306     ConstantSDNode *CSD = dyn_cast<ConstantSDNode>(N.getOperand(1));
20307     if (!CSD)
20308       return false;
20309     uint64_t AndMask = CSD->getZExtValue();
20310     return AndMask == 0xff || AndMask == 0xffff || AndMask == 0xffffffff;
20311   } else if (Opcode == ISD::SHL || Opcode == ISD::SRL || Opcode == ISD::SRA) {
20312     return isa<ConstantSDNode>(N.getOperand(1));
20313   }
20314 
20315   return false;
20316 }
20317 
20318 // (N - Y) + Z --> (Z - Y) + N
20319 // when N is an extend or shift operand
performAddCombineSubShift(SDNode * N,SDValue SUB,SDValue Z,SelectionDAG & DAG)20320 static SDValue performAddCombineSubShift(SDNode *N, SDValue SUB, SDValue Z,
20321                                          SelectionDAG &DAG) {
20322   auto IsOneUseExtend = [](SDValue N) {
20323     return N.hasOneUse() && isExtendOrShiftOperand(N);
20324   };
20325 
20326   // DAGCombiner will revert the combination when Z is constant cause
20327   // dead loop. So don't enable the combination when Z is constant.
20328   // If Z is one use shift C, we also can't do the optimization.
20329   // It will falling to self infinite loop.
20330   if (isa<ConstantSDNode>(Z) || IsOneUseExtend(Z))
20331     return SDValue();
20332 
20333   if (SUB.getOpcode() != ISD::SUB || !SUB.hasOneUse())
20334     return SDValue();
20335 
20336   SDValue Shift = SUB.getOperand(0);
20337   if (!IsOneUseExtend(Shift))
20338     return SDValue();
20339 
20340   SDLoc DL(N);
20341   EVT VT = N->getValueType(0);
20342 
20343   SDValue Y = SUB.getOperand(1);
20344   SDValue NewSub = DAG.getNode(ISD::SUB, DL, VT, Z, Y);
20345   return DAG.getNode(ISD::ADD, DL, VT, NewSub, Shift);
20346 }
20347 
performAddCombineForShiftedOperands(SDNode * N,SelectionDAG & DAG)20348 static SDValue performAddCombineForShiftedOperands(SDNode *N,
20349                                                    SelectionDAG &DAG) {
20350   // NOTE: Swapping LHS and RHS is not done for SUB, since SUB is not
20351   // commutative.
20352   if (N->getOpcode() != ISD::ADD)
20353     return SDValue();
20354 
20355   // Bail out when value type is not one of {i32, i64}, since AArch64 ADD with
20356   // shifted register is only available for i32 and i64.
20357   EVT VT = N->getValueType(0);
20358   if (VT != MVT::i32 && VT != MVT::i64)
20359     return SDValue();
20360 
20361   SDLoc DL(N);
20362   SDValue LHS = N->getOperand(0);
20363   SDValue RHS = N->getOperand(1);
20364 
20365   if (SDValue Val = performAddCombineSubShift(N, LHS, RHS, DAG))
20366     return Val;
20367   if (SDValue Val = performAddCombineSubShift(N, RHS, LHS, DAG))
20368     return Val;
20369 
20370   uint64_t LHSImm = 0, RHSImm = 0;
20371   // If both operand are shifted by imm and shift amount is not greater than 4
20372   // for one operand, swap LHS and RHS to put operand with smaller shift amount
20373   // on RHS.
20374   //
20375   // On many AArch64 processors (Cortex A78, Neoverse N1/N2/V1, etc), ADD with
20376   // LSL shift (shift <= 4) has smaller latency and larger throughput than ADD
20377   // with LSL (shift > 4). For the rest of processors, this is no-op for
20378   // performance or correctness.
20379   if (isOpcWithIntImmediate(LHS.getNode(), ISD::SHL, LHSImm) &&
20380       isOpcWithIntImmediate(RHS.getNode(), ISD::SHL, RHSImm) && LHSImm <= 4 &&
20381       RHSImm > 4 && LHS.hasOneUse())
20382     return DAG.getNode(ISD::ADD, DL, VT, RHS, LHS);
20383 
20384   return SDValue();
20385 }
20386 
20387 // The mid end will reassociate sub(sub(x, m1), m2) to sub(x, add(m1, m2))
20388 // This reassociates it back to allow the creation of more mls instructions.
performSubAddMULCombine(SDNode * N,SelectionDAG & DAG)20389 static SDValue performSubAddMULCombine(SDNode *N, SelectionDAG &DAG) {
20390   if (N->getOpcode() != ISD::SUB)
20391     return SDValue();
20392 
20393   SDValue Add = N->getOperand(1);
20394   SDValue X = N->getOperand(0);
20395   if (Add.getOpcode() != ISD::ADD)
20396     return SDValue();
20397 
20398   if (!Add.hasOneUse())
20399     return SDValue();
20400   if (DAG.isConstantIntBuildVectorOrConstantInt(peekThroughBitcasts(X)))
20401     return SDValue();
20402 
20403   SDValue M1 = Add.getOperand(0);
20404   SDValue M2 = Add.getOperand(1);
20405   if (M1.getOpcode() != ISD::MUL && M1.getOpcode() != AArch64ISD::SMULL &&
20406       M1.getOpcode() != AArch64ISD::UMULL)
20407     return SDValue();
20408   if (M2.getOpcode() != ISD::MUL && M2.getOpcode() != AArch64ISD::SMULL &&
20409       M2.getOpcode() != AArch64ISD::UMULL)
20410     return SDValue();
20411 
20412   EVT VT = N->getValueType(0);
20413   SDValue Sub = DAG.getNode(ISD::SUB, SDLoc(N), VT, X, M1);
20414   return DAG.getNode(ISD::SUB, SDLoc(N), VT, Sub, M2);
20415 }
20416 
20417 // Combine into mla/mls.
20418 // This works on the patterns of:
20419 //   add v1, (mul v2, v3)
20420 //   sub v1, (mul v2, v3)
20421 // for vectors of type <1 x i64> and <2 x i64> when SVE is available.
20422 // It will transform the add/sub to a scalable version, so that we can
20423 // make use of SVE's MLA/MLS that will be generated for that pattern
20424 static SDValue
performSVEMulAddSubCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)20425 performSVEMulAddSubCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
20426   SelectionDAG &DAG = DCI.DAG;
20427   // Make sure that the types are legal
20428   if (!DCI.isAfterLegalizeDAG())
20429     return SDValue();
20430   // Before using SVE's features, check first if it's available.
20431   if (!DAG.getSubtarget<AArch64Subtarget>().hasSVE())
20432     return SDValue();
20433 
20434   if (N->getOpcode() != ISD::ADD && N->getOpcode() != ISD::SUB)
20435     return SDValue();
20436 
20437   if (!N->getValueType(0).isFixedLengthVector())
20438     return SDValue();
20439 
20440   auto performOpt = [&DAG, &N](SDValue Op0, SDValue Op1) -> SDValue {
20441     if (Op1.getOpcode() != ISD::EXTRACT_SUBVECTOR)
20442       return SDValue();
20443 
20444     if (!cast<ConstantSDNode>(Op1->getOperand(1))->isZero())
20445       return SDValue();
20446 
20447     SDValue MulValue = Op1->getOperand(0);
20448     if (MulValue.getOpcode() != AArch64ISD::MUL_PRED)
20449       return SDValue();
20450 
20451     if (!Op1.hasOneUse() || !MulValue.hasOneUse())
20452       return SDValue();
20453 
20454     EVT ScalableVT = MulValue.getValueType();
20455     if (!ScalableVT.isScalableVector())
20456       return SDValue();
20457 
20458     SDValue ScaledOp = convertToScalableVector(DAG, ScalableVT, Op0);
20459     SDValue NewValue =
20460         DAG.getNode(N->getOpcode(), SDLoc(N), ScalableVT, {ScaledOp, MulValue});
20461     return convertFromScalableVector(DAG, N->getValueType(0), NewValue);
20462   };
20463 
20464   if (SDValue res = performOpt(N->getOperand(0), N->getOperand(1)))
20465     return res;
20466   else if (N->getOpcode() == ISD::ADD)
20467     return performOpt(N->getOperand(1), N->getOperand(0));
20468 
20469   return SDValue();
20470 }
20471 
20472 // Given a i64 add from a v1i64 extract, convert to a neon v1i64 add. This can
20473 // help, for example, to produce ssra from sshr+add.
performAddSubIntoVectorOp(SDNode * N,SelectionDAG & DAG)20474 static SDValue performAddSubIntoVectorOp(SDNode *N, SelectionDAG &DAG) {
20475   EVT VT = N->getValueType(0);
20476   if (VT != MVT::i64 ||
20477       DAG.getTargetLoweringInfo().isOperationExpand(N->getOpcode(), MVT::v1i64))
20478     return SDValue();
20479   SDValue Op0 = N->getOperand(0);
20480   SDValue Op1 = N->getOperand(1);
20481 
20482   // At least one of the operands should be an extract, and the other should be
20483   // something that is easy to convert to v1i64 type (in this case a load).
20484   if (Op0.getOpcode() != ISD::EXTRACT_VECTOR_ELT &&
20485       Op0.getOpcode() != ISD::LOAD)
20486     return SDValue();
20487   if (Op1.getOpcode() != ISD::EXTRACT_VECTOR_ELT &&
20488       Op1.getOpcode() != ISD::LOAD)
20489     return SDValue();
20490 
20491   SDLoc DL(N);
20492   if (Op0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
20493       Op0.getOperand(0).getValueType() == MVT::v1i64) {
20494     Op0 = Op0.getOperand(0);
20495     Op1 = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v1i64, Op1);
20496   } else if (Op1.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
20497              Op1.getOperand(0).getValueType() == MVT::v1i64) {
20498     Op0 = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v1i64, Op0);
20499     Op1 = Op1.getOperand(0);
20500   } else
20501     return SDValue();
20502 
20503   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64,
20504                      DAG.getNode(N->getOpcode(), DL, MVT::v1i64, Op0, Op1),
20505                      DAG.getConstant(0, DL, MVT::i64));
20506 }
20507 
isLoadOrMultipleLoads(SDValue B,SmallVector<LoadSDNode * > & Loads)20508 static bool isLoadOrMultipleLoads(SDValue B, SmallVector<LoadSDNode *> &Loads) {
20509   SDValue BV = peekThroughOneUseBitcasts(B);
20510   if (!BV->hasOneUse())
20511     return false;
20512   if (auto *Ld = dyn_cast<LoadSDNode>(BV)) {
20513     if (!Ld || !Ld->isSimple())
20514       return false;
20515     Loads.push_back(Ld);
20516     return true;
20517   } else if (BV.getOpcode() == ISD::BUILD_VECTOR ||
20518              BV.getOpcode() == ISD::CONCAT_VECTORS) {
20519     for (unsigned Op = 0; Op < BV.getNumOperands(); Op++) {
20520       auto *Ld = dyn_cast<LoadSDNode>(BV.getOperand(Op));
20521       if (!Ld || !Ld->isSimple() || !BV.getOperand(Op).hasOneUse())
20522         return false;
20523       Loads.push_back(Ld);
20524     }
20525     return true;
20526   } else if (B.getOpcode() == ISD::VECTOR_SHUFFLE) {
20527     // Try to find a tree of shuffles and concats from how IR shuffles of loads
20528     // are lowered. Note that this only comes up because we do not always visit
20529     // operands before uses. After that is fixed this can be removed and in the
20530     // meantime this is fairly specific to the lowering we expect from IR.
20531     // t46: v16i8 = vector_shuffle<0,1,2,3,4,5,6,7,8,9,10,11,16,17,18,19> t44, t45
20532     //   t44: v16i8 = vector_shuffle<0,1,2,3,4,5,6,7,16,17,18,19,u,u,u,u> t42, t43
20533     //     t42: v16i8 = concat_vectors t40, t36, undef:v4i8, undef:v4i8
20534     //       t40: v4i8,ch = load<(load (s32) from %ir.17)> t0, t22, undef:i64
20535     //       t36: v4i8,ch = load<(load (s32) from %ir.13)> t0, t18, undef:i64
20536     //     t43: v16i8 = concat_vectors t32, undef:v4i8, undef:v4i8, undef:v4i8
20537     //       t32: v4i8,ch = load<(load (s32) from %ir.9)> t0, t14, undef:i64
20538     //   t45: v16i8 = concat_vectors t28, undef:v4i8, undef:v4i8, undef:v4i8
20539     //     t28: v4i8,ch = load<(load (s32) from %ir.0)> t0, t2, undef:i64
20540     if (B.getOperand(0).getOpcode() != ISD::VECTOR_SHUFFLE ||
20541         B.getOperand(0).getOperand(0).getOpcode() != ISD::CONCAT_VECTORS ||
20542         B.getOperand(0).getOperand(1).getOpcode() != ISD::CONCAT_VECTORS ||
20543         B.getOperand(1).getOpcode() != ISD::CONCAT_VECTORS ||
20544         B.getOperand(1).getNumOperands() != 4)
20545       return false;
20546     auto SV1 = cast<ShuffleVectorSDNode>(B);
20547     auto SV2 = cast<ShuffleVectorSDNode>(B.getOperand(0));
20548     int NumElts = B.getValueType().getVectorNumElements();
20549     int NumSubElts = NumElts / 4;
20550     for (int I = 0; I < NumSubElts; I++) {
20551       // <0,1,2,3,4,5,6,7,8,9,10,11,16,17,18,19>
20552       if (SV1->getMaskElt(I) != I ||
20553           SV1->getMaskElt(I + NumSubElts) != I + NumSubElts ||
20554           SV1->getMaskElt(I + NumSubElts * 2) != I + NumSubElts * 2 ||
20555           SV1->getMaskElt(I + NumSubElts * 3) != I + NumElts)
20556         return false;
20557       // <0,1,2,3,4,5,6,7,16,17,18,19,u,u,u,u>
20558       if (SV2->getMaskElt(I) != I ||
20559           SV2->getMaskElt(I + NumSubElts) != I + NumSubElts ||
20560           SV2->getMaskElt(I + NumSubElts * 2) != I + NumElts)
20561         return false;
20562     }
20563     auto *Ld0 = dyn_cast<LoadSDNode>(SV2->getOperand(0).getOperand(0));
20564     auto *Ld1 = dyn_cast<LoadSDNode>(SV2->getOperand(0).getOperand(1));
20565     auto *Ld2 = dyn_cast<LoadSDNode>(SV2->getOperand(1).getOperand(0));
20566     auto *Ld3 = dyn_cast<LoadSDNode>(B.getOperand(1).getOperand(0));
20567     if (!Ld0 || !Ld1 || !Ld2 || !Ld3 || !Ld0->isSimple() || !Ld1->isSimple() ||
20568         !Ld2->isSimple() || !Ld3->isSimple())
20569       return false;
20570     Loads.push_back(Ld0);
20571     Loads.push_back(Ld1);
20572     Loads.push_back(Ld2);
20573     Loads.push_back(Ld3);
20574     return true;
20575   }
20576   return false;
20577 }
20578 
areLoadedOffsetButOtherwiseSame(SDValue Op0,SDValue Op1,SelectionDAG & DAG,unsigned & NumSubLoads)20579 static bool areLoadedOffsetButOtherwiseSame(SDValue Op0, SDValue Op1,
20580                                             SelectionDAG &DAG,
20581                                             unsigned &NumSubLoads) {
20582   if (!Op0.hasOneUse() || !Op1.hasOneUse())
20583     return false;
20584 
20585   SmallVector<LoadSDNode *> Loads0, Loads1;
20586   if (isLoadOrMultipleLoads(Op0, Loads0) &&
20587       isLoadOrMultipleLoads(Op1, Loads1)) {
20588     if (NumSubLoads && Loads0.size() != NumSubLoads)
20589       return false;
20590     NumSubLoads = Loads0.size();
20591     return Loads0.size() == Loads1.size() &&
20592            all_of(zip(Loads0, Loads1), [&DAG](auto L) {
20593              unsigned Size = get<0>(L)->getValueType(0).getSizeInBits();
20594              return Size == get<1>(L)->getValueType(0).getSizeInBits() &&
20595                     DAG.areNonVolatileConsecutiveLoads(get<1>(L), get<0>(L),
20596                                                        Size / 8, 1);
20597            });
20598   }
20599 
20600   if (Op0.getOpcode() != Op1.getOpcode())
20601     return false;
20602 
20603   switch (Op0.getOpcode()) {
20604   case ISD::ADD:
20605   case ISD::SUB:
20606     return areLoadedOffsetButOtherwiseSame(Op0.getOperand(0), Op1.getOperand(0),
20607                                            DAG, NumSubLoads) &&
20608            areLoadedOffsetButOtherwiseSame(Op0.getOperand(1), Op1.getOperand(1),
20609                                            DAG, NumSubLoads);
20610   case ISD::SIGN_EXTEND:
20611   case ISD::ANY_EXTEND:
20612   case ISD::ZERO_EXTEND:
20613     EVT XVT = Op0.getOperand(0).getValueType();
20614     if (XVT.getScalarSizeInBits() != 8 && XVT.getScalarSizeInBits() != 16 &&
20615         XVT.getScalarSizeInBits() != 32)
20616       return false;
20617     return areLoadedOffsetButOtherwiseSame(Op0.getOperand(0), Op1.getOperand(0),
20618                                            DAG, NumSubLoads);
20619   }
20620   return false;
20621 }
20622 
20623 // This method attempts to fold trees of add(ext(load p), shl(ext(load p+4))
20624 // into a single load of twice the size, that we extract the bottom part and top
20625 // part so that the shl can use a shll2 instruction. The two loads in that
20626 // example can also be larger trees of instructions, which are identical except
20627 // for the leaves which are all loads offset from the LHS, including
20628 // buildvectors of multiple loads. For example the RHS tree could be
20629 // sub(zext(buildvec(load p+4, load q+4)), zext(buildvec(load r+4, load s+4)))
20630 // Whilst it can be common for the larger loads to replace LDP instructions
20631 // (which doesn't gain anything on it's own), the larger loads can help create
20632 // more efficient code, and in buildvectors prevent the need for ld1 lane
20633 // inserts which can be slower than normal loads.
performExtBinopLoadFold(SDNode * N,SelectionDAG & DAG)20634 static SDValue performExtBinopLoadFold(SDNode *N, SelectionDAG &DAG) {
20635   EVT VT = N->getValueType(0);
20636   if (!VT.isFixedLengthVector() ||
20637       (VT.getScalarSizeInBits() != 16 && VT.getScalarSizeInBits() != 32 &&
20638        VT.getScalarSizeInBits() != 64))
20639     return SDValue();
20640 
20641   SDValue Other = N->getOperand(0);
20642   SDValue Shift = N->getOperand(1);
20643   if (Shift.getOpcode() != ISD::SHL && N->getOpcode() != ISD::SUB)
20644     std::swap(Shift, Other);
20645   APInt ShiftAmt;
20646   if (Shift.getOpcode() != ISD::SHL || !Shift.hasOneUse() ||
20647       !ISD::isConstantSplatVector(Shift.getOperand(1).getNode(), ShiftAmt))
20648     return SDValue();
20649 
20650   if (!ISD::isExtOpcode(Shift.getOperand(0).getOpcode()) ||
20651       !ISD::isExtOpcode(Other.getOpcode()) ||
20652       Shift.getOperand(0).getOperand(0).getValueType() !=
20653           Other.getOperand(0).getValueType() ||
20654       !Other.hasOneUse() || !Shift.getOperand(0).hasOneUse())
20655     return SDValue();
20656 
20657   SDValue Op0 = Other.getOperand(0);
20658   SDValue Op1 = Shift.getOperand(0).getOperand(0);
20659 
20660   unsigned NumSubLoads = 0;
20661   if (!areLoadedOffsetButOtherwiseSame(Op0, Op1, DAG, NumSubLoads))
20662     return SDValue();
20663 
20664   // Attempt to rule out some unprofitable cases using heuristics (some working
20665   // around suboptimal code generation), notably if the extend not be able to
20666   // use ushll2 instructions as the types are not large enough. Otherwise zip's
20667   // will need to be created which can increase the instruction count.
20668   unsigned NumElts = Op0.getValueType().getVectorNumElements();
20669   unsigned NumSubElts = NumElts / NumSubLoads;
20670   if (NumSubElts * VT.getScalarSizeInBits() < 128 ||
20671       (Other.getOpcode() != Shift.getOperand(0).getOpcode() &&
20672        Op0.getValueType().getSizeInBits() < 128 &&
20673        !DAG.getTargetLoweringInfo().isTypeLegal(Op0.getValueType())))
20674     return SDValue();
20675 
20676   // Recreate the tree with the new combined loads.
20677   std::function<SDValue(SDValue, SDValue, SelectionDAG &)> GenCombinedTree =
20678       [&GenCombinedTree](SDValue Op0, SDValue Op1, SelectionDAG &DAG) {
20679         EVT DVT =
20680             Op0.getValueType().getDoubleNumVectorElementsVT(*DAG.getContext());
20681 
20682         SmallVector<LoadSDNode *> Loads0, Loads1;
20683         if (isLoadOrMultipleLoads(Op0, Loads0) &&
20684             isLoadOrMultipleLoads(Op1, Loads1)) {
20685           EVT LoadVT = EVT::getVectorVT(
20686               *DAG.getContext(), Op0.getValueType().getScalarType(),
20687               Op0.getValueType().getVectorNumElements() / Loads0.size());
20688           EVT DLoadVT = LoadVT.getDoubleNumVectorElementsVT(*DAG.getContext());
20689 
20690           SmallVector<SDValue> NewLoads;
20691           for (const auto &[L0, L1] : zip(Loads0, Loads1)) {
20692             SDValue Load = DAG.getLoad(DLoadVT, SDLoc(L0), L0->getChain(),
20693                                        L0->getBasePtr(), L0->getPointerInfo(),
20694                                        L0->getOriginalAlign());
20695             DAG.makeEquivalentMemoryOrdering(L0, Load.getValue(1));
20696             DAG.makeEquivalentMemoryOrdering(L1, Load.getValue(1));
20697             NewLoads.push_back(Load);
20698           }
20699           return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(Op0), DVT, NewLoads);
20700         }
20701 
20702         SmallVector<SDValue> Ops;
20703         for (const auto &[O0, O1] : zip(Op0->op_values(), Op1->op_values()))
20704           Ops.push_back(GenCombinedTree(O0, O1, DAG));
20705         return DAG.getNode(Op0.getOpcode(), SDLoc(Op0), DVT, Ops);
20706       };
20707   SDValue NewOp = GenCombinedTree(Op0, Op1, DAG);
20708 
20709   SmallVector<int> LowMask(NumElts, 0), HighMask(NumElts, 0);
20710   int Hi = NumSubElts, Lo = 0;
20711   for (unsigned i = 0; i < NumSubLoads; i++) {
20712     for (unsigned j = 0; j < NumSubElts; j++) {
20713       LowMask[i * NumSubElts + j] = Lo++;
20714       HighMask[i * NumSubElts + j] = Hi++;
20715     }
20716     Lo += NumSubElts;
20717     Hi += NumSubElts;
20718   }
20719   SDLoc DL(N);
20720   SDValue Ext0, Ext1;
20721   // Extract the top and bottom lanes, then extend the result. Possibly extend
20722   // the result then extract the lanes if the two operands match as it produces
20723   // slightly smaller code.
20724   if (Other.getOpcode() != Shift.getOperand(0).getOpcode()) {
20725     SDValue SubL = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, Op0.getValueType(),
20726                                NewOp, DAG.getConstant(0, DL, MVT::i64));
20727     SDValue SubH =
20728         DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, Op0.getValueType(), NewOp,
20729                     DAG.getConstant(NumSubElts * NumSubLoads, DL, MVT::i64));
20730     SDValue Extr0 =
20731         DAG.getVectorShuffle(Op0.getValueType(), DL, SubL, SubH, LowMask);
20732     SDValue Extr1 =
20733         DAG.getVectorShuffle(Op0.getValueType(), DL, SubL, SubH, HighMask);
20734     Ext0 = DAG.getNode(Other.getOpcode(), DL, VT, Extr0);
20735     Ext1 = DAG.getNode(Shift.getOperand(0).getOpcode(), DL, VT, Extr1);
20736   } else {
20737     EVT DVT = VT.getDoubleNumVectorElementsVT(*DAG.getContext());
20738     SDValue Ext = DAG.getNode(Other.getOpcode(), DL, DVT, NewOp);
20739     SDValue SubL = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Ext,
20740                                DAG.getConstant(0, DL, MVT::i64));
20741     SDValue SubH =
20742         DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Ext,
20743                     DAG.getConstant(NumSubElts * NumSubLoads, DL, MVT::i64));
20744     Ext0 = DAG.getVectorShuffle(VT, DL, SubL, SubH, LowMask);
20745     Ext1 = DAG.getVectorShuffle(VT, DL, SubL, SubH, HighMask);
20746   }
20747   SDValue NShift =
20748       DAG.getNode(Shift.getOpcode(), DL, VT, Ext1, Shift.getOperand(1));
20749   return DAG.getNode(N->getOpcode(), DL, VT, Ext0, NShift);
20750 }
20751 
performAddSubCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)20752 static SDValue performAddSubCombine(SDNode *N,
20753                                     TargetLowering::DAGCombinerInfo &DCI) {
20754   // Try to change sum of two reductions.
20755   if (SDValue Val = performAddUADDVCombine(N, DCI.DAG))
20756     return Val;
20757   if (SDValue Val = performAddDotCombine(N, DCI.DAG))
20758     return Val;
20759   if (SDValue Val = performAddCSelIntoCSinc(N, DCI.DAG))
20760     return Val;
20761   if (SDValue Val = performNegCSelCombine(N, DCI.DAG))
20762     return Val;
20763   if (SDValue Val = performVectorExtCombine(N, DCI.DAG))
20764     return Val;
20765   if (SDValue Val = performAddCombineForShiftedOperands(N, DCI.DAG))
20766     return Val;
20767   if (SDValue Val = performSubAddMULCombine(N, DCI.DAG))
20768     return Val;
20769   if (SDValue Val = performSVEMulAddSubCombine(N, DCI))
20770     return Val;
20771   if (SDValue Val = performAddSubIntoVectorOp(N, DCI.DAG))
20772     return Val;
20773 
20774   if (SDValue Val = performExtBinopLoadFold(N, DCI.DAG))
20775     return Val;
20776 
20777   return performAddSubLongCombine(N, DCI);
20778 }
20779 
20780 // Massage DAGs which we can use the high-half "long" operations on into
20781 // something isel will recognize better. E.g.
20782 //
20783 // (aarch64_neon_umull (extract_high vec) (dupv64 scalar)) -->
20784 //   (aarch64_neon_umull (extract_high (v2i64 vec)))
20785 //                     (extract_high (v2i64 (dup128 scalar)))))
20786 //
tryCombineLongOpWithDup(unsigned IID,SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)20787 static SDValue tryCombineLongOpWithDup(unsigned IID, SDNode *N,
20788                                        TargetLowering::DAGCombinerInfo &DCI,
20789                                        SelectionDAG &DAG) {
20790   if (DCI.isBeforeLegalizeOps())
20791     return SDValue();
20792 
20793   SDValue LHS = N->getOperand((IID == Intrinsic::not_intrinsic) ? 0 : 1);
20794   SDValue RHS = N->getOperand((IID == Intrinsic::not_intrinsic) ? 1 : 2);
20795   assert(LHS.getValueType().is64BitVector() &&
20796          RHS.getValueType().is64BitVector() &&
20797          "unexpected shape for long operation");
20798 
20799   // Either node could be a DUP, but it's not worth doing both of them (you'd
20800   // just as well use the non-high version) so look for a corresponding extract
20801   // operation on the other "wing".
20802   if (isEssentiallyExtractHighSubvector(LHS)) {
20803     RHS = tryExtendDUPToExtractHigh(RHS, DAG);
20804     if (!RHS.getNode())
20805       return SDValue();
20806   } else if (isEssentiallyExtractHighSubvector(RHS)) {
20807     LHS = tryExtendDUPToExtractHigh(LHS, DAG);
20808     if (!LHS.getNode())
20809       return SDValue();
20810   } else
20811     return SDValue();
20812 
20813   if (IID == Intrinsic::not_intrinsic)
20814     return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), LHS, RHS);
20815 
20816   return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SDLoc(N), N->getValueType(0),
20817                      N->getOperand(0), LHS, RHS);
20818 }
20819 
tryCombineShiftImm(unsigned IID,SDNode * N,SelectionDAG & DAG)20820 static SDValue tryCombineShiftImm(unsigned IID, SDNode *N, SelectionDAG &DAG) {
20821   MVT ElemTy = N->getSimpleValueType(0).getScalarType();
20822   unsigned ElemBits = ElemTy.getSizeInBits();
20823 
20824   int64_t ShiftAmount;
20825   if (BuildVectorSDNode *BVN = dyn_cast<BuildVectorSDNode>(N->getOperand(2))) {
20826     APInt SplatValue, SplatUndef;
20827     unsigned SplatBitSize;
20828     bool HasAnyUndefs;
20829     if (!BVN->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
20830                               HasAnyUndefs, ElemBits) ||
20831         SplatBitSize != ElemBits)
20832       return SDValue();
20833 
20834     ShiftAmount = SplatValue.getSExtValue();
20835   } else if (ConstantSDNode *CVN = dyn_cast<ConstantSDNode>(N->getOperand(2))) {
20836     ShiftAmount = CVN->getSExtValue();
20837   } else
20838     return SDValue();
20839 
20840   // If the shift amount is zero, remove the shift intrinsic.
20841   if (ShiftAmount == 0 && IID != Intrinsic::aarch64_neon_sqshlu)
20842     return N->getOperand(1);
20843 
20844   unsigned Opcode;
20845   bool IsRightShift;
20846   switch (IID) {
20847   default:
20848     llvm_unreachable("Unknown shift intrinsic");
20849   case Intrinsic::aarch64_neon_sqshl:
20850     Opcode = AArch64ISD::SQSHL_I;
20851     IsRightShift = false;
20852     break;
20853   case Intrinsic::aarch64_neon_uqshl:
20854     Opcode = AArch64ISD::UQSHL_I;
20855     IsRightShift = false;
20856     break;
20857   case Intrinsic::aarch64_neon_srshl:
20858     Opcode = AArch64ISD::SRSHR_I;
20859     IsRightShift = true;
20860     break;
20861   case Intrinsic::aarch64_neon_urshl:
20862     Opcode = AArch64ISD::URSHR_I;
20863     IsRightShift = true;
20864     break;
20865   case Intrinsic::aarch64_neon_sqshlu:
20866     Opcode = AArch64ISD::SQSHLU_I;
20867     IsRightShift = false;
20868     break;
20869   case Intrinsic::aarch64_neon_sshl:
20870   case Intrinsic::aarch64_neon_ushl:
20871     // For positive shift amounts we can use SHL, as ushl/sshl perform a regular
20872     // left shift for positive shift amounts. For negative shifts we can use a
20873     // VASHR/VLSHR as appropiate.
20874     if (ShiftAmount < 0) {
20875       Opcode = IID == Intrinsic::aarch64_neon_sshl ? AArch64ISD::VASHR
20876                                                    : AArch64ISD::VLSHR;
20877       ShiftAmount = -ShiftAmount;
20878     } else
20879       Opcode = AArch64ISD::VSHL;
20880     IsRightShift = false;
20881     break;
20882   }
20883 
20884   EVT VT = N->getValueType(0);
20885   SDValue Op = N->getOperand(1);
20886   SDLoc dl(N);
20887   if (VT == MVT::i64) {
20888     Op = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v1i64, Op);
20889     VT = MVT::v1i64;
20890   }
20891 
20892   if (IsRightShift && ShiftAmount <= -1 && ShiftAmount >= -(int)ElemBits) {
20893     Op = DAG.getNode(Opcode, dl, VT, Op,
20894                      DAG.getConstant(-ShiftAmount, dl, MVT::i32));
20895     if (N->getValueType(0) == MVT::i64)
20896       Op = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i64, Op,
20897                        DAG.getConstant(0, dl, MVT::i64));
20898     return Op;
20899   } else if (!IsRightShift && ShiftAmount >= 0 && ShiftAmount < ElemBits) {
20900     Op = DAG.getNode(Opcode, dl, VT, Op,
20901                      DAG.getConstant(ShiftAmount, dl, MVT::i32));
20902     if (N->getValueType(0) == MVT::i64)
20903       Op = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i64, Op,
20904                        DAG.getConstant(0, dl, MVT::i64));
20905     return Op;
20906   }
20907 
20908   return SDValue();
20909 }
20910 
20911 // The CRC32[BH] instructions ignore the high bits of their data operand. Since
20912 // the intrinsics must be legal and take an i32, this means there's almost
20913 // certainly going to be a zext in the DAG which we can eliminate.
tryCombineCRC32(unsigned Mask,SDNode * N,SelectionDAG & DAG)20914 static SDValue tryCombineCRC32(unsigned Mask, SDNode *N, SelectionDAG &DAG) {
20915   SDValue AndN = N->getOperand(2);
20916   if (AndN.getOpcode() != ISD::AND)
20917     return SDValue();
20918 
20919   ConstantSDNode *CMask = dyn_cast<ConstantSDNode>(AndN.getOperand(1));
20920   if (!CMask || CMask->getZExtValue() != Mask)
20921     return SDValue();
20922 
20923   return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SDLoc(N), MVT::i32,
20924                      N->getOperand(0), N->getOperand(1), AndN.getOperand(0));
20925 }
20926 
combineAcrossLanesIntrinsic(unsigned Opc,SDNode * N,SelectionDAG & DAG)20927 static SDValue combineAcrossLanesIntrinsic(unsigned Opc, SDNode *N,
20928                                            SelectionDAG &DAG) {
20929   SDLoc dl(N);
20930   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, N->getValueType(0),
20931                      DAG.getNode(Opc, dl,
20932                                  N->getOperand(1).getSimpleValueType(),
20933                                  N->getOperand(1)),
20934                      DAG.getConstant(0, dl, MVT::i64));
20935 }
20936 
LowerSVEIntrinsicIndex(SDNode * N,SelectionDAG & DAG)20937 static SDValue LowerSVEIntrinsicIndex(SDNode *N, SelectionDAG &DAG) {
20938   SDLoc DL(N);
20939   SDValue Op1 = N->getOperand(1);
20940   SDValue Op2 = N->getOperand(2);
20941   EVT ScalarTy = Op2.getValueType();
20942   if ((ScalarTy == MVT::i8) || (ScalarTy == MVT::i16))
20943     ScalarTy = MVT::i32;
20944 
20945   // Lower index_vector(base, step) to mul(step step_vector(1)) + splat(base).
20946   SDValue StepVector = DAG.getStepVector(DL, N->getValueType(0));
20947   SDValue Step = DAG.getNode(ISD::SPLAT_VECTOR, DL, N->getValueType(0), Op2);
20948   SDValue Mul = DAG.getNode(ISD::MUL, DL, N->getValueType(0), StepVector, Step);
20949   SDValue Base = DAG.getNode(ISD::SPLAT_VECTOR, DL, N->getValueType(0), Op1);
20950   return DAG.getNode(ISD::ADD, DL, N->getValueType(0), Mul, Base);
20951 }
20952 
LowerSVEIntrinsicDUP(SDNode * N,SelectionDAG & DAG)20953 static SDValue LowerSVEIntrinsicDUP(SDNode *N, SelectionDAG &DAG) {
20954   SDLoc dl(N);
20955   SDValue Scalar = N->getOperand(3);
20956   EVT ScalarTy = Scalar.getValueType();
20957 
20958   if ((ScalarTy == MVT::i8) || (ScalarTy == MVT::i16))
20959     Scalar = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32, Scalar);
20960 
20961   SDValue Passthru = N->getOperand(1);
20962   SDValue Pred = N->getOperand(2);
20963   return DAG.getNode(AArch64ISD::DUP_MERGE_PASSTHRU, dl, N->getValueType(0),
20964                      Pred, Scalar, Passthru);
20965 }
20966 
LowerSVEIntrinsicEXT(SDNode * N,SelectionDAG & DAG)20967 static SDValue LowerSVEIntrinsicEXT(SDNode *N, SelectionDAG &DAG) {
20968   SDLoc dl(N);
20969   LLVMContext &Ctx = *DAG.getContext();
20970   EVT VT = N->getValueType(0);
20971 
20972   assert(VT.isScalableVector() && "Expected a scalable vector.");
20973 
20974   // Current lowering only supports the SVE-ACLE types.
20975   if (VT.getSizeInBits().getKnownMinValue() != AArch64::SVEBitsPerBlock)
20976     return SDValue();
20977 
20978   unsigned ElemSize = VT.getVectorElementType().getSizeInBits() / 8;
20979   unsigned ByteSize = VT.getSizeInBits().getKnownMinValue() / 8;
20980   EVT ByteVT =
20981       EVT::getVectorVT(Ctx, MVT::i8, ElementCount::getScalable(ByteSize));
20982 
20983   // Convert everything to the domain of EXT (i.e bytes).
20984   SDValue Op0 = DAG.getNode(ISD::BITCAST, dl, ByteVT, N->getOperand(1));
20985   SDValue Op1 = DAG.getNode(ISD::BITCAST, dl, ByteVT, N->getOperand(2));
20986   SDValue Op2 = DAG.getNode(ISD::MUL, dl, MVT::i32, N->getOperand(3),
20987                             DAG.getConstant(ElemSize, dl, MVT::i32));
20988 
20989   SDValue EXT = DAG.getNode(AArch64ISD::EXT, dl, ByteVT, Op0, Op1, Op2);
20990   return DAG.getNode(ISD::BITCAST, dl, VT, EXT);
20991 }
20992 
tryConvertSVEWideCompare(SDNode * N,ISD::CondCode CC,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)20993 static SDValue tryConvertSVEWideCompare(SDNode *N, ISD::CondCode CC,
20994                                         TargetLowering::DAGCombinerInfo &DCI,
20995                                         SelectionDAG &DAG) {
20996   if (DCI.isBeforeLegalize())
20997     return SDValue();
20998 
20999   SDValue Comparator = N->getOperand(3);
21000   if (Comparator.getOpcode() == AArch64ISD::DUP ||
21001       Comparator.getOpcode() == ISD::SPLAT_VECTOR) {
21002     unsigned IID = getIntrinsicID(N);
21003     EVT VT = N->getValueType(0);
21004     EVT CmpVT = N->getOperand(2).getValueType();
21005     SDValue Pred = N->getOperand(1);
21006     SDValue Imm;
21007     SDLoc DL(N);
21008 
21009     switch (IID) {
21010     default:
21011       llvm_unreachable("Called with wrong intrinsic!");
21012       break;
21013 
21014     // Signed comparisons
21015     case Intrinsic::aarch64_sve_cmpeq_wide:
21016     case Intrinsic::aarch64_sve_cmpne_wide:
21017     case Intrinsic::aarch64_sve_cmpge_wide:
21018     case Intrinsic::aarch64_sve_cmpgt_wide:
21019     case Intrinsic::aarch64_sve_cmplt_wide:
21020     case Intrinsic::aarch64_sve_cmple_wide: {
21021       if (auto *CN = dyn_cast<ConstantSDNode>(Comparator.getOperand(0))) {
21022         int64_t ImmVal = CN->getSExtValue();
21023         if (ImmVal >= -16 && ImmVal <= 15)
21024           Imm = DAG.getConstant(ImmVal, DL, MVT::i32);
21025         else
21026           return SDValue();
21027       }
21028       break;
21029     }
21030     // Unsigned comparisons
21031     case Intrinsic::aarch64_sve_cmphs_wide:
21032     case Intrinsic::aarch64_sve_cmphi_wide:
21033     case Intrinsic::aarch64_sve_cmplo_wide:
21034     case Intrinsic::aarch64_sve_cmpls_wide:  {
21035       if (auto *CN = dyn_cast<ConstantSDNode>(Comparator.getOperand(0))) {
21036         uint64_t ImmVal = CN->getZExtValue();
21037         if (ImmVal <= 127)
21038           Imm = DAG.getConstant(ImmVal, DL, MVT::i32);
21039         else
21040           return SDValue();
21041       }
21042       break;
21043     }
21044     }
21045 
21046     if (!Imm)
21047       return SDValue();
21048 
21049     SDValue Splat = DAG.getNode(ISD::SPLAT_VECTOR, DL, CmpVT, Imm);
21050     return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, DL, VT, Pred,
21051                        N->getOperand(2), Splat, DAG.getCondCode(CC));
21052   }
21053 
21054   return SDValue();
21055 }
21056 
getPTest(SelectionDAG & DAG,EVT VT,SDValue Pg,SDValue Op,AArch64CC::CondCode Cond)21057 static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
21058                         AArch64CC::CondCode Cond) {
21059   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
21060 
21061   SDLoc DL(Op);
21062   assert(Op.getValueType().isScalableVector() &&
21063          TLI.isTypeLegal(Op.getValueType()) &&
21064          "Expected legal scalable vector type!");
21065   assert(Op.getValueType() == Pg.getValueType() &&
21066          "Expected same type for PTEST operands");
21067 
21068   // Ensure target specific opcodes are using legal type.
21069   EVT OutVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
21070   SDValue TVal = DAG.getConstant(1, DL, OutVT);
21071   SDValue FVal = DAG.getConstant(0, DL, OutVT);
21072 
21073   // Ensure operands have type nxv16i1.
21074   if (Op.getValueType() != MVT::nxv16i1) {
21075     if ((Cond == AArch64CC::ANY_ACTIVE || Cond == AArch64CC::NONE_ACTIVE) &&
21076         isZeroingInactiveLanes(Op))
21077       Pg = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, MVT::nxv16i1, Pg);
21078     else
21079       Pg = getSVEPredicateBitCast(MVT::nxv16i1, Pg, DAG);
21080     Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, MVT::nxv16i1, Op);
21081   }
21082 
21083   // Set condition code (CC) flags.
21084   SDValue Test = DAG.getNode(
21085       Cond == AArch64CC::ANY_ACTIVE ? AArch64ISD::PTEST_ANY : AArch64ISD::PTEST,
21086       DL, MVT::Other, Pg, Op);
21087 
21088   // Convert CC to integer based on requested condition.
21089   // NOTE: Cond is inverted to promote CSEL's removal when it feeds a compare.
21090   SDValue CC = DAG.getConstant(getInvertedCondCode(Cond), DL, MVT::i32);
21091   SDValue Res = DAG.getNode(AArch64ISD::CSEL, DL, OutVT, FVal, TVal, CC, Test);
21092   return DAG.getZExtOrTrunc(Res, DL, VT);
21093 }
21094 
combineSVEReductionInt(SDNode * N,unsigned Opc,SelectionDAG & DAG)21095 static SDValue combineSVEReductionInt(SDNode *N, unsigned Opc,
21096                                       SelectionDAG &DAG) {
21097   SDLoc DL(N);
21098 
21099   SDValue Pred = N->getOperand(1);
21100   SDValue VecToReduce = N->getOperand(2);
21101 
21102   // NOTE: The integer reduction's result type is not always linked to the
21103   // operand's element type so we construct it from the intrinsic's result type.
21104   EVT ReduceVT = getPackedSVEVectorVT(N->getValueType(0));
21105   SDValue Reduce = DAG.getNode(Opc, DL, ReduceVT, Pred, VecToReduce);
21106 
21107   // SVE reductions set the whole vector register with the first element
21108   // containing the reduction result, which we'll now extract.
21109   SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
21110   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, N->getValueType(0), Reduce,
21111                      Zero);
21112 }
21113 
combineSVEReductionFP(SDNode * N,unsigned Opc,SelectionDAG & DAG)21114 static SDValue combineSVEReductionFP(SDNode *N, unsigned Opc,
21115                                      SelectionDAG &DAG) {
21116   SDLoc DL(N);
21117 
21118   SDValue Pred = N->getOperand(1);
21119   SDValue VecToReduce = N->getOperand(2);
21120 
21121   EVT ReduceVT = VecToReduce.getValueType();
21122   SDValue Reduce = DAG.getNode(Opc, DL, ReduceVT, Pred, VecToReduce);
21123 
21124   // SVE reductions set the whole vector register with the first element
21125   // containing the reduction result, which we'll now extract.
21126   SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
21127   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, N->getValueType(0), Reduce,
21128                      Zero);
21129 }
21130 
combineSVEReductionOrderedFP(SDNode * N,unsigned Opc,SelectionDAG & DAG)21131 static SDValue combineSVEReductionOrderedFP(SDNode *N, unsigned Opc,
21132                                             SelectionDAG &DAG) {
21133   SDLoc DL(N);
21134 
21135   SDValue Pred = N->getOperand(1);
21136   SDValue InitVal = N->getOperand(2);
21137   SDValue VecToReduce = N->getOperand(3);
21138   EVT ReduceVT = VecToReduce.getValueType();
21139 
21140   // Ordered reductions use the first lane of the result vector as the
21141   // reduction's initial value.
21142   SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
21143   InitVal = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ReduceVT,
21144                         DAG.getUNDEF(ReduceVT), InitVal, Zero);
21145 
21146   SDValue Reduce = DAG.getNode(Opc, DL, ReduceVT, Pred, InitVal, VecToReduce);
21147 
21148   // SVE reductions set the whole vector register with the first element
21149   // containing the reduction result, which we'll now extract.
21150   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, N->getValueType(0), Reduce,
21151                      Zero);
21152 }
21153 
21154 // If a merged operation has no inactive lanes we can relax it to a predicated
21155 // or unpredicated operation, which potentially allows better isel (perhaps
21156 // using immediate forms) or relaxing register reuse requirements.
convertMergedOpToPredOp(SDNode * N,unsigned Opc,SelectionDAG & DAG,bool UnpredOp=false,bool SwapOperands=false)21157 static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc,
21158                                        SelectionDAG &DAG, bool UnpredOp = false,
21159                                        bool SwapOperands = false) {
21160   assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN && "Expected intrinsic!");
21161   assert(N->getNumOperands() == 4 && "Expected 3 operand intrinsic!");
21162   SDValue Pg = N->getOperand(1);
21163   SDValue Op1 = N->getOperand(SwapOperands ? 3 : 2);
21164   SDValue Op2 = N->getOperand(SwapOperands ? 2 : 3);
21165 
21166   // ISD way to specify an all active predicate.
21167   if (isAllActivePredicate(DAG, Pg)) {
21168     if (UnpredOp)
21169       return DAG.getNode(Opc, SDLoc(N), N->getValueType(0), Op1, Op2);
21170 
21171     return DAG.getNode(Opc, SDLoc(N), N->getValueType(0), Pg, Op1, Op2);
21172   }
21173 
21174   // FUTURE: SplatVector(true)
21175   return SDValue();
21176 }
21177 
tryCombineWhileLo(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,const AArch64Subtarget * Subtarget)21178 static SDValue tryCombineWhileLo(SDNode *N,
21179                                  TargetLowering::DAGCombinerInfo &DCI,
21180                                  const AArch64Subtarget *Subtarget) {
21181   if (DCI.isBeforeLegalize())
21182     return SDValue();
21183 
21184   if (!Subtarget->hasSVE2p1())
21185     return SDValue();
21186 
21187   if (!N->hasNUsesOfValue(2, 0))
21188     return SDValue();
21189 
21190   const uint64_t HalfSize = N->getValueType(0).getVectorMinNumElements() / 2;
21191   if (HalfSize < 2)
21192     return SDValue();
21193 
21194   auto It = N->use_begin();
21195   SDNode *Lo = *It++;
21196   SDNode *Hi = *It;
21197 
21198   if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
21199       Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR)
21200     return SDValue();
21201 
21202   uint64_t OffLo = Lo->getConstantOperandVal(1);
21203   uint64_t OffHi = Hi->getConstantOperandVal(1);
21204 
21205   if (OffLo > OffHi) {
21206     std::swap(Lo, Hi);
21207     std::swap(OffLo, OffHi);
21208   }
21209 
21210   if (OffLo != 0 || OffHi != HalfSize)
21211     return SDValue();
21212 
21213   EVT HalfVec = Lo->getValueType(0);
21214   if (HalfVec != Hi->getValueType(0) ||
21215       HalfVec.getVectorElementCount() != ElementCount::getScalable(HalfSize))
21216     return SDValue();
21217 
21218   SelectionDAG &DAG = DCI.DAG;
21219   SDLoc DL(N);
21220   SDValue ID =
21221       DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
21222   SDValue Idx = N->getOperand(1);
21223   SDValue TC = N->getOperand(2);
21224   if (Idx.getValueType() != MVT::i64) {
21225     Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
21226     TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
21227   }
21228   auto R =
21229       DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL,
21230                   {Lo->getValueType(0), Hi->getValueType(0)}, {ID, Idx, TC});
21231 
21232   DCI.CombineTo(Lo, R.getValue(0));
21233   DCI.CombineTo(Hi, R.getValue(1));
21234 
21235   return SDValue(N, 0);
21236 }
21237 
performIntrinsicCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,const AArch64Subtarget * Subtarget)21238 static SDValue performIntrinsicCombine(SDNode *N,
21239                                        TargetLowering::DAGCombinerInfo &DCI,
21240                                        const AArch64Subtarget *Subtarget) {
21241   SelectionDAG &DAG = DCI.DAG;
21242   unsigned IID = getIntrinsicID(N);
21243   switch (IID) {
21244   default:
21245     break;
21246   case Intrinsic::aarch64_neon_vcvtfxs2fp:
21247   case Intrinsic::aarch64_neon_vcvtfxu2fp:
21248     return tryCombineFixedPointConvert(N, DCI, DAG);
21249   case Intrinsic::aarch64_neon_saddv:
21250     return combineAcrossLanesIntrinsic(AArch64ISD::SADDV, N, DAG);
21251   case Intrinsic::aarch64_neon_uaddv:
21252     return combineAcrossLanesIntrinsic(AArch64ISD::UADDV, N, DAG);
21253   case Intrinsic::aarch64_neon_sminv:
21254     return combineAcrossLanesIntrinsic(AArch64ISD::SMINV, N, DAG);
21255   case Intrinsic::aarch64_neon_uminv:
21256     return combineAcrossLanesIntrinsic(AArch64ISD::UMINV, N, DAG);
21257   case Intrinsic::aarch64_neon_smaxv:
21258     return combineAcrossLanesIntrinsic(AArch64ISD::SMAXV, N, DAG);
21259   case Intrinsic::aarch64_neon_umaxv:
21260     return combineAcrossLanesIntrinsic(AArch64ISD::UMAXV, N, DAG);
21261   case Intrinsic::aarch64_neon_fmax:
21262     return DAG.getNode(ISD::FMAXIMUM, SDLoc(N), N->getValueType(0),
21263                        N->getOperand(1), N->getOperand(2));
21264   case Intrinsic::aarch64_neon_fmin:
21265     return DAG.getNode(ISD::FMINIMUM, SDLoc(N), N->getValueType(0),
21266                        N->getOperand(1), N->getOperand(2));
21267   case Intrinsic::aarch64_neon_fmaxnm:
21268     return DAG.getNode(ISD::FMAXNUM, SDLoc(N), N->getValueType(0),
21269                        N->getOperand(1), N->getOperand(2));
21270   case Intrinsic::aarch64_neon_fminnm:
21271     return DAG.getNode(ISD::FMINNUM, SDLoc(N), N->getValueType(0),
21272                        N->getOperand(1), N->getOperand(2));
21273   case Intrinsic::aarch64_neon_smull:
21274     return DAG.getNode(AArch64ISD::SMULL, SDLoc(N), N->getValueType(0),
21275                        N->getOperand(1), N->getOperand(2));
21276   case Intrinsic::aarch64_neon_umull:
21277     return DAG.getNode(AArch64ISD::UMULL, SDLoc(N), N->getValueType(0),
21278                        N->getOperand(1), N->getOperand(2));
21279   case Intrinsic::aarch64_neon_pmull:
21280     return DAG.getNode(AArch64ISD::PMULL, SDLoc(N), N->getValueType(0),
21281                        N->getOperand(1), N->getOperand(2));
21282   case Intrinsic::aarch64_neon_sqdmull:
21283     return tryCombineLongOpWithDup(IID, N, DCI, DAG);
21284   case Intrinsic::aarch64_neon_sqshl:
21285   case Intrinsic::aarch64_neon_uqshl:
21286   case Intrinsic::aarch64_neon_sqshlu:
21287   case Intrinsic::aarch64_neon_srshl:
21288   case Intrinsic::aarch64_neon_urshl:
21289   case Intrinsic::aarch64_neon_sshl:
21290   case Intrinsic::aarch64_neon_ushl:
21291     return tryCombineShiftImm(IID, N, DAG);
21292   case Intrinsic::aarch64_neon_sabd:
21293     return DAG.getNode(ISD::ABDS, SDLoc(N), N->getValueType(0),
21294                        N->getOperand(1), N->getOperand(2));
21295   case Intrinsic::aarch64_neon_uabd:
21296     return DAG.getNode(ISD::ABDU, SDLoc(N), N->getValueType(0),
21297                        N->getOperand(1), N->getOperand(2));
21298   case Intrinsic::aarch64_crc32b:
21299   case Intrinsic::aarch64_crc32cb:
21300     return tryCombineCRC32(0xff, N, DAG);
21301   case Intrinsic::aarch64_crc32h:
21302   case Intrinsic::aarch64_crc32ch:
21303     return tryCombineCRC32(0xffff, N, DAG);
21304   case Intrinsic::aarch64_sve_saddv:
21305     // There is no i64 version of SADDV because the sign is irrelevant.
21306     if (N->getOperand(2)->getValueType(0).getVectorElementType() == MVT::i64)
21307       return combineSVEReductionInt(N, AArch64ISD::UADDV_PRED, DAG);
21308     else
21309       return combineSVEReductionInt(N, AArch64ISD::SADDV_PRED, DAG);
21310   case Intrinsic::aarch64_sve_uaddv:
21311     return combineSVEReductionInt(N, AArch64ISD::UADDV_PRED, DAG);
21312   case Intrinsic::aarch64_sve_smaxv:
21313     return combineSVEReductionInt(N, AArch64ISD::SMAXV_PRED, DAG);
21314   case Intrinsic::aarch64_sve_umaxv:
21315     return combineSVEReductionInt(N, AArch64ISD::UMAXV_PRED, DAG);
21316   case Intrinsic::aarch64_sve_sminv:
21317     return combineSVEReductionInt(N, AArch64ISD::SMINV_PRED, DAG);
21318   case Intrinsic::aarch64_sve_uminv:
21319     return combineSVEReductionInt(N, AArch64ISD::UMINV_PRED, DAG);
21320   case Intrinsic::aarch64_sve_orv:
21321     return combineSVEReductionInt(N, AArch64ISD::ORV_PRED, DAG);
21322   case Intrinsic::aarch64_sve_eorv:
21323     return combineSVEReductionInt(N, AArch64ISD::EORV_PRED, DAG);
21324   case Intrinsic::aarch64_sve_andv:
21325     return combineSVEReductionInt(N, AArch64ISD::ANDV_PRED, DAG);
21326   case Intrinsic::aarch64_sve_index:
21327     return LowerSVEIntrinsicIndex(N, DAG);
21328   case Intrinsic::aarch64_sve_dup:
21329     return LowerSVEIntrinsicDUP(N, DAG);
21330   case Intrinsic::aarch64_sve_dup_x:
21331     return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), N->getValueType(0),
21332                        N->getOperand(1));
21333   case Intrinsic::aarch64_sve_ext:
21334     return LowerSVEIntrinsicEXT(N, DAG);
21335   case Intrinsic::aarch64_sve_mul_u:
21336     return DAG.getNode(AArch64ISD::MUL_PRED, SDLoc(N), N->getValueType(0),
21337                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
21338   case Intrinsic::aarch64_sve_smulh_u:
21339     return DAG.getNode(AArch64ISD::MULHS_PRED, SDLoc(N), N->getValueType(0),
21340                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
21341   case Intrinsic::aarch64_sve_umulh_u:
21342     return DAG.getNode(AArch64ISD::MULHU_PRED, SDLoc(N), N->getValueType(0),
21343                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
21344   case Intrinsic::aarch64_sve_smin_u:
21345     return DAG.getNode(AArch64ISD::SMIN_PRED, SDLoc(N), N->getValueType(0),
21346                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
21347   case Intrinsic::aarch64_sve_umin_u:
21348     return DAG.getNode(AArch64ISD::UMIN_PRED, SDLoc(N), N->getValueType(0),
21349                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
21350   case Intrinsic::aarch64_sve_smax_u:
21351     return DAG.getNode(AArch64ISD::SMAX_PRED, SDLoc(N), N->getValueType(0),
21352                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
21353   case Intrinsic::aarch64_sve_umax_u:
21354     return DAG.getNode(AArch64ISD::UMAX_PRED, SDLoc(N), N->getValueType(0),
21355                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
21356   case Intrinsic::aarch64_sve_lsl_u:
21357     return DAG.getNode(AArch64ISD::SHL_PRED, SDLoc(N), N->getValueType(0),
21358                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
21359   case Intrinsic::aarch64_sve_lsr_u:
21360     return DAG.getNode(AArch64ISD::SRL_PRED, SDLoc(N), N->getValueType(0),
21361                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
21362   case Intrinsic::aarch64_sve_asr_u:
21363     return DAG.getNode(AArch64ISD::SRA_PRED, SDLoc(N), N->getValueType(0),
21364                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
21365   case Intrinsic::aarch64_sve_fadd_u:
21366     return DAG.getNode(AArch64ISD::FADD_PRED, SDLoc(N), N->getValueType(0),
21367                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
21368   case Intrinsic::aarch64_sve_fdiv_u:
21369     return DAG.getNode(AArch64ISD::FDIV_PRED, SDLoc(N), N->getValueType(0),
21370                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
21371   case Intrinsic::aarch64_sve_fmax_u:
21372     return DAG.getNode(AArch64ISD::FMAX_PRED, SDLoc(N), N->getValueType(0),
21373                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
21374   case Intrinsic::aarch64_sve_fmaxnm_u:
21375     return DAG.getNode(AArch64ISD::FMAXNM_PRED, SDLoc(N), N->getValueType(0),
21376                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
21377   case Intrinsic::aarch64_sve_fmla_u:
21378     return DAG.getNode(AArch64ISD::FMA_PRED, SDLoc(N), N->getValueType(0),
21379                        N->getOperand(1), N->getOperand(3), N->getOperand(4),
21380                        N->getOperand(2));
21381   case Intrinsic::aarch64_sve_fmin_u:
21382     return DAG.getNode(AArch64ISD::FMIN_PRED, SDLoc(N), N->getValueType(0),
21383                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
21384   case Intrinsic::aarch64_sve_fminnm_u:
21385     return DAG.getNode(AArch64ISD::FMINNM_PRED, SDLoc(N), N->getValueType(0),
21386                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
21387   case Intrinsic::aarch64_sve_fmul_u:
21388     return DAG.getNode(AArch64ISD::FMUL_PRED, SDLoc(N), N->getValueType(0),
21389                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
21390   case Intrinsic::aarch64_sve_fsub_u:
21391     return DAG.getNode(AArch64ISD::FSUB_PRED, SDLoc(N), N->getValueType(0),
21392                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
21393   case Intrinsic::aarch64_sve_add_u:
21394     return DAG.getNode(ISD::ADD, SDLoc(N), N->getValueType(0), N->getOperand(2),
21395                        N->getOperand(3));
21396   case Intrinsic::aarch64_sve_sub_u:
21397     return DAG.getNode(ISD::SUB, SDLoc(N), N->getValueType(0), N->getOperand(2),
21398                        N->getOperand(3));
21399   case Intrinsic::aarch64_sve_subr:
21400     return convertMergedOpToPredOp(N, ISD::SUB, DAG, true, true);
21401   case Intrinsic::aarch64_sve_and_u:
21402     return DAG.getNode(ISD::AND, SDLoc(N), N->getValueType(0), N->getOperand(2),
21403                        N->getOperand(3));
21404   case Intrinsic::aarch64_sve_bic_u:
21405     return DAG.getNode(AArch64ISD::BIC, SDLoc(N), N->getValueType(0),
21406                        N->getOperand(2), N->getOperand(3));
21407   case Intrinsic::aarch64_sve_eor_u:
21408     return DAG.getNode(ISD::XOR, SDLoc(N), N->getValueType(0), N->getOperand(2),
21409                        N->getOperand(3));
21410   case Intrinsic::aarch64_sve_orr_u:
21411     return DAG.getNode(ISD::OR, SDLoc(N), N->getValueType(0), N->getOperand(2),
21412                        N->getOperand(3));
21413   case Intrinsic::aarch64_sve_sabd_u:
21414     return DAG.getNode(ISD::ABDS, SDLoc(N), N->getValueType(0),
21415                        N->getOperand(2), N->getOperand(3));
21416   case Intrinsic::aarch64_sve_uabd_u:
21417     return DAG.getNode(ISD::ABDU, SDLoc(N), N->getValueType(0),
21418                        N->getOperand(2), N->getOperand(3));
21419   case Intrinsic::aarch64_sve_sdiv_u:
21420     return DAG.getNode(AArch64ISD::SDIV_PRED, SDLoc(N), N->getValueType(0),
21421                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
21422   case Intrinsic::aarch64_sve_udiv_u:
21423     return DAG.getNode(AArch64ISD::UDIV_PRED, SDLoc(N), N->getValueType(0),
21424                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
21425   case Intrinsic::aarch64_sve_sqadd:
21426     return convertMergedOpToPredOp(N, ISD::SADDSAT, DAG, true);
21427   case Intrinsic::aarch64_sve_sqsub_u:
21428     return DAG.getNode(ISD::SSUBSAT, SDLoc(N), N->getValueType(0),
21429                        N->getOperand(2), N->getOperand(3));
21430   case Intrinsic::aarch64_sve_uqadd:
21431     return convertMergedOpToPredOp(N, ISD::UADDSAT, DAG, true);
21432   case Intrinsic::aarch64_sve_uqsub_u:
21433     return DAG.getNode(ISD::USUBSAT, SDLoc(N), N->getValueType(0),
21434                        N->getOperand(2), N->getOperand(3));
21435   case Intrinsic::aarch64_sve_sqadd_x:
21436     return DAG.getNode(ISD::SADDSAT, SDLoc(N), N->getValueType(0),
21437                        N->getOperand(1), N->getOperand(2));
21438   case Intrinsic::aarch64_sve_sqsub_x:
21439     return DAG.getNode(ISD::SSUBSAT, SDLoc(N), N->getValueType(0),
21440                        N->getOperand(1), N->getOperand(2));
21441   case Intrinsic::aarch64_sve_uqadd_x:
21442     return DAG.getNode(ISD::UADDSAT, SDLoc(N), N->getValueType(0),
21443                        N->getOperand(1), N->getOperand(2));
21444   case Intrinsic::aarch64_sve_uqsub_x:
21445     return DAG.getNode(ISD::USUBSAT, SDLoc(N), N->getValueType(0),
21446                        N->getOperand(1), N->getOperand(2));
21447   case Intrinsic::aarch64_sve_asrd:
21448     return DAG.getNode(AArch64ISD::SRAD_MERGE_OP1, SDLoc(N), N->getValueType(0),
21449                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
21450   case Intrinsic::aarch64_sve_cmphs:
21451     if (!N->getOperand(2).getValueType().isFloatingPoint())
21452       return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
21453                          N->getValueType(0), N->getOperand(1), N->getOperand(2),
21454                          N->getOperand(3), DAG.getCondCode(ISD::SETUGE));
21455     break;
21456   case Intrinsic::aarch64_sve_cmphi:
21457     if (!N->getOperand(2).getValueType().isFloatingPoint())
21458       return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
21459                          N->getValueType(0), N->getOperand(1), N->getOperand(2),
21460                          N->getOperand(3), DAG.getCondCode(ISD::SETUGT));
21461     break;
21462   case Intrinsic::aarch64_sve_fcmpge:
21463   case Intrinsic::aarch64_sve_cmpge:
21464     return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
21465                        N->getValueType(0), N->getOperand(1), N->getOperand(2),
21466                        N->getOperand(3), DAG.getCondCode(ISD::SETGE));
21467     break;
21468   case Intrinsic::aarch64_sve_fcmpgt:
21469   case Intrinsic::aarch64_sve_cmpgt:
21470     return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
21471                        N->getValueType(0), N->getOperand(1), N->getOperand(2),
21472                        N->getOperand(3), DAG.getCondCode(ISD::SETGT));
21473     break;
21474   case Intrinsic::aarch64_sve_fcmpeq:
21475   case Intrinsic::aarch64_sve_cmpeq:
21476     return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
21477                        N->getValueType(0), N->getOperand(1), N->getOperand(2),
21478                        N->getOperand(3), DAG.getCondCode(ISD::SETEQ));
21479     break;
21480   case Intrinsic::aarch64_sve_fcmpne:
21481   case Intrinsic::aarch64_sve_cmpne:
21482     return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
21483                        N->getValueType(0), N->getOperand(1), N->getOperand(2),
21484                        N->getOperand(3), DAG.getCondCode(ISD::SETNE));
21485     break;
21486   case Intrinsic::aarch64_sve_fcmpuo:
21487     return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
21488                        N->getValueType(0), N->getOperand(1), N->getOperand(2),
21489                        N->getOperand(3), DAG.getCondCode(ISD::SETUO));
21490     break;
21491   case Intrinsic::aarch64_sve_fadda:
21492     return combineSVEReductionOrderedFP(N, AArch64ISD::FADDA_PRED, DAG);
21493   case Intrinsic::aarch64_sve_faddv:
21494     return combineSVEReductionFP(N, AArch64ISD::FADDV_PRED, DAG);
21495   case Intrinsic::aarch64_sve_fmaxnmv:
21496     return combineSVEReductionFP(N, AArch64ISD::FMAXNMV_PRED, DAG);
21497   case Intrinsic::aarch64_sve_fmaxv:
21498     return combineSVEReductionFP(N, AArch64ISD::FMAXV_PRED, DAG);
21499   case Intrinsic::aarch64_sve_fminnmv:
21500     return combineSVEReductionFP(N, AArch64ISD::FMINNMV_PRED, DAG);
21501   case Intrinsic::aarch64_sve_fminv:
21502     return combineSVEReductionFP(N, AArch64ISD::FMINV_PRED, DAG);
21503   case Intrinsic::aarch64_sve_sel:
21504     return DAG.getNode(ISD::VSELECT, SDLoc(N), N->getValueType(0),
21505                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
21506   case Intrinsic::aarch64_sve_cmpeq_wide:
21507     return tryConvertSVEWideCompare(N, ISD::SETEQ, DCI, DAG);
21508   case Intrinsic::aarch64_sve_cmpne_wide:
21509     return tryConvertSVEWideCompare(N, ISD::SETNE, DCI, DAG);
21510   case Intrinsic::aarch64_sve_cmpge_wide:
21511     return tryConvertSVEWideCompare(N, ISD::SETGE, DCI, DAG);
21512   case Intrinsic::aarch64_sve_cmpgt_wide:
21513     return tryConvertSVEWideCompare(N, ISD::SETGT, DCI, DAG);
21514   case Intrinsic::aarch64_sve_cmplt_wide:
21515     return tryConvertSVEWideCompare(N, ISD::SETLT, DCI, DAG);
21516   case Intrinsic::aarch64_sve_cmple_wide:
21517     return tryConvertSVEWideCompare(N, ISD::SETLE, DCI, DAG);
21518   case Intrinsic::aarch64_sve_cmphs_wide:
21519     return tryConvertSVEWideCompare(N, ISD::SETUGE, DCI, DAG);
21520   case Intrinsic::aarch64_sve_cmphi_wide:
21521     return tryConvertSVEWideCompare(N, ISD::SETUGT, DCI, DAG);
21522   case Intrinsic::aarch64_sve_cmplo_wide:
21523     return tryConvertSVEWideCompare(N, ISD::SETULT, DCI, DAG);
21524   case Intrinsic::aarch64_sve_cmpls_wide:
21525     return tryConvertSVEWideCompare(N, ISD::SETULE, DCI, DAG);
21526   case Intrinsic::aarch64_sve_ptest_any:
21527     return getPTest(DAG, N->getValueType(0), N->getOperand(1), N->getOperand(2),
21528                     AArch64CC::ANY_ACTIVE);
21529   case Intrinsic::aarch64_sve_ptest_first:
21530     return getPTest(DAG, N->getValueType(0), N->getOperand(1), N->getOperand(2),
21531                     AArch64CC::FIRST_ACTIVE);
21532   case Intrinsic::aarch64_sve_ptest_last:
21533     return getPTest(DAG, N->getValueType(0), N->getOperand(1), N->getOperand(2),
21534                     AArch64CC::LAST_ACTIVE);
21535   case Intrinsic::aarch64_sve_whilelo:
21536     return tryCombineWhileLo(N, DCI, Subtarget);
21537   }
21538   return SDValue();
21539 }
21540 
isCheapToExtend(const SDValue & N)21541 static bool isCheapToExtend(const SDValue &N) {
21542   unsigned OC = N->getOpcode();
21543   return OC == ISD::LOAD || OC == ISD::MLOAD ||
21544          ISD::isConstantSplatVectorAllZeros(N.getNode());
21545 }
21546 
21547 static SDValue
performSignExtendSetCCCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)21548 performSignExtendSetCCCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
21549                               SelectionDAG &DAG) {
21550   // If we have (sext (setcc A B)) and A and B are cheap to extend,
21551   // we can move the sext into the arguments and have the same result. For
21552   // example, if A and B are both loads, we can make those extending loads and
21553   // avoid an extra instruction. This pattern appears often in VLS code
21554   // generation where the inputs to the setcc have a different size to the
21555   // instruction that wants to use the result of the setcc.
21556   assert(N->getOpcode() == ISD::SIGN_EXTEND &&
21557          N->getOperand(0)->getOpcode() == ISD::SETCC);
21558   const SDValue SetCC = N->getOperand(0);
21559 
21560   const SDValue CCOp0 = SetCC.getOperand(0);
21561   const SDValue CCOp1 = SetCC.getOperand(1);
21562   if (!CCOp0->getValueType(0).isInteger() ||
21563       !CCOp1->getValueType(0).isInteger())
21564     return SDValue();
21565 
21566   ISD::CondCode Code =
21567       cast<CondCodeSDNode>(SetCC->getOperand(2).getNode())->get();
21568 
21569   ISD::NodeType ExtType =
21570       isSignedIntSetCC(Code) ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
21571 
21572   if (isCheapToExtend(SetCC.getOperand(0)) &&
21573       isCheapToExtend(SetCC.getOperand(1))) {
21574     const SDValue Ext1 =
21575         DAG.getNode(ExtType, SDLoc(N), N->getValueType(0), CCOp0);
21576     const SDValue Ext2 =
21577         DAG.getNode(ExtType, SDLoc(N), N->getValueType(0), CCOp1);
21578 
21579     return DAG.getSetCC(
21580         SDLoc(SetCC), N->getValueType(0), Ext1, Ext2,
21581         cast<CondCodeSDNode>(SetCC->getOperand(2).getNode())->get());
21582   }
21583 
21584   return SDValue();
21585 }
21586 
performExtendCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)21587 static SDValue performExtendCombine(SDNode *N,
21588                                     TargetLowering::DAGCombinerInfo &DCI,
21589                                     SelectionDAG &DAG) {
21590   // If we see something like (zext (sabd (extract_high ...), (DUP ...))) then
21591   // we can convert that DUP into another extract_high (of a bigger DUP), which
21592   // helps the backend to decide that an sabdl2 would be useful, saving a real
21593   // extract_high operation.
21594   if (!DCI.isBeforeLegalizeOps() && N->getOpcode() == ISD::ZERO_EXTEND &&
21595       (N->getOperand(0).getOpcode() == ISD::ABDU ||
21596        N->getOperand(0).getOpcode() == ISD::ABDS)) {
21597     SDNode *ABDNode = N->getOperand(0).getNode();
21598     SDValue NewABD =
21599         tryCombineLongOpWithDup(Intrinsic::not_intrinsic, ABDNode, DCI, DAG);
21600     if (!NewABD.getNode())
21601       return SDValue();
21602 
21603     return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), NewABD);
21604   }
21605 
21606   if (N->getValueType(0).isFixedLengthVector() &&
21607       N->getOpcode() == ISD::SIGN_EXTEND &&
21608       N->getOperand(0)->getOpcode() == ISD::SETCC)
21609     return performSignExtendSetCCCombine(N, DCI, DAG);
21610 
21611   return SDValue();
21612 }
21613 
splitStoreSplat(SelectionDAG & DAG,StoreSDNode & St,SDValue SplatVal,unsigned NumVecElts)21614 static SDValue splitStoreSplat(SelectionDAG &DAG, StoreSDNode &St,
21615                                SDValue SplatVal, unsigned NumVecElts) {
21616   assert(!St.isTruncatingStore() && "cannot split truncating vector store");
21617   Align OrigAlignment = St.getAlign();
21618   unsigned EltOffset = SplatVal.getValueType().getSizeInBits() / 8;
21619 
21620   // Create scalar stores. This is at least as good as the code sequence for a
21621   // split unaligned store which is a dup.s, ext.b, and two stores.
21622   // Most of the time the three stores should be replaced by store pair
21623   // instructions (stp).
21624   SDLoc DL(&St);
21625   SDValue BasePtr = St.getBasePtr();
21626   uint64_t BaseOffset = 0;
21627 
21628   const MachinePointerInfo &PtrInfo = St.getPointerInfo();
21629   SDValue NewST1 =
21630       DAG.getStore(St.getChain(), DL, SplatVal, BasePtr, PtrInfo,
21631                    OrigAlignment, St.getMemOperand()->getFlags());
21632 
21633   // As this in ISel, we will not merge this add which may degrade results.
21634   if (BasePtr->getOpcode() == ISD::ADD &&
21635       isa<ConstantSDNode>(BasePtr->getOperand(1))) {
21636     BaseOffset = cast<ConstantSDNode>(BasePtr->getOperand(1))->getSExtValue();
21637     BasePtr = BasePtr->getOperand(0);
21638   }
21639 
21640   unsigned Offset = EltOffset;
21641   while (--NumVecElts) {
21642     Align Alignment = commonAlignment(OrigAlignment, Offset);
21643     SDValue OffsetPtr =
21644         DAG.getNode(ISD::ADD, DL, MVT::i64, BasePtr,
21645                     DAG.getConstant(BaseOffset + Offset, DL, MVT::i64));
21646     NewST1 = DAG.getStore(NewST1.getValue(0), DL, SplatVal, OffsetPtr,
21647                           PtrInfo.getWithOffset(Offset), Alignment,
21648                           St.getMemOperand()->getFlags());
21649     Offset += EltOffset;
21650   }
21651   return NewST1;
21652 }
21653 
21654 // Returns an SVE type that ContentTy can be trivially sign or zero extended
21655 // into.
getSVEContainerType(EVT ContentTy)21656 static MVT getSVEContainerType(EVT ContentTy) {
21657   assert(ContentTy.isSimple() && "No SVE containers for extended types");
21658 
21659   switch (ContentTy.getSimpleVT().SimpleTy) {
21660   default:
21661     llvm_unreachable("No known SVE container for this MVT type");
21662   case MVT::nxv2i8:
21663   case MVT::nxv2i16:
21664   case MVT::nxv2i32:
21665   case MVT::nxv2i64:
21666   case MVT::nxv2f32:
21667   case MVT::nxv2f64:
21668     return MVT::nxv2i64;
21669   case MVT::nxv4i8:
21670   case MVT::nxv4i16:
21671   case MVT::nxv4i32:
21672   case MVT::nxv4f32:
21673     return MVT::nxv4i32;
21674   case MVT::nxv8i8:
21675   case MVT::nxv8i16:
21676   case MVT::nxv8f16:
21677   case MVT::nxv8bf16:
21678     return MVT::nxv8i16;
21679   case MVT::nxv16i8:
21680     return MVT::nxv16i8;
21681   }
21682 }
21683 
performLD1Combine(SDNode * N,SelectionDAG & DAG,unsigned Opc)21684 static SDValue performLD1Combine(SDNode *N, SelectionDAG &DAG, unsigned Opc) {
21685   SDLoc DL(N);
21686   EVT VT = N->getValueType(0);
21687 
21688   if (VT.getSizeInBits().getKnownMinValue() > AArch64::SVEBitsPerBlock)
21689     return SDValue();
21690 
21691   EVT ContainerVT = VT;
21692   if (ContainerVT.isInteger())
21693     ContainerVT = getSVEContainerType(ContainerVT);
21694 
21695   SDVTList VTs = DAG.getVTList(ContainerVT, MVT::Other);
21696   SDValue Ops[] = { N->getOperand(0), // Chain
21697                     N->getOperand(2), // Pg
21698                     N->getOperand(3), // Base
21699                     DAG.getValueType(VT) };
21700 
21701   SDValue Load = DAG.getNode(Opc, DL, VTs, Ops);
21702   SDValue LoadChain = SDValue(Load.getNode(), 1);
21703 
21704   if (ContainerVT.isInteger() && (VT != ContainerVT))
21705     Load = DAG.getNode(ISD::TRUNCATE, DL, VT, Load.getValue(0));
21706 
21707   return DAG.getMergeValues({ Load, LoadChain }, DL);
21708 }
21709 
performLDNT1Combine(SDNode * N,SelectionDAG & DAG)21710 static SDValue performLDNT1Combine(SDNode *N, SelectionDAG &DAG) {
21711   SDLoc DL(N);
21712   EVT VT = N->getValueType(0);
21713   EVT PtrTy = N->getOperand(3).getValueType();
21714 
21715   EVT LoadVT = VT;
21716   if (VT.isFloatingPoint())
21717     LoadVT = VT.changeTypeToInteger();
21718 
21719   auto *MINode = cast<MemIntrinsicSDNode>(N);
21720   SDValue PassThru = DAG.getConstant(0, DL, LoadVT);
21721   SDValue L = DAG.getMaskedLoad(LoadVT, DL, MINode->getChain(),
21722                                 MINode->getOperand(3), DAG.getUNDEF(PtrTy),
21723                                 MINode->getOperand(2), PassThru,
21724                                 MINode->getMemoryVT(), MINode->getMemOperand(),
21725                                 ISD::UNINDEXED, ISD::NON_EXTLOAD, false);
21726 
21727    if (VT.isFloatingPoint()) {
21728      SDValue Ops[] = { DAG.getNode(ISD::BITCAST, DL, VT, L), L.getValue(1) };
21729      return DAG.getMergeValues(Ops, DL);
21730    }
21731 
21732   return L;
21733 }
21734 
21735 template <unsigned Opcode>
performLD1ReplicateCombine(SDNode * N,SelectionDAG & DAG)21736 static SDValue performLD1ReplicateCombine(SDNode *N, SelectionDAG &DAG) {
21737   static_assert(Opcode == AArch64ISD::LD1RQ_MERGE_ZERO ||
21738                     Opcode == AArch64ISD::LD1RO_MERGE_ZERO,
21739                 "Unsupported opcode.");
21740   SDLoc DL(N);
21741   EVT VT = N->getValueType(0);
21742 
21743   EVT LoadVT = VT;
21744   if (VT.isFloatingPoint())
21745     LoadVT = VT.changeTypeToInteger();
21746 
21747   SDValue Ops[] = {N->getOperand(0), N->getOperand(2), N->getOperand(3)};
21748   SDValue Load = DAG.getNode(Opcode, DL, {LoadVT, MVT::Other}, Ops);
21749   SDValue LoadChain = SDValue(Load.getNode(), 1);
21750 
21751   if (VT.isFloatingPoint())
21752     Load = DAG.getNode(ISD::BITCAST, DL, VT, Load.getValue(0));
21753 
21754   return DAG.getMergeValues({Load, LoadChain}, DL);
21755 }
21756 
performST1Combine(SDNode * N,SelectionDAG & DAG)21757 static SDValue performST1Combine(SDNode *N, SelectionDAG &DAG) {
21758   SDLoc DL(N);
21759   SDValue Data = N->getOperand(2);
21760   EVT DataVT = Data.getValueType();
21761   EVT HwSrcVt = getSVEContainerType(DataVT);
21762   SDValue InputVT = DAG.getValueType(DataVT);
21763 
21764   if (DataVT.isFloatingPoint())
21765     InputVT = DAG.getValueType(HwSrcVt);
21766 
21767   SDValue SrcNew;
21768   if (Data.getValueType().isFloatingPoint())
21769     SrcNew = DAG.getNode(ISD::BITCAST, DL, HwSrcVt, Data);
21770   else
21771     SrcNew = DAG.getNode(ISD::ANY_EXTEND, DL, HwSrcVt, Data);
21772 
21773   SDValue Ops[] = { N->getOperand(0), // Chain
21774                     SrcNew,
21775                     N->getOperand(4), // Base
21776                     N->getOperand(3), // Pg
21777                     InputVT
21778                   };
21779 
21780   return DAG.getNode(AArch64ISD::ST1_PRED, DL, N->getValueType(0), Ops);
21781 }
21782 
performSTNT1Combine(SDNode * N,SelectionDAG & DAG)21783 static SDValue performSTNT1Combine(SDNode *N, SelectionDAG &DAG) {
21784   SDLoc DL(N);
21785 
21786   SDValue Data = N->getOperand(2);
21787   EVT DataVT = Data.getValueType();
21788   EVT PtrTy = N->getOperand(4).getValueType();
21789 
21790   if (DataVT.isFloatingPoint())
21791     Data = DAG.getNode(ISD::BITCAST, DL, DataVT.changeTypeToInteger(), Data);
21792 
21793   auto *MINode = cast<MemIntrinsicSDNode>(N);
21794   return DAG.getMaskedStore(MINode->getChain(), DL, Data, MINode->getOperand(4),
21795                             DAG.getUNDEF(PtrTy), MINode->getOperand(3),
21796                             MINode->getMemoryVT(), MINode->getMemOperand(),
21797                             ISD::UNINDEXED, false, false);
21798 }
21799 
21800 /// Replace a splat of zeros to a vector store by scalar stores of WZR/XZR.  The
21801 /// load store optimizer pass will merge them to store pair stores.  This should
21802 /// be better than a movi to create the vector zero followed by a vector store
21803 /// if the zero constant is not re-used, since one instructions and one register
21804 /// live range will be removed.
21805 ///
21806 /// For example, the final generated code should be:
21807 ///
21808 ///   stp xzr, xzr, [x0]
21809 ///
21810 /// instead of:
21811 ///
21812 ///   movi v0.2d, #0
21813 ///   str q0, [x0]
21814 ///
replaceZeroVectorStore(SelectionDAG & DAG,StoreSDNode & St)21815 static SDValue replaceZeroVectorStore(SelectionDAG &DAG, StoreSDNode &St) {
21816   SDValue StVal = St.getValue();
21817   EVT VT = StVal.getValueType();
21818 
21819   // Avoid scalarizing zero splat stores for scalable vectors.
21820   if (VT.isScalableVector())
21821     return SDValue();
21822 
21823   // It is beneficial to scalarize a zero splat store for 2 or 3 i64 elements or
21824   // 2, 3 or 4 i32 elements.
21825   int NumVecElts = VT.getVectorNumElements();
21826   if (!(((NumVecElts == 2 || NumVecElts == 3) &&
21827          VT.getVectorElementType().getSizeInBits() == 64) ||
21828         ((NumVecElts == 2 || NumVecElts == 3 || NumVecElts == 4) &&
21829          VT.getVectorElementType().getSizeInBits() == 32)))
21830     return SDValue();
21831 
21832   if (StVal.getOpcode() != ISD::BUILD_VECTOR)
21833     return SDValue();
21834 
21835   // If the zero constant has more than one use then the vector store could be
21836   // better since the constant mov will be amortized and stp q instructions
21837   // should be able to be formed.
21838   if (!StVal.hasOneUse())
21839     return SDValue();
21840 
21841   // If the store is truncating then it's going down to i16 or smaller, which
21842   // means it can be implemented in a single store anyway.
21843   if (St.isTruncatingStore())
21844     return SDValue();
21845 
21846   // If the immediate offset of the address operand is too large for the stp
21847   // instruction, then bail out.
21848   if (DAG.isBaseWithConstantOffset(St.getBasePtr())) {
21849     int64_t Offset = St.getBasePtr()->getConstantOperandVal(1);
21850     if (Offset < -512 || Offset > 504)
21851       return SDValue();
21852   }
21853 
21854   for (int I = 0; I < NumVecElts; ++I) {
21855     SDValue EltVal = StVal.getOperand(I);
21856     if (!isNullConstant(EltVal) && !isNullFPConstant(EltVal))
21857       return SDValue();
21858   }
21859 
21860   // Use a CopyFromReg WZR/XZR here to prevent
21861   // DAGCombiner::MergeConsecutiveStores from undoing this transformation.
21862   SDLoc DL(&St);
21863   unsigned ZeroReg;
21864   EVT ZeroVT;
21865   if (VT.getVectorElementType().getSizeInBits() == 32) {
21866     ZeroReg = AArch64::WZR;
21867     ZeroVT = MVT::i32;
21868   } else {
21869     ZeroReg = AArch64::XZR;
21870     ZeroVT = MVT::i64;
21871   }
21872   SDValue SplatVal =
21873       DAG.getCopyFromReg(DAG.getEntryNode(), DL, ZeroReg, ZeroVT);
21874   return splitStoreSplat(DAG, St, SplatVal, NumVecElts);
21875 }
21876 
21877 /// Replace a splat of a scalar to a vector store by scalar stores of the scalar
21878 /// value. The load store optimizer pass will merge them to store pair stores.
21879 /// This has better performance than a splat of the scalar followed by a split
21880 /// vector store. Even if the stores are not merged it is four stores vs a dup,
21881 /// followed by an ext.b and two stores.
replaceSplatVectorStore(SelectionDAG & DAG,StoreSDNode & St)21882 static SDValue replaceSplatVectorStore(SelectionDAG &DAG, StoreSDNode &St) {
21883   SDValue StVal = St.getValue();
21884   EVT VT = StVal.getValueType();
21885 
21886   // Don't replace floating point stores, they possibly won't be transformed to
21887   // stp because of the store pair suppress pass.
21888   if (VT.isFloatingPoint())
21889     return SDValue();
21890 
21891   // We can express a splat as store pair(s) for 2 or 4 elements.
21892   unsigned NumVecElts = VT.getVectorNumElements();
21893   if (NumVecElts != 4 && NumVecElts != 2)
21894     return SDValue();
21895 
21896   // If the store is truncating then it's going down to i16 or smaller, which
21897   // means it can be implemented in a single store anyway.
21898   if (St.isTruncatingStore())
21899     return SDValue();
21900 
21901   // Check that this is a splat.
21902   // Make sure that each of the relevant vector element locations are inserted
21903   // to, i.e. 0 and 1 for v2i64 and 0, 1, 2, 3 for v4i32.
21904   std::bitset<4> IndexNotInserted((1 << NumVecElts) - 1);
21905   SDValue SplatVal;
21906   for (unsigned I = 0; I < NumVecElts; ++I) {
21907     // Check for insert vector elements.
21908     if (StVal.getOpcode() != ISD::INSERT_VECTOR_ELT)
21909       return SDValue();
21910 
21911     // Check that same value is inserted at each vector element.
21912     if (I == 0)
21913       SplatVal = StVal.getOperand(1);
21914     else if (StVal.getOperand(1) != SplatVal)
21915       return SDValue();
21916 
21917     // Check insert element index.
21918     ConstantSDNode *CIndex = dyn_cast<ConstantSDNode>(StVal.getOperand(2));
21919     if (!CIndex)
21920       return SDValue();
21921     uint64_t IndexVal = CIndex->getZExtValue();
21922     if (IndexVal >= NumVecElts)
21923       return SDValue();
21924     IndexNotInserted.reset(IndexVal);
21925 
21926     StVal = StVal.getOperand(0);
21927   }
21928   // Check that all vector element locations were inserted to.
21929   if (IndexNotInserted.any())
21930       return SDValue();
21931 
21932   return splitStoreSplat(DAG, St, SplatVal, NumVecElts);
21933 }
21934 
splitStores(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)21935 static SDValue splitStores(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
21936                            SelectionDAG &DAG,
21937                            const AArch64Subtarget *Subtarget) {
21938 
21939   StoreSDNode *S = cast<StoreSDNode>(N);
21940   if (S->isVolatile() || S->isIndexed())
21941     return SDValue();
21942 
21943   SDValue StVal = S->getValue();
21944   EVT VT = StVal.getValueType();
21945 
21946   if (!VT.isFixedLengthVector())
21947     return SDValue();
21948 
21949   // If we get a splat of zeros, convert this vector store to a store of
21950   // scalars. They will be merged into store pairs of xzr thereby removing one
21951   // instruction and one register.
21952   if (SDValue ReplacedZeroSplat = replaceZeroVectorStore(DAG, *S))
21953     return ReplacedZeroSplat;
21954 
21955   // FIXME: The logic for deciding if an unaligned store should be split should
21956   // be included in TLI.allowsMisalignedMemoryAccesses(), and there should be
21957   // a call to that function here.
21958 
21959   if (!Subtarget->isMisaligned128StoreSlow())
21960     return SDValue();
21961 
21962   // Don't split at -Oz.
21963   if (DAG.getMachineFunction().getFunction().hasMinSize())
21964     return SDValue();
21965 
21966   // Don't split v2i64 vectors. Memcpy lowering produces those and splitting
21967   // those up regresses performance on micro-benchmarks and olden/bh.
21968   if (VT.getVectorNumElements() < 2 || VT == MVT::v2i64)
21969     return SDValue();
21970 
21971   // Split unaligned 16B stores. They are terrible for performance.
21972   // Don't split stores with alignment of 1 or 2. Code that uses clang vector
21973   // extensions can use this to mark that it does not want splitting to happen
21974   // (by underspecifying alignment to be 1 or 2). Furthermore, the chance of
21975   // eliminating alignment hazards is only 1 in 8 for alignment of 2.
21976   if (VT.getSizeInBits() != 128 || S->getAlign() >= Align(16) ||
21977       S->getAlign() <= Align(2))
21978     return SDValue();
21979 
21980   // If we get a splat of a scalar convert this vector store to a store of
21981   // scalars. They will be merged into store pairs thereby removing two
21982   // instructions.
21983   if (SDValue ReplacedSplat = replaceSplatVectorStore(DAG, *S))
21984     return ReplacedSplat;
21985 
21986   SDLoc DL(S);
21987 
21988   // Split VT into two.
21989   EVT HalfVT = VT.getHalfNumVectorElementsVT(*DAG.getContext());
21990   unsigned NumElts = HalfVT.getVectorNumElements();
21991   SDValue SubVector0 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, StVal,
21992                                    DAG.getConstant(0, DL, MVT::i64));
21993   SDValue SubVector1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, StVal,
21994                                    DAG.getConstant(NumElts, DL, MVT::i64));
21995   SDValue BasePtr = S->getBasePtr();
21996   SDValue NewST1 =
21997       DAG.getStore(S->getChain(), DL, SubVector0, BasePtr, S->getPointerInfo(),
21998                    S->getAlign(), S->getMemOperand()->getFlags());
21999   SDValue OffsetPtr = DAG.getNode(ISD::ADD, DL, MVT::i64, BasePtr,
22000                                   DAG.getConstant(8, DL, MVT::i64));
22001   return DAG.getStore(NewST1.getValue(0), DL, SubVector1, OffsetPtr,
22002                       S->getPointerInfo(), S->getAlign(),
22003                       S->getMemOperand()->getFlags());
22004 }
22005 
performSpliceCombine(SDNode * N,SelectionDAG & DAG)22006 static SDValue performSpliceCombine(SDNode *N, SelectionDAG &DAG) {
22007   assert(N->getOpcode() == AArch64ISD::SPLICE && "Unexepected Opcode!");
22008 
22009   // splice(pg, op1, undef) -> op1
22010   if (N->getOperand(2).isUndef())
22011     return N->getOperand(1);
22012 
22013   return SDValue();
22014 }
22015 
performUnpackCombine(SDNode * N,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)22016 static SDValue performUnpackCombine(SDNode *N, SelectionDAG &DAG,
22017                                     const AArch64Subtarget *Subtarget) {
22018   assert((N->getOpcode() == AArch64ISD::UUNPKHI ||
22019           N->getOpcode() == AArch64ISD::UUNPKLO) &&
22020          "Unexpected Opcode!");
22021 
22022   // uunpklo/hi undef -> undef
22023   if (N->getOperand(0).isUndef())
22024     return DAG.getUNDEF(N->getValueType(0));
22025 
22026   // If this is a masked load followed by an UUNPKLO, fold this into a masked
22027   // extending load.  We can do this even if this is already a masked
22028   // {z,}extload.
22029   if (N->getOperand(0).getOpcode() == ISD::MLOAD &&
22030       N->getOpcode() == AArch64ISD::UUNPKLO) {
22031     MaskedLoadSDNode *MLD = cast<MaskedLoadSDNode>(N->getOperand(0));
22032     SDValue Mask = MLD->getMask();
22033     SDLoc DL(N);
22034 
22035     if (MLD->isUnindexed() && MLD->getExtensionType() != ISD::SEXTLOAD &&
22036         SDValue(MLD, 0).hasOneUse() && Mask->getOpcode() == AArch64ISD::PTRUE &&
22037         (MLD->getPassThru()->isUndef() ||
22038          isZerosVector(MLD->getPassThru().getNode()))) {
22039       unsigned MinSVESize = Subtarget->getMinSVEVectorSizeInBits();
22040       unsigned PgPattern = Mask->getConstantOperandVal(0);
22041       EVT VT = N->getValueType(0);
22042 
22043       // Ensure we can double the size of the predicate pattern
22044       unsigned NumElts = getNumElementsFromSVEPredPattern(PgPattern);
22045       if (NumElts &&
22046           NumElts * VT.getVectorElementType().getSizeInBits() <= MinSVESize) {
22047         Mask =
22048             getPTrue(DAG, DL, VT.changeVectorElementType(MVT::i1), PgPattern);
22049         SDValue PassThru = DAG.getConstant(0, DL, VT);
22050         SDValue NewLoad = DAG.getMaskedLoad(
22051             VT, DL, MLD->getChain(), MLD->getBasePtr(), MLD->getOffset(), Mask,
22052             PassThru, MLD->getMemoryVT(), MLD->getMemOperand(),
22053             MLD->getAddressingMode(), ISD::ZEXTLOAD);
22054 
22055         DAG.ReplaceAllUsesOfValueWith(SDValue(MLD, 1), NewLoad.getValue(1));
22056 
22057         return NewLoad;
22058       }
22059     }
22060   }
22061 
22062   return SDValue();
22063 }
22064 
isHalvingTruncateAndConcatOfLegalIntScalableType(SDNode * N)22065 static bool isHalvingTruncateAndConcatOfLegalIntScalableType(SDNode *N) {
22066   if (N->getOpcode() != AArch64ISD::UZP1)
22067     return false;
22068   SDValue Op0 = N->getOperand(0);
22069   EVT SrcVT = Op0->getValueType(0);
22070   EVT DstVT = N->getValueType(0);
22071   return (SrcVT == MVT::nxv8i16 && DstVT == MVT::nxv16i8) ||
22072          (SrcVT == MVT::nxv4i32 && DstVT == MVT::nxv8i16) ||
22073          (SrcVT == MVT::nxv2i64 && DstVT == MVT::nxv4i32);
22074 }
22075 
22076 // Try to combine rounding shifts where the operands come from an extend, and
22077 // the result is truncated and combined into one vector.
22078 //   uzp1(rshrnb(uunpklo(X),C), rshrnb(uunpkhi(X), C)) -> urshr(X, C)
tryCombineExtendRShTrunc(SDNode * N,SelectionDAG & DAG)22079 static SDValue tryCombineExtendRShTrunc(SDNode *N, SelectionDAG &DAG) {
22080   assert(N->getOpcode() == AArch64ISD::UZP1 && "Only UZP1 expected.");
22081   SDValue Op0 = N->getOperand(0);
22082   SDValue Op1 = N->getOperand(1);
22083   EVT ResVT = N->getValueType(0);
22084 
22085   unsigned RshOpc = Op0.getOpcode();
22086   if (RshOpc != AArch64ISD::RSHRNB_I)
22087     return SDValue();
22088 
22089   // Same op code and imm value?
22090   SDValue ShiftValue = Op0.getOperand(1);
22091   if (RshOpc != Op1.getOpcode() || ShiftValue != Op1.getOperand(1))
22092     return SDValue();
22093 
22094   // Same unextended operand value?
22095   SDValue Lo = Op0.getOperand(0);
22096   SDValue Hi = Op1.getOperand(0);
22097   if (Lo.getOpcode() != AArch64ISD::UUNPKLO &&
22098       Hi.getOpcode() != AArch64ISD::UUNPKHI)
22099     return SDValue();
22100   SDValue OrigArg = Lo.getOperand(0);
22101   if (OrigArg != Hi.getOperand(0))
22102     return SDValue();
22103 
22104   SDLoc DL(N);
22105   return DAG.getNode(AArch64ISD::URSHR_I_PRED, DL, ResVT,
22106                      getPredicateForVector(DAG, DL, ResVT), OrigArg,
22107                      ShiftValue);
22108 }
22109 
22110 // Try to simplify:
22111 //    t1 = nxv8i16 add(X, 1 << (ShiftValue - 1))
22112 //    t2 = nxv8i16 srl(t1, ShiftValue)
22113 // to
22114 //    t1 = nxv8i16 rshrnb(X, shiftvalue).
22115 // rshrnb will zero the top half bits of each element. Therefore, this combine
22116 // should only be performed when a following instruction with the rshrnb
22117 // as an operand does not care about the top half of each element. For example,
22118 // a uzp1 or a truncating store.
trySimplifySrlAddToRshrnb(SDValue Srl,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)22119 static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
22120                                          const AArch64Subtarget *Subtarget) {
22121   EVT VT = Srl->getValueType(0);
22122   if (!VT.isScalableVector() || !Subtarget->hasSVE2())
22123     return SDValue();
22124 
22125   EVT ResVT;
22126   if (VT == MVT::nxv8i16)
22127     ResVT = MVT::nxv16i8;
22128   else if (VT == MVT::nxv4i32)
22129     ResVT = MVT::nxv8i16;
22130   else if (VT == MVT::nxv2i64)
22131     ResVT = MVT::nxv4i32;
22132   else
22133     return SDValue();
22134 
22135   SDLoc DL(Srl);
22136   unsigned ShiftValue;
22137   SDValue RShOperand;
22138   if (!canLowerSRLToRoundingShiftForVT(Srl, ResVT, DAG, ShiftValue, RShOperand))
22139     return SDValue();
22140   SDValue Rshrnb = DAG.getNode(
22141       AArch64ISD::RSHRNB_I, DL, ResVT,
22142       {RShOperand, DAG.getTargetConstant(ShiftValue, DL, MVT::i32)});
22143   return DAG.getNode(ISD::BITCAST, DL, VT, Rshrnb);
22144 }
22145 
performUzpCombine(SDNode * N,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)22146 static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
22147                                  const AArch64Subtarget *Subtarget) {
22148   SDLoc DL(N);
22149   SDValue Op0 = N->getOperand(0);
22150   SDValue Op1 = N->getOperand(1);
22151   EVT ResVT = N->getValueType(0);
22152 
22153   // uzp(extract_lo(x), extract_hi(x)) -> extract_lo(uzp x, x)
22154   if (Op0.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
22155       Op1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
22156       Op0.getOperand(0) == Op1.getOperand(0)) {
22157 
22158     SDValue SourceVec = Op0.getOperand(0);
22159     uint64_t ExtIdx0 = Op0.getConstantOperandVal(1);
22160     uint64_t ExtIdx1 = Op1.getConstantOperandVal(1);
22161     uint64_t NumElements = SourceVec.getValueType().getVectorMinNumElements();
22162     if (ExtIdx0 == 0 && ExtIdx1 == NumElements / 2) {
22163       EVT OpVT = Op0.getOperand(1).getValueType();
22164       EVT WidenedResVT = ResVT.getDoubleNumVectorElementsVT(*DAG.getContext());
22165       SDValue Uzp = DAG.getNode(N->getOpcode(), DL, WidenedResVT, SourceVec,
22166                                 DAG.getUNDEF(WidenedResVT));
22167       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ResVT, Uzp,
22168                          DAG.getConstant(0, DL, OpVT));
22169     }
22170   }
22171 
22172   // Following optimizations only work with uzp1.
22173   if (N->getOpcode() == AArch64ISD::UZP2)
22174     return SDValue();
22175 
22176   // uzp1(x, undef) -> concat(truncate(x), undef)
22177   if (Op1.getOpcode() == ISD::UNDEF) {
22178     EVT BCVT = MVT::Other, HalfVT = MVT::Other;
22179     switch (ResVT.getSimpleVT().SimpleTy) {
22180     default:
22181       break;
22182     case MVT::v16i8:
22183       BCVT = MVT::v8i16;
22184       HalfVT = MVT::v8i8;
22185       break;
22186     case MVT::v8i16:
22187       BCVT = MVT::v4i32;
22188       HalfVT = MVT::v4i16;
22189       break;
22190     case MVT::v4i32:
22191       BCVT = MVT::v2i64;
22192       HalfVT = MVT::v2i32;
22193       break;
22194     }
22195     if (BCVT != MVT::Other) {
22196       SDValue BC = DAG.getBitcast(BCVT, Op0);
22197       SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, HalfVT, BC);
22198       return DAG.getNode(ISD::CONCAT_VECTORS, DL, ResVT, Trunc,
22199                          DAG.getUNDEF(HalfVT));
22200     }
22201   }
22202 
22203   if (SDValue Urshr = tryCombineExtendRShTrunc(N, DAG))
22204     return Urshr;
22205 
22206   if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Op0, DAG, Subtarget))
22207     return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Rshrnb, Op1);
22208 
22209   if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Op1, DAG, Subtarget))
22210     return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Rshrnb);
22211 
22212   // uzp1(unpklo(uzp1(x, y)), z) => uzp1(x, z)
22213   if (Op0.getOpcode() == AArch64ISD::UUNPKLO) {
22214     if (Op0.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
22215       SDValue X = Op0.getOperand(0).getOperand(0);
22216       return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, X, Op1);
22217     }
22218   }
22219 
22220   // uzp1(x, unpkhi(uzp1(y, z))) => uzp1(x, z)
22221   if (Op1.getOpcode() == AArch64ISD::UUNPKHI) {
22222     if (Op1.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
22223       SDValue Z = Op1.getOperand(0).getOperand(1);
22224       return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Z);
22225     }
22226   }
22227 
22228   // These optimizations only work on little endian.
22229   if (!DAG.getDataLayout().isLittleEndian())
22230     return SDValue();
22231 
22232   // uzp1(bitcast(x), bitcast(y)) -> uzp1(x, y)
22233   // Example:
22234   // nxv4i32 = uzp1 bitcast(nxv4i32 x to nxv2i64), bitcast(nxv4i32 y to nxv2i64)
22235   // to
22236   // nxv4i32 = uzp1 nxv4i32 x, nxv4i32 y
22237   if (isHalvingTruncateAndConcatOfLegalIntScalableType(N) &&
22238       Op0.getOpcode() == ISD::BITCAST && Op1.getOpcode() == ISD::BITCAST) {
22239     if (Op0.getOperand(0).getValueType() == Op1.getOperand(0).getValueType()) {
22240       return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0.getOperand(0),
22241                          Op1.getOperand(0));
22242     }
22243   }
22244 
22245   if (ResVT != MVT::v2i32 && ResVT != MVT::v4i16 && ResVT != MVT::v8i8)
22246     return SDValue();
22247 
22248   SDValue SourceOp0 = peekThroughBitcasts(Op0);
22249   SDValue SourceOp1 = peekThroughBitcasts(Op1);
22250 
22251   // truncating uzp1(x, y) -> xtn(concat (x, y))
22252   if (SourceOp0.getValueType() == SourceOp1.getValueType()) {
22253     EVT Op0Ty = SourceOp0.getValueType();
22254     if ((ResVT == MVT::v4i16 && Op0Ty == MVT::v2i32) ||
22255         (ResVT == MVT::v8i8 && Op0Ty == MVT::v4i16)) {
22256       SDValue Concat =
22257           DAG.getNode(ISD::CONCAT_VECTORS, DL,
22258                       Op0Ty.getDoubleNumVectorElementsVT(*DAG.getContext()),
22259                       SourceOp0, SourceOp1);
22260       return DAG.getNode(ISD::TRUNCATE, DL, ResVT, Concat);
22261     }
22262   }
22263 
22264   // uzp1(xtn x, xtn y) -> xtn(uzp1 (x, y))
22265   if (SourceOp0.getOpcode() != ISD::TRUNCATE ||
22266       SourceOp1.getOpcode() != ISD::TRUNCATE)
22267     return SDValue();
22268   SourceOp0 = SourceOp0.getOperand(0);
22269   SourceOp1 = SourceOp1.getOperand(0);
22270 
22271   if (SourceOp0.getValueType() != SourceOp1.getValueType() ||
22272       !SourceOp0.getValueType().isSimple())
22273     return SDValue();
22274 
22275   EVT ResultTy;
22276 
22277   switch (SourceOp0.getSimpleValueType().SimpleTy) {
22278   case MVT::v2i64:
22279     ResultTy = MVT::v4i32;
22280     break;
22281   case MVT::v4i32:
22282     ResultTy = MVT::v8i16;
22283     break;
22284   case MVT::v8i16:
22285     ResultTy = MVT::v16i8;
22286     break;
22287   default:
22288     return SDValue();
22289   }
22290 
22291   SDValue UzpOp0 = DAG.getNode(ISD::BITCAST, DL, ResultTy, SourceOp0);
22292   SDValue UzpOp1 = DAG.getNode(ISD::BITCAST, DL, ResultTy, SourceOp1);
22293   SDValue UzpResult =
22294       DAG.getNode(AArch64ISD::UZP1, DL, UzpOp0.getValueType(), UzpOp0, UzpOp1);
22295 
22296   EVT BitcastResultTy;
22297 
22298   switch (ResVT.getSimpleVT().SimpleTy) {
22299   case MVT::v2i32:
22300     BitcastResultTy = MVT::v2i64;
22301     break;
22302   case MVT::v4i16:
22303     BitcastResultTy = MVT::v4i32;
22304     break;
22305   case MVT::v8i8:
22306     BitcastResultTy = MVT::v8i16;
22307     break;
22308   default:
22309     llvm_unreachable("Should be one of {v2i32, v4i16, v8i8}");
22310   }
22311 
22312   return DAG.getNode(ISD::TRUNCATE, DL, ResVT,
22313                      DAG.getNode(ISD::BITCAST, DL, BitcastResultTy, UzpResult));
22314 }
22315 
performGLD1Combine(SDNode * N,SelectionDAG & DAG)22316 static SDValue performGLD1Combine(SDNode *N, SelectionDAG &DAG) {
22317   unsigned Opc = N->getOpcode();
22318 
22319   assert(((Opc >= AArch64ISD::GLD1_MERGE_ZERO && // unsigned gather loads
22320            Opc <= AArch64ISD::GLD1_IMM_MERGE_ZERO) ||
22321           (Opc >= AArch64ISD::GLD1S_MERGE_ZERO && // signed gather loads
22322            Opc <= AArch64ISD::GLD1S_IMM_MERGE_ZERO)) &&
22323          "Invalid opcode.");
22324 
22325   const bool Scaled = Opc == AArch64ISD::GLD1_SCALED_MERGE_ZERO ||
22326                       Opc == AArch64ISD::GLD1S_SCALED_MERGE_ZERO;
22327   const bool Signed = Opc == AArch64ISD::GLD1S_MERGE_ZERO ||
22328                       Opc == AArch64ISD::GLD1S_SCALED_MERGE_ZERO;
22329   const bool Extended = Opc == AArch64ISD::GLD1_SXTW_MERGE_ZERO ||
22330                         Opc == AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO ||
22331                         Opc == AArch64ISD::GLD1_UXTW_MERGE_ZERO ||
22332                         Opc == AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO;
22333 
22334   SDLoc DL(N);
22335   SDValue Chain = N->getOperand(0);
22336   SDValue Pg = N->getOperand(1);
22337   SDValue Base = N->getOperand(2);
22338   SDValue Offset = N->getOperand(3);
22339   SDValue Ty = N->getOperand(4);
22340 
22341   EVT ResVT = N->getValueType(0);
22342 
22343   const auto OffsetOpc = Offset.getOpcode();
22344   const bool OffsetIsZExt =
22345       OffsetOpc == AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU;
22346   const bool OffsetIsSExt =
22347       OffsetOpc == AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU;
22348 
22349   // Fold sign/zero extensions of vector offsets into GLD1 nodes where possible.
22350   if (!Extended && (OffsetIsSExt || OffsetIsZExt)) {
22351     SDValue ExtPg = Offset.getOperand(0);
22352     VTSDNode *ExtFrom = cast<VTSDNode>(Offset.getOperand(2).getNode());
22353     EVT ExtFromEVT = ExtFrom->getVT().getVectorElementType();
22354 
22355     // If the predicate for the sign- or zero-extended offset is the
22356     // same as the predicate used for this load and the sign-/zero-extension
22357     // was from a 32-bits...
22358     if (ExtPg == Pg && ExtFromEVT == MVT::i32) {
22359       SDValue UnextendedOffset = Offset.getOperand(1);
22360 
22361       unsigned NewOpc = getGatherVecOpcode(Scaled, OffsetIsSExt, true);
22362       if (Signed)
22363         NewOpc = getSignExtendedGatherOpcode(NewOpc);
22364 
22365       return DAG.getNode(NewOpc, DL, {ResVT, MVT::Other},
22366                          {Chain, Pg, Base, UnextendedOffset, Ty});
22367     }
22368   }
22369 
22370   return SDValue();
22371 }
22372 
22373 /// Optimize a vector shift instruction and its operand if shifted out
22374 /// bits are not used.
performVectorShiftCombine(SDNode * N,const AArch64TargetLowering & TLI,TargetLowering::DAGCombinerInfo & DCI)22375 static SDValue performVectorShiftCombine(SDNode *N,
22376                                          const AArch64TargetLowering &TLI,
22377                                          TargetLowering::DAGCombinerInfo &DCI) {
22378   assert(N->getOpcode() == AArch64ISD::VASHR ||
22379          N->getOpcode() == AArch64ISD::VLSHR);
22380 
22381   SDValue Op = N->getOperand(0);
22382   unsigned OpScalarSize = Op.getScalarValueSizeInBits();
22383 
22384   unsigned ShiftImm = N->getConstantOperandVal(1);
22385   assert(OpScalarSize > ShiftImm && "Invalid shift imm");
22386 
22387   // Remove sign_extend_inreg (ashr(shl(x)) based on the number of sign bits.
22388   if (N->getOpcode() == AArch64ISD::VASHR &&
22389       Op.getOpcode() == AArch64ISD::VSHL &&
22390       N->getOperand(1) == Op.getOperand(1))
22391     if (DCI.DAG.ComputeNumSignBits(Op.getOperand(0)) > ShiftImm)
22392       return Op.getOperand(0);
22393 
22394   // If the shift is exact, the shifted out bits matter.
22395   if (N->getFlags().hasExact())
22396     return SDValue();
22397 
22398   APInt ShiftedOutBits = APInt::getLowBitsSet(OpScalarSize, ShiftImm);
22399   APInt DemandedMask = ~ShiftedOutBits;
22400 
22401   if (TLI.SimplifyDemandedBits(Op, DemandedMask, DCI))
22402     return SDValue(N, 0);
22403 
22404   return SDValue();
22405 }
22406 
performSunpkloCombine(SDNode * N,SelectionDAG & DAG)22407 static SDValue performSunpkloCombine(SDNode *N, SelectionDAG &DAG) {
22408   // sunpklo(sext(pred)) -> sext(extract_low_half(pred))
22409   // This transform works in partnership with performSetCCPunpkCombine to
22410   // remove unnecessary transfer of predicates into standard registers and back
22411   if (N->getOperand(0).getOpcode() == ISD::SIGN_EXTEND &&
22412       N->getOperand(0)->getOperand(0)->getValueType(0).getScalarType() ==
22413           MVT::i1) {
22414     SDValue CC = N->getOperand(0)->getOperand(0);
22415     auto VT = CC->getValueType(0).getHalfNumVectorElementsVT(*DAG.getContext());
22416     SDValue Unpk = DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), VT, CC,
22417                                DAG.getVectorIdxConstant(0, SDLoc(N)));
22418     return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), N->getValueType(0), Unpk);
22419   }
22420 
22421   return SDValue();
22422 }
22423 
22424 /// Target-specific DAG combine function for post-increment LD1 (lane) and
22425 /// post-increment LD1R.
performPostLD1Combine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,bool IsLaneOp)22426 static SDValue performPostLD1Combine(SDNode *N,
22427                                      TargetLowering::DAGCombinerInfo &DCI,
22428                                      bool IsLaneOp) {
22429   if (DCI.isBeforeLegalizeOps())
22430     return SDValue();
22431 
22432   SelectionDAG &DAG = DCI.DAG;
22433   EVT VT = N->getValueType(0);
22434 
22435   if (!VT.is128BitVector() && !VT.is64BitVector())
22436     return SDValue();
22437 
22438   unsigned LoadIdx = IsLaneOp ? 1 : 0;
22439   SDNode *LD = N->getOperand(LoadIdx).getNode();
22440   // If it is not LOAD, can not do such combine.
22441   if (LD->getOpcode() != ISD::LOAD)
22442     return SDValue();
22443 
22444   // The vector lane must be a constant in the LD1LANE opcode.
22445   SDValue Lane;
22446   if (IsLaneOp) {
22447     Lane = N->getOperand(2);
22448     auto *LaneC = dyn_cast<ConstantSDNode>(Lane);
22449     if (!LaneC || LaneC->getZExtValue() >= VT.getVectorNumElements())
22450       return SDValue();
22451   }
22452 
22453   LoadSDNode *LoadSDN = cast<LoadSDNode>(LD);
22454   EVT MemVT = LoadSDN->getMemoryVT();
22455   // Check if memory operand is the same type as the vector element.
22456   if (MemVT != VT.getVectorElementType())
22457     return SDValue();
22458 
22459   // Check if there are other uses. If so, do not combine as it will introduce
22460   // an extra load.
22461   for (SDNode::use_iterator UI = LD->use_begin(), UE = LD->use_end(); UI != UE;
22462        ++UI) {
22463     if (UI.getUse().getResNo() == 1) // Ignore uses of the chain result.
22464       continue;
22465     if (*UI != N)
22466       return SDValue();
22467   }
22468 
22469   // If there is one use and it can splat the value, prefer that operation.
22470   // TODO: This could be expanded to more operations if they reliably use the
22471   // index variants.
22472   if (N->hasOneUse()) {
22473     unsigned UseOpc = N->use_begin()->getOpcode();
22474     if (UseOpc == ISD::FMUL || UseOpc == ISD::FMA)
22475       return SDValue();
22476   }
22477 
22478   SDValue Addr = LD->getOperand(1);
22479   SDValue Vector = N->getOperand(0);
22480   // Search for a use of the address operand that is an increment.
22481   for (SDNode::use_iterator UI = Addr.getNode()->use_begin(), UE =
22482        Addr.getNode()->use_end(); UI != UE; ++UI) {
22483     SDNode *User = *UI;
22484     if (User->getOpcode() != ISD::ADD
22485         || UI.getUse().getResNo() != Addr.getResNo())
22486       continue;
22487 
22488     // If the increment is a constant, it must match the memory ref size.
22489     SDValue Inc = User->getOperand(User->getOperand(0) == Addr ? 1 : 0);
22490     if (ConstantSDNode *CInc = dyn_cast<ConstantSDNode>(Inc.getNode())) {
22491       uint32_t IncVal = CInc->getZExtValue();
22492       unsigned NumBytes = VT.getScalarSizeInBits() / 8;
22493       if (IncVal != NumBytes)
22494         continue;
22495       Inc = DAG.getRegister(AArch64::XZR, MVT::i64);
22496     }
22497 
22498     // To avoid cycle construction make sure that neither the load nor the add
22499     // are predecessors to each other or the Vector.
22500     SmallPtrSet<const SDNode *, 32> Visited;
22501     SmallVector<const SDNode *, 16> Worklist;
22502     Visited.insert(Addr.getNode());
22503     Worklist.push_back(User);
22504     Worklist.push_back(LD);
22505     Worklist.push_back(Vector.getNode());
22506     if (SDNode::hasPredecessorHelper(LD, Visited, Worklist) ||
22507         SDNode::hasPredecessorHelper(User, Visited, Worklist))
22508       continue;
22509 
22510     SmallVector<SDValue, 8> Ops;
22511     Ops.push_back(LD->getOperand(0));  // Chain
22512     if (IsLaneOp) {
22513       Ops.push_back(Vector);           // The vector to be inserted
22514       Ops.push_back(Lane);             // The lane to be inserted in the vector
22515     }
22516     Ops.push_back(Addr);
22517     Ops.push_back(Inc);
22518 
22519     EVT Tys[3] = { VT, MVT::i64, MVT::Other };
22520     SDVTList SDTys = DAG.getVTList(Tys);
22521     unsigned NewOp = IsLaneOp ? AArch64ISD::LD1LANEpost : AArch64ISD::LD1DUPpost;
22522     SDValue UpdN = DAG.getMemIntrinsicNode(NewOp, SDLoc(N), SDTys, Ops,
22523                                            MemVT,
22524                                            LoadSDN->getMemOperand());
22525 
22526     // Update the uses.
22527     SDValue NewResults[] = {
22528         SDValue(LD, 0),            // The result of load
22529         SDValue(UpdN.getNode(), 2) // Chain
22530     };
22531     DCI.CombineTo(LD, NewResults);
22532     DCI.CombineTo(N, SDValue(UpdN.getNode(), 0));     // Dup/Inserted Result
22533     DCI.CombineTo(User, SDValue(UpdN.getNode(), 1));  // Write back register
22534 
22535     break;
22536   }
22537   return SDValue();
22538 }
22539 
22540 /// Simplify ``Addr`` given that the top byte of it is ignored by HW during
22541 /// address translation.
performTBISimplification(SDValue Addr,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)22542 static bool performTBISimplification(SDValue Addr,
22543                                      TargetLowering::DAGCombinerInfo &DCI,
22544                                      SelectionDAG &DAG) {
22545   APInt DemandedMask = APInt::getLowBitsSet(64, 56);
22546   KnownBits Known;
22547   TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
22548                                         !DCI.isBeforeLegalizeOps());
22549   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
22550   if (TLI.SimplifyDemandedBits(Addr, DemandedMask, Known, TLO)) {
22551     DCI.CommitTargetLoweringOpt(TLO);
22552     return true;
22553   }
22554   return false;
22555 }
22556 
foldTruncStoreOfExt(SelectionDAG & DAG,SDNode * N)22557 static SDValue foldTruncStoreOfExt(SelectionDAG &DAG, SDNode *N) {
22558   assert((N->getOpcode() == ISD::STORE || N->getOpcode() == ISD::MSTORE) &&
22559          "Expected STORE dag node in input!");
22560 
22561   if (auto Store = dyn_cast<StoreSDNode>(N)) {
22562     if (!Store->isTruncatingStore() || Store->isIndexed())
22563       return SDValue();
22564     SDValue Ext = Store->getValue();
22565     auto ExtOpCode = Ext.getOpcode();
22566     if (ExtOpCode != ISD::ZERO_EXTEND && ExtOpCode != ISD::SIGN_EXTEND &&
22567         ExtOpCode != ISD::ANY_EXTEND)
22568       return SDValue();
22569     SDValue Orig = Ext->getOperand(0);
22570     if (Store->getMemoryVT() != Orig.getValueType())
22571       return SDValue();
22572     return DAG.getStore(Store->getChain(), SDLoc(Store), Orig,
22573                         Store->getBasePtr(), Store->getMemOperand());
22574   }
22575 
22576   return SDValue();
22577 }
22578 
22579 // A custom combine to lower load <3 x i8> as the more efficient sequence
22580 // below:
22581 //    ldrb wX, [x0, #2]
22582 //    ldrh wY, [x0]
22583 //    orr wX, wY, wX, lsl #16
22584 //    fmov s0, wX
22585 //
22586 // Note that an alternative sequence with even fewer (although usually more
22587 // complex/expensive) instructions would be:
22588 //   ld1r.4h { v0 }, [x0], #2
22589 //   ld1.b { v0 }[2], [x0]
22590 //
22591 // Generating this sequence unfortunately results in noticeably worse codegen
22592 // for code that extends the loaded v3i8, due to legalization breaking vector
22593 // shuffle detection in a way that is very difficult to work around.
22594 // TODO: Revisit once v3i8 legalization has been improved in general.
combineV3I8LoadExt(LoadSDNode * LD,SelectionDAG & DAG)22595 static SDValue combineV3I8LoadExt(LoadSDNode *LD, SelectionDAG &DAG) {
22596   EVT MemVT = LD->getMemoryVT();
22597   if (MemVT != EVT::getVectorVT(*DAG.getContext(), MVT::i8, 3) ||
22598       LD->getOriginalAlign() >= 4)
22599     return SDValue();
22600 
22601   SDLoc DL(LD);
22602   MachineFunction &MF = DAG.getMachineFunction();
22603   SDValue Chain = LD->getChain();
22604   SDValue BasePtr = LD->getBasePtr();
22605   MachineMemOperand *MMO = LD->getMemOperand();
22606   assert(LD->getOffset().isUndef() && "undef offset expected");
22607 
22608   // Load 2 x i8, then 1 x i8.
22609   SDValue L16 = DAG.getLoad(MVT::i16, DL, Chain, BasePtr, MMO);
22610   TypeSize Offset2 = TypeSize::getFixed(2);
22611   SDValue L8 = DAG.getLoad(MVT::i8, DL, Chain,
22612                            DAG.getMemBasePlusOffset(BasePtr, Offset2, DL),
22613                            MF.getMachineMemOperand(MMO, 2, 1));
22614 
22615   // Extend to i32.
22616   SDValue Ext16 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, L16);
22617   SDValue Ext8 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, L8);
22618 
22619   // Pack 2 x i8 and 1 x i8 in an i32 and convert to v4i8.
22620   SDValue Shl = DAG.getNode(ISD::SHL, DL, MVT::i32, Ext8,
22621                             DAG.getConstant(16, DL, MVT::i32));
22622   SDValue Or = DAG.getNode(ISD::OR, DL, MVT::i32, Ext16, Shl);
22623   SDValue Cast = DAG.getNode(ISD::BITCAST, DL, MVT::v4i8, Or);
22624 
22625   // Extract v3i8 again.
22626   SDValue Extract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MemVT, Cast,
22627                                 DAG.getConstant(0, DL, MVT::i64));
22628   SDValue TokenFactor = DAG.getNode(
22629       ISD::TokenFactor, DL, MVT::Other,
22630       {SDValue(cast<SDNode>(L16), 1), SDValue(cast<SDNode>(L8), 1)});
22631   return DAG.getMergeValues({Extract, TokenFactor}, DL);
22632 }
22633 
22634 // Perform TBI simplification if supported by the target and try to break up
22635 // nontemporal loads larger than 256-bits loads for odd types so LDNPQ 256-bit
22636 // load instructions can be selected.
performLOADCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)22637 static SDValue performLOADCombine(SDNode *N,
22638                                   TargetLowering::DAGCombinerInfo &DCI,
22639                                   SelectionDAG &DAG,
22640                                   const AArch64Subtarget *Subtarget) {
22641   if (Subtarget->supportsAddressTopByteIgnored())
22642     performTBISimplification(N->getOperand(1), DCI, DAG);
22643 
22644   LoadSDNode *LD = cast<LoadSDNode>(N);
22645   if (LD->isVolatile() || !Subtarget->isLittleEndian())
22646     return SDValue(N, 0);
22647 
22648   if (SDValue Res = combineV3I8LoadExt(LD, DAG))
22649     return Res;
22650 
22651   if (!LD->isNonTemporal())
22652     return SDValue(N, 0);
22653 
22654   EVT MemVT = LD->getMemoryVT();
22655   if (MemVT.isScalableVector() || MemVT.getSizeInBits() <= 256 ||
22656       MemVT.getSizeInBits() % 256 == 0 ||
22657       256 % MemVT.getScalarSizeInBits() != 0)
22658     return SDValue(N, 0);
22659 
22660   SDLoc DL(LD);
22661   SDValue Chain = LD->getChain();
22662   SDValue BasePtr = LD->getBasePtr();
22663   SDNodeFlags Flags = LD->getFlags();
22664   SmallVector<SDValue, 4> LoadOps;
22665   SmallVector<SDValue, 4> LoadOpsChain;
22666   // Replace any non temporal load over 256-bit with a series of 256 bit loads
22667   // and a scalar/vector load less than 256. This way we can utilize 256-bit
22668   // loads and reduce the amount of load instructions generated.
22669   MVT NewVT =
22670       MVT::getVectorVT(MemVT.getVectorElementType().getSimpleVT(),
22671                        256 / MemVT.getVectorElementType().getSizeInBits());
22672   unsigned Num256Loads = MemVT.getSizeInBits() / 256;
22673   // Create all 256-bit loads starting from offset 0 and up to Num256Loads-1*32.
22674   for (unsigned I = 0; I < Num256Loads; I++) {
22675     unsigned PtrOffset = I * 32;
22676     SDValue NewPtr = DAG.getMemBasePlusOffset(
22677         BasePtr, TypeSize::getFixed(PtrOffset), DL, Flags);
22678     Align NewAlign = commonAlignment(LD->getAlign(), PtrOffset);
22679     SDValue NewLoad = DAG.getLoad(
22680         NewVT, DL, Chain, NewPtr, LD->getPointerInfo().getWithOffset(PtrOffset),
22681         NewAlign, LD->getMemOperand()->getFlags(), LD->getAAInfo());
22682     LoadOps.push_back(NewLoad);
22683     LoadOpsChain.push_back(SDValue(cast<SDNode>(NewLoad), 1));
22684   }
22685 
22686   // Process remaining bits of the load operation.
22687   // This is done by creating an UNDEF vector to match the size of the
22688   // 256-bit loads and inserting the remaining load to it. We extract the
22689   // original load type at the end using EXTRACT_SUBVECTOR instruction.
22690   unsigned BitsRemaining = MemVT.getSizeInBits() % 256;
22691   unsigned PtrOffset = (MemVT.getSizeInBits() - BitsRemaining) / 8;
22692   MVT RemainingVT = MVT::getVectorVT(
22693       MemVT.getVectorElementType().getSimpleVT(),
22694       BitsRemaining / MemVT.getVectorElementType().getSizeInBits());
22695   SDValue NewPtr = DAG.getMemBasePlusOffset(
22696       BasePtr, TypeSize::getFixed(PtrOffset), DL, Flags);
22697   Align NewAlign = commonAlignment(LD->getAlign(), PtrOffset);
22698   SDValue RemainingLoad =
22699       DAG.getLoad(RemainingVT, DL, Chain, NewPtr,
22700                   LD->getPointerInfo().getWithOffset(PtrOffset), NewAlign,
22701                   LD->getMemOperand()->getFlags(), LD->getAAInfo());
22702   SDValue UndefVector = DAG.getUNDEF(NewVT);
22703   SDValue InsertIdx = DAG.getVectorIdxConstant(0, DL);
22704   SDValue ExtendedReminingLoad =
22705       DAG.getNode(ISD::INSERT_SUBVECTOR, DL, NewVT,
22706                   {UndefVector, RemainingLoad, InsertIdx});
22707   LoadOps.push_back(ExtendedReminingLoad);
22708   LoadOpsChain.push_back(SDValue(cast<SDNode>(RemainingLoad), 1));
22709   EVT ConcatVT =
22710       EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(),
22711                        LoadOps.size() * NewVT.getVectorNumElements());
22712   SDValue ConcatVectors =
22713       DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatVT, LoadOps);
22714   // Extract the original vector type size.
22715   SDValue ExtractSubVector =
22716       DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MemVT,
22717                   {ConcatVectors, DAG.getVectorIdxConstant(0, DL)});
22718   SDValue TokenFactor =
22719       DAG.getNode(ISD::TokenFactor, DL, MVT::Other, LoadOpsChain);
22720   return DAG.getMergeValues({ExtractSubVector, TokenFactor}, DL);
22721 }
22722 
tryGetOriginalBoolVectorType(SDValue Op,int Depth=0)22723 static EVT tryGetOriginalBoolVectorType(SDValue Op, int Depth = 0) {
22724   EVT VecVT = Op.getValueType();
22725   assert(VecVT.isVector() && VecVT.getVectorElementType() == MVT::i1 &&
22726          "Need boolean vector type.");
22727 
22728   if (Depth > 3)
22729     return MVT::INVALID_SIMPLE_VALUE_TYPE;
22730 
22731   // We can get the base type from a vector compare or truncate.
22732   if (Op.getOpcode() == ISD::SETCC || Op.getOpcode() == ISD::TRUNCATE)
22733     return Op.getOperand(0).getValueType();
22734 
22735   // If an operand is a bool vector, continue looking.
22736   EVT BaseVT = MVT::INVALID_SIMPLE_VALUE_TYPE;
22737   for (SDValue Operand : Op->op_values()) {
22738     if (Operand.getValueType() != VecVT)
22739       continue;
22740 
22741     EVT OperandVT = tryGetOriginalBoolVectorType(Operand, Depth + 1);
22742     if (!BaseVT.isSimple())
22743       BaseVT = OperandVT;
22744     else if (OperandVT != BaseVT)
22745       return MVT::INVALID_SIMPLE_VALUE_TYPE;
22746   }
22747 
22748   return BaseVT;
22749 }
22750 
22751 // When converting a <N x iX> vector to <N x i1> to store or use as a scalar
22752 // iN, we can use a trick that extracts the i^th bit from the i^th element and
22753 // then performs a vector add to get a scalar bitmask. This requires that each
22754 // element's bits are either all 1 or all 0.
vectorToScalarBitmask(SDNode * N,SelectionDAG & DAG)22755 static SDValue vectorToScalarBitmask(SDNode *N, SelectionDAG &DAG) {
22756   SDLoc DL(N);
22757   SDValue ComparisonResult(N, 0);
22758   EVT VecVT = ComparisonResult.getValueType();
22759   assert(VecVT.isVector() && "Must be a vector type");
22760 
22761   unsigned NumElts = VecVT.getVectorNumElements();
22762   if (NumElts != 2 && NumElts != 4 && NumElts != 8 && NumElts != 16)
22763     return SDValue();
22764 
22765   if (VecVT.getVectorElementType() != MVT::i1 &&
22766       !DAG.getTargetLoweringInfo().isTypeLegal(VecVT))
22767     return SDValue();
22768 
22769   // If we can find the original types to work on instead of a vector of i1,
22770   // we can avoid extend/extract conversion instructions.
22771   if (VecVT.getVectorElementType() == MVT::i1) {
22772     VecVT = tryGetOriginalBoolVectorType(ComparisonResult);
22773     if (!VecVT.isSimple()) {
22774       unsigned BitsPerElement = std::max(64 / NumElts, 8u); // >= 64-bit vector
22775       VecVT = MVT::getVectorVT(MVT::getIntegerVT(BitsPerElement), NumElts);
22776     }
22777   }
22778   VecVT = VecVT.changeVectorElementTypeToInteger();
22779 
22780   // Large vectors don't map directly to this conversion, so to avoid too many
22781   // edge cases, we don't apply it here. The conversion will likely still be
22782   // applied later via multiple smaller vectors, whose results are concatenated.
22783   if (VecVT.getSizeInBits() > 128)
22784     return SDValue();
22785 
22786   // Ensure that all elements' bits are either 0s or 1s.
22787   ComparisonResult = DAG.getSExtOrTrunc(ComparisonResult, DL, VecVT);
22788 
22789   SmallVector<SDValue, 16> MaskConstants;
22790   if (DAG.getSubtarget<AArch64Subtarget>().isNeonAvailable() &&
22791       VecVT == MVT::v16i8) {
22792     // v16i8 is a special case, as we have 16 entries but only 8 positional bits
22793     // per entry. We split it into two halves, apply the mask, zip the halves to
22794     // create 8x 16-bit values, and the perform the vector reduce.
22795     for (unsigned Half = 0; Half < 2; ++Half) {
22796       for (unsigned MaskBit = 1; MaskBit <= 128; MaskBit *= 2) {
22797         MaskConstants.push_back(DAG.getConstant(MaskBit, DL, MVT::i32));
22798       }
22799     }
22800     SDValue Mask = DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, MaskConstants);
22801     SDValue RepresentativeBits =
22802         DAG.getNode(ISD::AND, DL, VecVT, ComparisonResult, Mask);
22803 
22804     SDValue UpperRepresentativeBits =
22805         DAG.getNode(AArch64ISD::EXT, DL, VecVT, RepresentativeBits,
22806                     RepresentativeBits, DAG.getConstant(8, DL, MVT::i32));
22807     SDValue Zipped = DAG.getNode(AArch64ISD::ZIP1, DL, VecVT,
22808                                  RepresentativeBits, UpperRepresentativeBits);
22809     Zipped = DAG.getNode(ISD::BITCAST, DL, MVT::v8i16, Zipped);
22810     return DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i16, Zipped);
22811   }
22812 
22813   // All other vector sizes.
22814   unsigned MaxBitMask = 1u << (VecVT.getVectorNumElements() - 1);
22815   for (unsigned MaskBit = 1; MaskBit <= MaxBitMask; MaskBit *= 2) {
22816     MaskConstants.push_back(DAG.getConstant(MaskBit, DL, MVT::i64));
22817   }
22818 
22819   SDValue Mask = DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, MaskConstants);
22820   SDValue RepresentativeBits =
22821       DAG.getNode(ISD::AND, DL, VecVT, ComparisonResult, Mask);
22822   EVT ResultVT = MVT::getIntegerVT(std::max<unsigned>(
22823       NumElts, VecVT.getVectorElementType().getSizeInBits()));
22824   return DAG.getNode(ISD::VECREDUCE_ADD, DL, ResultVT, RepresentativeBits);
22825 }
22826 
combineBoolVectorAndTruncateStore(SelectionDAG & DAG,StoreSDNode * Store)22827 static SDValue combineBoolVectorAndTruncateStore(SelectionDAG &DAG,
22828                                                  StoreSDNode *Store) {
22829   if (!Store->isTruncatingStore())
22830     return SDValue();
22831 
22832   SDLoc DL(Store);
22833   SDValue VecOp = Store->getValue();
22834   EVT VT = VecOp.getValueType();
22835   EVT MemVT = Store->getMemoryVT();
22836 
22837   if (!MemVT.isVector() || !VT.isVector() ||
22838       MemVT.getVectorElementType() != MVT::i1)
22839     return SDValue();
22840 
22841   // If we are storing a vector that we are currently building, let
22842   // `scalarizeVectorStore()` handle this more efficiently.
22843   if (VecOp.getOpcode() == ISD::BUILD_VECTOR)
22844     return SDValue();
22845 
22846   VecOp = DAG.getNode(ISD::TRUNCATE, DL, MemVT, VecOp);
22847   SDValue VectorBits = vectorToScalarBitmask(VecOp.getNode(), DAG);
22848   if (!VectorBits)
22849     return SDValue();
22850 
22851   EVT StoreVT =
22852       EVT::getIntegerVT(*DAG.getContext(), MemVT.getStoreSizeInBits());
22853   SDValue ExtendedBits = DAG.getZExtOrTrunc(VectorBits, DL, StoreVT);
22854   return DAG.getStore(Store->getChain(), DL, ExtendedBits, Store->getBasePtr(),
22855                       Store->getMemOperand());
22856 }
22857 
isHalvingTruncateOfLegalScalableType(EVT SrcVT,EVT DstVT)22858 bool isHalvingTruncateOfLegalScalableType(EVT SrcVT, EVT DstVT) {
22859   return (SrcVT == MVT::nxv8i16 && DstVT == MVT::nxv8i8) ||
22860          (SrcVT == MVT::nxv4i32 && DstVT == MVT::nxv4i16) ||
22861          (SrcVT == MVT::nxv2i64 && DstVT == MVT::nxv2i32);
22862 }
22863 
22864 // Combine store (trunc X to <3 x i8>) to sequence of ST1.b.
combineI8TruncStore(StoreSDNode * ST,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)22865 static SDValue combineI8TruncStore(StoreSDNode *ST, SelectionDAG &DAG,
22866                                    const AArch64Subtarget *Subtarget) {
22867   SDValue Value = ST->getValue();
22868   EVT ValueVT = Value.getValueType();
22869 
22870   if (ST->isVolatile() || !Subtarget->isLittleEndian() ||
22871       Value.getOpcode() != ISD::TRUNCATE ||
22872       ValueVT != EVT::getVectorVT(*DAG.getContext(), MVT::i8, 3))
22873     return SDValue();
22874 
22875   assert(ST->getOffset().isUndef() && "undef offset expected");
22876   SDLoc DL(ST);
22877   auto WideVT = EVT::getVectorVT(
22878       *DAG.getContext(),
22879       Value->getOperand(0).getValueType().getVectorElementType(), 4);
22880   SDValue UndefVector = DAG.getUNDEF(WideVT);
22881   SDValue WideTrunc = DAG.getNode(
22882       ISD::INSERT_SUBVECTOR, DL, WideVT,
22883       {UndefVector, Value->getOperand(0), DAG.getVectorIdxConstant(0, DL)});
22884   SDValue Cast = DAG.getNode(
22885       ISD::BITCAST, DL, WideVT.getSizeInBits() == 64 ? MVT::v8i8 : MVT::v16i8,
22886       WideTrunc);
22887 
22888   MachineFunction &MF = DAG.getMachineFunction();
22889   SDValue Chain = ST->getChain();
22890   MachineMemOperand *MMO = ST->getMemOperand();
22891   unsigned IdxScale = WideVT.getScalarSizeInBits() / 8;
22892   SDValue E2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, Cast,
22893                            DAG.getConstant(2 * IdxScale, DL, MVT::i64));
22894   TypeSize Offset2 = TypeSize::getFixed(2);
22895   SDValue Ptr2 = DAG.getMemBasePlusOffset(ST->getBasePtr(), Offset2, DL);
22896   Chain = DAG.getStore(Chain, DL, E2, Ptr2, MF.getMachineMemOperand(MMO, 2, 1));
22897 
22898   SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, Cast,
22899                            DAG.getConstant(1 * IdxScale, DL, MVT::i64));
22900   TypeSize Offset1 = TypeSize::getFixed(1);
22901   SDValue Ptr1 = DAG.getMemBasePlusOffset(ST->getBasePtr(), Offset1, DL);
22902   Chain = DAG.getStore(Chain, DL, E1, Ptr1, MF.getMachineMemOperand(MMO, 1, 1));
22903 
22904   SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, Cast,
22905                            DAG.getConstant(0, DL, MVT::i64));
22906   Chain = DAG.getStore(Chain, DL, E0, ST->getBasePtr(),
22907                        MF.getMachineMemOperand(MMO, 0, 1));
22908   return Chain;
22909 }
22910 
performSTORECombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)22911 static SDValue performSTORECombine(SDNode *N,
22912                                    TargetLowering::DAGCombinerInfo &DCI,
22913                                    SelectionDAG &DAG,
22914                                    const AArch64Subtarget *Subtarget) {
22915   StoreSDNode *ST = cast<StoreSDNode>(N);
22916   SDValue Chain = ST->getChain();
22917   SDValue Value = ST->getValue();
22918   SDValue Ptr = ST->getBasePtr();
22919   EVT ValueVT = Value.getValueType();
22920 
22921   auto hasValidElementTypeForFPTruncStore = [](EVT VT) {
22922     EVT EltVT = VT.getVectorElementType();
22923     return EltVT == MVT::f32 || EltVT == MVT::f64;
22924   };
22925 
22926   if (SDValue Res = combineI8TruncStore(ST, DAG, Subtarget))
22927     return Res;
22928 
22929   // If this is an FP_ROUND followed by a store, fold this into a truncating
22930   // store. We can do this even if this is already a truncstore.
22931   // We purposefully don't care about legality of the nodes here as we know
22932   // they can be split down into something legal.
22933   if (DCI.isBeforeLegalizeOps() && Value.getOpcode() == ISD::FP_ROUND &&
22934       Value.getNode()->hasOneUse() && ST->isUnindexed() &&
22935       Subtarget->useSVEForFixedLengthVectors() &&
22936       ValueVT.isFixedLengthVector() &&
22937       ValueVT.getFixedSizeInBits() >= Subtarget->getMinSVEVectorSizeInBits() &&
22938       hasValidElementTypeForFPTruncStore(Value.getOperand(0).getValueType()))
22939     return DAG.getTruncStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
22940                              ST->getMemoryVT(), ST->getMemOperand());
22941 
22942   if (SDValue Split = splitStores(N, DCI, DAG, Subtarget))
22943     return Split;
22944 
22945   if (Subtarget->supportsAddressTopByteIgnored() &&
22946       performTBISimplification(N->getOperand(2), DCI, DAG))
22947     return SDValue(N, 0);
22948 
22949   if (SDValue Store = foldTruncStoreOfExt(DAG, N))
22950     return Store;
22951 
22952   if (SDValue Store = combineBoolVectorAndTruncateStore(DAG, ST))
22953     return Store;
22954 
22955   if (ST->isTruncatingStore()) {
22956     EVT StoreVT = ST->getMemoryVT();
22957     if (!isHalvingTruncateOfLegalScalableType(ValueVT, StoreVT))
22958       return SDValue();
22959     if (SDValue Rshrnb =
22960             trySimplifySrlAddToRshrnb(ST->getOperand(1), DAG, Subtarget)) {
22961       return DAG.getTruncStore(ST->getChain(), ST, Rshrnb, ST->getBasePtr(),
22962                                StoreVT, ST->getMemOperand());
22963     }
22964   }
22965 
22966   return SDValue();
22967 }
22968 
performMSTORECombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)22969 static SDValue performMSTORECombine(SDNode *N,
22970                                     TargetLowering::DAGCombinerInfo &DCI,
22971                                     SelectionDAG &DAG,
22972                                     const AArch64Subtarget *Subtarget) {
22973   MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(N);
22974   SDValue Value = MST->getValue();
22975   SDValue Mask = MST->getMask();
22976   SDLoc DL(N);
22977 
22978   // If this is a UZP1 followed by a masked store, fold this into a masked
22979   // truncating store.  We can do this even if this is already a masked
22980   // truncstore.
22981   if (Value.getOpcode() == AArch64ISD::UZP1 && Value->hasOneUse() &&
22982       MST->isUnindexed() && Mask->getOpcode() == AArch64ISD::PTRUE &&
22983       Value.getValueType().isInteger()) {
22984     Value = Value.getOperand(0);
22985     if (Value.getOpcode() == ISD::BITCAST) {
22986       EVT HalfVT =
22987           Value.getValueType().getHalfNumVectorElementsVT(*DAG.getContext());
22988       EVT InVT = Value.getOperand(0).getValueType();
22989 
22990       if (HalfVT.widenIntegerVectorElementType(*DAG.getContext()) == InVT) {
22991         unsigned MinSVESize = Subtarget->getMinSVEVectorSizeInBits();
22992         unsigned PgPattern = Mask->getConstantOperandVal(0);
22993 
22994         // Ensure we can double the size of the predicate pattern
22995         unsigned NumElts = getNumElementsFromSVEPredPattern(PgPattern);
22996         if (NumElts && NumElts * InVT.getVectorElementType().getSizeInBits() <=
22997                            MinSVESize) {
22998           Mask = getPTrue(DAG, DL, InVT.changeVectorElementType(MVT::i1),
22999                           PgPattern);
23000           return DAG.getMaskedStore(MST->getChain(), DL, Value.getOperand(0),
23001                                     MST->getBasePtr(), MST->getOffset(), Mask,
23002                                     MST->getMemoryVT(), MST->getMemOperand(),
23003                                     MST->getAddressingMode(),
23004                                     /*IsTruncating=*/true);
23005         }
23006       }
23007     }
23008   }
23009 
23010   if (MST->isTruncatingStore()) {
23011     EVT ValueVT = Value->getValueType(0);
23012     EVT MemVT = MST->getMemoryVT();
23013     if (!isHalvingTruncateOfLegalScalableType(ValueVT, MemVT))
23014       return SDValue();
23015     if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Value, DAG, Subtarget)) {
23016       return DAG.getMaskedStore(MST->getChain(), DL, Rshrnb, MST->getBasePtr(),
23017                                 MST->getOffset(), MST->getMask(),
23018                                 MST->getMemoryVT(), MST->getMemOperand(),
23019                                 MST->getAddressingMode(), true);
23020     }
23021   }
23022 
23023   return SDValue();
23024 }
23025 
23026 /// \return true if part of the index was folded into the Base.
foldIndexIntoBase(SDValue & BasePtr,SDValue & Index,SDValue Scale,SDLoc DL,SelectionDAG & DAG)23027 static bool foldIndexIntoBase(SDValue &BasePtr, SDValue &Index, SDValue Scale,
23028                               SDLoc DL, SelectionDAG &DAG) {
23029   // This function assumes a vector of i64 indices.
23030   EVT IndexVT = Index.getValueType();
23031   if (!IndexVT.isVector() || IndexVT.getVectorElementType() != MVT::i64)
23032     return false;
23033 
23034   // Simplify:
23035   //   BasePtr = Ptr
23036   //   Index = X + splat(Offset)
23037   // ->
23038   //   BasePtr = Ptr + Offset * scale.
23039   //   Index = X
23040   if (Index.getOpcode() == ISD::ADD) {
23041     if (auto Offset = DAG.getSplatValue(Index.getOperand(1))) {
23042       Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, Scale);
23043       BasePtr = DAG.getNode(ISD::ADD, DL, MVT::i64, BasePtr, Offset);
23044       Index = Index.getOperand(0);
23045       return true;
23046     }
23047   }
23048 
23049   // Simplify:
23050   //   BasePtr = Ptr
23051   //   Index = (X + splat(Offset)) << splat(Shift)
23052   // ->
23053   //   BasePtr = Ptr + (Offset << Shift) * scale)
23054   //   Index = X << splat(shift)
23055   if (Index.getOpcode() == ISD::SHL &&
23056       Index.getOperand(0).getOpcode() == ISD::ADD) {
23057     SDValue Add = Index.getOperand(0);
23058     SDValue ShiftOp = Index.getOperand(1);
23059     SDValue OffsetOp = Add.getOperand(1);
23060     if (auto Shift = DAG.getSplatValue(ShiftOp))
23061       if (auto Offset = DAG.getSplatValue(OffsetOp)) {
23062         Offset = DAG.getNode(ISD::SHL, DL, MVT::i64, Offset, Shift);
23063         Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, Scale);
23064         BasePtr = DAG.getNode(ISD::ADD, DL, MVT::i64, BasePtr, Offset);
23065         Index = DAG.getNode(ISD::SHL, DL, Index.getValueType(),
23066                             Add.getOperand(0), ShiftOp);
23067         return true;
23068       }
23069   }
23070 
23071   return false;
23072 }
23073 
23074 // Analyse the specified address returning true if a more optimal addressing
23075 // mode is available. When returning true all parameters are updated to reflect
23076 // their recommended values.
findMoreOptimalIndexType(const MaskedGatherScatterSDNode * N,SDValue & BasePtr,SDValue & Index,SelectionDAG & DAG)23077 static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
23078                                      SDValue &BasePtr, SDValue &Index,
23079                                      SelectionDAG &DAG) {
23080   // Try to iteratively fold parts of the index into the base pointer to
23081   // simplify the index as much as possible.
23082   bool Changed = false;
23083   while (foldIndexIntoBase(BasePtr, Index, N->getScale(), SDLoc(N), DAG))
23084     Changed = true;
23085 
23086   // Only consider element types that are pointer sized as smaller types can
23087   // be easily promoted.
23088   EVT IndexVT = Index.getValueType();
23089   if (IndexVT.getVectorElementType() != MVT::i64 || IndexVT == MVT::nxv2i64)
23090     return Changed;
23091 
23092   // Can indices be trivially shrunk?
23093   EVT DataVT = N->getOperand(1).getValueType();
23094   // Don't attempt to shrink the index for fixed vectors of 64 bit data since it
23095   // will later be re-extended to 64 bits in legalization
23096   if (DataVT.isFixedLengthVector() && DataVT.getScalarSizeInBits() == 64)
23097     return Changed;
23098   if (ISD::isVectorShrinkable(Index.getNode(), 32, N->isIndexSigned())) {
23099     EVT NewIndexVT = IndexVT.changeVectorElementType(MVT::i32);
23100     Index = DAG.getNode(ISD::TRUNCATE, SDLoc(N), NewIndexVT, Index);
23101     return true;
23102   }
23103 
23104   // Match:
23105   //   Index = step(const)
23106   int64_t Stride = 0;
23107   if (Index.getOpcode() == ISD::STEP_VECTOR) {
23108     Stride = cast<ConstantSDNode>(Index.getOperand(0))->getSExtValue();
23109   }
23110   // Match:
23111   //   Index = step(const) << shift(const)
23112   else if (Index.getOpcode() == ISD::SHL &&
23113            Index.getOperand(0).getOpcode() == ISD::STEP_VECTOR) {
23114     SDValue RHS = Index.getOperand(1);
23115     if (auto *Shift =
23116             dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(RHS))) {
23117       int64_t Step = (int64_t)Index.getOperand(0).getConstantOperandVal(1);
23118       Stride = Step << Shift->getZExtValue();
23119     }
23120   }
23121 
23122   // Return early because no supported pattern is found.
23123   if (Stride == 0)
23124     return Changed;
23125 
23126   if (Stride < std::numeric_limits<int32_t>::min() ||
23127       Stride > std::numeric_limits<int32_t>::max())
23128     return Changed;
23129 
23130   const auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
23131   unsigned MaxVScale =
23132       Subtarget.getMaxSVEVectorSizeInBits() / AArch64::SVEBitsPerBlock;
23133   int64_t LastElementOffset =
23134       IndexVT.getVectorMinNumElements() * Stride * MaxVScale;
23135 
23136   if (LastElementOffset < std::numeric_limits<int32_t>::min() ||
23137       LastElementOffset > std::numeric_limits<int32_t>::max())
23138     return Changed;
23139 
23140   EVT NewIndexVT = IndexVT.changeVectorElementType(MVT::i32);
23141   // Stride does not scale explicitly by 'Scale', because it happens in
23142   // the gather/scatter addressing mode.
23143   Index = DAG.getStepVector(SDLoc(N), NewIndexVT, APInt(32, Stride));
23144   return true;
23145 }
23146 
performMaskedGatherScatterCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)23147 static SDValue performMaskedGatherScatterCombine(
23148     SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) {
23149   MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
23150   assert(MGS && "Can only combine gather load or scatter store nodes");
23151 
23152   if (!DCI.isBeforeLegalize())
23153     return SDValue();
23154 
23155   SDLoc DL(MGS);
23156   SDValue Chain = MGS->getChain();
23157   SDValue Scale = MGS->getScale();
23158   SDValue Index = MGS->getIndex();
23159   SDValue Mask = MGS->getMask();
23160   SDValue BasePtr = MGS->getBasePtr();
23161   ISD::MemIndexType IndexType = MGS->getIndexType();
23162 
23163   if (!findMoreOptimalIndexType(MGS, BasePtr, Index, DAG))
23164     return SDValue();
23165 
23166   // Here we catch such cases early and change MGATHER's IndexType to allow
23167   // the use of an Index that's more legalisation friendly.
23168   if (auto *MGT = dyn_cast<MaskedGatherSDNode>(MGS)) {
23169     SDValue PassThru = MGT->getPassThru();
23170     SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
23171     return DAG.getMaskedGather(
23172         DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
23173         Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
23174   }
23175   auto *MSC = cast<MaskedScatterSDNode>(MGS);
23176   SDValue Data = MSC->getValue();
23177   SDValue Ops[] = {Chain, Data, Mask, BasePtr, Index, Scale};
23178   return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL,
23179                               Ops, MSC->getMemOperand(), IndexType,
23180                               MSC->isTruncatingStore());
23181 }
23182 
23183 /// Target-specific DAG combine function for NEON load/store intrinsics
23184 /// to merge base address updates.
performNEONPostLDSTCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)23185 static SDValue performNEONPostLDSTCombine(SDNode *N,
23186                                           TargetLowering::DAGCombinerInfo &DCI,
23187                                           SelectionDAG &DAG) {
23188   if (DCI.isBeforeLegalize() || DCI.isCalledByLegalizer())
23189     return SDValue();
23190 
23191   unsigned AddrOpIdx = N->getNumOperands() - 1;
23192   SDValue Addr = N->getOperand(AddrOpIdx);
23193 
23194   // Search for a use of the address operand that is an increment.
23195   for (SDNode::use_iterator UI = Addr.getNode()->use_begin(),
23196        UE = Addr.getNode()->use_end(); UI != UE; ++UI) {
23197     SDNode *User = *UI;
23198     if (User->getOpcode() != ISD::ADD ||
23199         UI.getUse().getResNo() != Addr.getResNo())
23200       continue;
23201 
23202     // Check that the add is independent of the load/store.  Otherwise, folding
23203     // it would create a cycle.
23204     SmallPtrSet<const SDNode *, 32> Visited;
23205     SmallVector<const SDNode *, 16> Worklist;
23206     Visited.insert(Addr.getNode());
23207     Worklist.push_back(N);
23208     Worklist.push_back(User);
23209     if (SDNode::hasPredecessorHelper(N, Visited, Worklist) ||
23210         SDNode::hasPredecessorHelper(User, Visited, Worklist))
23211       continue;
23212 
23213     // Find the new opcode for the updating load/store.
23214     bool IsStore = false;
23215     bool IsLaneOp = false;
23216     bool IsDupOp = false;
23217     unsigned NewOpc = 0;
23218     unsigned NumVecs = 0;
23219     unsigned IntNo = N->getConstantOperandVal(1);
23220     switch (IntNo) {
23221     default: llvm_unreachable("unexpected intrinsic for Neon base update");
23222     case Intrinsic::aarch64_neon_ld2:       NewOpc = AArch64ISD::LD2post;
23223       NumVecs = 2; break;
23224     case Intrinsic::aarch64_neon_ld3:       NewOpc = AArch64ISD::LD3post;
23225       NumVecs = 3; break;
23226     case Intrinsic::aarch64_neon_ld4:       NewOpc = AArch64ISD::LD4post;
23227       NumVecs = 4; break;
23228     case Intrinsic::aarch64_neon_st2:       NewOpc = AArch64ISD::ST2post;
23229       NumVecs = 2; IsStore = true; break;
23230     case Intrinsic::aarch64_neon_st3:       NewOpc = AArch64ISD::ST3post;
23231       NumVecs = 3; IsStore = true; break;
23232     case Intrinsic::aarch64_neon_st4:       NewOpc = AArch64ISD::ST4post;
23233       NumVecs = 4; IsStore = true; break;
23234     case Intrinsic::aarch64_neon_ld1x2:     NewOpc = AArch64ISD::LD1x2post;
23235       NumVecs = 2; break;
23236     case Intrinsic::aarch64_neon_ld1x3:     NewOpc = AArch64ISD::LD1x3post;
23237       NumVecs = 3; break;
23238     case Intrinsic::aarch64_neon_ld1x4:     NewOpc = AArch64ISD::LD1x4post;
23239       NumVecs = 4; break;
23240     case Intrinsic::aarch64_neon_st1x2:     NewOpc = AArch64ISD::ST1x2post;
23241       NumVecs = 2; IsStore = true; break;
23242     case Intrinsic::aarch64_neon_st1x3:     NewOpc = AArch64ISD::ST1x3post;
23243       NumVecs = 3; IsStore = true; break;
23244     case Intrinsic::aarch64_neon_st1x4:     NewOpc = AArch64ISD::ST1x4post;
23245       NumVecs = 4; IsStore = true; break;
23246     case Intrinsic::aarch64_neon_ld2r:      NewOpc = AArch64ISD::LD2DUPpost;
23247       NumVecs = 2; IsDupOp = true; break;
23248     case Intrinsic::aarch64_neon_ld3r:      NewOpc = AArch64ISD::LD3DUPpost;
23249       NumVecs = 3; IsDupOp = true; break;
23250     case Intrinsic::aarch64_neon_ld4r:      NewOpc = AArch64ISD::LD4DUPpost;
23251       NumVecs = 4; IsDupOp = true; break;
23252     case Intrinsic::aarch64_neon_ld2lane:   NewOpc = AArch64ISD::LD2LANEpost;
23253       NumVecs = 2; IsLaneOp = true; break;
23254     case Intrinsic::aarch64_neon_ld3lane:   NewOpc = AArch64ISD::LD3LANEpost;
23255       NumVecs = 3; IsLaneOp = true; break;
23256     case Intrinsic::aarch64_neon_ld4lane:   NewOpc = AArch64ISD::LD4LANEpost;
23257       NumVecs = 4; IsLaneOp = true; break;
23258     case Intrinsic::aarch64_neon_st2lane:   NewOpc = AArch64ISD::ST2LANEpost;
23259       NumVecs = 2; IsStore = true; IsLaneOp = true; break;
23260     case Intrinsic::aarch64_neon_st3lane:   NewOpc = AArch64ISD::ST3LANEpost;
23261       NumVecs = 3; IsStore = true; IsLaneOp = true; break;
23262     case Intrinsic::aarch64_neon_st4lane:   NewOpc = AArch64ISD::ST4LANEpost;
23263       NumVecs = 4; IsStore = true; IsLaneOp = true; break;
23264     }
23265 
23266     EVT VecTy;
23267     if (IsStore)
23268       VecTy = N->getOperand(2).getValueType();
23269     else
23270       VecTy = N->getValueType(0);
23271 
23272     // If the increment is a constant, it must match the memory ref size.
23273     SDValue Inc = User->getOperand(User->getOperand(0) == Addr ? 1 : 0);
23274     if (ConstantSDNode *CInc = dyn_cast<ConstantSDNode>(Inc.getNode())) {
23275       uint32_t IncVal = CInc->getZExtValue();
23276       unsigned NumBytes = NumVecs * VecTy.getSizeInBits() / 8;
23277       if (IsLaneOp || IsDupOp)
23278         NumBytes /= VecTy.getVectorNumElements();
23279       if (IncVal != NumBytes)
23280         continue;
23281       Inc = DAG.getRegister(AArch64::XZR, MVT::i64);
23282     }
23283     SmallVector<SDValue, 8> Ops;
23284     Ops.push_back(N->getOperand(0)); // Incoming chain
23285     // Load lane and store have vector list as input.
23286     if (IsLaneOp || IsStore)
23287       for (unsigned i = 2; i < AddrOpIdx; ++i)
23288         Ops.push_back(N->getOperand(i));
23289     Ops.push_back(Addr); // Base register
23290     Ops.push_back(Inc);
23291 
23292     // Return Types.
23293     EVT Tys[6];
23294     unsigned NumResultVecs = (IsStore ? 0 : NumVecs);
23295     unsigned n;
23296     for (n = 0; n < NumResultVecs; ++n)
23297       Tys[n] = VecTy;
23298     Tys[n++] = MVT::i64;  // Type of write back register
23299     Tys[n] = MVT::Other;  // Type of the chain
23300     SDVTList SDTys = DAG.getVTList(ArrayRef(Tys, NumResultVecs + 2));
23301 
23302     MemIntrinsicSDNode *MemInt = cast<MemIntrinsicSDNode>(N);
23303     SDValue UpdN = DAG.getMemIntrinsicNode(NewOpc, SDLoc(N), SDTys, Ops,
23304                                            MemInt->getMemoryVT(),
23305                                            MemInt->getMemOperand());
23306 
23307     // Update the uses.
23308     std::vector<SDValue> NewResults;
23309     for (unsigned i = 0; i < NumResultVecs; ++i) {
23310       NewResults.push_back(SDValue(UpdN.getNode(), i));
23311     }
23312     NewResults.push_back(SDValue(UpdN.getNode(), NumResultVecs + 1));
23313     DCI.CombineTo(N, NewResults);
23314     DCI.CombineTo(User, SDValue(UpdN.getNode(), NumResultVecs));
23315 
23316     break;
23317   }
23318   return SDValue();
23319 }
23320 
23321 // Checks to see if the value is the prescribed width and returns information
23322 // about its extension mode.
23323 static
checkValueWidth(SDValue V,unsigned width,ISD::LoadExtType & ExtType)23324 bool checkValueWidth(SDValue V, unsigned width, ISD::LoadExtType &ExtType) {
23325   ExtType = ISD::NON_EXTLOAD;
23326   switch(V.getNode()->getOpcode()) {
23327   default:
23328     return false;
23329   case ISD::LOAD: {
23330     LoadSDNode *LoadNode = cast<LoadSDNode>(V.getNode());
23331     if ((LoadNode->getMemoryVT() == MVT::i8 && width == 8)
23332        || (LoadNode->getMemoryVT() == MVT::i16 && width == 16)) {
23333       ExtType = LoadNode->getExtensionType();
23334       return true;
23335     }
23336     return false;
23337   }
23338   case ISD::AssertSext: {
23339     VTSDNode *TypeNode = cast<VTSDNode>(V.getNode()->getOperand(1));
23340     if ((TypeNode->getVT() == MVT::i8 && width == 8)
23341        || (TypeNode->getVT() == MVT::i16 && width == 16)) {
23342       ExtType = ISD::SEXTLOAD;
23343       return true;
23344     }
23345     return false;
23346   }
23347   case ISD::AssertZext: {
23348     VTSDNode *TypeNode = cast<VTSDNode>(V.getNode()->getOperand(1));
23349     if ((TypeNode->getVT() == MVT::i8 && width == 8)
23350        || (TypeNode->getVT() == MVT::i16 && width == 16)) {
23351       ExtType = ISD::ZEXTLOAD;
23352       return true;
23353     }
23354     return false;
23355   }
23356   case ISD::Constant:
23357   case ISD::TargetConstant: {
23358     return std::abs(cast<ConstantSDNode>(V.getNode())->getSExtValue()) <
23359            1LL << (width - 1);
23360   }
23361   }
23362 
23363   return true;
23364 }
23365 
23366 // This function does a whole lot of voodoo to determine if the tests are
23367 // equivalent without and with a mask. Essentially what happens is that given a
23368 // DAG resembling:
23369 //
23370 //  +-------------+ +-------------+ +-------------+ +-------------+
23371 //  |    Input    | | AddConstant | | CompConstant| |     CC      |
23372 //  +-------------+ +-------------+ +-------------+ +-------------+
23373 //           |           |           |               |
23374 //           V           V           |    +----------+
23375 //          +-------------+  +----+  |    |
23376 //          |     ADD     |  |0xff|  |    |
23377 //          +-------------+  +----+  |    |
23378 //                  |           |    |    |
23379 //                  V           V    |    |
23380 //                 +-------------+   |    |
23381 //                 |     AND     |   |    |
23382 //                 +-------------+   |    |
23383 //                      |            |    |
23384 //                      +-----+      |    |
23385 //                            |      |    |
23386 //                            V      V    V
23387 //                           +-------------+
23388 //                           |     CMP     |
23389 //                           +-------------+
23390 //
23391 // The AND node may be safely removed for some combinations of inputs. In
23392 // particular we need to take into account the extension type of the Input,
23393 // the exact values of AddConstant, CompConstant, and CC, along with the nominal
23394 // width of the input (this can work for any width inputs, the above graph is
23395 // specific to 8 bits.
23396 //
23397 // The specific equations were worked out by generating output tables for each
23398 // AArch64CC value in terms of and AddConstant (w1), CompConstant(w2). The
23399 // problem was simplified by working with 4 bit inputs, which means we only
23400 // needed to reason about 24 distinct bit patterns: 8 patterns unique to zero
23401 // extension (8,15), 8 patterns unique to sign extensions (-8,-1), and 8
23402 // patterns present in both extensions (0,7). For every distinct set of
23403 // AddConstant and CompConstants bit patterns we can consider the masked and
23404 // unmasked versions to be equivalent if the result of this function is true for
23405 // all 16 distinct bit patterns of for the current extension type of Input (w0).
23406 //
23407 //   sub      w8, w0, w1
23408 //   and      w10, w8, #0x0f
23409 //   cmp      w8, w2
23410 //   cset     w9, AArch64CC
23411 //   cmp      w10, w2
23412 //   cset     w11, AArch64CC
23413 //   cmp      w9, w11
23414 //   cset     w0, eq
23415 //   ret
23416 //
23417 // Since the above function shows when the outputs are equivalent it defines
23418 // when it is safe to remove the AND. Unfortunately it only runs on AArch64 and
23419 // would be expensive to run during compiles. The equations below were written
23420 // in a test harness that confirmed they gave equivalent outputs to the above
23421 // for all inputs function, so they can be used determine if the removal is
23422 // legal instead.
23423 //
23424 // isEquivalentMaskless() is the code for testing if the AND can be removed
23425 // factored out of the DAG recognition as the DAG can take several forms.
23426 
isEquivalentMaskless(unsigned CC,unsigned width,ISD::LoadExtType ExtType,int AddConstant,int CompConstant)23427 static bool isEquivalentMaskless(unsigned CC, unsigned width,
23428                                  ISD::LoadExtType ExtType, int AddConstant,
23429                                  int CompConstant) {
23430   // By being careful about our equations and only writing the in term
23431   // symbolic values and well known constants (0, 1, -1, MaxUInt) we can
23432   // make them generally applicable to all bit widths.
23433   int MaxUInt = (1 << width);
23434 
23435   // For the purposes of these comparisons sign extending the type is
23436   // equivalent to zero extending the add and displacing it by half the integer
23437   // width. Provided we are careful and make sure our equations are valid over
23438   // the whole range we can just adjust the input and avoid writing equations
23439   // for sign extended inputs.
23440   if (ExtType == ISD::SEXTLOAD)
23441     AddConstant -= (1 << (width-1));
23442 
23443   switch(CC) {
23444   case AArch64CC::LE:
23445   case AArch64CC::GT:
23446     if ((AddConstant == 0) ||
23447         (CompConstant == MaxUInt - 1 && AddConstant < 0) ||
23448         (AddConstant >= 0 && CompConstant < 0) ||
23449         (AddConstant <= 0 && CompConstant <= 0 && CompConstant < AddConstant))
23450       return true;
23451     break;
23452   case AArch64CC::LT:
23453   case AArch64CC::GE:
23454     if ((AddConstant == 0) ||
23455         (AddConstant >= 0 && CompConstant <= 0) ||
23456         (AddConstant <= 0 && CompConstant <= 0 && CompConstant <= AddConstant))
23457       return true;
23458     break;
23459   case AArch64CC::HI:
23460   case AArch64CC::LS:
23461     if ((AddConstant >= 0 && CompConstant < 0) ||
23462        (AddConstant <= 0 && CompConstant >= -1 &&
23463         CompConstant < AddConstant + MaxUInt))
23464       return true;
23465    break;
23466   case AArch64CC::PL:
23467   case AArch64CC::MI:
23468     if ((AddConstant == 0) ||
23469         (AddConstant > 0 && CompConstant <= 0) ||
23470         (AddConstant < 0 && CompConstant <= AddConstant))
23471       return true;
23472     break;
23473   case AArch64CC::LO:
23474   case AArch64CC::HS:
23475     if ((AddConstant >= 0 && CompConstant <= 0) ||
23476         (AddConstant <= 0 && CompConstant >= 0 &&
23477          CompConstant <= AddConstant + MaxUInt))
23478       return true;
23479     break;
23480   case AArch64CC::EQ:
23481   case AArch64CC::NE:
23482     if ((AddConstant > 0 && CompConstant < 0) ||
23483         (AddConstant < 0 && CompConstant >= 0 &&
23484          CompConstant < AddConstant + MaxUInt) ||
23485         (AddConstant >= 0 && CompConstant >= 0 &&
23486          CompConstant >= AddConstant) ||
23487         (AddConstant <= 0 && CompConstant < 0 && CompConstant < AddConstant))
23488       return true;
23489     break;
23490   case AArch64CC::VS:
23491   case AArch64CC::VC:
23492   case AArch64CC::AL:
23493   case AArch64CC::NV:
23494     return true;
23495   case AArch64CC::Invalid:
23496     break;
23497   }
23498 
23499   return false;
23500 }
23501 
23502 // (X & C) >u Mask --> (X & (C & (~Mask)) != 0
23503 // (X & C) <u Pow2 --> (X & (C & ~(Pow2-1)) == 0
performSubsToAndsCombine(SDNode * N,SDNode * SubsNode,SDNode * AndNode,SelectionDAG & DAG,unsigned CCIndex,unsigned CmpIndex,unsigned CC)23504 static SDValue performSubsToAndsCombine(SDNode *N, SDNode *SubsNode,
23505                                         SDNode *AndNode, SelectionDAG &DAG,
23506                                         unsigned CCIndex, unsigned CmpIndex,
23507                                         unsigned CC) {
23508   ConstantSDNode *SubsC = dyn_cast<ConstantSDNode>(SubsNode->getOperand(1));
23509   if (!SubsC)
23510     return SDValue();
23511 
23512   APInt SubsAP = SubsC->getAPIntValue();
23513   if (CC == AArch64CC::HI) {
23514     if (!SubsAP.isMask())
23515       return SDValue();
23516   } else if (CC == AArch64CC::LO) {
23517     if (!SubsAP.isPowerOf2())
23518       return SDValue();
23519   } else
23520     return SDValue();
23521 
23522   ConstantSDNode *AndC = dyn_cast<ConstantSDNode>(AndNode->getOperand(1));
23523   if (!AndC)
23524     return SDValue();
23525 
23526   APInt MaskAP = CC == AArch64CC::HI ? SubsAP : (SubsAP - 1);
23527 
23528   SDLoc DL(N);
23529   APInt AndSMask = (~MaskAP) & AndC->getAPIntValue();
23530   SDValue ANDS = DAG.getNode(
23531       AArch64ISD::ANDS, DL, SubsNode->getVTList(), AndNode->getOperand(0),
23532       DAG.getConstant(AndSMask, DL, SubsC->getValueType(0)));
23533   SDValue AArch64_CC =
23534       DAG.getConstant(CC == AArch64CC::HI ? AArch64CC::NE : AArch64CC::EQ, DL,
23535                       N->getOperand(CCIndex)->getValueType(0));
23536 
23537   // For now, only performCSELCombine and performBRCONDCombine call this
23538   // function. And both of them pass 2 for CCIndex, 3 for CmpIndex with 4
23539   // operands. So just init the ops direct to simplify the code. If we have some
23540   // other case with different CCIndex, CmpIndex, we need to use for loop to
23541   // rewrite the code here.
23542   // TODO: Do we need to assert number of operand is 4 here?
23543   assert((CCIndex == 2 && CmpIndex == 3) &&
23544          "Expected CCIndex to be 2 and CmpIndex to be 3.");
23545   SDValue Ops[] = {N->getOperand(0), N->getOperand(1), AArch64_CC,
23546                    ANDS.getValue(1)};
23547   return DAG.getNode(N->getOpcode(), N, N->getVTList(), Ops);
23548 }
23549 
23550 static
performCONDCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG,unsigned CCIndex,unsigned CmpIndex)23551 SDValue performCONDCombine(SDNode *N,
23552                            TargetLowering::DAGCombinerInfo &DCI,
23553                            SelectionDAG &DAG, unsigned CCIndex,
23554                            unsigned CmpIndex) {
23555   unsigned CC = cast<ConstantSDNode>(N->getOperand(CCIndex))->getSExtValue();
23556   SDNode *SubsNode = N->getOperand(CmpIndex).getNode();
23557   unsigned CondOpcode = SubsNode->getOpcode();
23558 
23559   if (CondOpcode != AArch64ISD::SUBS || SubsNode->hasAnyUseOfValue(0) ||
23560       !SubsNode->hasOneUse())
23561     return SDValue();
23562 
23563   // There is a SUBS feeding this condition. Is it fed by a mask we can
23564   // use?
23565 
23566   SDNode *AndNode = SubsNode->getOperand(0).getNode();
23567   unsigned MaskBits = 0;
23568 
23569   if (AndNode->getOpcode() != ISD::AND)
23570     return SDValue();
23571 
23572   if (SDValue Val = performSubsToAndsCombine(N, SubsNode, AndNode, DAG, CCIndex,
23573                                              CmpIndex, CC))
23574     return Val;
23575 
23576   if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(AndNode->getOperand(1))) {
23577     uint32_t CNV = CN->getZExtValue();
23578     if (CNV == 255)
23579       MaskBits = 8;
23580     else if (CNV == 65535)
23581       MaskBits = 16;
23582   }
23583 
23584   if (!MaskBits)
23585     return SDValue();
23586 
23587   SDValue AddValue = AndNode->getOperand(0);
23588 
23589   if (AddValue.getOpcode() != ISD::ADD)
23590     return SDValue();
23591 
23592   // The basic dag structure is correct, grab the inputs and validate them.
23593 
23594   SDValue AddInputValue1 = AddValue.getNode()->getOperand(0);
23595   SDValue AddInputValue2 = AddValue.getNode()->getOperand(1);
23596   SDValue SubsInputValue = SubsNode->getOperand(1);
23597 
23598   // The mask is present and the provenance of all the values is a smaller type,
23599   // lets see if the mask is superfluous.
23600 
23601   if (!isa<ConstantSDNode>(AddInputValue2.getNode()) ||
23602       !isa<ConstantSDNode>(SubsInputValue.getNode()))
23603     return SDValue();
23604 
23605   ISD::LoadExtType ExtType;
23606 
23607   if (!checkValueWidth(SubsInputValue, MaskBits, ExtType) ||
23608       !checkValueWidth(AddInputValue2, MaskBits, ExtType) ||
23609       !checkValueWidth(AddInputValue1, MaskBits, ExtType) )
23610     return SDValue();
23611 
23612   if(!isEquivalentMaskless(CC, MaskBits, ExtType,
23613                 cast<ConstantSDNode>(AddInputValue2.getNode())->getSExtValue(),
23614                 cast<ConstantSDNode>(SubsInputValue.getNode())->getSExtValue()))
23615     return SDValue();
23616 
23617   // The AND is not necessary, remove it.
23618 
23619   SDVTList VTs = DAG.getVTList(SubsNode->getValueType(0),
23620                                SubsNode->getValueType(1));
23621   SDValue Ops[] = { AddValue, SubsNode->getOperand(1) };
23622 
23623   SDValue NewValue = DAG.getNode(CondOpcode, SDLoc(SubsNode), VTs, Ops);
23624   DAG.ReplaceAllUsesWith(SubsNode, NewValue.getNode());
23625 
23626   return SDValue(N, 0);
23627 }
23628 
23629 // Optimize compare with zero and branch.
performBRCONDCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)23630 static SDValue performBRCONDCombine(SDNode *N,
23631                                     TargetLowering::DAGCombinerInfo &DCI,
23632                                     SelectionDAG &DAG) {
23633   MachineFunction &MF = DAG.getMachineFunction();
23634   // Speculation tracking/SLH assumes that optimized TB(N)Z/CB(N)Z instructions
23635   // will not be produced, as they are conditional branch instructions that do
23636   // not set flags.
23637   if (MF.getFunction().hasFnAttribute(Attribute::SpeculativeLoadHardening))
23638     return SDValue();
23639 
23640   if (SDValue NV = performCONDCombine(N, DCI, DAG, 2, 3))
23641     N = NV.getNode();
23642   SDValue Chain = N->getOperand(0);
23643   SDValue Dest = N->getOperand(1);
23644   SDValue CCVal = N->getOperand(2);
23645   SDValue Cmp = N->getOperand(3);
23646 
23647   assert(isa<ConstantSDNode>(CCVal) && "Expected a ConstantSDNode here!");
23648   unsigned CC = CCVal->getAsZExtVal();
23649   if (CC != AArch64CC::EQ && CC != AArch64CC::NE)
23650     return SDValue();
23651 
23652   unsigned CmpOpc = Cmp.getOpcode();
23653   if (CmpOpc != AArch64ISD::ADDS && CmpOpc != AArch64ISD::SUBS)
23654     return SDValue();
23655 
23656   // Only attempt folding if there is only one use of the flag and no use of the
23657   // value.
23658   if (!Cmp->hasNUsesOfValue(0, 0) || !Cmp->hasNUsesOfValue(1, 1))
23659     return SDValue();
23660 
23661   SDValue LHS = Cmp.getOperand(0);
23662   SDValue RHS = Cmp.getOperand(1);
23663 
23664   assert(LHS.getValueType() == RHS.getValueType() &&
23665          "Expected the value type to be the same for both operands!");
23666   if (LHS.getValueType() != MVT::i32 && LHS.getValueType() != MVT::i64)
23667     return SDValue();
23668 
23669   if (isNullConstant(LHS))
23670     std::swap(LHS, RHS);
23671 
23672   if (!isNullConstant(RHS))
23673     return SDValue();
23674 
23675   if (LHS.getOpcode() == ISD::SHL || LHS.getOpcode() == ISD::SRA ||
23676       LHS.getOpcode() == ISD::SRL)
23677     return SDValue();
23678 
23679   // Fold the compare into the branch instruction.
23680   SDValue BR;
23681   if (CC == AArch64CC::EQ)
23682     BR = DAG.getNode(AArch64ISD::CBZ, SDLoc(N), MVT::Other, Chain, LHS, Dest);
23683   else
23684     BR = DAG.getNode(AArch64ISD::CBNZ, SDLoc(N), MVT::Other, Chain, LHS, Dest);
23685 
23686   // Do not add new nodes to DAG combiner worklist.
23687   DCI.CombineTo(N, BR, false);
23688 
23689   return SDValue();
23690 }
23691 
foldCSELofCTTZ(SDNode * N,SelectionDAG & DAG)23692 static SDValue foldCSELofCTTZ(SDNode *N, SelectionDAG &DAG) {
23693   unsigned CC = N->getConstantOperandVal(2);
23694   SDValue SUBS = N->getOperand(3);
23695   SDValue Zero, CTTZ;
23696 
23697   if (CC == AArch64CC::EQ && SUBS.getOpcode() == AArch64ISD::SUBS) {
23698     Zero = N->getOperand(0);
23699     CTTZ = N->getOperand(1);
23700   } else if (CC == AArch64CC::NE && SUBS.getOpcode() == AArch64ISD::SUBS) {
23701     Zero = N->getOperand(1);
23702     CTTZ = N->getOperand(0);
23703   } else
23704     return SDValue();
23705 
23706   if ((CTTZ.getOpcode() != ISD::CTTZ && CTTZ.getOpcode() != ISD::TRUNCATE) ||
23707       (CTTZ.getOpcode() == ISD::TRUNCATE &&
23708        CTTZ.getOperand(0).getOpcode() != ISD::CTTZ))
23709     return SDValue();
23710 
23711   assert((CTTZ.getValueType() == MVT::i32 || CTTZ.getValueType() == MVT::i64) &&
23712          "Illegal type in CTTZ folding");
23713 
23714   if (!isNullConstant(Zero) || !isNullConstant(SUBS.getOperand(1)))
23715     return SDValue();
23716 
23717   SDValue X = CTTZ.getOpcode() == ISD::TRUNCATE
23718                   ? CTTZ.getOperand(0).getOperand(0)
23719                   : CTTZ.getOperand(0);
23720 
23721   if (X != SUBS.getOperand(0))
23722     return SDValue();
23723 
23724   unsigned BitWidth = CTTZ.getOpcode() == ISD::TRUNCATE
23725                           ? CTTZ.getOperand(0).getValueSizeInBits()
23726                           : CTTZ.getValueSizeInBits();
23727   SDValue BitWidthMinusOne =
23728       DAG.getConstant(BitWidth - 1, SDLoc(N), CTTZ.getValueType());
23729   return DAG.getNode(ISD::AND, SDLoc(N), CTTZ.getValueType(), CTTZ,
23730                      BitWidthMinusOne);
23731 }
23732 
23733 // (CSEL l r EQ (CMP (CSEL x y cc2 cond) x)) => (CSEL l r cc2 cond)
23734 // (CSEL l r EQ (CMP (CSEL x y cc2 cond) y)) => (CSEL l r !cc2 cond)
23735 // Where x and y are constants and x != y
23736 
23737 // (CSEL l r NE (CMP (CSEL x y cc2 cond) x)) => (CSEL l r !cc2 cond)
23738 // (CSEL l r NE (CMP (CSEL x y cc2 cond) y)) => (CSEL l r cc2 cond)
23739 // Where x and y are constants and x != y
foldCSELOfCSEL(SDNode * Op,SelectionDAG & DAG)23740 static SDValue foldCSELOfCSEL(SDNode *Op, SelectionDAG &DAG) {
23741   SDValue L = Op->getOperand(0);
23742   SDValue R = Op->getOperand(1);
23743   AArch64CC::CondCode OpCC =
23744       static_cast<AArch64CC::CondCode>(Op->getConstantOperandVal(2));
23745 
23746   SDValue OpCmp = Op->getOperand(3);
23747   if (!isCMP(OpCmp))
23748     return SDValue();
23749 
23750   SDValue CmpLHS = OpCmp.getOperand(0);
23751   SDValue CmpRHS = OpCmp.getOperand(1);
23752 
23753   if (CmpRHS.getOpcode() == AArch64ISD::CSEL)
23754     std::swap(CmpLHS, CmpRHS);
23755   else if (CmpLHS.getOpcode() != AArch64ISD::CSEL)
23756     return SDValue();
23757 
23758   SDValue X = CmpLHS->getOperand(0);
23759   SDValue Y = CmpLHS->getOperand(1);
23760   if (!isa<ConstantSDNode>(X) || !isa<ConstantSDNode>(Y) || X == Y) {
23761     return SDValue();
23762   }
23763 
23764   // If one of the constant is opaque constant, x,y sdnode is still different
23765   // but the real value maybe the same. So check APInt here to make sure the
23766   // code is correct.
23767   ConstantSDNode *CX = cast<ConstantSDNode>(X);
23768   ConstantSDNode *CY = cast<ConstantSDNode>(Y);
23769   if (CX->getAPIntValue() == CY->getAPIntValue())
23770     return SDValue();
23771 
23772   AArch64CC::CondCode CC =
23773       static_cast<AArch64CC::CondCode>(CmpLHS->getConstantOperandVal(2));
23774   SDValue Cond = CmpLHS->getOperand(3);
23775 
23776   if (CmpRHS == Y)
23777     CC = AArch64CC::getInvertedCondCode(CC);
23778   else if (CmpRHS != X)
23779     return SDValue();
23780 
23781   if (OpCC == AArch64CC::NE)
23782     CC = AArch64CC::getInvertedCondCode(CC);
23783   else if (OpCC != AArch64CC::EQ)
23784     return SDValue();
23785 
23786   SDLoc DL(Op);
23787   EVT VT = Op->getValueType(0);
23788 
23789   SDValue CCValue = DAG.getConstant(CC, DL, MVT::i32);
23790   return DAG.getNode(AArch64ISD::CSEL, DL, VT, L, R, CCValue, Cond);
23791 }
23792 
23793 // Optimize CSEL instructions
performCSELCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)23794 static SDValue performCSELCombine(SDNode *N,
23795                                   TargetLowering::DAGCombinerInfo &DCI,
23796                                   SelectionDAG &DAG) {
23797   // CSEL x, x, cc -> x
23798   if (N->getOperand(0) == N->getOperand(1))
23799     return N->getOperand(0);
23800 
23801   if (SDValue R = foldCSELOfCSEL(N, DAG))
23802     return R;
23803 
23804   // CSEL 0, cttz(X), eq(X, 0) -> AND cttz bitwidth-1
23805   // CSEL cttz(X), 0, ne(X, 0) -> AND cttz bitwidth-1
23806   if (SDValue Folded = foldCSELofCTTZ(N, DAG))
23807 		return Folded;
23808 
23809   return performCONDCombine(N, DCI, DAG, 2, 3);
23810 }
23811 
23812 // Try to re-use an already extended operand of a vector SetCC feeding a
23813 // extended select. Doing so avoids requiring another full extension of the
23814 // SET_CC result when lowering the select.
tryToWidenSetCCOperands(SDNode * Op,SelectionDAG & DAG)23815 static SDValue tryToWidenSetCCOperands(SDNode *Op, SelectionDAG &DAG) {
23816   EVT Op0MVT = Op->getOperand(0).getValueType();
23817   if (!Op0MVT.isVector() || Op->use_empty())
23818     return SDValue();
23819 
23820   // Make sure that all uses of Op are VSELECTs with result matching types where
23821   // the result type has a larger element type than the SetCC operand.
23822   SDNode *FirstUse = *Op->use_begin();
23823   if (FirstUse->getOpcode() != ISD::VSELECT)
23824     return SDValue();
23825   EVT UseMVT = FirstUse->getValueType(0);
23826   if (UseMVT.getScalarSizeInBits() <= Op0MVT.getScalarSizeInBits())
23827     return SDValue();
23828   if (any_of(Op->uses(), [&UseMVT](const SDNode *N) {
23829         return N->getOpcode() != ISD::VSELECT || N->getValueType(0) != UseMVT;
23830       }))
23831     return SDValue();
23832 
23833   APInt V;
23834   if (!ISD::isConstantSplatVector(Op->getOperand(1).getNode(), V))
23835     return SDValue();
23836 
23837   SDLoc DL(Op);
23838   SDValue Op0ExtV;
23839   SDValue Op1ExtV;
23840   ISD::CondCode CC = cast<CondCodeSDNode>(Op->getOperand(2))->get();
23841   // Check if the first operand of the SET_CC is already extended. If it is,
23842   // split the SET_CC and re-use the extended version of the operand.
23843   SDNode *Op0SExt = DAG.getNodeIfExists(ISD::SIGN_EXTEND, DAG.getVTList(UseMVT),
23844                                         Op->getOperand(0));
23845   SDNode *Op0ZExt = DAG.getNodeIfExists(ISD::ZERO_EXTEND, DAG.getVTList(UseMVT),
23846                                         Op->getOperand(0));
23847   if (Op0SExt && (isSignedIntSetCC(CC) || isIntEqualitySetCC(CC))) {
23848     Op0ExtV = SDValue(Op0SExt, 0);
23849     Op1ExtV = DAG.getNode(ISD::SIGN_EXTEND, DL, UseMVT, Op->getOperand(1));
23850   } else if (Op0ZExt && (isUnsignedIntSetCC(CC) || isIntEqualitySetCC(CC))) {
23851     Op0ExtV = SDValue(Op0ZExt, 0);
23852     Op1ExtV = DAG.getNode(ISD::ZERO_EXTEND, DL, UseMVT, Op->getOperand(1));
23853   } else
23854     return SDValue();
23855 
23856   return DAG.getNode(ISD::SETCC, DL, UseMVT.changeVectorElementType(MVT::i1),
23857                      Op0ExtV, Op1ExtV, Op->getOperand(2));
23858 }
23859 
23860 static SDValue
performVecReduceBitwiseCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)23861 performVecReduceBitwiseCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
23862                                SelectionDAG &DAG) {
23863   SDValue Vec = N->getOperand(0);
23864   if (DCI.isBeforeLegalize() &&
23865       Vec.getValueType().getVectorElementType() == MVT::i1 &&
23866       Vec.getValueType().isFixedLengthVector() &&
23867       Vec.getValueType().isPow2VectorType()) {
23868     SDLoc DL(N);
23869     return getVectorBitwiseReduce(N->getOpcode(), Vec, N->getValueType(0), DL,
23870                                   DAG);
23871   }
23872 
23873   return SDValue();
23874 }
23875 
performSETCCCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)23876 static SDValue performSETCCCombine(SDNode *N,
23877                                    TargetLowering::DAGCombinerInfo &DCI,
23878                                    SelectionDAG &DAG) {
23879   assert(N->getOpcode() == ISD::SETCC && "Unexpected opcode!");
23880   SDValue LHS = N->getOperand(0);
23881   SDValue RHS = N->getOperand(1);
23882   ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
23883   SDLoc DL(N);
23884   EVT VT = N->getValueType(0);
23885 
23886   if (SDValue V = tryToWidenSetCCOperands(N, DAG))
23887     return V;
23888 
23889   // setcc (csel 0, 1, cond, X), 1, ne ==> csel 0, 1, !cond, X
23890   if (Cond == ISD::SETNE && isOneConstant(RHS) &&
23891       LHS->getOpcode() == AArch64ISD::CSEL &&
23892       isNullConstant(LHS->getOperand(0)) && isOneConstant(LHS->getOperand(1)) &&
23893       LHS->hasOneUse()) {
23894     // Invert CSEL's condition.
23895     auto OldCond =
23896         static_cast<AArch64CC::CondCode>(LHS.getConstantOperandVal(2));
23897     auto NewCond = getInvertedCondCode(OldCond);
23898 
23899     // csel 0, 1, !cond, X
23900     SDValue CSEL =
23901         DAG.getNode(AArch64ISD::CSEL, DL, LHS.getValueType(), LHS.getOperand(0),
23902                     LHS.getOperand(1), DAG.getConstant(NewCond, DL, MVT::i32),
23903                     LHS.getOperand(3));
23904     return DAG.getZExtOrTrunc(CSEL, DL, VT);
23905   }
23906 
23907   // setcc (srl x, imm), 0, ne ==> setcc (and x, (-1 << imm)), 0, ne
23908   if (Cond == ISD::SETNE && isNullConstant(RHS) &&
23909       LHS->getOpcode() == ISD::SRL && isa<ConstantSDNode>(LHS->getOperand(1)) &&
23910       LHS->getConstantOperandVal(1) < VT.getScalarSizeInBits() &&
23911       LHS->hasOneUse()) {
23912     EVT TstVT = LHS->getValueType(0);
23913     if (TstVT.isScalarInteger() && TstVT.getFixedSizeInBits() <= 64) {
23914       // this pattern will get better opt in emitComparison
23915       uint64_t TstImm = -1ULL << LHS->getConstantOperandVal(1);
23916       SDValue TST = DAG.getNode(ISD::AND, DL, TstVT, LHS->getOperand(0),
23917                                 DAG.getConstant(TstImm, DL, TstVT));
23918       return DAG.getNode(ISD::SETCC, DL, VT, TST, RHS, N->getOperand(2));
23919     }
23920   }
23921 
23922   // setcc (iN (bitcast (vNi1 X))), 0, (eq|ne)
23923   //   ==> setcc (iN (zext (i1 (vecreduce_or (vNi1 X))))), 0, (eq|ne)
23924   // setcc (iN (bitcast (vNi1 X))), -1, (eq|ne)
23925   //   ==> setcc (iN (sext (i1 (vecreduce_and (vNi1 X))))), -1, (eq|ne)
23926   if (DCI.isBeforeLegalize() && VT.isScalarInteger() &&
23927       (Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
23928       (isNullConstant(RHS) || isAllOnesConstant(RHS)) &&
23929       LHS->getOpcode() == ISD::BITCAST) {
23930     EVT ToVT = LHS->getValueType(0);
23931     EVT FromVT = LHS->getOperand(0).getValueType();
23932     if (FromVT.isFixedLengthVector() &&
23933         FromVT.getVectorElementType() == MVT::i1) {
23934       bool IsNull = isNullConstant(RHS);
23935       LHS = DAG.getNode(IsNull ? ISD::VECREDUCE_OR : ISD::VECREDUCE_AND,
23936                         DL, MVT::i1, LHS->getOperand(0));
23937       LHS = DAG.getNode(IsNull ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND, DL, ToVT,
23938                         LHS);
23939       return DAG.getSetCC(DL, VT, LHS, RHS, Cond);
23940     }
23941   }
23942 
23943   // Try to perform the memcmp when the result is tested for [in]equality with 0
23944   if (SDValue V = performOrXorChainCombine(N, DAG))
23945     return V;
23946 
23947   return SDValue();
23948 }
23949 
23950 // Replace a flag-setting operator (eg ANDS) with the generic version
23951 // (eg AND) if the flag is unused.
performFlagSettingCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,unsigned GenericOpcode)23952 static SDValue performFlagSettingCombine(SDNode *N,
23953                                          TargetLowering::DAGCombinerInfo &DCI,
23954                                          unsigned GenericOpcode) {
23955   SDLoc DL(N);
23956   SDValue LHS = N->getOperand(0);
23957   SDValue RHS = N->getOperand(1);
23958   EVT VT = N->getValueType(0);
23959 
23960   // If the flag result isn't used, convert back to a generic opcode.
23961   if (!N->hasAnyUseOfValue(1)) {
23962     SDValue Res = DCI.DAG.getNode(GenericOpcode, DL, VT, N->ops());
23963     return DCI.DAG.getMergeValues({Res, DCI.DAG.getConstant(0, DL, MVT::i32)},
23964                                   DL);
23965   }
23966 
23967   // Combine identical generic nodes into this node, re-using the result.
23968   if (SDNode *Generic = DCI.DAG.getNodeIfExists(
23969           GenericOpcode, DCI.DAG.getVTList(VT), {LHS, RHS}))
23970     DCI.CombineTo(Generic, SDValue(N, 0));
23971 
23972   return SDValue();
23973 }
23974 
performSetCCPunpkCombine(SDNode * N,SelectionDAG & DAG)23975 static SDValue performSetCCPunpkCombine(SDNode *N, SelectionDAG &DAG) {
23976   // setcc_merge_zero pred
23977   //   (sign_extend (extract_subvector (setcc_merge_zero ... pred ...))), 0, ne
23978   //   => extract_subvector (inner setcc_merge_zero)
23979   SDValue Pred = N->getOperand(0);
23980   SDValue LHS = N->getOperand(1);
23981   SDValue RHS = N->getOperand(2);
23982   ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(3))->get();
23983 
23984   if (Cond != ISD::SETNE || !isZerosVector(RHS.getNode()) ||
23985       LHS->getOpcode() != ISD::SIGN_EXTEND)
23986     return SDValue();
23987 
23988   SDValue Extract = LHS->getOperand(0);
23989   if (Extract->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
23990       Extract->getValueType(0) != N->getValueType(0) ||
23991       Extract->getConstantOperandVal(1) != 0)
23992     return SDValue();
23993 
23994   SDValue InnerSetCC = Extract->getOperand(0);
23995   if (InnerSetCC->getOpcode() != AArch64ISD::SETCC_MERGE_ZERO)
23996     return SDValue();
23997 
23998   // By this point we've effectively got
23999   // zero_inactive_lanes_and_trunc_i1(sext_i1(A)). If we can prove A's inactive
24000   // lanes are already zero then the trunc(sext()) sequence is redundant and we
24001   // can operate on A directly.
24002   SDValue InnerPred = InnerSetCC.getOperand(0);
24003   if (Pred.getOpcode() == AArch64ISD::PTRUE &&
24004       InnerPred.getOpcode() == AArch64ISD::PTRUE &&
24005       Pred.getConstantOperandVal(0) == InnerPred.getConstantOperandVal(0) &&
24006       Pred->getConstantOperandVal(0) >= AArch64SVEPredPattern::vl1 &&
24007       Pred->getConstantOperandVal(0) <= AArch64SVEPredPattern::vl256)
24008     return Extract;
24009 
24010   return SDValue();
24011 }
24012 
24013 static SDValue
performSetccMergeZeroCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)24014 performSetccMergeZeroCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
24015   assert(N->getOpcode() == AArch64ISD::SETCC_MERGE_ZERO &&
24016          "Unexpected opcode!");
24017 
24018   SelectionDAG &DAG = DCI.DAG;
24019   SDValue Pred = N->getOperand(0);
24020   SDValue LHS = N->getOperand(1);
24021   SDValue RHS = N->getOperand(2);
24022   ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(3))->get();
24023 
24024   if (SDValue V = performSetCCPunpkCombine(N, DAG))
24025     return V;
24026 
24027   if (Cond == ISD::SETNE && isZerosVector(RHS.getNode()) &&
24028       LHS->getOpcode() == ISD::SIGN_EXTEND &&
24029       LHS->getOperand(0)->getValueType(0) == N->getValueType(0)) {
24030     //    setcc_merge_zero(
24031     //       pred, extend(setcc_merge_zero(pred, ...)), != splat(0))
24032     // => setcc_merge_zero(pred, ...)
24033     if (LHS->getOperand(0)->getOpcode() == AArch64ISD::SETCC_MERGE_ZERO &&
24034         LHS->getOperand(0)->getOperand(0) == Pred)
24035       return LHS->getOperand(0);
24036 
24037     //    setcc_merge_zero(
24038     //        all_active, extend(nxvNi1 ...), != splat(0))
24039     // -> nxvNi1 ...
24040     if (isAllActivePredicate(DAG, Pred))
24041       return LHS->getOperand(0);
24042 
24043     //    setcc_merge_zero(
24044     //        pred, extend(nxvNi1 ...), != splat(0))
24045     // -> nxvNi1 and(pred, ...)
24046     if (DCI.isAfterLegalizeDAG())
24047       // Do this after legalization to allow more folds on setcc_merge_zero
24048       // to be recognized.
24049       return DAG.getNode(ISD::AND, SDLoc(N), N->getValueType(0),
24050                          LHS->getOperand(0), Pred);
24051   }
24052 
24053   return SDValue();
24054 }
24055 
24056 // Optimize some simple tbz/tbnz cases.  Returns the new operand and bit to test
24057 // as well as whether the test should be inverted.  This code is required to
24058 // catch these cases (as opposed to standard dag combines) because
24059 // AArch64ISD::TBZ is matched during legalization.
getTestBitOperand(SDValue Op,unsigned & Bit,bool & Invert,SelectionDAG & DAG)24060 static SDValue getTestBitOperand(SDValue Op, unsigned &Bit, bool &Invert,
24061                                  SelectionDAG &DAG) {
24062 
24063   if (!Op->hasOneUse())
24064     return Op;
24065 
24066   // We don't handle undef/constant-fold cases below, as they should have
24067   // already been taken care of (e.g. and of 0, test of undefined shifted bits,
24068   // etc.)
24069 
24070   // (tbz (trunc x), b) -> (tbz x, b)
24071   // This case is just here to enable more of the below cases to be caught.
24072   if (Op->getOpcode() == ISD::TRUNCATE &&
24073       Bit < Op->getValueType(0).getSizeInBits()) {
24074     return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
24075   }
24076 
24077   // (tbz (any_ext x), b) -> (tbz x, b) if we don't use the extended bits.
24078   if (Op->getOpcode() == ISD::ANY_EXTEND &&
24079       Bit < Op->getOperand(0).getValueSizeInBits()) {
24080     return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
24081   }
24082 
24083   if (Op->getNumOperands() != 2)
24084     return Op;
24085 
24086   auto *C = dyn_cast<ConstantSDNode>(Op->getOperand(1));
24087   if (!C)
24088     return Op;
24089 
24090   switch (Op->getOpcode()) {
24091   default:
24092     return Op;
24093 
24094   // (tbz (and x, m), b) -> (tbz x, b)
24095   case ISD::AND:
24096     if ((C->getZExtValue() >> Bit) & 1)
24097       return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
24098     return Op;
24099 
24100   // (tbz (shl x, c), b) -> (tbz x, b-c)
24101   case ISD::SHL:
24102     if (C->getZExtValue() <= Bit &&
24103         (Bit - C->getZExtValue()) < Op->getValueType(0).getSizeInBits()) {
24104       Bit = Bit - C->getZExtValue();
24105       return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
24106     }
24107     return Op;
24108 
24109   // (tbz (sra x, c), b) -> (tbz x, b+c) or (tbz x, msb) if b+c is > # bits in x
24110   case ISD::SRA:
24111     Bit = Bit + C->getZExtValue();
24112     if (Bit >= Op->getValueType(0).getSizeInBits())
24113       Bit = Op->getValueType(0).getSizeInBits() - 1;
24114     return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
24115 
24116   // (tbz (srl x, c), b) -> (tbz x, b+c)
24117   case ISD::SRL:
24118     if ((Bit + C->getZExtValue()) < Op->getValueType(0).getSizeInBits()) {
24119       Bit = Bit + C->getZExtValue();
24120       return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
24121     }
24122     return Op;
24123 
24124   // (tbz (xor x, -1), b) -> (tbnz x, b)
24125   case ISD::XOR:
24126     if ((C->getZExtValue() >> Bit) & 1)
24127       Invert = !Invert;
24128     return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
24129   }
24130 }
24131 
24132 // Optimize test single bit zero/non-zero and branch.
performTBZCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)24133 static SDValue performTBZCombine(SDNode *N,
24134                                  TargetLowering::DAGCombinerInfo &DCI,
24135                                  SelectionDAG &DAG) {
24136   unsigned Bit = N->getConstantOperandVal(2);
24137   bool Invert = false;
24138   SDValue TestSrc = N->getOperand(1);
24139   SDValue NewTestSrc = getTestBitOperand(TestSrc, Bit, Invert, DAG);
24140 
24141   if (TestSrc == NewTestSrc)
24142     return SDValue();
24143 
24144   unsigned NewOpc = N->getOpcode();
24145   if (Invert) {
24146     if (NewOpc == AArch64ISD::TBZ)
24147       NewOpc = AArch64ISD::TBNZ;
24148     else {
24149       assert(NewOpc == AArch64ISD::TBNZ);
24150       NewOpc = AArch64ISD::TBZ;
24151     }
24152   }
24153 
24154   SDLoc DL(N);
24155   return DAG.getNode(NewOpc, DL, MVT::Other, N->getOperand(0), NewTestSrc,
24156                      DAG.getConstant(Bit, DL, MVT::i64), N->getOperand(3));
24157 }
24158 
24159 // Swap vselect operands where it may allow a predicated operation to achieve
24160 // the `sel`.
24161 //
24162 //     (vselect (setcc ( condcode) (_) (_)) (a)          (op (a) (b)))
24163 //  => (vselect (setcc (!condcode) (_) (_)) (op (a) (b)) (a))
trySwapVSelectOperands(SDNode * N,SelectionDAG & DAG)24164 static SDValue trySwapVSelectOperands(SDNode *N, SelectionDAG &DAG) {
24165   auto SelectA = N->getOperand(1);
24166   auto SelectB = N->getOperand(2);
24167   auto NTy = N->getValueType(0);
24168 
24169   if (!NTy.isScalableVector())
24170     return SDValue();
24171   SDValue SetCC = N->getOperand(0);
24172   if (SetCC.getOpcode() != ISD::SETCC || !SetCC.hasOneUse())
24173     return SDValue();
24174 
24175   switch (SelectB.getOpcode()) {
24176   default:
24177     return SDValue();
24178   case ISD::FMUL:
24179   case ISD::FSUB:
24180   case ISD::FADD:
24181     break;
24182   }
24183   if (SelectA != SelectB.getOperand(0))
24184     return SDValue();
24185 
24186   ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get();
24187   ISD::CondCode InverseCC =
24188       ISD::getSetCCInverse(CC, SetCC.getOperand(0).getValueType());
24189   auto InverseSetCC =
24190       DAG.getSetCC(SDLoc(SetCC), SetCC.getValueType(), SetCC.getOperand(0),
24191                    SetCC.getOperand(1), InverseCC);
24192 
24193   return DAG.getNode(ISD::VSELECT, SDLoc(N), NTy,
24194                      {InverseSetCC, SelectB, SelectA});
24195 }
24196 
24197 // vselect (v1i1 setcc) ->
24198 //     vselect (v1iXX setcc)  (XX is the size of the compared operand type)
24199 // FIXME: Currently the type legalizer can't handle VSELECT having v1i1 as
24200 // condition. If it can legalize "VSELECT v1i1" correctly, no need to combine
24201 // such VSELECT.
performVSelectCombine(SDNode * N,SelectionDAG & DAG)24202 static SDValue performVSelectCombine(SDNode *N, SelectionDAG &DAG) {
24203   if (auto SwapResult = trySwapVSelectOperands(N, DAG))
24204     return SwapResult;
24205 
24206   SDValue N0 = N->getOperand(0);
24207   EVT CCVT = N0.getValueType();
24208 
24209   if (isAllActivePredicate(DAG, N0))
24210     return N->getOperand(1);
24211 
24212   if (isAllInactivePredicate(N0))
24213     return N->getOperand(2);
24214 
24215   // Check for sign pattern (VSELECT setgt, iN lhs, -1, 1, -1) and transform
24216   // into (OR (ASR lhs, N-1), 1), which requires less instructions for the
24217   // supported types.
24218   SDValue SetCC = N->getOperand(0);
24219   if (SetCC.getOpcode() == ISD::SETCC &&
24220       SetCC.getOperand(2) == DAG.getCondCode(ISD::SETGT)) {
24221     SDValue CmpLHS = SetCC.getOperand(0);
24222     EVT VT = CmpLHS.getValueType();
24223     SDNode *CmpRHS = SetCC.getOperand(1).getNode();
24224     SDNode *SplatLHS = N->getOperand(1).getNode();
24225     SDNode *SplatRHS = N->getOperand(2).getNode();
24226     APInt SplatLHSVal;
24227     if (CmpLHS.getValueType() == N->getOperand(1).getValueType() &&
24228         VT.isSimple() &&
24229         is_contained(ArrayRef({MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16,
24230                                MVT::v2i32, MVT::v4i32, MVT::v2i64}),
24231                      VT.getSimpleVT().SimpleTy) &&
24232         ISD::isConstantSplatVector(SplatLHS, SplatLHSVal) &&
24233         SplatLHSVal.isOne() && ISD::isConstantSplatVectorAllOnes(CmpRHS) &&
24234         ISD::isConstantSplatVectorAllOnes(SplatRHS)) {
24235       unsigned NumElts = VT.getVectorNumElements();
24236       SmallVector<SDValue, 8> Ops(
24237           NumElts, DAG.getConstant(VT.getScalarSizeInBits() - 1, SDLoc(N),
24238                                    VT.getScalarType()));
24239       SDValue Val = DAG.getBuildVector(VT, SDLoc(N), Ops);
24240 
24241       auto Shift = DAG.getNode(ISD::SRA, SDLoc(N), VT, CmpLHS, Val);
24242       auto Or = DAG.getNode(ISD::OR, SDLoc(N), VT, Shift, N->getOperand(1));
24243       return Or;
24244     }
24245   }
24246 
24247   EVT CmpVT = N0.getOperand(0).getValueType();
24248   if (N0.getOpcode() != ISD::SETCC ||
24249       CCVT.getVectorElementCount() != ElementCount::getFixed(1) ||
24250       CCVT.getVectorElementType() != MVT::i1 ||
24251       CmpVT.getVectorElementType().isFloatingPoint())
24252     return SDValue();
24253 
24254   EVT ResVT = N->getValueType(0);
24255   // Only combine when the result type is of the same size as the compared
24256   // operands.
24257   if (ResVT.getSizeInBits() != CmpVT.getSizeInBits())
24258     return SDValue();
24259 
24260   SDValue IfTrue = N->getOperand(1);
24261   SDValue IfFalse = N->getOperand(2);
24262   SetCC = DAG.getSetCC(SDLoc(N), CmpVT.changeVectorElementTypeToInteger(),
24263                        N0.getOperand(0), N0.getOperand(1),
24264                        cast<CondCodeSDNode>(N0.getOperand(2))->get());
24265   return DAG.getNode(ISD::VSELECT, SDLoc(N), ResVT, SetCC,
24266                      IfTrue, IfFalse);
24267 }
24268 
24269 /// A vector select: "(select vL, vR, (setcc LHS, RHS))" is best performed with
24270 /// the compare-mask instructions rather than going via NZCV, even if LHS and
24271 /// RHS are really scalar. This replaces any scalar setcc in the above pattern
24272 /// with a vector one followed by a DUP shuffle on the result.
performSelectCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)24273 static SDValue performSelectCombine(SDNode *N,
24274                                     TargetLowering::DAGCombinerInfo &DCI) {
24275   SelectionDAG &DAG = DCI.DAG;
24276   SDValue N0 = N->getOperand(0);
24277   EVT ResVT = N->getValueType(0);
24278 
24279   if (N0.getOpcode() != ISD::SETCC)
24280     return SDValue();
24281 
24282   if (ResVT.isScalableVT())
24283     return SDValue();
24284 
24285   // Make sure the SETCC result is either i1 (initial DAG), or i32, the lowered
24286   // scalar SetCCResultType. We also don't expect vectors, because we assume
24287   // that selects fed by vector SETCCs are canonicalized to VSELECT.
24288   assert((N0.getValueType() == MVT::i1 || N0.getValueType() == MVT::i32) &&
24289          "Scalar-SETCC feeding SELECT has unexpected result type!");
24290 
24291   // If NumMaskElts == 0, the comparison is larger than select result. The
24292   // largest real NEON comparison is 64-bits per lane, which means the result is
24293   // at most 32-bits and an illegal vector. Just bail out for now.
24294   EVT SrcVT = N0.getOperand(0).getValueType();
24295 
24296   // Don't try to do this optimization when the setcc itself has i1 operands.
24297   // There are no legal vectors of i1, so this would be pointless. v1f16 is
24298   // ruled out to prevent the creation of setcc that need to be scalarized.
24299   if (SrcVT == MVT::i1 ||
24300       (SrcVT.isFloatingPoint() && SrcVT.getSizeInBits() <= 16))
24301     return SDValue();
24302 
24303   int NumMaskElts = ResVT.getSizeInBits() / SrcVT.getSizeInBits();
24304   if (!ResVT.isVector() || NumMaskElts == 0)
24305     return SDValue();
24306 
24307   SrcVT = EVT::getVectorVT(*DAG.getContext(), SrcVT, NumMaskElts);
24308   EVT CCVT = SrcVT.changeVectorElementTypeToInteger();
24309 
24310   // Also bail out if the vector CCVT isn't the same size as ResVT.
24311   // This can happen if the SETCC operand size doesn't divide the ResVT size
24312   // (e.g., f64 vs v3f32).
24313   if (CCVT.getSizeInBits() != ResVT.getSizeInBits())
24314     return SDValue();
24315 
24316   // Make sure we didn't create illegal types, if we're not supposed to.
24317   assert(DCI.isBeforeLegalize() ||
24318          DAG.getTargetLoweringInfo().isTypeLegal(SrcVT));
24319 
24320   // First perform a vector comparison, where lane 0 is the one we're interested
24321   // in.
24322   SDLoc DL(N0);
24323   SDValue LHS =
24324       DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, SrcVT, N0.getOperand(0));
24325   SDValue RHS =
24326       DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, SrcVT, N0.getOperand(1));
24327   SDValue SetCC = DAG.getNode(ISD::SETCC, DL, CCVT, LHS, RHS, N0.getOperand(2));
24328 
24329   // Now duplicate the comparison mask we want across all other lanes.
24330   SmallVector<int, 8> DUPMask(CCVT.getVectorNumElements(), 0);
24331   SDValue Mask = DAG.getVectorShuffle(CCVT, DL, SetCC, SetCC, DUPMask);
24332   Mask = DAG.getNode(ISD::BITCAST, DL,
24333                      ResVT.changeVectorElementTypeToInteger(), Mask);
24334 
24335   return DAG.getSelect(DL, ResVT, Mask, N->getOperand(1), N->getOperand(2));
24336 }
24337 
performDUPCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)24338 static SDValue performDUPCombine(SDNode *N,
24339                                  TargetLowering::DAGCombinerInfo &DCI) {
24340   EVT VT = N->getValueType(0);
24341   SDLoc DL(N);
24342   // If "v2i32 DUP(x)" and "v4i32 DUP(x)" both exist, use an extract from the
24343   // 128bit vector version.
24344   if (VT.is64BitVector() && DCI.isAfterLegalizeDAG()) {
24345     EVT LVT = VT.getDoubleNumVectorElementsVT(*DCI.DAG.getContext());
24346     SmallVector<SDValue> Ops(N->ops());
24347     if (SDNode *LN = DCI.DAG.getNodeIfExists(N->getOpcode(),
24348                                              DCI.DAG.getVTList(LVT), Ops)) {
24349       return DCI.DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, SDValue(LN, 0),
24350                              DCI.DAG.getConstant(0, DL, MVT::i64));
24351     }
24352   }
24353 
24354   if (N->getOpcode() == AArch64ISD::DUP) {
24355     if (DCI.isAfterLegalizeDAG()) {
24356       // If scalar dup's operand is extract_vector_elt, try to combine them into
24357       // duplane. For example,
24358       //
24359       //    t21: i32 = extract_vector_elt t19, Constant:i64<0>
24360       //  t18: v4i32 = AArch64ISD::DUP t21
24361       //  ==>
24362       //  t22: v4i32 = AArch64ISD::DUPLANE32 t19, Constant:i64<0>
24363       SDValue EXTRACT_VEC_ELT = N->getOperand(0);
24364       if (EXTRACT_VEC_ELT.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
24365         if (VT == EXTRACT_VEC_ELT.getOperand(0).getValueType()) {
24366           unsigned Opcode = getDUPLANEOp(VT.getVectorElementType());
24367           return DCI.DAG.getNode(Opcode, DL, VT, EXTRACT_VEC_ELT.getOperand(0),
24368                                  EXTRACT_VEC_ELT.getOperand(1));
24369         }
24370       }
24371     }
24372 
24373     return performPostLD1Combine(N, DCI, false);
24374   }
24375 
24376   return SDValue();
24377 }
24378 
24379 /// Get rid of unnecessary NVCASTs (that don't change the type).
performNVCASTCombine(SDNode * N,SelectionDAG & DAG)24380 static SDValue performNVCASTCombine(SDNode *N, SelectionDAG &DAG) {
24381   if (N->getValueType(0) == N->getOperand(0).getValueType())
24382     return N->getOperand(0);
24383   if (N->getOperand(0).getOpcode() == AArch64ISD::NVCAST)
24384     return DAG.getNode(AArch64ISD::NVCAST, SDLoc(N), N->getValueType(0),
24385                        N->getOperand(0).getOperand(0));
24386 
24387   return SDValue();
24388 }
24389 
24390 // If all users of the globaladdr are of the form (globaladdr + constant), find
24391 // the smallest constant, fold it into the globaladdr's offset and rewrite the
24392 // globaladdr as (globaladdr + constant) - constant.
performGlobalAddressCombine(SDNode * N,SelectionDAG & DAG,const AArch64Subtarget * Subtarget,const TargetMachine & TM)24393 static SDValue performGlobalAddressCombine(SDNode *N, SelectionDAG &DAG,
24394                                            const AArch64Subtarget *Subtarget,
24395                                            const TargetMachine &TM) {
24396   auto *GN = cast<GlobalAddressSDNode>(N);
24397   if (Subtarget->ClassifyGlobalReference(GN->getGlobal(), TM) !=
24398       AArch64II::MO_NO_FLAG)
24399     return SDValue();
24400 
24401   uint64_t MinOffset = -1ull;
24402   for (SDNode *N : GN->uses()) {
24403     if (N->getOpcode() != ISD::ADD)
24404       return SDValue();
24405     auto *C = dyn_cast<ConstantSDNode>(N->getOperand(0));
24406     if (!C)
24407       C = dyn_cast<ConstantSDNode>(N->getOperand(1));
24408     if (!C)
24409       return SDValue();
24410     MinOffset = std::min(MinOffset, C->getZExtValue());
24411   }
24412   uint64_t Offset = MinOffset + GN->getOffset();
24413 
24414   // Require that the new offset is larger than the existing one. Otherwise, we
24415   // can end up oscillating between two possible DAGs, for example,
24416   // (add (add globaladdr + 10, -1), 1) and (add globaladdr + 9, 1).
24417   if (Offset <= uint64_t(GN->getOffset()))
24418     return SDValue();
24419 
24420   // Check whether folding this offset is legal. It must not go out of bounds of
24421   // the referenced object to avoid violating the code model, and must be
24422   // smaller than 2^20 because this is the largest offset expressible in all
24423   // object formats. (The IMAGE_REL_ARM64_PAGEBASE_REL21 relocation in COFF
24424   // stores an immediate signed 21 bit offset.)
24425   //
24426   // This check also prevents us from folding negative offsets, which will end
24427   // up being treated in the same way as large positive ones. They could also
24428   // cause code model violations, and aren't really common enough to matter.
24429   if (Offset >= (1 << 20))
24430     return SDValue();
24431 
24432   const GlobalValue *GV = GN->getGlobal();
24433   Type *T = GV->getValueType();
24434   if (!T->isSized() ||
24435       Offset > GV->getDataLayout().getTypeAllocSize(T))
24436     return SDValue();
24437 
24438   SDLoc DL(GN);
24439   SDValue Result = DAG.getGlobalAddress(GV, DL, MVT::i64, Offset);
24440   return DAG.getNode(ISD::SUB, DL, MVT::i64, Result,
24441                      DAG.getConstant(MinOffset, DL, MVT::i64));
24442 }
24443 
performCTLZCombine(SDNode * N,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)24444 static SDValue performCTLZCombine(SDNode *N, SelectionDAG &DAG,
24445                                   const AArch64Subtarget *Subtarget) {
24446   SDValue BR = N->getOperand(0);
24447   if (!Subtarget->hasCSSC() || BR.getOpcode() != ISD::BITREVERSE ||
24448       !BR.getValueType().isScalarInteger())
24449     return SDValue();
24450 
24451   SDLoc DL(N);
24452   return DAG.getNode(ISD::CTTZ, DL, BR.getValueType(), BR.getOperand(0));
24453 }
24454 
24455 // Turns the vector of indices into a vector of byte offstes by scaling Offset
24456 // by (BitWidth / 8).
getScaledOffsetForBitWidth(SelectionDAG & DAG,SDValue Offset,SDLoc DL,unsigned BitWidth)24457 static SDValue getScaledOffsetForBitWidth(SelectionDAG &DAG, SDValue Offset,
24458                                           SDLoc DL, unsigned BitWidth) {
24459   assert(Offset.getValueType().isScalableVector() &&
24460          "This method is only for scalable vectors of offsets");
24461 
24462   SDValue Shift = DAG.getConstant(Log2_32(BitWidth / 8), DL, MVT::i64);
24463   SDValue SplatShift = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv2i64, Shift);
24464 
24465   return DAG.getNode(ISD::SHL, DL, MVT::nxv2i64, Offset, SplatShift);
24466 }
24467 
24468 /// Check if the value of \p OffsetInBytes can be used as an immediate for
24469 /// the gather load/prefetch and scatter store instructions with vector base and
24470 /// immediate offset addressing mode:
24471 ///
24472 ///      [<Zn>.[S|D]{, #<imm>}]
24473 ///
24474 /// where <imm> = sizeof(<T>) * k, for k = 0, 1, ..., 31.
isValidImmForSVEVecImmAddrMode(unsigned OffsetInBytes,unsigned ScalarSizeInBytes)24475 inline static bool isValidImmForSVEVecImmAddrMode(unsigned OffsetInBytes,
24476                                                   unsigned ScalarSizeInBytes) {
24477   // The immediate is not a multiple of the scalar size.
24478   if (OffsetInBytes % ScalarSizeInBytes)
24479     return false;
24480 
24481   // The immediate is out of range.
24482   if (OffsetInBytes / ScalarSizeInBytes > 31)
24483     return false;
24484 
24485   return true;
24486 }
24487 
24488 /// Check if the value of \p Offset represents a valid immediate for the SVE
24489 /// gather load/prefetch and scatter store instructiona with vector base and
24490 /// immediate offset addressing mode:
24491 ///
24492 ///      [<Zn>.[S|D]{, #<imm>}]
24493 ///
24494 /// where <imm> = sizeof(<T>) * k, for k = 0, 1, ..., 31.
isValidImmForSVEVecImmAddrMode(SDValue Offset,unsigned ScalarSizeInBytes)24495 static bool isValidImmForSVEVecImmAddrMode(SDValue Offset,
24496                                            unsigned ScalarSizeInBytes) {
24497   ConstantSDNode *OffsetConst = dyn_cast<ConstantSDNode>(Offset.getNode());
24498   return OffsetConst && isValidImmForSVEVecImmAddrMode(
24499                             OffsetConst->getZExtValue(), ScalarSizeInBytes);
24500 }
24501 
performScatterStoreCombine(SDNode * N,SelectionDAG & DAG,unsigned Opcode,bool OnlyPackedOffsets=true)24502 static SDValue performScatterStoreCombine(SDNode *N, SelectionDAG &DAG,
24503                                           unsigned Opcode,
24504                                           bool OnlyPackedOffsets = true) {
24505   const SDValue Src = N->getOperand(2);
24506   const EVT SrcVT = Src->getValueType(0);
24507   assert(SrcVT.isScalableVector() &&
24508          "Scatter stores are only possible for SVE vectors");
24509 
24510   SDLoc DL(N);
24511   MVT SrcElVT = SrcVT.getVectorElementType().getSimpleVT();
24512 
24513   // Make sure that source data will fit into an SVE register
24514   if (SrcVT.getSizeInBits().getKnownMinValue() > AArch64::SVEBitsPerBlock)
24515     return SDValue();
24516 
24517   // For FPs, ACLE only supports _packed_ single and double precision types.
24518   // SST1Q_[INDEX_]PRED is the ST1Q for sve2p1 and should allow all sizes.
24519   if (SrcElVT.isFloatingPoint())
24520     if ((SrcVT != MVT::nxv4f32) && (SrcVT != MVT::nxv2f64) &&
24521         ((Opcode != AArch64ISD::SST1Q_PRED &&
24522           Opcode != AArch64ISD::SST1Q_INDEX_PRED) ||
24523          ((SrcVT != MVT::nxv8f16) && (SrcVT != MVT::nxv8bf16))))
24524       return SDValue();
24525 
24526   // Depending on the addressing mode, this is either a pointer or a vector of
24527   // pointers (that fits into one register)
24528   SDValue Base = N->getOperand(4);
24529   // Depending on the addressing mode, this is either a single offset or a
24530   // vector of offsets  (that fits into one register)
24531   SDValue Offset = N->getOperand(5);
24532 
24533   // For "scalar + vector of indices", just scale the indices. This only
24534   // applies to non-temporal scatters because there's no instruction that takes
24535   // indices.
24536   if (Opcode == AArch64ISD::SSTNT1_INDEX_PRED) {
24537     Offset =
24538         getScaledOffsetForBitWidth(DAG, Offset, DL, SrcElVT.getSizeInBits());
24539     Opcode = AArch64ISD::SSTNT1_PRED;
24540   } else if (Opcode == AArch64ISD::SST1Q_INDEX_PRED) {
24541     Offset =
24542         getScaledOffsetForBitWidth(DAG, Offset, DL, SrcElVT.getSizeInBits());
24543     Opcode = AArch64ISD::SST1Q_PRED;
24544   }
24545 
24546   // In the case of non-temporal gather loads there's only one SVE instruction
24547   // per data-size: "scalar + vector", i.e.
24548   //    * stnt1{b|h|w|d} { z0.s }, p0/z, [z0.s, x0]
24549   // Since we do have intrinsics that allow the arguments to be in a different
24550   // order, we may need to swap them to match the spec.
24551   if ((Opcode == AArch64ISD::SSTNT1_PRED || Opcode == AArch64ISD::SST1Q_PRED) &&
24552       Offset.getValueType().isVector())
24553     std::swap(Base, Offset);
24554 
24555   // SST1_IMM requires that the offset is an immediate that is:
24556   //    * a multiple of #SizeInBytes,
24557   //    * in the range [0, 31 x #SizeInBytes],
24558   // where #SizeInBytes is the size in bytes of the stored items. For
24559   // immediates outside that range and non-immediate scalar offsets use SST1 or
24560   // SST1_UXTW instead.
24561   if (Opcode == AArch64ISD::SST1_IMM_PRED) {
24562     if (!isValidImmForSVEVecImmAddrMode(Offset,
24563                                         SrcVT.getScalarSizeInBits() / 8)) {
24564       if (MVT::nxv4i32 == Base.getValueType().getSimpleVT().SimpleTy)
24565         Opcode = AArch64ISD::SST1_UXTW_PRED;
24566       else
24567         Opcode = AArch64ISD::SST1_PRED;
24568 
24569       std::swap(Base, Offset);
24570     }
24571   }
24572 
24573   auto &TLI = DAG.getTargetLoweringInfo();
24574   if (!TLI.isTypeLegal(Base.getValueType()))
24575     return SDValue();
24576 
24577   // Some scatter store variants allow unpacked offsets, but only as nxv2i32
24578   // vectors. These are implicitly sign (sxtw) or zero (zxtw) extend to
24579   // nxv2i64. Legalize accordingly.
24580   if (!OnlyPackedOffsets &&
24581       Offset.getValueType().getSimpleVT().SimpleTy == MVT::nxv2i32)
24582     Offset = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::nxv2i64, Offset).getValue(0);
24583 
24584   if (!TLI.isTypeLegal(Offset.getValueType()))
24585     return SDValue();
24586 
24587   // Source value type that is representable in hardware
24588   EVT HwSrcVt = getSVEContainerType(SrcVT);
24589 
24590   // Keep the original type of the input data to store - this is needed to be
24591   // able to select the correct instruction, e.g. ST1B, ST1H, ST1W and ST1D. For
24592   // FP values we want the integer equivalent, so just use HwSrcVt.
24593   SDValue InputVT = DAG.getValueType(SrcVT);
24594   if (SrcVT.isFloatingPoint())
24595     InputVT = DAG.getValueType(HwSrcVt);
24596 
24597   SDVTList VTs = DAG.getVTList(MVT::Other);
24598   SDValue SrcNew;
24599 
24600   if (Src.getValueType().isFloatingPoint())
24601     SrcNew = DAG.getNode(ISD::BITCAST, DL, HwSrcVt, Src);
24602   else
24603     SrcNew = DAG.getNode(ISD::ANY_EXTEND, DL, HwSrcVt, Src);
24604 
24605   SDValue Ops[] = {N->getOperand(0), // Chain
24606                    SrcNew,
24607                    N->getOperand(3), // Pg
24608                    Base,
24609                    Offset,
24610                    InputVT};
24611 
24612   return DAG.getNode(Opcode, DL, VTs, Ops);
24613 }
24614 
performGatherLoadCombine(SDNode * N,SelectionDAG & DAG,unsigned Opcode,bool OnlyPackedOffsets=true)24615 static SDValue performGatherLoadCombine(SDNode *N, SelectionDAG &DAG,
24616                                         unsigned Opcode,
24617                                         bool OnlyPackedOffsets = true) {
24618   const EVT RetVT = N->getValueType(0);
24619   assert(RetVT.isScalableVector() &&
24620          "Gather loads are only possible for SVE vectors");
24621 
24622   SDLoc DL(N);
24623 
24624   // Make sure that the loaded data will fit into an SVE register
24625   if (RetVT.getSizeInBits().getKnownMinValue() > AArch64::SVEBitsPerBlock)
24626     return SDValue();
24627 
24628   // Depending on the addressing mode, this is either a pointer or a vector of
24629   // pointers (that fits into one register)
24630   SDValue Base = N->getOperand(3);
24631   // Depending on the addressing mode, this is either a single offset or a
24632   // vector of offsets  (that fits into one register)
24633   SDValue Offset = N->getOperand(4);
24634 
24635   // For "scalar + vector of indices", scale the indices to obtain unscaled
24636   // offsets. This applies to non-temporal and quadword gathers, which do not
24637   // have an addressing mode with scaled offset.
24638   if (Opcode == AArch64ISD::GLDNT1_INDEX_MERGE_ZERO) {
24639     Offset = getScaledOffsetForBitWidth(DAG, Offset, DL,
24640                                         RetVT.getScalarSizeInBits());
24641     Opcode = AArch64ISD::GLDNT1_MERGE_ZERO;
24642   } else if (Opcode == AArch64ISD::GLD1Q_INDEX_MERGE_ZERO) {
24643     Offset = getScaledOffsetForBitWidth(DAG, Offset, DL,
24644                                         RetVT.getScalarSizeInBits());
24645     Opcode = AArch64ISD::GLD1Q_MERGE_ZERO;
24646   }
24647 
24648   // In the case of non-temporal gather loads and quadword gather loads there's
24649   // only one addressing mode : "vector + scalar", e.g.
24650   //   ldnt1{b|h|w|d} { z0.s }, p0/z, [z0.s, x0]
24651   // Since we do have intrinsics that allow the arguments to be in a different
24652   // order, we may need to swap them to match the spec.
24653   if ((Opcode == AArch64ISD::GLDNT1_MERGE_ZERO ||
24654        Opcode == AArch64ISD::GLD1Q_MERGE_ZERO) &&
24655       Offset.getValueType().isVector())
24656     std::swap(Base, Offset);
24657 
24658   // GLD{FF}1_IMM requires that the offset is an immediate that is:
24659   //    * a multiple of #SizeInBytes,
24660   //    * in the range [0, 31 x #SizeInBytes],
24661   // where #SizeInBytes is the size in bytes of the loaded items. For
24662   // immediates outside that range and non-immediate scalar offsets use
24663   // GLD1_MERGE_ZERO or GLD1_UXTW_MERGE_ZERO instead.
24664   if (Opcode == AArch64ISD::GLD1_IMM_MERGE_ZERO ||
24665       Opcode == AArch64ISD::GLDFF1_IMM_MERGE_ZERO) {
24666     if (!isValidImmForSVEVecImmAddrMode(Offset,
24667                                         RetVT.getScalarSizeInBits() / 8)) {
24668       if (MVT::nxv4i32 == Base.getValueType().getSimpleVT().SimpleTy)
24669         Opcode = (Opcode == AArch64ISD::GLD1_IMM_MERGE_ZERO)
24670                      ? AArch64ISD::GLD1_UXTW_MERGE_ZERO
24671                      : AArch64ISD::GLDFF1_UXTW_MERGE_ZERO;
24672       else
24673         Opcode = (Opcode == AArch64ISD::GLD1_IMM_MERGE_ZERO)
24674                      ? AArch64ISD::GLD1_MERGE_ZERO
24675                      : AArch64ISD::GLDFF1_MERGE_ZERO;
24676 
24677       std::swap(Base, Offset);
24678     }
24679   }
24680 
24681   auto &TLI = DAG.getTargetLoweringInfo();
24682   if (!TLI.isTypeLegal(Base.getValueType()))
24683     return SDValue();
24684 
24685   // Some gather load variants allow unpacked offsets, but only as nxv2i32
24686   // vectors. These are implicitly sign (sxtw) or zero (zxtw) extend to
24687   // nxv2i64. Legalize accordingly.
24688   if (!OnlyPackedOffsets &&
24689       Offset.getValueType().getSimpleVT().SimpleTy == MVT::nxv2i32)
24690     Offset = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::nxv2i64, Offset).getValue(0);
24691 
24692   // Return value type that is representable in hardware
24693   EVT HwRetVt = getSVEContainerType(RetVT);
24694 
24695   // Keep the original output value type around - this is needed to be able to
24696   // select the correct instruction, e.g. LD1B, LD1H, LD1W and LD1D. For FP
24697   // values we want the integer equivalent, so just use HwRetVT.
24698   SDValue OutVT = DAG.getValueType(RetVT);
24699   if (RetVT.isFloatingPoint())
24700     OutVT = DAG.getValueType(HwRetVt);
24701 
24702   SDVTList VTs = DAG.getVTList(HwRetVt, MVT::Other);
24703   SDValue Ops[] = {N->getOperand(0), // Chain
24704                    N->getOperand(2), // Pg
24705                    Base, Offset, OutVT};
24706 
24707   SDValue Load = DAG.getNode(Opcode, DL, VTs, Ops);
24708   SDValue LoadChain = SDValue(Load.getNode(), 1);
24709 
24710   if (RetVT.isInteger() && (RetVT != HwRetVt))
24711     Load = DAG.getNode(ISD::TRUNCATE, DL, RetVT, Load.getValue(0));
24712 
24713   // If the original return value was FP, bitcast accordingly. Doing it here
24714   // means that we can avoid adding TableGen patterns for FPs.
24715   if (RetVT.isFloatingPoint())
24716     Load = DAG.getNode(ISD::BITCAST, DL, RetVT, Load.getValue(0));
24717 
24718   return DAG.getMergeValues({Load, LoadChain}, DL);
24719 }
24720 
24721 static SDValue
performSignExtendInRegCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)24722 performSignExtendInRegCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
24723                               SelectionDAG &DAG) {
24724   SDLoc DL(N);
24725   SDValue Src = N->getOperand(0);
24726   unsigned Opc = Src->getOpcode();
24727 
24728   // Sign extend of an unsigned unpack -> signed unpack
24729   if (Opc == AArch64ISD::UUNPKHI || Opc == AArch64ISD::UUNPKLO) {
24730 
24731     unsigned SOpc = Opc == AArch64ISD::UUNPKHI ? AArch64ISD::SUNPKHI
24732                                                : AArch64ISD::SUNPKLO;
24733 
24734     // Push the sign extend to the operand of the unpack
24735     // This is necessary where, for example, the operand of the unpack
24736     // is another unpack:
24737     // 4i32 sign_extend_inreg (4i32 uunpklo(8i16 uunpklo (16i8 opnd)), from 4i8)
24738     // ->
24739     // 4i32 sunpklo (8i16 sign_extend_inreg(8i16 uunpklo (16i8 opnd), from 8i8)
24740     // ->
24741     // 4i32 sunpklo(8i16 sunpklo(16i8 opnd))
24742     SDValue ExtOp = Src->getOperand(0);
24743     auto VT = cast<VTSDNode>(N->getOperand(1))->getVT();
24744     EVT EltTy = VT.getVectorElementType();
24745     (void)EltTy;
24746 
24747     assert((EltTy == MVT::i8 || EltTy == MVT::i16 || EltTy == MVT::i32) &&
24748            "Sign extending from an invalid type");
24749 
24750     EVT ExtVT = VT.getDoubleNumVectorElementsVT(*DAG.getContext());
24751 
24752     SDValue Ext = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, ExtOp.getValueType(),
24753                               ExtOp, DAG.getValueType(ExtVT));
24754 
24755     return DAG.getNode(SOpc, DL, N->getValueType(0), Ext);
24756   }
24757 
24758   if (DCI.isBeforeLegalizeOps())
24759     return SDValue();
24760 
24761   if (!EnableCombineMGatherIntrinsics)
24762     return SDValue();
24763 
24764   // SVE load nodes (e.g. AArch64ISD::GLD1) are straightforward candidates
24765   // for DAG Combine with SIGN_EXTEND_INREG. Bail out for all other nodes.
24766   unsigned NewOpc;
24767   unsigned MemVTOpNum = 4;
24768   switch (Opc) {
24769   case AArch64ISD::LD1_MERGE_ZERO:
24770     NewOpc = AArch64ISD::LD1S_MERGE_ZERO;
24771     MemVTOpNum = 3;
24772     break;
24773   case AArch64ISD::LDNF1_MERGE_ZERO:
24774     NewOpc = AArch64ISD::LDNF1S_MERGE_ZERO;
24775     MemVTOpNum = 3;
24776     break;
24777   case AArch64ISD::LDFF1_MERGE_ZERO:
24778     NewOpc = AArch64ISD::LDFF1S_MERGE_ZERO;
24779     MemVTOpNum = 3;
24780     break;
24781   case AArch64ISD::GLD1_MERGE_ZERO:
24782     NewOpc = AArch64ISD::GLD1S_MERGE_ZERO;
24783     break;
24784   case AArch64ISD::GLD1_SCALED_MERGE_ZERO:
24785     NewOpc = AArch64ISD::GLD1S_SCALED_MERGE_ZERO;
24786     break;
24787   case AArch64ISD::GLD1_SXTW_MERGE_ZERO:
24788     NewOpc = AArch64ISD::GLD1S_SXTW_MERGE_ZERO;
24789     break;
24790   case AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO:
24791     NewOpc = AArch64ISD::GLD1S_SXTW_SCALED_MERGE_ZERO;
24792     break;
24793   case AArch64ISD::GLD1_UXTW_MERGE_ZERO:
24794     NewOpc = AArch64ISD::GLD1S_UXTW_MERGE_ZERO;
24795     break;
24796   case AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO:
24797     NewOpc = AArch64ISD::GLD1S_UXTW_SCALED_MERGE_ZERO;
24798     break;
24799   case AArch64ISD::GLD1_IMM_MERGE_ZERO:
24800     NewOpc = AArch64ISD::GLD1S_IMM_MERGE_ZERO;
24801     break;
24802   case AArch64ISD::GLDFF1_MERGE_ZERO:
24803     NewOpc = AArch64ISD::GLDFF1S_MERGE_ZERO;
24804     break;
24805   case AArch64ISD::GLDFF1_SCALED_MERGE_ZERO:
24806     NewOpc = AArch64ISD::GLDFF1S_SCALED_MERGE_ZERO;
24807     break;
24808   case AArch64ISD::GLDFF1_SXTW_MERGE_ZERO:
24809     NewOpc = AArch64ISD::GLDFF1S_SXTW_MERGE_ZERO;
24810     break;
24811   case AArch64ISD::GLDFF1_SXTW_SCALED_MERGE_ZERO:
24812     NewOpc = AArch64ISD::GLDFF1S_SXTW_SCALED_MERGE_ZERO;
24813     break;
24814   case AArch64ISD::GLDFF1_UXTW_MERGE_ZERO:
24815     NewOpc = AArch64ISD::GLDFF1S_UXTW_MERGE_ZERO;
24816     break;
24817   case AArch64ISD::GLDFF1_UXTW_SCALED_MERGE_ZERO:
24818     NewOpc = AArch64ISD::GLDFF1S_UXTW_SCALED_MERGE_ZERO;
24819     break;
24820   case AArch64ISD::GLDFF1_IMM_MERGE_ZERO:
24821     NewOpc = AArch64ISD::GLDFF1S_IMM_MERGE_ZERO;
24822     break;
24823   case AArch64ISD::GLDNT1_MERGE_ZERO:
24824     NewOpc = AArch64ISD::GLDNT1S_MERGE_ZERO;
24825     break;
24826   default:
24827     return SDValue();
24828   }
24829 
24830   EVT SignExtSrcVT = cast<VTSDNode>(N->getOperand(1))->getVT();
24831   EVT SrcMemVT = cast<VTSDNode>(Src->getOperand(MemVTOpNum))->getVT();
24832 
24833   if ((SignExtSrcVT != SrcMemVT) || !Src.hasOneUse())
24834     return SDValue();
24835 
24836   EVT DstVT = N->getValueType(0);
24837   SDVTList VTs = DAG.getVTList(DstVT, MVT::Other);
24838 
24839   SmallVector<SDValue, 5> Ops;
24840   for (unsigned I = 0; I < Src->getNumOperands(); ++I)
24841     Ops.push_back(Src->getOperand(I));
24842 
24843   SDValue ExtLoad = DAG.getNode(NewOpc, SDLoc(N), VTs, Ops);
24844   DCI.CombineTo(N, ExtLoad);
24845   DCI.CombineTo(Src.getNode(), ExtLoad, ExtLoad.getValue(1));
24846 
24847   // Return N so it doesn't get rechecked
24848   return SDValue(N, 0);
24849 }
24850 
24851 /// Legalize the gather prefetch (scalar + vector addressing mode) when the
24852 /// offset vector is an unpacked 32-bit scalable vector. The other cases (Offset
24853 /// != nxv2i32) do not need legalization.
legalizeSVEGatherPrefetchOffsVec(SDNode * N,SelectionDAG & DAG)24854 static SDValue legalizeSVEGatherPrefetchOffsVec(SDNode *N, SelectionDAG &DAG) {
24855   const unsigned OffsetPos = 4;
24856   SDValue Offset = N->getOperand(OffsetPos);
24857 
24858   // Not an unpacked vector, bail out.
24859   if (Offset.getValueType().getSimpleVT().SimpleTy != MVT::nxv2i32)
24860     return SDValue();
24861 
24862   // Extend the unpacked offset vector to 64-bit lanes.
24863   SDLoc DL(N);
24864   Offset = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::nxv2i64, Offset);
24865   SmallVector<SDValue, 5> Ops(N->op_begin(), N->op_end());
24866   // Replace the offset operand with the 64-bit one.
24867   Ops[OffsetPos] = Offset;
24868 
24869   return DAG.getNode(N->getOpcode(), DL, DAG.getVTList(MVT::Other), Ops);
24870 }
24871 
24872 /// Combines a node carrying the intrinsic
24873 /// `aarch64_sve_prf<T>_gather_scalar_offset` into a node that uses
24874 /// `aarch64_sve_prfb_gather_uxtw_index` when the scalar offset passed to
24875 /// `aarch64_sve_prf<T>_gather_scalar_offset` is not a valid immediate for the
24876 /// sve gather prefetch instruction with vector plus immediate addressing mode.
combineSVEPrefetchVecBaseImmOff(SDNode * N,SelectionDAG & DAG,unsigned ScalarSizeInBytes)24877 static SDValue combineSVEPrefetchVecBaseImmOff(SDNode *N, SelectionDAG &DAG,
24878                                                unsigned ScalarSizeInBytes) {
24879   const unsigned ImmPos = 4, OffsetPos = 3;
24880   // No need to combine the node if the immediate is valid...
24881   if (isValidImmForSVEVecImmAddrMode(N->getOperand(ImmPos), ScalarSizeInBytes))
24882     return SDValue();
24883 
24884   // ...otherwise swap the offset base with the offset...
24885   SmallVector<SDValue, 5> Ops(N->op_begin(), N->op_end());
24886   std::swap(Ops[ImmPos], Ops[OffsetPos]);
24887   // ...and remap the intrinsic `aarch64_sve_prf<T>_gather_scalar_offset` to
24888   // `aarch64_sve_prfb_gather_uxtw_index`.
24889   SDLoc DL(N);
24890   Ops[1] = DAG.getConstant(Intrinsic::aarch64_sve_prfb_gather_uxtw_index, DL,
24891                            MVT::i64);
24892 
24893   return DAG.getNode(N->getOpcode(), DL, DAG.getVTList(MVT::Other), Ops);
24894 }
24895 
24896 // Return true if the vector operation can guarantee only the first lane of its
24897 // result contains data, with all bits in other lanes set to zero.
isLanes1toNKnownZero(SDValue Op)24898 static bool isLanes1toNKnownZero(SDValue Op) {
24899   switch (Op.getOpcode()) {
24900   default:
24901     return false;
24902   case AArch64ISD::ANDV_PRED:
24903   case AArch64ISD::EORV_PRED:
24904   case AArch64ISD::FADDA_PRED:
24905   case AArch64ISD::FADDV_PRED:
24906   case AArch64ISD::FMAXNMV_PRED:
24907   case AArch64ISD::FMAXV_PRED:
24908   case AArch64ISD::FMINNMV_PRED:
24909   case AArch64ISD::FMINV_PRED:
24910   case AArch64ISD::ORV_PRED:
24911   case AArch64ISD::SADDV_PRED:
24912   case AArch64ISD::SMAXV_PRED:
24913   case AArch64ISD::SMINV_PRED:
24914   case AArch64ISD::UADDV_PRED:
24915   case AArch64ISD::UMAXV_PRED:
24916   case AArch64ISD::UMINV_PRED:
24917     return true;
24918   }
24919 }
24920 
removeRedundantInsertVectorElt(SDNode * N)24921 static SDValue removeRedundantInsertVectorElt(SDNode *N) {
24922   assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT && "Unexpected node!");
24923   SDValue InsertVec = N->getOperand(0);
24924   SDValue InsertElt = N->getOperand(1);
24925   SDValue InsertIdx = N->getOperand(2);
24926 
24927   // We only care about inserts into the first element...
24928   if (!isNullConstant(InsertIdx))
24929     return SDValue();
24930   // ...of a zero'd vector...
24931   if (!ISD::isConstantSplatVectorAllZeros(InsertVec.getNode()))
24932     return SDValue();
24933   // ...where the inserted data was previously extracted...
24934   if (InsertElt.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
24935     return SDValue();
24936 
24937   SDValue ExtractVec = InsertElt.getOperand(0);
24938   SDValue ExtractIdx = InsertElt.getOperand(1);
24939 
24940   // ...from the first element of a vector.
24941   if (!isNullConstant(ExtractIdx))
24942     return SDValue();
24943 
24944   // If we get here we are effectively trying to zero lanes 1-N of a vector.
24945 
24946   // Ensure there's no type conversion going on.
24947   if (N->getValueType(0) != ExtractVec.getValueType())
24948     return SDValue();
24949 
24950   if (!isLanes1toNKnownZero(ExtractVec))
24951     return SDValue();
24952 
24953   // The explicit zeroing is redundant.
24954   return ExtractVec;
24955 }
24956 
24957 static SDValue
performInsertVectorEltCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)24958 performInsertVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
24959   if (SDValue Res = removeRedundantInsertVectorElt(N))
24960     return Res;
24961 
24962   return performPostLD1Combine(N, DCI, true);
24963 }
24964 
performFPExtendCombine(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const AArch64Subtarget * Subtarget)24965 static SDValue performFPExtendCombine(SDNode *N, SelectionDAG &DAG,
24966                                       TargetLowering::DAGCombinerInfo &DCI,
24967                                       const AArch64Subtarget *Subtarget) {
24968   SDValue N0 = N->getOperand(0);
24969   EVT VT = N->getValueType(0);
24970 
24971   // If this is fp_round(fpextend), don't fold it, allow ourselves to be folded.
24972   if (N->hasOneUse() && N->use_begin()->getOpcode() == ISD::FP_ROUND)
24973     return SDValue();
24974 
24975   auto hasValidElementTypeForFPExtLoad = [](EVT VT) {
24976     EVT EltVT = VT.getVectorElementType();
24977     return EltVT == MVT::f32 || EltVT == MVT::f64;
24978   };
24979 
24980   // fold (fpext (load x)) -> (fpext (fptrunc (extload x)))
24981   // We purposefully don't care about legality of the nodes here as we know
24982   // they can be split down into something legal.
24983   if (DCI.isBeforeLegalizeOps() && ISD::isNormalLoad(N0.getNode()) &&
24984       N0.hasOneUse() && Subtarget->useSVEForFixedLengthVectors() &&
24985       VT.isFixedLengthVector() && hasValidElementTypeForFPExtLoad(VT) &&
24986       VT.getFixedSizeInBits() >= Subtarget->getMinSVEVectorSizeInBits()) {
24987     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
24988     SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, SDLoc(N), VT,
24989                                      LN0->getChain(), LN0->getBasePtr(),
24990                                      N0.getValueType(), LN0->getMemOperand());
24991     DCI.CombineTo(N, ExtLoad);
24992     DCI.CombineTo(
24993         N0.getNode(),
24994         DAG.getNode(ISD::FP_ROUND, SDLoc(N0), N0.getValueType(), ExtLoad,
24995                     DAG.getIntPtrConstant(1, SDLoc(N0), /*isTarget=*/true)),
24996         ExtLoad.getValue(1));
24997     return SDValue(N, 0); // Return N so it doesn't get rechecked!
24998   }
24999 
25000   return SDValue();
25001 }
25002 
performBSPExpandForSVE(SDNode * N,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)25003 static SDValue performBSPExpandForSVE(SDNode *N, SelectionDAG &DAG,
25004                                       const AArch64Subtarget *Subtarget) {
25005   EVT VT = N->getValueType(0);
25006 
25007   // Don't expand for NEON, SVE2 or SME
25008   if (!VT.isScalableVector() || Subtarget->hasSVE2() || Subtarget->hasSME())
25009     return SDValue();
25010 
25011   SDLoc DL(N);
25012 
25013   SDValue Mask = N->getOperand(0);
25014   SDValue In1 = N->getOperand(1);
25015   SDValue In2 = N->getOperand(2);
25016 
25017   SDValue InvMask = DAG.getNOT(DL, Mask, VT);
25018   SDValue Sel = DAG.getNode(ISD::AND, DL, VT, Mask, In1);
25019   SDValue SelInv = DAG.getNode(ISD::AND, DL, VT, InvMask, In2);
25020   return DAG.getNode(ISD::OR, DL, VT, Sel, SelInv);
25021 }
25022 
performDupLane128Combine(SDNode * N,SelectionDAG & DAG)25023 static SDValue performDupLane128Combine(SDNode *N, SelectionDAG &DAG) {
25024   EVT VT = N->getValueType(0);
25025 
25026   SDValue Insert = N->getOperand(0);
25027   if (Insert.getOpcode() != ISD::INSERT_SUBVECTOR)
25028     return SDValue();
25029 
25030   if (!Insert.getOperand(0).isUndef())
25031     return SDValue();
25032 
25033   uint64_t IdxInsert = Insert.getConstantOperandVal(2);
25034   uint64_t IdxDupLane = N->getConstantOperandVal(1);
25035   if (IdxInsert != 0 || IdxDupLane != 0)
25036     return SDValue();
25037 
25038   SDValue Bitcast = Insert.getOperand(1);
25039   if (Bitcast.getOpcode() != ISD::BITCAST)
25040     return SDValue();
25041 
25042   SDValue Subvec = Bitcast.getOperand(0);
25043   EVT SubvecVT = Subvec.getValueType();
25044   if (!SubvecVT.is128BitVector())
25045     return SDValue();
25046   EVT NewSubvecVT =
25047       getPackedSVEVectorVT(Subvec.getValueType().getVectorElementType());
25048 
25049   SDLoc DL(N);
25050   SDValue NewInsert =
25051       DAG.getNode(ISD::INSERT_SUBVECTOR, DL, NewSubvecVT,
25052                   DAG.getUNDEF(NewSubvecVT), Subvec, Insert->getOperand(2));
25053   SDValue NewDuplane128 = DAG.getNode(AArch64ISD::DUPLANE128, DL, NewSubvecVT,
25054                                       NewInsert, N->getOperand(1));
25055   return DAG.getNode(ISD::BITCAST, DL, VT, NewDuplane128);
25056 }
25057 
25058 // Try to combine mull with uzp1.
tryCombineMULLWithUZP1(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)25059 static SDValue tryCombineMULLWithUZP1(SDNode *N,
25060                                       TargetLowering::DAGCombinerInfo &DCI,
25061                                       SelectionDAG &DAG) {
25062   if (DCI.isBeforeLegalizeOps())
25063     return SDValue();
25064 
25065   SDValue LHS = N->getOperand(0);
25066   SDValue RHS = N->getOperand(1);
25067 
25068   SDValue ExtractHigh;
25069   SDValue ExtractLow;
25070   SDValue TruncHigh;
25071   SDValue TruncLow;
25072   SDLoc DL(N);
25073 
25074   // Check the operands are trunc and extract_high.
25075   if (isEssentiallyExtractHighSubvector(LHS) &&
25076       RHS.getOpcode() == ISD::TRUNCATE) {
25077     TruncHigh = RHS;
25078     if (LHS.getOpcode() == ISD::BITCAST)
25079       ExtractHigh = LHS.getOperand(0);
25080     else
25081       ExtractHigh = LHS;
25082   } else if (isEssentiallyExtractHighSubvector(RHS) &&
25083              LHS.getOpcode() == ISD::TRUNCATE) {
25084     TruncHigh = LHS;
25085     if (RHS.getOpcode() == ISD::BITCAST)
25086       ExtractHigh = RHS.getOperand(0);
25087     else
25088       ExtractHigh = RHS;
25089   } else
25090     return SDValue();
25091 
25092   // If the truncate's operand is BUILD_VECTOR with DUP, do not combine the op
25093   // with uzp1.
25094   // You can see the regressions on test/CodeGen/AArch64/aarch64-smull.ll
25095   SDValue TruncHighOp = TruncHigh.getOperand(0);
25096   EVT TruncHighOpVT = TruncHighOp.getValueType();
25097   if (TruncHighOp.getOpcode() == AArch64ISD::DUP ||
25098       DAG.isSplatValue(TruncHighOp, false))
25099     return SDValue();
25100 
25101   // Check there is other extract_high with same source vector.
25102   // For example,
25103   //
25104   //    t18: v4i16 = extract_subvector t2, Constant:i64<0>
25105   //    t12: v4i16 = truncate t11
25106   //  t31: v4i32 = AArch64ISD::SMULL t18, t12
25107   //    t23: v4i16 = extract_subvector t2, Constant:i64<4>
25108   //    t16: v4i16 = truncate t15
25109   //  t30: v4i32 = AArch64ISD::SMULL t23, t1
25110   //
25111   // This dagcombine assumes the two extract_high uses same source vector in
25112   // order to detect the pair of the mull. If they have different source vector,
25113   // this code will not work.
25114   // TODO: Should also try to look through a bitcast.
25115   bool HasFoundMULLow = true;
25116   SDValue ExtractHighSrcVec = ExtractHigh.getOperand(0);
25117   if (ExtractHighSrcVec->use_size() != 2)
25118     HasFoundMULLow = false;
25119 
25120   // Find ExtractLow.
25121   for (SDNode *User : ExtractHighSrcVec.getNode()->uses()) {
25122     if (User == ExtractHigh.getNode())
25123       continue;
25124 
25125     if (User->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
25126         !isNullConstant(User->getOperand(1))) {
25127       HasFoundMULLow = false;
25128       break;
25129     }
25130 
25131     ExtractLow.setNode(User);
25132   }
25133 
25134   if (!ExtractLow || !ExtractLow->hasOneUse())
25135     HasFoundMULLow = false;
25136 
25137   // Check ExtractLow's user.
25138   if (HasFoundMULLow) {
25139     SDNode *ExtractLowUser = *ExtractLow.getNode()->use_begin();
25140     if (ExtractLowUser->getOpcode() != N->getOpcode()) {
25141       HasFoundMULLow = false;
25142     } else {
25143       if (ExtractLowUser->getOperand(0) == ExtractLow) {
25144         if (ExtractLowUser->getOperand(1).getOpcode() == ISD::TRUNCATE)
25145           TruncLow = ExtractLowUser->getOperand(1);
25146         else
25147           HasFoundMULLow = false;
25148       } else {
25149         if (ExtractLowUser->getOperand(0).getOpcode() == ISD::TRUNCATE)
25150           TruncLow = ExtractLowUser->getOperand(0);
25151         else
25152           HasFoundMULLow = false;
25153       }
25154     }
25155   }
25156 
25157   // If the truncate's operand is BUILD_VECTOR with DUP, do not combine the op
25158   // with uzp1.
25159   // You can see the regressions on test/CodeGen/AArch64/aarch64-smull.ll
25160   EVT TruncHighVT = TruncHigh.getValueType();
25161   EVT UZP1VT = TruncHighVT.getDoubleNumVectorElementsVT(*DAG.getContext());
25162   SDValue TruncLowOp =
25163       HasFoundMULLow ? TruncLow.getOperand(0) : DAG.getUNDEF(UZP1VT);
25164   EVT TruncLowOpVT = TruncLowOp.getValueType();
25165   if (HasFoundMULLow && (TruncLowOp.getOpcode() == AArch64ISD::DUP ||
25166                          DAG.isSplatValue(TruncLowOp, false)))
25167     return SDValue();
25168 
25169   // Create uzp1, extract_high and extract_low.
25170   if (TruncHighOpVT != UZP1VT)
25171     TruncHighOp = DAG.getNode(ISD::BITCAST, DL, UZP1VT, TruncHighOp);
25172   if (TruncLowOpVT != UZP1VT)
25173     TruncLowOp = DAG.getNode(ISD::BITCAST, DL, UZP1VT, TruncLowOp);
25174 
25175   SDValue UZP1 =
25176       DAG.getNode(AArch64ISD::UZP1, DL, UZP1VT, TruncLowOp, TruncHighOp);
25177   SDValue HighIdxCst =
25178       DAG.getConstant(TruncHighVT.getVectorNumElements(), DL, MVT::i64);
25179   SDValue NewTruncHigh =
25180       DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, TruncHighVT, UZP1, HighIdxCst);
25181   DAG.ReplaceAllUsesWith(TruncHigh, NewTruncHigh);
25182 
25183   if (HasFoundMULLow) {
25184     EVT TruncLowVT = TruncLow.getValueType();
25185     SDValue NewTruncLow = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, TruncLowVT,
25186                                       UZP1, ExtractLow.getOperand(1));
25187     DAG.ReplaceAllUsesWith(TruncLow, NewTruncLow);
25188   }
25189 
25190   return SDValue(N, 0);
25191 }
25192 
performMULLCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)25193 static SDValue performMULLCombine(SDNode *N,
25194                                   TargetLowering::DAGCombinerInfo &DCI,
25195                                   SelectionDAG &DAG) {
25196   if (SDValue Val =
25197           tryCombineLongOpWithDup(Intrinsic::not_intrinsic, N, DCI, DAG))
25198     return Val;
25199 
25200   if (SDValue Val = tryCombineMULLWithUZP1(N, DCI, DAG))
25201     return Val;
25202 
25203   return SDValue();
25204 }
25205 
25206 static SDValue
performScalarToVectorCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)25207 performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
25208                              SelectionDAG &DAG) {
25209   // Let's do below transform.
25210   //
25211   //         t34: v4i32 = AArch64ISD::UADDLV t2
25212   //       t35: i32 = extract_vector_elt t34, Constant:i64<0>
25213   //     t7: i64 = zero_extend t35
25214   //   t20: v1i64 = scalar_to_vector t7
25215   // ==>
25216   //      t34: v4i32 = AArch64ISD::UADDLV t2
25217   //    t39: v2i32 = extract_subvector t34, Constant:i64<0>
25218   //  t40: v1i64 = AArch64ISD::NVCAST t39
25219   if (DCI.isBeforeLegalizeOps())
25220     return SDValue();
25221 
25222   EVT VT = N->getValueType(0);
25223   if (VT != MVT::v1i64)
25224     return SDValue();
25225 
25226   SDValue ZEXT = N->getOperand(0);
25227   if (ZEXT.getOpcode() != ISD::ZERO_EXTEND || ZEXT.getValueType() != MVT::i64)
25228     return SDValue();
25229 
25230   SDValue EXTRACT_VEC_ELT = ZEXT.getOperand(0);
25231   if (EXTRACT_VEC_ELT.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
25232       EXTRACT_VEC_ELT.getValueType() != MVT::i32)
25233     return SDValue();
25234 
25235   if (!isNullConstant(EXTRACT_VEC_ELT.getOperand(1)))
25236     return SDValue();
25237 
25238   SDValue UADDLV = EXTRACT_VEC_ELT.getOperand(0);
25239   if (UADDLV.getOpcode() != AArch64ISD::UADDLV ||
25240       UADDLV.getValueType() != MVT::v4i32 ||
25241       UADDLV.getOperand(0).getValueType() != MVT::v8i8)
25242     return SDValue();
25243 
25244   // Let's generate new sequence with AArch64ISD::NVCAST.
25245   SDLoc DL(N);
25246   SDValue EXTRACT_SUBVEC =
25247       DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v2i32, UADDLV,
25248                   DAG.getConstant(0, DL, MVT::i64));
25249   SDValue NVCAST =
25250       DAG.getNode(AArch64ISD::NVCAST, DL, MVT::v1i64, EXTRACT_SUBVEC);
25251 
25252   return NVCAST;
25253 }
25254 
PerformDAGCombine(SDNode * N,DAGCombinerInfo & DCI) const25255 SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
25256                                                  DAGCombinerInfo &DCI) const {
25257   SelectionDAG &DAG = DCI.DAG;
25258   switch (N->getOpcode()) {
25259   default:
25260     LLVM_DEBUG(dbgs() << "Custom combining: skipping\n");
25261     break;
25262   case ISD::VECREDUCE_AND:
25263   case ISD::VECREDUCE_OR:
25264   case ISD::VECREDUCE_XOR:
25265     return performVecReduceBitwiseCombine(N, DCI, DAG);
25266   case ISD::ADD:
25267   case ISD::SUB:
25268     return performAddSubCombine(N, DCI);
25269   case ISD::BUILD_VECTOR:
25270     return performBuildVectorCombine(N, DCI, DAG);
25271   case ISD::TRUNCATE:
25272     return performTruncateCombine(N, DAG);
25273   case AArch64ISD::ANDS:
25274     return performFlagSettingCombine(N, DCI, ISD::AND);
25275   case AArch64ISD::ADC:
25276     if (auto R = foldOverflowCheck(N, DAG, /* IsAdd */ true))
25277       return R;
25278     return foldADCToCINC(N, DAG);
25279   case AArch64ISD::SBC:
25280     return foldOverflowCheck(N, DAG, /* IsAdd */ false);
25281   case AArch64ISD::ADCS:
25282     if (auto R = foldOverflowCheck(N, DAG, /* IsAdd */ true))
25283       return R;
25284     return performFlagSettingCombine(N, DCI, AArch64ISD::ADC);
25285   case AArch64ISD::SBCS:
25286     if (auto R = foldOverflowCheck(N, DAG, /* IsAdd */ false))
25287       return R;
25288     return performFlagSettingCombine(N, DCI, AArch64ISD::SBC);
25289   case AArch64ISD::BICi: {
25290     APInt DemandedBits =
25291         APInt::getAllOnes(N->getValueType(0).getScalarSizeInBits());
25292     APInt DemandedElts =
25293         APInt::getAllOnes(N->getValueType(0).getVectorNumElements());
25294 
25295     if (DAG.getTargetLoweringInfo().SimplifyDemandedBits(
25296             SDValue(N, 0), DemandedBits, DemandedElts, DCI))
25297       return SDValue();
25298 
25299     break;
25300   }
25301   case ISD::XOR:
25302     return performXorCombine(N, DAG, DCI, Subtarget);
25303   case ISD::MUL:
25304     return performMulCombine(N, DAG, DCI, Subtarget);
25305   case ISD::SINT_TO_FP:
25306   case ISD::UINT_TO_FP:
25307     return performIntToFpCombine(N, DAG, Subtarget);
25308   case ISD::FP_TO_SINT:
25309   case ISD::FP_TO_UINT:
25310   case ISD::FP_TO_SINT_SAT:
25311   case ISD::FP_TO_UINT_SAT:
25312     return performFpToIntCombine(N, DAG, DCI, Subtarget);
25313   case ISD::OR:
25314     return performORCombine(N, DCI, Subtarget, *this);
25315   case ISD::AND:
25316     return performANDCombine(N, DCI);
25317   case ISD::FADD:
25318     return performFADDCombine(N, DCI);
25319   case ISD::INTRINSIC_WO_CHAIN:
25320     return performIntrinsicCombine(N, DCI, Subtarget);
25321   case ISD::ANY_EXTEND:
25322   case ISD::ZERO_EXTEND:
25323   case ISD::SIGN_EXTEND:
25324     return performExtendCombine(N, DCI, DAG);
25325   case ISD::SIGN_EXTEND_INREG:
25326     return performSignExtendInRegCombine(N, DCI, DAG);
25327   case ISD::CONCAT_VECTORS:
25328     return performConcatVectorsCombine(N, DCI, DAG);
25329   case ISD::EXTRACT_SUBVECTOR:
25330     return performExtractSubvectorCombine(N, DCI, DAG);
25331   case ISD::INSERT_SUBVECTOR:
25332     return performInsertSubvectorCombine(N, DCI, DAG);
25333   case ISD::SELECT:
25334     return performSelectCombine(N, DCI);
25335   case ISD::VSELECT:
25336     return performVSelectCombine(N, DCI.DAG);
25337   case ISD::SETCC:
25338     return performSETCCCombine(N, DCI, DAG);
25339   case ISD::LOAD:
25340     return performLOADCombine(N, DCI, DAG, Subtarget);
25341   case ISD::STORE:
25342     return performSTORECombine(N, DCI, DAG, Subtarget);
25343   case ISD::MSTORE:
25344     return performMSTORECombine(N, DCI, DAG, Subtarget);
25345   case ISD::MGATHER:
25346   case ISD::MSCATTER:
25347     return performMaskedGatherScatterCombine(N, DCI, DAG);
25348   case ISD::FP_EXTEND:
25349     return performFPExtendCombine(N, DAG, DCI, Subtarget);
25350   case AArch64ISD::BRCOND:
25351     return performBRCONDCombine(N, DCI, DAG);
25352   case AArch64ISD::TBNZ:
25353   case AArch64ISD::TBZ:
25354     return performTBZCombine(N, DCI, DAG);
25355   case AArch64ISD::CSEL:
25356     return performCSELCombine(N, DCI, DAG);
25357   case AArch64ISD::DUP:
25358   case AArch64ISD::DUPLANE8:
25359   case AArch64ISD::DUPLANE16:
25360   case AArch64ISD::DUPLANE32:
25361   case AArch64ISD::DUPLANE64:
25362     return performDUPCombine(N, DCI);
25363   case AArch64ISD::DUPLANE128:
25364     return performDupLane128Combine(N, DAG);
25365   case AArch64ISD::NVCAST:
25366     return performNVCASTCombine(N, DAG);
25367   case AArch64ISD::SPLICE:
25368     return performSpliceCombine(N, DAG);
25369   case AArch64ISD::UUNPKLO:
25370   case AArch64ISD::UUNPKHI:
25371     return performUnpackCombine(N, DAG, Subtarget);
25372   case AArch64ISD::UZP1:
25373   case AArch64ISD::UZP2:
25374     return performUzpCombine(N, DAG, Subtarget);
25375   case AArch64ISD::SETCC_MERGE_ZERO:
25376     return performSetccMergeZeroCombine(N, DCI);
25377   case AArch64ISD::REINTERPRET_CAST:
25378     return performReinterpretCastCombine(N);
25379   case AArch64ISD::GLD1_MERGE_ZERO:
25380   case AArch64ISD::GLD1_SCALED_MERGE_ZERO:
25381   case AArch64ISD::GLD1_UXTW_MERGE_ZERO:
25382   case AArch64ISD::GLD1_SXTW_MERGE_ZERO:
25383   case AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO:
25384   case AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO:
25385   case AArch64ISD::GLD1_IMM_MERGE_ZERO:
25386   case AArch64ISD::GLD1S_MERGE_ZERO:
25387   case AArch64ISD::GLD1S_SCALED_MERGE_ZERO:
25388   case AArch64ISD::GLD1S_UXTW_MERGE_ZERO:
25389   case AArch64ISD::GLD1S_SXTW_MERGE_ZERO:
25390   case AArch64ISD::GLD1S_UXTW_SCALED_MERGE_ZERO:
25391   case AArch64ISD::GLD1S_SXTW_SCALED_MERGE_ZERO:
25392   case AArch64ISD::GLD1S_IMM_MERGE_ZERO:
25393     return performGLD1Combine(N, DAG);
25394   case AArch64ISD::VASHR:
25395   case AArch64ISD::VLSHR:
25396     return performVectorShiftCombine(N, *this, DCI);
25397   case AArch64ISD::SUNPKLO:
25398     return performSunpkloCombine(N, DAG);
25399   case AArch64ISD::BSP:
25400     return performBSPExpandForSVE(N, DAG, Subtarget);
25401   case ISD::INSERT_VECTOR_ELT:
25402     return performInsertVectorEltCombine(N, DCI);
25403   case ISD::EXTRACT_VECTOR_ELT:
25404     return performExtractVectorEltCombine(N, DCI, Subtarget);
25405   case ISD::VECREDUCE_ADD:
25406     return performVecReduceAddCombine(N, DCI.DAG, Subtarget);
25407   case AArch64ISD::UADDV:
25408     return performUADDVCombine(N, DAG);
25409   case AArch64ISD::SMULL:
25410   case AArch64ISD::UMULL:
25411   case AArch64ISD::PMULL:
25412     return performMULLCombine(N, DCI, DAG);
25413   case ISD::INTRINSIC_VOID:
25414   case ISD::INTRINSIC_W_CHAIN:
25415     switch (N->getConstantOperandVal(1)) {
25416     case Intrinsic::aarch64_sve_prfb_gather_scalar_offset:
25417       return combineSVEPrefetchVecBaseImmOff(N, DAG, 1 /*=ScalarSizeInBytes*/);
25418     case Intrinsic::aarch64_sve_prfh_gather_scalar_offset:
25419       return combineSVEPrefetchVecBaseImmOff(N, DAG, 2 /*=ScalarSizeInBytes*/);
25420     case Intrinsic::aarch64_sve_prfw_gather_scalar_offset:
25421       return combineSVEPrefetchVecBaseImmOff(N, DAG, 4 /*=ScalarSizeInBytes*/);
25422     case Intrinsic::aarch64_sve_prfd_gather_scalar_offset:
25423       return combineSVEPrefetchVecBaseImmOff(N, DAG, 8 /*=ScalarSizeInBytes*/);
25424     case Intrinsic::aarch64_sve_prfb_gather_uxtw_index:
25425     case Intrinsic::aarch64_sve_prfb_gather_sxtw_index:
25426     case Intrinsic::aarch64_sve_prfh_gather_uxtw_index:
25427     case Intrinsic::aarch64_sve_prfh_gather_sxtw_index:
25428     case Intrinsic::aarch64_sve_prfw_gather_uxtw_index:
25429     case Intrinsic::aarch64_sve_prfw_gather_sxtw_index:
25430     case Intrinsic::aarch64_sve_prfd_gather_uxtw_index:
25431     case Intrinsic::aarch64_sve_prfd_gather_sxtw_index:
25432       return legalizeSVEGatherPrefetchOffsVec(N, DAG);
25433     case Intrinsic::aarch64_neon_ld2:
25434     case Intrinsic::aarch64_neon_ld3:
25435     case Intrinsic::aarch64_neon_ld4:
25436     case Intrinsic::aarch64_neon_ld1x2:
25437     case Intrinsic::aarch64_neon_ld1x3:
25438     case Intrinsic::aarch64_neon_ld1x4:
25439     case Intrinsic::aarch64_neon_ld2lane:
25440     case Intrinsic::aarch64_neon_ld3lane:
25441     case Intrinsic::aarch64_neon_ld4lane:
25442     case Intrinsic::aarch64_neon_ld2r:
25443     case Intrinsic::aarch64_neon_ld3r:
25444     case Intrinsic::aarch64_neon_ld4r:
25445     case Intrinsic::aarch64_neon_st2:
25446     case Intrinsic::aarch64_neon_st3:
25447     case Intrinsic::aarch64_neon_st4:
25448     case Intrinsic::aarch64_neon_st1x2:
25449     case Intrinsic::aarch64_neon_st1x3:
25450     case Intrinsic::aarch64_neon_st1x4:
25451     case Intrinsic::aarch64_neon_st2lane:
25452     case Intrinsic::aarch64_neon_st3lane:
25453     case Intrinsic::aarch64_neon_st4lane:
25454       return performNEONPostLDSTCombine(N, DCI, DAG);
25455     case Intrinsic::aarch64_sve_ldnt1:
25456       return performLDNT1Combine(N, DAG);
25457     case Intrinsic::aarch64_sve_ld1rq:
25458       return performLD1ReplicateCombine<AArch64ISD::LD1RQ_MERGE_ZERO>(N, DAG);
25459     case Intrinsic::aarch64_sve_ld1ro:
25460       return performLD1ReplicateCombine<AArch64ISD::LD1RO_MERGE_ZERO>(N, DAG);
25461     case Intrinsic::aarch64_sve_ldnt1_gather_scalar_offset:
25462       return performGatherLoadCombine(N, DAG, AArch64ISD::GLDNT1_MERGE_ZERO);
25463     case Intrinsic::aarch64_sve_ldnt1_gather:
25464       return performGatherLoadCombine(N, DAG, AArch64ISD::GLDNT1_MERGE_ZERO);
25465     case Intrinsic::aarch64_sve_ldnt1_gather_index:
25466       return performGatherLoadCombine(N, DAG,
25467                                       AArch64ISD::GLDNT1_INDEX_MERGE_ZERO);
25468     case Intrinsic::aarch64_sve_ldnt1_gather_uxtw:
25469       return performGatherLoadCombine(N, DAG, AArch64ISD::GLDNT1_MERGE_ZERO);
25470     case Intrinsic::aarch64_sve_ld1:
25471       return performLD1Combine(N, DAG, AArch64ISD::LD1_MERGE_ZERO);
25472     case Intrinsic::aarch64_sve_ldnf1:
25473       return performLD1Combine(N, DAG, AArch64ISD::LDNF1_MERGE_ZERO);
25474     case Intrinsic::aarch64_sve_ldff1:
25475       return performLD1Combine(N, DAG, AArch64ISD::LDFF1_MERGE_ZERO);
25476     case Intrinsic::aarch64_sve_st1:
25477       return performST1Combine(N, DAG);
25478     case Intrinsic::aarch64_sve_stnt1:
25479       return performSTNT1Combine(N, DAG);
25480     case Intrinsic::aarch64_sve_stnt1_scatter_scalar_offset:
25481       return performScatterStoreCombine(N, DAG, AArch64ISD::SSTNT1_PRED);
25482     case Intrinsic::aarch64_sve_stnt1_scatter_uxtw:
25483       return performScatterStoreCombine(N, DAG, AArch64ISD::SSTNT1_PRED);
25484     case Intrinsic::aarch64_sve_stnt1_scatter:
25485       return performScatterStoreCombine(N, DAG, AArch64ISD::SSTNT1_PRED);
25486     case Intrinsic::aarch64_sve_stnt1_scatter_index:
25487       return performScatterStoreCombine(N, DAG, AArch64ISD::SSTNT1_INDEX_PRED);
25488     case Intrinsic::aarch64_sve_ld1_gather:
25489       return performGatherLoadCombine(N, DAG, AArch64ISD::GLD1_MERGE_ZERO);
25490     case Intrinsic::aarch64_sve_ld1q_gather_scalar_offset:
25491     case Intrinsic::aarch64_sve_ld1q_gather_vector_offset:
25492       return performGatherLoadCombine(N, DAG, AArch64ISD::GLD1Q_MERGE_ZERO);
25493     case Intrinsic::aarch64_sve_ld1q_gather_index:
25494       return performGatherLoadCombine(N, DAG,
25495                                       AArch64ISD::GLD1Q_INDEX_MERGE_ZERO);
25496     case Intrinsic::aarch64_sve_ld1_gather_index:
25497       return performGatherLoadCombine(N, DAG,
25498                                       AArch64ISD::GLD1_SCALED_MERGE_ZERO);
25499     case Intrinsic::aarch64_sve_ld1_gather_sxtw:
25500       return performGatherLoadCombine(N, DAG, AArch64ISD::GLD1_SXTW_MERGE_ZERO,
25501                                       /*OnlyPackedOffsets=*/false);
25502     case Intrinsic::aarch64_sve_ld1_gather_uxtw:
25503       return performGatherLoadCombine(N, DAG, AArch64ISD::GLD1_UXTW_MERGE_ZERO,
25504                                       /*OnlyPackedOffsets=*/false);
25505     case Intrinsic::aarch64_sve_ld1_gather_sxtw_index:
25506       return performGatherLoadCombine(N, DAG,
25507                                       AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO,
25508                                       /*OnlyPackedOffsets=*/false);
25509     case Intrinsic::aarch64_sve_ld1_gather_uxtw_index:
25510       return performGatherLoadCombine(N, DAG,
25511                                       AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO,
25512                                       /*OnlyPackedOffsets=*/false);
25513     case Intrinsic::aarch64_sve_ld1_gather_scalar_offset:
25514       return performGatherLoadCombine(N, DAG, AArch64ISD::GLD1_IMM_MERGE_ZERO);
25515     case Intrinsic::aarch64_sve_ldff1_gather:
25516       return performGatherLoadCombine(N, DAG, AArch64ISD::GLDFF1_MERGE_ZERO);
25517     case Intrinsic::aarch64_sve_ldff1_gather_index:
25518       return performGatherLoadCombine(N, DAG,
25519                                       AArch64ISD::GLDFF1_SCALED_MERGE_ZERO);
25520     case Intrinsic::aarch64_sve_ldff1_gather_sxtw:
25521       return performGatherLoadCombine(N, DAG,
25522                                       AArch64ISD::GLDFF1_SXTW_MERGE_ZERO,
25523                                       /*OnlyPackedOffsets=*/false);
25524     case Intrinsic::aarch64_sve_ldff1_gather_uxtw:
25525       return performGatherLoadCombine(N, DAG,
25526                                       AArch64ISD::GLDFF1_UXTW_MERGE_ZERO,
25527                                       /*OnlyPackedOffsets=*/false);
25528     case Intrinsic::aarch64_sve_ldff1_gather_sxtw_index:
25529       return performGatherLoadCombine(N, DAG,
25530                                       AArch64ISD::GLDFF1_SXTW_SCALED_MERGE_ZERO,
25531                                       /*OnlyPackedOffsets=*/false);
25532     case Intrinsic::aarch64_sve_ldff1_gather_uxtw_index:
25533       return performGatherLoadCombine(N, DAG,
25534                                       AArch64ISD::GLDFF1_UXTW_SCALED_MERGE_ZERO,
25535                                       /*OnlyPackedOffsets=*/false);
25536     case Intrinsic::aarch64_sve_ldff1_gather_scalar_offset:
25537       return performGatherLoadCombine(N, DAG,
25538                                       AArch64ISD::GLDFF1_IMM_MERGE_ZERO);
25539     case Intrinsic::aarch64_sve_st1q_scatter_scalar_offset:
25540     case Intrinsic::aarch64_sve_st1q_scatter_vector_offset:
25541       return performScatterStoreCombine(N, DAG, AArch64ISD::SST1Q_PRED);
25542     case Intrinsic::aarch64_sve_st1q_scatter_index:
25543       return performScatterStoreCombine(N, DAG, AArch64ISD::SST1Q_INDEX_PRED);
25544     case Intrinsic::aarch64_sve_st1_scatter:
25545       return performScatterStoreCombine(N, DAG, AArch64ISD::SST1_PRED);
25546     case Intrinsic::aarch64_sve_st1_scatter_index:
25547       return performScatterStoreCombine(N, DAG, AArch64ISD::SST1_SCALED_PRED);
25548     case Intrinsic::aarch64_sve_st1_scatter_sxtw:
25549       return performScatterStoreCombine(N, DAG, AArch64ISD::SST1_SXTW_PRED,
25550                                         /*OnlyPackedOffsets=*/false);
25551     case Intrinsic::aarch64_sve_st1_scatter_uxtw:
25552       return performScatterStoreCombine(N, DAG, AArch64ISD::SST1_UXTW_PRED,
25553                                         /*OnlyPackedOffsets=*/false);
25554     case Intrinsic::aarch64_sve_st1_scatter_sxtw_index:
25555       return performScatterStoreCombine(N, DAG,
25556                                         AArch64ISD::SST1_SXTW_SCALED_PRED,
25557                                         /*OnlyPackedOffsets=*/false);
25558     case Intrinsic::aarch64_sve_st1_scatter_uxtw_index:
25559       return performScatterStoreCombine(N, DAG,
25560                                         AArch64ISD::SST1_UXTW_SCALED_PRED,
25561                                         /*OnlyPackedOffsets=*/false);
25562     case Intrinsic::aarch64_sve_st1_scatter_scalar_offset:
25563       return performScatterStoreCombine(N, DAG, AArch64ISD::SST1_IMM_PRED);
25564     case Intrinsic::aarch64_rndr:
25565     case Intrinsic::aarch64_rndrrs: {
25566       unsigned IntrinsicID = N->getConstantOperandVal(1);
25567       auto Register =
25568           (IntrinsicID == Intrinsic::aarch64_rndr ? AArch64SysReg::RNDR
25569                                                   : AArch64SysReg::RNDRRS);
25570       SDLoc DL(N);
25571       SDValue A = DAG.getNode(
25572           AArch64ISD::MRS, DL, DAG.getVTList(MVT::i64, MVT::Glue, MVT::Other),
25573           N->getOperand(0), DAG.getConstant(Register, DL, MVT::i64));
25574       SDValue B = DAG.getNode(
25575           AArch64ISD::CSINC, DL, MVT::i32, DAG.getConstant(0, DL, MVT::i32),
25576           DAG.getConstant(0, DL, MVT::i32),
25577           DAG.getConstant(AArch64CC::NE, DL, MVT::i32), A.getValue(1));
25578       return DAG.getMergeValues(
25579           {A, DAG.getZExtOrTrunc(B, DL, MVT::i1), A.getValue(2)}, DL);
25580     }
25581     case Intrinsic::aarch64_sme_ldr_zt:
25582       return DAG.getNode(AArch64ISD::RESTORE_ZT, SDLoc(N),
25583                          DAG.getVTList(MVT::Other), N->getOperand(0),
25584                          N->getOperand(2), N->getOperand(3));
25585     case Intrinsic::aarch64_sme_str_zt:
25586       return DAG.getNode(AArch64ISD::SAVE_ZT, SDLoc(N),
25587                          DAG.getVTList(MVT::Other), N->getOperand(0),
25588                          N->getOperand(2), N->getOperand(3));
25589     default:
25590       break;
25591     }
25592     break;
25593   case ISD::GlobalAddress:
25594     return performGlobalAddressCombine(N, DAG, Subtarget, getTargetMachine());
25595   case ISD::CTLZ:
25596     return performCTLZCombine(N, DAG, Subtarget);
25597   case ISD::SCALAR_TO_VECTOR:
25598     return performScalarToVectorCombine(N, DCI, DAG);
25599   }
25600   return SDValue();
25601 }
25602 
25603 // Check if the return value is used as only a return value, as otherwise
25604 // we can't perform a tail-call. In particular, we need to check for
25605 // target ISD nodes that are returns and any other "odd" constructs
25606 // that the generic analysis code won't necessarily catch.
isUsedByReturnOnly(SDNode * N,SDValue & Chain) const25607 bool AArch64TargetLowering::isUsedByReturnOnly(SDNode *N,
25608                                                SDValue &Chain) const {
25609   if (N->getNumValues() != 1)
25610     return false;
25611   if (!N->hasNUsesOfValue(1, 0))
25612     return false;
25613 
25614   SDValue TCChain = Chain;
25615   SDNode *Copy = *N->use_begin();
25616   if (Copy->getOpcode() == ISD::CopyToReg) {
25617     // If the copy has a glue operand, we conservatively assume it isn't safe to
25618     // perform a tail call.
25619     if (Copy->getOperand(Copy->getNumOperands() - 1).getValueType() ==
25620         MVT::Glue)
25621       return false;
25622     TCChain = Copy->getOperand(0);
25623   } else if (Copy->getOpcode() != ISD::FP_EXTEND)
25624     return false;
25625 
25626   bool HasRet = false;
25627   for (SDNode *Node : Copy->uses()) {
25628     if (Node->getOpcode() != AArch64ISD::RET_GLUE)
25629       return false;
25630     HasRet = true;
25631   }
25632 
25633   if (!HasRet)
25634     return false;
25635 
25636   Chain = TCChain;
25637   return true;
25638 }
25639 
25640 // Return whether the an instruction can potentially be optimized to a tail
25641 // call. This will cause the optimizers to attempt to move, or duplicate,
25642 // return instructions to help enable tail call optimizations for this
25643 // instruction.
mayBeEmittedAsTailCall(const CallInst * CI) const25644 bool AArch64TargetLowering::mayBeEmittedAsTailCall(const CallInst *CI) const {
25645   return CI->isTailCall();
25646 }
25647 
isIndexingLegal(MachineInstr & MI,Register Base,Register Offset,bool IsPre,MachineRegisterInfo & MRI) const25648 bool AArch64TargetLowering::isIndexingLegal(MachineInstr &MI, Register Base,
25649                                             Register Offset, bool IsPre,
25650                                             MachineRegisterInfo &MRI) const {
25651   auto CstOffset = getIConstantVRegVal(Offset, MRI);
25652   if (!CstOffset || CstOffset->isZero())
25653     return false;
25654 
25655   // All of the indexed addressing mode instructions take a signed 9 bit
25656   // immediate offset. Our CstOffset is a G_PTR_ADD offset so it already
25657   // encodes the sign/indexing direction.
25658   return isInt<9>(CstOffset->getSExtValue());
25659 }
25660 
getIndexedAddressParts(SDNode * N,SDNode * Op,SDValue & Base,SDValue & Offset,SelectionDAG & DAG) const25661 bool AArch64TargetLowering::getIndexedAddressParts(SDNode *N, SDNode *Op,
25662                                                    SDValue &Base,
25663                                                    SDValue &Offset,
25664                                                    SelectionDAG &DAG) const {
25665   if (Op->getOpcode() != ISD::ADD && Op->getOpcode() != ISD::SUB)
25666     return false;
25667 
25668   // Non-null if there is exactly one user of the loaded value (ignoring chain).
25669   SDNode *ValOnlyUser = nullptr;
25670   for (SDNode::use_iterator UI = N->use_begin(), UE = N->use_end(); UI != UE;
25671        ++UI) {
25672     if (UI.getUse().getResNo() == 1)
25673       continue; // Ignore chain.
25674     if (ValOnlyUser == nullptr)
25675       ValOnlyUser = *UI;
25676     else {
25677       ValOnlyUser = nullptr; // Multiple non-chain uses, bail out.
25678       break;
25679     }
25680   }
25681 
25682   auto IsUndefOrZero = [](SDValue V) {
25683     return V.isUndef() || isNullOrNullSplat(V, /*AllowUndefs*/ true);
25684   };
25685 
25686   // If the only user of the value is a scalable vector splat, it is
25687   // preferable to do a replicating load (ld1r*).
25688   if (ValOnlyUser && ValOnlyUser->getValueType(0).isScalableVector() &&
25689       (ValOnlyUser->getOpcode() == ISD::SPLAT_VECTOR ||
25690        (ValOnlyUser->getOpcode() == AArch64ISD::DUP_MERGE_PASSTHRU &&
25691         IsUndefOrZero(ValOnlyUser->getOperand(2)))))
25692     return false;
25693 
25694   Base = Op->getOperand(0);
25695   // All of the indexed addressing mode instructions take a signed
25696   // 9 bit immediate offset.
25697   if (ConstantSDNode *RHS = dyn_cast<ConstantSDNode>(Op->getOperand(1))) {
25698     int64_t RHSC = RHS->getSExtValue();
25699     if (Op->getOpcode() == ISD::SUB)
25700       RHSC = -(uint64_t)RHSC;
25701     if (!isInt<9>(RHSC))
25702       return false;
25703     // Always emit pre-inc/post-inc addressing mode. Use negated constant offset
25704     // when dealing with subtraction.
25705     Offset = DAG.getConstant(RHSC, SDLoc(N), RHS->getValueType(0));
25706     return true;
25707   }
25708   return false;
25709 }
25710 
getPreIndexedAddressParts(SDNode * N,SDValue & Base,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG) const25711 bool AArch64TargetLowering::getPreIndexedAddressParts(SDNode *N, SDValue &Base,
25712                                                       SDValue &Offset,
25713                                                       ISD::MemIndexedMode &AM,
25714                                                       SelectionDAG &DAG) const {
25715   EVT VT;
25716   SDValue Ptr;
25717   if (LoadSDNode *LD = dyn_cast<LoadSDNode>(N)) {
25718     VT = LD->getMemoryVT();
25719     Ptr = LD->getBasePtr();
25720   } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(N)) {
25721     VT = ST->getMemoryVT();
25722     Ptr = ST->getBasePtr();
25723   } else
25724     return false;
25725 
25726   if (!getIndexedAddressParts(N, Ptr.getNode(), Base, Offset, DAG))
25727     return false;
25728   AM = ISD::PRE_INC;
25729   return true;
25730 }
25731 
getPostIndexedAddressParts(SDNode * N,SDNode * Op,SDValue & Base,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG) const25732 bool AArch64TargetLowering::getPostIndexedAddressParts(
25733     SDNode *N, SDNode *Op, SDValue &Base, SDValue &Offset,
25734     ISD::MemIndexedMode &AM, SelectionDAG &DAG) const {
25735   EVT VT;
25736   SDValue Ptr;
25737   if (LoadSDNode *LD = dyn_cast<LoadSDNode>(N)) {
25738     VT = LD->getMemoryVT();
25739     Ptr = LD->getBasePtr();
25740   } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(N)) {
25741     VT = ST->getMemoryVT();
25742     Ptr = ST->getBasePtr();
25743   } else
25744     return false;
25745 
25746   if (!getIndexedAddressParts(N, Op, Base, Offset, DAG))
25747     return false;
25748   // Post-indexing updates the base, so it's not a valid transform
25749   // if that's not the same as the load's pointer.
25750   if (Ptr != Base)
25751     return false;
25752   AM = ISD::POST_INC;
25753   return true;
25754 }
25755 
replaceBoolVectorBitcast(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG)25756 static void replaceBoolVectorBitcast(SDNode *N,
25757                                      SmallVectorImpl<SDValue> &Results,
25758                                      SelectionDAG &DAG) {
25759   SDLoc DL(N);
25760   SDValue Op = N->getOperand(0);
25761   EVT VT = N->getValueType(0);
25762   [[maybe_unused]] EVT SrcVT = Op.getValueType();
25763   assert(SrcVT.isVector() && SrcVT.getVectorElementType() == MVT::i1 &&
25764          "Must be bool vector.");
25765 
25766   // Special handling for Clang's __builtin_convertvector. For vectors with <8
25767   // elements, it adds a vector concatenation with undef(s). If we encounter
25768   // this here, we can skip the concat.
25769   if (Op.getOpcode() == ISD::CONCAT_VECTORS && !Op.getOperand(0).isUndef()) {
25770     bool AllUndef = true;
25771     for (unsigned I = 1; I < Op.getNumOperands(); ++I)
25772       AllUndef &= Op.getOperand(I).isUndef();
25773 
25774     if (AllUndef)
25775       Op = Op.getOperand(0);
25776   }
25777 
25778   SDValue VectorBits = vectorToScalarBitmask(Op.getNode(), DAG);
25779   if (VectorBits)
25780     Results.push_back(DAG.getZExtOrTrunc(VectorBits, DL, VT));
25781 }
25782 
CustomNonLegalBITCASTResults(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG,EVT ExtendVT,EVT CastVT)25783 static void CustomNonLegalBITCASTResults(SDNode *N,
25784                                          SmallVectorImpl<SDValue> &Results,
25785                                          SelectionDAG &DAG, EVT ExtendVT,
25786                                          EVT CastVT) {
25787   SDLoc DL(N);
25788   SDValue Op = N->getOperand(0);
25789   EVT VT = N->getValueType(0);
25790 
25791   // Use SCALAR_TO_VECTOR for lane zero
25792   SDValue Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, ExtendVT, Op);
25793   SDValue CastVal = DAG.getNode(ISD::BITCAST, DL, CastVT, Vec);
25794   SDValue IdxZero = DAG.getVectorIdxConstant(0, DL);
25795   Results.push_back(
25796       DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, CastVal, IdxZero));
25797 }
25798 
ReplaceBITCASTResults(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG) const25799 void AArch64TargetLowering::ReplaceBITCASTResults(
25800     SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
25801   SDLoc DL(N);
25802   SDValue Op = N->getOperand(0);
25803   EVT VT = N->getValueType(0);
25804   EVT SrcVT = Op.getValueType();
25805 
25806   if (VT == MVT::v2i16 && SrcVT == MVT::i32) {
25807     CustomNonLegalBITCASTResults(N, Results, DAG, MVT::v2i32, MVT::v4i16);
25808     return;
25809   }
25810 
25811   if (VT == MVT::v4i8 && SrcVT == MVT::i32) {
25812     CustomNonLegalBITCASTResults(N, Results, DAG, MVT::v2i32, MVT::v8i8);
25813     return;
25814   }
25815 
25816   if (VT == MVT::v2i8 && SrcVT == MVT::i16) {
25817     CustomNonLegalBITCASTResults(N, Results, DAG, MVT::v4i16, MVT::v8i8);
25818     return;
25819   }
25820 
25821   if (VT.isScalableVector() && !isTypeLegal(VT) && isTypeLegal(SrcVT)) {
25822     assert(!VT.isFloatingPoint() && SrcVT.isFloatingPoint() &&
25823            "Expected fp->int bitcast!");
25824 
25825     // Bitcasting between unpacked vector types of different element counts is
25826     // not a NOP because the live elements are laid out differently.
25827     //                01234567
25828     // e.g. nxv2i32 = XX??XX??
25829     //      nxv4f16 = X?X?X?X?
25830     if (VT.getVectorElementCount() != SrcVT.getVectorElementCount())
25831       return;
25832 
25833     SDValue CastResult = getSVESafeBitCast(getSVEContainerType(VT), Op, DAG);
25834     Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, CastResult));
25835     return;
25836   }
25837 
25838   if (SrcVT.isVector() && SrcVT.getVectorElementType() == MVT::i1 &&
25839       !VT.isVector())
25840     return replaceBoolVectorBitcast(N, Results, DAG);
25841 
25842   if (VT != MVT::i16 || (SrcVT != MVT::f16 && SrcVT != MVT::bf16))
25843     return;
25844 
25845   Op = DAG.getTargetInsertSubreg(AArch64::hsub, DL, MVT::f32,
25846                                  DAG.getUNDEF(MVT::i32), Op);
25847   Op = DAG.getNode(ISD::BITCAST, DL, MVT::i32, Op);
25848   Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Op));
25849 }
25850 
ReplaceAddWithADDP(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)25851 static void ReplaceAddWithADDP(SDNode *N, SmallVectorImpl<SDValue> &Results,
25852                                SelectionDAG &DAG,
25853                                const AArch64Subtarget *Subtarget) {
25854   EVT VT = N->getValueType(0);
25855   if (!VT.is256BitVector() ||
25856       (VT.getScalarType().isFloatingPoint() &&
25857        !N->getFlags().hasAllowReassociation()) ||
25858       (VT.getScalarType() == MVT::f16 && !Subtarget->hasFullFP16()) ||
25859       VT.getScalarType() == MVT::bf16)
25860     return;
25861 
25862   SDValue X = N->getOperand(0);
25863   auto *Shuf = dyn_cast<ShuffleVectorSDNode>(N->getOperand(1));
25864   if (!Shuf) {
25865     Shuf = dyn_cast<ShuffleVectorSDNode>(N->getOperand(0));
25866     X = N->getOperand(1);
25867     if (!Shuf)
25868       return;
25869   }
25870 
25871   if (Shuf->getOperand(0) != X || !Shuf->getOperand(1)->isUndef())
25872     return;
25873 
25874   // Check the mask is 1,0,3,2,5,4,...
25875   ArrayRef<int> Mask = Shuf->getMask();
25876   for (int I = 0, E = Mask.size(); I < E; I++)
25877     if (Mask[I] != (I % 2 == 0 ? I + 1 : I - 1))
25878       return;
25879 
25880   SDLoc DL(N);
25881   auto LoHi = DAG.SplitVector(X, DL);
25882   assert(LoHi.first.getValueType() == LoHi.second.getValueType());
25883   SDValue Addp = DAG.getNode(AArch64ISD::ADDP, N, LoHi.first.getValueType(),
25884                              LoHi.first, LoHi.second);
25885 
25886   // Shuffle the elements back into order.
25887   SmallVector<int> NMask;
25888   for (unsigned I = 0, E = VT.getVectorNumElements() / 2; I < E; I++) {
25889     NMask.push_back(I);
25890     NMask.push_back(I);
25891   }
25892   Results.push_back(
25893       DAG.getVectorShuffle(VT, DL,
25894                            DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Addp,
25895                                        DAG.getUNDEF(LoHi.first.getValueType())),
25896                            DAG.getUNDEF(VT), NMask));
25897 }
25898 
ReplaceReductionResults(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG,unsigned InterOp,unsigned AcrossOp)25899 static void ReplaceReductionResults(SDNode *N,
25900                                     SmallVectorImpl<SDValue> &Results,
25901                                     SelectionDAG &DAG, unsigned InterOp,
25902                                     unsigned AcrossOp) {
25903   EVT LoVT, HiVT;
25904   SDValue Lo, Hi;
25905   SDLoc dl(N);
25906   std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(N->getValueType(0));
25907   std::tie(Lo, Hi) = DAG.SplitVectorOperand(N, 0);
25908   SDValue InterVal = DAG.getNode(InterOp, dl, LoVT, Lo, Hi);
25909   SDValue SplitVal = DAG.getNode(AcrossOp, dl, LoVT, InterVal);
25910   Results.push_back(SplitVal);
25911 }
25912 
ReplaceExtractSubVectorResults(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG) const25913 void AArch64TargetLowering::ReplaceExtractSubVectorResults(
25914     SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
25915   SDValue In = N->getOperand(0);
25916   EVT InVT = In.getValueType();
25917 
25918   // Common code will handle these just fine.
25919   if (!InVT.isScalableVector() || !InVT.isInteger())
25920     return;
25921 
25922   SDLoc DL(N);
25923   EVT VT = N->getValueType(0);
25924 
25925   // The following checks bail if this is not a halving operation.
25926 
25927   ElementCount ResEC = VT.getVectorElementCount();
25928 
25929   if (InVT.getVectorElementCount() != (ResEC * 2))
25930     return;
25931 
25932   auto *CIndex = dyn_cast<ConstantSDNode>(N->getOperand(1));
25933   if (!CIndex)
25934     return;
25935 
25936   unsigned Index = CIndex->getZExtValue();
25937   if ((Index != 0) && (Index != ResEC.getKnownMinValue()))
25938     return;
25939 
25940   unsigned Opcode = (Index == 0) ? AArch64ISD::UUNPKLO : AArch64ISD::UUNPKHI;
25941   EVT ExtendedHalfVT = VT.widenIntegerVectorElementType(*DAG.getContext());
25942 
25943   SDValue Half = DAG.getNode(Opcode, DL, ExtendedHalfVT, N->getOperand(0));
25944   Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, Half));
25945 }
25946 
25947 // Create an even/odd pair of X registers holding integer value V.
createGPRPairNode(SelectionDAG & DAG,SDValue V)25948 static SDValue createGPRPairNode(SelectionDAG &DAG, SDValue V) {
25949   SDLoc dl(V.getNode());
25950   auto [VLo, VHi] = DAG.SplitScalar(V, dl, MVT::i64, MVT::i64);
25951   if (DAG.getDataLayout().isBigEndian())
25952     std::swap (VLo, VHi);
25953   SDValue RegClass =
25954       DAG.getTargetConstant(AArch64::XSeqPairsClassRegClassID, dl, MVT::i32);
25955   SDValue SubReg0 = DAG.getTargetConstant(AArch64::sube64, dl, MVT::i32);
25956   SDValue SubReg1 = DAG.getTargetConstant(AArch64::subo64, dl, MVT::i32);
25957   const SDValue Ops[] = { RegClass, VLo, SubReg0, VHi, SubReg1 };
25958   return SDValue(
25959       DAG.getMachineNode(TargetOpcode::REG_SEQUENCE, dl, MVT::Untyped, Ops), 0);
25960 }
25961 
ReplaceCMP_SWAP_128Results(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)25962 static void ReplaceCMP_SWAP_128Results(SDNode *N,
25963                                        SmallVectorImpl<SDValue> &Results,
25964                                        SelectionDAG &DAG,
25965                                        const AArch64Subtarget *Subtarget) {
25966   assert(N->getValueType(0) == MVT::i128 &&
25967          "AtomicCmpSwap on types less than 128 should be legal");
25968 
25969   MachineMemOperand *MemOp = cast<MemSDNode>(N)->getMemOperand();
25970   if (Subtarget->hasLSE() || Subtarget->outlineAtomics()) {
25971     // LSE has a 128-bit compare and swap (CASP), but i128 is not a legal type,
25972     // so lower it here, wrapped in REG_SEQUENCE and EXTRACT_SUBREG.
25973     SDValue Ops[] = {
25974         createGPRPairNode(DAG, N->getOperand(2)), // Compare value
25975         createGPRPairNode(DAG, N->getOperand(3)), // Store value
25976         N->getOperand(1), // Ptr
25977         N->getOperand(0), // Chain in
25978     };
25979 
25980     unsigned Opcode;
25981     switch (MemOp->getMergedOrdering()) {
25982     case AtomicOrdering::Monotonic:
25983       Opcode = AArch64::CASPX;
25984       break;
25985     case AtomicOrdering::Acquire:
25986       Opcode = AArch64::CASPAX;
25987       break;
25988     case AtomicOrdering::Release:
25989       Opcode = AArch64::CASPLX;
25990       break;
25991     case AtomicOrdering::AcquireRelease:
25992     case AtomicOrdering::SequentiallyConsistent:
25993       Opcode = AArch64::CASPALX;
25994       break;
25995     default:
25996       llvm_unreachable("Unexpected ordering!");
25997     }
25998 
25999     MachineSDNode *CmpSwap = DAG.getMachineNode(
26000         Opcode, SDLoc(N), DAG.getVTList(MVT::Untyped, MVT::Other), Ops);
26001     DAG.setNodeMemRefs(CmpSwap, {MemOp});
26002 
26003     unsigned SubReg1 = AArch64::sube64, SubReg2 = AArch64::subo64;
26004     if (DAG.getDataLayout().isBigEndian())
26005       std::swap(SubReg1, SubReg2);
26006     SDValue Lo = DAG.getTargetExtractSubreg(SubReg1, SDLoc(N), MVT::i64,
26007                                             SDValue(CmpSwap, 0));
26008     SDValue Hi = DAG.getTargetExtractSubreg(SubReg2, SDLoc(N), MVT::i64,
26009                                             SDValue(CmpSwap, 0));
26010     Results.push_back(
26011         DAG.getNode(ISD::BUILD_PAIR, SDLoc(N), MVT::i128, Lo, Hi));
26012     Results.push_back(SDValue(CmpSwap, 1)); // Chain out
26013     return;
26014   }
26015 
26016   unsigned Opcode;
26017   switch (MemOp->getMergedOrdering()) {
26018   case AtomicOrdering::Monotonic:
26019     Opcode = AArch64::CMP_SWAP_128_MONOTONIC;
26020     break;
26021   case AtomicOrdering::Acquire:
26022     Opcode = AArch64::CMP_SWAP_128_ACQUIRE;
26023     break;
26024   case AtomicOrdering::Release:
26025     Opcode = AArch64::CMP_SWAP_128_RELEASE;
26026     break;
26027   case AtomicOrdering::AcquireRelease:
26028   case AtomicOrdering::SequentiallyConsistent:
26029     Opcode = AArch64::CMP_SWAP_128;
26030     break;
26031   default:
26032     llvm_unreachable("Unexpected ordering!");
26033   }
26034 
26035   SDLoc DL(N);
26036   auto Desired = DAG.SplitScalar(N->getOperand(2), DL, MVT::i64, MVT::i64);
26037   auto New = DAG.SplitScalar(N->getOperand(3), DL, MVT::i64, MVT::i64);
26038   SDValue Ops[] = {N->getOperand(1), Desired.first, Desired.second,
26039                    New.first,        New.second,    N->getOperand(0)};
26040   SDNode *CmpSwap = DAG.getMachineNode(
26041       Opcode, SDLoc(N), DAG.getVTList(MVT::i64, MVT::i64, MVT::i32, MVT::Other),
26042       Ops);
26043   DAG.setNodeMemRefs(cast<MachineSDNode>(CmpSwap), {MemOp});
26044 
26045   Results.push_back(DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i128,
26046                                 SDValue(CmpSwap, 0), SDValue(CmpSwap, 1)));
26047   Results.push_back(SDValue(CmpSwap, 3));
26048 }
26049 
getAtomicLoad128Opcode(unsigned ISDOpcode,AtomicOrdering Ordering)26050 static unsigned getAtomicLoad128Opcode(unsigned ISDOpcode,
26051                                        AtomicOrdering Ordering) {
26052   // ATOMIC_LOAD_CLR only appears when lowering ATOMIC_LOAD_AND (see
26053   // LowerATOMIC_LOAD_AND). We can't take that approach with 128-bit, because
26054   // the type is not legal. Therefore we shouldn't expect to see a 128-bit
26055   // ATOMIC_LOAD_CLR at any point.
26056   assert(ISDOpcode != ISD::ATOMIC_LOAD_CLR &&
26057          "ATOMIC_LOAD_AND should be lowered to LDCLRP directly");
26058   assert(ISDOpcode != ISD::ATOMIC_LOAD_ADD && "There is no 128 bit LDADD");
26059   assert(ISDOpcode != ISD::ATOMIC_LOAD_SUB && "There is no 128 bit LDSUB");
26060 
26061   if (ISDOpcode == ISD::ATOMIC_LOAD_AND) {
26062     // The operand will need to be XORed in a separate step.
26063     switch (Ordering) {
26064     case AtomicOrdering::Monotonic:
26065       return AArch64::LDCLRP;
26066       break;
26067     case AtomicOrdering::Acquire:
26068       return AArch64::LDCLRPA;
26069       break;
26070     case AtomicOrdering::Release:
26071       return AArch64::LDCLRPL;
26072       break;
26073     case AtomicOrdering::AcquireRelease:
26074     case AtomicOrdering::SequentiallyConsistent:
26075       return AArch64::LDCLRPAL;
26076       break;
26077     default:
26078       llvm_unreachable("Unexpected ordering!");
26079     }
26080   }
26081 
26082   if (ISDOpcode == ISD::ATOMIC_LOAD_OR) {
26083     switch (Ordering) {
26084     case AtomicOrdering::Monotonic:
26085       return AArch64::LDSETP;
26086       break;
26087     case AtomicOrdering::Acquire:
26088       return AArch64::LDSETPA;
26089       break;
26090     case AtomicOrdering::Release:
26091       return AArch64::LDSETPL;
26092       break;
26093     case AtomicOrdering::AcquireRelease:
26094     case AtomicOrdering::SequentiallyConsistent:
26095       return AArch64::LDSETPAL;
26096       break;
26097     default:
26098       llvm_unreachable("Unexpected ordering!");
26099     }
26100   }
26101 
26102   if (ISDOpcode == ISD::ATOMIC_SWAP) {
26103     switch (Ordering) {
26104     case AtomicOrdering::Monotonic:
26105       return AArch64::SWPP;
26106       break;
26107     case AtomicOrdering::Acquire:
26108       return AArch64::SWPPA;
26109       break;
26110     case AtomicOrdering::Release:
26111       return AArch64::SWPPL;
26112       break;
26113     case AtomicOrdering::AcquireRelease:
26114     case AtomicOrdering::SequentiallyConsistent:
26115       return AArch64::SWPPAL;
26116       break;
26117     default:
26118       llvm_unreachable("Unexpected ordering!");
26119     }
26120   }
26121 
26122   llvm_unreachable("Unexpected ISDOpcode!");
26123 }
26124 
ReplaceATOMIC_LOAD_128Results(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)26125 static void ReplaceATOMIC_LOAD_128Results(SDNode *N,
26126                                           SmallVectorImpl<SDValue> &Results,
26127                                           SelectionDAG &DAG,
26128                                           const AArch64Subtarget *Subtarget) {
26129   // LSE128 has a 128-bit RMW ops, but i128 is not a legal type, so lower it
26130   // here. This follows the approach of the CMP_SWAP_XXX pseudo instructions
26131   // rather than the CASP instructions, because CASP has register classes for
26132   // the pairs of registers and therefore uses REG_SEQUENCE and EXTRACT_SUBREG
26133   // to present them as single operands. LSE128 instructions use the GPR64
26134   // register class (because the pair does not have to be sequential), like
26135   // CMP_SWAP_XXX, and therefore we use TRUNCATE and BUILD_PAIR.
26136 
26137   assert(N->getValueType(0) == MVT::i128 &&
26138          "AtomicLoadXXX on types less than 128 should be legal");
26139 
26140   if (!Subtarget->hasLSE128())
26141     return;
26142 
26143   MachineMemOperand *MemOp = cast<MemSDNode>(N)->getMemOperand();
26144   const SDValue &Chain = N->getOperand(0);
26145   const SDValue &Ptr = N->getOperand(1);
26146   const SDValue &Val128 = N->getOperand(2);
26147   std::pair<SDValue, SDValue> Val2x64 =
26148       DAG.SplitScalar(Val128, SDLoc(Val128), MVT::i64, MVT::i64);
26149 
26150   const unsigned ISDOpcode = N->getOpcode();
26151   const unsigned MachineOpcode =
26152       getAtomicLoad128Opcode(ISDOpcode, MemOp->getMergedOrdering());
26153 
26154   if (ISDOpcode == ISD::ATOMIC_LOAD_AND) {
26155     SDLoc dl(Val128);
26156     Val2x64.first =
26157         DAG.getNode(ISD::XOR, dl, MVT::i64,
26158                     DAG.getConstant(-1ULL, dl, MVT::i64), Val2x64.first);
26159     Val2x64.second =
26160         DAG.getNode(ISD::XOR, dl, MVT::i64,
26161                     DAG.getConstant(-1ULL, dl, MVT::i64), Val2x64.second);
26162   }
26163 
26164   SDValue Ops[] = {Val2x64.first, Val2x64.second, Ptr, Chain};
26165   if (DAG.getDataLayout().isBigEndian())
26166     std::swap(Ops[0], Ops[1]);
26167 
26168   MachineSDNode *AtomicInst =
26169       DAG.getMachineNode(MachineOpcode, SDLoc(N),
26170                          DAG.getVTList(MVT::i64, MVT::i64, MVT::Other), Ops);
26171 
26172   DAG.setNodeMemRefs(AtomicInst, {MemOp});
26173 
26174   SDValue Lo = SDValue(AtomicInst, 0), Hi = SDValue(AtomicInst, 1);
26175   if (DAG.getDataLayout().isBigEndian())
26176     std::swap(Lo, Hi);
26177 
26178   Results.push_back(DAG.getNode(ISD::BUILD_PAIR, SDLoc(N), MVT::i128, Lo, Hi));
26179   Results.push_back(SDValue(AtomicInst, 2)); // Chain out
26180 }
26181 
ReplaceNodeResults(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG) const26182 void AArch64TargetLowering::ReplaceNodeResults(
26183     SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
26184   switch (N->getOpcode()) {
26185   default:
26186     llvm_unreachable("Don't know how to custom expand this");
26187   case ISD::BITCAST:
26188     ReplaceBITCASTResults(N, Results, DAG);
26189     return;
26190   case ISD::VECREDUCE_ADD:
26191   case ISD::VECREDUCE_SMAX:
26192   case ISD::VECREDUCE_SMIN:
26193   case ISD::VECREDUCE_UMAX:
26194   case ISD::VECREDUCE_UMIN:
26195     Results.push_back(LowerVECREDUCE(SDValue(N, 0), DAG));
26196     return;
26197   case ISD::ADD:
26198   case ISD::FADD:
26199     ReplaceAddWithADDP(N, Results, DAG, Subtarget);
26200     return;
26201 
26202   case ISD::CTPOP:
26203   case ISD::PARITY:
26204     if (SDValue Result = LowerCTPOP_PARITY(SDValue(N, 0), DAG))
26205       Results.push_back(Result);
26206     return;
26207   case AArch64ISD::SADDV:
26208     ReplaceReductionResults(N, Results, DAG, ISD::ADD, AArch64ISD::SADDV);
26209     return;
26210   case AArch64ISD::UADDV:
26211     ReplaceReductionResults(N, Results, DAG, ISD::ADD, AArch64ISD::UADDV);
26212     return;
26213   case AArch64ISD::SMINV:
26214     ReplaceReductionResults(N, Results, DAG, ISD::SMIN, AArch64ISD::SMINV);
26215     return;
26216   case AArch64ISD::UMINV:
26217     ReplaceReductionResults(N, Results, DAG, ISD::UMIN, AArch64ISD::UMINV);
26218     return;
26219   case AArch64ISD::SMAXV:
26220     ReplaceReductionResults(N, Results, DAG, ISD::SMAX, AArch64ISD::SMAXV);
26221     return;
26222   case AArch64ISD::UMAXV:
26223     ReplaceReductionResults(N, Results, DAG, ISD::UMAX, AArch64ISD::UMAXV);
26224     return;
26225   case ISD::MULHS:
26226     if (useSVEForFixedLengthVectorVT(SDValue(N, 0).getValueType()))
26227       Results.push_back(
26228           LowerToPredicatedOp(SDValue(N, 0), DAG, AArch64ISD::MULHS_PRED));
26229     return;
26230   case ISD::MULHU:
26231     if (useSVEForFixedLengthVectorVT(SDValue(N, 0).getValueType()))
26232       Results.push_back(
26233           LowerToPredicatedOp(SDValue(N, 0), DAG, AArch64ISD::MULHU_PRED));
26234     return;
26235   case ISD::FP_TO_UINT:
26236   case ISD::FP_TO_SINT:
26237   case ISD::STRICT_FP_TO_SINT:
26238   case ISD::STRICT_FP_TO_UINT:
26239     assert(N->getValueType(0) == MVT::i128 && "unexpected illegal conversion");
26240     // Let normal code take care of it by not adding anything to Results.
26241     return;
26242   case ISD::ATOMIC_CMP_SWAP:
26243     ReplaceCMP_SWAP_128Results(N, Results, DAG, Subtarget);
26244     return;
26245   case ISD::ATOMIC_LOAD_CLR:
26246     assert(N->getValueType(0) != MVT::i128 &&
26247            "128-bit ATOMIC_LOAD_AND should be lowered directly to LDCLRP");
26248     break;
26249   case ISD::ATOMIC_LOAD_AND:
26250   case ISD::ATOMIC_LOAD_OR:
26251   case ISD::ATOMIC_SWAP: {
26252     assert(cast<AtomicSDNode>(N)->getVal().getValueType() == MVT::i128 &&
26253            "Expected 128-bit atomicrmw.");
26254     // These need custom type legalisation so we go directly to instruction.
26255     ReplaceATOMIC_LOAD_128Results(N, Results, DAG, Subtarget);
26256     return;
26257   }
26258   case ISD::ATOMIC_LOAD:
26259   case ISD::LOAD: {
26260     MemSDNode *LoadNode = cast<MemSDNode>(N);
26261     EVT MemVT = LoadNode->getMemoryVT();
26262     // Handle lowering 256 bit non temporal loads into LDNP for little-endian
26263     // targets.
26264     if (LoadNode->isNonTemporal() && Subtarget->isLittleEndian() &&
26265         MemVT.getSizeInBits() == 256u &&
26266         (MemVT.getScalarSizeInBits() == 8u ||
26267          MemVT.getScalarSizeInBits() == 16u ||
26268          MemVT.getScalarSizeInBits() == 32u ||
26269          MemVT.getScalarSizeInBits() == 64u)) {
26270 
26271       SDValue Result = DAG.getMemIntrinsicNode(
26272           AArch64ISD::LDNP, SDLoc(N),
26273           DAG.getVTList({MemVT.getHalfNumVectorElementsVT(*DAG.getContext()),
26274                          MemVT.getHalfNumVectorElementsVT(*DAG.getContext()),
26275                          MVT::Other}),
26276           {LoadNode->getChain(), LoadNode->getBasePtr()},
26277           LoadNode->getMemoryVT(), LoadNode->getMemOperand());
26278 
26279       SDValue Pair = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), MemVT,
26280                                  Result.getValue(0), Result.getValue(1));
26281       Results.append({Pair, Result.getValue(2) /* Chain */});
26282       return;
26283     }
26284 
26285     if ((!LoadNode->isVolatile() && !LoadNode->isAtomic()) ||
26286         LoadNode->getMemoryVT() != MVT::i128) {
26287       // Non-volatile or atomic loads are optimized later in AArch64's load/store
26288       // optimizer.
26289       return;
26290     }
26291 
26292     if (SDValue(N, 0).getValueType() == MVT::i128) {
26293       auto *AN = dyn_cast<AtomicSDNode>(LoadNode);
26294       bool isLoadAcquire =
26295           AN && AN->getSuccessOrdering() == AtomicOrdering::Acquire;
26296       unsigned Opcode = isLoadAcquire ? AArch64ISD::LDIAPP : AArch64ISD::LDP;
26297 
26298       if (isLoadAcquire)
26299         assert(Subtarget->hasFeature(AArch64::FeatureRCPC3));
26300 
26301       SDValue Result = DAG.getMemIntrinsicNode(
26302           Opcode, SDLoc(N), DAG.getVTList({MVT::i64, MVT::i64, MVT::Other}),
26303           {LoadNode->getChain(), LoadNode->getBasePtr()},
26304           LoadNode->getMemoryVT(), LoadNode->getMemOperand());
26305 
26306       unsigned FirstRes = DAG.getDataLayout().isBigEndian() ? 1 : 0;
26307 
26308       SDValue Pair =
26309           DAG.getNode(ISD::BUILD_PAIR, SDLoc(N), MVT::i128,
26310                       Result.getValue(FirstRes), Result.getValue(1 - FirstRes));
26311       Results.append({Pair, Result.getValue(2) /* Chain */});
26312     }
26313     return;
26314   }
26315   case ISD::EXTRACT_SUBVECTOR:
26316     ReplaceExtractSubVectorResults(N, Results, DAG);
26317     return;
26318   case ISD::INSERT_SUBVECTOR:
26319   case ISD::CONCAT_VECTORS:
26320     // Custom lowering has been requested for INSERT_SUBVECTOR and
26321     // CONCAT_VECTORS -- but delegate to common code for result type
26322     // legalisation
26323     return;
26324   case ISD::INTRINSIC_WO_CHAIN: {
26325     EVT VT = N->getValueType(0);
26326 
26327     Intrinsic::ID IntID =
26328         static_cast<Intrinsic::ID>(N->getConstantOperandVal(0));
26329     switch (IntID) {
26330     default:
26331       return;
26332     case Intrinsic::aarch64_sve_clasta_n: {
26333       assert((VT == MVT::i8 || VT == MVT::i16) &&
26334              "custom lowering for unexpected type");
26335       SDLoc DL(N);
26336       auto Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, N->getOperand(2));
26337       auto V = DAG.getNode(AArch64ISD::CLASTA_N, DL, MVT::i32,
26338                            N->getOperand(1), Op2, N->getOperand(3));
26339       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
26340       return;
26341     }
26342     case Intrinsic::aarch64_sve_clastb_n: {
26343       assert((VT == MVT::i8 || VT == MVT::i16) &&
26344              "custom lowering for unexpected type");
26345       SDLoc DL(N);
26346       auto Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, N->getOperand(2));
26347       auto V = DAG.getNode(AArch64ISD::CLASTB_N, DL, MVT::i32,
26348                            N->getOperand(1), Op2, N->getOperand(3));
26349       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
26350       return;
26351     }
26352     case Intrinsic::aarch64_sve_lasta: {
26353       assert((VT == MVT::i8 || VT == MVT::i16) &&
26354              "custom lowering for unexpected type");
26355       SDLoc DL(N);
26356       auto V = DAG.getNode(AArch64ISD::LASTA, DL, MVT::i32,
26357                            N->getOperand(1), N->getOperand(2));
26358       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
26359       return;
26360     }
26361     case Intrinsic::aarch64_sve_lastb: {
26362       assert((VT == MVT::i8 || VT == MVT::i16) &&
26363              "custom lowering for unexpected type");
26364       SDLoc DL(N);
26365       auto V = DAG.getNode(AArch64ISD::LASTB, DL, MVT::i32,
26366                            N->getOperand(1), N->getOperand(2));
26367       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
26368       return;
26369     }
26370     case Intrinsic::get_active_lane_mask: {
26371       if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i1)
26372         return;
26373 
26374       // NOTE: Only trivial type promotion is supported.
26375       EVT NewVT = getTypeToTransformTo(*DAG.getContext(), VT);
26376       if (NewVT.getVectorNumElements() != VT.getVectorNumElements())
26377         return;
26378 
26379       SDLoc DL(N);
26380       auto V = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NewVT, N->ops());
26381       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
26382       return;
26383     }
26384     }
26385   }
26386   case ISD::READ_REGISTER: {
26387     SDLoc DL(N);
26388     assert(N->getValueType(0) == MVT::i128 &&
26389            "READ_REGISTER custom lowering is only for 128-bit sysregs");
26390     SDValue Chain = N->getOperand(0);
26391     SDValue SysRegName = N->getOperand(1);
26392 
26393     SDValue Result = DAG.getNode(
26394         AArch64ISD::MRRS, DL, DAG.getVTList({MVT::i64, MVT::i64, MVT::Other}),
26395         Chain, SysRegName);
26396 
26397     // Sysregs are not endian. Result.getValue(0) always contains the lower half
26398     // of the 128-bit System Register value.
26399     SDValue Pair = DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i128,
26400                                Result.getValue(0), Result.getValue(1));
26401     Results.push_back(Pair);
26402     Results.push_back(Result.getValue(2)); // Chain
26403     return;
26404   }
26405   }
26406 }
26407 
useLoadStackGuardNode() const26408 bool AArch64TargetLowering::useLoadStackGuardNode() const {
26409   if (Subtarget->isTargetAndroid() || Subtarget->isTargetFuchsia())
26410     return TargetLowering::useLoadStackGuardNode();
26411   return true;
26412 }
26413 
combineRepeatedFPDivisors() const26414 unsigned AArch64TargetLowering::combineRepeatedFPDivisors() const {
26415   // Combine multiple FDIVs with the same divisor into multiple FMULs by the
26416   // reciprocal if there are three or more FDIVs.
26417   return 3;
26418 }
26419 
26420 TargetLoweringBase::LegalizeTypeAction
getPreferredVectorAction(MVT VT) const26421 AArch64TargetLowering::getPreferredVectorAction(MVT VT) const {
26422   // During type legalization, we prefer to widen v1i8, v1i16, v1i32  to v8i8,
26423   // v4i16, v2i32 instead of to promote.
26424   if (VT == MVT::v1i8 || VT == MVT::v1i16 || VT == MVT::v1i32 ||
26425       VT == MVT::v1f32)
26426     return TypeWidenVector;
26427 
26428   return TargetLoweringBase::getPreferredVectorAction(VT);
26429 }
26430 
26431 // In v8.4a, ldp and stp instructions are guaranteed to be single-copy atomic
26432 // provided the address is 16-byte aligned.
isOpSuitableForLDPSTP(const Instruction * I) const26433 bool AArch64TargetLowering::isOpSuitableForLDPSTP(const Instruction *I) const {
26434   if (!Subtarget->hasLSE2())
26435     return false;
26436 
26437   if (auto LI = dyn_cast<LoadInst>(I))
26438     return LI->getType()->getPrimitiveSizeInBits() == 128 &&
26439            LI->getAlign() >= Align(16);
26440 
26441   if (auto SI = dyn_cast<StoreInst>(I))
26442     return SI->getValueOperand()->getType()->getPrimitiveSizeInBits() == 128 &&
26443            SI->getAlign() >= Align(16);
26444 
26445   return false;
26446 }
26447 
isOpSuitableForLSE128(const Instruction * I) const26448 bool AArch64TargetLowering::isOpSuitableForLSE128(const Instruction *I) const {
26449   if (!Subtarget->hasLSE128())
26450     return false;
26451 
26452   // Only use SWPP for stores where LSE2 would require a fence. Unlike STP, SWPP
26453   // will clobber the two registers.
26454   if (const auto *SI = dyn_cast<StoreInst>(I))
26455     return SI->getValueOperand()->getType()->getPrimitiveSizeInBits() == 128 &&
26456            SI->getAlign() >= Align(16) &&
26457            (SI->getOrdering() == AtomicOrdering::Release ||
26458             SI->getOrdering() == AtomicOrdering::SequentiallyConsistent);
26459 
26460   if (const auto *RMW = dyn_cast<AtomicRMWInst>(I))
26461     return RMW->getValOperand()->getType()->getPrimitiveSizeInBits() == 128 &&
26462            RMW->getAlign() >= Align(16) &&
26463            (RMW->getOperation() == AtomicRMWInst::Xchg ||
26464             RMW->getOperation() == AtomicRMWInst::And ||
26465             RMW->getOperation() == AtomicRMWInst::Or);
26466 
26467   return false;
26468 }
26469 
isOpSuitableForRCPC3(const Instruction * I) const26470 bool AArch64TargetLowering::isOpSuitableForRCPC3(const Instruction *I) const {
26471   if (!Subtarget->hasLSE2() || !Subtarget->hasRCPC3())
26472     return false;
26473 
26474   if (auto LI = dyn_cast<LoadInst>(I))
26475     return LI->getType()->getPrimitiveSizeInBits() == 128 &&
26476            LI->getAlign() >= Align(16) &&
26477            LI->getOrdering() == AtomicOrdering::Acquire;
26478 
26479   if (auto SI = dyn_cast<StoreInst>(I))
26480     return SI->getValueOperand()->getType()->getPrimitiveSizeInBits() == 128 &&
26481            SI->getAlign() >= Align(16) &&
26482            SI->getOrdering() == AtomicOrdering::Release;
26483 
26484   return false;
26485 }
26486 
shouldInsertFencesForAtomic(const Instruction * I) const26487 bool AArch64TargetLowering::shouldInsertFencesForAtomic(
26488     const Instruction *I) const {
26489   if (isOpSuitableForRCPC3(I))
26490     return false;
26491   if (isOpSuitableForLSE128(I))
26492     return false;
26493   if (isOpSuitableForLDPSTP(I))
26494     return true;
26495   return false;
26496 }
26497 
shouldInsertTrailingFenceForAtomicStore(const Instruction * I) const26498 bool AArch64TargetLowering::shouldInsertTrailingFenceForAtomicStore(
26499     const Instruction *I) const {
26500   // Store-Release instructions only provide seq_cst guarantees when paired with
26501   // Load-Acquire instructions. MSVC CRT does not use these instructions to
26502   // implement seq_cst loads and stores, so we need additional explicit fences
26503   // after memory writes.
26504   if (!Subtarget->getTargetTriple().isWindowsMSVCEnvironment())
26505     return false;
26506 
26507   switch (I->getOpcode()) {
26508   default:
26509     return false;
26510   case Instruction::AtomicCmpXchg:
26511     return cast<AtomicCmpXchgInst>(I)->getSuccessOrdering() ==
26512            AtomicOrdering::SequentiallyConsistent;
26513   case Instruction::AtomicRMW:
26514     return cast<AtomicRMWInst>(I)->getOrdering() ==
26515            AtomicOrdering::SequentiallyConsistent;
26516   case Instruction::Store:
26517     return cast<StoreInst>(I)->getOrdering() ==
26518            AtomicOrdering::SequentiallyConsistent;
26519   }
26520 }
26521 
26522 // Loads and stores less than 128-bits are already atomic; ones above that
26523 // are doomed anyway, so defer to the default libcall and blame the OS when
26524 // things go wrong.
26525 TargetLoweringBase::AtomicExpansionKind
shouldExpandAtomicStoreInIR(StoreInst * SI) const26526 AArch64TargetLowering::shouldExpandAtomicStoreInIR(StoreInst *SI) const {
26527   unsigned Size = SI->getValueOperand()->getType()->getPrimitiveSizeInBits();
26528   if (Size != 128)
26529     return AtomicExpansionKind::None;
26530   if (isOpSuitableForRCPC3(SI))
26531     return AtomicExpansionKind::None;
26532   if (isOpSuitableForLSE128(SI))
26533     return AtomicExpansionKind::Expand;
26534   if (isOpSuitableForLDPSTP(SI))
26535     return AtomicExpansionKind::None;
26536   return AtomicExpansionKind::Expand;
26537 }
26538 
26539 // Loads and stores less than 128-bits are already atomic; ones above that
26540 // are doomed anyway, so defer to the default libcall and blame the OS when
26541 // things go wrong.
26542 TargetLowering::AtomicExpansionKind
shouldExpandAtomicLoadInIR(LoadInst * LI) const26543 AArch64TargetLowering::shouldExpandAtomicLoadInIR(LoadInst *LI) const {
26544   unsigned Size = LI->getType()->getPrimitiveSizeInBits();
26545 
26546   if (Size != 128)
26547     return AtomicExpansionKind::None;
26548   if (isOpSuitableForRCPC3(LI))
26549     return AtomicExpansionKind::None;
26550   // No LSE128 loads
26551   if (isOpSuitableForLDPSTP(LI))
26552     return AtomicExpansionKind::None;
26553 
26554   // At -O0, fast-regalloc cannot cope with the live vregs necessary to
26555   // implement atomicrmw without spilling. If the target address is also on the
26556   // stack and close enough to the spill slot, this can lead to a situation
26557   // where the monitor always gets cleared and the atomic operation can never
26558   // succeed. So at -O0 lower this operation to a CAS loop.
26559   if (getTargetMachine().getOptLevel() == CodeGenOptLevel::None)
26560     return AtomicExpansionKind::CmpXChg;
26561 
26562   // Using CAS for an atomic load has a better chance of succeeding under high
26563   // contention situations. So use it if available.
26564   return Subtarget->hasLSE() ? AtomicExpansionKind::CmpXChg
26565                              : AtomicExpansionKind::LLSC;
26566 }
26567 
26568 // The "default" for integer RMW operations is to expand to an LL/SC loop.
26569 // However, with the LSE instructions (or outline-atomics mode, which provides
26570 // library routines in place of the LSE-instructions), we can directly emit many
26571 // operations instead.
26572 //
26573 // Floating-point operations are always emitted to a cmpxchg loop, because they
26574 // may trigger a trap which aborts an LLSC sequence.
26575 TargetLowering::AtomicExpansionKind
shouldExpandAtomicRMWInIR(AtomicRMWInst * AI) const26576 AArch64TargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
26577   unsigned Size = AI->getType()->getPrimitiveSizeInBits();
26578   assert(Size <= 128 && "AtomicExpandPass should've handled larger sizes.");
26579 
26580   if (AI->isFloatingPointOperation())
26581     return AtomicExpansionKind::CmpXChg;
26582 
26583   bool CanUseLSE128 = Subtarget->hasLSE128() && Size == 128 &&
26584                       (AI->getOperation() == AtomicRMWInst::Xchg ||
26585                        AI->getOperation() == AtomicRMWInst::Or ||
26586                        AI->getOperation() == AtomicRMWInst::And);
26587   if (CanUseLSE128)
26588     return AtomicExpansionKind::None;
26589 
26590   // Nand is not supported in LSE.
26591   // Leave 128 bits to LLSC or CmpXChg.
26592   if (AI->getOperation() != AtomicRMWInst::Nand && Size < 128) {
26593     if (Subtarget->hasLSE())
26594       return AtomicExpansionKind::None;
26595     if (Subtarget->outlineAtomics()) {
26596       // [U]Min/[U]Max RWM atomics are used in __sync_fetch_ libcalls so far.
26597       // Don't outline them unless
26598       // (1) high level <atomic> support approved:
26599       //   http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p0493r1.pdf
26600       // (2) low level libgcc and compiler-rt support implemented by:
26601       //   min/max outline atomics helpers
26602       if (AI->getOperation() != AtomicRMWInst::Min &&
26603           AI->getOperation() != AtomicRMWInst::Max &&
26604           AI->getOperation() != AtomicRMWInst::UMin &&
26605           AI->getOperation() != AtomicRMWInst::UMax) {
26606         return AtomicExpansionKind::None;
26607       }
26608     }
26609   }
26610 
26611   // At -O0, fast-regalloc cannot cope with the live vregs necessary to
26612   // implement atomicrmw without spilling. If the target address is also on the
26613   // stack and close enough to the spill slot, this can lead to a situation
26614   // where the monitor always gets cleared and the atomic operation can never
26615   // succeed. So at -O0 lower this operation to a CAS loop. Also worthwhile if
26616   // we have a single CAS instruction that can replace the loop.
26617   if (getTargetMachine().getOptLevel() == CodeGenOptLevel::None ||
26618       Subtarget->hasLSE())
26619     return AtomicExpansionKind::CmpXChg;
26620 
26621   return AtomicExpansionKind::LLSC;
26622 }
26623 
26624 TargetLowering::AtomicExpansionKind
shouldExpandAtomicCmpXchgInIR(AtomicCmpXchgInst * AI) const26625 AArch64TargetLowering::shouldExpandAtomicCmpXchgInIR(
26626     AtomicCmpXchgInst *AI) const {
26627   // If subtarget has LSE, leave cmpxchg intact for codegen.
26628   if (Subtarget->hasLSE() || Subtarget->outlineAtomics())
26629     return AtomicExpansionKind::None;
26630   // At -O0, fast-regalloc cannot cope with the live vregs necessary to
26631   // implement cmpxchg without spilling. If the address being exchanged is also
26632   // on the stack and close enough to the spill slot, this can lead to a
26633   // situation where the monitor always gets cleared and the atomic operation
26634   // can never succeed. So at -O0 we need a late-expanded pseudo-inst instead.
26635   if (getTargetMachine().getOptLevel() == CodeGenOptLevel::None)
26636     return AtomicExpansionKind::None;
26637 
26638   // 128-bit atomic cmpxchg is weird; AtomicExpand doesn't know how to expand
26639   // it.
26640   unsigned Size = AI->getCompareOperand()->getType()->getPrimitiveSizeInBits();
26641   if (Size > 64)
26642     return AtomicExpansionKind::None;
26643 
26644   return AtomicExpansionKind::LLSC;
26645 }
26646 
emitLoadLinked(IRBuilderBase & Builder,Type * ValueTy,Value * Addr,AtomicOrdering Ord) const26647 Value *AArch64TargetLowering::emitLoadLinked(IRBuilderBase &Builder,
26648                                              Type *ValueTy, Value *Addr,
26649                                              AtomicOrdering Ord) const {
26650   Module *M = Builder.GetInsertBlock()->getParent()->getParent();
26651   bool IsAcquire = isAcquireOrStronger(Ord);
26652 
26653   // Since i128 isn't legal and intrinsics don't get type-lowered, the ldrexd
26654   // intrinsic must return {i64, i64} and we have to recombine them into a
26655   // single i128 here.
26656   if (ValueTy->getPrimitiveSizeInBits() == 128) {
26657     Intrinsic::ID Int =
26658         IsAcquire ? Intrinsic::aarch64_ldaxp : Intrinsic::aarch64_ldxp;
26659     Function *Ldxr = Intrinsic::getDeclaration(M, Int);
26660 
26661     Value *LoHi = Builder.CreateCall(Ldxr, Addr, "lohi");
26662 
26663     Value *Lo = Builder.CreateExtractValue(LoHi, 0, "lo");
26664     Value *Hi = Builder.CreateExtractValue(LoHi, 1, "hi");
26665     Lo = Builder.CreateZExt(Lo, ValueTy, "lo64");
26666     Hi = Builder.CreateZExt(Hi, ValueTy, "hi64");
26667     return Builder.CreateOr(
26668         Lo, Builder.CreateShl(Hi, ConstantInt::get(ValueTy, 64)), "val64");
26669   }
26670 
26671   Type *Tys[] = { Addr->getType() };
26672   Intrinsic::ID Int =
26673       IsAcquire ? Intrinsic::aarch64_ldaxr : Intrinsic::aarch64_ldxr;
26674   Function *Ldxr = Intrinsic::getDeclaration(M, Int, Tys);
26675 
26676   const DataLayout &DL = M->getDataLayout();
26677   IntegerType *IntEltTy = Builder.getIntNTy(DL.getTypeSizeInBits(ValueTy));
26678   CallInst *CI = Builder.CreateCall(Ldxr, Addr);
26679   CI->addParamAttr(
26680       0, Attribute::get(Builder.getContext(), Attribute::ElementType, ValueTy));
26681   Value *Trunc = Builder.CreateTrunc(CI, IntEltTy);
26682 
26683   return Builder.CreateBitCast(Trunc, ValueTy);
26684 }
26685 
emitAtomicCmpXchgNoStoreLLBalance(IRBuilderBase & Builder) const26686 void AArch64TargetLowering::emitAtomicCmpXchgNoStoreLLBalance(
26687     IRBuilderBase &Builder) const {
26688   Module *M = Builder.GetInsertBlock()->getParent()->getParent();
26689   Builder.CreateCall(Intrinsic::getDeclaration(M, Intrinsic::aarch64_clrex));
26690 }
26691 
emitStoreConditional(IRBuilderBase & Builder,Value * Val,Value * Addr,AtomicOrdering Ord) const26692 Value *AArch64TargetLowering::emitStoreConditional(IRBuilderBase &Builder,
26693                                                    Value *Val, Value *Addr,
26694                                                    AtomicOrdering Ord) const {
26695   Module *M = Builder.GetInsertBlock()->getParent()->getParent();
26696   bool IsRelease = isReleaseOrStronger(Ord);
26697 
26698   // Since the intrinsics must have legal type, the i128 intrinsics take two
26699   // parameters: "i64, i64". We must marshal Val into the appropriate form
26700   // before the call.
26701   if (Val->getType()->getPrimitiveSizeInBits() == 128) {
26702     Intrinsic::ID Int =
26703         IsRelease ? Intrinsic::aarch64_stlxp : Intrinsic::aarch64_stxp;
26704     Function *Stxr = Intrinsic::getDeclaration(M, Int);
26705     Type *Int64Ty = Type::getInt64Ty(M->getContext());
26706 
26707     Value *Lo = Builder.CreateTrunc(Val, Int64Ty, "lo");
26708     Value *Hi = Builder.CreateTrunc(Builder.CreateLShr(Val, 64), Int64Ty, "hi");
26709     return Builder.CreateCall(Stxr, {Lo, Hi, Addr});
26710   }
26711 
26712   Intrinsic::ID Int =
26713       IsRelease ? Intrinsic::aarch64_stlxr : Intrinsic::aarch64_stxr;
26714   Type *Tys[] = { Addr->getType() };
26715   Function *Stxr = Intrinsic::getDeclaration(M, Int, Tys);
26716 
26717   const DataLayout &DL = M->getDataLayout();
26718   IntegerType *IntValTy = Builder.getIntNTy(DL.getTypeSizeInBits(Val->getType()));
26719   Val = Builder.CreateBitCast(Val, IntValTy);
26720 
26721   CallInst *CI = Builder.CreateCall(
26722       Stxr, {Builder.CreateZExtOrBitCast(
26723                  Val, Stxr->getFunctionType()->getParamType(0)),
26724              Addr});
26725   CI->addParamAttr(1, Attribute::get(Builder.getContext(),
26726                                      Attribute::ElementType, Val->getType()));
26727   return CI;
26728 }
26729 
functionArgumentNeedsConsecutiveRegisters(Type * Ty,CallingConv::ID CallConv,bool isVarArg,const DataLayout & DL) const26730 bool AArch64TargetLowering::functionArgumentNeedsConsecutiveRegisters(
26731     Type *Ty, CallingConv::ID CallConv, bool isVarArg,
26732     const DataLayout &DL) const {
26733   if (!Ty->isArrayTy()) {
26734     const TypeSize &TySize = Ty->getPrimitiveSizeInBits();
26735     return TySize.isScalable() && TySize.getKnownMinValue() > 128;
26736   }
26737 
26738   // All non aggregate members of the type must have the same type
26739   SmallVector<EVT> ValueVTs;
26740   ComputeValueVTs(*this, DL, Ty, ValueVTs);
26741   return all_equal(ValueVTs);
26742 }
26743 
shouldNormalizeToSelectSequence(LLVMContext &,EVT) const26744 bool AArch64TargetLowering::shouldNormalizeToSelectSequence(LLVMContext &,
26745                                                             EVT) const {
26746   return false;
26747 }
26748 
UseTlsOffset(IRBuilderBase & IRB,unsigned Offset)26749 static Value *UseTlsOffset(IRBuilderBase &IRB, unsigned Offset) {
26750   Module *M = IRB.GetInsertBlock()->getParent()->getParent();
26751   Function *ThreadPointerFunc =
26752       Intrinsic::getDeclaration(M, Intrinsic::thread_pointer);
26753   return IRB.CreatePointerCast(
26754       IRB.CreateConstGEP1_32(IRB.getInt8Ty(), IRB.CreateCall(ThreadPointerFunc),
26755                              Offset),
26756       IRB.getPtrTy(0));
26757 }
26758 
getIRStackGuard(IRBuilderBase & IRB) const26759 Value *AArch64TargetLowering::getIRStackGuard(IRBuilderBase &IRB) const {
26760   // Android provides a fixed TLS slot for the stack cookie. See the definition
26761   // of TLS_SLOT_STACK_GUARD in
26762   // https://android.googlesource.com/platform/bionic/+/main/libc/platform/bionic/tls_defines.h
26763   if (Subtarget->isTargetAndroid())
26764     return UseTlsOffset(IRB, 0x28);
26765 
26766   // Fuchsia is similar.
26767   // <zircon/tls.h> defines ZX_TLS_STACK_GUARD_OFFSET with this value.
26768   if (Subtarget->isTargetFuchsia())
26769     return UseTlsOffset(IRB, -0x10);
26770 
26771   return TargetLowering::getIRStackGuard(IRB);
26772 }
26773 
insertSSPDeclarations(Module & M) const26774 void AArch64TargetLowering::insertSSPDeclarations(Module &M) const {
26775   // MSVC CRT provides functionalities for stack protection.
26776   if (Subtarget->getTargetTriple().isWindowsMSVCEnvironment()) {
26777     // MSVC CRT has a global variable holding security cookie.
26778     M.getOrInsertGlobal("__security_cookie",
26779                         PointerType::getUnqual(M.getContext()));
26780 
26781     // MSVC CRT has a function to validate security cookie.
26782     FunctionCallee SecurityCheckCookie =
26783         M.getOrInsertFunction(Subtarget->getSecurityCheckCookieName(),
26784                               Type::getVoidTy(M.getContext()),
26785                               PointerType::getUnqual(M.getContext()));
26786     if (Function *F = dyn_cast<Function>(SecurityCheckCookie.getCallee())) {
26787       F->setCallingConv(CallingConv::Win64);
26788       F->addParamAttr(0, Attribute::AttrKind::InReg);
26789     }
26790     return;
26791   }
26792   TargetLowering::insertSSPDeclarations(M);
26793 }
26794 
getSDagStackGuard(const Module & M) const26795 Value *AArch64TargetLowering::getSDagStackGuard(const Module &M) const {
26796   // MSVC CRT has a global variable holding security cookie.
26797   if (Subtarget->getTargetTriple().isWindowsMSVCEnvironment())
26798     return M.getGlobalVariable("__security_cookie");
26799   return TargetLowering::getSDagStackGuard(M);
26800 }
26801 
getSSPStackGuardCheck(const Module & M) const26802 Function *AArch64TargetLowering::getSSPStackGuardCheck(const Module &M) const {
26803   // MSVC CRT has a function to validate security cookie.
26804   if (Subtarget->getTargetTriple().isWindowsMSVCEnvironment())
26805     return M.getFunction(Subtarget->getSecurityCheckCookieName());
26806   return TargetLowering::getSSPStackGuardCheck(M);
26807 }
26808 
26809 Value *
getSafeStackPointerLocation(IRBuilderBase & IRB) const26810 AArch64TargetLowering::getSafeStackPointerLocation(IRBuilderBase &IRB) const {
26811   // Android provides a fixed TLS slot for the SafeStack pointer. See the
26812   // definition of TLS_SLOT_SAFESTACK in
26813   // https://android.googlesource.com/platform/bionic/+/master/libc/private/bionic_tls.h
26814   if (Subtarget->isTargetAndroid())
26815     return UseTlsOffset(IRB, 0x48);
26816 
26817   // Fuchsia is similar.
26818   // <zircon/tls.h> defines ZX_TLS_UNSAFE_SP_OFFSET with this value.
26819   if (Subtarget->isTargetFuchsia())
26820     return UseTlsOffset(IRB, -0x8);
26821 
26822   return TargetLowering::getSafeStackPointerLocation(IRB);
26823 }
26824 
isMaskAndCmp0FoldingBeneficial(const Instruction & AndI) const26825 bool AArch64TargetLowering::isMaskAndCmp0FoldingBeneficial(
26826     const Instruction &AndI) const {
26827   // Only sink 'and' mask to cmp use block if it is masking a single bit, since
26828   // this is likely to be fold the and/cmp/br into a single tbz instruction.  It
26829   // may be beneficial to sink in other cases, but we would have to check that
26830   // the cmp would not get folded into the br to form a cbz for these to be
26831   // beneficial.
26832   ConstantInt* Mask = dyn_cast<ConstantInt>(AndI.getOperand(1));
26833   if (!Mask)
26834     return false;
26835   return Mask->getValue().isPowerOf2();
26836 }
26837 
26838 bool AArch64TargetLowering::
shouldProduceAndByConstByHoistingConstFromShiftsLHSOfAnd(SDValue X,ConstantSDNode * XC,ConstantSDNode * CC,SDValue Y,unsigned OldShiftOpcode,unsigned NewShiftOpcode,SelectionDAG & DAG) const26839     shouldProduceAndByConstByHoistingConstFromShiftsLHSOfAnd(
26840         SDValue X, ConstantSDNode *XC, ConstantSDNode *CC, SDValue Y,
26841         unsigned OldShiftOpcode, unsigned NewShiftOpcode,
26842         SelectionDAG &DAG) const {
26843   // Does baseline recommend not to perform the fold by default?
26844   if (!TargetLowering::shouldProduceAndByConstByHoistingConstFromShiftsLHSOfAnd(
26845           X, XC, CC, Y, OldShiftOpcode, NewShiftOpcode, DAG))
26846     return false;
26847   // Else, if this is a vector shift, prefer 'shl'.
26848   return X.getValueType().isScalarInteger() || NewShiftOpcode == ISD::SHL;
26849 }
26850 
26851 TargetLowering::ShiftLegalizationStrategy
preferredShiftLegalizationStrategy(SelectionDAG & DAG,SDNode * N,unsigned int ExpansionFactor) const26852 AArch64TargetLowering::preferredShiftLegalizationStrategy(
26853     SelectionDAG &DAG, SDNode *N, unsigned int ExpansionFactor) const {
26854   if (DAG.getMachineFunction().getFunction().hasMinSize() &&
26855       !Subtarget->isTargetWindows() && !Subtarget->isTargetDarwin())
26856     return ShiftLegalizationStrategy::LowerToLibcall;
26857   return TargetLowering::preferredShiftLegalizationStrategy(DAG, N,
26858                                                             ExpansionFactor);
26859 }
26860 
initializeSplitCSR(MachineBasicBlock * Entry) const26861 void AArch64TargetLowering::initializeSplitCSR(MachineBasicBlock *Entry) const {
26862   // Update IsSplitCSR in AArch64unctionInfo.
26863   AArch64FunctionInfo *AFI = Entry->getParent()->getInfo<AArch64FunctionInfo>();
26864   AFI->setIsSplitCSR(true);
26865 }
26866 
insertCopiesSplitCSR(MachineBasicBlock * Entry,const SmallVectorImpl<MachineBasicBlock * > & Exits) const26867 void AArch64TargetLowering::insertCopiesSplitCSR(
26868     MachineBasicBlock *Entry,
26869     const SmallVectorImpl<MachineBasicBlock *> &Exits) const {
26870   const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
26871   const MCPhysReg *IStart = TRI->getCalleeSavedRegsViaCopy(Entry->getParent());
26872   if (!IStart)
26873     return;
26874 
26875   const TargetInstrInfo *TII = Subtarget->getInstrInfo();
26876   MachineRegisterInfo *MRI = &Entry->getParent()->getRegInfo();
26877   MachineBasicBlock::iterator MBBI = Entry->begin();
26878   for (const MCPhysReg *I = IStart; *I; ++I) {
26879     const TargetRegisterClass *RC = nullptr;
26880     if (AArch64::GPR64RegClass.contains(*I))
26881       RC = &AArch64::GPR64RegClass;
26882     else if (AArch64::FPR64RegClass.contains(*I))
26883       RC = &AArch64::FPR64RegClass;
26884     else
26885       llvm_unreachable("Unexpected register class in CSRsViaCopy!");
26886 
26887     Register NewVR = MRI->createVirtualRegister(RC);
26888     // Create copy from CSR to a virtual register.
26889     // FIXME: this currently does not emit CFI pseudo-instructions, it works
26890     // fine for CXX_FAST_TLS since the C++-style TLS access functions should be
26891     // nounwind. If we want to generalize this later, we may need to emit
26892     // CFI pseudo-instructions.
26893     assert(Entry->getParent()->getFunction().hasFnAttribute(
26894                Attribute::NoUnwind) &&
26895            "Function should be nounwind in insertCopiesSplitCSR!");
26896     Entry->addLiveIn(*I);
26897     BuildMI(*Entry, MBBI, DebugLoc(), TII->get(TargetOpcode::COPY), NewVR)
26898         .addReg(*I);
26899 
26900     // Insert the copy-back instructions right before the terminator.
26901     for (auto *Exit : Exits)
26902       BuildMI(*Exit, Exit->getFirstTerminator(), DebugLoc(),
26903               TII->get(TargetOpcode::COPY), *I)
26904           .addReg(NewVR);
26905   }
26906 }
26907 
isIntDivCheap(EVT VT,AttributeList Attr) const26908 bool AArch64TargetLowering::isIntDivCheap(EVT VT, AttributeList Attr) const {
26909   // Integer division on AArch64 is expensive. However, when aggressively
26910   // optimizing for code size, we prefer to use a div instruction, as it is
26911   // usually smaller than the alternative sequence.
26912   // The exception to this is vector division. Since AArch64 doesn't have vector
26913   // integer division, leaving the division as-is is a loss even in terms of
26914   // size, because it will have to be scalarized, while the alternative code
26915   // sequence can be performed in vector form.
26916   bool OptSize = Attr.hasFnAttr(Attribute::MinSize);
26917   return OptSize && !VT.isVector();
26918 }
26919 
preferIncOfAddToSubOfNot(EVT VT) const26920 bool AArch64TargetLowering::preferIncOfAddToSubOfNot(EVT VT) const {
26921   // We want inc-of-add for scalars and sub-of-not for vectors.
26922   return VT.isScalarInteger();
26923 }
26924 
shouldConvertFpToSat(unsigned Op,EVT FPVT,EVT VT) const26925 bool AArch64TargetLowering::shouldConvertFpToSat(unsigned Op, EVT FPVT,
26926                                                  EVT VT) const {
26927   // v8f16 without fp16 need to be extended to v8f32, which is more difficult to
26928   // legalize.
26929   if (FPVT == MVT::v8f16 && !Subtarget->hasFullFP16())
26930     return false;
26931   if (FPVT == MVT::v8bf16)
26932     return false;
26933   return TargetLowering::shouldConvertFpToSat(Op, FPVT, VT);
26934 }
26935 
26936 MachineInstr *
EmitKCFICheck(MachineBasicBlock & MBB,MachineBasicBlock::instr_iterator & MBBI,const TargetInstrInfo * TII) const26937 AArch64TargetLowering::EmitKCFICheck(MachineBasicBlock &MBB,
26938                                      MachineBasicBlock::instr_iterator &MBBI,
26939                                      const TargetInstrInfo *TII) const {
26940   assert(MBBI->isCall() && MBBI->getCFIType() &&
26941          "Invalid call instruction for a KCFI check");
26942 
26943   switch (MBBI->getOpcode()) {
26944   case AArch64::BLR:
26945   case AArch64::BLRNoIP:
26946   case AArch64::TCRETURNri:
26947   case AArch64::TCRETURNrix16x17:
26948   case AArch64::TCRETURNrix17:
26949   case AArch64::TCRETURNrinotx16:
26950     break;
26951   default:
26952     llvm_unreachable("Unexpected CFI call opcode");
26953   }
26954 
26955   MachineOperand &Target = MBBI->getOperand(0);
26956   assert(Target.isReg() && "Invalid target operand for an indirect call");
26957   Target.setIsRenamable(false);
26958 
26959   return BuildMI(MBB, MBBI, MBBI->getDebugLoc(), TII->get(AArch64::KCFI_CHECK))
26960       .addReg(Target.getReg())
26961       .addImm(MBBI->getCFIType())
26962       .getInstr();
26963 }
26964 
enableAggressiveFMAFusion(EVT VT) const26965 bool AArch64TargetLowering::enableAggressiveFMAFusion(EVT VT) const {
26966   return Subtarget->hasAggressiveFMA() && VT.isFloatingPoint();
26967 }
26968 
26969 unsigned
getVaListSizeInBits(const DataLayout & DL) const26970 AArch64TargetLowering::getVaListSizeInBits(const DataLayout &DL) const {
26971   if (Subtarget->isTargetDarwin() || Subtarget->isTargetWindows())
26972     return getPointerTy(DL).getSizeInBits();
26973 
26974   return 3 * getPointerTy(DL).getSizeInBits() + 2 * 32;
26975 }
26976 
finalizeLowering(MachineFunction & MF) const26977 void AArch64TargetLowering::finalizeLowering(MachineFunction &MF) const {
26978   MachineFrameInfo &MFI = MF.getFrameInfo();
26979   // If we have any vulnerable SVE stack objects then the stack protector
26980   // needs to be placed at the top of the SVE stack area, as the SVE locals
26981   // are placed above the other locals, so we allocate it as if it were a
26982   // scalable vector.
26983   // FIXME: It may be worthwhile having a specific interface for this rather
26984   // than doing it here in finalizeLowering.
26985   if (MFI.hasStackProtectorIndex()) {
26986     for (unsigned int i = 0, e = MFI.getObjectIndexEnd(); i != e; ++i) {
26987       if (MFI.getStackID(i) == TargetStackID::ScalableVector &&
26988           MFI.getObjectSSPLayout(i) != MachineFrameInfo::SSPLK_None) {
26989         MFI.setStackID(MFI.getStackProtectorIndex(),
26990                        TargetStackID::ScalableVector);
26991         MFI.setObjectAlignment(MFI.getStackProtectorIndex(), Align(16));
26992         break;
26993       }
26994     }
26995   }
26996   MFI.computeMaxCallFrameSize(MF);
26997   TargetLoweringBase::finalizeLowering(MF);
26998 }
26999 
27000 // Unlike X86, we let frame lowering assign offsets to all catch objects.
needsFixedCatchObjects() const27001 bool AArch64TargetLowering::needsFixedCatchObjects() const {
27002   return false;
27003 }
27004 
shouldLocalize(const MachineInstr & MI,const TargetTransformInfo * TTI) const27005 bool AArch64TargetLowering::shouldLocalize(
27006     const MachineInstr &MI, const TargetTransformInfo *TTI) const {
27007   auto &MF = *MI.getMF();
27008   auto &MRI = MF.getRegInfo();
27009   auto maxUses = [](unsigned RematCost) {
27010     // A cost of 1 means remats are basically free.
27011     if (RematCost == 1)
27012       return std::numeric_limits<unsigned>::max();
27013     if (RematCost == 2)
27014       return 2U;
27015 
27016     // Remat is too expensive, only sink if there's one user.
27017     if (RematCost > 2)
27018       return 1U;
27019     llvm_unreachable("Unexpected remat cost");
27020   };
27021 
27022   unsigned Opc = MI.getOpcode();
27023   switch (Opc) {
27024   case TargetOpcode::G_GLOBAL_VALUE: {
27025     // On Darwin, TLS global vars get selected into function calls, which
27026     // we don't want localized, as they can get moved into the middle of a
27027     // another call sequence.
27028     const GlobalValue &GV = *MI.getOperand(1).getGlobal();
27029     if (GV.isThreadLocal() && Subtarget->isTargetMachO())
27030       return false;
27031     return true; // Always localize G_GLOBAL_VALUE to avoid high reg pressure.
27032   }
27033   case TargetOpcode::G_FCONSTANT:
27034   case TargetOpcode::G_CONSTANT: {
27035     const ConstantInt *CI;
27036     unsigned AdditionalCost = 0;
27037 
27038     if (Opc == TargetOpcode::G_CONSTANT)
27039       CI = MI.getOperand(1).getCImm();
27040     else {
27041       LLT Ty = MRI.getType(MI.getOperand(0).getReg());
27042       // We try to estimate cost of 32/64b fpimms, as they'll likely be
27043       // materialized as integers.
27044       if (Ty.getScalarSizeInBits() != 32 && Ty.getScalarSizeInBits() != 64)
27045         break;
27046       auto APF = MI.getOperand(1).getFPImm()->getValueAPF();
27047       bool OptForSize =
27048           MF.getFunction().hasOptSize() || MF.getFunction().hasMinSize();
27049       if (isFPImmLegal(APF, EVT::getFloatingPointVT(Ty.getScalarSizeInBits()),
27050                        OptForSize))
27051         return true; // Constant should be cheap.
27052       CI =
27053           ConstantInt::get(MF.getFunction().getContext(), APF.bitcastToAPInt());
27054       // FP materialization also costs an extra move, from gpr to fpr.
27055       AdditionalCost = 1;
27056     }
27057     APInt Imm = CI->getValue();
27058     InstructionCost Cost = TTI->getIntImmCost(
27059         Imm, CI->getType(), TargetTransformInfo::TCK_CodeSize);
27060     assert(Cost.isValid() && "Expected a valid imm cost");
27061 
27062     unsigned RematCost = *Cost.getValue();
27063     RematCost += AdditionalCost;
27064     Register Reg = MI.getOperand(0).getReg();
27065     unsigned MaxUses = maxUses(RematCost);
27066     // Don't pass UINT_MAX sentinel value to hasAtMostUserInstrs().
27067     if (MaxUses == std::numeric_limits<unsigned>::max())
27068       --MaxUses;
27069     return MRI.hasAtMostUserInstrs(Reg, MaxUses);
27070   }
27071   // If we legalized G_GLOBAL_VALUE into ADRP + G_ADD_LOW, mark both as being
27072   // localizable.
27073   case AArch64::ADRP:
27074   case AArch64::G_ADD_LOW:
27075   // Need to localize G_PTR_ADD so that G_GLOBAL_VALUE can be localized too.
27076   case TargetOpcode::G_PTR_ADD:
27077     return true;
27078   default:
27079     break;
27080   }
27081   return TargetLoweringBase::shouldLocalize(MI, TTI);
27082 }
27083 
fallBackToDAGISel(const Instruction & Inst) const27084 bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
27085   // Fallback for scalable vectors.
27086   // Note that if EnableSVEGISel is true, we allow scalable vector types for
27087   // all instructions, regardless of whether they are actually supported.
27088   if (!EnableSVEGISel) {
27089     if (Inst.getType()->isScalableTy()) {
27090       return true;
27091     }
27092 
27093     for (unsigned i = 0; i < Inst.getNumOperands(); ++i)
27094       if (Inst.getOperand(i)->getType()->isScalableTy())
27095         return true;
27096 
27097     if (const AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) {
27098       if (AI->getAllocatedType()->isScalableTy())
27099         return true;
27100     }
27101   }
27102 
27103   // Checks to allow the use of SME instructions
27104   if (auto *Base = dyn_cast<CallBase>(&Inst)) {
27105     auto CallerAttrs = SMEAttrs(*Inst.getFunction());
27106     auto CalleeAttrs = SMEAttrs(*Base);
27107     if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
27108         CallerAttrs.requiresLazySave(CalleeAttrs) ||
27109         CallerAttrs.requiresPreservingZT0(CalleeAttrs))
27110       return true;
27111   }
27112   return false;
27113 }
27114 
27115 // Return the largest legal scalable vector type that matches VT's element type.
getContainerForFixedLengthVector(SelectionDAG & DAG,EVT VT)27116 static EVT getContainerForFixedLengthVector(SelectionDAG &DAG, EVT VT) {
27117   assert(VT.isFixedLengthVector() &&
27118          DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
27119          "Expected legal fixed length vector!");
27120   switch (VT.getVectorElementType().getSimpleVT().SimpleTy) {
27121   default:
27122     llvm_unreachable("unexpected element type for SVE container");
27123   case MVT::i8:
27124     return EVT(MVT::nxv16i8);
27125   case MVT::i16:
27126     return EVT(MVT::nxv8i16);
27127   case MVT::i32:
27128     return EVT(MVT::nxv4i32);
27129   case MVT::i64:
27130     return EVT(MVT::nxv2i64);
27131   case MVT::bf16:
27132     return EVT(MVT::nxv8bf16);
27133   case MVT::f16:
27134     return EVT(MVT::nxv8f16);
27135   case MVT::f32:
27136     return EVT(MVT::nxv4f32);
27137   case MVT::f64:
27138     return EVT(MVT::nxv2f64);
27139   }
27140 }
27141 
27142 // Return a PTRUE with active lanes corresponding to the extent of VT.
getPredicateForFixedLengthVector(SelectionDAG & DAG,SDLoc & DL,EVT VT)27143 static SDValue getPredicateForFixedLengthVector(SelectionDAG &DAG, SDLoc &DL,
27144                                                 EVT VT) {
27145   assert(VT.isFixedLengthVector() &&
27146          DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
27147          "Expected legal fixed length vector!");
27148 
27149   std::optional<unsigned> PgPattern =
27150       getSVEPredPatternFromNumElements(VT.getVectorNumElements());
27151   assert(PgPattern && "Unexpected element count for SVE predicate");
27152 
27153   // For vectors that are exactly getMaxSVEVectorSizeInBits big, we can use
27154   // AArch64SVEPredPattern::all, which can enable the use of unpredicated
27155   // variants of instructions when available.
27156   const auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
27157   unsigned MinSVESize = Subtarget.getMinSVEVectorSizeInBits();
27158   unsigned MaxSVESize = Subtarget.getMaxSVEVectorSizeInBits();
27159   if (MaxSVESize && MinSVESize == MaxSVESize &&
27160       MaxSVESize == VT.getSizeInBits())
27161     PgPattern = AArch64SVEPredPattern::all;
27162 
27163   MVT MaskVT;
27164   switch (VT.getVectorElementType().getSimpleVT().SimpleTy) {
27165   default:
27166     llvm_unreachable("unexpected element type for SVE predicate");
27167   case MVT::i8:
27168     MaskVT = MVT::nxv16i1;
27169     break;
27170   case MVT::i16:
27171   case MVT::f16:
27172   case MVT::bf16:
27173     MaskVT = MVT::nxv8i1;
27174     break;
27175   case MVT::i32:
27176   case MVT::f32:
27177     MaskVT = MVT::nxv4i1;
27178     break;
27179   case MVT::i64:
27180   case MVT::f64:
27181     MaskVT = MVT::nxv2i1;
27182     break;
27183   }
27184 
27185   return getPTrue(DAG, DL, MaskVT, *PgPattern);
27186 }
27187 
getPredicateForScalableVector(SelectionDAG & DAG,SDLoc & DL,EVT VT)27188 static SDValue getPredicateForScalableVector(SelectionDAG &DAG, SDLoc &DL,
27189                                              EVT VT) {
27190   assert(VT.isScalableVector() && DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
27191          "Expected legal scalable vector!");
27192   auto PredTy = VT.changeVectorElementType(MVT::i1);
27193   return getPTrue(DAG, DL, PredTy, AArch64SVEPredPattern::all);
27194 }
27195 
getPredicateForVector(SelectionDAG & DAG,SDLoc & DL,EVT VT)27196 static SDValue getPredicateForVector(SelectionDAG &DAG, SDLoc &DL, EVT VT) {
27197   if (VT.isFixedLengthVector())
27198     return getPredicateForFixedLengthVector(DAG, DL, VT);
27199 
27200   return getPredicateForScalableVector(DAG, DL, VT);
27201 }
27202 
27203 // Grow V to consume an entire SVE register.
convertToScalableVector(SelectionDAG & DAG,EVT VT,SDValue V)27204 static SDValue convertToScalableVector(SelectionDAG &DAG, EVT VT, SDValue V) {
27205   assert(VT.isScalableVector() &&
27206          "Expected to convert into a scalable vector!");
27207   assert(V.getValueType().isFixedLengthVector() &&
27208          "Expected a fixed length vector operand!");
27209   SDLoc DL(V);
27210   SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
27211   return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, DAG.getUNDEF(VT), V, Zero);
27212 }
27213 
27214 // Shrink V so it's just big enough to maintain a VT's worth of data.
convertFromScalableVector(SelectionDAG & DAG,EVT VT,SDValue V)27215 static SDValue convertFromScalableVector(SelectionDAG &DAG, EVT VT, SDValue V) {
27216   assert(VT.isFixedLengthVector() &&
27217          "Expected to convert into a fixed length vector!");
27218   assert(V.getValueType().isScalableVector() &&
27219          "Expected a scalable vector operand!");
27220   SDLoc DL(V);
27221   SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
27222   return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, V, Zero);
27223 }
27224 
27225 // Convert all fixed length vector loads larger than NEON to masked_loads.
LowerFixedLengthVectorLoadToSVE(SDValue Op,SelectionDAG & DAG) const27226 SDValue AArch64TargetLowering::LowerFixedLengthVectorLoadToSVE(
27227     SDValue Op, SelectionDAG &DAG) const {
27228   auto Load = cast<LoadSDNode>(Op);
27229 
27230   SDLoc DL(Op);
27231   EVT VT = Op.getValueType();
27232   EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
27233   EVT LoadVT = ContainerVT;
27234   EVT MemVT = Load->getMemoryVT();
27235 
27236   auto Pg = getPredicateForFixedLengthVector(DAG, DL, VT);
27237 
27238   if (VT.isFloatingPoint()) {
27239     LoadVT = ContainerVT.changeTypeToInteger();
27240     MemVT = MemVT.changeTypeToInteger();
27241   }
27242 
27243   SDValue NewLoad = DAG.getMaskedLoad(
27244       LoadVT, DL, Load->getChain(), Load->getBasePtr(), Load->getOffset(), Pg,
27245       DAG.getUNDEF(LoadVT), MemVT, Load->getMemOperand(),
27246       Load->getAddressingMode(), Load->getExtensionType());
27247 
27248   SDValue Result = NewLoad;
27249   if (VT.isFloatingPoint() && Load->getExtensionType() == ISD::EXTLOAD) {
27250     EVT ExtendVT = ContainerVT.changeVectorElementType(
27251         Load->getMemoryVT().getVectorElementType());
27252 
27253     Result = getSVESafeBitCast(ExtendVT, Result, DAG);
27254     Result = DAG.getNode(AArch64ISD::FP_EXTEND_MERGE_PASSTHRU, DL, ContainerVT,
27255                          Pg, Result, DAG.getUNDEF(ContainerVT));
27256   } else if (VT.isFloatingPoint()) {
27257     Result = DAG.getNode(ISD::BITCAST, DL, ContainerVT, Result);
27258   }
27259 
27260   Result = convertFromScalableVector(DAG, VT, Result);
27261   SDValue MergedValues[2] = {Result, NewLoad.getValue(1)};
27262   return DAG.getMergeValues(MergedValues, DL);
27263 }
27264 
convertFixedMaskToScalableVector(SDValue Mask,SelectionDAG & DAG)27265 static SDValue convertFixedMaskToScalableVector(SDValue Mask,
27266                                                 SelectionDAG &DAG) {
27267   SDLoc DL(Mask);
27268   EVT InVT = Mask.getValueType();
27269   EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
27270 
27271   auto Pg = getPredicateForFixedLengthVector(DAG, DL, InVT);
27272 
27273   if (ISD::isBuildVectorAllOnes(Mask.getNode()))
27274     return Pg;
27275 
27276   auto Op1 = convertToScalableVector(DAG, ContainerVT, Mask);
27277   auto Op2 = DAG.getConstant(0, DL, ContainerVT);
27278 
27279   return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, DL, Pg.getValueType(),
27280                      {Pg, Op1, Op2, DAG.getCondCode(ISD::SETNE)});
27281 }
27282 
27283 // Convert all fixed length vector loads larger than NEON to masked_loads.
LowerFixedLengthVectorMLoadToSVE(SDValue Op,SelectionDAG & DAG) const27284 SDValue AArch64TargetLowering::LowerFixedLengthVectorMLoadToSVE(
27285     SDValue Op, SelectionDAG &DAG) const {
27286   auto Load = cast<MaskedLoadSDNode>(Op);
27287 
27288   SDLoc DL(Op);
27289   EVT VT = Op.getValueType();
27290   EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
27291 
27292   SDValue Mask = Load->getMask();
27293   // If this is an extending load and the mask type is not the same as
27294   // load's type then we have to extend the mask type.
27295   if (VT.getScalarSizeInBits() > Mask.getValueType().getScalarSizeInBits()) {
27296     assert(Load->getExtensionType() != ISD::NON_EXTLOAD &&
27297            "Incorrect mask type");
27298     Mask = DAG.getNode(ISD::ANY_EXTEND, DL, VT, Mask);
27299   }
27300   Mask = convertFixedMaskToScalableVector(Mask, DAG);
27301 
27302   SDValue PassThru;
27303   bool IsPassThruZeroOrUndef = false;
27304 
27305   if (Load->getPassThru()->isUndef()) {
27306     PassThru = DAG.getUNDEF(ContainerVT);
27307     IsPassThruZeroOrUndef = true;
27308   } else {
27309     if (ContainerVT.isInteger())
27310       PassThru = DAG.getConstant(0, DL, ContainerVT);
27311     else
27312       PassThru = DAG.getConstantFP(0, DL, ContainerVT);
27313     if (isZerosVector(Load->getPassThru().getNode()))
27314       IsPassThruZeroOrUndef = true;
27315   }
27316 
27317   SDValue NewLoad = DAG.getMaskedLoad(
27318       ContainerVT, DL, Load->getChain(), Load->getBasePtr(), Load->getOffset(),
27319       Mask, PassThru, Load->getMemoryVT(), Load->getMemOperand(),
27320       Load->getAddressingMode(), Load->getExtensionType());
27321 
27322   SDValue Result = NewLoad;
27323   if (!IsPassThruZeroOrUndef) {
27324     SDValue OldPassThru =
27325         convertToScalableVector(DAG, ContainerVT, Load->getPassThru());
27326     Result = DAG.getSelect(DL, ContainerVT, Mask, Result, OldPassThru);
27327   }
27328 
27329   Result = convertFromScalableVector(DAG, VT, Result);
27330   SDValue MergedValues[2] = {Result, NewLoad.getValue(1)};
27331   return DAG.getMergeValues(MergedValues, DL);
27332 }
27333 
27334 // Convert all fixed length vector stores larger than NEON to masked_stores.
LowerFixedLengthVectorStoreToSVE(SDValue Op,SelectionDAG & DAG) const27335 SDValue AArch64TargetLowering::LowerFixedLengthVectorStoreToSVE(
27336     SDValue Op, SelectionDAG &DAG) const {
27337   auto Store = cast<StoreSDNode>(Op);
27338 
27339   SDLoc DL(Op);
27340   EVT VT = Store->getValue().getValueType();
27341   EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
27342   EVT MemVT = Store->getMemoryVT();
27343 
27344   auto Pg = getPredicateForFixedLengthVector(DAG, DL, VT);
27345   auto NewValue = convertToScalableVector(DAG, ContainerVT, Store->getValue());
27346 
27347   if (VT.isFloatingPoint() && Store->isTruncatingStore()) {
27348     EVT TruncVT = ContainerVT.changeVectorElementType(
27349         Store->getMemoryVT().getVectorElementType());
27350     MemVT = MemVT.changeTypeToInteger();
27351     NewValue = DAG.getNode(AArch64ISD::FP_ROUND_MERGE_PASSTHRU, DL, TruncVT, Pg,
27352                            NewValue, DAG.getTargetConstant(0, DL, MVT::i64),
27353                            DAG.getUNDEF(TruncVT));
27354     NewValue =
27355         getSVESafeBitCast(ContainerVT.changeTypeToInteger(), NewValue, DAG);
27356   } else if (VT.isFloatingPoint()) {
27357     MemVT = MemVT.changeTypeToInteger();
27358     NewValue =
27359         getSVESafeBitCast(ContainerVT.changeTypeToInteger(), NewValue, DAG);
27360   }
27361 
27362   return DAG.getMaskedStore(Store->getChain(), DL, NewValue,
27363                             Store->getBasePtr(), Store->getOffset(), Pg, MemVT,
27364                             Store->getMemOperand(), Store->getAddressingMode(),
27365                             Store->isTruncatingStore());
27366 }
27367 
LowerFixedLengthVectorMStoreToSVE(SDValue Op,SelectionDAG & DAG) const27368 SDValue AArch64TargetLowering::LowerFixedLengthVectorMStoreToSVE(
27369     SDValue Op, SelectionDAG &DAG) const {
27370   auto *Store = cast<MaskedStoreSDNode>(Op);
27371 
27372   SDLoc DL(Op);
27373   EVT VT = Store->getValue().getValueType();
27374   EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
27375 
27376   auto NewValue = convertToScalableVector(DAG, ContainerVT, Store->getValue());
27377   SDValue Mask = convertFixedMaskToScalableVector(Store->getMask(), DAG);
27378 
27379   return DAG.getMaskedStore(
27380       Store->getChain(), DL, NewValue, Store->getBasePtr(), Store->getOffset(),
27381       Mask, Store->getMemoryVT(), Store->getMemOperand(),
27382       Store->getAddressingMode(), Store->isTruncatingStore());
27383 }
27384 
LowerFixedLengthVectorIntDivideToSVE(SDValue Op,SelectionDAG & DAG) const27385 SDValue AArch64TargetLowering::LowerFixedLengthVectorIntDivideToSVE(
27386     SDValue Op, SelectionDAG &DAG) const {
27387   SDLoc dl(Op);
27388   EVT VT = Op.getValueType();
27389   EVT EltVT = VT.getVectorElementType();
27390 
27391   bool Signed = Op.getOpcode() == ISD::SDIV;
27392   unsigned PredOpcode = Signed ? AArch64ISD::SDIV_PRED : AArch64ISD::UDIV_PRED;
27393 
27394   bool Negated;
27395   uint64_t SplatVal;
27396   if (Signed && isPow2Splat(Op.getOperand(1), SplatVal, Negated)) {
27397     EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
27398     SDValue Op1 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(0));
27399     SDValue Op2 = DAG.getTargetConstant(Log2_64(SplatVal), dl, MVT::i32);
27400 
27401     SDValue Pg = getPredicateForFixedLengthVector(DAG, dl, VT);
27402     SDValue Res =
27403         DAG.getNode(AArch64ISD::SRAD_MERGE_OP1, dl, ContainerVT, Pg, Op1, Op2);
27404     if (Negated)
27405       Res = DAG.getNode(ISD::SUB, dl, ContainerVT,
27406                         DAG.getConstant(0, dl, ContainerVT), Res);
27407 
27408     return convertFromScalableVector(DAG, VT, Res);
27409   }
27410 
27411   // Scalable vector i32/i64 DIV is supported.
27412   if (EltVT == MVT::i32 || EltVT == MVT::i64)
27413     return LowerToPredicatedOp(Op, DAG, PredOpcode);
27414 
27415   // Scalable vector i8/i16 DIV is not supported. Promote it to i32.
27416   EVT HalfVT = VT.getHalfNumVectorElementsVT(*DAG.getContext());
27417   EVT PromVT = HalfVT.widenIntegerVectorElementType(*DAG.getContext());
27418   unsigned ExtendOpcode = Signed ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
27419 
27420   // If the wider type is legal: extend, op, and truncate.
27421   EVT WideVT = VT.widenIntegerVectorElementType(*DAG.getContext());
27422   if (DAG.getTargetLoweringInfo().isTypeLegal(WideVT)) {
27423     SDValue Op0 = DAG.getNode(ExtendOpcode, dl, WideVT, Op.getOperand(0));
27424     SDValue Op1 = DAG.getNode(ExtendOpcode, dl, WideVT, Op.getOperand(1));
27425     SDValue Div = DAG.getNode(Op.getOpcode(), dl, WideVT, Op0, Op1);
27426     return DAG.getNode(ISD::TRUNCATE, dl, VT, Div);
27427   }
27428 
27429   auto HalveAndExtendVector = [&DAG, &dl, &HalfVT, &PromVT,
27430                                &ExtendOpcode](SDValue Op) {
27431     SDValue IdxZero = DAG.getConstant(0, dl, MVT::i64);
27432     SDValue IdxHalf =
27433         DAG.getConstant(HalfVT.getVectorNumElements(), dl, MVT::i64);
27434     SDValue Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, HalfVT, Op, IdxZero);
27435     SDValue Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, HalfVT, Op, IdxHalf);
27436     return std::pair<SDValue, SDValue>(
27437         {DAG.getNode(ExtendOpcode, dl, PromVT, Lo),
27438          DAG.getNode(ExtendOpcode, dl, PromVT, Hi)});
27439   };
27440 
27441   // If wider type is not legal: split, extend, op, trunc and concat.
27442   auto [Op0LoExt, Op0HiExt] = HalveAndExtendVector(Op.getOperand(0));
27443   auto [Op1LoExt, Op1HiExt] = HalveAndExtendVector(Op.getOperand(1));
27444   SDValue Lo = DAG.getNode(Op.getOpcode(), dl, PromVT, Op0LoExt, Op1LoExt);
27445   SDValue Hi = DAG.getNode(Op.getOpcode(), dl, PromVT, Op0HiExt, Op1HiExt);
27446   SDValue LoTrunc = DAG.getNode(ISD::TRUNCATE, dl, HalfVT, Lo);
27447   SDValue HiTrunc = DAG.getNode(ISD::TRUNCATE, dl, HalfVT, Hi);
27448   return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, {LoTrunc, HiTrunc});
27449 }
27450 
LowerFixedLengthVectorIntExtendToSVE(SDValue Op,SelectionDAG & DAG) const27451 SDValue AArch64TargetLowering::LowerFixedLengthVectorIntExtendToSVE(
27452     SDValue Op, SelectionDAG &DAG) const {
27453   EVT VT = Op.getValueType();
27454   assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
27455 
27456   SDLoc DL(Op);
27457   SDValue Val = Op.getOperand(0);
27458   EVT ContainerVT = getContainerForFixedLengthVector(DAG, Val.getValueType());
27459   Val = convertToScalableVector(DAG, ContainerVT, Val);
27460 
27461   bool Signed = Op.getOpcode() == ISD::SIGN_EXTEND;
27462   unsigned ExtendOpc = Signed ? AArch64ISD::SUNPKLO : AArch64ISD::UUNPKLO;
27463 
27464   // Repeatedly unpack Val until the result is of the desired element type.
27465   switch (ContainerVT.getSimpleVT().SimpleTy) {
27466   default:
27467     llvm_unreachable("unimplemented container type");
27468   case MVT::nxv16i8:
27469     Val = DAG.getNode(ExtendOpc, DL, MVT::nxv8i16, Val);
27470     if (VT.getVectorElementType() == MVT::i16)
27471       break;
27472     [[fallthrough]];
27473   case MVT::nxv8i16:
27474     Val = DAG.getNode(ExtendOpc, DL, MVT::nxv4i32, Val);
27475     if (VT.getVectorElementType() == MVT::i32)
27476       break;
27477     [[fallthrough]];
27478   case MVT::nxv4i32:
27479     Val = DAG.getNode(ExtendOpc, DL, MVT::nxv2i64, Val);
27480     assert(VT.getVectorElementType() == MVT::i64 && "Unexpected element type!");
27481     break;
27482   }
27483 
27484   return convertFromScalableVector(DAG, VT, Val);
27485 }
27486 
LowerFixedLengthVectorTruncateToSVE(SDValue Op,SelectionDAG & DAG) const27487 SDValue AArch64TargetLowering::LowerFixedLengthVectorTruncateToSVE(
27488     SDValue Op, SelectionDAG &DAG) const {
27489   EVT VT = Op.getValueType();
27490   assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
27491 
27492   SDLoc DL(Op);
27493   SDValue Val = Op.getOperand(0);
27494   EVT ContainerVT = getContainerForFixedLengthVector(DAG, Val.getValueType());
27495   Val = convertToScalableVector(DAG, ContainerVT, Val);
27496 
27497   // Repeatedly truncate Val until the result is of the desired element type.
27498   switch (ContainerVT.getSimpleVT().SimpleTy) {
27499   default:
27500     llvm_unreachable("unimplemented container type");
27501   case MVT::nxv2i64:
27502     Val = DAG.getNode(ISD::BITCAST, DL, MVT::nxv4i32, Val);
27503     Val = DAG.getNode(AArch64ISD::UZP1, DL, MVT::nxv4i32, Val, Val);
27504     if (VT.getVectorElementType() == MVT::i32)
27505       break;
27506     [[fallthrough]];
27507   case MVT::nxv4i32:
27508     Val = DAG.getNode(ISD::BITCAST, DL, MVT::nxv8i16, Val);
27509     Val = DAG.getNode(AArch64ISD::UZP1, DL, MVT::nxv8i16, Val, Val);
27510     if (VT.getVectorElementType() == MVT::i16)
27511       break;
27512     [[fallthrough]];
27513   case MVT::nxv8i16:
27514     Val = DAG.getNode(ISD::BITCAST, DL, MVT::nxv16i8, Val);
27515     Val = DAG.getNode(AArch64ISD::UZP1, DL, MVT::nxv16i8, Val, Val);
27516     assert(VT.getVectorElementType() == MVT::i8 && "Unexpected element type!");
27517     break;
27518   }
27519 
27520   return convertFromScalableVector(DAG, VT, Val);
27521 }
27522 
LowerFixedLengthExtractVectorElt(SDValue Op,SelectionDAG & DAG) const27523 SDValue AArch64TargetLowering::LowerFixedLengthExtractVectorElt(
27524     SDValue Op, SelectionDAG &DAG) const {
27525   EVT VT = Op.getValueType();
27526   EVT InVT = Op.getOperand(0).getValueType();
27527   assert(InVT.isFixedLengthVector() && "Expected fixed length vector type!");
27528 
27529   SDLoc DL(Op);
27530   EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
27531   SDValue Op0 = convertToScalableVector(DAG, ContainerVT, Op->getOperand(0));
27532 
27533   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Op.getOperand(1));
27534 }
27535 
LowerFixedLengthInsertVectorElt(SDValue Op,SelectionDAG & DAG) const27536 SDValue AArch64TargetLowering::LowerFixedLengthInsertVectorElt(
27537     SDValue Op, SelectionDAG &DAG) const {
27538   EVT VT = Op.getValueType();
27539   assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
27540 
27541   SDLoc DL(Op);
27542   EVT InVT = Op.getOperand(0).getValueType();
27543   EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
27544   SDValue Op0 = convertToScalableVector(DAG, ContainerVT, Op->getOperand(0));
27545 
27546   auto ScalableRes = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ContainerVT, Op0,
27547                                  Op.getOperand(1), Op.getOperand(2));
27548 
27549   return convertFromScalableVector(DAG, VT, ScalableRes);
27550 }
27551 
27552 // Convert vector operation 'Op' to an equivalent predicated operation whereby
27553 // the original operation's type is used to construct a suitable predicate.
27554 // NOTE: The results for inactive lanes are undefined.
LowerToPredicatedOp(SDValue Op,SelectionDAG & DAG,unsigned NewOp) const27555 SDValue AArch64TargetLowering::LowerToPredicatedOp(SDValue Op,
27556                                                    SelectionDAG &DAG,
27557                                                    unsigned NewOp) const {
27558   EVT VT = Op.getValueType();
27559   SDLoc DL(Op);
27560   auto Pg = getPredicateForVector(DAG, DL, VT);
27561 
27562   if (VT.isFixedLengthVector()) {
27563     assert(isTypeLegal(VT) && "Expected only legal fixed-width types");
27564     EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
27565 
27566     // Create list of operands by converting existing ones to scalable types.
27567     SmallVector<SDValue, 4> Operands = {Pg};
27568     for (const SDValue &V : Op->op_values()) {
27569       if (isa<CondCodeSDNode>(V)) {
27570         Operands.push_back(V);
27571         continue;
27572       }
27573 
27574       if (const VTSDNode *VTNode = dyn_cast<VTSDNode>(V)) {
27575         EVT VTArg = VTNode->getVT().getVectorElementType();
27576         EVT NewVTArg = ContainerVT.changeVectorElementType(VTArg);
27577         Operands.push_back(DAG.getValueType(NewVTArg));
27578         continue;
27579       }
27580 
27581       assert(isTypeLegal(V.getValueType()) &&
27582              "Expected only legal fixed-width types");
27583       Operands.push_back(convertToScalableVector(DAG, ContainerVT, V));
27584     }
27585 
27586     if (isMergePassthruOpcode(NewOp))
27587       Operands.push_back(DAG.getUNDEF(ContainerVT));
27588 
27589     auto ScalableRes = DAG.getNode(NewOp, DL, ContainerVT, Operands);
27590     return convertFromScalableVector(DAG, VT, ScalableRes);
27591   }
27592 
27593   assert(VT.isScalableVector() && "Only expect to lower scalable vector op!");
27594 
27595   SmallVector<SDValue, 4> Operands = {Pg};
27596   for (const SDValue &V : Op->op_values()) {
27597     assert((!V.getValueType().isVector() ||
27598             V.getValueType().isScalableVector()) &&
27599            "Only scalable vectors are supported!");
27600     Operands.push_back(V);
27601   }
27602 
27603   if (isMergePassthruOpcode(NewOp))
27604     Operands.push_back(DAG.getUNDEF(VT));
27605 
27606   return DAG.getNode(NewOp, DL, VT, Operands, Op->getFlags());
27607 }
27608 
27609 // If a fixed length vector operation has no side effects when applied to
27610 // undefined elements, we can safely use scalable vectors to perform the same
27611 // operation without needing to worry about predication.
LowerToScalableOp(SDValue Op,SelectionDAG & DAG) const27612 SDValue AArch64TargetLowering::LowerToScalableOp(SDValue Op,
27613                                                  SelectionDAG &DAG) const {
27614   EVT VT = Op.getValueType();
27615   assert(VT.isFixedLengthVector() && isTypeLegal(VT) &&
27616          "Only expected to lower fixed length vector operation!");
27617   EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
27618 
27619   // Create list of operands by converting existing ones to scalable types.
27620   SmallVector<SDValue, 4> Ops;
27621   for (const SDValue &V : Op->op_values()) {
27622     assert(!isa<VTSDNode>(V) && "Unexpected VTSDNode node!");
27623 
27624     // Pass through non-vector operands.
27625     if (!V.getValueType().isVector()) {
27626       Ops.push_back(V);
27627       continue;
27628     }
27629 
27630     // "cast" fixed length vector to a scalable vector.
27631     assert(V.getValueType().isFixedLengthVector() &&
27632            isTypeLegal(V.getValueType()) &&
27633            "Only fixed length vectors are supported!");
27634     Ops.push_back(convertToScalableVector(DAG, ContainerVT, V));
27635   }
27636 
27637   auto ScalableRes = DAG.getNode(Op.getOpcode(), SDLoc(Op), ContainerVT, Ops);
27638   return convertFromScalableVector(DAG, VT, ScalableRes);
27639 }
27640 
LowerVECREDUCE_SEQ_FADD(SDValue ScalarOp,SelectionDAG & DAG) const27641 SDValue AArch64TargetLowering::LowerVECREDUCE_SEQ_FADD(SDValue ScalarOp,
27642     SelectionDAG &DAG) const {
27643   SDLoc DL(ScalarOp);
27644   SDValue AccOp = ScalarOp.getOperand(0);
27645   SDValue VecOp = ScalarOp.getOperand(1);
27646   EVT SrcVT = VecOp.getValueType();
27647   EVT ResVT = SrcVT.getVectorElementType();
27648 
27649   EVT ContainerVT = SrcVT;
27650   if (SrcVT.isFixedLengthVector()) {
27651     ContainerVT = getContainerForFixedLengthVector(DAG, SrcVT);
27652     VecOp = convertToScalableVector(DAG, ContainerVT, VecOp);
27653   }
27654 
27655   SDValue Pg = getPredicateForVector(DAG, DL, SrcVT);
27656   SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
27657 
27658   // Convert operands to Scalable.
27659   AccOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ContainerVT,
27660                       DAG.getUNDEF(ContainerVT), AccOp, Zero);
27661 
27662   // Perform reduction.
27663   SDValue Rdx = DAG.getNode(AArch64ISD::FADDA_PRED, DL, ContainerVT,
27664                             Pg, AccOp, VecOp);
27665 
27666   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Rdx, Zero);
27667 }
27668 
LowerPredReductionToSVE(SDValue ReduceOp,SelectionDAG & DAG) const27669 SDValue AArch64TargetLowering::LowerPredReductionToSVE(SDValue ReduceOp,
27670                                                        SelectionDAG &DAG) const {
27671   SDLoc DL(ReduceOp);
27672   SDValue Op = ReduceOp.getOperand(0);
27673   EVT OpVT = Op.getValueType();
27674   EVT VT = ReduceOp.getValueType();
27675 
27676   if (!OpVT.isScalableVector() || OpVT.getVectorElementType() != MVT::i1)
27677     return SDValue();
27678 
27679   SDValue Pg = getPredicateForVector(DAG, DL, OpVT);
27680 
27681   switch (ReduceOp.getOpcode()) {
27682   default:
27683     return SDValue();
27684   case ISD::VECREDUCE_OR:
27685     if (isAllActivePredicate(DAG, Pg) && OpVT == MVT::nxv16i1)
27686       // The predicate can be 'Op' because
27687       // vecreduce_or(Op & <all true>) <=> vecreduce_or(Op).
27688       return getPTest(DAG, VT, Op, Op, AArch64CC::ANY_ACTIVE);
27689     else
27690       return getPTest(DAG, VT, Pg, Op, AArch64CC::ANY_ACTIVE);
27691   case ISD::VECREDUCE_AND: {
27692     Op = DAG.getNode(ISD::XOR, DL, OpVT, Op, Pg);
27693     return getPTest(DAG, VT, Pg, Op, AArch64CC::NONE_ACTIVE);
27694   }
27695   case ISD::VECREDUCE_XOR: {
27696     SDValue ID =
27697         DAG.getTargetConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64);
27698     if (OpVT == MVT::nxv1i1) {
27699       // Emulate a CNTP on .Q using .D and a different governing predicate.
27700       Pg = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, MVT::nxv2i1, Pg);
27701       Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, MVT::nxv2i1, Op);
27702     }
27703     SDValue Cntp =
27704         DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64, ID, Pg, Op);
27705     return DAG.getAnyExtOrTrunc(Cntp, DL, VT);
27706   }
27707   }
27708 
27709   return SDValue();
27710 }
27711 
LowerReductionToSVE(unsigned Opcode,SDValue ScalarOp,SelectionDAG & DAG) const27712 SDValue AArch64TargetLowering::LowerReductionToSVE(unsigned Opcode,
27713                                                    SDValue ScalarOp,
27714                                                    SelectionDAG &DAG) const {
27715   SDLoc DL(ScalarOp);
27716   SDValue VecOp = ScalarOp.getOperand(0);
27717   EVT SrcVT = VecOp.getValueType();
27718 
27719   if (useSVEForFixedLengthVectorVT(
27720           SrcVT,
27721           /*OverrideNEON=*/Subtarget->useSVEForFixedLengthVectors())) {
27722     EVT ContainerVT = getContainerForFixedLengthVector(DAG, SrcVT);
27723     VecOp = convertToScalableVector(DAG, ContainerVT, VecOp);
27724   }
27725 
27726   // UADDV always returns an i64 result.
27727   EVT ResVT = (Opcode == AArch64ISD::UADDV_PRED) ? MVT::i64 :
27728                                                    SrcVT.getVectorElementType();
27729   EVT RdxVT = SrcVT;
27730   if (SrcVT.isFixedLengthVector() || Opcode == AArch64ISD::UADDV_PRED)
27731     RdxVT = getPackedSVEVectorVT(ResVT);
27732 
27733   SDValue Pg = getPredicateForVector(DAG, DL, SrcVT);
27734   SDValue Rdx = DAG.getNode(Opcode, DL, RdxVT, Pg, VecOp);
27735   SDValue Res = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT,
27736                             Rdx, DAG.getConstant(0, DL, MVT::i64));
27737 
27738   // The VEC_REDUCE nodes expect an element size result.
27739   if (ResVT != ScalarOp.getValueType())
27740     Res = DAG.getAnyExtOrTrunc(Res, DL, ScalarOp.getValueType());
27741 
27742   return Res;
27743 }
27744 
27745 SDValue
LowerFixedLengthVectorSelectToSVE(SDValue Op,SelectionDAG & DAG) const27746 AArch64TargetLowering::LowerFixedLengthVectorSelectToSVE(SDValue Op,
27747     SelectionDAG &DAG) const {
27748   EVT VT = Op.getValueType();
27749   SDLoc DL(Op);
27750 
27751   EVT InVT = Op.getOperand(1).getValueType();
27752   EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
27753   SDValue Op1 = convertToScalableVector(DAG, ContainerVT, Op->getOperand(1));
27754   SDValue Op2 = convertToScalableVector(DAG, ContainerVT, Op->getOperand(2));
27755 
27756   // Convert the mask to a predicated (NOTE: We don't need to worry about
27757   // inactive lanes since VSELECT is safe when given undefined elements).
27758   EVT MaskVT = Op.getOperand(0).getValueType();
27759   EVT MaskContainerVT = getContainerForFixedLengthVector(DAG, MaskVT);
27760   auto Mask = convertToScalableVector(DAG, MaskContainerVT, Op.getOperand(0));
27761   Mask = DAG.getNode(ISD::TRUNCATE, DL,
27762                      MaskContainerVT.changeVectorElementType(MVT::i1), Mask);
27763 
27764   auto ScalableRes = DAG.getNode(ISD::VSELECT, DL, ContainerVT,
27765                                 Mask, Op1, Op2);
27766 
27767   return convertFromScalableVector(DAG, VT, ScalableRes);
27768 }
27769 
LowerFixedLengthVectorSetccToSVE(SDValue Op,SelectionDAG & DAG) const27770 SDValue AArch64TargetLowering::LowerFixedLengthVectorSetccToSVE(
27771     SDValue Op, SelectionDAG &DAG) const {
27772   SDLoc DL(Op);
27773   EVT InVT = Op.getOperand(0).getValueType();
27774   EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
27775 
27776   assert(InVT.isFixedLengthVector() && isTypeLegal(InVT) &&
27777          "Only expected to lower fixed length vector operation!");
27778   assert(Op.getValueType() == InVT.changeTypeToInteger() &&
27779          "Expected integer result of the same bit length as the inputs!");
27780 
27781   auto Op1 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(0));
27782   auto Op2 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(1));
27783   auto Pg = getPredicateForFixedLengthVector(DAG, DL, InVT);
27784 
27785   EVT CmpVT = Pg.getValueType();
27786   auto Cmp = DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, DL, CmpVT,
27787                          {Pg, Op1, Op2, Op.getOperand(2)});
27788 
27789   EVT PromoteVT = ContainerVT.changeTypeToInteger();
27790   auto Promote = DAG.getBoolExtOrTrunc(Cmp, DL, PromoteVT, InVT);
27791   return convertFromScalableVector(DAG, Op.getValueType(), Promote);
27792 }
27793 
27794 SDValue
LowerFixedLengthBitcastToSVE(SDValue Op,SelectionDAG & DAG) const27795 AArch64TargetLowering::LowerFixedLengthBitcastToSVE(SDValue Op,
27796                                                     SelectionDAG &DAG) const {
27797   SDLoc DL(Op);
27798   auto SrcOp = Op.getOperand(0);
27799   EVT VT = Op.getValueType();
27800   EVT ContainerDstVT = getContainerForFixedLengthVector(DAG, VT);
27801   EVT ContainerSrcVT =
27802       getContainerForFixedLengthVector(DAG, SrcOp.getValueType());
27803 
27804   SrcOp = convertToScalableVector(DAG, ContainerSrcVT, SrcOp);
27805   Op = DAG.getNode(ISD::BITCAST, DL, ContainerDstVT, SrcOp);
27806   return convertFromScalableVector(DAG, VT, Op);
27807 }
27808 
LowerFixedLengthConcatVectorsToSVE(SDValue Op,SelectionDAG & DAG) const27809 SDValue AArch64TargetLowering::LowerFixedLengthConcatVectorsToSVE(
27810     SDValue Op, SelectionDAG &DAG) const {
27811   SDLoc DL(Op);
27812   unsigned NumOperands = Op->getNumOperands();
27813 
27814   assert(NumOperands > 1 && isPowerOf2_32(NumOperands) &&
27815          "Unexpected number of operands in CONCAT_VECTORS");
27816 
27817   auto SrcOp1 = Op.getOperand(0);
27818   auto SrcOp2 = Op.getOperand(1);
27819   EVT VT = Op.getValueType();
27820   EVT SrcVT = SrcOp1.getValueType();
27821 
27822   if (NumOperands > 2) {
27823     SmallVector<SDValue, 4> Ops;
27824     EVT PairVT = SrcVT.getDoubleNumVectorElementsVT(*DAG.getContext());
27825     for (unsigned I = 0; I < NumOperands; I += 2)
27826       Ops.push_back(DAG.getNode(ISD::CONCAT_VECTORS, DL, PairVT,
27827                                 Op->getOperand(I), Op->getOperand(I + 1)));
27828 
27829     return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Ops);
27830   }
27831 
27832   EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
27833 
27834   SDValue Pg = getPredicateForFixedLengthVector(DAG, DL, SrcVT);
27835   SrcOp1 = convertToScalableVector(DAG, ContainerVT, SrcOp1);
27836   SrcOp2 = convertToScalableVector(DAG, ContainerVT, SrcOp2);
27837 
27838   Op = DAG.getNode(AArch64ISD::SPLICE, DL, ContainerVT, Pg, SrcOp1, SrcOp2);
27839 
27840   return convertFromScalableVector(DAG, VT, Op);
27841 }
27842 
27843 SDValue
LowerFixedLengthFPExtendToSVE(SDValue Op,SelectionDAG & DAG) const27844 AArch64TargetLowering::LowerFixedLengthFPExtendToSVE(SDValue Op,
27845                                                      SelectionDAG &DAG) const {
27846   EVT VT = Op.getValueType();
27847   assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
27848 
27849   SDLoc DL(Op);
27850   SDValue Val = Op.getOperand(0);
27851   SDValue Pg = getPredicateForVector(DAG, DL, VT);
27852   EVT SrcVT = Val.getValueType();
27853   EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
27854   EVT ExtendVT = ContainerVT.changeVectorElementType(
27855       SrcVT.getVectorElementType());
27856 
27857   Val = DAG.getNode(ISD::BITCAST, DL, SrcVT.changeTypeToInteger(), Val);
27858   Val = DAG.getNode(ISD::ANY_EXTEND, DL, VT.changeTypeToInteger(), Val);
27859 
27860   Val = convertToScalableVector(DAG, ContainerVT.changeTypeToInteger(), Val);
27861   Val = getSVESafeBitCast(ExtendVT, Val, DAG);
27862   Val = DAG.getNode(AArch64ISD::FP_EXTEND_MERGE_PASSTHRU, DL, ContainerVT,
27863                     Pg, Val, DAG.getUNDEF(ContainerVT));
27864 
27865   return convertFromScalableVector(DAG, VT, Val);
27866 }
27867 
27868 SDValue
LowerFixedLengthFPRoundToSVE(SDValue Op,SelectionDAG & DAG) const27869 AArch64TargetLowering::LowerFixedLengthFPRoundToSVE(SDValue Op,
27870                                                     SelectionDAG &DAG) const {
27871   EVT VT = Op.getValueType();
27872   assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
27873 
27874   SDLoc DL(Op);
27875   SDValue Val = Op.getOperand(0);
27876   EVT SrcVT = Val.getValueType();
27877   EVT ContainerSrcVT = getContainerForFixedLengthVector(DAG, SrcVT);
27878   EVT RoundVT = ContainerSrcVT.changeVectorElementType(
27879       VT.getVectorElementType());
27880   SDValue Pg = getPredicateForVector(DAG, DL, RoundVT);
27881 
27882   Val = convertToScalableVector(DAG, ContainerSrcVT, Val);
27883   Val = DAG.getNode(AArch64ISD::FP_ROUND_MERGE_PASSTHRU, DL, RoundVT, Pg, Val,
27884                     Op.getOperand(1), DAG.getUNDEF(RoundVT));
27885   Val = getSVESafeBitCast(ContainerSrcVT.changeTypeToInteger(), Val, DAG);
27886   Val = convertFromScalableVector(DAG, SrcVT.changeTypeToInteger(), Val);
27887 
27888   Val = DAG.getNode(ISD::TRUNCATE, DL, VT.changeTypeToInteger(), Val);
27889   return DAG.getNode(ISD::BITCAST, DL, VT, Val);
27890 }
27891 
27892 SDValue
LowerFixedLengthIntToFPToSVE(SDValue Op,SelectionDAG & DAG) const27893 AArch64TargetLowering::LowerFixedLengthIntToFPToSVE(SDValue Op,
27894                                                     SelectionDAG &DAG) const {
27895   EVT VT = Op.getValueType();
27896   assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
27897 
27898   bool IsSigned = Op.getOpcode() == ISD::SINT_TO_FP;
27899   unsigned Opcode = IsSigned ? AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU
27900                              : AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU;
27901 
27902   SDLoc DL(Op);
27903   SDValue Val = Op.getOperand(0);
27904   EVT SrcVT = Val.getValueType();
27905   EVT ContainerDstVT = getContainerForFixedLengthVector(DAG, VT);
27906   EVT ContainerSrcVT = getContainerForFixedLengthVector(DAG, SrcVT);
27907 
27908   if (VT.bitsGE(SrcVT)) {
27909     SDValue Pg = getPredicateForFixedLengthVector(DAG, DL, VT);
27910 
27911     Val = DAG.getNode(IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL,
27912                       VT.changeTypeToInteger(), Val);
27913 
27914     // Safe to use a larger than specified operand because by promoting the
27915     // value nothing has changed from an arithmetic point of view.
27916     Val =
27917         convertToScalableVector(DAG, ContainerDstVT.changeTypeToInteger(), Val);
27918     Val = DAG.getNode(Opcode, DL, ContainerDstVT, Pg, Val,
27919                       DAG.getUNDEF(ContainerDstVT));
27920     return convertFromScalableVector(DAG, VT, Val);
27921   } else {
27922     EVT CvtVT = ContainerSrcVT.changeVectorElementType(
27923         ContainerDstVT.getVectorElementType());
27924     SDValue Pg = getPredicateForFixedLengthVector(DAG, DL, SrcVT);
27925 
27926     Val = convertToScalableVector(DAG, ContainerSrcVT, Val);
27927     Val = DAG.getNode(Opcode, DL, CvtVT, Pg, Val, DAG.getUNDEF(CvtVT));
27928     Val = getSVESafeBitCast(ContainerSrcVT, Val, DAG);
27929     Val = convertFromScalableVector(DAG, SrcVT, Val);
27930 
27931     Val = DAG.getNode(ISD::TRUNCATE, DL, VT.changeTypeToInteger(), Val);
27932     return DAG.getNode(ISD::BITCAST, DL, VT, Val);
27933   }
27934 }
27935 
27936 SDValue
LowerVECTOR_DEINTERLEAVE(SDValue Op,SelectionDAG & DAG) const27937 AArch64TargetLowering::LowerVECTOR_DEINTERLEAVE(SDValue Op,
27938                                                 SelectionDAG &DAG) const {
27939   SDLoc DL(Op);
27940   EVT OpVT = Op.getValueType();
27941   assert(OpVT.isScalableVector() &&
27942          "Expected scalable vector in LowerVECTOR_DEINTERLEAVE.");
27943   SDValue Even = DAG.getNode(AArch64ISD::UZP1, DL, OpVT, Op.getOperand(0),
27944                              Op.getOperand(1));
27945   SDValue Odd = DAG.getNode(AArch64ISD::UZP2, DL, OpVT, Op.getOperand(0),
27946                             Op.getOperand(1));
27947   return DAG.getMergeValues({Even, Odd}, DL);
27948 }
27949 
LowerVECTOR_INTERLEAVE(SDValue Op,SelectionDAG & DAG) const27950 SDValue AArch64TargetLowering::LowerVECTOR_INTERLEAVE(SDValue Op,
27951                                                       SelectionDAG &DAG) const {
27952   SDLoc DL(Op);
27953   EVT OpVT = Op.getValueType();
27954   assert(OpVT.isScalableVector() &&
27955          "Expected scalable vector in LowerVECTOR_INTERLEAVE.");
27956 
27957   SDValue Lo = DAG.getNode(AArch64ISD::ZIP1, DL, OpVT, Op.getOperand(0),
27958                            Op.getOperand(1));
27959   SDValue Hi = DAG.getNode(AArch64ISD::ZIP2, DL, OpVT, Op.getOperand(0),
27960                            Op.getOperand(1));
27961   return DAG.getMergeValues({Lo, Hi}, DL);
27962 }
27963 
LowerVECTOR_HISTOGRAM(SDValue Op,SelectionDAG & DAG) const27964 SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
27965                                                      SelectionDAG &DAG) const {
27966   // FIXME: Maybe share some code with LowerMGather/Scatter?
27967   MaskedHistogramSDNode *HG = cast<MaskedHistogramSDNode>(Op);
27968   SDLoc DL(HG);
27969   SDValue Chain = HG->getChain();
27970   SDValue Inc = HG->getInc();
27971   SDValue Mask = HG->getMask();
27972   SDValue Ptr = HG->getBasePtr();
27973   SDValue Index = HG->getIndex();
27974   SDValue Scale = HG->getScale();
27975   SDValue IntID = HG->getIntID();
27976 
27977   // The Intrinsic ID determines the type of update operation.
27978   [[maybe_unused]] ConstantSDNode *CID = cast<ConstantSDNode>(IntID.getNode());
27979   // Right now, we only support 'add' as an update.
27980   assert(CID->getZExtValue() == Intrinsic::experimental_vector_histogram_add &&
27981          "Unexpected histogram update operation");
27982 
27983   EVT IncVT = Inc.getValueType();
27984   EVT IndexVT = Index.getValueType();
27985   EVT MemVT = EVT::getVectorVT(*DAG.getContext(), IncVT,
27986                                IndexVT.getVectorElementCount());
27987   SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
27988   SDValue PassThru = DAG.getSplatVector(MemVT, DL, Zero);
27989   SDValue IncSplat = DAG.getSplatVector(MemVT, DL, Inc);
27990   SDValue Ops[] = {Chain, PassThru, Mask, Ptr, Index, Scale};
27991 
27992   MachineMemOperand *MMO = HG->getMemOperand();
27993   // Create an MMO for the gather, without load|store flags.
27994   MachineMemOperand *GMMO = DAG.getMachineFunction().getMachineMemOperand(
27995       MMO->getPointerInfo(), MachineMemOperand::MOLoad, MMO->getSize(),
27996       MMO->getAlign(), MMO->getAAInfo());
27997   ISD::MemIndexType IndexType = HG->getIndexType();
27998   SDValue Gather =
27999       DAG.getMaskedGather(DAG.getVTList(MemVT, MVT::Other), MemVT, DL, Ops,
28000                           GMMO, IndexType, ISD::NON_EXTLOAD);
28001 
28002   SDValue GChain = Gather.getValue(1);
28003 
28004   // Perform the histcnt, multiply by inc, add to bucket data.
28005   SDValue ID = DAG.getTargetConstant(Intrinsic::aarch64_sve_histcnt, DL, IncVT);
28006   SDValue HistCnt =
28007       DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, IndexVT, ID, Mask, Index, Index);
28008   SDValue Mul = DAG.getNode(ISD::MUL, DL, MemVT, HistCnt, IncSplat);
28009   SDValue Add = DAG.getNode(ISD::ADD, DL, MemVT, Gather, Mul);
28010 
28011   // Create an MMO for the scatter, without load|store flags.
28012   MachineMemOperand *SMMO = DAG.getMachineFunction().getMachineMemOperand(
28013       MMO->getPointerInfo(), MachineMemOperand::MOStore, MMO->getSize(),
28014       MMO->getAlign(), MMO->getAAInfo());
28015 
28016   SDValue ScatterOps[] = {GChain, Add, Mask, Ptr, Index, Scale};
28017   SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MemVT, DL,
28018                                          ScatterOps, SMMO, IndexType, false);
28019   return Scatter;
28020 }
28021 
28022 SDValue
LowerFixedLengthFPToIntToSVE(SDValue Op,SelectionDAG & DAG) const28023 AArch64TargetLowering::LowerFixedLengthFPToIntToSVE(SDValue Op,
28024                                                     SelectionDAG &DAG) const {
28025   EVT VT = Op.getValueType();
28026   assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
28027 
28028   bool IsSigned = Op.getOpcode() == ISD::FP_TO_SINT;
28029   unsigned Opcode = IsSigned ? AArch64ISD::FCVTZS_MERGE_PASSTHRU
28030                              : AArch64ISD::FCVTZU_MERGE_PASSTHRU;
28031 
28032   SDLoc DL(Op);
28033   SDValue Val = Op.getOperand(0);
28034   EVT SrcVT = Val.getValueType();
28035   EVT ContainerDstVT = getContainerForFixedLengthVector(DAG, VT);
28036   EVT ContainerSrcVT = getContainerForFixedLengthVector(DAG, SrcVT);
28037 
28038   if (VT.bitsGT(SrcVT)) {
28039     EVT CvtVT = ContainerDstVT.changeVectorElementType(
28040       ContainerSrcVT.getVectorElementType());
28041     SDValue Pg = getPredicateForFixedLengthVector(DAG, DL, VT);
28042 
28043     Val = DAG.getNode(ISD::BITCAST, DL, SrcVT.changeTypeToInteger(), Val);
28044     Val = DAG.getNode(ISD::ANY_EXTEND, DL, VT, Val);
28045 
28046     Val = convertToScalableVector(DAG, ContainerDstVT, Val);
28047     Val = getSVESafeBitCast(CvtVT, Val, DAG);
28048     Val = DAG.getNode(Opcode, DL, ContainerDstVT, Pg, Val,
28049                       DAG.getUNDEF(ContainerDstVT));
28050     return convertFromScalableVector(DAG, VT, Val);
28051   } else {
28052     EVT CvtVT = ContainerSrcVT.changeTypeToInteger();
28053     SDValue Pg = getPredicateForFixedLengthVector(DAG, DL, SrcVT);
28054 
28055     // Safe to use a larger than specified result since an fp_to_int where the
28056     // result doesn't fit into the destination is undefined.
28057     Val = convertToScalableVector(DAG, ContainerSrcVT, Val);
28058     Val = DAG.getNode(Opcode, DL, CvtVT, Pg, Val, DAG.getUNDEF(CvtVT));
28059     Val = convertFromScalableVector(DAG, SrcVT.changeTypeToInteger(), Val);
28060 
28061     return DAG.getNode(ISD::TRUNCATE, DL, VT, Val);
28062   }
28063 }
28064 
GenerateFixedLengthSVETBL(SDValue Op,SDValue Op1,SDValue Op2,ArrayRef<int> ShuffleMask,EVT VT,EVT ContainerVT,SelectionDAG & DAG)28065 static SDValue GenerateFixedLengthSVETBL(SDValue Op, SDValue Op1, SDValue Op2,
28066                                          ArrayRef<int> ShuffleMask, EVT VT,
28067                                          EVT ContainerVT, SelectionDAG &DAG) {
28068   auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
28069   SDLoc DL(Op);
28070   unsigned MinSVESize = Subtarget.getMinSVEVectorSizeInBits();
28071   unsigned MaxSVESize = Subtarget.getMaxSVEVectorSizeInBits();
28072   bool IsSingleOp =
28073       ShuffleVectorInst::isSingleSourceMask(ShuffleMask, ShuffleMask.size());
28074 
28075   if (!Subtarget.isNeonAvailable() && !MinSVESize)
28076     MinSVESize = 128;
28077 
28078   // Ignore two operands if no SVE2 or all index numbers couldn't
28079   // be represented.
28080   if (!IsSingleOp && !Subtarget.hasSVE2())
28081     return SDValue();
28082 
28083   EVT VTOp1 = Op.getOperand(0).getValueType();
28084   unsigned BitsPerElt = VTOp1.getVectorElementType().getSizeInBits();
28085   unsigned IndexLen = MinSVESize / BitsPerElt;
28086   unsigned ElementsPerVectorReg = VTOp1.getVectorNumElements();
28087   uint64_t MaxOffset = APInt(BitsPerElt, -1, false).getZExtValue();
28088   EVT MaskEltType = VTOp1.getVectorElementType().changeTypeToInteger();
28089   EVT MaskType = EVT::getVectorVT(*DAG.getContext(), MaskEltType, IndexLen);
28090   bool MinMaxEqual = (MinSVESize == MaxSVESize);
28091   assert(ElementsPerVectorReg <= IndexLen && ShuffleMask.size() <= IndexLen &&
28092          "Incorrectly legalised shuffle operation");
28093 
28094   SmallVector<SDValue, 8> TBLMask;
28095   // If MinSVESize is not equal to MaxSVESize then we need to know which
28096   // TBL mask element needs adjustment.
28097   SmallVector<SDValue, 8> AddRuntimeVLMask;
28098 
28099   // Bail out for 8-bits element types, because with 2048-bit SVE register
28100   // size 8 bits is only sufficient to index into the first source vector.
28101   if (!IsSingleOp && !MinMaxEqual && BitsPerElt == 8)
28102     return SDValue();
28103 
28104   for (int Index : ShuffleMask) {
28105     // Handling poison index value.
28106     if (Index < 0)
28107       Index = 0;
28108     // If the mask refers to elements in the second operand, then we have to
28109     // offset the index by the number of elements in a vector. If this is number
28110     // is not known at compile-time, we need to maintain a mask with 'VL' values
28111     // to add at runtime.
28112     if ((unsigned)Index >= ElementsPerVectorReg) {
28113       if (MinMaxEqual) {
28114         Index += IndexLen - ElementsPerVectorReg;
28115       } else {
28116         Index = Index - ElementsPerVectorReg;
28117         AddRuntimeVLMask.push_back(DAG.getConstant(1, DL, MVT::i64));
28118       }
28119     } else if (!MinMaxEqual)
28120       AddRuntimeVLMask.push_back(DAG.getConstant(0, DL, MVT::i64));
28121     // For 8-bit elements and 1024-bit SVE registers and MaxOffset equals
28122     // to 255, this might point to the last element of in the second operand
28123     // of the shufflevector, thus we are rejecting this transform.
28124     if ((unsigned)Index >= MaxOffset)
28125       return SDValue();
28126     TBLMask.push_back(DAG.getConstant(Index, DL, MVT::i64));
28127   }
28128 
28129   // Choosing an out-of-range index leads to the lane being zeroed vs zero
28130   // value where it would perform first lane duplication for out of
28131   // index elements. For i8 elements an out-of-range index could be a valid
28132   // for 2048-bit vector register size.
28133   for (unsigned i = 0; i < IndexLen - ElementsPerVectorReg; ++i) {
28134     TBLMask.push_back(DAG.getConstant((int)MaxOffset, DL, MVT::i64));
28135     if (!MinMaxEqual)
28136       AddRuntimeVLMask.push_back(DAG.getConstant(0, DL, MVT::i64));
28137   }
28138 
28139   EVT MaskContainerVT = getContainerForFixedLengthVector(DAG, MaskType);
28140   SDValue VecMask =
28141       DAG.getBuildVector(MaskType, DL, ArrayRef(TBLMask.data(), IndexLen));
28142   SDValue SVEMask = convertToScalableVector(DAG, MaskContainerVT, VecMask);
28143 
28144   SDValue Shuffle;
28145   if (IsSingleOp)
28146     Shuffle =
28147         DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ContainerVT,
28148                     DAG.getConstant(Intrinsic::aarch64_sve_tbl, DL, MVT::i32),
28149                     Op1, SVEMask);
28150   else if (Subtarget.hasSVE2()) {
28151     if (!MinMaxEqual) {
28152       unsigned MinNumElts = AArch64::SVEBitsPerBlock / BitsPerElt;
28153       SDValue VScale = (BitsPerElt == 64)
28154                            ? DAG.getVScale(DL, MVT::i64, APInt(64, MinNumElts))
28155                            : DAG.getVScale(DL, MVT::i32, APInt(32, MinNumElts));
28156       SDValue VecMask =
28157           DAG.getBuildVector(MaskType, DL, ArrayRef(TBLMask.data(), IndexLen));
28158       SDValue MulByMask = DAG.getNode(
28159           ISD::MUL, DL, MaskType,
28160           DAG.getNode(ISD::SPLAT_VECTOR, DL, MaskType, VScale),
28161           DAG.getBuildVector(MaskType, DL,
28162                              ArrayRef(AddRuntimeVLMask.data(), IndexLen)));
28163       SDValue UpdatedVecMask =
28164           DAG.getNode(ISD::ADD, DL, MaskType, VecMask, MulByMask);
28165       SVEMask = convertToScalableVector(
28166           DAG, getContainerForFixedLengthVector(DAG, MaskType), UpdatedVecMask);
28167     }
28168     Shuffle =
28169         DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ContainerVT,
28170                     DAG.getConstant(Intrinsic::aarch64_sve_tbl2, DL, MVT::i32),
28171                     Op1, Op2, SVEMask);
28172   }
28173   Shuffle = convertFromScalableVector(DAG, VT, Shuffle);
28174   return DAG.getNode(ISD::BITCAST, DL, Op.getValueType(), Shuffle);
28175 }
28176 
LowerFixedLengthVECTOR_SHUFFLEToSVE(SDValue Op,SelectionDAG & DAG) const28177 SDValue AArch64TargetLowering::LowerFixedLengthVECTOR_SHUFFLEToSVE(
28178     SDValue Op, SelectionDAG &DAG) const {
28179   EVT VT = Op.getValueType();
28180   assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
28181 
28182   auto *SVN = cast<ShuffleVectorSDNode>(Op.getNode());
28183   auto ShuffleMask = SVN->getMask();
28184 
28185   SDLoc DL(Op);
28186   SDValue Op1 = Op.getOperand(0);
28187   SDValue Op2 = Op.getOperand(1);
28188 
28189   EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
28190   Op1 = convertToScalableVector(DAG, ContainerVT, Op1);
28191   Op2 = convertToScalableVector(DAG, ContainerVT, Op2);
28192 
28193   auto MinLegalExtractEltScalarTy = [](EVT ScalarTy) -> EVT {
28194     if (ScalarTy == MVT::i8 || ScalarTy == MVT::i16)
28195       return MVT::i32;
28196     return ScalarTy;
28197   };
28198 
28199   if (SVN->isSplat()) {
28200     unsigned Lane = std::max(0, SVN->getSplatIndex());
28201     EVT ScalarTy = MinLegalExtractEltScalarTy(VT.getVectorElementType());
28202     SDValue SplatEl = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarTy, Op1,
28203                                   DAG.getConstant(Lane, DL, MVT::i64));
28204     Op = DAG.getNode(ISD::SPLAT_VECTOR, DL, ContainerVT, SplatEl);
28205     return convertFromScalableVector(DAG, VT, Op);
28206   }
28207 
28208   bool ReverseEXT = false;
28209   unsigned Imm;
28210   if (isEXTMask(ShuffleMask, VT, ReverseEXT, Imm) &&
28211       Imm == VT.getVectorNumElements() - 1) {
28212     if (ReverseEXT)
28213       std::swap(Op1, Op2);
28214     EVT ScalarTy = MinLegalExtractEltScalarTy(VT.getVectorElementType());
28215     SDValue Scalar = DAG.getNode(
28216         ISD::EXTRACT_VECTOR_ELT, DL, ScalarTy, Op1,
28217         DAG.getConstant(VT.getVectorNumElements() - 1, DL, MVT::i64));
28218     Op = DAG.getNode(AArch64ISD::INSR, DL, ContainerVT, Op2, Scalar);
28219     return convertFromScalableVector(DAG, VT, Op);
28220   }
28221 
28222   unsigned EltSize = VT.getScalarSizeInBits();
28223   for (unsigned LaneSize : {64U, 32U, 16U}) {
28224     if (isREVMask(ShuffleMask, EltSize, VT.getVectorNumElements(), LaneSize)) {
28225       EVT NewVT =
28226           getPackedSVEVectorVT(EVT::getIntegerVT(*DAG.getContext(), LaneSize));
28227       unsigned RevOp;
28228       if (EltSize == 8)
28229         RevOp = AArch64ISD::BSWAP_MERGE_PASSTHRU;
28230       else if (EltSize == 16)
28231         RevOp = AArch64ISD::REVH_MERGE_PASSTHRU;
28232       else
28233         RevOp = AArch64ISD::REVW_MERGE_PASSTHRU;
28234 
28235       Op = DAG.getNode(ISD::BITCAST, DL, NewVT, Op1);
28236       Op = LowerToPredicatedOp(Op, DAG, RevOp);
28237       Op = DAG.getNode(ISD::BITCAST, DL, ContainerVT, Op);
28238       return convertFromScalableVector(DAG, VT, Op);
28239     }
28240   }
28241 
28242   if (Subtarget->hasSVE2p1() && EltSize == 64 &&
28243       isREVMask(ShuffleMask, EltSize, VT.getVectorNumElements(), 128)) {
28244     if (!VT.isFloatingPoint())
28245       return LowerToPredicatedOp(Op, DAG, AArch64ISD::REVD_MERGE_PASSTHRU);
28246 
28247     EVT NewVT = getPackedSVEVectorVT(EVT::getIntegerVT(*DAG.getContext(), 64));
28248     Op = DAG.getNode(ISD::BITCAST, DL, NewVT, Op1);
28249     Op = LowerToPredicatedOp(Op, DAG, AArch64ISD::REVD_MERGE_PASSTHRU);
28250     Op = DAG.getNode(ISD::BITCAST, DL, ContainerVT, Op);
28251     return convertFromScalableVector(DAG, VT, Op);
28252   }
28253 
28254   unsigned WhichResult;
28255   if (isZIPMask(ShuffleMask, VT.getVectorNumElements(), WhichResult) &&
28256       WhichResult == 0)
28257     return convertFromScalableVector(
28258         DAG, VT, DAG.getNode(AArch64ISD::ZIP1, DL, ContainerVT, Op1, Op2));
28259 
28260   if (isTRNMask(ShuffleMask, VT.getVectorNumElements(), WhichResult)) {
28261     unsigned Opc = (WhichResult == 0) ? AArch64ISD::TRN1 : AArch64ISD::TRN2;
28262     return convertFromScalableVector(
28263         DAG, VT, DAG.getNode(Opc, DL, ContainerVT, Op1, Op2));
28264   }
28265 
28266   if (isZIP_v_undef_Mask(ShuffleMask, VT, WhichResult) && WhichResult == 0)
28267     return convertFromScalableVector(
28268         DAG, VT, DAG.getNode(AArch64ISD::ZIP1, DL, ContainerVT, Op1, Op1));
28269 
28270   if (isTRN_v_undef_Mask(ShuffleMask, VT, WhichResult)) {
28271     unsigned Opc = (WhichResult == 0) ? AArch64ISD::TRN1 : AArch64ISD::TRN2;
28272     return convertFromScalableVector(
28273         DAG, VT, DAG.getNode(Opc, DL, ContainerVT, Op1, Op1));
28274   }
28275 
28276   // Functions like isZIPMask return true when a ISD::VECTOR_SHUFFLE's mask
28277   // represents the same logical operation as performed by a ZIP instruction. In
28278   // isolation these functions do not mean the ISD::VECTOR_SHUFFLE is exactly
28279   // equivalent to an AArch64 instruction. There's the extra component of
28280   // ISD::VECTOR_SHUFFLE's value type to consider. Prior to SVE these functions
28281   // only operated on 64/128bit vector types that have a direct mapping to a
28282   // target register and so an exact mapping is implied.
28283   // However, when using SVE for fixed length vectors, most legal vector types
28284   // are actually sub-vectors of a larger SVE register. When mapping
28285   // ISD::VECTOR_SHUFFLE to an SVE instruction care must be taken to consider
28286   // how the mask's indices translate. Specifically, when the mapping requires
28287   // an exact meaning for a specific vector index (e.g. Index X is the last
28288   // vector element in the register) then such mappings are often only safe when
28289   // the exact SVE register size is know. The main exception to this is when
28290   // indices are logically relative to the first element of either
28291   // ISD::VECTOR_SHUFFLE operand because these relative indices don't change
28292   // when converting from fixed-length to scalable vector types (i.e. the start
28293   // of a fixed length vector is always the start of a scalable vector).
28294   unsigned MinSVESize = Subtarget->getMinSVEVectorSizeInBits();
28295   unsigned MaxSVESize = Subtarget->getMaxSVEVectorSizeInBits();
28296   if (MinSVESize == MaxSVESize && MaxSVESize == VT.getSizeInBits()) {
28297     if (ShuffleVectorInst::isReverseMask(ShuffleMask, ShuffleMask.size()) &&
28298         Op2.isUndef()) {
28299       Op = DAG.getNode(ISD::VECTOR_REVERSE, DL, ContainerVT, Op1);
28300       return convertFromScalableVector(DAG, VT, Op);
28301     }
28302 
28303     if (isZIPMask(ShuffleMask, VT.getVectorNumElements(), WhichResult) &&
28304         WhichResult != 0)
28305       return convertFromScalableVector(
28306           DAG, VT, DAG.getNode(AArch64ISD::ZIP2, DL, ContainerVT, Op1, Op2));
28307 
28308     if (isUZPMask(ShuffleMask, VT.getVectorNumElements(), WhichResult)) {
28309       unsigned Opc = (WhichResult == 0) ? AArch64ISD::UZP1 : AArch64ISD::UZP2;
28310       return convertFromScalableVector(
28311           DAG, VT, DAG.getNode(Opc, DL, ContainerVT, Op1, Op2));
28312     }
28313 
28314     if (isZIP_v_undef_Mask(ShuffleMask, VT, WhichResult) && WhichResult != 0)
28315       return convertFromScalableVector(
28316           DAG, VT, DAG.getNode(AArch64ISD::ZIP2, DL, ContainerVT, Op1, Op1));
28317 
28318     if (isUZP_v_undef_Mask(ShuffleMask, VT, WhichResult)) {
28319       unsigned Opc = (WhichResult == 0) ? AArch64ISD::UZP1 : AArch64ISD::UZP2;
28320       return convertFromScalableVector(
28321           DAG, VT, DAG.getNode(Opc, DL, ContainerVT, Op1, Op1));
28322     }
28323   }
28324 
28325   // Avoid producing TBL instruction if we don't know SVE register minimal size,
28326   // unless NEON is not available and we can assume minimal SVE register size is
28327   // 128-bits.
28328   if (MinSVESize || !Subtarget->isNeonAvailable())
28329     return GenerateFixedLengthSVETBL(Op, Op1, Op2, ShuffleMask, VT, ContainerVT,
28330                                      DAG);
28331 
28332   return SDValue();
28333 }
28334 
getSVESafeBitCast(EVT VT,SDValue Op,SelectionDAG & DAG) const28335 SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op,
28336                                                  SelectionDAG &DAG) const {
28337   SDLoc DL(Op);
28338   EVT InVT = Op.getValueType();
28339 
28340   assert(VT.isScalableVector() && isTypeLegal(VT) &&
28341          InVT.isScalableVector() && isTypeLegal(InVT) &&
28342          "Only expect to cast between legal scalable vector types!");
28343   assert(VT.getVectorElementType() != MVT::i1 &&
28344          InVT.getVectorElementType() != MVT::i1 &&
28345          "For predicate bitcasts, use getSVEPredicateBitCast");
28346 
28347   if (InVT == VT)
28348     return Op;
28349 
28350   EVT PackedVT = getPackedSVEVectorVT(VT.getVectorElementType());
28351   EVT PackedInVT = getPackedSVEVectorVT(InVT.getVectorElementType());
28352 
28353   // Safe bitcasting between unpacked vector types of different element counts
28354   // is currently unsupported because the following is missing the necessary
28355   // work to ensure the result's elements live where they're supposed to within
28356   // an SVE register.
28357   //                01234567
28358   // e.g. nxv2i32 = XX??XX??
28359   //      nxv4f16 = X?X?X?X?
28360   assert((VT.getVectorElementCount() == InVT.getVectorElementCount() ||
28361           VT == PackedVT || InVT == PackedInVT) &&
28362          "Unexpected bitcast!");
28363 
28364   // Pack input if required.
28365   if (InVT != PackedInVT)
28366     Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, PackedInVT, Op);
28367 
28368   Op = DAG.getNode(ISD::BITCAST, DL, PackedVT, Op);
28369 
28370   // Unpack result if required.
28371   if (VT != PackedVT)
28372     Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op);
28373 
28374   return Op;
28375 }
28376 
isAllActivePredicate(SelectionDAG & DAG,SDValue N) const28377 bool AArch64TargetLowering::isAllActivePredicate(SelectionDAG &DAG,
28378                                                  SDValue N) const {
28379   return ::isAllActivePredicate(DAG, N);
28380 }
28381 
getPromotedVTForPredicate(EVT VT) const28382 EVT AArch64TargetLowering::getPromotedVTForPredicate(EVT VT) const {
28383   return ::getPromotedVTForPredicate(VT);
28384 }
28385 
SimplifyDemandedBitsForTargetNode(SDValue Op,const APInt & OriginalDemandedBits,const APInt & OriginalDemandedElts,KnownBits & Known,TargetLoweringOpt & TLO,unsigned Depth) const28386 bool AArch64TargetLowering::SimplifyDemandedBitsForTargetNode(
28387     SDValue Op, const APInt &OriginalDemandedBits,
28388     const APInt &OriginalDemandedElts, KnownBits &Known, TargetLoweringOpt &TLO,
28389     unsigned Depth) const {
28390 
28391   unsigned Opc = Op.getOpcode();
28392   switch (Opc) {
28393   case AArch64ISD::VSHL: {
28394     // Match (VSHL (VLSHR Val X) X)
28395     SDValue ShiftL = Op;
28396     SDValue ShiftR = Op->getOperand(0);
28397     if (ShiftR->getOpcode() != AArch64ISD::VLSHR)
28398       return false;
28399 
28400     if (!ShiftL.hasOneUse() || !ShiftR.hasOneUse())
28401       return false;
28402 
28403     unsigned ShiftLBits = ShiftL->getConstantOperandVal(1);
28404     unsigned ShiftRBits = ShiftR->getConstantOperandVal(1);
28405 
28406     // Other cases can be handled as well, but this is not
28407     // implemented.
28408     if (ShiftRBits != ShiftLBits)
28409       return false;
28410 
28411     unsigned ScalarSize = Op.getScalarValueSizeInBits();
28412     assert(ScalarSize > ShiftLBits && "Invalid shift imm");
28413 
28414     APInt ZeroBits = APInt::getLowBitsSet(ScalarSize, ShiftLBits);
28415     APInt UnusedBits = ~OriginalDemandedBits;
28416 
28417     if ((ZeroBits & UnusedBits) != ZeroBits)
28418       return false;
28419 
28420     // All bits that are zeroed by (VSHL (VLSHR Val X) X) are not
28421     // used - simplify to just Val.
28422     return TLO.CombineTo(Op, ShiftR->getOperand(0));
28423   }
28424   case AArch64ISD::BICi: {
28425     // Fold BICi if all destination bits already known to be zeroed
28426     SDValue Op0 = Op.getOperand(0);
28427     KnownBits KnownOp0 =
28428         TLO.DAG.computeKnownBits(Op0, OriginalDemandedElts, Depth + 1);
28429     // Op0 &= ~(ConstantOperandVal(1) << ConstantOperandVal(2))
28430     uint64_t BitsToClear = Op->getConstantOperandVal(1)
28431                            << Op->getConstantOperandVal(2);
28432     APInt AlreadyZeroedBitsToClear = BitsToClear & KnownOp0.Zero;
28433     if (APInt(Known.getBitWidth(), BitsToClear)
28434             .isSubsetOf(AlreadyZeroedBitsToClear))
28435       return TLO.CombineTo(Op, Op0);
28436 
28437     Known = KnownOp0 &
28438             KnownBits::makeConstant(APInt(Known.getBitWidth(), ~BitsToClear));
28439 
28440     return false;
28441   }
28442   case ISD::INTRINSIC_WO_CHAIN: {
28443     if (auto ElementSize = IsSVECntIntrinsic(Op)) {
28444       unsigned MaxSVEVectorSizeInBits = Subtarget->getMaxSVEVectorSizeInBits();
28445       if (!MaxSVEVectorSizeInBits)
28446         MaxSVEVectorSizeInBits = AArch64::SVEMaxBitsPerVector;
28447       unsigned MaxElements = MaxSVEVectorSizeInBits / *ElementSize;
28448       // The SVE count intrinsics don't support the multiplier immediate so we
28449       // don't have to account for that here. The value returned may be slightly
28450       // over the true required bits, as this is based on the "ALL" pattern. The
28451       // other patterns are also exposed by these intrinsics, but they all
28452       // return a value that's strictly less than "ALL".
28453       unsigned RequiredBits = llvm::bit_width(MaxElements);
28454       unsigned BitWidth = Known.Zero.getBitWidth();
28455       if (RequiredBits < BitWidth)
28456         Known.Zero.setHighBits(BitWidth - RequiredBits);
28457       return false;
28458     }
28459   }
28460   }
28461 
28462   return TargetLowering::SimplifyDemandedBitsForTargetNode(
28463       Op, OriginalDemandedBits, OriginalDemandedElts, Known, TLO, Depth);
28464 }
28465 
isTargetCanonicalConstantNode(SDValue Op) const28466 bool AArch64TargetLowering::isTargetCanonicalConstantNode(SDValue Op) const {
28467   return Op.getOpcode() == AArch64ISD::DUP ||
28468          Op.getOpcode() == AArch64ISD::MOVI ||
28469          (Op.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
28470           Op.getOperand(0).getOpcode() == AArch64ISD::DUP) ||
28471          TargetLowering::isTargetCanonicalConstantNode(Op);
28472 }
28473 
isComplexDeinterleavingSupported() const28474 bool AArch64TargetLowering::isComplexDeinterleavingSupported() const {
28475   return Subtarget->hasSVE() || Subtarget->hasSVE2() ||
28476          Subtarget->hasComplxNum();
28477 }
28478 
isComplexDeinterleavingOperationSupported(ComplexDeinterleavingOperation Operation,Type * Ty) const28479 bool AArch64TargetLowering::isComplexDeinterleavingOperationSupported(
28480     ComplexDeinterleavingOperation Operation, Type *Ty) const {
28481   auto *VTy = dyn_cast<VectorType>(Ty);
28482   if (!VTy)
28483     return false;
28484 
28485   // If the vector is scalable, SVE is enabled, implying support for complex
28486   // numbers. Otherwise, we need to ensure complex number support is available
28487   if (!VTy->isScalableTy() && !Subtarget->hasComplxNum())
28488     return false;
28489 
28490   auto *ScalarTy = VTy->getScalarType();
28491   unsigned NumElements = VTy->getElementCount().getKnownMinValue();
28492 
28493   // We can only process vectors that have a bit size of 128 or higher (with an
28494   // additional 64 bits for Neon). Additionally, these vectors must have a
28495   // power-of-2 size, as we later split them into the smallest supported size
28496   // and merging them back together after applying complex operation.
28497   unsigned VTyWidth = VTy->getScalarSizeInBits() * NumElements;
28498   if ((VTyWidth < 128 && (VTy->isScalableTy() || VTyWidth != 64)) ||
28499       !llvm::isPowerOf2_32(VTyWidth))
28500     return false;
28501 
28502   if (ScalarTy->isIntegerTy() && Subtarget->hasSVE2() && VTy->isScalableTy()) {
28503     unsigned ScalarWidth = ScalarTy->getScalarSizeInBits();
28504     return 8 <= ScalarWidth && ScalarWidth <= 64;
28505   }
28506 
28507   return (ScalarTy->isHalfTy() && Subtarget->hasFullFP16()) ||
28508          ScalarTy->isFloatTy() || ScalarTy->isDoubleTy();
28509 }
28510 
createComplexDeinterleavingIR(IRBuilderBase & B,ComplexDeinterleavingOperation OperationType,ComplexDeinterleavingRotation Rotation,Value * InputA,Value * InputB,Value * Accumulator) const28511 Value *AArch64TargetLowering::createComplexDeinterleavingIR(
28512     IRBuilderBase &B, ComplexDeinterleavingOperation OperationType,
28513     ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB,
28514     Value *Accumulator) const {
28515   VectorType *Ty = cast<VectorType>(InputA->getType());
28516   bool IsScalable = Ty->isScalableTy();
28517   bool IsInt = Ty->getElementType()->isIntegerTy();
28518 
28519   unsigned TyWidth =
28520       Ty->getScalarSizeInBits() * Ty->getElementCount().getKnownMinValue();
28521 
28522   assert(((TyWidth >= 128 && llvm::isPowerOf2_32(TyWidth)) || TyWidth == 64) &&
28523          "Vector type must be either 64 or a power of 2 that is at least 128");
28524 
28525   if (TyWidth > 128) {
28526     int Stride = Ty->getElementCount().getKnownMinValue() / 2;
28527     auto *HalfTy = VectorType::getHalfElementsVectorType(Ty);
28528     auto *LowerSplitA = B.CreateExtractVector(HalfTy, InputA, B.getInt64(0));
28529     auto *LowerSplitB = B.CreateExtractVector(HalfTy, InputB, B.getInt64(0));
28530     auto *UpperSplitA =
28531         B.CreateExtractVector(HalfTy, InputA, B.getInt64(Stride));
28532     auto *UpperSplitB =
28533         B.CreateExtractVector(HalfTy, InputB, B.getInt64(Stride));
28534     Value *LowerSplitAcc = nullptr;
28535     Value *UpperSplitAcc = nullptr;
28536     if (Accumulator) {
28537       LowerSplitAcc = B.CreateExtractVector(HalfTy, Accumulator, B.getInt64(0));
28538       UpperSplitAcc =
28539           B.CreateExtractVector(HalfTy, Accumulator, B.getInt64(Stride));
28540     }
28541     auto *LowerSplitInt = createComplexDeinterleavingIR(
28542         B, OperationType, Rotation, LowerSplitA, LowerSplitB, LowerSplitAcc);
28543     auto *UpperSplitInt = createComplexDeinterleavingIR(
28544         B, OperationType, Rotation, UpperSplitA, UpperSplitB, UpperSplitAcc);
28545 
28546     auto *Result = B.CreateInsertVector(Ty, PoisonValue::get(Ty), LowerSplitInt,
28547                                         B.getInt64(0));
28548     return B.CreateInsertVector(Ty, Result, UpperSplitInt, B.getInt64(Stride));
28549   }
28550 
28551   if (OperationType == ComplexDeinterleavingOperation::CMulPartial) {
28552     if (Accumulator == nullptr)
28553       Accumulator = Constant::getNullValue(Ty);
28554 
28555     if (IsScalable) {
28556       if (IsInt)
28557         return B.CreateIntrinsic(
28558             Intrinsic::aarch64_sve_cmla_x, Ty,
28559             {Accumulator, InputA, InputB, B.getInt32((int)Rotation * 90)});
28560 
28561       auto *Mask = B.getAllOnesMask(Ty->getElementCount());
28562       return B.CreateIntrinsic(
28563           Intrinsic::aarch64_sve_fcmla, Ty,
28564           {Mask, Accumulator, InputA, InputB, B.getInt32((int)Rotation * 90)});
28565     }
28566 
28567     Intrinsic::ID IdMap[4] = {Intrinsic::aarch64_neon_vcmla_rot0,
28568                               Intrinsic::aarch64_neon_vcmla_rot90,
28569                               Intrinsic::aarch64_neon_vcmla_rot180,
28570                               Intrinsic::aarch64_neon_vcmla_rot270};
28571 
28572 
28573     return B.CreateIntrinsic(IdMap[(int)Rotation], Ty,
28574                              {Accumulator, InputA, InputB});
28575   }
28576 
28577   if (OperationType == ComplexDeinterleavingOperation::CAdd) {
28578     if (IsScalable) {
28579       if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
28580           Rotation == ComplexDeinterleavingRotation::Rotation_270) {
28581         if (IsInt)
28582           return B.CreateIntrinsic(
28583               Intrinsic::aarch64_sve_cadd_x, Ty,
28584               {InputA, InputB, B.getInt32((int)Rotation * 90)});
28585 
28586         auto *Mask = B.getAllOnesMask(Ty->getElementCount());
28587         return B.CreateIntrinsic(
28588             Intrinsic::aarch64_sve_fcadd, Ty,
28589             {Mask, InputA, InputB, B.getInt32((int)Rotation * 90)});
28590       }
28591       return nullptr;
28592     }
28593 
28594     Intrinsic::ID IntId = Intrinsic::not_intrinsic;
28595     if (Rotation == ComplexDeinterleavingRotation::Rotation_90)
28596       IntId = Intrinsic::aarch64_neon_vcadd_rot90;
28597     else if (Rotation == ComplexDeinterleavingRotation::Rotation_270)
28598       IntId = Intrinsic::aarch64_neon_vcadd_rot270;
28599 
28600     if (IntId == Intrinsic::not_intrinsic)
28601       return nullptr;
28602 
28603     return B.CreateIntrinsic(IntId, Ty, {InputA, InputB});
28604   }
28605 
28606   return nullptr;
28607 }
28608 
preferScalarizeSplat(SDNode * N) const28609 bool AArch64TargetLowering::preferScalarizeSplat(SDNode *N) const {
28610   unsigned Opc = N->getOpcode();
28611   if (ISD::isExtOpcode(Opc)) {
28612     if (any_of(N->uses(),
28613                [&](SDNode *Use) { return Use->getOpcode() == ISD::MUL; }))
28614       return false;
28615   }
28616   return true;
28617 }
28618 
getMinimumJumpTableEntries() const28619 unsigned AArch64TargetLowering::getMinimumJumpTableEntries() const {
28620   return Subtarget->getMinimumJumpTableEntries();
28621 }
28622 
getRegisterTypeForCallingConv(LLVMContext & Context,CallingConv::ID CC,EVT VT) const28623 MVT AArch64TargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
28624                                                          CallingConv::ID CC,
28625                                                          EVT VT) const {
28626   bool NonUnitFixedLengthVector =
28627       VT.isFixedLengthVector() && !VT.getVectorElementCount().isScalar();
28628   if (!NonUnitFixedLengthVector || !Subtarget->useSVEForFixedLengthVectors())
28629     return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
28630 
28631   EVT VT1;
28632   MVT RegisterVT;
28633   unsigned NumIntermediates;
28634   getVectorTypeBreakdownForCallingConv(Context, CC, VT, VT1, NumIntermediates,
28635                                        RegisterVT);
28636   return RegisterVT;
28637 }
28638 
getNumRegistersForCallingConv(LLVMContext & Context,CallingConv::ID CC,EVT VT) const28639 unsigned AArch64TargetLowering::getNumRegistersForCallingConv(
28640     LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
28641   bool NonUnitFixedLengthVector =
28642       VT.isFixedLengthVector() && !VT.getVectorElementCount().isScalar();
28643   if (!NonUnitFixedLengthVector || !Subtarget->useSVEForFixedLengthVectors())
28644     return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT);
28645 
28646   EVT VT1;
28647   MVT VT2;
28648   unsigned NumIntermediates;
28649   return getVectorTypeBreakdownForCallingConv(Context, CC, VT, VT1,
28650                                               NumIntermediates, VT2);
28651 }
28652 
getVectorTypeBreakdownForCallingConv(LLVMContext & Context,CallingConv::ID CC,EVT VT,EVT & IntermediateVT,unsigned & NumIntermediates,MVT & RegisterVT) const28653 unsigned AArch64TargetLowering::getVectorTypeBreakdownForCallingConv(
28654     LLVMContext &Context, CallingConv::ID CC, EVT VT, EVT &IntermediateVT,
28655     unsigned &NumIntermediates, MVT &RegisterVT) const {
28656   int NumRegs = TargetLowering::getVectorTypeBreakdownForCallingConv(
28657       Context, CC, VT, IntermediateVT, NumIntermediates, RegisterVT);
28658   if (!RegisterVT.isFixedLengthVector() ||
28659       RegisterVT.getFixedSizeInBits() <= 128)
28660     return NumRegs;
28661 
28662   assert(Subtarget->useSVEForFixedLengthVectors() && "Unexpected mode!");
28663   assert(IntermediateVT == RegisterVT && "Unexpected VT mismatch!");
28664   assert(RegisterVT.getFixedSizeInBits() % 128 == 0 && "Unexpected size!");
28665 
28666   // A size mismatch here implies either type promotion or widening and would
28667   // have resulted in scalarisation if larger vectors had not be available.
28668   if (RegisterVT.getSizeInBits() * NumRegs != VT.getSizeInBits()) {
28669     EVT EltTy = VT.getVectorElementType();
28670     EVT NewVT = EVT::getVectorVT(Context, EltTy, ElementCount::getFixed(1));
28671     if (!isTypeLegal(NewVT))
28672       NewVT = EltTy;
28673 
28674     IntermediateVT = NewVT;
28675     NumIntermediates = VT.getVectorNumElements();
28676     RegisterVT = getRegisterType(Context, NewVT);
28677     return NumIntermediates;
28678   }
28679 
28680   // SVE VLS support does not introduce a new ABI so we should use NEON sized
28681   // types for vector arguments and returns.
28682 
28683   unsigned NumSubRegs = RegisterVT.getFixedSizeInBits() / 128;
28684   NumIntermediates *= NumSubRegs;
28685   NumRegs *= NumSubRegs;
28686 
28687   switch (RegisterVT.getVectorElementType().SimpleTy) {
28688   default:
28689     llvm_unreachable("unexpected element type for vector");
28690   case MVT::i8:
28691     IntermediateVT = RegisterVT = MVT::v16i8;
28692     break;
28693   case MVT::i16:
28694     IntermediateVT = RegisterVT = MVT::v8i16;
28695     break;
28696   case MVT::i32:
28697     IntermediateVT = RegisterVT = MVT::v4i32;
28698     break;
28699   case MVT::i64:
28700     IntermediateVT = RegisterVT = MVT::v2i64;
28701     break;
28702   case MVT::f16:
28703     IntermediateVT = RegisterVT = MVT::v8f16;
28704     break;
28705   case MVT::f32:
28706     IntermediateVT = RegisterVT = MVT::v4f32;
28707     break;
28708   case MVT::f64:
28709     IntermediateVT = RegisterVT = MVT::v2f64;
28710     break;
28711   case MVT::bf16:
28712     IntermediateVT = RegisterVT = MVT::v8bf16;
28713     break;
28714   }
28715 
28716   return NumRegs;
28717 }
28718 
hasInlineStackProbe(const MachineFunction & MF) const28719 bool AArch64TargetLowering::hasInlineStackProbe(
28720     const MachineFunction &MF) const {
28721   return !Subtarget->isTargetWindows() &&
28722          MF.getInfo<AArch64FunctionInfo>()->hasStackProbing();
28723 }
28724 
28725 #ifndef NDEBUG
verifyTargetSDNode(const SDNode * N) const28726 void AArch64TargetLowering::verifyTargetSDNode(const SDNode *N) const {
28727   switch (N->getOpcode()) {
28728   default:
28729     break;
28730   case AArch64ISD::SUNPKLO:
28731   case AArch64ISD::SUNPKHI:
28732   case AArch64ISD::UUNPKLO:
28733   case AArch64ISD::UUNPKHI: {
28734     assert(N->getNumValues() == 1 && "Expected one result!");
28735     assert(N->getNumOperands() == 1 && "Expected one operand!");
28736     EVT VT = N->getValueType(0);
28737     EVT OpVT = N->getOperand(0).getValueType();
28738     assert(OpVT.isVector() && VT.isVector() && OpVT.isInteger() &&
28739            VT.isInteger() && "Expected integer vectors!");
28740     assert(OpVT.getSizeInBits() == VT.getSizeInBits() &&
28741            "Expected vectors of equal size!");
28742     // TODO: Enable assert once bogus creations have been fixed.
28743     // assert(OpVT.getVectorElementCount() == VT.getVectorElementCount()*2 &&
28744     //       "Expected result vector with half the lanes of its input!");
28745     break;
28746   }
28747   case AArch64ISD::TRN1:
28748   case AArch64ISD::TRN2:
28749   case AArch64ISD::UZP1:
28750   case AArch64ISD::UZP2:
28751   case AArch64ISD::ZIP1:
28752   case AArch64ISD::ZIP2: {
28753     assert(N->getNumValues() == 1 && "Expected one result!");
28754     assert(N->getNumOperands() == 2 && "Expected two operands!");
28755     EVT VT = N->getValueType(0);
28756     EVT Op0VT = N->getOperand(0).getValueType();
28757     EVT Op1VT = N->getOperand(1).getValueType();
28758     assert(VT.isVector() && Op0VT.isVector() && Op1VT.isVector() &&
28759            "Expected vectors!");
28760     // TODO: Enable assert once bogus creations have been fixed.
28761     // assert(VT == Op0VT && VT == Op1VT && "Expected matching vectors!");
28762     break;
28763   }
28764   }
28765 }
28766 #endif
28767